From aa20c973121e75496915471defe5205a2311bb88 Mon Sep 17 00:00:00 2001 From: naiba Date: Sun, 20 Oct 2024 23:23:04 +0800 Subject: [PATCH] refactor agent auth & server api --- .github/workflows/test.yml | 4 + cmd/dashboard/controller/api_v1.go | 2 +- cmd/dashboard/controller/common_page.go | 76 ++--------- cmd/dashboard/controller/controller.go | 12 +- cmd/dashboard/controller/jwt.go | 52 ++++++-- cmd/dashboard/controller/member_api.go | 159 ++++------------------- cmd/dashboard/controller/server.go | 130 ++++++++++++++++++ cmd/dashboard/controller/server_group.go | 36 +++++ cmd/dashboard/controller/ws.go | 96 ++++++++++++++ model/config.go | 29 +++-- model/notification_test.go | 2 - model/server.go | 43 ++---- model/server_api.go | 29 +++++ model/server_group.go | 1 + model/server_test.go | 30 ----- pkg/utils/utils.go | 14 ++ service/rpc/auth.go | 31 ++++- service/singleton/api.go | 58 ++++----- service/singleton/server.go | 14 +- 19 files changed, 488 insertions(+), 330 deletions(-) create mode 100644 cmd/dashboard/controller/server.go create mode 100644 cmd/dashboard/controller/server_group.go create mode 100644 cmd/dashboard/controller/ws.go create mode 100644 model/server_api.go delete mode 100644 model/server_test.go diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 6f13264..abd374c 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -8,6 +8,10 @@ on: - "go.sum" - "resource/**" - ".github/workflows/test.yml" + pull_request: + branches: + - master + jobs: tests: diff --git a/cmd/dashboard/controller/api_v1.go b/cmd/dashboard/controller/api_v1.go index 112e4cd..4eb24f1 100644 --- a/cmd/dashboard/controller/api_v1.go +++ b/cmd/dashboard/controller/api_v1.go @@ -72,7 +72,7 @@ func (v *apiV1) serverDetails(c *gin.Context) { } tag := c.Query("tag") if tag != "" { - c.JSON(200, singleton.ServerAPI.GetStatusByTag(tag)) + // c.JSON(200, singleton.ServerAPI.GetStatusByTag(tag)) return } if len(idList) != 0 { diff --git a/cmd/dashboard/controller/common_page.go b/cmd/dashboard/controller/common_page.go index 1fabf36..cceb987 100644 --- a/cmd/dashboard/controller/common_page.go +++ b/cmd/dashboard/controller/common_page.go @@ -10,7 +10,6 @@ import ( "github.com/gorilla/websocket" "github.com/hashicorp/go-uuid" "github.com/jinzhu/copier" - "golang.org/x/sync/singleflight" "github.com/naiba/nezha/model" "github.com/naiba/nezha/pkg/utils" @@ -21,8 +20,7 @@ import ( ) type commonPage struct { - r *gin.Engine - requestGroup singleflight.Group + r *gin.Engine } func (cp *commonPage) serve() { @@ -37,14 +35,13 @@ func (cp *commonPage) serve() { // TODO: 界面直接跳转使用该接口 cr.GET("/network/:id", cp.network) cr.GET("/network", cp.network) - cr.GET("/ws", cp.ws) cr.POST("/terminal", cp.createTerminal) cr.GET("/file", cp.createFM) cr.GET("/file/:id", cp.fm) } func (p *commonPage) service(c *gin.Context) { - res, _, _ := p.requestGroup.Do("servicePage", func() (interface{}, error) { + res, _, _ := requestGroup.Do("servicePage", func() (interface{}, error) { singleton.AlertsLock.RLock() defer singleton.AlertsLock.RUnlock() var stats map[uint64]model.ServiceItemResponse @@ -71,7 +68,7 @@ func (p *commonPage) service(c *gin.Context) { func (cp *commonPage) network(c *gin.Context) { var ( monitorHistory *model.MonitorHistory - servers []*model.Server + servers []model.Server serverIdsWithMonitor []uint64 monitorInfos = []byte("{}") id uint64 @@ -148,7 +145,7 @@ func (cp *commonPage) network(c *gin.Context) { for _, server := range singleton.SortedServerList { for _, id := range serverIdsWithMonitor { if server.ID == id { - servers = append(servers, server) + servers = append(servers, *server) } } } @@ -156,14 +153,14 @@ func (cp *commonPage) network(c *gin.Context) { for _, server := range singleton.SortedServerListForGuest { for _, id := range serverIdsWithMonitor { if server.ID == id { - servers = append(servers, server) + servers = append(servers, *server) } } } } - serversBytes, _ := utils.Json.Marshal(Data{ - Now: time.Now().Unix() * 1000, - Servers: servers, + serversBytes, _ := utils.Json.Marshal(model.StreamServerData{ + Now: time.Now().Unix() * 1000, + // Servers: servers, }) c.HTML(http.StatusOK, "", gin.H{ @@ -176,7 +173,7 @@ func (cp *commonPage) getServerStat(c *gin.Context, withPublicNote bool) ([]byte _, isMember := c.Get(model.CtxKeyAuthorizedUser) var isViewPasswordVerfied bool authorized := isMember || isViewPasswordVerfied - v, err, _ := cp.requestGroup.Do(fmt.Sprintf("serverStats::%t", authorized), func() (interface{}, error) { + v, err, _ := requestGroup.Do(fmt.Sprintf("serverStats::%t", authorized), func() (interface{}, error) { singleton.SortedServerLock.RLock() defer singleton.SortedServerLock.RUnlock() @@ -187,18 +184,18 @@ func (cp *commonPage) getServerStat(c *gin.Context, withPublicNote bool) ([]byte serverList = singleton.SortedServerListForGuest } - var servers []*model.Server + var servers []model.Server for _, server := range serverList { item := *server if !withPublicNote { item.PublicNote = "" } - servers = append(servers, &item) + servers = append(servers, item) } - return utils.Json.Marshal(Data{ - Now: time.Now().Unix() * 1000, - Servers: servers, + return utils.Json.Marshal(model.StreamServerData{ + Now: time.Now().Unix() * 1000, + // Servers: servers, }) }) return v.([]byte), err @@ -223,51 +220,6 @@ func (cp *commonPage) home(c *gin.Context) { }) } -var upgrader = websocket.Upgrader{ - ReadBufferSize: 32768, - WriteBufferSize: 32768, -} - -type Data struct { - Now int64 `json:"now,omitempty"` - Servers []*model.Server `json:"servers,omitempty"` -} - -func (cp *commonPage) ws(c *gin.Context) { - conn, err := upgrader.Upgrade(c.Writer, c.Request, nil) - if err != nil { - // mygin.ShowErrorPage(c, mygin.ErrInfo{ - // Code: http.StatusInternalServerError, - // // Title: singleton.Localizer.MustLocalize(&i18n.LocalizeConfig{ - // // MessageID: "NetworkError", - // // }), - // Msg: "Websocket协议切换失败", - // Link: "/", - // Btn: "返回首页", - // }, true) - return - } - defer conn.Close() - count := 0 - for { - stat, err := cp.getServerStat(c, false) - if err != nil { - continue - } - if err := conn.WriteMessage(websocket.TextMessage, stat); err != nil { - break - } - count += 1 - if count%4 == 0 { - err = conn.WriteMessage(websocket.PingMessage, []byte{}) - if err != nil { - break - } - } - time.Sleep(time.Second * 2) - } -} - func (cp *commonPage) terminal(c *gin.Context) { streamId := c.Param("id") if _, err := rpc.NezhaHandlerSingleton.GetStream(streamId); err != nil { diff --git a/cmd/dashboard/controller/controller.go b/cmd/dashboard/controller/controller.go index f2adeeb..3564299 100644 --- a/cmd/dashboard/controller/controller.go +++ b/cmd/dashboard/controller/controller.go @@ -49,13 +49,21 @@ func routers(r *gin.Engine) { if err != nil { log.Fatal("JWT Error:" + err.Error()) } + if err := authMiddleware.MiddlewareInit(); err != nil { + log.Fatal("authMiddleware.MiddlewareInit Error:" + err.Error()) + } api := r.Group("api/v1") - api.Use(handlerMiddleWare(authMiddleware)) - api.POST("/login", authMiddleware.LoginHandler) + unrequiredAuth := api.Group("", unrquiredAuthMiddleware(authMiddleware)) + unrequiredAuth.GET("/ws/server", serverStream) + unrequiredAuth.GET("/server-group", listServerGroup) + auth := api.Group("", authMiddleware.MiddlewareFunc()) auth.GET("/refresh_token", authMiddleware.RefreshHandler) + auth.PATCH("/server/:id", editServer) + + api.DELETE("/batch-delete/server", batchDeleteServer) // 通用页面 // cp := commonPage{r: r} diff --git a/cmd/dashboard/controller/jwt.go b/cmd/dashboard/controller/jwt.go index 1ae5733..de41342 100644 --- a/cmd/dashboard/controller/jwt.go +++ b/cmd/dashboard/controller/jwt.go @@ -1,8 +1,8 @@ package controller import ( + "encoding/json" "fmt" - "log" "net/http" "time" @@ -16,7 +16,7 @@ import ( func initParams() *jwt.GinJWTMiddleware { return &jwt.GinJWTMiddleware{ Realm: singleton.Conf.SiteName, - Key: []byte(singleton.Conf.SecretKey), + Key: []byte(singleton.Conf.JWTSecretKey), CookieName: "nz-jwt", Timeout: time.Hour, MaxRefresh: time.Hour, @@ -44,15 +44,6 @@ func initParams() *jwt.GinJWTMiddleware { } } -func handlerMiddleWare(authMiddleware *jwt.GinJWTMiddleware) gin.HandlerFunc { - return func(context *gin.Context) { - errInit := authMiddleware.MiddlewareInit() - if errInit != nil { - log.Fatal("authMiddleware.MiddlewareInit() Error:" + errInit.Error()) - } - } -} - func payloadFunc() func(data interface{}) jwt.MapClaims { return func(data interface{}) jwt.MapClaims { if v, ok := data.(string); ok { @@ -81,7 +72,7 @@ func identityHandler() func(c *gin.Context) interface{} { // @Schemes // @Description user login // @Accept json -// @param request body model.LoginRequest true "Login Request" +// @param loginRequest body model.LoginRequest true "Login Request" // @Produce json // @Success 200 {object} model.CommonResponse[model.LoginResponse] // @Router /login [post] @@ -152,3 +143,40 @@ func refreshResponse(c *gin.Context, code int, token string, expire time.Time) { }, }) } + +func unrquiredAuthMiddleware(mw *jwt.GinJWTMiddleware) func(c *gin.Context) { + return func(c *gin.Context) { + claims, err := mw.GetClaimsFromJWT(c) + if err != nil { + return + } + + switch v := claims["exp"].(type) { + case nil: + return + case float64: + if int64(v) < mw.TimeFunc().Unix() { + return + } + case json.Number: + n, err := v.Int64() + if err != nil { + return + } + if n < mw.TimeFunc().Unix() { + return + } + default: + return + } + + c.Set("JWT_PAYLOAD", claims) + identity := mw.IdentityHandler(c) + + if identity != nil { + c.Set(mw.IdentityKey, identity) + } + + c.Next() + } +} diff --git a/cmd/dashboard/controller/member_api.go b/cmd/dashboard/controller/member_api.go index 739db34..e869096 100644 --- a/cmd/dashboard/controller/member_api.go +++ b/cmd/dashboard/controller/member_api.go @@ -13,7 +13,6 @@ import ( "github.com/gin-gonic/gin" "github.com/jinzhu/copier" "golang.org/x/net/idna" - "gorm.io/gorm" "github.com/naiba/nezha/model" "github.com/naiba/nezha/pkg/utils" @@ -44,7 +43,6 @@ func (ma *memberAPI) serve() { mr.GET("/cron/:id/manual", ma.manualTrigger) mr.POST("/force-update", ma.forceUpdate) mr.POST("/batch-update-server-group", ma.batchUpdateServerGroup) - mr.POST("/batch-delete-server", ma.batchDeleteServer) mr.POST("/notification", ma.addOrEditNotification) mr.POST("/ddns", ma.addOrEditDDNS) mr.POST("/nat", ma.addOrEditNAT) @@ -188,25 +186,6 @@ func (ma *memberAPI) delete(c *gin.Context) { var err error switch c.Param("model") { - case "server": - err := singleton.DB.Transaction(func(tx *gorm.DB) error { - err = singleton.DB.Unscoped().Delete(&model.Server{}, "id = ?", id).Error - if err != nil { - return err - } - err = singleton.DB.Unscoped().Delete(&model.MonitorHistory{}, "server_id = ?", id).Error - if err != nil { - return err - } - return nil - }) - if err == nil { - // 删除服务器 - singleton.ServerLock.Lock() - onServerDelete(id) - singleton.ServerLock.Unlock() - singleton.ReSortServer() - } case "notification": err = singleton.DB.Unscoped().Delete(&model.Notification{}, "id = ?", id).Error if err == nil { @@ -346,10 +325,8 @@ func (ma *memberAPI) addOrEditServer(c *gin.Context) { err := c.ShouldBindJSON(&sf) if err == nil { s.Name = sf.Name - s.Secret = sf.Secret s.DisplayIndex = sf.DisplayIndex s.ID = sf.ID - s.Tag = sf.Tag s.Note = sf.Note s.PublicNote = sf.PublicNote s.HideForGuest = sf.HideForGuest == "on" @@ -358,7 +335,7 @@ func (ma *memberAPI) addOrEditServer(c *gin.Context) { err = utils.Json.Unmarshal([]byte(sf.DDNSProfilesRaw), &s.DDNSProfiles) if err == nil { if s.ID == 0 { - s.Secret, err = utils.GenerateRandomString(18) + _, err = utils.GenerateRandomString(18) if err == nil { err = singleton.DB.Create(&s).Error } @@ -378,34 +355,6 @@ func (ma *memberAPI) addOrEditServer(c *gin.Context) { if isEdit { singleton.ServerLock.Lock() s.CopyFromRunningServer(singleton.ServerList[s.ID]) - // 如果修改了 Secret - if s.Secret != singleton.ServerList[s.ID].Secret { - // 删除旧 Secret-ID 绑定关系 - singleton.SecretToID[s.Secret] = s.ID - // 设置新的 Secret-ID 绑定关系 - delete(singleton.SecretToID, singleton.ServerList[s.ID].Secret) - } - // 如果修改了Tag - oldTag := singleton.ServerList[s.ID].Tag - newTag := s.Tag - if newTag != oldTag { - index := -1 - for i := 0; i < len(singleton.ServerTagToIDList[oldTag]); i++ { - if singleton.ServerTagToIDList[oldTag][i] == s.ID { - index = i - break - } - } - if index > -1 { - // 删除旧 Tag-ID 绑定关系 - singleton.ServerTagToIDList[oldTag] = append(singleton.ServerTagToIDList[oldTag][:index], singleton.ServerTagToIDList[oldTag][index+1:]...) - if len(singleton.ServerTagToIDList[oldTag]) == 0 { - delete(singleton.ServerTagToIDList, oldTag) - } - } - // 设置新的 Tag-ID 绑定关系 - singleton.ServerTagToIDList[newTag] = append(singleton.ServerTagToIDList[newTag], s.ID) - } singleton.ServerList[s.ID] = &s singleton.ServerLock.Unlock() } else { @@ -413,9 +362,7 @@ func (ma *memberAPI) addOrEditServer(c *gin.Context) { s.State = &model.HostState{} s.TaskCloseLock = new(sync.Mutex) singleton.ServerLock.Lock() - singleton.SecretToID[s.Secret] = s.ID singleton.ServerList[s.ID] = &s - singleton.ServerTagToIDList[s.Tag] = append(singleton.ServerTagToIDList[s.Tag], s.ID) singleton.ServerLock.Unlock() } singleton.ReSortServer() @@ -636,28 +583,28 @@ func (ma *memberAPI) batchUpdateServerGroup(c *gin.Context) { serverId := req.Servers[i] var s model.Server copier.Copy(&s, singleton.ServerList[serverId]) - s.Tag = req.Group - // 如果修改了Ta - oldTag := singleton.ServerList[serverId].Tag - newTag := s.Tag - if newTag != oldTag { - index := -1 - for i := 0; i < len(singleton.ServerTagToIDList[oldTag]); i++ { - if singleton.ServerTagToIDList[oldTag][i] == s.ID { - index = i - break - } - } - if index > -1 { - // 删除旧 Tag-ID 绑定关系 - singleton.ServerTagToIDList[oldTag] = append(singleton.ServerTagToIDList[oldTag][:index], singleton.ServerTagToIDList[oldTag][index+1:]...) - if len(singleton.ServerTagToIDList[oldTag]) == 0 { - delete(singleton.ServerTagToIDList, oldTag) - } - } - // 设置新的 Tag-ID 绑定关系 - singleton.ServerTagToIDList[newTag] = append(singleton.ServerTagToIDList[newTag], s.ID) - } + // s.Tag = req.Group + // // 如果修改了Ta + // oldTag := singleton.ServerList[serverId].Tag + // newTag := s.Tag + // if newTag != oldTag { + // index := -1 + // for i := 0; i < len(singleton.ServerTagToIDList[oldTag]); i++ { + // if singleton.ServerTagToIDList[oldTag][i] == s.ID { + // index = i + // break + // } + // } + // if index > -1 { + // // 删除旧 Tag-ID 绑定关系 + // singleton.ServerTagToIDList[oldTag] = append(singleton.ServerTagToIDList[oldTag][:index], singleton.ServerTagToIDList[oldTag][index+1:]...) + // if len(singleton.ServerTagToIDList[oldTag]) == 0 { + // delete(singleton.ServerTagToIDList, oldTag) + // } + // } + // // 设置新的 Tag-ID 绑定关系 + // singleton.ServerTagToIDList[newTag] = append(singleton.ServerTagToIDList[newTag], s.ID) + // } singleton.ServerList[s.ID] = &s } @@ -1067,63 +1014,3 @@ func (ma *memberAPI) updateSetting(c *gin.Context) { Code: http.StatusOK, }) } - -func (ma *memberAPI) batchDeleteServer(c *gin.Context) { - var servers []uint64 - if err := c.ShouldBindJSON(&servers); err != nil { - c.JSON(http.StatusOK, model.Response{ - Code: http.StatusBadRequest, - Message: err.Error(), - }) - return - } - if err := singleton.DB.Unscoped().Delete(&model.Server{}, "id in (?)", servers).Error; err != nil { - c.JSON(http.StatusOK, model.Response{ - Code: http.StatusBadRequest, - Message: err.Error(), - }) - return - } - singleton.ServerLock.Lock() - for i := 0; i < len(servers); i++ { - id := servers[i] - onServerDelete(id) - } - singleton.ServerLock.Unlock() - singleton.ReSortServer() - c.JSON(http.StatusOK, model.Response{ - Code: http.StatusOK, - }) -} - -func onServerDelete(id uint64) { - tag := singleton.ServerList[id].Tag - delete(singleton.SecretToID, singleton.ServerList[id].Secret) - delete(singleton.ServerList, id) - index := -1 - for i := 0; i < len(singleton.ServerTagToIDList[tag]); i++ { - if singleton.ServerTagToIDList[tag][i] == id { - index = i - break - } - } - if index > -1 { - - singleton.ServerTagToIDList[tag] = append(singleton.ServerTagToIDList[tag][:index], singleton.ServerTagToIDList[tag][index+1:]...) - if len(singleton.ServerTagToIDList[tag]) == 0 { - delete(singleton.ServerTagToIDList, tag) - } - } - - singleton.AlertsLock.Lock() - for i := 0; i < len(singleton.Alerts); i++ { - if singleton.AlertsCycleTransferStatsStore[singleton.Alerts[i].ID] != nil { - delete(singleton.AlertsCycleTransferStatsStore[singleton.Alerts[i].ID].ServerName, id) - delete(singleton.AlertsCycleTransferStatsStore[singleton.Alerts[i].ID].Transfer, id) - delete(singleton.AlertsCycleTransferStatsStore[singleton.Alerts[i].ID].NextUpdate, id) - } - } - singleton.AlertsLock.Unlock() - - singleton.DB.Unscoped().Delete(&model.Transfer{}, "server_id = ?", id) -} diff --git a/cmd/dashboard/controller/server.go b/cmd/dashboard/controller/server.go new file mode 100644 index 0000000..daab265 --- /dev/null +++ b/cmd/dashboard/controller/server.go @@ -0,0 +1,130 @@ +package controller + +import ( + "net/http" + "strconv" + + "github.com/gin-gonic/gin" + "github.com/naiba/nezha/model" + "github.com/naiba/nezha/pkg/utils" + "github.com/naiba/nezha/service/singleton" +) + +// Edit server +// @Summary Edit server +// @Security BearerAuth +// @Schemes +// @Description Edit server +// @Tags auth required +// @Produce json +// @Success 200 {object} model.CommonResponse[any] +// @Router /server/{id} [patch] +func editServer(c *gin.Context) { + idStr := c.Param("id") + id, err := strconv.ParseUint(idStr, 10, 64) + if err != nil { + c.JSON(http.StatusOK, model.CommonResponse[interface{}]{ + Success: false, + Error: err.Error(), + }) + return + } + var sf model.EditServer + var s model.Server + if err := c.ShouldBindJSON(&sf); err != nil { + c.JSON(http.StatusOK, model.CommonResponse[interface{}]{ + Success: false, + Error: err.Error(), + }) + return + } + s.Name = sf.Name + s.DisplayIndex = sf.DisplayIndex + s.ID = id + s.Note = sf.Note + s.PublicNote = sf.PublicNote + s.HideForGuest = sf.HideForGuest + s.EnableDDNS = sf.EnableDDNS + s.DDNSProfiles = sf.DDNSProfiles + ddnsProfilesRaw, err := utils.Json.Marshal(s.DDNSProfiles) + if err != nil { + c.JSON(http.StatusOK, model.CommonResponse[interface{}]{ + Success: false, + Error: err.Error(), + }) + return + } + s.DDNSProfilesRaw = string(ddnsProfilesRaw) + + if err := singleton.DB.Save(&s).Error; err != nil { + c.JSON(http.StatusOK, model.CommonResponse[interface{}]{ + Success: false, + Error: err.Error(), + }) + return + } + + singleton.ServerLock.Lock() + s.CopyFromRunningServer(singleton.ServerList[s.ID]) + singleton.ServerList[s.ID] = &s + singleton.ServerLock.Unlock() + singleton.ReSortServer() + c.JSON(http.StatusOK, model.Response{ + Code: http.StatusOK, + }) +} + +// Batch delete server +// @Summary Batch delete server +// @Security BearerAuth +// @Schemes +// @Description Batch delete server +// @Tags auth required +// @Accept json +// @param request body []uint64 true "id list" +// @Produce json +// @Success 200 {object} model.CommonResponse[any] +// @Router /batch-delete/server [post] +func batchDeleteServer(c *gin.Context) { + var servers []uint64 + if err := c.ShouldBindJSON(&servers); err != nil { + c.JSON(http.StatusOK, model.CommonResponse[interface{}]{ + Success: false, + Error: err.Error(), + }) + return + } + + if err := singleton.DB.Unscoped().Delete(&model.Server{}, "id in (?)", servers).Error; err != nil { + c.JSON(http.StatusOK, model.CommonResponse[interface{}]{ + Success: false, + Error: err.Error(), + }) + return + } + + singleton.ServerLock.Lock() + for i := 0; i < len(servers); i++ { + id := servers[i] + delete(singleton.ServerList, id) + + singleton.AlertsLock.Lock() + for i := 0; i < len(singleton.Alerts); i++ { + if singleton.AlertsCycleTransferStatsStore[singleton.Alerts[i].ID] != nil { + delete(singleton.AlertsCycleTransferStatsStore[singleton.Alerts[i].ID].ServerName, id) + delete(singleton.AlertsCycleTransferStatsStore[singleton.Alerts[i].ID].Transfer, id) + delete(singleton.AlertsCycleTransferStatsStore[singleton.Alerts[i].ID].NextUpdate, id) + } + } + singleton.AlertsLock.Unlock() + + singleton.DB.Unscoped().Delete(&model.Transfer{}, "server_id = ?", id) + } + singleton.ServerLock.Unlock() + + singleton.ReSortServer() + + c.JSON(http.StatusOK, model.CommonResponse[interface{}]{ + Success: true, + }) +} diff --git a/cmd/dashboard/controller/server_group.go b/cmd/dashboard/controller/server_group.go new file mode 100644 index 0000000..5950b56 --- /dev/null +++ b/cmd/dashboard/controller/server_group.go @@ -0,0 +1,36 @@ +package controller + +import ( + "log" + "net/http" + + "github.com/gin-gonic/gin" + "github.com/naiba/nezha/model" + "github.com/naiba/nezha/pkg/utils" + "github.com/naiba/nezha/service/singleton" +) + +// List server group +// @Summary List server group +// @Schemes +// @Description List server group +// @Security BearerAuth +// @Tags common +// @Produce json +// @Success 200 {object} model.CommonResponse[[]model.ServerGroup] +// @Router /server-group [get] +func listServerGroup(c *gin.Context) { + authorizedUser, has := c.Get(model.CtxKeyAuthorizedUser) + log.Println("bingo test", authorizedUser, has) + var sg []model.ServerGroup + err := singleton.DB.Find(&sg).Error + c.JSON(http.StatusOK, model.CommonResponse[[]model.ServerGroup]{ + Success: err == nil, + Data: sg, + Error: utils.IfOrFn[string](err == nil, func() string { + return err.Error() + }, func() string { + return "" + }), + }) +} diff --git a/cmd/dashboard/controller/ws.go b/cmd/dashboard/controller/ws.go new file mode 100644 index 0000000..262ce74 --- /dev/null +++ b/cmd/dashboard/controller/ws.go @@ -0,0 +1,96 @@ +package controller + +import ( + "fmt" + "net/http" + "time" + + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" + "github.com/naiba/nezha/model" + "github.com/naiba/nezha/pkg/utils" + "github.com/naiba/nezha/service/singleton" + "golang.org/x/sync/singleflight" +) + +var upgrader = websocket.Upgrader{ + ReadBufferSize: 32768, + WriteBufferSize: 32768, +} + +// Websocket server stream +// @Summary Websocket server stream +// @tags common +// @Schemes +// @Description Websocket server stream +// @security BearerAuth +// @Produce json +// @Success 200 {object} model.StreamServerData +// @Router /ws/server [get] +func serverStream(c *gin.Context) { + conn, err := upgrader.Upgrade(c.Writer, c.Request, nil) + if err != nil { + c.JSON(http.StatusOK, model.CommonResponse[interface{}]{ + Success: false, + Error: err.Error(), + }) + return + } + defer conn.Close() + count := 0 + for { + stat, err := getServerStat(c, count == 0) + if err != nil { + continue + } + if err := conn.WriteMessage(websocket.TextMessage, stat); err != nil { + break + } + count += 1 + if count%4 == 0 { + err = conn.WriteMessage(websocket.PingMessage, []byte{}) + if err != nil { + break + } + } + time.Sleep(time.Second * 2) + } +} + +var requestGroup singleflight.Group + +func getServerStat(c *gin.Context, withPublicNote bool) ([]byte, error) { + _, isMember := c.Get(model.CtxKeyAuthorizedUser) + authorized := isMember // TODO || isViewPasswordVerfied + v, err, _ := requestGroup.Do(fmt.Sprintf("serverStats::%t", authorized), func() (interface{}, error) { + singleton.SortedServerLock.RLock() + defer singleton.SortedServerLock.RUnlock() + + var serverList []*model.Server + if authorized { + serverList = singleton.SortedServerList + } else { + serverList = singleton.SortedServerListForGuest + } + + var servers []model.StreamServer + for i := 0; i < len(serverList); i++ { + server := serverList[i] + servers = append(servers, model.StreamServer{ + ID: server.ID, + Name: server.Name, + PublicNote: utils.IfOr(withPublicNote, server.PublicNote, ""), + DisplayIndex: server.DisplayIndex, + Host: server.Host, + State: server.State, + LastActive: server.LastActive, + }) + } + + return utils.Json.Marshal(model.StreamServerData{ + Now: time.Now().Unix() * 1000, + Servers: servers, + }) + }) + return v.([]byte), err +} diff --git a/model/config.go b/model/config.go index 855f898..c3f239b 100644 --- a/model/config.go +++ b/model/config.go @@ -80,13 +80,14 @@ func (c *AgentConfig) Save() error { type Config struct { Debug bool // debug模式开关 - Language string // 系统语言,默认 zh-CN - SiteName string - SecretKey string - ListenPort uint - InstallHost string - TLS bool - Location string // 时区,默认为 Asia/Shanghai + Language string // 系统语言,默认 zh-CN + SiteName string + JWTSecretKey string + AgentSecretKey string + ListenPort uint + InstallHost string + TLS bool + Location string // 时区,默认为 Asia/Shanghai EnablePlainIPInNotification bool // 通知信息IP不打码 @@ -132,8 +133,18 @@ func (c *Config) Read(path string) error { if c.AvgPingCount == 0 { c.AvgPingCount = 2 } - if c.SecretKey == "" { - c.SecretKey, err = utils.GenerateRandomString(1024) + if c.JWTSecretKey == "" { + c.JWTSecretKey, err = utils.GenerateRandomString(1024) + if err != nil { + return err + } + if err = c.Save(); err != nil { + return err + } + } + + if c.AgentSecretKey == "" { + c.AgentSecretKey, err = utils.GenerateRandomString(32) if err != nil { return err } diff --git a/model/notification_test.go b/model/notification_test.go index 7a4302a..0a5112c 100644 --- a/model/notification_test.go +++ b/model/notification_test.go @@ -37,8 +37,6 @@ func execCase(t *testing.T, item testSt) { server := Server{ Common: Common{}, Name: "ServerName", - Tag: "", - Secret: "", Note: "", DisplayIndex: 0, Host: &Host{ diff --git a/model/server.go b/model/server.go index 2427ae3..090f186 100644 --- a/model/server.go +++ b/model/server.go @@ -1,8 +1,6 @@ package model import ( - "fmt" - "html/template" "log" "sync" "time" @@ -15,21 +13,21 @@ import ( type Server struct { Common - Name string - Tag string // 分组名 - Secret string `gorm:"uniqueIndex" json:"-"` - Note string `json:"-"` // 管理员可见备注 - PublicNote string `json:"PublicNote,omitempty"` // 公开备注 - DisplayIndex int // 展示排序,越大越靠前 - HideForGuest bool // 对游客隐藏 - EnableDDNS bool // 启用DDNS - DDNSProfiles []uint64 `gorm:"-" json:"-"` // DDNS配置 + Name string `json:"name,omitempty"` + UUID string `json:"uuid,omitempty" gorm:"unique"` + Note string `json:"note,omitempty"` // 管理员可见备注 + PublicNote string `json:"public_note,omitempty"` // 公开备注 + DisplayIndex int `json:"display_index,omitempty"` // 展示排序,越大越靠前 + HideForGuest bool `json:"hide_for_guest,omitempty"` // 对游客隐藏 + EnableDDNS bool `json:"enable_ddns,omitempty"` // 启用DDNS DDNSProfilesRaw string `gorm:"default:'[]';column:ddns_profiles_raw" json:"-"` - Host *Host `gorm:"-"` - State *HostState `gorm:"-"` - LastActive time.Time `gorm:"-"` + DDNSProfiles []uint64 `gorm:"-" json:"ddns_profiles,omitempty"` // DDNS配置 + + Host *Host `gorm:"-" json:"host,omitempty"` + State *HostState `gorm:"-" json:"state,omitempty"` + LastActive time.Time `gorm:"-" json:"last_active,omitempty"` TaskClose chan error `gorm:"-" json:"-"` TaskCloseLock *sync.Mutex `gorm:"-" json:"-"` @@ -59,20 +57,3 @@ func (s *Server) AfterFind(tx *gorm.DB) error { } return nil } - -func boolToString(b bool) string { - if b { - return "true" - } - return "false" -} - -func (s Server) MarshalForDashboard() template.JS { - name, _ := utils.Json.Marshal(s.Name) - tag, _ := utils.Json.Marshal(s.Tag) - note, _ := utils.Json.Marshal(s.Note) - secret, _ := utils.Json.Marshal(s.Secret) - ddnsProfilesRaw, _ := utils.Json.Marshal(s.DDNSProfilesRaw) - publicNote, _ := utils.Json.Marshal(s.PublicNote) - return template.JS(fmt.Sprintf(`{"ID":%d,"Name":%s,"Secret":%s,"DisplayIndex":%d,"Tag":%s,"Note":%s,"HideForGuest": %s,"EnableDDNS": %s,"DDNSProfilesRaw": %s,"PublicNote": %s}`, s.ID, name, secret, s.DisplayIndex, tag, note, boolToString(s.HideForGuest), boolToString(s.EnableDDNS), ddnsProfilesRaw, publicNote)) -} diff --git a/model/server_api.go b/model/server_api.go new file mode 100644 index 0000000..e0b8fc9 --- /dev/null +++ b/model/server_api.go @@ -0,0 +1,29 @@ +package model + +import "time" + +type StreamServer struct { + ID uint64 `json:"id,omitempty"` + Name string `json:"name,omitempty"` + PublicNote string `json:"public_note,omitempty"` // 公开备注,只第一个数据包有值 + DisplayIndex int `json:"display_index,omitempty"` // 展示排序,越大越靠前 + + Host *Host `json:"host,omitempty"` + State *HostState `json:"state,omitempty"` + LastActive time.Time `json:"last_active,omitempty"` +} + +type StreamServerData struct { + Now int64 `json:"now,omitempty"` + Servers []StreamServer `json:"servers,omitempty"` +} + +type EditServer struct { + Name string `json:"name,omitempty"` + Note string `json:"note,omitempty"` // 管理员可见备注 + PublicNote string `json:"public_note,omitempty"` // 公开备注 + DisplayIndex int `json:"display_index,omitempty"` // 展示排序,越大越靠前 + HideForGuest bool `json:"hide_for_guest,omitempty"` // 对游客隐藏 + EnableDDNS bool `json:"enable_ddns,omitempty"` // 启用DDNS + DDNSProfiles []uint64 `gorm:"-" json:"ddns_profiles,omitempty"` // DDNS配置 +} diff --git a/model/server_group.go b/model/server_group.go index db558e9..98f1400 100644 --- a/model/server_group.go +++ b/model/server_group.go @@ -2,5 +2,6 @@ package model type ServerGroup struct { Common + Name string `json:"name"` } diff --git a/model/server_test.go b/model/server_test.go deleted file mode 100644 index 7a4d91f..0000000 --- a/model/server_test.go +++ /dev/null @@ -1,30 +0,0 @@ -package model - -import ( - "testing" - - "github.com/naiba/nezha/pkg/utils" -) - -func TestServerMarshal(t *testing.T) { - patterns := []string{ - "asd > asd", - "asd \" asd", - "asd } asd", - } - - for i := 0; i < len(patterns); i++ { - server := Server{ - Name: patterns[i], - Tag: patterns[i], - } - serverStr := string(server.MarshalForDashboard()) - var serverRestore Server - if utils.Json.Unmarshal([]byte(serverStr), &serverRestore) != nil { - t.Fatalf("Error: %s", serverStr) - } - if server.Name != serverRestore.Name { - t.Fatalf("Expected %s, but got %s", server.Name, serverRestore.Name) - } - } -} diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go index 5c45c28..c47a4c9 100644 --- a/pkg/utils/utils.go +++ b/pkg/utils/utils.go @@ -90,3 +90,17 @@ func Uint64SubInt64(a uint64, b int64) uint64 { } return a - uint64(b) } + +func IfOr[T any](a bool, x, y T) T { + if a { + return x + } + return y +} + +func IfOrFn[T any](a bool, x, y func() T) T { + if a { + return x() + } + return y() +} diff --git a/service/rpc/auth.go b/service/rpc/auth.go index a189916..921bd12 100644 --- a/service/rpc/auth.go +++ b/service/rpc/auth.go @@ -2,20 +2,23 @@ package rpc import ( "context" + "sync" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" + "github.com/naiba/nezha/model" "github.com/naiba/nezha/service/singleton" ) type authHandler struct { ClientSecret string + ClientUUID string } func (a *authHandler) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { - return map[string]string{"client_secret": a.ClientSecret}, nil + return map[string]string{"client_secret": a.ClientSecret, "client_uuid": a.ClientUUID}, nil } func (a *authHandler) RequireTransportSecurity() bool { @@ -33,15 +36,29 @@ func (a *authHandler) Check(ctx context.Context) (uint64, error) { clientSecret = value[0] } + if clientSecret != singleton.Conf.AgentSecretKey { + return 0, status.Errorf(codes.Unauthenticated, "客户端认证失败") + } + + var clientUUID string + if value, ok := md["client_uuid"]; ok { + clientUUID = value[0] + } + singleton.ServerLock.RLock() defer singleton.ServerLock.RUnlock() - clientID, hasID := singleton.SecretToID[clientSecret] + clientID, hasID := singleton.ServerUUIDToID[clientUUID] if !hasID { - return 0, status.Errorf(codes.Unauthenticated, "客户端认证失败") - } - _, hasServer := singleton.ServerList[clientID] - if !hasServer { - return 0, status.Errorf(codes.Unauthenticated, "客户端认证失败") + s := model.Server{UUID: clientUUID} + if err := singleton.DB.Create(&s).Error; err != nil { + return 0, status.Errorf(codes.Unauthenticated, err.Error()) + } + s.Host = &model.Host{} + s.State = &model.HostState{} + s.TaskCloseLock = new(sync.Mutex) + singleton.ServerList[s.ID] = &s + singleton.ServerUUIDToID[clientUUID] = s.ID } + return clientID, nil } diff --git a/service/singleton/api.go b/service/singleton/api.go index f86cb72..0f813ac 100644 --- a/service/singleton/api.go +++ b/service/singleton/api.go @@ -103,9 +103,9 @@ func (s *ServerAPIService) GetStatusByIDList(idList []uint64) *ServerStatusRespo } ipv4, ipv6, validIP := utils.SplitIPAddr(server.Host.IP) info := CommonServerInfo{ - ID: server.ID, - Name: server.Name, - Tag: server.Tag, + ID: server.ID, + Name: server.Name, + // Tag: server.Tag, LastActive: server.LastActive.Unix(), IPV4: ipv4, IPV6: ipv6, @@ -125,9 +125,9 @@ func (s *ServerAPIService) GetStatusByIDList(idList []uint64) *ServerStatusRespo } // GetStatusByTag 获取传入分组的所有服务器状态信息 -func (s *ServerAPIService) GetStatusByTag(tag string) *ServerStatusResponse { - return s.GetStatusByIDList(ServerTagToIDList[tag]) -} +// func (s *ServerAPIService) GetStatusByTag(tag string) *ServerStatusResponse { +// return s.GetStatusByIDList(ServerTagToIDList[tag]) +// } // GetAllStatus 获取所有服务器状态信息 func (s *ServerAPIService) GetAllStatus() *ServerStatusResponse { @@ -143,9 +143,9 @@ func (s *ServerAPIService) GetAllStatus() *ServerStatusResponse { } ipv4, ipv6, validIP := utils.SplitIPAddr(host.IP) info := CommonServerInfo{ - ID: v.ID, - Name: v.Name, - Tag: v.Tag, + ID: v.ID, + Name: v.Name, + // Tag: v.Tag, LastActive: v.LastActive.Unix(), IPV4: ipv4, IPV6: ipv6, @@ -173,23 +173,23 @@ func (s *ServerAPIService) GetListByTag(tag string) *ServerInfoResponse { ServerLock.RLock() defer ServerLock.RUnlock() - for _, v := range ServerTagToIDList[tag] { - host := ServerList[v].Host - if host == nil { - continue - } - ipv4, ipv6, validIP := utils.SplitIPAddr(host.IP) - info := &CommonServerInfo{ - ID: v, - Name: ServerList[v].Name, - Tag: ServerList[v].Tag, - LastActive: ServerList[v].LastActive.Unix(), - IPV4: ipv4, - IPV6: ipv6, - ValidIP: validIP, - } - res.Result = append(res.Result, info) - } + // for _, v := range ServerTagToIDList[tag] { + // host := ServerList[v].Host + // if host == nil { + // continue + // } + // ipv4, ipv6, validIP := utils.SplitIPAddr(host.IP) + // info := &CommonServerInfo{ + // ID: v, + // Name: ServerList[v].Name, + // Tag: ServerList[v].Tag, + // LastActive: ServerList[v].LastActive.Unix(), + // IPV4: ipv4, + // IPV6: ipv6, + // ValidIP: validIP, + // } + // res.Result = append(res.Result, info) + // } res.CommonResponse = CommonResponse{ Code: 0, Message: "success", @@ -211,9 +211,9 @@ func (s *ServerAPIService) GetAllList() *ServerInfoResponse { } ipv4, ipv6, validIP := utils.SplitIPAddr(host.IP) info := &CommonServerInfo{ - ID: v.ID, - Name: v.Name, - Tag: v.Tag, + ID: v.ID, + Name: v.Name, + // Tag: v.Tag, LastActive: v.LastActive.Unix(), IPV4: ipv4, IPV6: ipv6, diff --git a/service/singleton/server.go b/service/singleton/server.go index 21b58a3..3404271 100644 --- a/service/singleton/server.go +++ b/service/singleton/server.go @@ -8,21 +8,18 @@ import ( ) var ( - ServerList map[uint64]*model.Server // [ServerID] -> model.Server - SecretToID map[string]uint64 // [ServerSecret] -> ServerID - ServerTagToIDList map[string][]uint64 // [ServerTag] -> ServerID - ServerLock sync.RWMutex + ServerList map[uint64]*model.Server // [ServerID] -> model.Server + ServerUUIDToID map[string]uint64 // [ServerUUID] -> ServerID + ServerLock sync.RWMutex SortedServerList []*model.Server // 用于存储服务器列表的 slice,按照服务器 ID 排序 SortedServerListForGuest []*model.Server SortedServerLock sync.RWMutex ) -// InitServer 初始化 ServerID <-> Secret 的映射 func InitServer() { ServerList = make(map[uint64]*model.Server) - SecretToID = make(map[string]uint64) - ServerTagToIDList = make(map[string][]uint64) + ServerUUIDToID = make(map[string]uint64) } // loadServers 加载服务器列表并根据ID排序 @@ -36,8 +33,7 @@ func loadServers() { innerS.State = &model.HostState{} innerS.TaskCloseLock = new(sync.Mutex) ServerList[innerS.ID] = &innerS - SecretToID[innerS.Secret] = innerS.ID - ServerTagToIDList[innerS.Tag] = append(ServerTagToIDList[innerS.Tag], innerS.ID) + ServerUUIDToID[innerS.UUID] = innerS.ID } ReSortServer() }