feat: update to go1.24 & support listening https (#1002)

* feat: support listening https

* refactor

* modernize

* support snake case in config

* more precise control of config fields

* update goreleaser config

* remove kubeyaml

* fix: expose agent_secret

* chore
This commit is contained in:
UUBulb 2025-02-28 22:02:54 +08:00 committed by GitHub
parent e770398a11
commit 1d2f8d24f6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
28 changed files with 321 additions and 175 deletions

View File

@ -14,6 +14,8 @@ builds:
flags:
- -trimpath
- -buildvcs=false
tags:
- go_json
goos:
- linux
goarch:
@ -31,6 +33,8 @@ builds:
flags:
- -trimpath
- -buildvcs=false
tags:
- go_json
goos:
- linux
goarch:
@ -48,6 +52,8 @@ builds:
flags:
- -trimpath
- -buildvcs=false
tags:
- go_json
goos:
- linux
goarch:
@ -65,6 +71,8 @@ builds:
flags:
- -trimpath
- -buildvcs=false
tags:
- go_json
goos:
- windows
goarch:

View File

@ -175,10 +175,10 @@ type pHandlerFunc[S ~[]E, E any] func(c *gin.Context) (*model.Value[S], error)
// gorm errors here instead
type gormError struct {
msg string
a []interface{}
a []any
}
func newGormError(format string, args ...interface{}) error {
func newGormError(format string, args ...any) error {
return &gormError{
msg: format,
a: args,
@ -191,10 +191,10 @@ func (ge *gormError) Error() string {
type wsError struct {
msg string
a []interface{}
a []any
}
func newWsError(format string, args ...interface{}) error {
func newWsError(format string, args ...any) error {
return &wsError{
msg: format,
a: args,

View File

@ -5,11 +5,11 @@ import (
"time"
"github.com/gin-gonic/gin"
"github.com/goccy/go-json"
"github.com/gorilla/websocket"
"github.com/hashicorp/go-uuid"
"github.com/nezhahq/nezha/model"
"github.com/nezhahq/nezha/pkg/utils"
"github.com/nezhahq/nezha/pkg/websocketx"
"github.com/nezhahq/nezha/proto"
"github.com/nezhahq/nezha/service/rpc"
@ -48,7 +48,7 @@ func createFM(c *gin.Context) (*model.CreateFMResponse, error) {
rpc.NezhaHandlerSingleton.CreateStream(streamId)
fmData, _ := utils.Json.Marshal(&model.TaskFM{
fmData, _ := json.Marshal(&model.TaskFM{
StreamID: streamId,
})
if err := server.TaskStream.Send(&proto.Task{

View File

@ -1,12 +1,12 @@
package controller
import (
"encoding/json"
"net/http"
"time"
jwt "github.com/appleboy/gin-jwt/v2"
"github.com/gin-gonic/gin"
"github.com/goccy/go-json"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
@ -48,8 +48,8 @@ func initParams() *jwt.GinJWTMiddleware {
}
}
func payloadFunc() func(data interface{}) jwt.MapClaims {
return func(data interface{}) jwt.MapClaims {
func payloadFunc() func(data any) jwt.MapClaims {
return func(data any) jwt.MapClaims {
if v, ok := data.(string); ok {
return jwt.MapClaims{
model.CtxKeyAuthorizedUser: v,
@ -59,8 +59,8 @@ func payloadFunc() func(data interface{}) jwt.MapClaims {
}
}
func identityHandler() func(c *gin.Context) interface{} {
return func(c *gin.Context) interface{} {
func identityHandler() func(c *gin.Context) any {
return func(c *gin.Context) any {
claims := jwt.ExtractClaims(c)
userId := claims[model.CtxKeyAuthorizedUser].(string)
var user model.User
@ -80,8 +80,8 @@ func identityHandler() func(c *gin.Context) interface{} {
// @Produce json
// @Success 200 {object} model.CommonResponse[model.LoginResponse]
// @Router /login [post]
func authenticator() func(c *gin.Context) (interface{}, error) {
return func(c *gin.Context) (interface{}, error) {
func authenticator() func(c *gin.Context) (any, error) {
return func(c *gin.Context) (any, error) {
var loginVals model.LoginRequest
if err := c.ShouldBind(&loginVals); err != nil {
return "", jwt.ErrMissingLoginValues
@ -113,8 +113,8 @@ func authenticator() func(c *gin.Context) (interface{}, error) {
}
}
func authorizator() func(data interface{}, c *gin.Context) bool {
return func(data interface{}, c *gin.Context) bool {
func authorizator() func(data any, c *gin.Context) bool {
return func(data any, c *gin.Context) bool {
_, ok := data.(*model.User)
return ok
}

View File

@ -7,11 +7,11 @@ import (
"time"
"github.com/gin-gonic/gin"
"github.com/goccy/go-json"
"github.com/jinzhu/copier"
"gorm.io/gorm"
"github.com/nezhahq/nezha/model"
"github.com/nezhahq/nezha/pkg/utils"
pb "github.com/nezhahq/nezha/proto"
"github.com/nezhahq/nezha/service/singleton"
)
@ -81,13 +81,13 @@ func updateServer(c *gin.Context) (any, error) {
s.DDNSProfiles = sf.DDNSProfiles
s.OverrideDDNSDomains = sf.OverrideDDNSDomains
ddnsProfilesRaw, err := utils.Json.Marshal(s.DDNSProfiles)
ddnsProfilesRaw, err := json.Marshal(s.DDNSProfiles)
if err != nil {
return nil, err
}
s.DDNSProfilesRaw = string(ddnsProfilesRaw)
overrideDomainsRaw, err := utils.Json.Marshal(sf.OverrideDDNSDomains)
overrideDomainsRaw, err := json.Marshal(sf.OverrideDDNSDomains)
if err != nil {
return nil, err
}
@ -281,10 +281,7 @@ func setServerConfig(c *gin.Context) (*model.ServerTaskResponse, error) {
var respMu sync.Mutex
for i := 0; i < len(servers); i += 10 {
end := i + 10
if end > len(servers) {
end = len(servers)
}
end := min(i+10, len(servers))
group := servers[i:end]
wg.Add(1)

View File

@ -25,7 +25,7 @@ import (
// @Success 200 {object} model.CommonResponse[model.ServiceResponse]
// @Router /service [get]
func showService(c *gin.Context) (*model.ServiceResponse, error) {
res, err, _ := requestGroup.Do("list-service", func() (interface{}, error) {
res, err, _ := requestGroup.Do("list-service", func() (any, error) {
singleton.AlertsLock.RLock()
defer singleton.AlertsLock.RUnlock()
stats := singleton.ServiceSentinelShared.CopyStats()
@ -41,8 +41,8 @@ func showService(c *gin.Context) (*model.ServiceResponse, error) {
}
return &model.ServiceResponse{
Services: res.([]interface{})[0].(map[uint64]model.ServiceResponseItem),
CycleTransferStats: res.([]interface{})[1].(map[uint64]model.CycleTransferStats),
Services: res.([]any)[0].(map[uint64]model.ServiceResponseItem),
CycleTransferStats: res.([]any)[1].(map[uint64]model.CycleTransferStats),
}, nil
}

View File

@ -17,9 +17,9 @@ import (
// @Security BearerAuth
// @Tags common
// @Produce json
// @Success 200 {object} model.CommonResponse[model.SettingResponse[model.Config]]
// @Success 200 {object} model.CommonResponse[model.SettingResponse]
// @Router /setting [get]
func listConfig(c *gin.Context) (model.SettingResponse[any], error) {
func listConfig(c *gin.Context) (*model.SettingResponse, error) {
u, authorized := c.Get(model.CtxKeyAuthorizedUser)
var isAdmin bool
if authorized {
@ -30,30 +30,29 @@ func listConfig(c *gin.Context) (model.SettingResponse[any], error) {
config := *singleton.Conf
config.Language = strings.Replace(config.Language, "_", "-", -1)
conf := model.SettingResponse[any]{
Config: config,
conf := model.SettingResponse{
Config: model.Setting{
ConfigForGuests: config.ConfigForGuests,
ConfigDashboard: config.ConfigDashboard,
},
Version: singleton.Version,
FrontendTemplates: singleton.FrontendTemplates,
}
if !authorized || !isAdmin {
configForGuests := model.ConfigForGuests{
Language: config.Language,
SiteName: config.SiteName,
CustomCode: config.CustomCode,
CustomCodeDashboard: config.CustomCodeDashboard,
Oauth2Providers: config.Oauth2Providers,
}
configForGuests := config.ConfigForGuests
if authorized {
configForGuests.TLS = singleton.Conf.TLS
configForGuests.AgentTLS = singleton.Conf.AgentTLS
configForGuests.InstallHost = singleton.Conf.InstallHost
}
conf = model.SettingResponse[any]{
Config: configForGuests,
conf = model.SettingResponse{
Config: model.Setting{
ConfigForGuests: configForGuests,
},
}
}
return conf, nil
return &conf, nil
}
// Edit config
@ -98,7 +97,7 @@ func updateConfig(c *gin.Context) (any, error) {
singleton.Conf.CustomCode = sf.CustomCode
singleton.Conf.CustomCodeDashboard = sf.CustomCodeDashboard
singleton.Conf.RealIPHeader = sf.RealIPHeader
singleton.Conf.TLS = sf.TLS
singleton.Conf.AgentTLS = sf.AgentTLS
singleton.Conf.UserTemplate = sf.UserTemplate
if err := singleton.Conf.Save(); err != nil {

View File

@ -4,11 +4,11 @@ import (
"time"
"github.com/gin-gonic/gin"
"github.com/goccy/go-json"
"github.com/gorilla/websocket"
"github.com/hashicorp/go-uuid"
"github.com/nezhahq/nezha/model"
"github.com/nezhahq/nezha/pkg/utils"
"github.com/nezhahq/nezha/pkg/websocketx"
"github.com/nezhahq/nezha/proto"
"github.com/nezhahq/nezha/service/rpc"
@ -46,7 +46,7 @@ func createTerminal(c *gin.Context) (*model.CreateTerminalResponse, error) {
rpc.NezhaHandlerSingleton.CreateStream(streamId)
terminalData, _ := utils.Json.Marshal(&model.TerminalTask{
terminalData, _ := json.Marshal(&model.TerminalTask{
StreamID: streamId,
})
if err := server.TaskStream.Send(&proto.Task{

View File

@ -9,6 +9,7 @@ import (
"unicode/utf8"
"github.com/gin-gonic/gin"
"github.com/goccy/go-json"
"github.com/gorilla/websocket"
"github.com/hashicorp/go-uuid"
"golang.org/x/sync/singleflight"
@ -157,7 +158,7 @@ func serverStream(c *gin.Context) (any, error) {
var requestGroup singleflight.Group
func getServerStat(withPublicNote, authorized bool) ([]byte, error) {
v, err, _ := requestGroup.Do(fmt.Sprintf("serverStats::%t", authorized), func() (interface{}, error) {
v, err, _ := requestGroup.Do(fmt.Sprintf("serverStats::%t", authorized), func() (any, error) {
var serverList []*model.Server
if authorized {
serverList = singleton.ServerShared.GetSortedList()
@ -183,7 +184,7 @@ func getServerStat(withPublicNote, authorized bool) ([]byte, error) {
})
}
return utils.Json.Marshal(model.StreamServerData{
return json.Marshal(model.StreamServerData{
Now: time.Now().Unix() * 1000,
Online: singleton.GetOnlineUserCount(),
Servers: servers,

View File

@ -2,7 +2,9 @@ package main
import (
"context"
"crypto/tls"
"embed"
"errors"
"flag"
"fmt"
"log"
@ -16,8 +18,6 @@ import (
"github.com/gin-gonic/gin"
"github.com/ory/graceful"
"golang.org/x/crypto/bcrypt"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
"github.com/nezhahq/nezha/cmd/dashboard/controller"
"github.com/nezhahq/nezha/cmd/dashboard/controller/waf"
@ -133,22 +133,56 @@ func main() {
controller.InitUpgrader()
muxHandler := newHTTPandGRPCMux(httpHandler, grpcHandler)
http2Server := &http2.Server{}
muxServer := &http.Server{Handler: h2c.NewHandler(muxHandler, http2Server), ReadHeaderTimeout: time.Second * 5}
muxServerHTTP := &http.Server{
Handler: muxHandler,
ReadHeaderTimeout: time.Second * 5,
}
muxServerHTTP.Protocols = new(http.Protocols)
muxServerHTTP.Protocols.SetHTTP1(true)
muxServerHTTP.Protocols.SetUnencryptedHTTP2(true)
var muxServerHTTPS *http.Server
if singleton.Conf.HTTPS.ListenPort != 0 {
muxServerHTTPS = &http.Server{
Addr: fmt.Sprintf("%s:%d", singleton.Conf.ListenHost, singleton.Conf.HTTPS.ListenPort),
Handler: muxHandler,
ReadHeaderTimeout: time.Second * 5,
TLSConfig: &tls.Config{
InsecureSkipVerify: singleton.Conf.HTTPS.InsecureTLS,
},
}
}
errChan := make(chan error, 2)
if err := graceful.Graceful(func() error {
log.Printf("NEZHA>> Dashboard::START ON %s:%d", singleton.Conf.ListenHost, singleton.Conf.ListenPort)
return muxServer.Serve(l)
if singleton.Conf.HTTPS.ListenPort != 0 {
go func() {
errChan <- muxServerHTTPS.ListenAndServeTLS(singleton.Conf.HTTPS.TLSCertPath, singleton.Conf.HTTPS.TLSKeyPath)
}()
log.Printf("NEZHA>> Dashboard::START ON %s:%d", singleton.Conf.ListenHost, singleton.Conf.HTTPS.ListenPort)
}
go func() {
errChan <- muxServerHTTP.Serve(l)
}()
return <-errChan
}, func(c context.Context) error {
log.Println("NEZHA>> Graceful::START")
singleton.RecordTransferHourlyUsage()
log.Println("NEZHA>> Graceful::END")
return muxServer.Shutdown(c)
err := muxServerHTTPS.Shutdown(c)
return errors.Join(muxServerHTTP.Shutdown(c), err)
}); err != nil {
log.Printf("NEZHA>> ERROR: %v", err)
if errors.Unwrap(err) != nil {
log.Printf("NEZHA>> ERROR HTTPS: %v", err)
}
}
close(errChan)
}
func newHTTPandGRPCMux(httpHandler http.Handler, grpcHandler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
natConfig := singleton.NATShared.GetNATConfigByDomain(r.Host)

View File

@ -8,6 +8,7 @@ import (
"net/netip"
"time"
"github.com/goccy/go-json"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
@ -27,7 +28,7 @@ func ServeRPC() *grpc.Server {
return server
}
func waf(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
func waf(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
realip, _ := ctx.Value(model.CtxKeyRealIP{}).(string)
if err := model.CheckIP(singleton.DB, realip); err != nil {
return nil, err
@ -35,7 +36,7 @@ func waf(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handl
return handler(ctx, req)
}
func getRealIp(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
func getRealIp(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
var ip, connectingIp string
p, ok := peer.FromContext(ctx)
if ok {
@ -133,20 +134,20 @@ func ServeNAT(w http.ResponseWriter, r *http.Request, natConfig *model.NAT) {
streamId, err := uuid.GenerateUUID()
if err != nil {
w.WriteHeader(http.StatusServiceUnavailable)
w.Write([]byte(fmt.Sprintf("stream id error: %v", err)))
w.Write(fmt.Appendf(nil, "stream id error: %v", err))
return
}
rpcService.NezhaHandlerSingleton.CreateStream(streamId)
defer rpcService.NezhaHandlerSingleton.CloseStream(streamId)
taskData, err := utils.Json.Marshal(model.TaskNAT{
taskData, err := json.Marshal(model.TaskNAT{
StreamID: streamId,
Host: natConfig.Host,
})
if err != nil {
w.WriteHeader(http.StatusServiceUnavailable)
w.Write([]byte(fmt.Sprintf("task data error: %v", err)))
w.Write(fmt.Appendf(nil, "task data error: %v", err))
return
}
@ -155,20 +156,20 @@ func ServeNAT(w http.ResponseWriter, r *http.Request, natConfig *model.NAT) {
Data: string(taskData),
}); err != nil {
w.WriteHeader(http.StatusServiceUnavailable)
w.Write([]byte(fmt.Sprintf("send task error: %v", err)))
w.Write(fmt.Appendf(nil, "send task error: %v", err))
return
}
wWrapped, err := utils.NewRequestWrapper(r, w)
if err != nil {
w.WriteHeader(http.StatusServiceUnavailable)
w.Write([]byte(fmt.Sprintf("request wrapper error: %v", err)))
w.Write(fmt.Appendf(nil, "request wrapper error: %v", err))
return
}
if err := rpcService.NezhaHandlerSingleton.UserConnected(streamId, wWrapped); err != nil {
w.WriteHeader(http.StatusServiceUnavailable)
w.Write([]byte(fmt.Sprintf("user connected error: %v", err)))
w.Write(fmt.Appendf(nil, "user connected error: %v", err))
return
}

8
go.mod
View File

@ -1,6 +1,6 @@
module github.com/nezhahq/nezha
go 1.23.6
go 1.24.0
require (
github.com/appleboy/gin-jwt/v2 v2.10.1
@ -8,10 +8,11 @@ require (
github.com/dustinkirkland/golang-petname v0.0.0-20240428194347-eebcea082ee0
github.com/gin-contrib/pprof v1.5.2
github.com/gin-gonic/gin v1.10.0
github.com/go-viper/mapstructure/v2 v2.2.1
github.com/goccy/go-json v0.10.5
github.com/gorilla/websocket v1.5.3
github.com/hashicorp/go-uuid v1.0.3
github.com/jinzhu/copier v0.4.0
github.com/json-iterator/go v1.1.12
github.com/knadh/koanf/parsers/yaml v0.1.0
github.com/knadh/koanf/providers/env v1.0.0
github.com/knadh/koanf/providers/file v1.1.2
@ -56,12 +57,11 @@ require (
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.25.0 // indirect
github.com/go-viper/mapstructure/v2 v2.2.1 // indirect
github.com/goccy/go-json v0.10.5 // indirect
github.com/golang-jwt/jwt/v4 v4.5.1 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/josharian/intern v1.0.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.9 // indirect
github.com/knadh/koanf/maps v0.1.1 // indirect
github.com/leodido/go-urn v1.4.0 // indirect

View File

@ -1,7 +1,7 @@
package model
import (
"github.com/nezhahq/nezha/pkg/utils"
"github.com/goccy/go-json"
"gorm.io/gorm"
)
@ -25,17 +25,17 @@ type AlertRule struct {
}
func (r *AlertRule) BeforeSave(tx *gorm.DB) error {
if data, err := utils.Json.Marshal(r.Rules); err != nil {
if data, err := json.Marshal(r.Rules); err != nil {
return err
} else {
r.RulesRaw = string(data)
}
if data, err := utils.Json.Marshal(r.FailTriggerTasks); err != nil {
if data, err := json.Marshal(r.FailTriggerTasks); err != nil {
return err
} else {
r.FailTriggerTasksRaw = string(data)
}
if data, err := utils.Json.Marshal(r.RecoverTriggerTasks); err != nil {
if data, err := json.Marshal(r.RecoverTriggerTasks); err != nil {
return err
} else {
r.RecoverTriggerTasksRaw = string(data)
@ -45,13 +45,13 @@ func (r *AlertRule) BeforeSave(tx *gorm.DB) error {
func (r *AlertRule) AfterFind(tx *gorm.DB) error {
var err error
if err = utils.Json.Unmarshal([]byte(r.RulesRaw), &r.Rules); err != nil {
if err = json.Unmarshal([]byte(r.RulesRaw), &r.Rules); err != nil {
return err
}
if err = utils.Json.Unmarshal([]byte(r.FailTriggerTasksRaw), &r.FailTriggerTasks); err != nil {
if err = json.Unmarshal([]byte(r.FailTriggerTasksRaw), &r.FailTriggerTasks); err != nil {
return err
}
if err = utils.Json.Unmarshal([]byte(r.RecoverTriggerTasksRaw), &r.RecoverTriggerTasks); err != nil {
if err = json.Unmarshal([]byte(r.RecoverTriggerTasksRaw), &r.RecoverTriggerTasks); err != nil {
return err
}
return nil

View File

@ -77,7 +77,7 @@ func SearchByIDCtx[S ~[]E, E CommonInterface](c *gin.Context, x S) S {
return any(l).(S)
default:
var s S
for _, idStr := range strings.Split(c.Query("id"), ",") {
for idStr := range strings.SplitSeq(c.Query("id"), ",") {
id, err := strconv.ParseUint(idStr, 10, 64)
if err != nil {
continue
@ -93,7 +93,7 @@ func searchByIDCtxServer(c *gin.Context, x []*Server) []*Server {
list1, list2 := SplitList(x)
var clist1, clist2 []*Server
for _, idStr := range strings.Split(c.Query("id"), ",") {
for idStr := range strings.SplitSeq(c.Query("id"), ",") {
id, err := strconv.ParseUint(idStr, 10, 64)
if err != nil {
continue

View File

@ -1,18 +1,16 @@
package model
import (
"maps"
"os"
"path/filepath"
"slices"
"strconv"
"strings"
"github.com/go-viper/mapstructure/v2"
kyaml "github.com/knadh/koanf/parsers/yaml"
"github.com/knadh/koanf/providers/env"
"github.com/knadh/koanf/providers/file"
"github.com/knadh/koanf/v2"
"gopkg.in/yaml.v3"
"github.com/nezhahq/nezha/pkg/utils"
)
@ -24,52 +22,56 @@ const (
)
type ConfigForGuests struct {
Language string `json:"language"`
SiteName string `json:"site_name"`
CustomCode string `json:"custom_code,omitempty"`
CustomCodeDashboard string `json:"custom_code_dashboard,omitempty"`
Oauth2Providers []string `json:"oauth2_providers,omitempty"`
Language string `koanf:"language" json:"language"` // 系统语言,默认 zh_CN
SiteName string `koanf:"site_name" json:"site_name"`
CustomCode string `koanf:"custom_code" json:"custom_code,omitempty"`
CustomCodeDashboard string `koanf:"custom_code_dashboard" json:"custom_code_dashboard,omitempty"`
Oauth2Providers []string `koanf:"-" json:"oauth2_providers,omitempty"` // oauth2 供应商列表,无需配置,自动生成
InstallHost string `json:"install_host,omitempty"`
TLS bool `json:"tls,omitempty"`
InstallHost string `koanf:"install_host" json:"install_host,omitempty"`
AgentTLS bool `koanf:"tls" json:"tls,omitempty"` // 用于前端判断生成的安装命令是否启用 TLS
}
type ConfigDashboard struct {
Debug bool `koanf:"debug" json:"debug,omitempty"` // debug模式开关
RealIPHeader string `koanf:"real_ip_header" json:"real_ip_header,omitempty"` // 真实IP
UserTemplate string `koanf:"user_template" json:"user_template,omitempty"`
AdminTemplate string `koanf:"admin_template" json:"admin_template,omitempty"`
Location string `koanf:"location" json:"location,omitempty"` // 时区,默认为 Asia/Shanghai
ForceAuth bool `koanf:"force_auth" json:"force_auth,omitempty"` // 强制要求认证
AgentSecretKey string `koanf:"agent_secret_key" json:"agent_secret_key,omitempty"`
EnablePlainIPInNotification bool `koanf:"enable_plain_ip_in_notification" json:"enable_plain_ip_in_notification,omitempty"` // 通知信息IP不打码
// IP变更提醒
EnableIPChangeNotification bool `koanf:"enable_ip_change_notification" json:"enable_ip_change_notification,omitempty"`
IPChangeNotificationGroupID uint64 `koanf:"ip_change_notification_group_id" json:"ip_change_notification_group_id"`
Cover uint8 `koanf:"cover" json:"cover"` // 覆盖范围0:提醒未被 IgnoredIPNotification 包含的所有服务器; 1:仅提醒被 IgnoredIPNotification 包含的服务器;
IgnoredIPNotification string `koanf:"ignored_ip_notification" json:"ignored_ip_notification,omitempty"` // 特定服务器IP多个服务器用逗号分隔
IgnoredIPNotificationServerIDs map[uint64]bool `koanf:"ignored_ip_notification_server_ids" json:"ignored_ip_notification_server_ids,omitempty"` // [ServerID] -> bool(值为true代表当前ServerID在特定服务器列表内
AvgPingCount int `koanf:"avg_ping_count" json:"avg_ping_count,omitempty"`
DNSServers string `koanf:"dns_servers" json:"dns_servers,omitempty"`
}
type Config struct {
Debug bool `mapstructure:"debug" json:"debug,omitempty"` // debug模式开关
RealIPHeader string `mapstructure:"real_ip_header" json:"real_ip_header,omitempty"` // 真实IP
ConfigForGuests
ConfigDashboard
Language string `mapstructure:"language" json:"language"` // 系统语言,默认 zh_CN
SiteName string `mapstructure:"site_name" json:"site_name"`
UserTemplate string `mapstructure:"user_template" json:"user_template,omitempty"`
AdminTemplate string `mapstructure:"admin_template" json:"admin_template,omitempty"`
JWTSecretKey string `mapstructure:"jwt_secret_key" json:"jwt_secret_key,omitempty"`
AgentSecretKey string `mapstructure:"agent_secret_key" json:"agent_secret_key,omitempty"`
ListenPort uint `mapstructure:"listen_port" json:"listen_port,omitempty"`
ListenHost string `mapstructure:"listen_host" json:"listen_host,omitempty"`
InstallHost string `mapstructure:"install_host" json:"install_host,omitempty"`
TLS bool `mapstructure:"tls" json:"tls,omitempty"`
Location string `mapstructure:"location" json:"location,omitempty"` // 时区,默认为 Asia/Shanghai
ForceAuth bool `mapstructure:"force_auth" json:"force_auth,omitempty"` // 强制要求认证
EnablePlainIPInNotification bool `mapstructure:"enable_plain_ip_in_notification" json:"enable_plain_ip_in_notification,omitempty"` // 通知信息IP不打码
// IP变更提醒
EnableIPChangeNotification bool `mapstructure:"enable_ip_change_notification" json:"enable_ip_change_notification,omitempty"`
IPChangeNotificationGroupID uint64 `mapstructure:"ip_change_notification_group_id" json:"ip_change_notification_group_id"`
Cover uint8 `mapstructure:"cover" json:"cover"` // 覆盖范围0:提醒未被 IgnoredIPNotification 包含的所有服务器; 1:仅提醒被 IgnoredIPNotification 包含的服务器;
IgnoredIPNotification string `mapstructure:"ignored_ip_notification" json:"ignored_ip_notification,omitempty"` // 特定服务器IP多个服务器用逗号分隔
IgnoredIPNotificationServerIDs map[uint64]bool `mapstructure:"ignored_ip_notification_server_ids" json:"ignored_ip_notification_server_ids,omitempty"` // [ServerID] -> bool(值为true代表当前ServerID在特定服务器列表内
AvgPingCount int `mapstructure:"avg_ping_count" json:"avg_ping_count,omitempty"`
DNSServers string `mapstructure:"dns_servers" json:"dns_servers,omitempty"`
CustomCode string `mapstructure:"custom_code" json:"custom_code,omitempty"`
CustomCodeDashboard string `mapstructure:"custom_code_dashboard" json:"custom_code_dashboard,omitempty"`
JWTSecretKey string `koanf:"jwt_secret_key" json:"jwt_secret_key,omitempty"`
ListenPort uint16 `koanf:"listen_port" json:"listen_port,omitempty"`
ListenHost string `koanf:"listen_host" json:"listen_host,omitempty"`
// oauth2 配置
Oauth2 map[string]*Oauth2Config `mapstructure:"oauth2" json:"oauth2,omitempty"`
// oauth2 供应商列表,无需配置,自动生成
Oauth2Providers []string `yaml:"-" json:"oauth2_providers,omitempty"`
Oauth2 map[string]*Oauth2Config `koanf:"oauth2" json:"oauth2,omitempty"`
// HTTPS 配置
HTTPS struct {
ListenPort uint16 `koanf:"listen_port" json:"listen_port,omitempty"`
TLSCertPath string `koanf:"tls_cert_path" json:"tls_cert_path,omitempty"`
TLSKeyPath string `koanf:"tls_key_path" json:"tls_key_path,omitempty"`
InsecureTLS bool `koanf:"insecure_tls" json:"insecure_tls,omitempty"`
} `koanf:"https" json:"https"`
k *koanf.Koanf `json:"-"`
filePath string `json:"-"`
@ -94,10 +96,11 @@ func (c *Config) Read(path string, frontendTemplates []FrontendTemplate) error {
}
}
err = c.k.Unmarshal("", c)
err = c.k.UnmarshalWithConf("", c, koanfConf(c))
if err != nil {
return err
}
if c.ListenPort == 0 {
c.ListenPort = 8008
}
@ -151,7 +154,7 @@ func (c *Config) Read(path string, frontendTemplates []FrontendTemplate) error {
}
}
c.Oauth2Providers = slices.Collect(maps.Keys(c.Oauth2))
c.Oauth2Providers = utils.MapKeysToSlice(c.Oauth2)
c.updateIgnoredIPNotificationID()
return nil
@ -160,9 +163,8 @@ func (c *Config) Read(path string, frontendTemplates []FrontendTemplate) error {
// updateIgnoredIPNotificationID 更新用于判断服务器ID是否属于特定服务器的map
func (c *Config) updateIgnoredIPNotificationID() {
c.IgnoredIPNotificationServerIDs = make(map[uint64]bool)
splitedIDs := strings.Split(c.IgnoredIPNotification, ",")
for i := 0; i < len(splitedIDs); i++ {
id, _ := strconv.ParseUint(splitedIDs[i], 10, 64)
for splitedID := range strings.SplitSeq(c.IgnoredIPNotification, ",") {
id, _ := strconv.ParseUint(splitedID, 10, 64)
if id > 0 {
c.IgnoredIPNotificationServerIDs[id] = true
}
@ -172,7 +174,7 @@ func (c *Config) updateIgnoredIPNotificationID() {
// Save 保存配置文件
func (c *Config) Save() error {
c.updateIgnoredIPNotificationID()
data, err := yaml.Marshal(c)
data, err := c.k.Marshal(kyaml.Parser())
if err != nil {
return err
}
@ -184,3 +186,21 @@ func (c *Config) Save() error {
return os.WriteFile(c.filePath, data, 0600)
}
func koanfConf(c any) koanf.UnmarshalConf {
return koanf.UnmarshalConf{
DecoderConfig: &mapstructure.DecoderConfig{
DecodeHook: mapstructure.ComposeDecodeHookFunc(
mapstructure.StringToTimeDurationHookFunc(),
utils.TextUnmarshalerHookFunc()),
Metadata: nil,
Result: c,
WeaklyTypedInput: true,
MatchName: func(mapKey, fieldName string) bool {
return strings.EqualFold(mapKey, fieldName) ||
strings.EqualFold(mapKey, strings.ReplaceAll(fieldName, "_", ""))
},
Squash: true,
},
}
}

View File

@ -3,7 +3,7 @@ package model
import (
"time"
"github.com/nezhahq/nezha/pkg/utils"
"github.com/goccy/go-json"
"github.com/robfig/cron/v3"
"gorm.io/gorm"
)
@ -34,7 +34,7 @@ type Cron struct {
}
func (c *Cron) BeforeSave(tx *gorm.DB) error {
if data, err := utils.Json.Marshal(c.Servers); err != nil {
if data, err := json.Marshal(c.Servers); err != nil {
return err
} else {
c.ServersRaw = string(data)
@ -43,5 +43,5 @@ func (c *Cron) BeforeSave(tx *gorm.DB) error {
}
func (c *Cron) AfterFind(tx *gorm.DB) error {
return utils.Json.Unmarshal([]byte(c.ServersRaw), &c.Servers)
return json.Unmarshal([]byte(c.ServersRaw), &c.Servers)
}

View File

@ -1,7 +1,7 @@
package model
import (
"github.com/nezhahq/nezha/pkg/utils"
"github.com/goccy/go-json"
"gorm.io/gorm"
)
@ -39,7 +39,7 @@ func (d DDNSProfile) TableName() string {
}
func (d *DDNSProfile) BeforeSave(tx *gorm.DB) error {
if data, err := utils.Json.Marshal(d.Domains); err != nil {
if data, err := json.Marshal(d.Domains); err != nil {
return err
} else {
d.DomainsRaw = string(data)
@ -48,5 +48,5 @@ func (d *DDNSProfile) BeforeSave(tx *gorm.DB) error {
}
func (d *DDNSProfile) AfterFind(tx *gorm.DB) error {
return utils.Json.Unmarshal([]byte(d.DomainsRaw), &d.Domains)
return json.Unmarshal([]byte(d.DomainsRaw), &d.Domains)
}

View File

@ -9,6 +9,7 @@ import (
"strings"
"time"
"github.com/goccy/go-json"
"github.com/nezhahq/nezha/pkg/utils"
)
@ -66,7 +67,7 @@ func (ns *NotificationServerBundle) reqBody(message string) (string, error) {
switch n.RequestType {
case NotificationRequestTypeJSON:
return ns.replaceParamsInString(n.RequestBody, message, func(msg string) string {
msgBytes, _ := utils.Json.Marshal(msg)
msgBytes, _ := json.Marshal(msg)
return string(msgBytes)[1 : len(msgBytes)-1]
}), nil
case NotificationRequestTypeForm:

View File

@ -5,9 +5,9 @@ import (
"slices"
"time"
"github.com/goccy/go-json"
"gorm.io/gorm"
"github.com/nezhahq/nezha/pkg/utils"
pb "github.com/nezhahq/nezha/proto"
)
@ -59,13 +59,13 @@ func (s *Server) CopyFromRunningServer(old *Server) {
func (s *Server) AfterFind(tx *gorm.DB) error {
if s.DDNSProfilesRaw != "" {
if err := utils.Json.Unmarshal([]byte(s.DDNSProfilesRaw), &s.DDNSProfiles); err != nil {
if err := json.Unmarshal([]byte(s.DDNSProfilesRaw), &s.DDNSProfiles); err != nil {
log.Println("NEZHA>> Server.AfterFind:", err)
return nil
}
}
if s.OverrideDDNSDomainsRaw != "" {
if err := utils.Json.Unmarshal([]byte(s.OverrideDDNSDomainsRaw), &s.OverrideDDNSDomains); err != nil {
if err := json.Unmarshal([]byte(s.OverrideDDNSDomainsRaw), &s.OverrideDDNSDomains); err != nil {
log.Println("NEZHA>> Server.AfterFind:", err)
return nil
}

View File

@ -4,10 +4,10 @@ import (
"fmt"
"log"
"github.com/goccy/go-json"
"github.com/robfig/cron/v3"
"gorm.io/gorm"
"github.com/nezhahq/nezha/pkg/utils"
pb "github.com/nezhahq/nezha/proto"
)
@ -91,17 +91,17 @@ func (m *Service) CronSpec() string {
}
func (m *Service) BeforeSave(tx *gorm.DB) error {
if data, err := utils.Json.Marshal(m.SkipServers); err != nil {
if data, err := json.Marshal(m.SkipServers); err != nil {
return err
} else {
m.SkipServersRaw = string(data)
}
if data, err := utils.Json.Marshal(m.FailTriggerTasks); err != nil {
if data, err := json.Marshal(m.FailTriggerTasks); err != nil {
return err
} else {
m.FailTriggerTasksRaw = string(data)
}
if data, err := utils.Json.Marshal(m.RecoverTriggerTasks); err != nil {
if data, err := json.Marshal(m.RecoverTriggerTasks); err != nil {
return err
} else {
m.RecoverTriggerTasksRaw = string(data)
@ -111,16 +111,16 @@ func (m *Service) BeforeSave(tx *gorm.DB) error {
func (m *Service) AfterFind(tx *gorm.DB) error {
m.SkipServers = make(map[uint64]bool)
if err := utils.Json.Unmarshal([]byte(m.SkipServersRaw), &m.SkipServers); err != nil {
if err := json.Unmarshal([]byte(m.SkipServersRaw), &m.SkipServers); err != nil {
log.Println("NEZHA>> Service.AfterFind:", err)
return nil
}
// 加载触发任务列表
if err := utils.Json.Unmarshal([]byte(m.FailTriggerTasksRaw), &m.FailTriggerTasks); err != nil {
if err := json.Unmarshal([]byte(m.FailTriggerTasksRaw), &m.FailTriggerTasks); err != nil {
return err
}
if err := utils.Json.Unmarshal([]byte(m.RecoverTriggerTasksRaw), &m.RecoverTriggerTasks); err != nil {
if err := json.Unmarshal([]byte(m.RecoverTriggerTasksRaw), &m.RecoverTriggerTasks); err != nil {
return err
}

View File

@ -13,11 +13,16 @@ type SettingForm struct {
RealIPHeader string `json:"real_ip_header,omitempty" validate:"optional"` // 真实IP
UserTemplate string `json:"user_template,omitempty" validate:"optional"`
TLS bool `json:"tls,omitempty" validate:"optional"`
AgentTLS bool `json:"tls,omitempty" validate:"optional"`
EnableIPChangeNotification bool `json:"enable_ip_change_notification,omitempty" validate:"optional"`
EnablePlainIPInNotification bool `json:"enable_plain_ip_in_notification,omitempty" validate:"optional"`
}
type Setting struct {
ConfigForGuests
ConfigDashboard
}
type FrontendTemplate struct {
Path string `json:"path,omitempty"`
Name string `json:"name,omitempty"`
@ -28,8 +33,8 @@ type FrontendTemplate struct {
IsOfficial bool `json:"is_official,omitempty"`
}
type SettingResponse[T any] struct {
Config T `json:"config,omitempty"`
type SettingResponse struct {
Config Setting `json:"config"`
Version string `json:"version,omitempty"`
FrontendTemplates []FrontendTemplate `json:"frontend_templates,omitempty"`

View File

@ -117,7 +117,7 @@ func BlockIP(db *gorm.DB, ip string, reason uint8, uid int64) error {
}
now := uint64(time.Now().Unix())
var count interface{}
var count any
if reason == WAFBlockReasonTypeManual {
count = 99999
} else {

View File

@ -33,9 +33,7 @@ func GjsonIter(json string) (iter.Seq2[string, string], error) {
return nil, ErrGjsonWrongType
}
return func(yield func(string, string) bool) {
result.ForEach(func(k, v gjson.Result) bool {
return yield(k.String(), v.String())
})
}, nil
return ConvertSeq2(result.ForEach, func(k, v gjson.Result) (string, string) {
return k.String(), v.String()
}), nil
}

71
pkg/utils/koanf.go Normal file
View File

@ -0,0 +1,71 @@
package utils
import (
"encoding"
"reflect"
"github.com/go-viper/mapstructure/v2"
)
// TextUnmarshalerHookFunc is a fixed version of mapstructure.TextUnmarshallerHookFunc.
// This hook allows to additionally unmarshal text into custom string types that implement the encoding.Text(Un)Marshaler interface(s).
func TextUnmarshalerHookFunc() mapstructure.DecodeHookFuncType {
return func(
f reflect.Type,
t reflect.Type,
data any,
) (any, error) {
if f.Kind() != reflect.String {
return data, nil
}
result := reflect.New(t).Interface()
unmarshaller, ok := result.(encoding.TextUnmarshaler)
if !ok {
return data, nil
}
// default text representation is the actual value of the `from` string
var (
dataVal = reflect.ValueOf(data)
text = []byte(dataVal.String())
)
if f.Kind() == t.Kind() {
// source and target are of underlying type string
var (
err error
ptrVal = reflect.New(dataVal.Type())
)
if !ptrVal.Elem().CanSet() {
// cannot set, skip, this should not happen
if err := unmarshaller.UnmarshalText(text); err != nil {
return nil, err
}
return result, nil
}
ptrVal.Elem().Set(dataVal)
// We need to assert that both, the value type and the pointer type
// do (not) implement the TextMarshaller interface before proceeding and simply
// using the string value of the string type.
// it might be the case that the internal string representation differs from
// the (un)marshalled string.
for _, v := range []reflect.Value{dataVal, ptrVal} {
if marshaller, ok := v.Interface().(encoding.TextMarshaler); ok {
text, err = marshaller.MarshalText()
if err != nil {
return nil, err
}
break
}
}
}
// text is either the source string's value or the source string type's marshaled value
// which may differ from its internal string value.
if err := unmarshaller.UnmarshalText(text); err != nil {
return nil, err
}
return result, nil
}
}

View File

@ -1,6 +1,7 @@
package utils
import (
"cmp"
"crypto/rand"
"errors"
"iter"
@ -13,24 +14,19 @@ import (
"strings"
"golang.org/x/exp/constraints"
jsoniter "github.com/json-iterator/go"
)
var (
Json = jsoniter.ConfigCompatibleWithStandardLibrary
DNSServers = []string{"8.8.8.8:53", "8.8.4.4:53", "1.1.1.1:53", "1.0.0.1:53"}
)
var ipv4Re = regexp.MustCompile(`(\d*\.).*(\.\d*)`)
ipv4Re = regexp.MustCompile(`(\d*\.).*(\.\d*)`)
ipv6Re = regexp.MustCompile(`(\w*:\w*:).*(:\w*:\w*)`)
)
func ipv4Desensitize(ipv4Addr string) string {
return ipv4Re.ReplaceAllString(ipv4Addr, "$1****$2")
}
var ipv6Re = regexp.MustCompile(`(\w*:\w*:).*(:\w*:\w*)`)
func ipv6Desensitize(ipv6Addr string) string {
return ipv6Re.ReplaceAllString(ipv6Addr, "$1****$2")
}
@ -51,9 +47,11 @@ func IPStringToBinary(ip string) ([]byte, error) {
}
func BinaryToIPString(b []byte) string {
var addr16 [16]byte
copy(addr16[:], b)
addr := netip.AddrFrom16(addr16)
if len(b) < 16 {
return "::"
}
addr := netip.AddrFrom16([16]byte(b))
return addr.Unmap().String()
}
@ -74,7 +72,7 @@ func GenerateRandomString(n int) (string, error) {
const letters = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
lettersLength := big.NewInt(int64(len(letters)))
ret := make([]byte, n)
for i := 0; i < n; i++ {
for i := range n {
num, err := rand.Int(rand.Reader, lettersLength)
if err != nil {
return "", err
@ -117,22 +115,35 @@ func MapValuesToSlice[Map ~map[K]V, K comparable, V any](m Map) []V {
return slices.AppendSeq(s, maps.Values(m))
}
func Unique[T comparable](s []T) []T {
m := make(map[T]struct{})
ret := make([]T, 0, len(s))
for _, v := range s {
if _, ok := m[v]; !ok {
m[v] = struct{}{}
ret = append(ret, v)
}
}
return ret
func MapKeysToSlice[Map ~map[K]V, K comparable, V any](m Map) []K {
s := make([]K, 0, len(m))
return slices.AppendSeq(s, maps.Keys(m))
}
func ConvertSeq[T, U any](seq iter.Seq[T], f func(e T) U) iter.Seq[U] {
return func(yield func(U) bool) {
for e := range seq {
if !yield(f(e)) {
func Unique[S ~[]E, E cmp.Ordered](list S) S {
if list == nil {
return nil
}
out := make([]E, len(list))
copy(out, list)
slices.Sort(out)
return slices.Compact(out)
}
func ConvertSeq[In, Out any](seq iter.Seq[In], f func(In) Out) iter.Seq[Out] {
return func(yield func(Out) bool) {
for in := range seq {
if !yield(f(in)) {
return
}
}
}
}
func ConvertSeq2[KIn, VIn, KOut, VOut any](seq iter.Seq2[KIn, VIn], f func(KIn, VIn) (KOut, VOut)) iter.Seq2[KOut, VOut] {
return func(yield func(KOut, VOut) bool) {
for k, v := range seq {
if !yield(f(k, v)) {
return
}
}

View File

@ -43,7 +43,7 @@ func TestNotification(t *testing.T) {
func TestGenerGenerateRandomString(t *testing.T) {
generatedString := make(map[string]bool)
for i := 0; i < 100; i++ {
for range 100 {
str, err := GenerateRandomString(32)
if err != nil {
t.Fatalf("Error: %s", err)

View File

@ -96,7 +96,7 @@ func OnRefreshOrAddAlert(alert *model.AlertRule) {
delete(alertsStore, alert.ID)
delete(alertsPrevState, alert.ID)
var isEdit bool
for i := 0; i < len(Alerts); i++ {
for i := range Alerts {
if Alerts[i].ID == alert.ID {
Alerts[i] = alert
isEdit = true

View File

@ -52,7 +52,7 @@ func OnUserUpdate(u *model.User) {
AgentSecretToUserId[u.AgentSecret] = u.ID
}
func OnUserDelete(id []uint64, errorFunc func(string, ...interface{}) error) error {
func OnUserDelete(id []uint64, errorFunc func(string, ...any) error) error {
UserLock.Lock()
defer UserLock.Unlock()