mirror of
https://github.com/xiaoxinpro/nginx-proxy-manager-zh.git
synced 2025-01-22 12:58:13 -05:00
Add backend unit tests
This commit is contained in:
parent
72b071dbaa
commit
b123ca4fd0
@ -1,4 +1,4 @@
|
||||
version: "2"
|
||||
version: "3"
|
||||
|
||||
tasks:
|
||||
default:
|
||||
|
@ -136,7 +136,15 @@ func getCommonArgs() []string {
|
||||
}
|
||||
|
||||
// This is split out into it's own function so it's testable
|
||||
func buildCertRequestArgs(domains []string, method, outputFullchainFile, outputKeyFile string, dnsProvider *dnsprovider.Model, ca *certificateauthority.Model, force bool) ([]string, error) {
|
||||
func buildCertRequestArgs(
|
||||
domains []string,
|
||||
method,
|
||||
outputFullchainFile,
|
||||
outputKeyFile string,
|
||||
dnsProvider *dnsprovider.Model,
|
||||
ca *certificateauthority.Model,
|
||||
force bool,
|
||||
) ([]string, error) {
|
||||
// The argument order matters.
|
||||
// see https://github.com/acmesh-official/acme.sh/wiki/How-to-issue-a-cert#3-multiple-domains-san-mode--hybrid-mode
|
||||
// for multiple domains and note that the method of validation is required just after the domain arg, each time.
|
||||
|
@ -193,3 +193,33 @@ func TestBuildCertRequestArgs(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAcmeShFilePath(t *testing.T) {
|
||||
t.Run("basic test", func(t *testing.T) {
|
||||
path, err := getAcmeShFilePath()
|
||||
assert.Equal(t, "/bin/acme.sh", path)
|
||||
assert.Equal(t, nil, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetCommonEnvVars(t *testing.T) {
|
||||
t.Run("basic test", func(t *testing.T) {
|
||||
expected := []string{
|
||||
"ACMESH_CONFIG_HOME=/data/.acme.sh/config",
|
||||
"ACMESH_HOME=/data/.acme.sh",
|
||||
"CERT_HOME=/data/.acme.sh/certs",
|
||||
"LE_CONFIG_HOME=/data/.acme.sh/config",
|
||||
"LE_WORKING_DIR=/data/.acme.sh",
|
||||
}
|
||||
vals := getCommonEnvVars()
|
||||
assert.Equal(t, expected, vals)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetAcmeShVersion(t *testing.T) {
|
||||
t.Run("basic test", func(t *testing.T) {
|
||||
resp := GetAcmeShVersion()
|
||||
assert.Greater(t, len(resp), 1)
|
||||
assert.Equal(t, "v", resp[:1])
|
||||
})
|
||||
}
|
||||
|
13
backend/internal/api/context/context_test.go
Normal file
13
backend/internal/api/context/context_test.go
Normal file
@ -0,0 +1,13 @@
|
||||
package context
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestGetString(t *testing.T) {
|
||||
t.Run("basic test", func(t *testing.T) {
|
||||
assert.Equal(t, "context value: Body", BodyCtxKey.String())
|
||||
})
|
||||
}
|
@ -1,208 +0,0 @@
|
||||
package filters
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// NewFilterSchema is the main method to specify a new Filter Schema for use in Middleware
|
||||
func NewFilterSchema(fieldSchemas []string) string {
|
||||
return fmt.Sprintf(baseFilterSchema, strings.Join(fieldSchemas, ", "))
|
||||
}
|
||||
|
||||
// BoolFieldSchema returns the Field Schema for a Boolean accepted value field
|
||||
func BoolFieldSchema(fieldName string) string {
|
||||
return fmt.Sprintf(`{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"field": {
|
||||
"type": "string",
|
||||
"pattern": "^%s$"
|
||||
},
|
||||
"modifier": %s,
|
||||
"value": {
|
||||
"oneOf": [
|
||||
%s,
|
||||
{
|
||||
"type": "array",
|
||||
"items": %s
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}`, fieldName, boolModifiers, filterBool, filterBool)
|
||||
}
|
||||
|
||||
// IntFieldSchema returns the Field Schema for a Integer accepted value field
|
||||
func IntFieldSchema(fieldName string) string {
|
||||
return fmt.Sprintf(`{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"field": {
|
||||
"type": "string",
|
||||
"pattern": "^%s$"
|
||||
},
|
||||
"modifier": %s,
|
||||
"value": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "string",
|
||||
"pattern": "^[0-9]+$"
|
||||
},
|
||||
{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"pattern": "^[0-9]+$"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}`, fieldName, allModifiers)
|
||||
}
|
||||
|
||||
// StringFieldSchema returns the Field Schema for a String accepted value field
|
||||
func StringFieldSchema(fieldName string) string {
|
||||
return fmt.Sprintf(`{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"field": {
|
||||
"type": "string",
|
||||
"pattern": "^%s$"
|
||||
},
|
||||
"modifier": %s,
|
||||
"value": {
|
||||
"oneOf": [
|
||||
%s,
|
||||
{
|
||||
"type": "array",
|
||||
"items": %s
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}`, fieldName, stringModifiers, filterString, filterString)
|
||||
}
|
||||
|
||||
// RegexFieldSchema returns the Field Schema for a String accepted value field matching a Regex
|
||||
func RegexFieldSchema(fieldName string, regex string) string {
|
||||
return fmt.Sprintf(`{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"field": {
|
||||
"type": "string",
|
||||
"pattern": "^%s$"
|
||||
},
|
||||
"modifier": %s,
|
||||
"value": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "string",
|
||||
"pattern": "%s"
|
||||
},
|
||||
{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"pattern": "%s"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}`, fieldName, stringModifiers, regex, regex)
|
||||
}
|
||||
|
||||
// DateFieldSchema returns the Field Schema for a String accepted value field matching a Date format
|
||||
func DateFieldSchema(fieldName string) string {
|
||||
return fmt.Sprintf(`{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"field": {
|
||||
"type": "string",
|
||||
"pattern": "^%s$"
|
||||
},
|
||||
"modifier": %s,
|
||||
"value": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "string",
|
||||
"pattern": "^([12]\\d{3}-(0[1-9]|1[0-2])-(0[1-9]|[12]\\d|3[01]))$"
|
||||
},
|
||||
{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"pattern": "^([12]\\d{3}-(0[1-9]|1[0-2])-(0[1-9]|[12]\\d|3[01]))$"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}`, fieldName, allModifiers)
|
||||
}
|
||||
|
||||
// DateTimeFieldSchema returns the Field Schema for a String accepted value field matching a Date format
|
||||
// 2020-03-01T10:30:00+10:00
|
||||
func DateTimeFieldSchema(fieldName string) string {
|
||||
return fmt.Sprintf(`{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"field": {
|
||||
"type": "string",
|
||||
"pattern": "^%s$"
|
||||
},
|
||||
"modifier": %s,
|
||||
"value": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "string",
|
||||
"pattern": "^([12]\\d{3}-(0[1-9]|1[0-2])-(0[1-9]|[12]\\d|3[01]))$"
|
||||
},
|
||||
{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"pattern": "^([12]\\d{3}-(0[1-9]|1[0-2])-(0[1-9]|[12]\\d|3[01]))$"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}`, fieldName, allModifiers)
|
||||
}
|
||||
|
||||
const allModifiers = `{
|
||||
"type": "string",
|
||||
"pattern": "^(equals|not|contains|starts|ends|in|notin|min|max|greater|less)$"
|
||||
}`
|
||||
|
||||
const boolModifiers = `{
|
||||
"type": "string",
|
||||
"pattern": "^(equals|not)$"
|
||||
}`
|
||||
|
||||
const stringModifiers = `{
|
||||
"type": "string",
|
||||
"pattern": "^(equals|not|contains|starts|ends|in|notin)$"
|
||||
}`
|
||||
|
||||
const filterBool = `{
|
||||
"type": "string",
|
||||
"pattern": "^(TRUE|true|t|yes|y|on|1|FALSE|f|false|n|no|off|0)$"
|
||||
}`
|
||||
|
||||
const filterString = `{
|
||||
"type": "string",
|
||||
"minLength": 1
|
||||
}`
|
||||
|
||||
const baseFilterSchema = `{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"oneOf": [
|
||||
%s
|
||||
]
|
||||
}
|
||||
}`
|
@ -26,7 +26,7 @@ func GetCertificates() func(http.ResponseWriter, *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
certificates, err := certificate.List(pageInfo, middleware.GetFiltersFromContext(r), getExpandFromContext(r))
|
||||
certificates, err := certificate.List(pageInfo, middleware.GetFiltersFromContext(r), middleware.GetExpandFromContext(r))
|
||||
if err != nil {
|
||||
h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil)
|
||||
} else {
|
||||
@ -41,7 +41,7 @@ func GetCertificate() func(http.ResponseWriter, *http.Request) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if item := getCertificateFromRequest(w, r); item != nil {
|
||||
// nolint: errcheck,gosec
|
||||
item.Expand(getExpandFromContext(r))
|
||||
item.Expand(middleware.GetExpandFromContext(r))
|
||||
h.ResultResponseJSON(w, r, http.StatusOK, item)
|
||||
}
|
||||
}
|
||||
|
@ -4,8 +4,6 @@ import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"npm/internal/api/context"
|
||||
"npm/internal/api/middleware"
|
||||
"npm/internal/model"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
@ -15,7 +13,7 @@ import (
|
||||
const defaultLimit = 10
|
||||
|
||||
func getPageInfoFromRequest(r *http.Request) (model.PageInfo, error) {
|
||||
var pageInfo model.PageInfo
|
||||
pageInfo := model.PageInfo{}
|
||||
var err error
|
||||
|
||||
pageInfo.Offset, pageInfo.Limit, err = getPagination(r)
|
||||
@ -23,7 +21,7 @@ func getPageInfoFromRequest(r *http.Request) (model.PageInfo, error) {
|
||||
return pageInfo, err
|
||||
}
|
||||
|
||||
pageInfo.Sort = middleware.GetSortFromContext(r)
|
||||
// pageInfo.Sort = middleware.GetSortFromContext(r)
|
||||
|
||||
return pageInfo, nil
|
||||
}
|
||||
@ -93,12 +91,3 @@ func getPagination(r *http.Request) (int, int, error) {
|
||||
|
||||
return offset, limit, nil
|
||||
}
|
||||
|
||||
// getExpandFromContext returns the Expansion setting
|
||||
func getExpandFromContext(r *http.Request) []string {
|
||||
expand, ok := r.Context().Value(context.ExpansionCtxKey).([]string)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return expand
|
||||
}
|
||||
|
118
backend/internal/api/handler/helpers_test.go
Normal file
118
backend/internal/api/handler/helpers_test.go
Normal file
@ -0,0 +1,118 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"npm/internal/model"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestGetPageInfoFromRequest(t *testing.T) {
|
||||
t.Run("basic test", func(t *testing.T) {
|
||||
r := httptest.NewRequest(http.MethodGet, "/hosts", nil)
|
||||
p, err := getPageInfoFromRequest(r)
|
||||
|
||||
var nilStringSlice []string
|
||||
var nilSortSlice []model.Sort
|
||||
defaultSort := model.Sort{Field: "name", Direction: "asc"}
|
||||
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, 0, p.Offset)
|
||||
assert.Equal(t, 10, p.Limit)
|
||||
assert.Equal(t, nilStringSlice, p.Expand)
|
||||
assert.Equal(t, nilSortSlice, p.Sort)
|
||||
assert.Equal(t, []model.Sort{defaultSort}, p.GetSort(defaultSort))
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetQueryVarInt(t *testing.T) {
|
||||
type want struct {
|
||||
val int
|
||||
err string
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
queryVar string
|
||||
required bool
|
||||
defaultValue int
|
||||
want want
|
||||
}{
|
||||
{
|
||||
name: "simple default",
|
||||
url: "/hosts",
|
||||
queryVar: "something",
|
||||
required: false,
|
||||
defaultValue: 100,
|
||||
want: want{
|
||||
val: 100,
|
||||
err: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "required flag",
|
||||
url: "/hosts",
|
||||
queryVar: "something",
|
||||
required: true,
|
||||
want: want{
|
||||
val: 0,
|
||||
err: "something was not supplied in the request",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "simple get",
|
||||
url: "/hosts?something=50",
|
||||
queryVar: "something",
|
||||
required: true,
|
||||
want: want{
|
||||
val: 50,
|
||||
err: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid number",
|
||||
url: "/hosts?something=aaa",
|
||||
queryVar: "something",
|
||||
required: true,
|
||||
want: want{
|
||||
val: 0,
|
||||
err: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "preceding zeros",
|
||||
url: "/hosts?something=0000050",
|
||||
queryVar: "something",
|
||||
required: true,
|
||||
want: want{
|
||||
val: 50,
|
||||
err: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "decimals",
|
||||
url: "/hosts?something=50.50",
|
||||
queryVar: "something",
|
||||
required: true,
|
||||
want: want{
|
||||
val: 0,
|
||||
err: "",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := httptest.NewRequest(http.MethodGet, tt.url, nil)
|
||||
val, err := getQueryVarInt(r, tt.queryVar, tt.required, tt.defaultValue)
|
||||
assert.Equal(t, tt.want.val, val)
|
||||
if tt.want.err != "" {
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, tt.want.err, err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -26,7 +26,7 @@ func GetHosts() func(http.ResponseWriter, *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
hosts, err := host.List(pageInfo, middleware.GetFiltersFromContext(r), getExpandFromContext(r))
|
||||
hosts, err := host.List(pageInfo, middleware.GetFiltersFromContext(r), middleware.GetExpandFromContext(r))
|
||||
if err != nil {
|
||||
h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil)
|
||||
} else {
|
||||
@ -52,7 +52,7 @@ func GetHost() func(http.ResponseWriter, *http.Request) {
|
||||
h.NotFound(w, r)
|
||||
case nil:
|
||||
// nolint: errcheck,gosec
|
||||
item.Expand(getExpandFromContext(r))
|
||||
item.Expand(middleware.GetExpandFromContext(r))
|
||||
h.ResultResponseJSON(w, r, http.StatusOK, item)
|
||||
default:
|
||||
h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil)
|
||||
@ -132,7 +132,7 @@ func UpdateHost() func(http.ResponseWriter, *http.Request) {
|
||||
}
|
||||
|
||||
// nolint: errcheck,gosec
|
||||
hostObject.Expand(getExpandFromContext(r))
|
||||
hostObject.Expand(middleware.GetExpandFromContext(r))
|
||||
|
||||
configureHost(hostObject)
|
||||
|
||||
|
@ -25,14 +25,21 @@ func NotFound() func(http.ResponseWriter, *http.Request) {
|
||||
assetsSub, _ = fs.Sub(embed.Assets, "assets")
|
||||
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
defaultFile := "index.html"
|
||||
path := strings.TrimLeft(r.URL.Path, "/")
|
||||
|
||||
isAPI := false
|
||||
if len(path) >= 3 && path[0:3] == "api" {
|
||||
isAPI = true
|
||||
}
|
||||
|
||||
if path == "" {
|
||||
path = "index.html"
|
||||
path = defaultFile
|
||||
}
|
||||
|
||||
err := tryRead(assetsSub, path, w)
|
||||
if err == errIsDir {
|
||||
err = tryRead(assetsSub, "index.html", w)
|
||||
err = tryRead(assetsSub, defaultFile, w)
|
||||
if err != nil {
|
||||
h.NotFound(w, r)
|
||||
}
|
||||
@ -40,6 +47,16 @@ func NotFound() func(http.ResponseWriter, *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if the path has an extension and not in the "/api" path
|
||||
ext := filepath.Ext(path)
|
||||
if !isAPI && ext == "" {
|
||||
// Not an api endpoint and Not a specific file, return the default index file
|
||||
err := tryRead(assetsSub, defaultFile, w)
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
h.NotFound(w, r)
|
||||
}
|
||||
}
|
||||
|
@ -27,7 +27,7 @@ func GetUpstreams() func(http.ResponseWriter, *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
items, err := upstream.List(pageInfo, middleware.GetFiltersFromContext(r), getExpandFromContext(r))
|
||||
items, err := upstream.List(pageInfo, middleware.GetFiltersFromContext(r), middleware.GetExpandFromContext(r))
|
||||
if err != nil {
|
||||
h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil)
|
||||
} else {
|
||||
@ -53,7 +53,7 @@ func GetUpstream() func(http.ResponseWriter, *http.Request) {
|
||||
h.NotFound(w, r)
|
||||
case nil:
|
||||
// nolint: errcheck,gosec
|
||||
item.Expand(getExpandFromContext(r))
|
||||
item.Expand(middleware.GetExpandFromContext(r))
|
||||
h.ResultResponseJSON(w, r, http.StatusOK, item)
|
||||
default:
|
||||
h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil)
|
||||
@ -127,7 +127,7 @@ func UpdateUpstream() func(http.ResponseWriter, *http.Request) {
|
||||
}
|
||||
|
||||
// nolint: errcheck,gosec
|
||||
// item.Expand(getExpandFromContext(r))
|
||||
// item.Expand(middleware.GetExpandFromContext(r))
|
||||
|
||||
configureUpstream(item)
|
||||
|
||||
|
@ -27,7 +27,7 @@ func GetUsers() func(http.ResponseWriter, *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
users, err := user.List(pageInfo, middleware.GetFiltersFromContext(r), getExpandFromContext(r))
|
||||
users, err := user.List(pageInfo, middleware.GetFiltersFromContext(r), middleware.GetExpandFromContext(r))
|
||||
if err != nil {
|
||||
h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil)
|
||||
} else {
|
||||
@ -52,7 +52,7 @@ func GetUser() func(http.ResponseWriter, *http.Request) {
|
||||
h.NotFound(w, r)
|
||||
case nil:
|
||||
// nolint: errcheck,gosec
|
||||
item.Expand(getExpandFromContext(r))
|
||||
item.Expand(middleware.GetExpandFromContext(r))
|
||||
h.ResultResponseJSON(w, r, http.StatusOK, item)
|
||||
default:
|
||||
h.ResultErrorJSON(w, r, http.StatusBadRequest, err.Error(), nil)
|
||||
@ -108,7 +108,7 @@ func UpdateUser() func(http.ResponseWriter, *http.Request) {
|
||||
}
|
||||
|
||||
// nolint: errcheck,gosec
|
||||
userObject.Expand(getExpandFromContext(r))
|
||||
userObject.Expand(middleware.GetExpandFromContext(r))
|
||||
|
||||
h.ResultResponseJSON(w, r, http.StatusOK, userObject)
|
||||
default:
|
||||
|
@ -1,46 +0,0 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/qri-io/jsonschema"
|
||||
"github.com/rotisserie/eris"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrInvalidJSON is an error for invalid json
|
||||
ErrInvalidJSON = eris.New("JSON is invalid")
|
||||
// ErrInvalidPayload is an error for invalid incoming data
|
||||
ErrInvalidPayload = eris.New("Payload is invalid")
|
||||
)
|
||||
|
||||
// ValidateRequestSchema takes a Schema and the Content to validate against it
|
||||
func ValidateRequestSchema(schema string, requestBody []byte) ([]jsonschema.KeyError, error) {
|
||||
var jsonErrors []jsonschema.KeyError
|
||||
var schemaBytes = []byte(schema)
|
||||
|
||||
// Make sure the body is valid JSON
|
||||
if !isJSON(requestBody) {
|
||||
return jsonErrors, ErrInvalidJSON
|
||||
}
|
||||
|
||||
rs := &jsonschema.Schema{}
|
||||
if err := json.Unmarshal(schemaBytes, rs); err != nil {
|
||||
return jsonErrors, err
|
||||
}
|
||||
|
||||
var validationErr error
|
||||
ctx := context.TODO()
|
||||
if jsonErrors, validationErr = rs.ValidateBytes(ctx, requestBody); len(jsonErrors) > 0 {
|
||||
return jsonErrors, validationErr
|
||||
}
|
||||
|
||||
// Valid
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func isJSON(bytes []byte) bool {
|
||||
var js map[string]interface{}
|
||||
return json.Unmarshal(bytes, &js) == nil
|
||||
}
|
@ -11,6 +11,12 @@ import (
|
||||
"npm/internal/logger"
|
||||
|
||||
"github.com/qri-io/jsonschema"
|
||||
"github.com/rotisserie/eris"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrInvalidPayload is an error for invalid incoming data
|
||||
ErrInvalidPayload = eris.New("Payload is invalid")
|
||||
)
|
||||
|
||||
// Response interface for standard API results
|
||||
|
180
backend/internal/api/http/responses_test.go
Normal file
180
backend/internal/api/http/responses_test.go
Normal file
@ -0,0 +1,180 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"npm/internal/entity/user"
|
||||
"npm/internal/model"
|
||||
"testing"
|
||||
|
||||
"github.com/qri-io/jsonschema"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestResultResponseJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
status int
|
||||
given interface{}
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "simple response",
|
||||
status: http.StatusOK,
|
||||
given: true,
|
||||
want: "{\"result\":true}",
|
||||
},
|
||||
{
|
||||
name: "detailed response",
|
||||
status: http.StatusBadRequest,
|
||||
given: user.Model{
|
||||
ModelBase: model.ModelBase{ID: 10},
|
||||
Email: "me@example.com",
|
||||
Name: "John Doe",
|
||||
Nickname: "Jonny",
|
||||
},
|
||||
want: "{\"result\":{\"id\":10,\"created_at\":0,\"updated_at\":0,\"name\":\"John Doe\",\"nickname\":\"Jonny\",\"email\":\"me@example.com\",\"is_disabled\":false,\"gravatar_url\":\"\"}}",
|
||||
},
|
||||
{
|
||||
name: "error response",
|
||||
status: http.StatusNotFound,
|
||||
given: ErrorResponse{
|
||||
Code: 404,
|
||||
Message: "Not found",
|
||||
Invalid: []string{"your", "page", "was", "not", "found"},
|
||||
},
|
||||
want: "{\"result\":null,\"error\":{\"code\":404,\"message\":\"Not found\",\"invalid\":[\"your\",\"page\",\"was\",\"not\",\"found\"]}}",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := httptest.NewRequest(http.MethodGet, "/anything", nil)
|
||||
w := httptest.NewRecorder()
|
||||
ResultResponseJSON(w, r, tt.status, tt.given)
|
||||
res := w.Result()
|
||||
defer res.Body.Close()
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
t.Errorf("expected error to be nil got %v", err)
|
||||
}
|
||||
assert.Equal(t, tt.want, string(body))
|
||||
assert.Equal(t, tt.status, res.StatusCode)
|
||||
assert.Equal(t, "application/json; charset=utf-8", res.Header.Get("Content-Type"))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResultSchemaErrorJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
given []jsonschema.KeyError
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "case a",
|
||||
given: []jsonschema.KeyError{
|
||||
{
|
||||
PropertyPath: "/something",
|
||||
InvalidValue: "name",
|
||||
Message: "Name cannot be empty",
|
||||
},
|
||||
},
|
||||
want: "{\"result\":null,\"error\":{\"code\":400,\"message\":{},\"invalid\":[{\"propertyPath\":\"/something\",\"invalidValue\":\"name\",\"message\":\"Name cannot be empty\"}]}}",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := httptest.NewRequest(http.MethodGet, "/anything", nil)
|
||||
w := httptest.NewRecorder()
|
||||
ResultSchemaErrorJSON(w, r, tt.given)
|
||||
res := w.Result()
|
||||
defer res.Body.Close()
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
t.Errorf("expected error to be nil got %v", err)
|
||||
}
|
||||
assert.Equal(t, tt.want, string(body))
|
||||
assert.Equal(t, 400, res.StatusCode)
|
||||
assert.Equal(t, "application/json; charset=utf-8", res.Header.Get("Content-Type"))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResultErrorJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
status int
|
||||
message string
|
||||
extended interface{}
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "case a",
|
||||
status: http.StatusBadGateway,
|
||||
message: "Oh not something is not acceptable",
|
||||
extended: nil,
|
||||
want: "{\"result\":null,\"error\":{\"code\":502,\"message\":\"Oh not something is not acceptable\"}}",
|
||||
},
|
||||
{
|
||||
name: "case b",
|
||||
status: http.StatusNotAcceptable,
|
||||
message: "Oh not something is not acceptable again",
|
||||
extended: []string{"name is not allowed", "dob is wrong or something"},
|
||||
want: "{\"result\":null,\"error\":{\"code\":406,\"message\":\"Oh not something is not acceptable again\",\"invalid\":[\"name is not allowed\",\"dob is wrong or something\"]}}",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := httptest.NewRequest(http.MethodGet, "/anything", nil)
|
||||
w := httptest.NewRecorder()
|
||||
ResultErrorJSON(w, r, tt.status, tt.message, tt.extended)
|
||||
res := w.Result()
|
||||
defer res.Body.Close()
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
t.Errorf("expected error to be nil got %v", err)
|
||||
}
|
||||
assert.Equal(t, tt.want, string(body))
|
||||
assert.Equal(t, tt.status, res.StatusCode)
|
||||
assert.Equal(t, "application/json; charset=utf-8", res.Header.Get("Content-Type"))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNotFound(t *testing.T) {
|
||||
t.Run("basic test", func(t *testing.T) {
|
||||
r := httptest.NewRequest(http.MethodGet, "/anything", nil)
|
||||
w := httptest.NewRecorder()
|
||||
NotFound(w, r)
|
||||
res := w.Result()
|
||||
defer res.Body.Close()
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
t.Errorf("expected error to be nil got %v", err)
|
||||
}
|
||||
assert.Equal(t, "{\"result\":null,\"error\":{\"code\":404,\"message\":\"Not found\"}}", string(body))
|
||||
assert.Equal(t, http.StatusNotFound, res.StatusCode)
|
||||
assert.Equal(t, "application/json; charset=utf-8", res.Header.Get("Content-Type"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestResultResponseText(t *testing.T) {
|
||||
t.Run("basic test", func(t *testing.T) {
|
||||
r := httptest.NewRequest(http.MethodGet, "/anything", nil)
|
||||
w := httptest.NewRecorder()
|
||||
ResultResponseText(w, r, http.StatusOK, "omg this works")
|
||||
res := w.Result()
|
||||
defer res.Body.Close()
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
t.Errorf("expected error to be nil got %v", err)
|
||||
}
|
||||
assert.Equal(t, "omg this works", string(body))
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
assert.Equal(t, "text/plain; charset=utf-8", res.Header.Get("Content-Type"))
|
||||
})
|
||||
}
|
@ -22,3 +22,12 @@ func Expansion(next http.Handler) http.Handler {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// GetExpandFromContext returns the Expansion setting
|
||||
func GetExpandFromContext(r *http.Request) []string {
|
||||
expand, ok := r.Context().Value(c.ExpansionCtxKey).([]string)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return expand
|
||||
}
|
||||
|
@ -35,7 +35,7 @@ func TestGetHealthz(t *testing.T) {
|
||||
|
||||
func TestNonExistent(t *testing.T) {
|
||||
respRec := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "/non-existent-endpoint", nil)
|
||||
req, _ := http.NewRequest("GET", "/non-existent-endpoint.jpg", nil)
|
||||
|
||||
r.ServeHTTP(respRec, req)
|
||||
assert.Equal(t, http.StatusNotFound, respRec.Code)
|
||||
|
Loading…
Reference in New Issue
Block a user