From 1d2f8d24f6d5e050cb8bf08c456b281f4a27df58 Mon Sep 17 00:00:00 2001 From: UUBulb <35923940+uubulb@users.noreply.github.com> Date: Fri, 28 Feb 2025 22:02:54 +0800 Subject: [PATCH] feat: update to go1.24 & support listening https (#1002) * feat: support listening https * refactor * modernize * support snake case in config * more precise control of config fields * update goreleaser config * remove kubeyaml * fix: expose agent_secret * chore --- .goreleaser.yml | 8 ++ cmd/dashboard/controller/controller.go | 8 +- cmd/dashboard/controller/fm.go | 4 +- cmd/dashboard/controller/jwt.go | 18 ++-- cmd/dashboard/controller/server.go | 11 +-- cmd/dashboard/controller/service.go | 6 +- cmd/dashboard/controller/setting.go | 31 ++++--- cmd/dashboard/controller/terminal.go | 4 +- cmd/dashboard/controller/ws.go | 5 +- cmd/dashboard/main.go | 46 ++++++++-- cmd/dashboard/rpc/rpc.go | 17 ++-- go.mod | 8 +- model/alertrule.go | 14 +-- model/common.go | 4 +- model/config.go | 116 +++++++++++++++---------- model/cron.go | 6 +- model/ddns.go | 6 +- model/notification.go | 3 +- model/server.go | 6 +- model/service.go | 14 +-- model/setting_api.go | 11 ++- model/waf.go | 2 +- pkg/utils/gjson.go | 8 +- pkg/utils/koanf.go | 71 +++++++++++++++ pkg/utils/utils.go | 63 ++++++++------ pkg/utils/utils_test.go | 2 +- service/singleton/alertsentinel.go | 2 +- service/singleton/user.go | 2 +- 28 files changed, 321 insertions(+), 175 deletions(-) create mode 100644 pkg/utils/koanf.go diff --git a/.goreleaser.yml b/.goreleaser.yml index c88cb92..55e342e 100644 --- a/.goreleaser.yml +++ b/.goreleaser.yml @@ -14,6 +14,8 @@ builds: flags: - -trimpath - -buildvcs=false + tags: + - go_json goos: - linux goarch: @@ -31,6 +33,8 @@ builds: flags: - -trimpath - -buildvcs=false + tags: + - go_json goos: - linux goarch: @@ -48,6 +52,8 @@ builds: flags: - -trimpath - -buildvcs=false + tags: + - go_json goos: - linux goarch: @@ -65,6 +71,8 @@ builds: flags: - -trimpath - -buildvcs=false + tags: + - go_json goos: - windows goarch: diff --git a/cmd/dashboard/controller/controller.go b/cmd/dashboard/controller/controller.go index ba268f9..f61474f 100644 --- a/cmd/dashboard/controller/controller.go +++ b/cmd/dashboard/controller/controller.go @@ -175,10 +175,10 @@ type pHandlerFunc[S ~[]E, E any] func(c *gin.Context) (*model.Value[S], error) // gorm errors here instead type gormError struct { msg string - a []interface{} + a []any } -func newGormError(format string, args ...interface{}) error { +func newGormError(format string, args ...any) error { return &gormError{ msg: format, a: args, @@ -191,10 +191,10 @@ func (ge *gormError) Error() string { type wsError struct { msg string - a []interface{} + a []any } -func newWsError(format string, args ...interface{}) error { +func newWsError(format string, args ...any) error { return &wsError{ msg: format, a: args, diff --git a/cmd/dashboard/controller/fm.go b/cmd/dashboard/controller/fm.go index 8124f47..8769911 100644 --- a/cmd/dashboard/controller/fm.go +++ b/cmd/dashboard/controller/fm.go @@ -5,11 +5,11 @@ import ( "time" "github.com/gin-gonic/gin" + "github.com/goccy/go-json" "github.com/gorilla/websocket" "github.com/hashicorp/go-uuid" "github.com/nezhahq/nezha/model" - "github.com/nezhahq/nezha/pkg/utils" "github.com/nezhahq/nezha/pkg/websocketx" "github.com/nezhahq/nezha/proto" "github.com/nezhahq/nezha/service/rpc" @@ -48,7 +48,7 @@ func createFM(c *gin.Context) (*model.CreateFMResponse, error) { rpc.NezhaHandlerSingleton.CreateStream(streamId) - fmData, _ := utils.Json.Marshal(&model.TaskFM{ + fmData, _ := json.Marshal(&model.TaskFM{ StreamID: streamId, }) if err := server.TaskStream.Send(&proto.Task{ diff --git a/cmd/dashboard/controller/jwt.go b/cmd/dashboard/controller/jwt.go index 337fec7..5323b95 100644 --- a/cmd/dashboard/controller/jwt.go +++ b/cmd/dashboard/controller/jwt.go @@ -1,12 +1,12 @@ package controller import ( - "encoding/json" "net/http" "time" jwt "github.com/appleboy/gin-jwt/v2" "github.com/gin-gonic/gin" + "github.com/goccy/go-json" "golang.org/x/crypto/bcrypt" "gorm.io/gorm" @@ -48,8 +48,8 @@ func initParams() *jwt.GinJWTMiddleware { } } -func payloadFunc() func(data interface{}) jwt.MapClaims { - return func(data interface{}) jwt.MapClaims { +func payloadFunc() func(data any) jwt.MapClaims { + return func(data any) jwt.MapClaims { if v, ok := data.(string); ok { return jwt.MapClaims{ model.CtxKeyAuthorizedUser: v, @@ -59,8 +59,8 @@ func payloadFunc() func(data interface{}) jwt.MapClaims { } } -func identityHandler() func(c *gin.Context) interface{} { - return func(c *gin.Context) interface{} { +func identityHandler() func(c *gin.Context) any { + return func(c *gin.Context) any { claims := jwt.ExtractClaims(c) userId := claims[model.CtxKeyAuthorizedUser].(string) var user model.User @@ -80,8 +80,8 @@ func identityHandler() func(c *gin.Context) interface{} { // @Produce json // @Success 200 {object} model.CommonResponse[model.LoginResponse] // @Router /login [post] -func authenticator() func(c *gin.Context) (interface{}, error) { - return func(c *gin.Context) (interface{}, error) { +func authenticator() func(c *gin.Context) (any, error) { + return func(c *gin.Context) (any, error) { var loginVals model.LoginRequest if err := c.ShouldBind(&loginVals); err != nil { return "", jwt.ErrMissingLoginValues @@ -113,8 +113,8 @@ func authenticator() func(c *gin.Context) (interface{}, error) { } } -func authorizator() func(data interface{}, c *gin.Context) bool { - return func(data interface{}, c *gin.Context) bool { +func authorizator() func(data any, c *gin.Context) bool { + return func(data any, c *gin.Context) bool { _, ok := data.(*model.User) return ok } diff --git a/cmd/dashboard/controller/server.go b/cmd/dashboard/controller/server.go index 5b039de..dc6839c 100644 --- a/cmd/dashboard/controller/server.go +++ b/cmd/dashboard/controller/server.go @@ -7,11 +7,11 @@ import ( "time" "github.com/gin-gonic/gin" + "github.com/goccy/go-json" "github.com/jinzhu/copier" "gorm.io/gorm" "github.com/nezhahq/nezha/model" - "github.com/nezhahq/nezha/pkg/utils" pb "github.com/nezhahq/nezha/proto" "github.com/nezhahq/nezha/service/singleton" ) @@ -81,13 +81,13 @@ func updateServer(c *gin.Context) (any, error) { s.DDNSProfiles = sf.DDNSProfiles s.OverrideDDNSDomains = sf.OverrideDDNSDomains - ddnsProfilesRaw, err := utils.Json.Marshal(s.DDNSProfiles) + ddnsProfilesRaw, err := json.Marshal(s.DDNSProfiles) if err != nil { return nil, err } s.DDNSProfilesRaw = string(ddnsProfilesRaw) - overrideDomainsRaw, err := utils.Json.Marshal(sf.OverrideDDNSDomains) + overrideDomainsRaw, err := json.Marshal(sf.OverrideDDNSDomains) if err != nil { return nil, err } @@ -281,10 +281,7 @@ func setServerConfig(c *gin.Context) (*model.ServerTaskResponse, error) { var respMu sync.Mutex for i := 0; i < len(servers); i += 10 { - end := i + 10 - if end > len(servers) { - end = len(servers) - } + end := min(i+10, len(servers)) group := servers[i:end] wg.Add(1) diff --git a/cmd/dashboard/controller/service.go b/cmd/dashboard/controller/service.go index 8cdf7e8..2d1bf6b 100644 --- a/cmd/dashboard/controller/service.go +++ b/cmd/dashboard/controller/service.go @@ -25,7 +25,7 @@ import ( // @Success 200 {object} model.CommonResponse[model.ServiceResponse] // @Router /service [get] func showService(c *gin.Context) (*model.ServiceResponse, error) { - res, err, _ := requestGroup.Do("list-service", func() (interface{}, error) { + res, err, _ := requestGroup.Do("list-service", func() (any, error) { singleton.AlertsLock.RLock() defer singleton.AlertsLock.RUnlock() stats := singleton.ServiceSentinelShared.CopyStats() @@ -41,8 +41,8 @@ func showService(c *gin.Context) (*model.ServiceResponse, error) { } return &model.ServiceResponse{ - Services: res.([]interface{})[0].(map[uint64]model.ServiceResponseItem), - CycleTransferStats: res.([]interface{})[1].(map[uint64]model.CycleTransferStats), + Services: res.([]any)[0].(map[uint64]model.ServiceResponseItem), + CycleTransferStats: res.([]any)[1].(map[uint64]model.CycleTransferStats), }, nil } diff --git a/cmd/dashboard/controller/setting.go b/cmd/dashboard/controller/setting.go index 333b321..2e15f70 100644 --- a/cmd/dashboard/controller/setting.go +++ b/cmd/dashboard/controller/setting.go @@ -17,9 +17,9 @@ import ( // @Security BearerAuth // @Tags common // @Produce json -// @Success 200 {object} model.CommonResponse[model.SettingResponse[model.Config]] +// @Success 200 {object} model.CommonResponse[model.SettingResponse] // @Router /setting [get] -func listConfig(c *gin.Context) (model.SettingResponse[any], error) { +func listConfig(c *gin.Context) (*model.SettingResponse, error) { u, authorized := c.Get(model.CtxKeyAuthorizedUser) var isAdmin bool if authorized { @@ -30,30 +30,29 @@ func listConfig(c *gin.Context) (model.SettingResponse[any], error) { config := *singleton.Conf config.Language = strings.Replace(config.Language, "_", "-", -1) - conf := model.SettingResponse[any]{ - Config: config, + conf := model.SettingResponse{ + Config: model.Setting{ + ConfigForGuests: config.ConfigForGuests, + ConfigDashboard: config.ConfigDashboard, + }, Version: singleton.Version, FrontendTemplates: singleton.FrontendTemplates, } if !authorized || !isAdmin { - configForGuests := model.ConfigForGuests{ - Language: config.Language, - SiteName: config.SiteName, - CustomCode: config.CustomCode, - CustomCodeDashboard: config.CustomCodeDashboard, - Oauth2Providers: config.Oauth2Providers, - } + configForGuests := config.ConfigForGuests if authorized { - configForGuests.TLS = singleton.Conf.TLS + configForGuests.AgentTLS = singleton.Conf.AgentTLS configForGuests.InstallHost = singleton.Conf.InstallHost } - conf = model.SettingResponse[any]{ - Config: configForGuests, + conf = model.SettingResponse{ + Config: model.Setting{ + ConfigForGuests: configForGuests, + }, } } - return conf, nil + return &conf, nil } // Edit config @@ -98,7 +97,7 @@ func updateConfig(c *gin.Context) (any, error) { singleton.Conf.CustomCode = sf.CustomCode singleton.Conf.CustomCodeDashboard = sf.CustomCodeDashboard singleton.Conf.RealIPHeader = sf.RealIPHeader - singleton.Conf.TLS = sf.TLS + singleton.Conf.AgentTLS = sf.AgentTLS singleton.Conf.UserTemplate = sf.UserTemplate if err := singleton.Conf.Save(); err != nil { diff --git a/cmd/dashboard/controller/terminal.go b/cmd/dashboard/controller/terminal.go index ec6a7d1..7899f1b 100644 --- a/cmd/dashboard/controller/terminal.go +++ b/cmd/dashboard/controller/terminal.go @@ -4,11 +4,11 @@ import ( "time" "github.com/gin-gonic/gin" + "github.com/goccy/go-json" "github.com/gorilla/websocket" "github.com/hashicorp/go-uuid" "github.com/nezhahq/nezha/model" - "github.com/nezhahq/nezha/pkg/utils" "github.com/nezhahq/nezha/pkg/websocketx" "github.com/nezhahq/nezha/proto" "github.com/nezhahq/nezha/service/rpc" @@ -46,7 +46,7 @@ func createTerminal(c *gin.Context) (*model.CreateTerminalResponse, error) { rpc.NezhaHandlerSingleton.CreateStream(streamId) - terminalData, _ := utils.Json.Marshal(&model.TerminalTask{ + terminalData, _ := json.Marshal(&model.TerminalTask{ StreamID: streamId, }) if err := server.TaskStream.Send(&proto.Task{ diff --git a/cmd/dashboard/controller/ws.go b/cmd/dashboard/controller/ws.go index d409ec5..0efb3c7 100644 --- a/cmd/dashboard/controller/ws.go +++ b/cmd/dashboard/controller/ws.go @@ -9,6 +9,7 @@ import ( "unicode/utf8" "github.com/gin-gonic/gin" + "github.com/goccy/go-json" "github.com/gorilla/websocket" "github.com/hashicorp/go-uuid" "golang.org/x/sync/singleflight" @@ -157,7 +158,7 @@ func serverStream(c *gin.Context) (any, error) { var requestGroup singleflight.Group func getServerStat(withPublicNote, authorized bool) ([]byte, error) { - v, err, _ := requestGroup.Do(fmt.Sprintf("serverStats::%t", authorized), func() (interface{}, error) { + v, err, _ := requestGroup.Do(fmt.Sprintf("serverStats::%t", authorized), func() (any, error) { var serverList []*model.Server if authorized { serverList = singleton.ServerShared.GetSortedList() @@ -183,7 +184,7 @@ func getServerStat(withPublicNote, authorized bool) ([]byte, error) { }) } - return utils.Json.Marshal(model.StreamServerData{ + return json.Marshal(model.StreamServerData{ Now: time.Now().Unix() * 1000, Online: singleton.GetOnlineUserCount(), Servers: servers, diff --git a/cmd/dashboard/main.go b/cmd/dashboard/main.go index c7c15b3..bafc384 100644 --- a/cmd/dashboard/main.go +++ b/cmd/dashboard/main.go @@ -2,7 +2,9 @@ package main import ( "context" + "crypto/tls" "embed" + "errors" "flag" "fmt" "log" @@ -16,8 +18,6 @@ import ( "github.com/gin-gonic/gin" "github.com/ory/graceful" "golang.org/x/crypto/bcrypt" - "golang.org/x/net/http2" - "golang.org/x/net/http2/h2c" "github.com/nezhahq/nezha/cmd/dashboard/controller" "github.com/nezhahq/nezha/cmd/dashboard/controller/waf" @@ -133,20 +133,54 @@ func main() { controller.InitUpgrader() muxHandler := newHTTPandGRPCMux(httpHandler, grpcHandler) - http2Server := &http2.Server{} - muxServer := &http.Server{Handler: h2c.NewHandler(muxHandler, http2Server), ReadHeaderTimeout: time.Second * 5} + muxServerHTTP := &http.Server{ + Handler: muxHandler, + ReadHeaderTimeout: time.Second * 5, + } + muxServerHTTP.Protocols = new(http.Protocols) + muxServerHTTP.Protocols.SetHTTP1(true) + muxServerHTTP.Protocols.SetUnencryptedHTTP2(true) + + var muxServerHTTPS *http.Server + if singleton.Conf.HTTPS.ListenPort != 0 { + muxServerHTTPS = &http.Server{ + Addr: fmt.Sprintf("%s:%d", singleton.Conf.ListenHost, singleton.Conf.HTTPS.ListenPort), + Handler: muxHandler, + ReadHeaderTimeout: time.Second * 5, + TLSConfig: &tls.Config{ + InsecureSkipVerify: singleton.Conf.HTTPS.InsecureTLS, + }, + } + } + + errChan := make(chan error, 2) if err := graceful.Graceful(func() error { log.Printf("NEZHA>> Dashboard::START ON %s:%d", singleton.Conf.ListenHost, singleton.Conf.ListenPort) - return muxServer.Serve(l) + if singleton.Conf.HTTPS.ListenPort != 0 { + go func() { + errChan <- muxServerHTTPS.ListenAndServeTLS(singleton.Conf.HTTPS.TLSCertPath, singleton.Conf.HTTPS.TLSKeyPath) + }() + log.Printf("NEZHA>> Dashboard::START ON %s:%d", singleton.Conf.ListenHost, singleton.Conf.HTTPS.ListenPort) + } + go func() { + errChan <- muxServerHTTP.Serve(l) + }() + return <-errChan }, func(c context.Context) error { log.Println("NEZHA>> Graceful::START") singleton.RecordTransferHourlyUsage() log.Println("NEZHA>> Graceful::END") - return muxServer.Shutdown(c) + err := muxServerHTTPS.Shutdown(c) + return errors.Join(muxServerHTTP.Shutdown(c), err) }); err != nil { log.Printf("NEZHA>> ERROR: %v", err) + if errors.Unwrap(err) != nil { + log.Printf("NEZHA>> ERROR HTTPS: %v", err) + } } + + close(errChan) } func newHTTPandGRPCMux(httpHandler http.Handler, grpcHandler http.Handler) http.Handler { diff --git a/cmd/dashboard/rpc/rpc.go b/cmd/dashboard/rpc/rpc.go index aa84398..215b959 100644 --- a/cmd/dashboard/rpc/rpc.go +++ b/cmd/dashboard/rpc/rpc.go @@ -8,6 +8,7 @@ import ( "net/netip" "time" + "github.com/goccy/go-json" "google.golang.org/grpc" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" @@ -27,7 +28,7 @@ func ServeRPC() *grpc.Server { return server } -func waf(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { +func waf(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { realip, _ := ctx.Value(model.CtxKeyRealIP{}).(string) if err := model.CheckIP(singleton.DB, realip); err != nil { return nil, err @@ -35,7 +36,7 @@ func waf(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handl return handler(ctx, req) } -func getRealIp(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { +func getRealIp(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { var ip, connectingIp string p, ok := peer.FromContext(ctx) if ok { @@ -133,20 +134,20 @@ func ServeNAT(w http.ResponseWriter, r *http.Request, natConfig *model.NAT) { streamId, err := uuid.GenerateUUID() if err != nil { w.WriteHeader(http.StatusServiceUnavailable) - w.Write([]byte(fmt.Sprintf("stream id error: %v", err))) + w.Write(fmt.Appendf(nil, "stream id error: %v", err)) return } rpcService.NezhaHandlerSingleton.CreateStream(streamId) defer rpcService.NezhaHandlerSingleton.CloseStream(streamId) - taskData, err := utils.Json.Marshal(model.TaskNAT{ + taskData, err := json.Marshal(model.TaskNAT{ StreamID: streamId, Host: natConfig.Host, }) if err != nil { w.WriteHeader(http.StatusServiceUnavailable) - w.Write([]byte(fmt.Sprintf("task data error: %v", err))) + w.Write(fmt.Appendf(nil, "task data error: %v", err)) return } @@ -155,20 +156,20 @@ func ServeNAT(w http.ResponseWriter, r *http.Request, natConfig *model.NAT) { Data: string(taskData), }); err != nil { w.WriteHeader(http.StatusServiceUnavailable) - w.Write([]byte(fmt.Sprintf("send task error: %v", err))) + w.Write(fmt.Appendf(nil, "send task error: %v", err)) return } wWrapped, err := utils.NewRequestWrapper(r, w) if err != nil { w.WriteHeader(http.StatusServiceUnavailable) - w.Write([]byte(fmt.Sprintf("request wrapper error: %v", err))) + w.Write(fmt.Appendf(nil, "request wrapper error: %v", err)) return } if err := rpcService.NezhaHandlerSingleton.UserConnected(streamId, wWrapped); err != nil { w.WriteHeader(http.StatusServiceUnavailable) - w.Write([]byte(fmt.Sprintf("user connected error: %v", err))) + w.Write(fmt.Appendf(nil, "user connected error: %v", err)) return } diff --git a/go.mod b/go.mod index d1a3663..00cace3 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/nezhahq/nezha -go 1.23.6 +go 1.24.0 require ( github.com/appleboy/gin-jwt/v2 v2.10.1 @@ -8,10 +8,11 @@ require ( github.com/dustinkirkland/golang-petname v0.0.0-20240428194347-eebcea082ee0 github.com/gin-contrib/pprof v1.5.2 github.com/gin-gonic/gin v1.10.0 + github.com/go-viper/mapstructure/v2 v2.2.1 + github.com/goccy/go-json v0.10.5 github.com/gorilla/websocket v1.5.3 github.com/hashicorp/go-uuid v1.0.3 github.com/jinzhu/copier v0.4.0 - github.com/json-iterator/go v1.1.12 github.com/knadh/koanf/parsers/yaml v0.1.0 github.com/knadh/koanf/providers/env v1.0.0 github.com/knadh/koanf/providers/file v1.1.2 @@ -56,12 +57,11 @@ require ( github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.25.0 // indirect - github.com/go-viper/mapstructure/v2 v2.2.1 // indirect - github.com/goccy/go-json v0.10.5 // indirect github.com/golang-jwt/jwt/v4 v4.5.1 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/josharian/intern v1.0.0 // indirect + github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.2.9 // indirect github.com/knadh/koanf/maps v0.1.1 // indirect github.com/leodido/go-urn v1.4.0 // indirect diff --git a/model/alertrule.go b/model/alertrule.go index 9b4fca0..459c04d 100644 --- a/model/alertrule.go +++ b/model/alertrule.go @@ -1,7 +1,7 @@ package model import ( - "github.com/nezhahq/nezha/pkg/utils" + "github.com/goccy/go-json" "gorm.io/gorm" ) @@ -25,17 +25,17 @@ type AlertRule struct { } func (r *AlertRule) BeforeSave(tx *gorm.DB) error { - if data, err := utils.Json.Marshal(r.Rules); err != nil { + if data, err := json.Marshal(r.Rules); err != nil { return err } else { r.RulesRaw = string(data) } - if data, err := utils.Json.Marshal(r.FailTriggerTasks); err != nil { + if data, err := json.Marshal(r.FailTriggerTasks); err != nil { return err } else { r.FailTriggerTasksRaw = string(data) } - if data, err := utils.Json.Marshal(r.RecoverTriggerTasks); err != nil { + if data, err := json.Marshal(r.RecoverTriggerTasks); err != nil { return err } else { r.RecoverTriggerTasksRaw = string(data) @@ -45,13 +45,13 @@ func (r *AlertRule) BeforeSave(tx *gorm.DB) error { func (r *AlertRule) AfterFind(tx *gorm.DB) error { var err error - if err = utils.Json.Unmarshal([]byte(r.RulesRaw), &r.Rules); err != nil { + if err = json.Unmarshal([]byte(r.RulesRaw), &r.Rules); err != nil { return err } - if err = utils.Json.Unmarshal([]byte(r.FailTriggerTasksRaw), &r.FailTriggerTasks); err != nil { + if err = json.Unmarshal([]byte(r.FailTriggerTasksRaw), &r.FailTriggerTasks); err != nil { return err } - if err = utils.Json.Unmarshal([]byte(r.RecoverTriggerTasksRaw), &r.RecoverTriggerTasks); err != nil { + if err = json.Unmarshal([]byte(r.RecoverTriggerTasksRaw), &r.RecoverTriggerTasks); err != nil { return err } return nil diff --git a/model/common.go b/model/common.go index 0cb6bf0..182b21b 100644 --- a/model/common.go +++ b/model/common.go @@ -77,7 +77,7 @@ func SearchByIDCtx[S ~[]E, E CommonInterface](c *gin.Context, x S) S { return any(l).(S) default: var s S - for _, idStr := range strings.Split(c.Query("id"), ",") { + for idStr := range strings.SplitSeq(c.Query("id"), ",") { id, err := strconv.ParseUint(idStr, 10, 64) if err != nil { continue @@ -93,7 +93,7 @@ func searchByIDCtxServer(c *gin.Context, x []*Server) []*Server { list1, list2 := SplitList(x) var clist1, clist2 []*Server - for _, idStr := range strings.Split(c.Query("id"), ",") { + for idStr := range strings.SplitSeq(c.Query("id"), ",") { id, err := strconv.ParseUint(idStr, 10, 64) if err != nil { continue diff --git a/model/config.go b/model/config.go index f7ec925..6c5486c 100644 --- a/model/config.go +++ b/model/config.go @@ -1,18 +1,16 @@ package model import ( - "maps" "os" "path/filepath" - "slices" "strconv" "strings" + "github.com/go-viper/mapstructure/v2" kyaml "github.com/knadh/koanf/parsers/yaml" "github.com/knadh/koanf/providers/env" "github.com/knadh/koanf/providers/file" "github.com/knadh/koanf/v2" - "gopkg.in/yaml.v3" "github.com/nezhahq/nezha/pkg/utils" ) @@ -24,52 +22,56 @@ const ( ) type ConfigForGuests struct { - Language string `json:"language"` - SiteName string `json:"site_name"` - CustomCode string `json:"custom_code,omitempty"` - CustomCodeDashboard string `json:"custom_code_dashboard,omitempty"` - Oauth2Providers []string `json:"oauth2_providers,omitempty"` + Language string `koanf:"language" json:"language"` // 系统语言,默认 zh_CN + SiteName string `koanf:"site_name" json:"site_name"` + CustomCode string `koanf:"custom_code" json:"custom_code,omitempty"` + CustomCodeDashboard string `koanf:"custom_code_dashboard" json:"custom_code_dashboard,omitempty"` + Oauth2Providers []string `koanf:"-" json:"oauth2_providers,omitempty"` // oauth2 供应商列表,无需配置,自动生成 - InstallHost string `json:"install_host,omitempty"` - TLS bool `json:"tls,omitempty"` + InstallHost string `koanf:"install_host" json:"install_host,omitempty"` + AgentTLS bool `koanf:"tls" json:"tls,omitempty"` // 用于前端判断生成的安装命令是否启用 TLS +} + +type ConfigDashboard struct { + Debug bool `koanf:"debug" json:"debug,omitempty"` // debug模式开关 + RealIPHeader string `koanf:"real_ip_header" json:"real_ip_header,omitempty"` // 真实IP + UserTemplate string `koanf:"user_template" json:"user_template,omitempty"` + AdminTemplate string `koanf:"admin_template" json:"admin_template,omitempty"` + Location string `koanf:"location" json:"location,omitempty"` // 时区,默认为 Asia/Shanghai + ForceAuth bool `koanf:"force_auth" json:"force_auth,omitempty"` // 强制要求认证 + AgentSecretKey string `koanf:"agent_secret_key" json:"agent_secret_key,omitempty"` + + EnablePlainIPInNotification bool `koanf:"enable_plain_ip_in_notification" json:"enable_plain_ip_in_notification,omitempty"` // 通知信息IP不打码 + + // IP变更提醒 + EnableIPChangeNotification bool `koanf:"enable_ip_change_notification" json:"enable_ip_change_notification,omitempty"` + IPChangeNotificationGroupID uint64 `koanf:"ip_change_notification_group_id" json:"ip_change_notification_group_id"` + Cover uint8 `koanf:"cover" json:"cover"` // 覆盖范围(0:提醒未被 IgnoredIPNotification 包含的所有服务器; 1:仅提醒被 IgnoredIPNotification 包含的服务器;) + IgnoredIPNotification string `koanf:"ignored_ip_notification" json:"ignored_ip_notification,omitempty"` // 特定服务器IP(多个服务器用逗号分隔) + + IgnoredIPNotificationServerIDs map[uint64]bool `koanf:"ignored_ip_notification_server_ids" json:"ignored_ip_notification_server_ids,omitempty"` // [ServerID] -> bool(值为true代表当前ServerID在特定服务器列表内) + AvgPingCount int `koanf:"avg_ping_count" json:"avg_ping_count,omitempty"` + DNSServers string `koanf:"dns_servers" json:"dns_servers,omitempty"` } type Config struct { - Debug bool `mapstructure:"debug" json:"debug,omitempty"` // debug模式开关 - RealIPHeader string `mapstructure:"real_ip_header" json:"real_ip_header,omitempty"` // 真实IP + ConfigForGuests + ConfigDashboard - Language string `mapstructure:"language" json:"language"` // 系统语言,默认 zh_CN - SiteName string `mapstructure:"site_name" json:"site_name"` - UserTemplate string `mapstructure:"user_template" json:"user_template,omitempty"` - AdminTemplate string `mapstructure:"admin_template" json:"admin_template,omitempty"` - JWTSecretKey string `mapstructure:"jwt_secret_key" json:"jwt_secret_key,omitempty"` - AgentSecretKey string `mapstructure:"agent_secret_key" json:"agent_secret_key,omitempty"` - ListenPort uint `mapstructure:"listen_port" json:"listen_port,omitempty"` - ListenHost string `mapstructure:"listen_host" json:"listen_host,omitempty"` - InstallHost string `mapstructure:"install_host" json:"install_host,omitempty"` - TLS bool `mapstructure:"tls" json:"tls,omitempty"` - Location string `mapstructure:"location" json:"location,omitempty"` // 时区,默认为 Asia/Shanghai - ForceAuth bool `mapstructure:"force_auth" json:"force_auth,omitempty"` // 强制要求认证 - - EnablePlainIPInNotification bool `mapstructure:"enable_plain_ip_in_notification" json:"enable_plain_ip_in_notification,omitempty"` // 通知信息IP不打码 - - // IP变更提醒 - EnableIPChangeNotification bool `mapstructure:"enable_ip_change_notification" json:"enable_ip_change_notification,omitempty"` - IPChangeNotificationGroupID uint64 `mapstructure:"ip_change_notification_group_id" json:"ip_change_notification_group_id"` - Cover uint8 `mapstructure:"cover" json:"cover"` // 覆盖范围(0:提醒未被 IgnoredIPNotification 包含的所有服务器; 1:仅提醒被 IgnoredIPNotification 包含的服务器;) - IgnoredIPNotification string `mapstructure:"ignored_ip_notification" json:"ignored_ip_notification,omitempty"` // 特定服务器IP(多个服务器用逗号分隔) - - IgnoredIPNotificationServerIDs map[uint64]bool `mapstructure:"ignored_ip_notification_server_ids" json:"ignored_ip_notification_server_ids,omitempty"` // [ServerID] -> bool(值为true代表当前ServerID在特定服务器列表内) - AvgPingCount int `mapstructure:"avg_ping_count" json:"avg_ping_count,omitempty"` - DNSServers string `mapstructure:"dns_servers" json:"dns_servers,omitempty"` - - CustomCode string `mapstructure:"custom_code" json:"custom_code,omitempty"` - CustomCodeDashboard string `mapstructure:"custom_code_dashboard" json:"custom_code_dashboard,omitempty"` + JWTSecretKey string `koanf:"jwt_secret_key" json:"jwt_secret_key,omitempty"` + ListenPort uint16 `koanf:"listen_port" json:"listen_port,omitempty"` + ListenHost string `koanf:"listen_host" json:"listen_host,omitempty"` // oauth2 配置 - Oauth2 map[string]*Oauth2Config `mapstructure:"oauth2" json:"oauth2,omitempty"` - // oauth2 供应商列表,无需配置,自动生成 - Oauth2Providers []string `yaml:"-" json:"oauth2_providers,omitempty"` + Oauth2 map[string]*Oauth2Config `koanf:"oauth2" json:"oauth2,omitempty"` + + // HTTPS 配置 + HTTPS struct { + ListenPort uint16 `koanf:"listen_port" json:"listen_port,omitempty"` + TLSCertPath string `koanf:"tls_cert_path" json:"tls_cert_path,omitempty"` + TLSKeyPath string `koanf:"tls_key_path" json:"tls_key_path,omitempty"` + InsecureTLS bool `koanf:"insecure_tls" json:"insecure_tls,omitempty"` + } `koanf:"https" json:"https"` k *koanf.Koanf `json:"-"` filePath string `json:"-"` @@ -94,10 +96,11 @@ func (c *Config) Read(path string, frontendTemplates []FrontendTemplate) error { } } - err = c.k.Unmarshal("", c) + err = c.k.UnmarshalWithConf("", c, koanfConf(c)) if err != nil { return err } + if c.ListenPort == 0 { c.ListenPort = 8008 } @@ -151,7 +154,7 @@ func (c *Config) Read(path string, frontendTemplates []FrontendTemplate) error { } } - c.Oauth2Providers = slices.Collect(maps.Keys(c.Oauth2)) + c.Oauth2Providers = utils.MapKeysToSlice(c.Oauth2) c.updateIgnoredIPNotificationID() return nil @@ -160,9 +163,8 @@ func (c *Config) Read(path string, frontendTemplates []FrontendTemplate) error { // updateIgnoredIPNotificationID 更新用于判断服务器ID是否属于特定服务器的map func (c *Config) updateIgnoredIPNotificationID() { c.IgnoredIPNotificationServerIDs = make(map[uint64]bool) - splitedIDs := strings.Split(c.IgnoredIPNotification, ",") - for i := 0; i < len(splitedIDs); i++ { - id, _ := strconv.ParseUint(splitedIDs[i], 10, 64) + for splitedID := range strings.SplitSeq(c.IgnoredIPNotification, ",") { + id, _ := strconv.ParseUint(splitedID, 10, 64) if id > 0 { c.IgnoredIPNotificationServerIDs[id] = true } @@ -172,7 +174,7 @@ func (c *Config) updateIgnoredIPNotificationID() { // Save 保存配置文件 func (c *Config) Save() error { c.updateIgnoredIPNotificationID() - data, err := yaml.Marshal(c) + data, err := c.k.Marshal(kyaml.Parser()) if err != nil { return err } @@ -184,3 +186,21 @@ func (c *Config) Save() error { return os.WriteFile(c.filePath, data, 0600) } + +func koanfConf(c any) koanf.UnmarshalConf { + return koanf.UnmarshalConf{ + DecoderConfig: &mapstructure.DecoderConfig{ + DecodeHook: mapstructure.ComposeDecodeHookFunc( + mapstructure.StringToTimeDurationHookFunc(), + utils.TextUnmarshalerHookFunc()), + Metadata: nil, + Result: c, + WeaklyTypedInput: true, + MatchName: func(mapKey, fieldName string) bool { + return strings.EqualFold(mapKey, fieldName) || + strings.EqualFold(mapKey, strings.ReplaceAll(fieldName, "_", "")) + }, + Squash: true, + }, + } +} diff --git a/model/cron.go b/model/cron.go index 76a9c30..418c421 100644 --- a/model/cron.go +++ b/model/cron.go @@ -3,7 +3,7 @@ package model import ( "time" - "github.com/nezhahq/nezha/pkg/utils" + "github.com/goccy/go-json" "github.com/robfig/cron/v3" "gorm.io/gorm" ) @@ -34,7 +34,7 @@ type Cron struct { } func (c *Cron) BeforeSave(tx *gorm.DB) error { - if data, err := utils.Json.Marshal(c.Servers); err != nil { + if data, err := json.Marshal(c.Servers); err != nil { return err } else { c.ServersRaw = string(data) @@ -43,5 +43,5 @@ func (c *Cron) BeforeSave(tx *gorm.DB) error { } func (c *Cron) AfterFind(tx *gorm.DB) error { - return utils.Json.Unmarshal([]byte(c.ServersRaw), &c.Servers) + return json.Unmarshal([]byte(c.ServersRaw), &c.Servers) } diff --git a/model/ddns.go b/model/ddns.go index fe942b9..601476e 100644 --- a/model/ddns.go +++ b/model/ddns.go @@ -1,7 +1,7 @@ package model import ( - "github.com/nezhahq/nezha/pkg/utils" + "github.com/goccy/go-json" "gorm.io/gorm" ) @@ -39,7 +39,7 @@ func (d DDNSProfile) TableName() string { } func (d *DDNSProfile) BeforeSave(tx *gorm.DB) error { - if data, err := utils.Json.Marshal(d.Domains); err != nil { + if data, err := json.Marshal(d.Domains); err != nil { return err } else { d.DomainsRaw = string(data) @@ -48,5 +48,5 @@ func (d *DDNSProfile) BeforeSave(tx *gorm.DB) error { } func (d *DDNSProfile) AfterFind(tx *gorm.DB) error { - return utils.Json.Unmarshal([]byte(d.DomainsRaw), &d.Domains) + return json.Unmarshal([]byte(d.DomainsRaw), &d.Domains) } diff --git a/model/notification.go b/model/notification.go index dcde788..e6f570d 100644 --- a/model/notification.go +++ b/model/notification.go @@ -9,6 +9,7 @@ import ( "strings" "time" + "github.com/goccy/go-json" "github.com/nezhahq/nezha/pkg/utils" ) @@ -66,7 +67,7 @@ func (ns *NotificationServerBundle) reqBody(message string) (string, error) { switch n.RequestType { case NotificationRequestTypeJSON: return ns.replaceParamsInString(n.RequestBody, message, func(msg string) string { - msgBytes, _ := utils.Json.Marshal(msg) + msgBytes, _ := json.Marshal(msg) return string(msgBytes)[1 : len(msgBytes)-1] }), nil case NotificationRequestTypeForm: diff --git a/model/server.go b/model/server.go index ffd54e9..e87276c 100644 --- a/model/server.go +++ b/model/server.go @@ -5,9 +5,9 @@ import ( "slices" "time" + "github.com/goccy/go-json" "gorm.io/gorm" - "github.com/nezhahq/nezha/pkg/utils" pb "github.com/nezhahq/nezha/proto" ) @@ -59,13 +59,13 @@ func (s *Server) CopyFromRunningServer(old *Server) { func (s *Server) AfterFind(tx *gorm.DB) error { if s.DDNSProfilesRaw != "" { - if err := utils.Json.Unmarshal([]byte(s.DDNSProfilesRaw), &s.DDNSProfiles); err != nil { + if err := json.Unmarshal([]byte(s.DDNSProfilesRaw), &s.DDNSProfiles); err != nil { log.Println("NEZHA>> Server.AfterFind:", err) return nil } } if s.OverrideDDNSDomainsRaw != "" { - if err := utils.Json.Unmarshal([]byte(s.OverrideDDNSDomainsRaw), &s.OverrideDDNSDomains); err != nil { + if err := json.Unmarshal([]byte(s.OverrideDDNSDomainsRaw), &s.OverrideDDNSDomains); err != nil { log.Println("NEZHA>> Server.AfterFind:", err) return nil } diff --git a/model/service.go b/model/service.go index 76e423c..0be9f38 100644 --- a/model/service.go +++ b/model/service.go @@ -4,10 +4,10 @@ import ( "fmt" "log" + "github.com/goccy/go-json" "github.com/robfig/cron/v3" "gorm.io/gorm" - "github.com/nezhahq/nezha/pkg/utils" pb "github.com/nezhahq/nezha/proto" ) @@ -91,17 +91,17 @@ func (m *Service) CronSpec() string { } func (m *Service) BeforeSave(tx *gorm.DB) error { - if data, err := utils.Json.Marshal(m.SkipServers); err != nil { + if data, err := json.Marshal(m.SkipServers); err != nil { return err } else { m.SkipServersRaw = string(data) } - if data, err := utils.Json.Marshal(m.FailTriggerTasks); err != nil { + if data, err := json.Marshal(m.FailTriggerTasks); err != nil { return err } else { m.FailTriggerTasksRaw = string(data) } - if data, err := utils.Json.Marshal(m.RecoverTriggerTasks); err != nil { + if data, err := json.Marshal(m.RecoverTriggerTasks); err != nil { return err } else { m.RecoverTriggerTasksRaw = string(data) @@ -111,16 +111,16 @@ func (m *Service) BeforeSave(tx *gorm.DB) error { func (m *Service) AfterFind(tx *gorm.DB) error { m.SkipServers = make(map[uint64]bool) - if err := utils.Json.Unmarshal([]byte(m.SkipServersRaw), &m.SkipServers); err != nil { + if err := json.Unmarshal([]byte(m.SkipServersRaw), &m.SkipServers); err != nil { log.Println("NEZHA>> Service.AfterFind:", err) return nil } // 加载触发任务列表 - if err := utils.Json.Unmarshal([]byte(m.FailTriggerTasksRaw), &m.FailTriggerTasks); err != nil { + if err := json.Unmarshal([]byte(m.FailTriggerTasksRaw), &m.FailTriggerTasks); err != nil { return err } - if err := utils.Json.Unmarshal([]byte(m.RecoverTriggerTasksRaw), &m.RecoverTriggerTasks); err != nil { + if err := json.Unmarshal([]byte(m.RecoverTriggerTasksRaw), &m.RecoverTriggerTasks); err != nil { return err } diff --git a/model/setting_api.go b/model/setting_api.go index bb445d7..1f2fa29 100644 --- a/model/setting_api.go +++ b/model/setting_api.go @@ -13,11 +13,16 @@ type SettingForm struct { RealIPHeader string `json:"real_ip_header,omitempty" validate:"optional"` // 真实IP UserTemplate string `json:"user_template,omitempty" validate:"optional"` - TLS bool `json:"tls,omitempty" validate:"optional"` + AgentTLS bool `json:"tls,omitempty" validate:"optional"` EnableIPChangeNotification bool `json:"enable_ip_change_notification,omitempty" validate:"optional"` EnablePlainIPInNotification bool `json:"enable_plain_ip_in_notification,omitempty" validate:"optional"` } +type Setting struct { + ConfigForGuests + ConfigDashboard +} + type FrontendTemplate struct { Path string `json:"path,omitempty"` Name string `json:"name,omitempty"` @@ -28,8 +33,8 @@ type FrontendTemplate struct { IsOfficial bool `json:"is_official,omitempty"` } -type SettingResponse[T any] struct { - Config T `json:"config,omitempty"` +type SettingResponse struct { + Config Setting `json:"config"` Version string `json:"version,omitempty"` FrontendTemplates []FrontendTemplate `json:"frontend_templates,omitempty"` diff --git a/model/waf.go b/model/waf.go index ae3ecbf..60959cb 100644 --- a/model/waf.go +++ b/model/waf.go @@ -117,7 +117,7 @@ func BlockIP(db *gorm.DB, ip string, reason uint8, uid int64) error { } now := uint64(time.Now().Unix()) - var count interface{} + var count any if reason == WAFBlockReasonTypeManual { count = 99999 } else { diff --git a/pkg/utils/gjson.go b/pkg/utils/gjson.go index 77594a7..15df2eb 100644 --- a/pkg/utils/gjson.go +++ b/pkg/utils/gjson.go @@ -33,9 +33,7 @@ func GjsonIter(json string) (iter.Seq2[string, string], error) { return nil, ErrGjsonWrongType } - return func(yield func(string, string) bool) { - result.ForEach(func(k, v gjson.Result) bool { - return yield(k.String(), v.String()) - }) - }, nil + return ConvertSeq2(result.ForEach, func(k, v gjson.Result) (string, string) { + return k.String(), v.String() + }), nil } diff --git a/pkg/utils/koanf.go b/pkg/utils/koanf.go new file mode 100644 index 0000000..dcc66b9 --- /dev/null +++ b/pkg/utils/koanf.go @@ -0,0 +1,71 @@ +package utils + +import ( + "encoding" + "reflect" + + "github.com/go-viper/mapstructure/v2" +) + +// TextUnmarshalerHookFunc is a fixed version of mapstructure.TextUnmarshallerHookFunc. +// This hook allows to additionally unmarshal text into custom string types that implement the encoding.Text(Un)Marshaler interface(s). +func TextUnmarshalerHookFunc() mapstructure.DecodeHookFuncType { + return func( + f reflect.Type, + t reflect.Type, + data any, + ) (any, error) { + if f.Kind() != reflect.String { + return data, nil + } + result := reflect.New(t).Interface() + unmarshaller, ok := result.(encoding.TextUnmarshaler) + if !ok { + return data, nil + } + + // default text representation is the actual value of the `from` string + var ( + dataVal = reflect.ValueOf(data) + text = []byte(dataVal.String()) + ) + if f.Kind() == t.Kind() { + // source and target are of underlying type string + var ( + err error + ptrVal = reflect.New(dataVal.Type()) + ) + if !ptrVal.Elem().CanSet() { + // cannot set, skip, this should not happen + if err := unmarshaller.UnmarshalText(text); err != nil { + return nil, err + } + return result, nil + } + ptrVal.Elem().Set(dataVal) + + // We need to assert that both, the value type and the pointer type + // do (not) implement the TextMarshaller interface before proceeding and simply + // using the string value of the string type. + // it might be the case that the internal string representation differs from + // the (un)marshalled string. + + for _, v := range []reflect.Value{dataVal, ptrVal} { + if marshaller, ok := v.Interface().(encoding.TextMarshaler); ok { + text, err = marshaller.MarshalText() + if err != nil { + return nil, err + } + break + } + } + } + + // text is either the source string's value or the source string type's marshaled value + // which may differ from its internal string value. + if err := unmarshaller.UnmarshalText(text); err != nil { + return nil, err + } + return result, nil + } +} diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go index ad3bfc9..8aff3bc 100644 --- a/pkg/utils/utils.go +++ b/pkg/utils/utils.go @@ -1,6 +1,7 @@ package utils import ( + "cmp" "crypto/rand" "errors" "iter" @@ -13,24 +14,19 @@ import ( "strings" "golang.org/x/exp/constraints" - - jsoniter "github.com/json-iterator/go" ) var ( - Json = jsoniter.ConfigCompatibleWithStandardLibrary - DNSServers = []string{"8.8.8.8:53", "8.8.4.4:53", "1.1.1.1:53", "1.0.0.1:53"} -) -var ipv4Re = regexp.MustCompile(`(\d*\.).*(\.\d*)`) + ipv4Re = regexp.MustCompile(`(\d*\.).*(\.\d*)`) + ipv6Re = regexp.MustCompile(`(\w*:\w*:).*(:\w*:\w*)`) +) func ipv4Desensitize(ipv4Addr string) string { return ipv4Re.ReplaceAllString(ipv4Addr, "$1****$2") } -var ipv6Re = regexp.MustCompile(`(\w*:\w*:).*(:\w*:\w*)`) - func ipv6Desensitize(ipv6Addr string) string { return ipv6Re.ReplaceAllString(ipv6Addr, "$1****$2") } @@ -51,9 +47,11 @@ func IPStringToBinary(ip string) ([]byte, error) { } func BinaryToIPString(b []byte) string { - var addr16 [16]byte - copy(addr16[:], b) - addr := netip.AddrFrom16(addr16) + if len(b) < 16 { + return "::" + } + + addr := netip.AddrFrom16([16]byte(b)) return addr.Unmap().String() } @@ -74,7 +72,7 @@ func GenerateRandomString(n int) (string, error) { const letters = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" lettersLength := big.NewInt(int64(len(letters))) ret := make([]byte, n) - for i := 0; i < n; i++ { + for i := range n { num, err := rand.Int(rand.Reader, lettersLength) if err != nil { return "", err @@ -117,22 +115,35 @@ func MapValuesToSlice[Map ~map[K]V, K comparable, V any](m Map) []V { return slices.AppendSeq(s, maps.Values(m)) } -func Unique[T comparable](s []T) []T { - m := make(map[T]struct{}) - ret := make([]T, 0, len(s)) - for _, v := range s { - if _, ok := m[v]; !ok { - m[v] = struct{}{} - ret = append(ret, v) - } - } - return ret +func MapKeysToSlice[Map ~map[K]V, K comparable, V any](m Map) []K { + s := make([]K, 0, len(m)) + return slices.AppendSeq(s, maps.Keys(m)) } -func ConvertSeq[T, U any](seq iter.Seq[T], f func(e T) U) iter.Seq[U] { - return func(yield func(U) bool) { - for e := range seq { - if !yield(f(e)) { +func Unique[S ~[]E, E cmp.Ordered](list S) S { + if list == nil { + return nil + } + out := make([]E, len(list)) + copy(out, list) + slices.Sort(out) + return slices.Compact(out) +} + +func ConvertSeq[In, Out any](seq iter.Seq[In], f func(In) Out) iter.Seq[Out] { + return func(yield func(Out) bool) { + for in := range seq { + if !yield(f(in)) { + return + } + } + } +} + +func ConvertSeq2[KIn, VIn, KOut, VOut any](seq iter.Seq2[KIn, VIn], f func(KIn, VIn) (KOut, VOut)) iter.Seq2[KOut, VOut] { + return func(yield func(KOut, VOut) bool) { + for k, v := range seq { + if !yield(f(k, v)) { return } } diff --git a/pkg/utils/utils_test.go b/pkg/utils/utils_test.go index d9ee1ba..0ad5e97 100644 --- a/pkg/utils/utils_test.go +++ b/pkg/utils/utils_test.go @@ -43,7 +43,7 @@ func TestNotification(t *testing.T) { func TestGenerGenerateRandomString(t *testing.T) { generatedString := make(map[string]bool) - for i := 0; i < 100; i++ { + for range 100 { str, err := GenerateRandomString(32) if err != nil { t.Fatalf("Error: %s", err) diff --git a/service/singleton/alertsentinel.go b/service/singleton/alertsentinel.go index 276eb1f..ef8b41f 100644 --- a/service/singleton/alertsentinel.go +++ b/service/singleton/alertsentinel.go @@ -96,7 +96,7 @@ func OnRefreshOrAddAlert(alert *model.AlertRule) { delete(alertsStore, alert.ID) delete(alertsPrevState, alert.ID) var isEdit bool - for i := 0; i < len(Alerts); i++ { + for i := range Alerts { if Alerts[i].ID == alert.ID { Alerts[i] = alert isEdit = true diff --git a/service/singleton/user.go b/service/singleton/user.go index acf0db1..79c90eb 100644 --- a/service/singleton/user.go +++ b/service/singleton/user.go @@ -52,7 +52,7 @@ func OnUserUpdate(u *model.User) { AgentSecretToUserId[u.AgentSecret] = u.ID } -func OnUserDelete(id []uint64, errorFunc func(string, ...interface{}) error) error { +func OnUserDelete(id []uint64, errorFunc func(string, ...any) error) error { UserLock.Lock() defer UserLock.Unlock()