diff --git a/cmd/dashboard/controller/controller.go b/cmd/dashboard/controller/controller.go index 37b08b9..af5dca7 100644 --- a/cmd/dashboard/controller/controller.go +++ b/cmd/dashboard/controller/controller.go @@ -79,7 +79,7 @@ func routers(r *gin.Engine, frontendDist fs.FS) { auth.GET("/profile", commonHandler(getProfile)) auth.POST("/profile", commonHandler(updateProfile)) - auth.GET("/user", commonHandler(listUser)) + auth.GET("/user", adminHandler(listUser)) auth.POST("/user", adminHandler(createUser)) auth.POST("/batch-delete/user", adminHandler(batchDeleteUser)) diff --git a/cmd/dashboard/controller/user.go b/cmd/dashboard/controller/user.go index 52f9d70..12b236f 100644 --- a/cmd/dashboard/controller/user.go +++ b/cmd/dashboard/controller/user.go @@ -126,6 +126,7 @@ func createUser(c *gin.Context) (uint64, error) { return 0, err } + singleton.OnUserUpdate(&u) return u.ID, nil } @@ -150,5 +151,6 @@ func batchDeleteUser(c *gin.Context) (any, error) { return nil, singleton.Localizer.ErrorT("can't delete yourself") } + singleton.OnUserDelete(ids) return nil, singleton.DB.Where("id IN (?)", ids).Delete(&model.User{}).Error } diff --git a/model/common.go b/model/common.go index 6b83911..05344f2 100644 --- a/model/common.go +++ b/model/common.go @@ -20,7 +20,7 @@ type Common struct { // Do not use soft deletion // DeletedAt gorm.DeletedAt `gorm:"index" json:"deleted_at,omitempty"` - UserID uint64 `json:"user_id,omitempty"` + UserID uint64 `json:"-"` } func (c *Common) GetID() uint64 { diff --git a/model/user.go b/model/user.go index fe8a1ed..3e987d3 100644 --- a/model/user.go +++ b/model/user.go @@ -1,5 +1,10 @@ package model +import ( + "github.com/nezhahq/nezha/pkg/utils" + "gorm.io/gorm" +) + const ( RoleAdmin uint8 = iota RoleMember @@ -7,9 +12,20 @@ const ( type User struct { Common - Username string `json:"username,omitempty" gorm:"uniqueIndex"` - Password string `json:"password,omitempty" gorm:"type:char(72)"` - Role uint8 `json:"role,omitempty"` + Username string `json:"username,omitempty" gorm:"uniqueIndex"` + Password string `json:"password,omitempty" gorm:"type:char(72)"` + Role uint8 `json:"role,omitempty"` + AgentSecret string `json:"agent_secret,omitempty" gorm:"type:char(32)"` +} + +func (u *User) BeforeSave(tx *gorm.DB) error { + key, err := utils.GenerateRandomString(32) + if err != nil { + return err + } + + u.AgentSecret = key + return nil } type Profile struct { diff --git a/service/rpc/auth.go b/service/rpc/auth.go index 1709168..817f981 100644 --- a/service/rpc/auth.go +++ b/service/rpc/auth.go @@ -2,6 +2,7 @@ package rpc import ( "context" + "crypto/subtle" "strings" petname "github.com/dustinkirkland/golang-petname" @@ -36,10 +37,14 @@ func (a *authHandler) Check(ctx context.Context) (uint64, error) { ip, _ := ctx.Value(model.CtxKeyRealIP{}).(string) - if clientSecret != singleton.Conf.AgentSecretKey { + singleton.UserLock.RLock() + userId, ok := singleton.AgentSecretToUserId[clientSecret] + if !ok && subtle.ConstantTimeCompare([]byte(clientSecret), []byte(singleton.Conf.AgentSecretKey)) != 1 { + singleton.UserLock.RUnlock() model.BlockIP(singleton.DB, ip, model.WAFBlockReasonTypeAgentAuthFail) return 0, status.Error(codes.Unauthenticated, "客户端认证失败") } + singleton.UserLock.RUnlock() model.ClearIP(singleton.DB, ip) @@ -53,21 +58,26 @@ func (a *authHandler) Check(ctx context.Context) (uint64, error) { } singleton.ServerLock.RLock() - defer singleton.ServerLock.RUnlock() - clientID, hasID := singleton.ServerUUIDToID[clientUUID] + singleton.ServerLock.RUnlock() + if !hasID { - s := model.Server{UUID: clientUUID, Name: petname.Generate(2, "-")} + s := model.Server{UUID: clientUUID, Name: petname.Generate(2, "-"), Common: model.Common{ + UserID: userId, + }} if err := singleton.DB.Create(&s).Error; err != nil { return 0, status.Error(codes.Unauthenticated, err.Error()) } s.Host = &model.Host{} s.State = &model.HostState{} s.GeoIP = &model.GeoIP{} - // generate a random silly server name + + singleton.ServerLock.Lock() singleton.ServerList[s.ID] = &s singleton.ServerUUIDToID[clientUUID] = s.ID + singleton.ServerLock.Unlock() singleton.ReSortServer() + clientID = s.ID } diff --git a/service/singleton/ddns.go b/service/singleton/ddns.go index 7f35dab..9a196a1 100644 --- a/service/singleton/ddns.go +++ b/service/singleton/ddns.go @@ -24,12 +24,10 @@ var ( func initDDNS() { DB.Find(&DDNSList) - DDNSCacheLock.Lock() DDNSCache = make(map[uint64]*model.DDNSProfile) for i := 0; i < len(DDNSList); i++ { DDNSCache[DDNSList[i].ID] = DDNSList[i] } - DDNSCacheLock.Unlock() OnNameserverUpdate() } diff --git a/service/singleton/nat.go b/service/singleton/nat.go index 7ac2897..cacc4a7 100644 --- a/service/singleton/nat.go +++ b/service/singleton/nat.go @@ -19,8 +19,6 @@ var ( func initNAT() { DB.Find(&NATList) - NATCacheRwLock.Lock() - defer NATCacheRwLock.Unlock() NATCache = make(map[string]*model.NAT) for i := 0; i < len(NATList); i++ { NATCache[NATList[i].Domain] = NATList[i] diff --git a/service/singleton/notification.go b/service/singleton/notification.go index 498f000..caf1cc4 100644 --- a/service/singleton/notification.go +++ b/service/singleton/notification.go @@ -30,7 +30,7 @@ var ( ) // InitNotification 初始化 GroupID <-> ID <-> Notification 的映射 -func InitNotification() { +func initNotification() { NotificationList = make(map[uint64]map[uint64]*model.Notification) NotificationIDToGroups = make(map[uint64]map[uint64]struct{}) NotificationGroup = make(map[uint64]string) @@ -38,9 +38,7 @@ func InitNotification() { // loadNotifications 从 DB 初始化通知方式相关参数 func loadNotifications() { - InitNotification() - NotificationsLock.Lock() - + initNotification() groupNotifications := make(map[uint64][]uint64) var ngn []model.NotificationGroupNotification if err := DB.Find(&ngn).Error; err != nil { @@ -74,8 +72,6 @@ func loadNotifications() { } } } - - NotificationsLock.Unlock() } func UpdateNotificationList() { diff --git a/service/singleton/servicesentinel.go b/service/singleton/servicesentinel.go index 2ef5c56..fe7abe6 100644 --- a/service/singleton/servicesentinel.go +++ b/service/singleton/servicesentinel.go @@ -192,13 +192,6 @@ func (ss *ServiceSentinel) loadServiceHistory() { panic(err) } - ss.serviceResponseDataStoreLock.Lock() - defer ss.serviceResponseDataStoreLock.Unlock() - ss.monthlyStatusLock.Lock() - defer ss.monthlyStatusLock.Unlock() - ss.ServicesLock.Lock() - defer ss.ServicesLock.Unlock() - for i := 0; i < len(services); i++ { task := *services[i] // 通过cron定时将服务监控任务传递给任务调度管道 diff --git a/service/singleton/singleton.go b/service/singleton/singleton.go index 8596bcc..c284e91 100644 --- a/service/singleton/singleton.go +++ b/service/singleton/singleton.go @@ -40,6 +40,7 @@ func InitTimezoneAndCache() { // LoadSingleton 加载子服务并执行 func LoadSingleton() { + initUser() // 加载用户ID绑定表 initI18n() // 加载本地化服务 loadNotifications() // 加载通知服务 loadServers() // 加载服务器列表 diff --git a/service/singleton/user.go b/service/singleton/user.go new file mode 100644 index 0000000..763344e --- /dev/null +++ b/service/singleton/user.go @@ -0,0 +1,54 @@ +package singleton + +import ( + "sync" + + "github.com/nezhahq/nezha/model" +) + +var ( + UserIdToAgentSecret map[uint64]string + AgentSecretToUserId map[string]uint64 + + UserLock sync.RWMutex +) + +func initUser() { + UserIdToAgentSecret = make(map[uint64]string) + AgentSecretToUserId = make(map[string]uint64) + + var users []model.User + DB.Find(&users) + + for _, u := range users { + UserIdToAgentSecret[u.ID] = u.AgentSecret + AgentSecretToUserId[u.AgentSecret] = u.ID + } +} + +func OnUserUpdate(u *model.User) { + UserLock.Lock() + defer UserLock.Unlock() + + if u == nil { + return + } + + UserIdToAgentSecret[u.ID] = u.AgentSecret + AgentSecretToUserId[u.AgentSecret] = u.ID +} + +func OnUserDelete(id []uint64) { + UserLock.Lock() + defer UserLock.Unlock() + + if len(id) < 1 { + return + } + + for _, uid := range id { + secret := UserIdToAgentSecret[uid] + delete(AgentSecretToUserId, secret) + delete(UserIdToAgentSecret, uid) + } +}