nginx-proxy-manager-zh/backend/internal/api/middleware/list_query.go
2023-07-24 13:42:50 +10:00

196 lines
4.8 KiB
Go

package middleware
import (
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
c "npm/internal/api/context"
h "npm/internal/api/http"
"npm/internal/model"
"npm/internal/tags"
"npm/internal/util"
"github.com/qri-io/jsonschema"
)
// ListQuery will accept a pre-defined schemaData to validate against the GET query params
// passed in to this endpoint. This will ensure that the filters are not injecting SQL
// and the sort parameter is valid as well.
// After we have determined what the Filters are to be, they are saved on the Context
// to be used later in other endpoints.
func ListQuery(obj interface{}) func(http.Handler) http.Handler {
schemaData := tags.GetFilterSchema(obj)
filterMap := tags.GetFilterMap(obj)
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ctx, statusCode, errMsg, errors := listQueryFilters(r, ctx, schemaData)
if statusCode > 0 {
h.ResultErrorJSON(w, r, statusCode, errMsg, errors)
return
}
ctx, statusCode, errMsg = listQuerySort(r, filterMap, ctx)
if statusCode > 0 {
h.ResultErrorJSON(w, r, statusCode, errMsg, nil)
return
}
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
func listQuerySort(
r *http.Request,
filterMap map[string]model.FilterMapValue,
ctx context.Context,
) (context.Context, int, string) {
var sortFields []model.Sort
sortString := r.URL.Query().Get("sort")
if sortString == "" {
return ctx, 0, ""
}
// Split sort fields up in to slice
sorts := strings.Split(sortString, ",")
for _, sortItem := range sorts {
if strings.Contains(sortItem, ".") {
theseItems := strings.Split(sortItem, ".")
switch strings.ToLower(theseItems[1]) {
case "desc":
fallthrough
case "descending":
theseItems[1] = "DESC"
default:
theseItems[1] = "ASC"
}
sortFields = append(sortFields, model.Sort{
Field: theseItems[0],
Direction: theseItems[1],
})
} else {
sortFields = append(sortFields, model.Sort{
Field: sortItem,
Direction: "ASC",
})
}
}
// check against filter schema
for _, f := range sortFields {
if _, exists := filterMap[f.Field]; !exists {
return ctx, http.StatusBadRequest, "Invalid sort field"
}
}
ctx = context.WithValue(ctx, c.SortCtxKey, sortFields)
// No problems!
return ctx, 0, ""
}
func listQueryFilters(
r *http.Request,
ctx context.Context,
schemaData string,
) (context.Context, int, string, interface{}) {
reservedFilterKeys := []string{
"limit",
"offset",
"sort",
"expand",
"t", // This is used as a timestamp paramater in some clients and can be ignored
}
var filters []model.Filter
for key, val := range r.URL.Query() {
key = strings.ToLower(key)
// Split out the modifier from the field name and set a default modifier
var keyParts []string
keyParts = strings.Split(key, ":")
if len(keyParts) == 1 {
// Default modifier
keyParts = append(keyParts, "equals")
}
// Only use this filter if it's not a reserved get param
if !util.SliceContainsItem(reservedFilterKeys, keyParts[0]) {
for _, valItem := range val {
// Check that the val isn't empty
if len(strings.TrimSpace(valItem)) > 0 {
valSlice := []string{valItem}
if keyParts[1] == "in" || keyParts[1] == "notin" {
valSlice = strings.Split(valItem, ",")
}
filters = append(filters, model.Filter{
Field: keyParts[0],
Modifier: keyParts[1],
Value: valSlice,
})
}
}
}
}
// Only validate schema if there are filters to validate
if len(filters) > 0 {
// Marshal the Filters in to a JSON string so that the Schema Validation works against it
filterData, marshalErr := json.MarshalIndent(filters, "", " ")
if marshalErr != nil {
return ctx, http.StatusInternalServerError, fmt.Sprintf("Schema Fatal: %v", marshalErr), nil
}
// Create root schema
rs := &jsonschema.Schema{}
if err := json.Unmarshal([]byte(schemaData), rs); err != nil {
return ctx, http.StatusInternalServerError, fmt.Sprintf("Schema Fatal: %v", err), nil
}
// Validate it
errors, jsonError := rs.ValidateBytes(ctx, filterData)
if jsonError != nil {
return ctx, http.StatusBadRequest, jsonError.Error(), nil
}
if len(errors) > 0 {
return ctx, http.StatusBadRequest, "Invalid Filters", errors
}
ctx = context.WithValue(ctx, c.FiltersCtxKey, filters)
}
// No problems!
return ctx, 0, "", nil
}
// GetFiltersFromContext returns the Filters
func GetFiltersFromContext(r *http.Request) []model.Filter {
filters, ok := r.Context().Value(c.FiltersCtxKey).([]model.Filter)
if !ok {
// the assertion failed
return nil
}
return filters
}
// GetSortFromContext returns the Sort
func GetSortFromContext(r *http.Request) []model.Sort {
sorts, ok := r.Context().Value(c.SortCtxKey).([]model.Sort)
if !ok {
// the assertion failed
return nil
}
return sorts
}