From 653d0cf2e96b9ae9409a3b12a5a1c94bb6ab92e9 Mon Sep 17 00:00:00 2001 From: UUBulb <35923940+uubulb@users.noreply.github.com> Date: Sun, 22 Dec 2024 00:05:41 +0800 Subject: [PATCH] feat: user roles (#852) * [WIP] feat: user roles * update * update * admin handler * update * feat: user-specific connection secret * simplify some logics * cleanup * update waf * update user api error handling * update waf api * fix codeql * update waf table * fix several problems * add pagination for waf api * update permission checks * switch to runtime check * 1 * cover? * some changes --- cmd/dashboard/controller/alertrule.go | 36 ++++- cmd/dashboard/controller/controller.go | 113 +++++++++++---- cmd/dashboard/controller/cron.go | 55 ++++++- cmd/dashboard/controller/ddns.go | 16 +++ cmd/dashboard/controller/fm.go | 19 ++- cmd/dashboard/controller/jwt.go | 11 +- cmd/dashboard/controller/nat.go | 37 ++++- cmd/dashboard/controller/notification.go | 16 ++- .../controller/notification_group.go | 55 ++++++- cmd/dashboard/controller/server.go | 29 ++++ cmd/dashboard/controller/server_group.go | 55 ++++++- cmd/dashboard/controller/service.go | 43 ++++++ cmd/dashboard/controller/terminal.go | 19 ++- cmd/dashboard/controller/user.go | 5 +- cmd/dashboard/controller/waf.go | 34 ++++- cmd/dashboard/rpc/rpc.go | 26 +++- go.mod | 4 +- model/alertrule.go | 13 +- model/api.go | 17 +++ model/common.go | 43 ++++++ model/user.go | 36 ++++- model/user_group.go | 6 - model/user_group_user.go | 7 - model/waf.go | 62 +++++--- pkg/utils/utils.go | 7 + service/rpc/auth.go | 23 ++- service/singleton/alertsentinel.go | 10 +- service/singleton/crontask.go | 6 +- service/singleton/ddns.go | 8 +- service/singleton/nat.go | 8 +- service/singleton/notification.go | 14 +- service/singleton/server.go | 34 ++--- service/singleton/servicesentinel.go | 14 +- service/singleton/singleton.go | 5 +- service/singleton/user.go | 135 ++++++++++++++++++ 35 files changed, 841 insertions(+), 180 deletions(-) delete mode 100644 model/user_group.go delete mode 100644 model/user_group_user.go create mode 100644 service/singleton/user.go diff --git a/cmd/dashboard/controller/alertrule.go b/cmd/dashboard/controller/alertrule.go index 708f9cf..2b22734 100644 --- a/cmd/dashboard/controller/alertrule.go +++ b/cmd/dashboard/controller/alertrule.go @@ -50,6 +50,9 @@ func createAlertRule(c *gin.Context) (uint64, error) { return 0, err } + uid := getUid(c) + + r.UserID = uid r.Name = arf.Name r.Rules = arf.Rules r.FailTriggerTasks = arf.FailTriggerTasks @@ -59,7 +62,7 @@ func createAlertRule(c *gin.Context) (uint64, error) { r.TriggerMode = arf.TriggerMode r.Enable = &enable - if err := validateRule(&r); err != nil { + if err := validateRule(c, &r); err != nil { return 0, err } @@ -100,6 +103,10 @@ func updateAlertRule(c *gin.Context) (any, error) { return nil, singleton.Localizer.ErrorT("alert id %d does not exist", id) } + if !r.HasPermission(c) { + return nil, singleton.Localizer.ErrorT("permission denied") + } + r.Name = arf.Name r.Rules = arf.Rules r.FailTriggerTasks = arf.FailTriggerTasks @@ -109,7 +116,7 @@ func updateAlertRule(c *gin.Context) (any, error) { r.TriggerMode = arf.TriggerMode r.Enable = &enable - if err := validateRule(&r); err != nil { + if err := validateRule(c, &r); err != nil { return 0, err } @@ -134,11 +141,21 @@ func updateAlertRule(c *gin.Context) (any, error) { // @Router /batch-delete/alert-rule [post] func batchDeleteAlertRule(c *gin.Context) (any, error) { var ar []uint64 - if err := c.ShouldBindJSON(&ar); err != nil { return nil, err } + var ars []model.AlertRule + if err := singleton.DB.Where("id in (?)", ar).Find(&ars).Error; err != nil { + return nil, err + } + + for _, a := range ars { + if !a.HasPermission(c) { + return nil, singleton.Localizer.ErrorT("permission denied") + } + } + if err := singleton.DB.Unscoped().Delete(&model.AlertRule{}, "id in (?)", ar).Error; err != nil { return nil, newGormError("%v", err) } @@ -147,9 +164,20 @@ func batchDeleteAlertRule(c *gin.Context) (any, error) { return nil, nil } -func validateRule(r *model.AlertRule) error { +func validateRule(c *gin.Context, r *model.AlertRule) error { if len(r.Rules) > 0 { for _, rule := range r.Rules { + singleton.ServerLock.RLock() + for s := range rule.Ignore { + if server, ok := singleton.ServerList[s]; ok { + if !server.HasPermission(c) { + singleton.ServerLock.RUnlock() + return singleton.Localizer.ErrorT("permission denied") + } + } + } + singleton.ServerLock.RUnlock() + if !rule.IsTransferDurationRule() { if rule.Duration < 3 { return singleton.Localizer.ErrorT("duration need to be at least 3") diff --git a/cmd/dashboard/controller/controller.go b/cmd/dashboard/controller/controller.go index 9f0f4af..898b4e6 100644 --- a/cmd/dashboard/controller/controller.go +++ b/cmd/dashboard/controller/controller.go @@ -9,6 +9,7 @@ import ( "net/http" "os" "path" + "slices" "strings" jwt "github.com/appleboy/gin-jwt/v2" @@ -78,11 +79,11 @@ func routers(r *gin.Engine, frontendDist fs.FS) { auth.GET("/profile", commonHandler(getProfile)) auth.POST("/profile", commonHandler(updateProfile)) - auth.GET("/user", commonHandler(listUser)) - auth.POST("/user", commonHandler(createUser)) - auth.POST("/batch-delete/user", commonHandler(batchDeleteUser)) + auth.GET("/user", adminHandler(listUser)) + auth.POST("/user", adminHandler(createUser)) + auth.POST("/batch-delete/user", adminHandler(batchDeleteUser)) - auth.GET("/service/list", commonHandler(listService)) + auth.GET("/service/list", listHandler(listService)) auth.POST("/service", commonHandler(createService)) auth.PATCH("/service/:id", commonHandler(updateService)) auth.POST("/batch-delete/service", commonHandler(batchDeleteService)) @@ -96,42 +97,42 @@ func routers(r *gin.Engine, frontendDist fs.FS) { auth.PATCH("/notification-group/:id", commonHandler(updateNotificationGroup)) auth.POST("/batch-delete/notification-group", commonHandler(batchDeleteNotificationGroup)) - auth.GET("/server", commonHandler(listServer)) + auth.GET("/server", listHandler(listServer)) auth.PATCH("/server/:id", commonHandler(updateServer)) auth.POST("/batch-delete/server", commonHandler(batchDeleteServer)) auth.POST("/force-update/server", commonHandler(forceUpdateServer)) - auth.GET("/notification", commonHandler(listNotification)) + auth.GET("/notification", listHandler(listNotification)) auth.POST("/notification", commonHandler(createNotification)) auth.PATCH("/notification/:id", commonHandler(updateNotification)) auth.POST("/batch-delete/notification", commonHandler(batchDeleteNotification)) - auth.GET("/alert-rule", commonHandler(listAlertRule)) + auth.GET("/alert-rule", listHandler(listAlertRule)) auth.POST("/alert-rule", commonHandler(createAlertRule)) auth.PATCH("/alert-rule/:id", commonHandler(updateAlertRule)) auth.POST("/batch-delete/alert-rule", commonHandler(batchDeleteAlertRule)) - auth.GET("/cron", commonHandler(listCron)) + auth.GET("/cron", listHandler(listCron)) auth.POST("/cron", commonHandler(createCron)) auth.PATCH("/cron/:id", commonHandler(updateCron)) auth.GET("/cron/:id/manual", commonHandler(manualTriggerCron)) auth.POST("/batch-delete/cron", commonHandler(batchDeleteCron)) - auth.GET("/ddns", commonHandler(listDDNS)) + auth.GET("/ddns", listHandler(listDDNS)) auth.GET("/ddns/providers", commonHandler(listProviders)) auth.POST("/ddns", commonHandler(createDDNS)) auth.PATCH("/ddns/:id", commonHandler(updateDDNS)) auth.POST("/batch-delete/ddns", commonHandler(batchDeleteDDNS)) - auth.GET("/nat", commonHandler(listNAT)) + auth.GET("/nat", listHandler(listNAT)) auth.POST("/nat", commonHandler(createNAT)) auth.PATCH("/nat/:id", commonHandler(updateNAT)) auth.POST("/batch-delete/nat", commonHandler(batchDeleteNAT)) - auth.GET("/waf", commonHandler(listBlockedAddress)) - auth.POST("/batch-delete/waf", commonHandler(batchDeleteBlockedAddress)) + auth.GET("/waf", pCommonHandler(listBlockedAddress)) + auth.POST("/batch-delete/waf", adminHandler(batchDeleteBlockedAddress)) - auth.PATCH("/setting", commonHandler(updateConfig)) + auth.PATCH("/setting", adminHandler(updateConfig)) r.NoRoute(fallbackToFrontend(frontendDist)) } @@ -152,6 +153,7 @@ func newErrorResponse(err error) model.CommonResponse[any] { } type handlerFunc[T any] func(c *gin.Context) (T, error) +type pHandlerFunc[S ~[]E, E any] func(c *gin.Context) (*model.Value[S], error) // There are many error types in gorm, so create a custom type to represent all // gorm errors here instead @@ -189,29 +191,86 @@ func (we *wsError) Error() string { func commonHandler[T any](handler handlerFunc[T]) func(*gin.Context) { return func(c *gin.Context) { - data, err := handler(c) - if err == nil { - c.JSON(http.StatusOK, model.CommonResponse[T]{Success: true, Data: data}) + handle(c, handler) + } +} + +func adminHandler[T any](handler handlerFunc[T]) func(*gin.Context) { + return func(c *gin.Context) { + auth, ok := c.Get(model.CtxKeyAuthorizedUser) + if !ok { + c.JSON(http.StatusOK, newErrorResponse(singleton.Localizer.ErrorT("unauthorized"))) return } - switch err.(type) { - case *gormError: - log.Printf("NEZHA>> gorm error: %v", err) - c.JSON(http.StatusOK, newErrorResponse(singleton.Localizer.ErrorT("database error"))) + + user := *auth.(*model.User) + if user.Role != model.RoleAdmin { + c.JSON(http.StatusOK, newErrorResponse(singleton.Localizer.ErrorT("permission denied"))) return - case *wsError: - // Connection is upgraded to WebSocket, so c.Writer is no longer usable - if msg := err.Error(); msg != "" { - log.Printf("NEZHA>> websocket error: %v", err) - } - return - default: + } + + handle(c, handler) + } +} + +func handle[T any](c *gin.Context, handler handlerFunc[T]) { + data, err := handler(c) + if err == nil { + c.JSON(http.StatusOK, model.CommonResponse[T]{Success: true, Data: data}) + return + } + switch err.(type) { + case *gormError: + log.Printf("NEZHA>> gorm error: %v", err) + c.JSON(http.StatusOK, newErrorResponse(singleton.Localizer.ErrorT("database error"))) + return + case *wsError: + // Connection is upgraded to WebSocket, so c.Writer is no longer usable + if msg := err.Error(); msg != "" { + log.Printf("NEZHA>> websocket error: %v", err) + } + return + default: + c.JSON(http.StatusOK, newErrorResponse(err)) + return + } +} + +func listHandler[S ~[]E, E model.CommonInterface](handler handlerFunc[S]) func(*gin.Context) { + return func(c *gin.Context) { + data, err := handler(c) + if err != nil { c.JSON(http.StatusOK, newErrorResponse(err)) return } + + c.JSON(http.StatusOK, model.CommonResponse[S]{Success: true, Data: filter(c, data)}) } } +func pCommonHandler[S ~[]E, E any](handler pHandlerFunc[S, E]) func(*gin.Context) { + return func(c *gin.Context) { + data, err := handler(c) + if err != nil { + c.JSON(http.StatusOK, newErrorResponse(err)) + return + } + + c.JSON(http.StatusOK, model.PaginatedResponse[S, E]{Success: true, Data: data}) + } +} + +func filter[S ~[]E, E model.CommonInterface](ctx *gin.Context, s S) S { + return slices.DeleteFunc(s, func(e E) bool { + return !e.HasPermission(ctx) + }) +} + +func getUid(c *gin.Context) uint64 { + user, _ := c.MustGet(model.CtxKeyAuthorizedUser).(*model.User) + return user.ID +} + func fallbackToFrontend(frontendDist fs.FS) func(*gin.Context) { checkLocalFileOrFs := func(c *gin.Context, fs fs.FS, path string) bool { if _, err := os.Stat(path); err == nil { diff --git a/cmd/dashboard/controller/cron.go b/cmd/dashboard/controller/cron.go index d8d95de..4c54541 100644 --- a/cmd/dashboard/controller/cron.go +++ b/cmd/dashboard/controller/cron.go @@ -1,7 +1,6 @@ package controller import ( - "fmt" "strconv" "github.com/gin-gonic/gin" @@ -50,6 +49,18 @@ func createCron(c *gin.Context) (uint64, error) { return 0, err } + singleton.ServerLock.RLock() + for _, sid := range cf.Servers { + if server, ok := singleton.ServerList[sid]; ok { + if !server.HasPermission(c) { + singleton.ServerLock.RUnlock() + return 0, singleton.Localizer.ErrorT("permission denied") + } + } + } + singleton.ServerLock.RUnlock() + + cr.UserID = getUid(c) cr.TaskType = cf.TaskType cr.Name = cf.Name cr.Scheduler = cf.Scheduler @@ -104,9 +115,24 @@ func updateCron(c *gin.Context) (any, error) { return 0, err } + singleton.ServerLock.RLock() + for _, sid := range cf.Servers { + if server, ok := singleton.ServerList[sid]; ok { + if !server.HasPermission(c) { + singleton.ServerLock.RUnlock() + return nil, singleton.Localizer.ErrorT("permission denied") + } + } + } + singleton.ServerLock.RUnlock() + var cr model.Cron if err := singleton.DB.First(&cr, id).Error; err != nil { - return nil, fmt.Errorf("task id %d does not exist", id) + return nil, singleton.Localizer.ErrorT("task id %d does not exist", id) + } + + if !cr.HasPermission(c) { + return nil, singleton.Localizer.ErrorT("permission denied") } cr.TaskType = cf.TaskType @@ -156,12 +182,19 @@ func manualTriggerCron(c *gin.Context) (any, error) { return nil, err } - var cr model.Cron - if err := singleton.DB.First(&cr, id).Error; err != nil { + singleton.CronLock.RLock() + cr, ok := singleton.Crons[id] + if !ok { + singleton.CronLock.RUnlock() return nil, singleton.Localizer.ErrorT("task id %d does not exist", id) } + singleton.CronLock.RUnlock() - singleton.ManualTrigger(&cr) + if !cr.HasPermission(c) { + return nil, singleton.Localizer.ErrorT("permission denied") + } + + singleton.ManualTrigger(cr) return nil, nil } @@ -178,11 +211,21 @@ func manualTriggerCron(c *gin.Context) (any, error) { // @Router /batch-delete/cron [post] func batchDeleteCron(c *gin.Context) (any, error) { var cr []uint64 - if err := c.ShouldBindJSON(&cr); err != nil { return nil, err } + singleton.CronLock.RLock() + for _, crID := range cr { + if crn, ok := singleton.Crons[crID]; ok { + if !crn.HasPermission(c) { + singleton.CronLock.RUnlock() + return nil, singleton.Localizer.ErrorT("permission denied") + } + } + } + singleton.CronLock.RUnlock() + if err := singleton.DB.Unscoped().Delete(&model.Cron{}, "id in (?)", cr).Error; err != nil { return nil, newGormError("%v", err) } diff --git a/cmd/dashboard/controller/ddns.go b/cmd/dashboard/controller/ddns.go index b58c28e..c8a168b 100644 --- a/cmd/dashboard/controller/ddns.go +++ b/cmd/dashboard/controller/ddns.go @@ -56,6 +56,7 @@ func createDDNS(c *gin.Context) (uint64, error) { return 0, singleton.Localizer.ErrorT("the retry count must be an integer between 1 and 10") } + p.UserID = getUid(c) p.Name = df.Name enableIPv4 := df.EnableIPv4 enableIPv6 := df.EnableIPv6 @@ -125,6 +126,10 @@ func updateDDNS(c *gin.Context) (any, error) { return nil, singleton.Localizer.ErrorT("profile id %d does not exist", id) } + if !p.HasPermission(c) { + return nil, singleton.Localizer.ErrorT("permission denied") + } + p.Name = df.Name enableIPv4 := df.EnableIPv4 enableIPv6 := df.EnableIPv6 @@ -178,6 +183,17 @@ func batchDeleteDDNS(c *gin.Context) (any, error) { return nil, err } + singleton.DDNSCacheLock.RLock() + for _, pid := range ddnsConfigs { + if p, ok := singleton.DDNSCache[pid]; ok { + if !p.HasPermission(c) { + singleton.DDNSCacheLock.RUnlock() + return nil, singleton.Localizer.ErrorT("permission denied") + } + } + } + singleton.DDNSCacheLock.RUnlock() + if err := singleton.DB.Unscoped().Delete(&model.DDNSProfile{}, "id in (?)", ddnsConfigs).Error; err != nil { return nil, newGormError("%v", err) } diff --git a/cmd/dashboard/controller/fm.go b/cmd/dashboard/controller/fm.go index c413d62..955970b 100644 --- a/cmd/dashboard/controller/fm.go +++ b/cmd/dashboard/controller/fm.go @@ -7,6 +7,7 @@ import ( "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "github.com/hashicorp/go-uuid" + "github.com/nezhahq/nezha/model" "github.com/nezhahq/nezha/pkg/utils" "github.com/nezhahq/nezha/pkg/websocketx" @@ -31,13 +32,6 @@ func createFM(c *gin.Context) (*model.CreateFMResponse, error) { return nil, err } - streamId, err := uuid.GenerateUUID() - if err != nil { - return nil, err - } - - rpc.NezhaHandlerSingleton.CreateStream(streamId) - singleton.ServerLock.RLock() server := singleton.ServerList[id] singleton.ServerLock.RUnlock() @@ -45,6 +39,17 @@ func createFM(c *gin.Context) (*model.CreateFMResponse, error) { return nil, singleton.Localizer.ErrorT("server not found or not connected") } + if !server.HasPermission(c) { + return nil, singleton.Localizer.ErrorT("permission denied") + } + + streamId, err := uuid.GenerateUUID() + if err != nil { + return nil, err + } + + rpc.NezhaHandlerSingleton.CreateStream(streamId) + fmData, _ := utils.Json.Marshal(&model.TaskFM{ StreamID: streamId, }) diff --git a/cmd/dashboard/controller/jwt.go b/cmd/dashboard/controller/jwt.go index 2e4a462..72a261e 100644 --- a/cmd/dashboard/controller/jwt.go +++ b/cmd/dashboard/controller/jwt.go @@ -88,18 +88,21 @@ func authenticator() func(c *gin.Context) (interface{}, error) { } var user model.User + realip := c.GetString(model.CtxKeyRealIPStr) if err := singleton.DB.Select("id", "password").Where("username = ?", loginVals.Username).First(&user).Error; err != nil { if err == gorm.ErrRecordNotFound { - model.BlockIP(singleton.DB, c.GetString(model.CtxKeyRealIPStr), model.WAFBlockReasonTypeLoginFail) + model.BlockIP(singleton.DB, realip, model.WAFBlockReasonTypeLoginFail, model.BlockIDUnknownUser) } return nil, jwt.ErrFailedAuthentication } if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(loginVals.Password)); err != nil { - model.BlockIP(singleton.DB, c.GetString(model.CtxKeyRealIPStr), model.WAFBlockReasonTypeLoginFail) + model.BlockIP(singleton.DB, realip, model.WAFBlockReasonTypeLoginFail, int64(user.ID)) return nil, jwt.ErrFailedAuthentication } + model.ClearIP(singleton.DB, realip, model.BlockIDUnknownUser) + model.ClearIP(singleton.DB, realip, int64(user.ID)) return utils.Itoa(user.ID), nil } } @@ -169,10 +172,10 @@ func optionalAuthMiddleware(mw *jwt.GinJWTMiddleware) func(c *gin.Context) { identity := mw.IdentityHandler(c) if identity != nil { - model.ClearIP(singleton.DB, c.GetString(model.CtxKeyRealIPStr)) + model.ClearIP(singleton.DB, c.GetString(model.CtxKeyRealIPStr), model.BlockIDToken) c.Set(mw.IdentityKey, identity) } else { - if err := model.BlockIP(singleton.DB, c.GetString(model.CtxKeyRealIPStr), model.WAFBlockReasonTypeBruteForceToken); err != nil { + if err := model.BlockIP(singleton.DB, c.GetString(model.CtxKeyRealIPStr), model.WAFBlockReasonTypeBruteForceToken, model.BlockIDToken); err != nil { waf.ShowBlockPage(c, err) return } diff --git a/cmd/dashboard/controller/nat.go b/cmd/dashboard/controller/nat.go index 6a2737e..20260ad 100644 --- a/cmd/dashboard/controller/nat.go +++ b/cmd/dashboard/controller/nat.go @@ -51,6 +51,18 @@ func createNAT(c *gin.Context) (uint64, error) { return 0, err } + singleton.ServerLock.RLock() + if server, ok := singleton.ServerList[nf.ServerID]; ok { + if !server.HasPermission(c) { + singleton.ServerLock.RUnlock() + return 0, singleton.Localizer.ErrorT("permission denied") + } + } + singleton.ServerLock.RUnlock() + + uid := getUid(c) + + n.UserID = uid n.Name = nf.Name n.Domain = nf.Domain n.Host = nf.Host @@ -90,11 +102,24 @@ func updateNAT(c *gin.Context) (any, error) { return nil, err } + singleton.ServerLock.RLock() + if server, ok := singleton.ServerList[nf.ServerID]; ok { + if !server.HasPermission(c) { + singleton.ServerLock.RUnlock() + return nil, singleton.Localizer.ErrorT("permission denied") + } + } + singleton.ServerLock.RUnlock() + var n model.NAT if err = singleton.DB.First(&n, id).Error; err != nil { return nil, singleton.Localizer.ErrorT("profile id %d does not exist", id) } + if !n.HasPermission(c) { + return nil, singleton.Localizer.ErrorT("permission denied") + } + n.Name = nf.Name n.Domain = nf.Domain n.Host = nf.Host @@ -122,11 +147,21 @@ func updateNAT(c *gin.Context) (any, error) { // @Router /batch-delete/nat [post] func batchDeleteNAT(c *gin.Context) (any, error) { var n []uint64 - if err := c.ShouldBindJSON(&n); err != nil { return nil, err } + singleton.NATCacheRwLock.RLock() + for _, id := range n { + if p, ok := singleton.NATCache[singleton.NATIDToDomain[id]]; ok { + if !p.HasPermission(c) { + singleton.NATCacheRwLock.RUnlock() + return nil, singleton.Localizer.ErrorT("permission denied") + } + } + } + singleton.NATCacheRwLock.RUnlock() + if err := singleton.DB.Unscoped().Delete(&model.NAT{}, "id in (?)", n).Error; err != nil { return nil, newGormError("%v", err) } diff --git a/cmd/dashboard/controller/notification.go b/cmd/dashboard/controller/notification.go index 78ac770..1180dcb 100644 --- a/cmd/dashboard/controller/notification.go +++ b/cmd/dashboard/controller/notification.go @@ -48,6 +48,7 @@ func createNotification(c *gin.Context) (uint64, error) { } var n model.Notification + n.UserID = getUid(c) n.Name = nf.Name n.RequestMethod = nf.RequestMethod n.RequestType = nf.RequestType @@ -106,6 +107,10 @@ func updateNotification(c *gin.Context) (any, error) { return nil, singleton.Localizer.ErrorT("notification id %d does not exist", id) } + if !n.HasPermission(c) { + return nil, singleton.Localizer.ErrorT("permission denied") + } + n.Name = nf.Name n.RequestMethod = nf.RequestMethod n.RequestType = nf.RequestType @@ -149,11 +154,20 @@ func updateNotification(c *gin.Context) (any, error) { // @Router /batch-delete/notification [post] func batchDeleteNotification(c *gin.Context) (any, error) { var n []uint64 - if err := c.ShouldBindJSON(&n); err != nil { return nil, err } + singleton.NotificationsLock.RLock() + for _, nid := range n { + if ns, ok := singleton.NotificationMap[nid]; ok { + if !ns.HasPermission(c) { + return nil, singleton.Localizer.ErrorT("permission denied") + } + } + } + singleton.NotificationsLock.RUnlock() + err := singleton.DB.Transaction(func(tx *gorm.DB) error { if err := tx.Unscoped().Delete(&model.Notification{}, "id in (?)", n).Error; err != nil { return err diff --git a/cmd/dashboard/controller/notification_group.go b/cmd/dashboard/controller/notification_group.go index 2e74dba..8f5eef6 100644 --- a/cmd/dashboard/controller/notification_group.go +++ b/cmd/dashboard/controller/notification_group.go @@ -20,7 +20,7 @@ import ( // @Produce json // @Success 200 {object} model.CommonResponse[[]model.NotificationGroupResponseItem] // @Router /notification-group [get] -func listNotificationGroup(c *gin.Context) ([]model.NotificationGroupResponseItem, error) { +func listNotificationGroup(c *gin.Context) ([]*model.NotificationGroupResponseItem, error) { var ng []model.NotificationGroup if err := singleton.DB.Find(&ng).Error; err != nil { return nil, err @@ -39,9 +39,9 @@ func listNotificationGroup(c *gin.Context) ([]model.NotificationGroupResponseIte groupNotifications[n.NotificationGroupID] = append(groupNotifications[n.NotificationGroupID], n.NotificationID) } - ngRes := make([]model.NotificationGroupResponseItem, 0, len(ng)) + ngRes := make([]*model.NotificationGroupResponseItem, 0, len(ng)) for _, n := range ng { - ngRes = append(ngRes, model.NotificationGroupResponseItem{ + ngRes = append(ngRes, &model.NotificationGroupResponseItem{ Group: n, Notifications: groupNotifications[n.ID], }) @@ -68,8 +68,22 @@ func createNotificationGroup(c *gin.Context) (uint64, error) { } ngf.Notifications = slices.Compact(ngf.Notifications) + singleton.NotificationsLock.RLock() + for _, nid := range ngf.Notifications { + if n, ok := singleton.NotificationMap[nid]; ok { + if !n.HasPermission(c) { + singleton.NotificationsLock.RUnlock() + return 0, singleton.Localizer.ErrorT("permission denied") + } + } + } + singleton.NotificationsLock.RUnlock() + + uid := getUid(c) + var ng model.NotificationGroup ng.Name = ngf.Name + ng.UserID = uid var count int64 if err := singleton.DB.Model(&model.Notification{}).Where("id in (?)", ngf.Notifications).Count(&count).Error; err != nil { @@ -86,6 +100,9 @@ func createNotificationGroup(c *gin.Context) (uint64, error) { } for _, n := range ngf.Notifications { if err := tx.Create(&model.NotificationGroupNotification{ + Common: model.Common{ + UserID: uid, + }, NotificationGroupID: ng.ID, NotificationID: n, }).Error; err != nil { @@ -126,11 +143,27 @@ func updateNotificationGroup(c *gin.Context) (any, error) { if err := c.ShouldBindJSON(&ngf); err != nil { return nil, err } + + singleton.NotificationsLock.RLock() + for _, nid := range ngf.Notifications { + if n, ok := singleton.NotificationMap[nid]; ok { + if !n.HasPermission(c) { + singleton.NotificationsLock.RUnlock() + return nil, singleton.Localizer.ErrorT("permission denied") + } + } + } + singleton.NotificationsLock.RUnlock() + var ngDB model.NotificationGroup if err := singleton.DB.First(&ngDB, id).Error; err != nil { return nil, singleton.Localizer.ErrorT("group id %d does not exist", id) } + if !ngDB.HasPermission(c) { + return nil, singleton.Localizer.ErrorT("permission denied") + } + ngDB.Name = ngf.Name ngf.Notifications = slices.Compact(ngf.Notifications) @@ -142,6 +175,8 @@ func updateNotificationGroup(c *gin.Context) (any, error) { return nil, singleton.Localizer.ErrorT("have invalid notification id") } + uid := getUid(c) + err = singleton.DB.Transaction(func(tx *gorm.DB) error { if err := tx.Save(&ngDB).Error; err != nil { return err @@ -152,6 +187,9 @@ func updateNotificationGroup(c *gin.Context) (any, error) { for _, n := range ngf.Notifications { if err := tx.Create(&model.NotificationGroupNotification{ + Common: model.Common{ + UserID: uid, + }, NotificationGroupID: ngDB.ID, NotificationID: n, }).Error; err != nil { @@ -185,6 +223,17 @@ func batchDeleteNotificationGroup(c *gin.Context) (any, error) { return nil, err } + var ng []model.NotificationGroup + if err := singleton.DB.Where("id in (?)", ngn).Find(&ng).Error; err != nil { + return nil, err + } + + for _, n := range ng { + if !n.HasPermission(c) { + return nil, singleton.Localizer.ErrorT("permission denied") + } + } + err := singleton.DB.Transaction(func(tx *gorm.DB) error { if err := tx.Unscoped().Delete(&model.NotificationGroup{}, "id in (?)", ngn).Error; err != nil { return err diff --git a/cmd/dashboard/controller/server.go b/cmd/dashboard/controller/server.go index 48bd56e..85d9fb7 100644 --- a/cmd/dashboard/controller/server.go +++ b/cmd/dashboard/controller/server.go @@ -56,11 +56,26 @@ func updateServer(c *gin.Context) (any, error) { return nil, err } + singleton.DDNSCacheLock.RLock() + for _, pid := range sf.DDNSProfiles { + if p, ok := singleton.DDNSCache[pid]; ok { + if !p.HasPermission(c) { + singleton.DDNSCacheLock.RUnlock() + return nil, singleton.Localizer.ErrorT("permission denied") + } + } + } + singleton.DDNSCacheLock.RUnlock() + var s model.Server if err := singleton.DB.First(&s, id).Error; err != nil { return nil, singleton.Localizer.ErrorT("server id %d does not exist", id) } + if !s.HasPermission(c) { + return nil, singleton.Localizer.ErrorT("permission denied") + } + s.Name = sf.Name s.DisplayIndex = sf.DisplayIndex s.Note = sf.Note @@ -104,6 +119,17 @@ func batchDeleteServer(c *gin.Context) (any, error) { return nil, err } + singleton.ServerLock.RLock() + for _, sid := range servers { + if s, ok := singleton.ServerList[sid]; ok { + if !s.HasPermission(c) { + singleton.ServerLock.RUnlock() + return nil, singleton.Localizer.ErrorT("permission denied") + } + } + } + singleton.ServerLock.RUnlock() + err := singleton.DB.Transaction(func(tx *gorm.DB) error { if err := tx.Unscoped().Delete(&model.Server{}, "id in (?)", servers).Error; err != nil { return err @@ -161,6 +187,9 @@ func forceUpdateServer(c *gin.Context) (*model.ForceUpdateResponse, error) { server := singleton.ServerList[sid] singleton.ServerLock.RUnlock() if server != nil && server.TaskStream != nil { + if !server.HasPermission(c) { + return nil, singleton.Localizer.ErrorT("permission denied") + } if err := server.TaskStream.Send(&pb.Task{ Type: model.TaskTypeUpgrade, }); err != nil { diff --git a/cmd/dashboard/controller/server_group.go b/cmd/dashboard/controller/server_group.go index 98b5c6f..7655a9b 100644 --- a/cmd/dashboard/controller/server_group.go +++ b/cmd/dashboard/controller/server_group.go @@ -20,7 +20,7 @@ import ( // @Produce json // @Success 200 {object} model.CommonResponse[[]model.ServerGroupResponseItem] // @Router /server-group [get] -func listServerGroup(c *gin.Context) ([]model.ServerGroupResponseItem, error) { +func listServerGroup(c *gin.Context) ([]*model.ServerGroupResponseItem, error) { var sg []model.ServerGroup if err := singleton.DB.Find(&sg).Error; err != nil { return nil, err @@ -38,9 +38,9 @@ func listServerGroup(c *gin.Context) ([]model.ServerGroupResponseItem, error) { groupServers[s.ServerGroupId] = append(groupServers[s.ServerGroupId], s.ServerId) } - var sgRes []model.ServerGroupResponseItem + var sgRes []*model.ServerGroupResponseItem for _, s := range sg { - sgRes = append(sgRes, model.ServerGroupResponseItem{ + sgRes = append(sgRes, &model.ServerGroupResponseItem{ Group: s, Servers: groupServers[s.ID], }) @@ -67,8 +67,22 @@ func createServerGroup(c *gin.Context) (uint64, error) { } sgf.Servers = slices.Compact(sgf.Servers) + singleton.ServerLock.RLock() + for _, sid := range sgf.Servers { + if server, ok := singleton.ServerList[sid]; ok { + if !server.HasPermission(c) { + singleton.ServerLock.RUnlock() + return 0, singleton.Localizer.ErrorT("permission denied") + } + } + } + singleton.ServerLock.RUnlock() + + uid := getUid(c) + var sg model.ServerGroup sg.Name = sgf.Name + sg.UserID = uid var count int64 if err := singleton.DB.Model(&model.Server{}).Where("id in (?)", sgf.Servers).Count(&count).Error; err != nil { @@ -84,6 +98,9 @@ func createServerGroup(c *gin.Context) (uint64, error) { } for _, s := range sgf.Servers { if err := tx.Create(&model.ServerGroupServer{ + Common: model.Common{ + UserID: uid, + }, ServerGroupId: sg.ID, ServerId: s, }).Error; err != nil { @@ -125,10 +142,26 @@ func updateServerGroup(c *gin.Context) (any, error) { } sg.Servers = slices.Compact(sg.Servers) + singleton.ServerLock.RLock() + for _, sid := range sg.Servers { + if server, ok := singleton.ServerList[sid]; ok { + if !server.HasPermission(c) { + singleton.ServerLock.RUnlock() + return nil, singleton.Localizer.ErrorT("permission denied") + } + } + } + singleton.ServerLock.RUnlock() + var sgDB model.ServerGroup if err := singleton.DB.First(&sgDB, id).Error; err != nil { return nil, singleton.Localizer.ErrorT("group id %d does not exist", id) } + + if !sgDB.HasPermission(c) { + return nil, singleton.Localizer.ErrorT("unauthorized") + } + sgDB.Name = sg.Name var count int64 @@ -139,6 +172,8 @@ func updateServerGroup(c *gin.Context) (any, error) { return nil, singleton.Localizer.ErrorT("have invalid server id") } + uid := getUid(c) + err = singleton.DB.Transaction(func(tx *gorm.DB) error { if err := tx.Save(&sgDB).Error; err != nil { return err @@ -149,6 +184,9 @@ func updateServerGroup(c *gin.Context) (any, error) { for _, s := range sg.Servers { if err := tx.Create(&model.ServerGroupServer{ + Common: model.Common{ + UserID: uid, + }, ServerGroupId: sgDB.ID, ServerId: s, }).Error; err != nil { @@ -181,6 +219,17 @@ func batchDeleteServerGroup(c *gin.Context) (any, error) { return nil, err } + var sg []model.ServerGroup + if err := singleton.DB.Where("id in (?)", sgs).Find(&sg).Error; err != nil { + return nil, err + } + + for _, s := range sg { + if !s.HasPermission(c) { + return nil, singleton.Localizer.ErrorT("permission denied") + } + } + err := singleton.DB.Transaction(func(tx *gorm.DB) error { if err := tx.Unscoped().Delete(&model.ServerGroup{}, "id in (?)", sgs).Error; err != nil { return err diff --git a/cmd/dashboard/controller/service.go b/cmd/dashboard/controller/service.go index 2956681..7395af5 100644 --- a/cmd/dashboard/controller/service.go +++ b/cmd/dashboard/controller/service.go @@ -190,7 +190,10 @@ func createService(c *gin.Context) (uint64, error) { return 0, err } + uid := getUid(c) + var m model.Service + m.UserID = uid m.Name = mf.Name m.Target = strings.TrimSpace(mf.Target) m.Type = mf.Type @@ -207,6 +210,10 @@ func createService(c *gin.Context) (uint64, error) { m.RecoverTriggerTasks = mf.RecoverTriggerTasks m.FailTriggerTasks = mf.FailTriggerTasks + if err := validateServers(c, &m); err != nil { + return 0, err + } + if err := singleton.DB.Create(&m).Error; err != nil { return 0, newGormError("%v", err) } @@ -260,6 +267,11 @@ func updateService(c *gin.Context) (any, error) { if err := singleton.DB.First(&m, id).Error; err != nil { return nil, singleton.Localizer.ErrorT("service id %d does not exist", id) } + + if !m.HasPermission(c) { + return nil, singleton.Localizer.ErrorT("permission denied") + } + m.Name = mf.Name m.Target = strings.TrimSpace(mf.Target) m.Type = mf.Type @@ -276,6 +288,10 @@ func updateService(c *gin.Context) (any, error) { m.RecoverTriggerTasks = mf.RecoverTriggerTasks m.FailTriggerTasks = mf.FailTriggerTasks + if err := validateServers(c, &m); err != nil { + return 0, err + } + if err := singleton.DB.Save(&m).Error; err != nil { return nil, newGormError("%v", err) } @@ -318,6 +334,18 @@ func batchDeleteService(c *gin.Context) (any, error) { if err := c.ShouldBindJSON(&ids); err != nil { return nil, err } + + singleton.ServiceSentinelShared.ServicesLock.RLock() + for _, id := range ids { + if ss, ok := singleton.ServiceSentinelShared.Services[id]; ok { + if !ss.HasPermission(c) { + singleton.ServiceSentinelShared.ServicesLock.RUnlock() + return nil, singleton.Localizer.ErrorT("permission denied") + } + } + } + singleton.ServiceSentinelShared.ServicesLock.RUnlock() + err := singleton.DB.Transaction(func(tx *gorm.DB) error { if err := tx.Unscoped().Delete(&model.Service{}, "id in (?)", ids).Error; err != nil { return err @@ -331,3 +359,18 @@ func batchDeleteService(c *gin.Context) (any, error) { singleton.ServiceSentinelShared.UpdateServiceList() return nil, nil } + +func validateServers(c *gin.Context, ss *model.Service) error { + singleton.ServerLock.RLock() + defer singleton.ServerLock.RUnlock() + + for s := range ss.SkipServers { + if server, ok := singleton.ServerList[s]; ok { + if !server.HasPermission(c) { + return singleton.Localizer.ErrorT("permission denied") + } + } + } + + return nil +} diff --git a/cmd/dashboard/controller/terminal.go b/cmd/dashboard/controller/terminal.go index a5011b0..3c1ef0c 100644 --- a/cmd/dashboard/controller/terminal.go +++ b/cmd/dashboard/controller/terminal.go @@ -6,6 +6,7 @@ import ( "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "github.com/hashicorp/go-uuid" + "github.com/nezhahq/nezha/model" "github.com/nezhahq/nezha/pkg/utils" "github.com/nezhahq/nezha/pkg/websocketx" @@ -29,13 +30,6 @@ func createTerminal(c *gin.Context) (*model.CreateTerminalResponse, error) { return nil, err } - streamId, err := uuid.GenerateUUID() - if err != nil { - return nil, err - } - - rpc.NezhaHandlerSingleton.CreateStream(streamId) - singleton.ServerLock.RLock() server := singleton.ServerList[createTerminalReq.ServerID] singleton.ServerLock.RUnlock() @@ -43,6 +37,17 @@ func createTerminal(c *gin.Context) (*model.CreateTerminalResponse, error) { return nil, singleton.Localizer.ErrorT("server not found or not connected") } + if !server.HasPermission(c) { + return nil, singleton.Localizer.ErrorT("permission denied") + } + + streamId, err := uuid.GenerateUUID() + if err != nil { + return nil, err + } + + rpc.NezhaHandlerSingleton.CreateStream(streamId) + terminalData, _ := utils.Json.Marshal(&model.TerminalTask{ StreamID: streamId, }) diff --git a/cmd/dashboard/controller/user.go b/cmd/dashboard/controller/user.go index 7c8b9b9..5b5c019 100644 --- a/cmd/dashboard/controller/user.go +++ b/cmd/dashboard/controller/user.go @@ -114,6 +114,7 @@ func createUser(c *gin.Context) (uint64, error) { var u model.User u.Username = uf.Username + u.Role = model.RoleMember hash, err := bcrypt.GenerateFromPassword([]byte(uf.Password), bcrypt.DefaultCost) if err != nil { @@ -125,6 +126,7 @@ func createUser(c *gin.Context) (uint64, error) { return 0, err } + singleton.OnUserUpdate(&u) return u.ID, nil } @@ -149,5 +151,6 @@ func batchDeleteUser(c *gin.Context) (any, error) { return nil, singleton.Localizer.ErrorT("can't delete yourself") } - return nil, singleton.DB.Where("id IN (?)", ids).Delete(&model.User{}).Error + err := singleton.OnUserDelete(ids, newGormError) + return nil, err } diff --git a/cmd/dashboard/controller/waf.go b/cmd/dashboard/controller/waf.go index cd37218..de90faa 100644 --- a/cmd/dashboard/controller/waf.go +++ b/cmd/dashboard/controller/waf.go @@ -1,6 +1,8 @@ package controller import ( + "strconv" + "github.com/gin-gonic/gin" "github.com/nezhahq/nezha/model" @@ -13,16 +15,40 @@ import ( // @Schemes // @Description List server // @Tags auth required +// @Param limit query uint false "Page limit" +// @Param offset query uint false "Page offset" // @Produce json -// @Success 200 {object} model.CommonResponse[[]model.WAFApiMock] +// @Success 200 {object} model.PaginatedResponse[[]model.WAFApiMock, model.WAFApiMock] // @Router /waf [get] -func listBlockedAddress(c *gin.Context) ([]*model.WAF, error) { +func listBlockedAddress(c *gin.Context) (*model.Value[[]*model.WAF], error) { + limit, err := strconv.Atoi(c.Query("limit")) + if err != nil || limit < 1 { + limit = 25 + } + + offset, err := strconv.Atoi(c.Query("offset")) + if err != nil || offset < 0 { + offset = 0 + } + var waf []*model.WAF - if err := singleton.DB.Find(&waf).Error; err != nil { + if err := singleton.DB.Limit(limit).Offset(offset).Find(&waf).Error; err != nil { return nil, err } - return waf, nil + var total int64 + if err := singleton.DB.Model(&model.WAF{}).Count(&total).Error; err != nil { + return nil, err + } + + return &model.Value[[]*model.WAF]{ + Value: waf, + Pagination: model.Pagination{ + Offset: offset, + Limit: limit, + Total: total, + }, + }, nil } // Batch delete blocked addresses diff --git a/cmd/dashboard/rpc/rpc.go b/cmd/dashboard/rpc/rpc.go index 1618870..51f6e12 100644 --- a/cmd/dashboard/rpc/rpc.go +++ b/cmd/dashboard/rpc/rpc.go @@ -100,12 +100,34 @@ func DispatchTask(serviceSentinelDispatchBus <-chan model.Service) { continue } if task.Cover == model.ServiceCoverIgnoreAll && task.SkipServers[singleton.SortedServerList[workedServerIndex].ID] { - singleton.SortedServerList[workedServerIndex].TaskStream.Send(task.PB()) + server := singleton.SortedServerList[workedServerIndex] + singleton.UserLock.RLock() + var role uint8 + if u, ok := singleton.UserInfoMap[server.UserID]; !ok { + role = model.RoleMember + } else { + role = u.Role + } + singleton.UserLock.RUnlock() + if task.UserID == server.UserID || role == model.RoleAdmin { + singleton.SortedServerList[workedServerIndex].TaskStream.Send(task.PB()) + } workedServerIndex++ continue } if task.Cover == model.ServiceCoverAll && !task.SkipServers[singleton.SortedServerList[workedServerIndex].ID] { - singleton.SortedServerList[workedServerIndex].TaskStream.Send(task.PB()) + server := singleton.SortedServerList[workedServerIndex] + singleton.UserLock.RLock() + var role uint8 + if u, ok := singleton.UserInfoMap[server.UserID]; !ok { + role = model.RoleMember + } else { + role = u.Role + } + singleton.UserLock.RUnlock() + if task.UserID == server.UserID || role == model.RoleAdmin { + singleton.SortedServerList[workedServerIndex].TaskStream.Send(task.PB()) + } workedServerIndex++ continue } diff --git a/go.mod b/go.mod index 28ef269..f0c6f0e 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,8 @@ module github.com/nezhahq/nezha -go 1.22.7 +go 1.23.0 -toolchain go1.23.1 +toolchain go1.23.2 require ( github.com/appleboy/gin-jwt/v2 v2.10.0 diff --git a/model/alertrule.go b/model/alertrule.go index 431dc86..b505cf5 100644 --- a/model/alertrule.go +++ b/model/alertrule.go @@ -62,10 +62,15 @@ func (r *AlertRule) Enabled() bool { } // Snapshot 对传入的Server进行该报警规则下所有type的检查 返回每项检查结果 -func (r *AlertRule) Snapshot(cycleTransferStats *CycleTransferStats, server *Server, db *gorm.DB) []bool { - point := make([]bool, 0, len(r.Rules)) - for _, rule := range r.Rules { - point = append(point, rule.Snapshot(cycleTransferStats, server, db)) +func (r *AlertRule) Snapshot(cycleTransferStats *CycleTransferStats, server *Server, db *gorm.DB, role uint8) []bool { + point := make([]bool, len(r.Rules)) + + if r.UserID != server.UserID && role != RoleAdmin { + return point + } + + for i, rule := range r.Rules { + point[i] = rule.Snapshot(cycleTransferStats, server, db) } return point } diff --git a/model/api.go b/model/api.go index f69d157..dbcea5b 100644 --- a/model/api.go +++ b/model/api.go @@ -15,6 +15,23 @@ type CommonResponse[T any] struct { Error string `json:"error,omitempty"` } +type PaginatedResponse[S ~[]E, E any] struct { + Success bool `json:"success,omitempty"` + Data *Value[S] `json:"data,omitempty"` + Error string `json:"error,omitempty"` +} + +type Value[T any] struct { + Value T `json:"value,omitempty"` + Pagination Pagination `json:"pagination,omitempty"` +} + +type Pagination struct { + Offset int `json:"offset,omitempty"` + Limit int `json:"limit,omitempty"` + Total int64 `json:"total,omitempty"` +} + type LoginResponse struct { Token string `json:"token,omitempty"` Expire string `json:"expire,omitempty"` diff --git a/model/common.go b/model/common.go index 294295e..6946783 100644 --- a/model/common.go +++ b/model/common.go @@ -2,6 +2,8 @@ package model import ( "time" + + "github.com/gin-gonic/gin" ) const ( @@ -18,6 +20,47 @@ type Common struct { UpdatedAt time.Time `gorm:"autoUpdateTime" json:"updated_at,omitempty"` // Do not use soft deletion // DeletedAt gorm.DeletedAt `gorm:"index" json:"deleted_at,omitempty"` + + UserID uint64 `json:"-"` +} + +func (c *Common) GetID() uint64 { + return c.ID +} + +func (c *Common) GetUserID() uint64 { + return c.UserID +} + +func (c *Common) HasPermission(ctx *gin.Context) bool { + auth, ok := ctx.Get(CtxKeyAuthorizedUser) + if !ok { + return false + } + + user := *auth.(*User) + if user.Role == RoleAdmin { + return true + } + + return user.ID == c.UserID +} + +type CommonInterface interface { + GetID() uint64 + GetUserID() uint64 + HasPermission(*gin.Context) bool +} + +func FindByUserID[S ~[]E, E CommonInterface](s S, uid uint64) []uint64 { + var list []uint64 + for _, v := range s { + if v.GetUserID() == uid { + list = append(list, v.GetID()) + } + } + + return list } type Response struct { diff --git a/model/user.go b/model/user.go index e1f297b..29dfe4f 100644 --- a/model/user.go +++ b/model/user.go @@ -1,9 +1,41 @@ package model +import ( + "github.com/nezhahq/nezha/pkg/utils" + "gorm.io/gorm" +) + +const ( + RoleAdmin uint8 = iota + RoleMember +) + type User struct { Common - Username string `json:"username,omitempty" gorm:"uniqueIndex"` - Password string `json:"password,omitempty" gorm:"type:char(72)"` + Username string `json:"username,omitempty" gorm:"uniqueIndex"` + Password string `json:"password,omitempty" gorm:"type:char(72)"` + Role uint8 `json:"role,omitempty"` + AgentSecret string `json:"agent_secret,omitempty" gorm:"type:char(32)"` +} + +type UserInfo struct { + Role uint8 + _ [3]byte + AgentSecret string +} + +func (u *User) BeforeSave(tx *gorm.DB) error { + if u.AgentSecret != "" { + return nil + } + + key, err := utils.GenerateRandomString(32) + if err != nil { + return err + } + + u.AgentSecret = key + return nil } type Profile struct { diff --git a/model/user_group.go b/model/user_group.go deleted file mode 100644 index 45b0ecf..0000000 --- a/model/user_group.go +++ /dev/null @@ -1,6 +0,0 @@ -package model - -type UserGroup struct { - Common - Name string `json:"name"` -} diff --git a/model/user_group_user.go b/model/user_group_user.go deleted file mode 100644 index 080f35e..0000000 --- a/model/user_group_user.go +++ /dev/null @@ -1,7 +0,0 @@ -package model - -type UserGroupUser struct { - Common - UserGroupId uint64 `json:"user_group_id"` - UserId uint64 `json:"user_id"` -} diff --git a/model/waf.go b/model/waf.go index 63f5032..a0a62a7 100644 --- a/model/waf.go +++ b/model/waf.go @@ -16,22 +16,30 @@ const ( WAFBlockReasonTypeAgentAuthFail ) +const ( + BlockIDgRPC = -127 + iota + BlockIDToken + BlockIDUnknownUser +) + type WAFApiMock struct { - IP string `json:"ip,omitempty"` - Count uint64 `json:"count,omitempty"` - LastBlockReason uint8 `json:"last_block_reason,omitempty"` - LastBlockTimestamp uint64 `json:"last_block_timestamp,omitempty"` + IP string `json:"ip,omitempty"` + BlockIdentifier int64 `json:"block_identifier,omitempty"` + BlockReason uint8 `json:"block_reason,omitempty"` + BlockTimestamp uint64 `json:"block_timestamp,omitempty"` + Count uint64 `json:"count,omitempty"` } type WAF struct { - IP []byte `gorm:"type:binary(16);primaryKey" json:"ip,omitempty"` - Count uint64 `json:"count,omitempty"` - LastBlockReason uint8 `json:"last_block_reason,omitempty"` - LastBlockTimestamp uint64 `json:"last_block_timestamp,omitempty"` + IP []byte `gorm:"type:binary(16);primaryKey" json:"ip,omitempty"` + BlockIdentifier int64 `gorm:"primaryKey" json:"block_identifier,omitempty"` + BlockReason uint8 `json:"block_reason,omitempty"` + BlockTimestamp uint64 `gorm:"index" json:"block_timestamp,omitempty"` + Count uint64 `json:"count,omitempty"` } func (w *WAF) TableName() string { - return "waf" + return "nz_waf" } func CheckIP(db *gorm.DB, ip string) error { @@ -42,22 +50,31 @@ func CheckIP(db *gorm.DB, ip string) error { if err != nil { return err } - var w WAF - result := db.Limit(1).Find(&w, "ip = ?", ipBinary) + + var blockTimestamp uint64 + result := db.Model(&WAF{}).Order("block_timestamp desc").Select("block_timestamp").Where("ip = ?", ipBinary).Limit(1).Find(&blockTimestamp) if result.Error != nil { return result.Error } - if result.RowsAffected == 0 { // 检查是否未找到记录 + + // 检查是否未找到记录 + if result.RowsAffected < 1 { return nil } + + var count uint64 + if err := db.Model(&WAF{}).Select("SUM(count)").Where("ip = ?", ipBinary).Scan(&count).Error; err != nil { + return err + } + now := time.Now().Unix() - if powAdd(w.Count, 4, w.LastBlockTimestamp) > uint64(now) { + if powAdd(count, 4, blockTimestamp) > uint64(now) { return errors.New("you are blocked by nezha WAF") } return nil } -func ClearIP(db *gorm.DB, ip string) error { +func ClearIP(db *gorm.DB, ip string, uid int64) error { if ip == "" { return nil } @@ -65,7 +82,7 @@ func ClearIP(db *gorm.DB, ip string) error { if err != nil { return err } - return db.Unscoped().Delete(&WAF{}, "ip = ?", ipBinary).Error + return db.Unscoped().Delete(&WAF{}, "ip = ? and block_identifier = ?", ipBinary, uid).Error } func BatchClearIP(db *gorm.DB, ip []string) error { @@ -83,7 +100,7 @@ func BatchClearIP(db *gorm.DB, ip []string) error { return db.Unscoped().Delete(&WAF{}, "ip in (?)", ips).Error } -func BlockIP(db *gorm.DB, ip string, reason uint8) error { +func BlockIP(db *gorm.DB, ip string, reason uint8, uid int64) error { if ip == "" { return nil } @@ -91,16 +108,19 @@ func BlockIP(db *gorm.DB, ip string, reason uint8) error { if err != nil { return err } - var w WAF - w.IP = ipBinary + w := WAF{ + IP: ipBinary, + BlockIdentifier: uid, + } + now := uint64(time.Now().Unix()) return db.Transaction(func(tx *gorm.DB) error { if err := tx.Where(&w).Attrs(WAF{ - LastBlockReason: reason, - LastBlockTimestamp: uint64(time.Now().Unix()), + BlockReason: reason, + BlockTimestamp: now, }).FirstOrCreate(&w).Error; err != nil { return err } - return tx.Exec("UPDATE waf SET count = count + 1, last_block_reason = ?, last_block_timestamp = ? WHERE ip = ?", reason, uint64(time.Now().Unix()), ipBinary).Error + return tx.Exec("UPDATE nz_waf SET count = count + 1, block_reason = ?, block_timestamp = ? WHERE ip = ? and block_identifier = ?", reason, now, ipBinary, uid).Error }) } diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go index 1ec9985..d3efd5a 100644 --- a/pkg/utils/utils.go +++ b/pkg/utils/utils.go @@ -3,10 +3,12 @@ package utils import ( "crypto/rand" "errors" + "maps" "math/big" "net/netip" "os" "regexp" + "slices" "strconv" "strings" @@ -145,3 +147,8 @@ func Itoa[T constraints.Integer](i T) string { return "" } } + +func MapValuesToSlice[Map ~map[K]V, K comparable, V any](m Map) []V { + s := make([]V, 0, len(m)) + return slices.AppendSeq(s, maps.Values(m)) +} diff --git a/service/rpc/auth.go b/service/rpc/auth.go index 1709168..435b615 100644 --- a/service/rpc/auth.go +++ b/service/rpc/auth.go @@ -36,12 +36,16 @@ func (a *authHandler) Check(ctx context.Context) (uint64, error) { ip, _ := ctx.Value(model.CtxKeyRealIP{}).(string) - if clientSecret != singleton.Conf.AgentSecretKey { - model.BlockIP(singleton.DB, ip, model.WAFBlockReasonTypeAgentAuthFail) + singleton.UserLock.RLock() + userId, ok := singleton.AgentSecretToUserId[clientSecret] + if !ok && clientSecret != singleton.Conf.AgentSecretKey { + singleton.UserLock.RUnlock() + model.BlockIP(singleton.DB, ip, model.WAFBlockReasonTypeAgentAuthFail, model.BlockIDgRPC) return 0, status.Error(codes.Unauthenticated, "客户端认证失败") } + singleton.UserLock.RUnlock() - model.ClearIP(singleton.DB, ip) + model.ClearIP(singleton.DB, ip, model.BlockIDgRPC) var clientUUID string if value, ok := md["client_uuid"]; ok { @@ -53,21 +57,26 @@ func (a *authHandler) Check(ctx context.Context) (uint64, error) { } singleton.ServerLock.RLock() - defer singleton.ServerLock.RUnlock() - clientID, hasID := singleton.ServerUUIDToID[clientUUID] + singleton.ServerLock.RUnlock() + if !hasID { - s := model.Server{UUID: clientUUID, Name: petname.Generate(2, "-")} + s := model.Server{UUID: clientUUID, Name: petname.Generate(2, "-"), Common: model.Common{ + UserID: userId, + }} if err := singleton.DB.Create(&s).Error; err != nil { return 0, status.Error(codes.Unauthenticated, err.Error()) } s.Host = &model.Host{} s.State = &model.HostState{} s.GeoIP = &model.GeoIP{} - // generate a random silly server name + + singleton.ServerLock.Lock() singleton.ServerList[s.ID] = &s singleton.ServerUUIDToID[clientUUID] = s.ID + singleton.ServerLock.Unlock() singleton.ReSortServer() + clientID = s.ID } diff --git a/service/singleton/alertsentinel.go b/service/singleton/alertsentinel.go index a4f672d..ec20a80 100644 --- a/service/singleton/alertsentinel.go +++ b/service/singleton/alertsentinel.go @@ -143,8 +143,16 @@ func checkStatus() { } for _, server := range ServerList { // 监测点 + UserLock.RLock() + var role uint8 + if u, ok := UserInfoMap[server.UserID]; !ok { + role = model.RoleMember + } else { + role = u.Role + } + UserLock.RUnlock() alertsStore[alert.ID][server.ID] = append(alertsStore[alert. - ID][server.ID], alert.Snapshot(AlertsCycleTransferStatsStore[alert.ID], server, DB)) + ID][server.ID], alert.Snapshot(AlertsCycleTransferStatsStore[alert.ID], server, DB, role)) // 发送通知,分为触发报警和恢复通知 max, passed := alert.Check(alertsStore[alert.ID][server.ID]) // 保存当前服务器状态信息 diff --git a/service/singleton/crontask.go b/service/singleton/crontask.go index 0005ef7..8fc63e4 100644 --- a/service/singleton/crontask.go +++ b/service/singleton/crontask.go @@ -12,6 +12,7 @@ import ( "github.com/robfig/cron/v3" "github.com/nezhahq/nezha/model" + "github.com/nezhahq/nezha/pkg/utils" pb "github.com/nezhahq/nezha/proto" ) @@ -79,10 +80,7 @@ func UpdateCronList() { CronLock.RLock() defer CronLock.RUnlock() - CronList = make([]*model.Cron, 0, len(Crons)) - for _, c := range Crons { - CronList = append(CronList, c) - } + CronList = utils.MapValuesToSlice(Crons) slices.SortFunc(CronList, func(a, b *model.Cron) int { return cmp.Compare(a.ID, b.ID) }) diff --git a/service/singleton/ddns.go b/service/singleton/ddns.go index 7f35dab..ddbc040 100644 --- a/service/singleton/ddns.go +++ b/service/singleton/ddns.go @@ -13,6 +13,7 @@ import ( ddns2 "github.com/nezhahq/nezha/pkg/ddns" "github.com/nezhahq/nezha/pkg/ddns/dummy" "github.com/nezhahq/nezha/pkg/ddns/webhook" + "github.com/nezhahq/nezha/pkg/utils" ) var ( @@ -24,12 +25,10 @@ var ( func initDDNS() { DB.Find(&DDNSList) - DDNSCacheLock.Lock() DDNSCache = make(map[uint64]*model.DDNSProfile) for i := 0; i < len(DDNSList); i++ { DDNSCache[DDNSList[i].ID] = DDNSList[i] } - DDNSCacheLock.Unlock() OnNameserverUpdate() } @@ -56,10 +55,7 @@ func UpdateDDNSList() { DDNSListLock.Lock() defer DDNSListLock.Unlock() - DDNSList = make([]*model.DDNSProfile, 0, len(DDNSCache)) - for _, p := range DDNSCache { - DDNSList = append(DDNSList, p) - } + DDNSList = utils.MapValuesToSlice(DDNSCache) slices.SortFunc(DDNSList, func(a, b *model.DDNSProfile) int { return cmp.Compare(a.ID, b.ID) }) diff --git a/service/singleton/nat.go b/service/singleton/nat.go index 7ac2897..e6f0323 100644 --- a/service/singleton/nat.go +++ b/service/singleton/nat.go @@ -6,6 +6,7 @@ import ( "sync" "github.com/nezhahq/nezha/model" + "github.com/nezhahq/nezha/pkg/utils" ) var ( @@ -19,8 +20,6 @@ var ( func initNAT() { DB.Find(&NATList) - NATCacheRwLock.Lock() - defer NATCacheRwLock.Unlock() NATCache = make(map[string]*model.NAT) for i := 0; i < len(NATList); i++ { NATCache[NATList[i].Domain] = NATList[i] @@ -59,10 +58,7 @@ func UpdateNATList() { NATListLock.Lock() defer NATListLock.Unlock() - NATList = make([]*model.NAT, 0, len(NATCache)) - for _, n := range NATCache { - NATList = append(NATList, n) - } + NATList = utils.MapValuesToSlice(NATCache) slices.SortFunc(NATList, func(a, b *model.NAT) int { return cmp.Compare(a.ID, b.ID) }) diff --git a/service/singleton/notification.go b/service/singleton/notification.go index 498f000..cb63087 100644 --- a/service/singleton/notification.go +++ b/service/singleton/notification.go @@ -9,6 +9,7 @@ import ( "time" "github.com/nezhahq/nezha/model" + "github.com/nezhahq/nezha/pkg/utils" ) const ( @@ -30,7 +31,7 @@ var ( ) // InitNotification 初始化 GroupID <-> ID <-> Notification 的映射 -func InitNotification() { +func initNotification() { NotificationList = make(map[uint64]map[uint64]*model.Notification) NotificationIDToGroups = make(map[uint64]map[uint64]struct{}) NotificationGroup = make(map[uint64]string) @@ -38,9 +39,7 @@ func InitNotification() { // loadNotifications 从 DB 初始化通知方式相关参数 func loadNotifications() { - InitNotification() - NotificationsLock.Lock() - + initNotification() groupNotifications := make(map[uint64][]uint64) var ngn []model.NotificationGroupNotification if err := DB.Find(&ngn).Error; err != nil { @@ -74,8 +73,6 @@ func loadNotifications() { } } } - - NotificationsLock.Unlock() } func UpdateNotificationList() { @@ -85,10 +82,7 @@ func UpdateNotificationList() { NotificationSortedLock.Lock() defer NotificationSortedLock.Unlock() - NotificationListSorted = make([]*model.Notification, 0, len(NotificationMap)) - for _, n := range NotificationMap { - NotificationListSorted = append(NotificationListSorted, n) - } + NotificationListSorted = utils.MapValuesToSlice(NotificationMap) slices.SortFunc(NotificationListSorted, func(a, b *model.Notification) int { return cmp.Compare(a.ID, b.ID) }) diff --git a/service/singleton/server.go b/service/singleton/server.go index 9450d9f..c6afb32 100644 --- a/service/singleton/server.go +++ b/service/singleton/server.go @@ -1,10 +1,12 @@ package singleton import ( - "sort" + "cmp" + "slices" "sync" "github.com/nezhahq/nezha/model" + "github.com/nezhahq/nezha/pkg/utils" ) var ( @@ -45,29 +47,21 @@ func ReSortServer() { SortedServerLock.Lock() defer SortedServerLock.Unlock() - SortedServerList = make([]*model.Server, 0, len(ServerList)) - SortedServerListForGuest = make([]*model.Server, 0) - for _, s := range ServerList { - SortedServerList = append(SortedServerList, s) + SortedServerList = utils.MapValuesToSlice(ServerList) + // 按照服务器 ID 排序的具体实现(ID越大越靠前) + slices.SortStableFunc(SortedServerList, func(a, b *model.Server) int { + if a.DisplayIndex == b.DisplayIndex { + return cmp.Compare(a.ID, b.ID) + } + return cmp.Compare(b.DisplayIndex, a.DisplayIndex) + }) + + SortedServerListForGuest = make([]*model.Server, 0, len(SortedServerList)) + for _, s := range SortedServerList { if !s.HideForGuest { SortedServerListForGuest = append(SortedServerListForGuest, s) } } - - // 按照服务器 ID 排序的具体实现(ID越大越靠前) - sort.SliceStable(SortedServerList, func(i, j int) bool { - if SortedServerList[i].DisplayIndex == SortedServerList[j].DisplayIndex { - return SortedServerList[i].ID < SortedServerList[j].ID - } - return SortedServerList[i].DisplayIndex > SortedServerList[j].DisplayIndex - }) - - sort.SliceStable(SortedServerListForGuest, func(i, j int) bool { - if SortedServerListForGuest[i].DisplayIndex == SortedServerListForGuest[j].DisplayIndex { - return SortedServerListForGuest[i].ID < SortedServerListForGuest[j].ID - } - return SortedServerListForGuest[i].DisplayIndex > SortedServerListForGuest[j].DisplayIndex - }) } func OnServerDelete(sid []uint64) { diff --git a/service/singleton/servicesentinel.go b/service/singleton/servicesentinel.go index 2ef5c56..8cbfd1b 100644 --- a/service/singleton/servicesentinel.go +++ b/service/singleton/servicesentinel.go @@ -11,6 +11,7 @@ import ( "github.com/jinzhu/copier" "github.com/nezhahq/nezha/model" + "github.com/nezhahq/nezha/pkg/utils" pb "github.com/nezhahq/nezha/proto" ) @@ -174,11 +175,7 @@ func (ss *ServiceSentinel) UpdateServiceList() { ss.ServiceListLock.Lock() defer ss.ServiceListLock.Unlock() - ss.ServiceList = make([]*model.Service, 0, len(ss.Services)) - for _, v := range ss.Services { - ss.ServiceList = append(ss.ServiceList, v) - } - + ss.ServiceList = utils.MapValuesToSlice(ss.Services) slices.SortFunc(ss.ServiceList, func(a, b *model.Service) int { return cmp.Compare(a.ID, b.ID) }) @@ -192,13 +189,6 @@ func (ss *ServiceSentinel) loadServiceHistory() { panic(err) } - ss.serviceResponseDataStoreLock.Lock() - defer ss.serviceResponseDataStoreLock.Unlock() - ss.monthlyStatusLock.Lock() - defer ss.monthlyStatusLock.Unlock() - ss.ServicesLock.Lock() - defer ss.ServicesLock.Unlock() - for i := 0; i < len(services); i++ { task := *services[i] // 通过cron定时将服务监控任务传递给任务调度管道 diff --git a/service/singleton/singleton.go b/service/singleton/singleton.go index ed2c564..d1a5f0f 100644 --- a/service/singleton/singleton.go +++ b/service/singleton/singleton.go @@ -42,6 +42,7 @@ func InitTimezoneAndCache() { // LoadSingleton 加载子服务并执行 func LoadSingleton() { + initUser() // 加载用户ID绑定表 initI18n() // 加载本地化服务 loadNotifications() // 加载通知服务 loadServers() // 加载服务器列表 @@ -81,8 +82,8 @@ func InitDBFromPath(path string) { } err = DB.AutoMigrate(model.Server{}, model.User{}, model.ServerGroup{}, model.NotificationGroup{}, model.Notification{}, model.AlertRule{}, model.Service{}, model.NotificationGroupNotification{}, - model.ServiceHistory{}, model.Cron{}, model.Transfer{}, model.ServerGroupServer{}, model.UserGroup{}, - model.UserGroupUser{}, model.NAT{}, model.DDNSProfile{}, model.NotificationGroupNotification{}, + model.ServiceHistory{}, model.Cron{}, model.Transfer{}, model.ServerGroupServer{}, + model.NAT{}, model.DDNSProfile{}, model.NotificationGroupNotification{}, model.WAF{}) if err != nil { panic(err) diff --git a/service/singleton/user.go b/service/singleton/user.go new file mode 100644 index 0000000..ab883d6 --- /dev/null +++ b/service/singleton/user.go @@ -0,0 +1,135 @@ +package singleton + +import ( + "sync" + + "github.com/nezhahq/nezha/model" + "gorm.io/gorm" +) + +var ( + UserInfoMap map[uint64]model.UserInfo + AgentSecretToUserId map[string]uint64 + + UserLock sync.RWMutex +) + +func initUser() { + UserInfoMap = make(map[uint64]model.UserInfo) + AgentSecretToUserId = make(map[string]uint64) + + var users []model.User + DB.Find(&users) + + for _, u := range users { + UserInfoMap[u.ID] = model.UserInfo{ + Role: u.Role, + AgentSecret: u.AgentSecret, + } + AgentSecretToUserId[u.AgentSecret] = u.ID + } +} + +func OnUserUpdate(u *model.User) { + UserLock.Lock() + defer UserLock.Unlock() + + if u == nil { + return + } + + UserInfoMap[u.ID] = model.UserInfo{ + Role: u.Role, + AgentSecret: u.AgentSecret, + } + AgentSecretToUserId[u.AgentSecret] = u.ID +} + +func OnUserDelete(id []uint64, errorFunc func(string, ...interface{}) error) error { + UserLock.Lock() + defer UserLock.Unlock() + + if len(id) < 1 { + return Localizer.ErrorT("user id not specified") + } + + var ( + cron, server bool + crons, servers []uint64 + ) + + for _, uid := range id { + err := DB.Transaction(func(tx *gorm.DB) error { + CronLock.RLock() + crons = model.FindByUserID(CronList, uid) + CronLock.RUnlock() + + cron = len(crons) > 0 + if cron { + if err := tx.Unscoped().Delete(&model.Cron{}, "id in (?)", crons).Error; err != nil { + return err + } + } + + SortedServerLock.RLock() + servers = model.FindByUserID(SortedServerList, uid) + SortedServerLock.RUnlock() + + server = len(servers) > 0 + if server { + if err := tx.Unscoped().Delete(&model.Server{}, "id in (?)", servers).Error; err != nil { + return err + } + if err := tx.Unscoped().Delete(&model.ServerGroupServer{}, "server_id in (?)", servers).Error; err != nil { + return err + } + } + + if err := tx.Unscoped().Delete(&model.Transfer{}, "server_id in (?)", servers).Error; err != nil { + return err + } + + if err := tx.Where("id IN (?)", id).Delete(&model.User{}).Error; err != nil { + return err + } + return nil + }) + + if err != nil { + return errorFunc("%v", err) + } + + if cron { + OnDeleteCron(crons) + } + + if server { + AlertsLock.Lock() + for _, sid := range servers { + for _, alert := range Alerts { + if AlertsCycleTransferStatsStore[alert.ID] != nil { + delete(AlertsCycleTransferStatsStore[alert.ID].ServerName, sid) + delete(AlertsCycleTransferStatsStore[alert.ID].Transfer, sid) + delete(AlertsCycleTransferStatsStore[alert.ID].NextUpdate, sid) + } + } + } + AlertsLock.Unlock() + OnServerDelete(servers) + } + + secret := UserInfoMap[uid].AgentSecret + delete(AgentSecretToUserId, secret) + delete(UserInfoMap, uid) + } + + if cron { + UpdateCronList() + } + + if server { + ReSortServer() + } + + return nil +}