From 7455accf5855f97f05c9480cd0a94cbfa5363e46 Mon Sep 17 00:00:00 2001 From: Jamie Curnow Date: Mon, 27 Feb 2023 17:21:40 +1000 Subject: [PATCH] Proper 404's for objects --- .gitignore | 2 +- .../api/handler/certificate_authorities.go | 32 ++++++++++------ backend/internal/api/handler/certificates.go | 14 ++++--- backend/internal/api/handler/dns_providers.go | 38 +++++++++++++------ backend/internal/api/handler/hosts.go | 15 +++++--- .../internal/api/handler/nginx_templates.go | 13 ++++--- backend/internal/api/handler/not_found.go | 4 +- backend/internal/api/handler/settings.go | 22 +++++++---- backend/internal/api/handler/streams.go | 32 ++++++++++------ backend/internal/api/handler/upstreams.go | 6 +-- backend/internal/api/handler/users.go | 36 +++++++++++------- backend/internal/api/http/responses.go | 10 +++++ backend/internal/dnsproviders/common.go | 5 +-- backend/internal/errors/errors.go | 1 + 14 files changed, 149 insertions(+), 81 deletions(-) diff --git a/.gitignore b/.gitignore index 35104b2..9b72e48 100644 --- a/.gitignore +++ b/.gitignore @@ -23,4 +23,4 @@ dist backend/embed/acme.sh docker/dev/resolv.conf docker/dev/dnsrouter-config.json.tmp - +thunder-tests diff --git a/backend/internal/api/handler/certificate_authorities.go b/backend/internal/api/handler/certificate_authorities.go index c822cc0..bd0a03e 100644 --- a/backend/internal/api/handler/certificate_authorities.go +++ b/backend/internal/api/handler/certificate_authorities.go @@ -1,6 +1,7 @@ package handler import ( + "database/sql" "encoding/json" "fmt" "net/http" @@ -43,11 +44,14 @@ func GetCertificateAuthority() func(http.ResponseWriter, *http.Request) { return } - cert, err := certificateauthority.GetByID(caID) - if err != nil { + item, err := certificateauthority.GetByID(caID) + switch err { + case sql.ErrNoRows: + h.NotFound(w, r) + case nil: + h.ResultResponseJSON(w, r, http.StatusOK, item) + default: h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil) - } else { - h.ResultResponseJSON(w, r, http.StatusOK, cert) } } } @@ -95,9 +99,10 @@ func UpdateCertificateAuthority() func(http.ResponseWriter, *http.Request) { } ca, err := certificateauthority.GetByID(caID) - if err != nil { - h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil) - } else { + switch err { + case sql.ErrNoRows: + h.NotFound(w, r) + case nil: bodyBytes, _ := r.Context().Value(c.BodyCtxKey).([]byte) err := json.Unmarshal(bodyBytes, &ca) if err != nil { @@ -116,6 +121,8 @@ func UpdateCertificateAuthority() func(http.ResponseWriter, *http.Request) { } h.ResultResponseJSON(w, r, http.StatusOK, ca) + default: + h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil) } } } @@ -131,11 +138,14 @@ func DeleteCertificateAuthority() func(http.ResponseWriter, *http.Request) { return } - cert, err := certificateauthority.GetByID(caID) - if err != nil { + item, err := certificateauthority.GetByID(caID) + switch err { + case sql.ErrNoRows: + h.NotFound(w, r) + case nil: + h.ResultResponseJSON(w, r, http.StatusOK, item.Delete()) + default: h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil) - } else { - h.ResultResponseJSON(w, r, http.StatusOK, cert.Delete()) } } } diff --git a/backend/internal/api/handler/certificates.go b/backend/internal/api/handler/certificates.go index cc4580b..f27d96a 100644 --- a/backend/internal/api/handler/certificates.go +++ b/backend/internal/api/handler/certificates.go @@ -49,7 +49,7 @@ func GetCertificate() func(http.ResponseWriter, *http.Request) { item, err := certificate.GetByID(certificateID) switch err { case sql.ErrNoRows: - h.ResultErrorJSON(w, r, http.StatusNotFound, "Not found", nil) + h.NotFound(w, r) case nil: // nolint: errcheck,gosec item.Expand(getExpandFromContext(r)) @@ -100,10 +100,10 @@ func UpdateCertificate() func(http.ResponseWriter, *http.Request) { } certificateObject, err := certificate.GetByID(certificateID) - if err != nil { - h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil) - } else { - + switch err { + case sql.ErrNoRows: + h.NotFound(w, r) + case nil: // This is a special endpoint, as it needs to verify the schema payload // based on the certificate type, without being given a type in the payload. // The middleware would normally handle this. @@ -133,6 +133,8 @@ func UpdateCertificate() func(http.ResponseWriter, *http.Request) { configureCertificate(certificateObject) h.ResultResponseJSON(w, r, http.StatusOK, certificateObject) + default: + h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil) } } } @@ -151,7 +153,7 @@ func DeleteCertificate() func(http.ResponseWriter, *http.Request) { item, err := certificate.GetByID(certificateID) switch err { case sql.ErrNoRows: - h.ResultErrorJSON(w, r, http.StatusNotFound, "Not found", nil) + h.NotFound(w, r) case nil: // Ensure that this upstream isn't in use by a host cnt := host.GetCertificateUseCount(certificateID) diff --git a/backend/internal/api/handler/dns_providers.go b/backend/internal/api/handler/dns_providers.go index 0f5c8f1..5320f98 100644 --- a/backend/internal/api/handler/dns_providers.go +++ b/backend/internal/api/handler/dns_providers.go @@ -1,6 +1,7 @@ package handler import ( + "database/sql" "encoding/json" "fmt" "net/http" @@ -10,6 +11,7 @@ import ( "npm/internal/api/middleware" "npm/internal/dnsproviders" "npm/internal/entity/dnsprovider" + "npm/internal/errors" ) // GetDNSProviders will return a list of DNS Providers @@ -43,10 +45,13 @@ func GetDNSProvider() func(http.ResponseWriter, *http.Request) { } item, err := dnsprovider.GetByID(providerID) - if err != nil { - h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil) - } else { + switch err { + case sql.ErrNoRows: + h.NotFound(w, r) + case nil: h.ResultResponseJSON(w, r, http.StatusOK, item) + default: + h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil) } } } @@ -89,9 +94,10 @@ func UpdateDNSProvider() func(http.ResponseWriter, *http.Request) { } item, err := dnsprovider.GetByID(providerID) - if err != nil { - h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil) - } else { + switch err { + case sql.ErrNoRows: + h.NotFound(w, r) + case nil: bodyBytes, _ := r.Context().Value(c.BodyCtxKey).([]byte) err := json.Unmarshal(bodyBytes, &item) if err != nil { @@ -105,6 +111,8 @@ func UpdateDNSProvider() func(http.ResponseWriter, *http.Request) { } h.ResultResponseJSON(w, r, http.StatusOK, item) + default: + h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil) } } } @@ -121,10 +129,13 @@ func DeleteDNSProvider() func(http.ResponseWriter, *http.Request) { } item, err := dnsprovider.GetByID(providerID) - if err != nil { - h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil) - } else { + switch err { + case sql.ErrNoRows: + h.NotFound(w, r) + case nil: h.ResultResponseJSON(w, r, http.StatusOK, item.Delete()) + default: + h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil) } } } @@ -149,10 +160,13 @@ func GetAcmeshProvider() func(http.ResponseWriter, *http.Request) { } item, getErr := dnsproviders.Get(acmeshID) - if err != nil { - h.ResultErrorJSON(w, r, http.StatusBadRequest, getErr.Error(), nil) - } else { + switch getErr { + case errors.ErrProviderNotFound: + h.NotFound(w, r) + case nil: h.ResultResponseJSON(w, r, http.StatusOK, item) + default: + h.ResultErrorJSON(w, r, http.StatusBadRequest, getErr.Error(), nil) } } } diff --git a/backend/internal/api/handler/hosts.go b/backend/internal/api/handler/hosts.go index 64e1e21..daec999 100644 --- a/backend/internal/api/handler/hosts.go +++ b/backend/internal/api/handler/hosts.go @@ -49,7 +49,7 @@ func GetHost() func(http.ResponseWriter, *http.Request) { item, err := host.GetByID(hostID) switch err { case sql.ErrNoRows: - h.ResultErrorJSON(w, r, http.StatusNotFound, "Not found", nil) + h.NotFound(w, r) case nil: // nolint: errcheck,gosec item.Expand(getExpandFromContext(r)) @@ -110,9 +110,10 @@ func UpdateHost() func(http.ResponseWriter, *http.Request) { } hostObject, err := host.GetByID(hostID) - if err != nil { - h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil) - } else { + switch err { + case sql.ErrNoRows: + h.NotFound(w, r) + case nil: bodyBytes, _ := r.Context().Value(c.BodyCtxKey).([]byte) err := json.Unmarshal(bodyBytes, &hostObject) if err != nil { @@ -136,6 +137,8 @@ func UpdateHost() func(http.ResponseWriter, *http.Request) { configureHost(hostObject) h.ResultResponseJSON(w, r, http.StatusOK, hostObject) + default: + h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil) } } } @@ -154,7 +157,7 @@ func DeleteHost() func(http.ResponseWriter, *http.Request) { item, err := host.GetByID(hostID) switch err { case sql.ErrNoRows: - h.ResultErrorJSON(w, r, http.StatusNotFound, "Not found", nil) + h.NotFound(w, r) case nil: h.ResultResponseJSON(w, r, http.StatusOK, item.Delete()) configureHost(item) @@ -179,7 +182,7 @@ func GetHostNginxConfig(format string) func(http.ResponseWriter, *http.Request) item, err := host.GetByID(hostID) switch err { case sql.ErrNoRows: - h.ResultErrorJSON(w, r, http.StatusNotFound, "Not found", nil) + h.NotFound(w, r) case nil: // Get the config from disk content, nErr := nginx.GetHostConfigContent(item) diff --git a/backend/internal/api/handler/nginx_templates.go b/backend/internal/api/handler/nginx_templates.go index a459691..b80caa1 100644 --- a/backend/internal/api/handler/nginx_templates.go +++ b/backend/internal/api/handler/nginx_templates.go @@ -45,7 +45,7 @@ func GetNginxTemplate() func(http.ResponseWriter, *http.Request) { item, err := nginxtemplate.GetByID(templateID) switch err { case sql.ErrNoRows: - h.ResultErrorJSON(w, r, http.StatusNotFound, "Not found", nil) + h.NotFound(w, r) case nil: h.ResultResponseJSON(w, r, http.StatusOK, item) default: @@ -94,9 +94,10 @@ func UpdateNginxTemplate() func(http.ResponseWriter, *http.Request) { // reconfigure, _ := getQueryVarBool(r, "reconfigure", false, false) nginxTemplate, err := nginxtemplate.GetByID(templateID) - if err != nil { - h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil) - } else { + switch err { + case sql.ErrNoRows: + h.NotFound(w, r) + case nil: bodyBytes, _ := r.Context().Value(c.BodyCtxKey).([]byte) err := json.Unmarshal(bodyBytes, &nginxTemplate) if err != nil { @@ -110,6 +111,8 @@ func UpdateNginxTemplate() func(http.ResponseWriter, *http.Request) { } h.ResultResponseJSON(w, r, http.StatusOK, nginxTemplate) + default: + h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil) } } } @@ -128,7 +131,7 @@ func DeleteNginxTemplate() func(http.ResponseWriter, *http.Request) { item, err := nginxtemplate.GetByID(templateID) switch err { case sql.ErrNoRows: - h.ResultErrorJSON(w, r, http.StatusNotFound, "Not found", nil) + h.NotFound(w, r) case nil: h.ResultResponseJSON(w, r, http.StatusOK, item.Delete()) default: diff --git a/backend/internal/api/handler/not_found.go b/backend/internal/api/handler/not_found.go index 341be08..49ed277 100644 --- a/backend/internal/api/handler/not_found.go +++ b/backend/internal/api/handler/not_found.go @@ -34,13 +34,13 @@ func NotFound() func(http.ResponseWriter, *http.Request) { if err == errIsDir { err = tryRead(assetsSub, "index.html", w) if err != nil { - h.ResultErrorJSON(w, r, http.StatusNotFound, "Not found", nil) + h.NotFound(w, r) } } else if err == nil { return } - h.ResultErrorJSON(w, r, http.StatusNotFound, "Not found", nil) + h.NotFound(w, r) } } diff --git a/backend/internal/api/handler/settings.go b/backend/internal/api/handler/settings.go index b0e1d7a..2af313f 100644 --- a/backend/internal/api/handler/settings.go +++ b/backend/internal/api/handler/settings.go @@ -1,6 +1,7 @@ package handler import ( + "database/sql" "encoding/json" "fmt" "net/http" @@ -38,11 +39,14 @@ func GetSetting() func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { name := chi.URLParam(r, "name") - sett, err := setting.GetByName(name) - if err != nil { + item, err := setting.GetByName(name) + switch err { + case sql.ErrNoRows: + h.NotFound(w, r) + case nil: + h.ResultResponseJSON(w, r, http.StatusOK, item) + default: h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil) - } else { - h.ResultResponseJSON(w, r, http.StatusOK, sett) } } } @@ -76,10 +80,10 @@ func UpdateSetting() func(http.ResponseWriter, *http.Request) { settingName := chi.URLParam(r, "name") setting, err := setting.GetByName(settingName) - if err != nil { - h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil) - } else { - + switch err { + case sql.ErrNoRows: + h.NotFound(w, r) + case nil: bodyBytes, _ := r.Context().Value(c.BodyCtxKey).([]byte) err := json.Unmarshal(bodyBytes, &setting) if err != nil { @@ -93,6 +97,8 @@ func UpdateSetting() func(http.ResponseWriter, *http.Request) { } h.ResultResponseJSON(w, r, http.StatusOK, setting) + default: + h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil) } } } diff --git a/backend/internal/api/handler/streams.go b/backend/internal/api/handler/streams.go index 17c668e..c3f8fa0 100644 --- a/backend/internal/api/handler/streams.go +++ b/backend/internal/api/handler/streams.go @@ -1,6 +1,7 @@ package handler import ( + "database/sql" "encoding/json" "fmt" "net/http" @@ -41,11 +42,14 @@ func GetStream() func(http.ResponseWriter, *http.Request) { return } - host, err := stream.GetByID(hostID) - if err != nil { + item, err := stream.GetByID(hostID) + switch err { + case sql.ErrNoRows: + h.NotFound(w, r) + case nil: + h.ResultResponseJSON(w, r, http.StatusOK, item) + default: h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil) - } else { - h.ResultResponseJSON(w, r, http.StatusOK, host) } } } @@ -88,9 +92,10 @@ func UpdateStream() func(http.ResponseWriter, *http.Request) { } host, err := stream.GetByID(hostID) - if err != nil { - h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil) - } else { + switch err { + case sql.ErrNoRows: + h.NotFound(w, r) + case nil: bodyBytes, _ := r.Context().Value(c.BodyCtxKey).([]byte) err := json.Unmarshal(bodyBytes, &host) if err != nil { @@ -104,6 +109,8 @@ func UpdateStream() func(http.ResponseWriter, *http.Request) { } h.ResultResponseJSON(w, r, http.StatusOK, host) + default: + h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil) } } } @@ -119,11 +126,14 @@ func DeleteStream() func(http.ResponseWriter, *http.Request) { return } - host, err := stream.GetByID(hostID) - if err != nil { + item, err := stream.GetByID(hostID) + switch err { + case sql.ErrNoRows: + h.NotFound(w, r) + case nil: + h.ResultResponseJSON(w, r, http.StatusOK, item.Delete()) + default: h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil) - } else { - h.ResultResponseJSON(w, r, http.StatusOK, host.Delete()) } } } diff --git a/backend/internal/api/handler/upstreams.go b/backend/internal/api/handler/upstreams.go index 4794ece..7a470f9 100644 --- a/backend/internal/api/handler/upstreams.go +++ b/backend/internal/api/handler/upstreams.go @@ -50,7 +50,7 @@ func GetUpstream() func(http.ResponseWriter, *http.Request) { item, err := upstream.GetByID(upstreamID) switch err { case sql.ErrNoRows: - h.ResultErrorJSON(w, r, http.StatusNotFound, "Not found", nil) + h.NotFound(w, r) case nil: // nolint: errcheck,gosec item.Expand(getExpandFromContext(r)) @@ -150,7 +150,7 @@ func DeleteUpstream() func(http.ResponseWriter, *http.Request) { item, err := upstream.GetByID(upstreamID) switch err { case sql.ErrNoRows: - h.ResultErrorJSON(w, r, http.StatusNotFound, "Not found", nil) + h.NotFound(w, r) case nil: // Ensure that this upstream isn't in use by a host cnt := host.GetUpstreamUseCount(upstreamID) @@ -181,7 +181,7 @@ func GetUpstreamNginxConfig(format string) func(http.ResponseWriter, *http.Reque item, err := upstream.GetByID(upstreamID) switch err { case sql.ErrNoRows: - h.ResultErrorJSON(w, r, http.StatusNotFound, "Not found", nil) + h.NotFound(w, r) case nil: // Get the config from disk content, nErr := nginx.GetUpstreamConfigContent(item) diff --git a/backend/internal/api/handler/users.go b/backend/internal/api/handler/users.go index 192ff7b..f9b06f3 100644 --- a/backend/internal/api/handler/users.go +++ b/backend/internal/api/handler/users.go @@ -1,6 +1,7 @@ package handler import ( + "database/sql" "encoding/json" "net/http" @@ -45,13 +46,16 @@ func GetUser() func(http.ResponseWriter, *http.Request) { return } - userObject, err := user.GetByID(userID) - if err != nil { - h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil) - } else { + item, err := user.GetByID(userID) + switch err { + case sql.ErrNoRows: + h.NotFound(w, r) + case nil: // nolint: errcheck,gosec - userObject.Expand(getExpandFromContext(r)) - h.ResultResponseJSON(w, r, http.StatusOK, userObject) + item.Expand(getExpandFromContext(r)) + h.ResultResponseJSON(w, r, http.StatusOK, item) + default: + h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil) } } } @@ -67,9 +71,10 @@ func UpdateUser() func(http.ResponseWriter, *http.Request) { } userObject, err := user.GetByID(userID) - if err != nil { - h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil) - } else { + switch err { + case sql.ErrNoRows: + h.NotFound(w, r) + case nil: // nolint: errcheck,gosec userObject.Expand([]string{"capabilities"}) bodyBytes, _ := r.Context().Value(c.BodyCtxKey).([]byte) @@ -106,6 +111,8 @@ func UpdateUser() func(http.ResponseWriter, *http.Request) { userObject.Expand(getExpandFromContext(r)) h.ResultResponseJSON(w, r, http.StatusOK, userObject) + default: + h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil) } } } @@ -127,11 +134,14 @@ func DeleteUser() func(http.ResponseWriter, *http.Request) { return } - user, err := user.GetByID(userID) - if err != nil { + item, err := user.GetByID(userID) + switch err { + case sql.ErrNoRows: + h.NotFound(w, r) + case nil: + h.ResultResponseJSON(w, r, http.StatusOK, item.Delete()) + default: h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil) - } else { - h.ResultResponseJSON(w, r, http.StatusOK, user.Delete()) } } } diff --git a/backend/internal/api/http/responses.go b/backend/internal/api/http/responses.go index 881503d..66b2c0c 100644 --- a/backend/internal/api/http/responses.go +++ b/backend/internal/api/http/responses.go @@ -81,6 +81,16 @@ func ResultErrorJSON(w http.ResponseWriter, r *http.Request, status int, message ResultResponseJSON(w, r, status, errorResponse) } +// NotFound will return a 404 response +func NotFound(w http.ResponseWriter, r *http.Request) { + errorResponse := ErrorResponse{ + Code: http.StatusNotFound, + Message: "Not found", + } + + ResultResponseJSON(w, r, http.StatusNotFound, errorResponse) +} + // ResultResponseText will write the result as text to the http output func ResultResponseText(w http.ResponseWriter, r *http.Request, status int, content string) { w.Header().Set("Content-Type", "text/plain; charset=utf-8") diff --git a/backend/internal/dnsproviders/common.go b/backend/internal/dnsproviders/common.go index 0e8bb4b..b552ff8 100644 --- a/backend/internal/dnsproviders/common.go +++ b/backend/internal/dnsproviders/common.go @@ -2,8 +2,7 @@ package dnsproviders import ( "encoding/json" - - "github.com/rotisserie/eris" + "npm/internal/errors" ) // providerField should mimick jsonschema, so that @@ -112,5 +111,5 @@ func Get(provider string) (Provider, error) { if item, found := all[provider]; found { return item, nil } - return Provider{}, eris.New("provider_not_found") + return Provider{}, errors.ErrProviderNotFound } diff --git a/backend/internal/errors/errors.go b/backend/internal/errors/errors.go index 314c59e..80f864d 100644 --- a/backend/internal/errors/errors.go +++ b/backend/internal/errors/errors.go @@ -15,4 +15,5 @@ var ( ErrValidationFailed = eris.New("request-failed-validation") ErrCurrentPasswordInvalid = eris.New("current-password-invalid") ErrCABundleDoesNotExist = eris.New("ca-bundle-does-not-exist") + ErrProviderNotFound = eris.New("provider_not_found") )