package router import ( "encoding/json" "fmt" "io" "net/http" "strings" "time" "go.jolheiser.com/vanity/sdk" "go.jolheiser.com/vanity/server/database" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "github.com/rs/zerolog/log" ) func New(token, domain string, db *database.Database) *chi.Mux { r := chi.NewRouter() r.Use(middleware.Recoverer) r.Use(middleware.Timeout(60 * time.Second)) r.Mount("/_/", http.StripPrefix("/_/", static())) r.Get("/", indexGET(domain, db)) r.Options("/", infoPackages(db)) r.Post("/", addUpdatePackage(db, token)) r.Patch("/", addUpdatePackage(db, token)) r.Delete("/", removePackage(db, token)) r.Get("/*", vanityGET(domain, db)) return r } func indexGET(domain string, db *database.Database) http.HandlerFunc { return func(res http.ResponseWriter, req *http.Request) { packages, err := db.Packages() if err != nil { log.Error().Msgf("could not load packages: %v", err) http.Error(res, "could not load packages", http.StatusInternalServerError) return } tpl, err := tmpl(domain, "index.tmpl") if err != nil { log.Warn().Msgf("could not load index template: %v", err) } if err := tpl.Execute(res, map[string]interface{}{ "Packages": packages, "Index": true, }); err != nil { log.Error().Msgf("could not write response: %v", err) } } } func vanityGET(domain string, db *database.Database) http.HandlerFunc { return func(res http.ResponseWriter, req *http.Request) { key := chi.URLParam(req, "*") key = strings.Split(key, "/")[0] pkg, err := db.Package(key) if err != nil { if database.IsErrPackageNotFound(err) { http.NotFound(res, req) return } http.Error(res, "could not load package", http.StatusInternalServerError) return } sdf, err := sdk.AnalyzeSDF(pkg) if err != nil { log.Warn().Msgf("could not get SDF for %s: %v", key, err) } ctx := map[string]interface{}{ "Package": pkg, "Module": pkg.Module(domain), "GoSource": fmt.Sprintf("%s %s %s %s", pkg.Module(domain), pkg.CloneHTTP, sdf.Dir, sdf.File), "Index": false, } q := req.URL.Query() if q.Get("go-get") != "" || q.Get("git-import") != "" { tpl, err := tmpl(domain, "import.tmpl") if err != nil { log.Warn().Msgf("could not load import template: %v", err) } if err := tpl.Execute(res, ctx); err != nil { log.Error().Msgf("could not write response: %v", err) } return } tpl, err := tmpl(domain, "vanity.tmpl") if err != nil { log.Warn().Msgf("could not load vanity template: %v", err) } if err := tpl.Execute(res, ctx); err != nil { log.Error().Msgf("could not write response: %v", err) } } } func infoPackages(db *database.Database) func(http.ResponseWriter, *http.Request) { return func(res http.ResponseWriter, req *http.Request) { packages, err := db.Packages() if err != nil { http.Error(res, "could not load package", http.StatusInternalServerError) return } info := sdk.Info{ Version: Version, NumPackages: len(packages), Packages: packages, } if err := json.NewEncoder(res).Encode(info); err != nil { http.Error(res, "could not marshal info", http.StatusInternalServerError) } } } func addUpdatePackage(db *database.Database, token string) func(http.ResponseWriter, *http.Request) { return func(res http.ResponseWriter, req *http.Request) { if req.Header.Get(sdk.TokenHeader) != token { res.WriteHeader(http.StatusUnauthorized) return } data, err := io.ReadAll(req.Body) if err != nil { res.WriteHeader(http.StatusBadRequest) return } defer req.Body.Close() var pkg sdk.Package if err := json.Unmarshal(data, &pkg); err != nil { res.WriteHeader(http.StatusBadRequest) return } exists, err := db.PackageJSON(pkg.Name) if err != nil && !database.IsErrPackageNotFound(err) { res.WriteHeader(http.StatusInternalServerError) return } switch req.Method { case http.MethodPost: if exists != nil { res.WriteHeader(http.StatusConflict) return } case http.MethodPatch: if exists == nil { res.WriteHeader(http.StatusNotFound) return } } if err := db.PutPackage(pkg); err != nil { res.WriteHeader(http.StatusInternalServerError) return } switch req.Method { case http.MethodPost: res.WriteHeader(http.StatusCreated) case http.MethodPatch: res.WriteHeader(http.StatusOK) } } } func removePackage(db *database.Database, token string) func(http.ResponseWriter, *http.Request) { return func(res http.ResponseWriter, req *http.Request) { if req.Header.Get(sdk.TokenHeader) != token { res.WriteHeader(http.StatusUnauthorized) return } data, err := io.ReadAll(req.Body) if err != nil { res.WriteHeader(http.StatusBadRequest) return } defer req.Body.Close() var pkg sdk.Package if err := json.Unmarshal(data, &pkg); err != nil { res.WriteHeader(http.StatusBadRequest) return } if err := db.RemovePackage(pkg.Name); err != nil { res.WriteHeader(http.StatusInternalServerError) return } res.WriteHeader(http.StatusOK) } }