From a0e17f9678283a6d2038df03e1911691a7acdc45 Mon Sep 17 00:00:00 2001 From: Jamie Curnow Date: Mon, 24 Jul 2023 11:49:08 +1000 Subject: [PATCH] Better checking for api sort param to prevent sql injection And moved filters out and cached object reflection --- backend/internal/api/context/context.go | 2 + backend/internal/api/handler/helpers.go | 43 +--- backend/internal/api/middleware/expansion.go | 2 +- backend/internal/api/middleware/filters.go | 118 ----------- backend/internal/api/middleware/list_query.go | 196 ++++++++++++++++++ backend/internal/api/router.go | 20 +- backend/internal/entity/filters.go | 58 +----- backend/internal/entity/lists.go | 2 +- backend/internal/entity/scopes.go | 2 +- backend/internal/model/filter.go | 6 + backend/internal/tags/filters.go | 53 +++++ backend/internal/tags/reflect.go | 33 +++ 12 files changed, 312 insertions(+), 223 deletions(-) delete mode 100644 backend/internal/api/middleware/filters.go create mode 100644 backend/internal/api/middleware/list_query.go create mode 100644 backend/internal/tags/filters.go create mode 100644 backend/internal/tags/reflect.go diff --git a/backend/internal/api/context/context.go b/backend/internal/api/context/context.go index f3bc957..dc9f32b 100644 --- a/backend/internal/api/context/context.go +++ b/backend/internal/api/context/context.go @@ -7,6 +7,8 @@ var ( UserIDCtxKey = &contextKey{"UserID"} // FiltersCtxKey is the name of the Filters value on the context FiltersCtxKey = &contextKey{"Filters"} + // SortCtxKey is the name of the Sort value on the context + SortCtxKey = &contextKey{"Sort"} // PrettyPrintCtxKey is the name of the pretty print context PrettyPrintCtxKey = &contextKey{"Pretty"} // ExpansionCtxKey is the name of the expansion context diff --git a/backend/internal/api/handler/helpers.go b/backend/internal/api/handler/helpers.go index d6f5e58..d07ac05 100644 --- a/backend/internal/api/handler/helpers.go +++ b/backend/internal/api/handler/helpers.go @@ -3,9 +3,9 @@ package handler import ( "net/http" "strconv" - "strings" "npm/internal/api/context" + "npm/internal/api/middleware" "npm/internal/model" "github.com/go-chi/chi/v5" @@ -23,50 +23,11 @@ func getPageInfoFromRequest(r *http.Request) (model.PageInfo, error) { return pageInfo, err } - pageInfo.Sort = getSortParameter(r) + pageInfo.Sort = middleware.GetSortFromContext(r) return pageInfo, nil } -func getSortParameter(r *http.Request) []model.Sort { - var sortFields []model.Sort - - queryValues := r.URL.Query() - sortString := queryValues.Get("sort") - if sortString == "" { - return sortFields - } - - // Split sort fields up in to slice - sorts := strings.Split(sortString, ",") - for _, sortItem := range sorts { - if strings.Contains(sortItem, ".") { - theseItems := strings.Split(sortItem, ".") - - switch strings.ToLower(theseItems[1]) { - case "desc": - fallthrough - case "descending": - theseItems[1] = "DESC" - default: - theseItems[1] = "ASC" - } - - sortFields = append(sortFields, model.Sort{ - Field: theseItems[0], - Direction: theseItems[1], - }) - } else { - sortFields = append(sortFields, model.Sort{ - Field: sortItem, - Direction: "ASC", - }) - } - } - - return sortFields -} - func getQueryVarInt(r *http.Request, varName string, required bool, defaultValue int) (int, error) { queryValues := r.URL.Query() varValue := queryValues.Get(varName) diff --git a/backend/internal/api/middleware/expansion.go b/backend/internal/api/middleware/expansion.go index 871bd27..bedd3dc 100644 --- a/backend/internal/api/middleware/expansion.go +++ b/backend/internal/api/middleware/expansion.go @@ -9,7 +9,7 @@ import ( ) // Expansion will determine whether the request should have objects expanded -// with ?expand=1 or ?expand=true +// with ?expand=item,item func Expansion(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { expandStr := r.URL.Query().Get("expand") diff --git a/backend/internal/api/middleware/filters.go b/backend/internal/api/middleware/filters.go deleted file mode 100644 index 21c9518..0000000 --- a/backend/internal/api/middleware/filters.go +++ /dev/null @@ -1,118 +0,0 @@ -package middleware - -import ( - "context" - "encoding/json" - "fmt" - "net/http" - "strings" - - c "npm/internal/api/context" - h "npm/internal/api/http" - "npm/internal/entity" - "npm/internal/model" - "npm/internal/util" - - "github.com/qri-io/jsonschema" -) - -// Filters will accept a pre-defined schemaData to validate against the GET query params -// passed in to this endpoint. This will ensure that the filters are not injecting SQL. -// After we have determined what the Filters are to be, they are saved on the Context -// to be used later in other endpoints. -func Filters(obj interface{}) func(http.Handler) http.Handler { - schemaData := entity.GetFilterSchema(obj, true) - - reservedFilterKeys := []string{ - "limit", - "offset", - "sort", - "order", - "expand", - "t", // This is used as a timestamp paramater in some clients and can be ignored - } - - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - var filters []model.Filter - for key, val := range r.URL.Query() { - key = strings.ToLower(key) - - // Split out the modifier from the field name and set a default modifier - var keyParts []string - keyParts = strings.Split(key, ":") - if len(keyParts) == 1 { - // Default modifier - keyParts = append(keyParts, "equals") - } - - // Only use this filter if it's not a reserved get param - if !util.SliceContainsItem(reservedFilterKeys, keyParts[0]) { - for _, valItem := range val { - // Check that the val isn't empty - if len(strings.TrimSpace(valItem)) > 0 { - valSlice := []string{valItem} - if keyParts[1] == "in" || keyParts[1] == "notin" { - valSlice = strings.Split(valItem, ",") - } - - filters = append(filters, model.Filter{ - Field: keyParts[0], - Modifier: keyParts[1], - Value: valSlice, - }) - } - } - } - } - - // Only validate schema if there are filters to validate - if len(filters) > 0 { - ctx := r.Context() - - // Marshal the Filters in to a JSON string so that the Schema Validation works against it - filterData, marshalErr := json.MarshalIndent(filters, "", " ") - if marshalErr != nil { - h.ResultErrorJSON(w, r, http.StatusInternalServerError, fmt.Sprintf("Schema Fatal: %v", marshalErr), nil) - return - } - - // Create root schema - rs := &jsonschema.Schema{} - if err := json.Unmarshal([]byte(schemaData), rs); err != nil { - h.ResultErrorJSON(w, r, http.StatusInternalServerError, fmt.Sprintf("Schema Fatal: %v", err), nil) - return - } - - // Validate it - errors, jsonError := rs.ValidateBytes(ctx, filterData) - if jsonError != nil { - h.ResultErrorJSON(w, r, http.StatusBadRequest, jsonError.Error(), nil) - return - } - - if len(errors) > 0 { - h.ResultErrorJSON(w, r, http.StatusBadRequest, "Invalid Filters", errors) - return - } - - // todo: populate filters object with the gorm database name - - ctx = context.WithValue(ctx, c.FiltersCtxKey, filters) - next.ServeHTTP(w, r.WithContext(ctx)) - } else { - next.ServeHTTP(w, r) - } - }) - } -} - -// GetFiltersFromContext returns the Filters -func GetFiltersFromContext(r *http.Request) []model.Filter { - filters, ok := r.Context().Value(c.FiltersCtxKey).([]model.Filter) - if !ok { - // the assertion failed - return nil - } - return filters -} diff --git a/backend/internal/api/middleware/list_query.go b/backend/internal/api/middleware/list_query.go new file mode 100644 index 0000000..fd8506f --- /dev/null +++ b/backend/internal/api/middleware/list_query.go @@ -0,0 +1,196 @@ +package middleware + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strings" + + c "npm/internal/api/context" + h "npm/internal/api/http" + "npm/internal/entity" + "npm/internal/model" + "npm/internal/tags" + "npm/internal/util" + + "github.com/qri-io/jsonschema" +) + +// ListQuery will accept a pre-defined schemaData to validate against the GET query params +// passed in to this endpoint. This will ensure that the filters are not injecting SQL +// and the sort parameter is valid as well. +// After we have determined what the Filters are to be, they are saved on the Context +// to be used later in other endpoints. +func ListQuery(obj interface{}) func(http.Handler) http.Handler { + schemaData := entity.GetFilterSchema(obj, true) + filterMap := tags.GetFilterMap(obj) + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + ctx, statusCode, errMsg, errors := listQueryFilters(r, ctx, schemaData) + if statusCode > 0 { + h.ResultErrorJSON(w, r, statusCode, errMsg, errors) + return + } + + ctx, statusCode, errMsg = listQuerySort(r, filterMap, ctx) + if statusCode > 0 { + h.ResultErrorJSON(w, r, statusCode, errMsg, nil) + return + } + + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} + +func listQuerySort( + r *http.Request, + filterMap map[string]model.FilterMapValue, + ctx context.Context, +) (context.Context, int, string) { + var sortFields []model.Sort + + sortString := r.URL.Query().Get("sort") + if sortString == "" { + return ctx, 0, "" + } + + // Split sort fields up in to slice + sorts := strings.Split(sortString, ",") + for _, sortItem := range sorts { + if strings.Contains(sortItem, ".") { + theseItems := strings.Split(sortItem, ".") + + switch strings.ToLower(theseItems[1]) { + case "desc": + fallthrough + case "descending": + theseItems[1] = "DESC" + default: + theseItems[1] = "ASC" + } + + sortFields = append(sortFields, model.Sort{ + Field: theseItems[0], + Direction: theseItems[1], + }) + } else { + sortFields = append(sortFields, model.Sort{ + Field: sortItem, + Direction: "ASC", + }) + } + } + + // check against filter schema + for _, f := range sortFields { + if _, exists := filterMap[f.Field]; !exists { + return ctx, http.StatusBadRequest, "Invalid sort field" + } + } + + ctx = context.WithValue(ctx, c.SortCtxKey, sortFields) + + // No problems! + return ctx, 0, "" +} + +func listQueryFilters( + r *http.Request, + ctx context.Context, + schemaData string, +) (context.Context, int, string, interface{}) { + reservedFilterKeys := []string{ + "limit", + "offset", + "sort", + "expand", + "t", // This is used as a timestamp paramater in some clients and can be ignored + } + + var filters []model.Filter + for key, val := range r.URL.Query() { + key = strings.ToLower(key) + + // Split out the modifier from the field name and set a default modifier + var keyParts []string + keyParts = strings.Split(key, ":") + if len(keyParts) == 1 { + // Default modifier + keyParts = append(keyParts, "equals") + } + + // Only use this filter if it's not a reserved get param + if !util.SliceContainsItem(reservedFilterKeys, keyParts[0]) { + for _, valItem := range val { + // Check that the val isn't empty + if len(strings.TrimSpace(valItem)) > 0 { + valSlice := []string{valItem} + if keyParts[1] == "in" || keyParts[1] == "notin" { + valSlice = strings.Split(valItem, ",") + } + + filters = append(filters, model.Filter{ + Field: keyParts[0], + Modifier: keyParts[1], + Value: valSlice, + }) + } + } + } + } + + // Only validate schema if there are filters to validate + if len(filters) > 0 { + // Marshal the Filters in to a JSON string so that the Schema Validation works against it + filterData, marshalErr := json.MarshalIndent(filters, "", " ") + if marshalErr != nil { + return ctx, http.StatusInternalServerError, fmt.Sprintf("Schema Fatal: %v", marshalErr), nil + } + + // Create root schema + rs := &jsonschema.Schema{} + if err := json.Unmarshal([]byte(schemaData), rs); err != nil { + return ctx, http.StatusInternalServerError, fmt.Sprintf("Schema Fatal: %v", err), nil + } + + // Validate it + errors, jsonError := rs.ValidateBytes(ctx, filterData) + if jsonError != nil { + return ctx, http.StatusBadRequest, jsonError.Error(), nil + } + + if len(errors) > 0 { + return ctx, http.StatusBadRequest, "Invalid Filters", errors + } + + ctx = context.WithValue(ctx, c.FiltersCtxKey, filters) + } + + // No problems! + return ctx, 0, "", nil +} + +// GetFiltersFromContext returns the Filters +func GetFiltersFromContext(r *http.Request) []model.Filter { + filters, ok := r.Context().Value(c.FiltersCtxKey).([]model.Filter) + if !ok { + // the assertion failed + return nil + } + return filters +} + +// GetSortFromContext returns the Sort +func GetSortFromContext(r *http.Request) []model.Sort { + sorts, ok := r.Context().Value(c.SortCtxKey).([]model.Sort) + if !ok { + // the assertion failed + return nil + } + return sorts +} diff --git a/backend/internal/api/router.go b/backend/internal/api/router.go index f08fc93..069f86c 100644 --- a/backend/internal/api/router.go +++ b/backend/internal/api/router.go @@ -104,7 +104,7 @@ func applyRoutes(r chi.Router) chi.Router { // List r.With( middleware.Enforce(user.CapabilityUsersManage), - middleware.Filters(user.Model{}), + middleware.ListQuery(user.Model{}), ).Get("/", handler.GetUsers()) // Specific Item @@ -136,7 +136,7 @@ func applyRoutes(r chi.Router) chi.Router { r.With(middleware.EnforceSetup(true), middleware.Enforce(user.CapabilitySettingsManage)).Route("/settings", func(r chi.Router) { // List r.With( - middleware.Filters(setting.Model{}), + middleware.ListQuery(setting.Model{}), ).Get("/", handler.GetSettings()) r.Get("/{name}", handler.GetSetting()) @@ -151,7 +151,7 @@ func applyRoutes(r chi.Router) chi.Router { // List r.With( middleware.Enforce(user.CapabilityAccessListsView), - middleware.Filters(accesslist.Model{}), + middleware.ListQuery(accesslist.Model{}), ).Get("/", handler.GetAccessLists()) // Create @@ -175,7 +175,7 @@ func applyRoutes(r chi.Router) chi.Router { // List r.With( middleware.Enforce(user.CapabilityDNSProvidersView), - middleware.Filters(dnsprovider.Model{}), + middleware.ListQuery(dnsprovider.Model{}), ).Get("/", handler.GetDNSProviders()) // Create @@ -205,7 +205,7 @@ func applyRoutes(r chi.Router) chi.Router { // List r.With( middleware.Enforce(user.CapabilityCertificateAuthoritiesView), - middleware.Filters(certificateauthority.Model{}), + middleware.ListQuery(certificateauthority.Model{}), ).Get("/", handler.GetCertificateAuthorities()) // Create @@ -235,7 +235,7 @@ func applyRoutes(r chi.Router) chi.Router { // List r.With( middleware.Enforce(user.CapabilityCertificatesView), - middleware.Filters(certificate.Model{}), + middleware.ListQuery(certificate.Model{}), ).Get("/", handler.GetCertificates()) // Create @@ -262,7 +262,7 @@ func applyRoutes(r chi.Router) chi.Router { // List r.With( middleware.Enforce(user.CapabilityHostsView), - middleware.Filters(host.Model{}), + middleware.ListQuery(host.Model{}), ).Get("/", handler.GetHosts()) // Create @@ -288,7 +288,7 @@ func applyRoutes(r chi.Router) chi.Router { // List r.With( middleware.Enforce(user.CapabilityNginxTemplatesView), - middleware.Filters(nginxtemplate.Model{}), + middleware.ListQuery(nginxtemplate.Model{}), ).Get("/", handler.GetNginxTemplates()) // Create @@ -312,7 +312,7 @@ func applyRoutes(r chi.Router) chi.Router { // List r.With( middleware.Enforce(user.CapabilityStreamsView), - middleware.Filters(stream.Model{}), + middleware.ListQuery(stream.Model{}), ).Get("/", handler.GetStreams()) // Create @@ -336,7 +336,7 @@ func applyRoutes(r chi.Router) chi.Router { // List r.With( middleware.Enforce(user.CapabilityHostsView), - middleware.Filters(upstream.Model{}), + middleware.ListQuery(upstream.Model{}), ).Get("/", handler.GetUpstreams()) // Create diff --git a/backend/internal/entity/filters.go b/backend/internal/entity/filters.go index 317d275..8613f48 100644 --- a/backend/internal/entity/filters.go +++ b/backend/internal/entity/filters.go @@ -1,66 +1,22 @@ package entity import ( - "reflect" - "regexp" - "strings" + "npm/internal/model" + "npm/internal/tags" ) -type filterMapValue struct { - Type string - Field string -} - // GetFilterMap returns the filter map -func GetFilterMap(m interface{}, includeBaseEntity bool) map[string]filterMapValue { - filterMap := getFilterMapForInterface(m) +func GetFilterMap(m interface{}, includeBaseEntity bool) map[string]model.FilterMapValue { + filterMap := tags.GetFilterMap(m) if includeBaseEntity { - return mergeFilterMaps(getFilterMapForInterface(ModelBase{}), filterMap) + return mergeFilterMaps(tags.GetFilterMap(ModelBase{}), filterMap) } return filterMap } -func getFilterMapForInterface(m interface{}) map[string]filterMapValue { - var filterMap = make(map[string]filterMapValue) - - // TypeOf returns the reflection Type that represents the dynamic type of variable. - // If variable is a nil interface value, TypeOf returns nil. - t := reflect.TypeOf(m) - - // Iterate over all available fields and read the tag value - for i := 0; i < t.NumField(); i++ { - // Get the field, returns https://golang.org/pkg/reflect/#StructField - field := t.Field(i) - - // Get the field tag value - filterTag := field.Tag.Get("filter") - dbTag := field.Tag.Get("gorm") - if filterTag != "" && dbTag != "" && dbTag != "-" && filterTag != "-" { - // db can have many parts, we need to pull out the "column:value" part - dbField := field.Name - r := regexp.MustCompile(`(?:^|;)column:([^;|$]+)(?:$|;)`) - if matches := r.FindStringSubmatch(dbTag); len(matches) > 1 { - dbField = matches[1] - } - // Filter tag can be a 2 part thing: name,type - // ie: account_id,integer - // So we need to split and use the first part - parts := strings.Split(filterTag, ",") - if len(parts) > 1 { - filterMap[parts[0]] = filterMapValue{ - Type: parts[1], - Field: dbField, - } - } - } - } - - return filterMap -} - -func mergeFilterMaps(m1 map[string]filterMapValue, m2 map[string]filterMapValue) map[string]filterMapValue { - merged := make(map[string]filterMapValue, 0) +func mergeFilterMaps(m1 map[string]model.FilterMapValue, m2 map[string]model.FilterMapValue) map[string]model.FilterMapValue { + merged := make(map[string]model.FilterMapValue, 0) for k, v := range m1 { merged[k] = v } diff --git a/backend/internal/entity/lists.go b/backend/internal/entity/lists.go index 7b51e11..8507913 100644 --- a/backend/internal/entity/lists.go +++ b/backend/internal/entity/lists.go @@ -26,7 +26,7 @@ type ListResponse struct { func ListQueryBuilder( pageInfo *model.PageInfo, filters []model.Filter, - filterMap map[string]filterMapValue, + filterMap map[string]model.FilterMapValue, ) *gorm.DB { scopes := make([]func(*gorm.DB) *gorm.DB, 0) scopes = append(scopes, ScopeOffsetLimit(pageInfo)) diff --git a/backend/internal/entity/scopes.go b/backend/internal/entity/scopes.go index 9cca1c8..6bdcc07 100644 --- a/backend/internal/entity/scopes.go +++ b/backend/internal/entity/scopes.go @@ -36,7 +36,7 @@ func ScopeOrderBy(pageInfo *model.PageInfo, defaultSort model.Sort) func(db *gor } } -func ScopeFilters(filters []model.Filter, filterMap map[string]filterMapValue) func(db *gorm.DB) *gorm.DB { +func ScopeFilters(filters []model.Filter, filterMap map[string]model.FilterMapValue) func(db *gorm.DB) *gorm.DB { return func(db *gorm.DB) *gorm.DB { like := database.GetCaseInsensitiveLike() for _, f := range filters { diff --git a/backend/internal/model/filter.go b/backend/internal/model/filter.go index 8b88b70..fae252a 100644 --- a/backend/internal/model/filter.go +++ b/backend/internal/model/filter.go @@ -6,3 +6,9 @@ type Filter struct { Modifier string `json:"modifier"` Value []string `json:"value"` } + +// FilterMapValue ... +type FilterMapValue struct { + Type string + Field string +} diff --git a/backend/internal/tags/filters.go b/backend/internal/tags/filters.go new file mode 100644 index 0000000..bddb80b --- /dev/null +++ b/backend/internal/tags/filters.go @@ -0,0 +1,53 @@ +package tags + +import ( + "reflect" + "regexp" + "strings" + + "npm/internal/model" +) + +func GetFilterMap(m interface{}) map[string]model.FilterMapValue { + name := getName(m) + if val, exists := getCache(name); exists { + return val + } + + var filterMap = make(map[string]model.FilterMapValue) + + // TypeOf returns the reflection Type that represents the dynamic type of variable. + // If variable is a nil interface value, TypeOf returns nil. + t := reflect.TypeOf(m) + + // Iterate over all available fields and read the tag value + for i := 0; i < t.NumField(); i++ { + // Get the field, returns https://golang.org/pkg/reflect/#StructField + field := t.Field(i) + + // Get the field tag value + filterTag := field.Tag.Get("filter") + dbTag := field.Tag.Get("gorm") + if filterTag != "" && dbTag != "" && dbTag != "-" && filterTag != "-" { + // db can have many parts, we need to pull out the "column:value" part + dbField := field.Name + r := regexp.MustCompile(`(?:^|;)column:([^;|$]+)(?:$|;)`) + if matches := r.FindStringSubmatch(dbTag); len(matches) > 1 { + dbField = matches[1] + } + // Filter tag can be a 2 part thing: name,type + // ie: account_id,integer + // So we need to split and use the first part + parts := strings.Split(filterTag, ",") + if len(parts) > 1 { + filterMap[parts[0]] = model.FilterMapValue{ + Type: parts[1], + Field: dbField, + } + } + } + } + + setCache(name, filterMap) + return filterMap +} diff --git a/backend/internal/tags/reflect.go b/backend/internal/tags/reflect.go new file mode 100644 index 0000000..2f1f8b2 --- /dev/null +++ b/backend/internal/tags/reflect.go @@ -0,0 +1,33 @@ +package tags + +import ( + "fmt" + "reflect" + + "npm/internal/model" +) + +var tagCache map[string]map[string]model.FilterMapValue + +// getName returns the name of the type given +func getName(m interface{}) string { + fc := reflect.TypeOf(m) + return fmt.Sprint(fc) +} + +// getCache tries to find a cached item for this name +func getCache(name string) (map[string]model.FilterMapValue, bool) { + if val, ok := tagCache[name]; ok { + return val, true + } + return nil, false +} + +// setCache sets the name to this value +func setCache(name string, val map[string]model.FilterMapValue) { + // Hack to initialise empty map + if len(tagCache) == 0 { + tagCache = make(map[string]map[string]model.FilterMapValue, 0) + } + tagCache[name] = val +}