From b123ca4fd031f924238631dcb23b2a8dd07b566b Mon Sep 17 00:00:00 2001 From: Jamie Curnow Date: Tue, 25 Jul 2023 11:59:02 +1000 Subject: [PATCH] Add backend unit tests --- backend/Taskfile.yml | 2 +- backend/internal/acme/acmesh.go | 10 +- backend/internal/acme/acmesh_test.go | 30 +++ backend/internal/api/context/context_test.go | 13 ++ backend/internal/api/filters/helpers.go | 208 ------------------- backend/internal/api/handler/certificates.go | 4 +- backend/internal/api/handler/helpers.go | 15 +- backend/internal/api/handler/helpers_test.go | 118 +++++++++++ backend/internal/api/handler/hosts.go | 6 +- backend/internal/api/handler/not_found.go | 21 +- backend/internal/api/handler/upstreams.go | 6 +- backend/internal/api/handler/users.go | 6 +- backend/internal/api/http/requests.go | 46 ---- backend/internal/api/http/responses.go | 6 + backend/internal/api/http/responses_test.go | 180 ++++++++++++++++ backend/internal/api/middleware/expansion.go | 9 + backend/internal/api/router_test.go | 2 +- 17 files changed, 399 insertions(+), 283 deletions(-) create mode 100644 backend/internal/api/context/context_test.go delete mode 100644 backend/internal/api/filters/helpers.go create mode 100644 backend/internal/api/handler/helpers_test.go delete mode 100644 backend/internal/api/http/requests.go create mode 100644 backend/internal/api/http/responses_test.go diff --git a/backend/Taskfile.yml b/backend/Taskfile.yml index 4de171e..512046e 100644 --- a/backend/Taskfile.yml +++ b/backend/Taskfile.yml @@ -1,4 +1,4 @@ -version: "2" +version: "3" tasks: default: diff --git a/backend/internal/acme/acmesh.go b/backend/internal/acme/acmesh.go index e096acf..a7e57d0 100644 --- a/backend/internal/acme/acmesh.go +++ b/backend/internal/acme/acmesh.go @@ -136,7 +136,15 @@ func getCommonArgs() []string { } // This is split out into it's own function so it's testable -func buildCertRequestArgs(domains []string, method, outputFullchainFile, outputKeyFile string, dnsProvider *dnsprovider.Model, ca *certificateauthority.Model, force bool) ([]string, error) { +func buildCertRequestArgs( + domains []string, + method, + outputFullchainFile, + outputKeyFile string, + dnsProvider *dnsprovider.Model, + ca *certificateauthority.Model, + force bool, +) ([]string, error) { // The argument order matters. // see https://github.com/acmesh-official/acme.sh/wiki/How-to-issue-a-cert#3-multiple-domains-san-mode--hybrid-mode // for multiple domains and note that the method of validation is required just after the domain arg, each time. diff --git a/backend/internal/acme/acmesh_test.go b/backend/internal/acme/acmesh_test.go index 08969b8..920ec31 100644 --- a/backend/internal/acme/acmesh_test.go +++ b/backend/internal/acme/acmesh_test.go @@ -193,3 +193,33 @@ func TestBuildCertRequestArgs(t *testing.T) { }) } } + +func TestGetAcmeShFilePath(t *testing.T) { + t.Run("basic test", func(t *testing.T) { + path, err := getAcmeShFilePath() + assert.Equal(t, "/bin/acme.sh", path) + assert.Equal(t, nil, err) + }) +} + +func TestGetCommonEnvVars(t *testing.T) { + t.Run("basic test", func(t *testing.T) { + expected := []string{ + "ACMESH_CONFIG_HOME=/data/.acme.sh/config", + "ACMESH_HOME=/data/.acme.sh", + "CERT_HOME=/data/.acme.sh/certs", + "LE_CONFIG_HOME=/data/.acme.sh/config", + "LE_WORKING_DIR=/data/.acme.sh", + } + vals := getCommonEnvVars() + assert.Equal(t, expected, vals) + }) +} + +func TestGetAcmeShVersion(t *testing.T) { + t.Run("basic test", func(t *testing.T) { + resp := GetAcmeShVersion() + assert.Greater(t, len(resp), 1) + assert.Equal(t, "v", resp[:1]) + }) +} diff --git a/backend/internal/api/context/context_test.go b/backend/internal/api/context/context_test.go new file mode 100644 index 0000000..0c5365d --- /dev/null +++ b/backend/internal/api/context/context_test.go @@ -0,0 +1,13 @@ +package context + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGetString(t *testing.T) { + t.Run("basic test", func(t *testing.T) { + assert.Equal(t, "context value: Body", BodyCtxKey.String()) + }) +} diff --git a/backend/internal/api/filters/helpers.go b/backend/internal/api/filters/helpers.go deleted file mode 100644 index 5f5d523..0000000 --- a/backend/internal/api/filters/helpers.go +++ /dev/null @@ -1,208 +0,0 @@ -package filters - -import ( - "fmt" - "strings" -) - -// NewFilterSchema is the main method to specify a new Filter Schema for use in Middleware -func NewFilterSchema(fieldSchemas []string) string { - return fmt.Sprintf(baseFilterSchema, strings.Join(fieldSchemas, ", ")) -} - -// BoolFieldSchema returns the Field Schema for a Boolean accepted value field -func BoolFieldSchema(fieldName string) string { - return fmt.Sprintf(`{ - "type": "object", - "properties": { - "field": { - "type": "string", - "pattern": "^%s$" - }, - "modifier": %s, - "value": { - "oneOf": [ - %s, - { - "type": "array", - "items": %s - } - ] - } - } - }`, fieldName, boolModifiers, filterBool, filterBool) -} - -// IntFieldSchema returns the Field Schema for a Integer accepted value field -func IntFieldSchema(fieldName string) string { - return fmt.Sprintf(`{ - "type": "object", - "properties": { - "field": { - "type": "string", - "pattern": "^%s$" - }, - "modifier": %s, - "value": { - "oneOf": [ - { - "type": "string", - "pattern": "^[0-9]+$" - }, - { - "type": "array", - "items": { - "type": "string", - "pattern": "^[0-9]+$" - } - } - ] - } - } - }`, fieldName, allModifiers) -} - -// StringFieldSchema returns the Field Schema for a String accepted value field -func StringFieldSchema(fieldName string) string { - return fmt.Sprintf(`{ - "type": "object", - "properties": { - "field": { - "type": "string", - "pattern": "^%s$" - }, - "modifier": %s, - "value": { - "oneOf": [ - %s, - { - "type": "array", - "items": %s - } - ] - } - } - }`, fieldName, stringModifiers, filterString, filterString) -} - -// RegexFieldSchema returns the Field Schema for a String accepted value field matching a Regex -func RegexFieldSchema(fieldName string, regex string) string { - return fmt.Sprintf(`{ - "type": "object", - "properties": { - "field": { - "type": "string", - "pattern": "^%s$" - }, - "modifier": %s, - "value": { - "oneOf": [ - { - "type": "string", - "pattern": "%s" - }, - { - "type": "array", - "items": { - "type": "string", - "pattern": "%s" - } - } - ] - } - } - }`, fieldName, stringModifiers, regex, regex) -} - -// DateFieldSchema returns the Field Schema for a String accepted value field matching a Date format -func DateFieldSchema(fieldName string) string { - return fmt.Sprintf(`{ - "type": "object", - "properties": { - "field": { - "type": "string", - "pattern": "^%s$" - }, - "modifier": %s, - "value": { - "oneOf": [ - { - "type": "string", - "pattern": "^([12]\\d{3}-(0[1-9]|1[0-2])-(0[1-9]|[12]\\d|3[01]))$" - }, - { - "type": "array", - "items": { - "type": "string", - "pattern": "^([12]\\d{3}-(0[1-9]|1[0-2])-(0[1-9]|[12]\\d|3[01]))$" - } - } - ] - } - } - }`, fieldName, allModifiers) -} - -// DateTimeFieldSchema returns the Field Schema for a String accepted value field matching a Date format -// 2020-03-01T10:30:00+10:00 -func DateTimeFieldSchema(fieldName string) string { - return fmt.Sprintf(`{ - "type": "object", - "properties": { - "field": { - "type": "string", - "pattern": "^%s$" - }, - "modifier": %s, - "value": { - "oneOf": [ - { - "type": "string", - "pattern": "^([12]\\d{3}-(0[1-9]|1[0-2])-(0[1-9]|[12]\\d|3[01]))$" - }, - { - "type": "array", - "items": { - "type": "string", - "pattern": "^([12]\\d{3}-(0[1-9]|1[0-2])-(0[1-9]|[12]\\d|3[01]))$" - } - } - ] - } - } - }`, fieldName, allModifiers) -} - -const allModifiers = `{ - "type": "string", - "pattern": "^(equals|not|contains|starts|ends|in|notin|min|max|greater|less)$" -}` - -const boolModifiers = `{ - "type": "string", - "pattern": "^(equals|not)$" -}` - -const stringModifiers = `{ - "type": "string", - "pattern": "^(equals|not|contains|starts|ends|in|notin)$" -}` - -const filterBool = `{ - "type": "string", - "pattern": "^(TRUE|true|t|yes|y|on|1|FALSE|f|false|n|no|off|0)$" -}` - -const filterString = `{ - "type": "string", - "minLength": 1 -}` - -const baseFilterSchema = `{ - "type": "array", - "items": { - "oneOf": [ - %s - ] - } -}` diff --git a/backend/internal/api/handler/certificates.go b/backend/internal/api/handler/certificates.go index aae8b2b..30860e4 100644 --- a/backend/internal/api/handler/certificates.go +++ b/backend/internal/api/handler/certificates.go @@ -26,7 +26,7 @@ func GetCertificates() func(http.ResponseWriter, *http.Request) { return } - certificates, err := certificate.List(pageInfo, middleware.GetFiltersFromContext(r), getExpandFromContext(r)) + certificates, err := certificate.List(pageInfo, middleware.GetFiltersFromContext(r), middleware.GetExpandFromContext(r)) if err != nil { h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil) } else { @@ -41,7 +41,7 @@ func GetCertificate() func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { if item := getCertificateFromRequest(w, r); item != nil { // nolint: errcheck,gosec - item.Expand(getExpandFromContext(r)) + item.Expand(middleware.GetExpandFromContext(r)) h.ResultResponseJSON(w, r, http.StatusOK, item) } } diff --git a/backend/internal/api/handler/helpers.go b/backend/internal/api/handler/helpers.go index d07ac05..8bc6a5e 100644 --- a/backend/internal/api/handler/helpers.go +++ b/backend/internal/api/handler/helpers.go @@ -4,8 +4,6 @@ import ( "net/http" "strconv" - "npm/internal/api/context" - "npm/internal/api/middleware" "npm/internal/model" "github.com/go-chi/chi/v5" @@ -15,7 +13,7 @@ import ( const defaultLimit = 10 func getPageInfoFromRequest(r *http.Request) (model.PageInfo, error) { - var pageInfo model.PageInfo + pageInfo := model.PageInfo{} var err error pageInfo.Offset, pageInfo.Limit, err = getPagination(r) @@ -23,7 +21,7 @@ func getPageInfoFromRequest(r *http.Request) (model.PageInfo, error) { return pageInfo, err } - pageInfo.Sort = middleware.GetSortFromContext(r) + // pageInfo.Sort = middleware.GetSortFromContext(r) return pageInfo, nil } @@ -93,12 +91,3 @@ func getPagination(r *http.Request) (int, int, error) { return offset, limit, nil } - -// getExpandFromContext returns the Expansion setting -func getExpandFromContext(r *http.Request) []string { - expand, ok := r.Context().Value(context.ExpansionCtxKey).([]string) - if !ok { - return nil - } - return expand -} diff --git a/backend/internal/api/handler/helpers_test.go b/backend/internal/api/handler/helpers_test.go new file mode 100644 index 0000000..e0f93f6 --- /dev/null +++ b/backend/internal/api/handler/helpers_test.go @@ -0,0 +1,118 @@ +package handler + +import ( + "net/http" + "net/http/httptest" + "npm/internal/model" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGetPageInfoFromRequest(t *testing.T) { + t.Run("basic test", func(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, "/hosts", nil) + p, err := getPageInfoFromRequest(r) + + var nilStringSlice []string + var nilSortSlice []model.Sort + defaultSort := model.Sort{Field: "name", Direction: "asc"} + + assert.Equal(t, nil, err) + assert.Equal(t, 0, p.Offset) + assert.Equal(t, 10, p.Limit) + assert.Equal(t, nilStringSlice, p.Expand) + assert.Equal(t, nilSortSlice, p.Sort) + assert.Equal(t, []model.Sort{defaultSort}, p.GetSort(defaultSort)) + }) +} + +func TestGetQueryVarInt(t *testing.T) { + type want struct { + val int + err string + } + + tests := []struct { + name string + url string + queryVar string + required bool + defaultValue int + want want + }{ + { + name: "simple default", + url: "/hosts", + queryVar: "something", + required: false, + defaultValue: 100, + want: want{ + val: 100, + err: "", + }, + }, + { + name: "required flag", + url: "/hosts", + queryVar: "something", + required: true, + want: want{ + val: 0, + err: "something was not supplied in the request", + }, + }, + { + name: "simple get", + url: "/hosts?something=50", + queryVar: "something", + required: true, + want: want{ + val: 50, + err: "", + }, + }, + { + name: "invalid number", + url: "/hosts?something=aaa", + queryVar: "something", + required: true, + want: want{ + val: 0, + err: "", + }, + }, + { + name: "preceding zeros", + url: "/hosts?something=0000050", + queryVar: "something", + required: true, + want: want{ + val: 50, + err: "", + }, + }, + { + name: "decimals", + url: "/hosts?something=50.50", + queryVar: "something", + required: true, + want: want{ + val: 0, + err: "", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, tt.url, nil) + val, err := getQueryVarInt(r, tt.queryVar, tt.required, tt.defaultValue) + assert.Equal(t, tt.want.val, val) + if tt.want.err != "" { + assert.NotEqual(t, nil, err) + assert.Equal(t, tt.want.err, err.Error()) + } + }) + } +} diff --git a/backend/internal/api/handler/hosts.go b/backend/internal/api/handler/hosts.go index 4275f0d..18b8483 100644 --- a/backend/internal/api/handler/hosts.go +++ b/backend/internal/api/handler/hosts.go @@ -26,7 +26,7 @@ func GetHosts() func(http.ResponseWriter, *http.Request) { return } - hosts, err := host.List(pageInfo, middleware.GetFiltersFromContext(r), getExpandFromContext(r)) + hosts, err := host.List(pageInfo, middleware.GetFiltersFromContext(r), middleware.GetExpandFromContext(r)) if err != nil { h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil) } else { @@ -52,7 +52,7 @@ func GetHost() func(http.ResponseWriter, *http.Request) { h.NotFound(w, r) case nil: // nolint: errcheck,gosec - item.Expand(getExpandFromContext(r)) + item.Expand(middleware.GetExpandFromContext(r)) h.ResultResponseJSON(w, r, http.StatusOK, item) default: h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil) @@ -132,7 +132,7 @@ func UpdateHost() func(http.ResponseWriter, *http.Request) { } // nolint: errcheck,gosec - hostObject.Expand(getExpandFromContext(r)) + hostObject.Expand(middleware.GetExpandFromContext(r)) configureHost(hostObject) diff --git a/backend/internal/api/handler/not_found.go b/backend/internal/api/handler/not_found.go index 49ed277..7612604 100644 --- a/backend/internal/api/handler/not_found.go +++ b/backend/internal/api/handler/not_found.go @@ -25,14 +25,21 @@ func NotFound() func(http.ResponseWriter, *http.Request) { assetsSub, _ = fs.Sub(embed.Assets, "assets") return func(w http.ResponseWriter, r *http.Request) { + defaultFile := "index.html" path := strings.TrimLeft(r.URL.Path, "/") + + isAPI := false + if len(path) >= 3 && path[0:3] == "api" { + isAPI = true + } + if path == "" { - path = "index.html" + path = defaultFile } err := tryRead(assetsSub, path, w) if err == errIsDir { - err = tryRead(assetsSub, "index.html", w) + err = tryRead(assetsSub, defaultFile, w) if err != nil { h.NotFound(w, r) } @@ -40,6 +47,16 @@ func NotFound() func(http.ResponseWriter, *http.Request) { return } + // Check if the path has an extension and not in the "/api" path + ext := filepath.Ext(path) + if !isAPI && ext == "" { + // Not an api endpoint and Not a specific file, return the default index file + err := tryRead(assetsSub, defaultFile, w) + if err == nil { + return + } + } + h.NotFound(w, r) } } diff --git a/backend/internal/api/handler/upstreams.go b/backend/internal/api/handler/upstreams.go index 3add697..b060acb 100644 --- a/backend/internal/api/handler/upstreams.go +++ b/backend/internal/api/handler/upstreams.go @@ -27,7 +27,7 @@ func GetUpstreams() func(http.ResponseWriter, *http.Request) { return } - items, err := upstream.List(pageInfo, middleware.GetFiltersFromContext(r), getExpandFromContext(r)) + items, err := upstream.List(pageInfo, middleware.GetFiltersFromContext(r), middleware.GetExpandFromContext(r)) if err != nil { h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil) } else { @@ -53,7 +53,7 @@ func GetUpstream() func(http.ResponseWriter, *http.Request) { h.NotFound(w, r) case nil: // nolint: errcheck,gosec - item.Expand(getExpandFromContext(r)) + item.Expand(middleware.GetExpandFromContext(r)) h.ResultResponseJSON(w, r, http.StatusOK, item) default: h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil) @@ -127,7 +127,7 @@ func UpdateUpstream() func(http.ResponseWriter, *http.Request) { } // nolint: errcheck,gosec - // item.Expand(getExpandFromContext(r)) + // item.Expand(middleware.GetExpandFromContext(r)) configureUpstream(item) diff --git a/backend/internal/api/handler/users.go b/backend/internal/api/handler/users.go index 961a0eb..4258886 100644 --- a/backend/internal/api/handler/users.go +++ b/backend/internal/api/handler/users.go @@ -27,7 +27,7 @@ func GetUsers() func(http.ResponseWriter, *http.Request) { return } - users, err := user.List(pageInfo, middleware.GetFiltersFromContext(r), getExpandFromContext(r)) + users, err := user.List(pageInfo, middleware.GetFiltersFromContext(r), middleware.GetExpandFromContext(r)) if err != nil { h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil) } else { @@ -52,7 +52,7 @@ func GetUser() func(http.ResponseWriter, *http.Request) { h.NotFound(w, r) case nil: // nolint: errcheck,gosec - item.Expand(getExpandFromContext(r)) + item.Expand(middleware.GetExpandFromContext(r)) h.ResultResponseJSON(w, r, http.StatusOK, item) default: h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil) @@ -108,7 +108,7 @@ func UpdateUser() func(http.ResponseWriter, *http.Request) { } // nolint: errcheck,gosec - userObject.Expand(getExpandFromContext(r)) + userObject.Expand(middleware.GetExpandFromContext(r)) h.ResultResponseJSON(w, r, http.StatusOK, userObject) default: diff --git a/backend/internal/api/http/requests.go b/backend/internal/api/http/requests.go deleted file mode 100644 index a5b3583..0000000 --- a/backend/internal/api/http/requests.go +++ /dev/null @@ -1,46 +0,0 @@ -package http - -import ( - "context" - "encoding/json" - - "github.com/qri-io/jsonschema" - "github.com/rotisserie/eris" -) - -var ( - // ErrInvalidJSON is an error for invalid json - ErrInvalidJSON = eris.New("JSON is invalid") - // ErrInvalidPayload is an error for invalid incoming data - ErrInvalidPayload = eris.New("Payload is invalid") -) - -// ValidateRequestSchema takes a Schema and the Content to validate against it -func ValidateRequestSchema(schema string, requestBody []byte) ([]jsonschema.KeyError, error) { - var jsonErrors []jsonschema.KeyError - var schemaBytes = []byte(schema) - - // Make sure the body is valid JSON - if !isJSON(requestBody) { - return jsonErrors, ErrInvalidJSON - } - - rs := &jsonschema.Schema{} - if err := json.Unmarshal(schemaBytes, rs); err != nil { - return jsonErrors, err - } - - var validationErr error - ctx := context.TODO() - if jsonErrors, validationErr = rs.ValidateBytes(ctx, requestBody); len(jsonErrors) > 0 { - return jsonErrors, validationErr - } - - // Valid - return nil, nil -} - -func isJSON(bytes []byte) bool { - var js map[string]interface{} - return json.Unmarshal(bytes, &js) == nil -} diff --git a/backend/internal/api/http/responses.go b/backend/internal/api/http/responses.go index 66b2c0c..1bf4212 100644 --- a/backend/internal/api/http/responses.go +++ b/backend/internal/api/http/responses.go @@ -11,6 +11,12 @@ import ( "npm/internal/logger" "github.com/qri-io/jsonschema" + "github.com/rotisserie/eris" +) + +var ( + // ErrInvalidPayload is an error for invalid incoming data + ErrInvalidPayload = eris.New("Payload is invalid") ) // Response interface for standard API results diff --git a/backend/internal/api/http/responses_test.go b/backend/internal/api/http/responses_test.go new file mode 100644 index 0000000..1bc56f8 --- /dev/null +++ b/backend/internal/api/http/responses_test.go @@ -0,0 +1,180 @@ +package http + +import ( + "io" + "net/http" + "net/http/httptest" + "npm/internal/entity/user" + "npm/internal/model" + "testing" + + "github.com/qri-io/jsonschema" + "github.com/stretchr/testify/assert" +) + +func TestResultResponseJSON(t *testing.T) { + tests := []struct { + name string + status int + given interface{} + want string + }{ + { + name: "simple response", + status: http.StatusOK, + given: true, + want: "{\"result\":true}", + }, + { + name: "detailed response", + status: http.StatusBadRequest, + given: user.Model{ + ModelBase: model.ModelBase{ID: 10}, + Email: "me@example.com", + Name: "John Doe", + Nickname: "Jonny", + }, + want: "{\"result\":{\"id\":10,\"created_at\":0,\"updated_at\":0,\"name\":\"John Doe\",\"nickname\":\"Jonny\",\"email\":\"me@example.com\",\"is_disabled\":false,\"gravatar_url\":\"\"}}", + }, + { + name: "error response", + status: http.StatusNotFound, + given: ErrorResponse{ + Code: 404, + Message: "Not found", + Invalid: []string{"your", "page", "was", "not", "found"}, + }, + want: "{\"result\":null,\"error\":{\"code\":404,\"message\":\"Not found\",\"invalid\":[\"your\",\"page\",\"was\",\"not\",\"found\"]}}", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, "/anything", nil) + w := httptest.NewRecorder() + ResultResponseJSON(w, r, tt.status, tt.given) + res := w.Result() + defer res.Body.Close() + body, err := io.ReadAll(res.Body) + if err != nil { + t.Errorf("expected error to be nil got %v", err) + } + assert.Equal(t, tt.want, string(body)) + assert.Equal(t, tt.status, res.StatusCode) + assert.Equal(t, "application/json; charset=utf-8", res.Header.Get("Content-Type")) + }) + } +} + +func TestResultSchemaErrorJSON(t *testing.T) { + tests := []struct { + name string + given []jsonschema.KeyError + want string + }{ + { + name: "case a", + given: []jsonschema.KeyError{ + { + PropertyPath: "/something", + InvalidValue: "name", + Message: "Name cannot be empty", + }, + }, + want: "{\"result\":null,\"error\":{\"code\":400,\"message\":{},\"invalid\":[{\"propertyPath\":\"/something\",\"invalidValue\":\"name\",\"message\":\"Name cannot be empty\"}]}}", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, "/anything", nil) + w := httptest.NewRecorder() + ResultSchemaErrorJSON(w, r, tt.given) + res := w.Result() + defer res.Body.Close() + body, err := io.ReadAll(res.Body) + if err != nil { + t.Errorf("expected error to be nil got %v", err) + } + assert.Equal(t, tt.want, string(body)) + assert.Equal(t, 400, res.StatusCode) + assert.Equal(t, "application/json; charset=utf-8", res.Header.Get("Content-Type")) + }) + } +} + +func TestResultErrorJSON(t *testing.T) { + tests := []struct { + name string + status int + message string + extended interface{} + want string + }{ + { + name: "case a", + status: http.StatusBadGateway, + message: "Oh not something is not acceptable", + extended: nil, + want: "{\"result\":null,\"error\":{\"code\":502,\"message\":\"Oh not something is not acceptable\"}}", + }, + { + name: "case b", + status: http.StatusNotAcceptable, + message: "Oh not something is not acceptable again", + extended: []string{"name is not allowed", "dob is wrong or something"}, + want: "{\"result\":null,\"error\":{\"code\":406,\"message\":\"Oh not something is not acceptable again\",\"invalid\":[\"name is not allowed\",\"dob is wrong or something\"]}}", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, "/anything", nil) + w := httptest.NewRecorder() + ResultErrorJSON(w, r, tt.status, tt.message, tt.extended) + res := w.Result() + defer res.Body.Close() + body, err := io.ReadAll(res.Body) + if err != nil { + t.Errorf("expected error to be nil got %v", err) + } + assert.Equal(t, tt.want, string(body)) + assert.Equal(t, tt.status, res.StatusCode) + assert.Equal(t, "application/json; charset=utf-8", res.Header.Get("Content-Type")) + }) + } +} + +func TestNotFound(t *testing.T) { + t.Run("basic test", func(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, "/anything", nil) + w := httptest.NewRecorder() + NotFound(w, r) + res := w.Result() + defer res.Body.Close() + body, err := io.ReadAll(res.Body) + if err != nil { + t.Errorf("expected error to be nil got %v", err) + } + assert.Equal(t, "{\"result\":null,\"error\":{\"code\":404,\"message\":\"Not found\"}}", string(body)) + assert.Equal(t, http.StatusNotFound, res.StatusCode) + assert.Equal(t, "application/json; charset=utf-8", res.Header.Get("Content-Type")) + }) +} + +func TestResultResponseText(t *testing.T) { + t.Run("basic test", func(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, "/anything", nil) + w := httptest.NewRecorder() + ResultResponseText(w, r, http.StatusOK, "omg this works") + res := w.Result() + defer res.Body.Close() + body, err := io.ReadAll(res.Body) + if err != nil { + t.Errorf("expected error to be nil got %v", err) + } + assert.Equal(t, "omg this works", string(body)) + assert.Equal(t, http.StatusOK, res.StatusCode) + assert.Equal(t, "text/plain; charset=utf-8", res.Header.Get("Content-Type")) + }) +} diff --git a/backend/internal/api/middleware/expansion.go b/backend/internal/api/middleware/expansion.go index bedd3dc..266bfe4 100644 --- a/backend/internal/api/middleware/expansion.go +++ b/backend/internal/api/middleware/expansion.go @@ -22,3 +22,12 @@ func Expansion(next http.Handler) http.Handler { } }) } + +// GetExpandFromContext returns the Expansion setting +func GetExpandFromContext(r *http.Request) []string { + expand, ok := r.Context().Value(c.ExpansionCtxKey).([]string) + if !ok { + return nil + } + return expand +} diff --git a/backend/internal/api/router_test.go b/backend/internal/api/router_test.go index 1d0e1a6..0fb1cf9 100644 --- a/backend/internal/api/router_test.go +++ b/backend/internal/api/router_test.go @@ -35,7 +35,7 @@ func TestGetHealthz(t *testing.T) { func TestNonExistent(t *testing.T) { respRec := httptest.NewRecorder() - req, _ := http.NewRequest("GET", "/non-existent-endpoint", nil) + req, _ := http.NewRequest("GET", "/non-existent-endpoint.jpg", nil) r.ServeHTTP(respRec, req) assert.Equal(t, http.StatusNotFound, respRec.Code)