Compare commits

..

3 Commits

Author SHA1 Message Date
UUBulb
f127ff8141
Merge a5dbc5693d into d835aeb486 2024-12-18 09:26:34 +00:00
uubulb
a5dbc5693d feat: user-specific connection secret 2024-12-18 17:26:24 +08:00
uubulb
4754d283c9 update 2024-12-18 15:29:46 +08:00
16 changed files with 108 additions and 46 deletions

View File

@ -79,7 +79,7 @@ func routers(r *gin.Engine, frontendDist fs.FS) {
auth.GET("/profile", commonHandler(getProfile))
auth.POST("/profile", commonHandler(updateProfile))
auth.GET("/user", commonHandler(listUser))
auth.GET("/user", adminHandler(listUser))
auth.POST("/user", adminHandler(createUser))
auth.POST("/batch-delete/user", adminHandler(batchDeleteUser))
@ -97,7 +97,7 @@ func routers(r *gin.Engine, frontendDist fs.FS) {
auth.PATCH("/notification-group/:id", commonHandler(updateNotificationGroup))
auth.POST("/batch-delete/notification-group", commonHandler(batchDeleteNotificationGroup))
auth.GET("/server", listHandler(listServer))
auth.GET("/server", commonHandler(listServer))
auth.PATCH("/server/:id", commonHandler(updateServer))
auth.POST("/batch-delete/server", commonHandler(batchDeleteServer))
auth.POST("/force-update/server", commonHandler(forceUpdateServer))
@ -243,13 +243,13 @@ func listHandler[S ~[]E, E model.CommonInterface](handler handlerFunc[S]) func(*
return
}
c.JSON(http.StatusOK, filter(c, data))
c.JSON(http.StatusOK, model.CommonResponse[S]{Success: true, Data: filter(c, data)})
}
}
func filter[S ~[]E, E model.CommonInterface](ctx *gin.Context, s S) S {
return slices.DeleteFunc(s, func(e E) bool {
return e.HasPermission(ctx)
return !e.HasPermission(ctx)
})
}

View File

@ -168,6 +168,10 @@ func manualTriggerCron(c *gin.Context) (any, error) {
}
singleton.CronLock.RUnlock()
if !cr.HasPermission(c) {
return nil, singleton.Localizer.ErrorT("permission denied")
}
singleton.ManualTrigger(cr)
return nil, nil
}

View File

@ -201,7 +201,7 @@ func batchDeleteNotificationGroup(c *gin.Context) (any, error) {
}
var ng []model.NotificationGroup
if err := singleton.DB.Where("id in (?)", ng).Find(&ng).Error; err != nil {
if err := singleton.DB.Where("id in (?)", ngnr).Find(&ng).Error; err != nil {
return nil, err
}

View File

@ -178,6 +178,9 @@ func forceUpdateServer(c *gin.Context) (*model.ForceUpdateResponse, error) {
server := singleton.ServerList[sid]
singleton.ServerLock.RUnlock()
if server != nil && server.TaskStream != nil {
if !server.HasPermission(c) {
return nil, singleton.Localizer.ErrorT("permission denied")
}
if err := server.TaskStream.Send(&pb.Task{
Type: model.TaskTypeUpgrade,
}); err != nil {

View File

@ -126,6 +126,7 @@ func createUser(c *gin.Context) (uint64, error) {
return 0, err
}
singleton.OnUserUpdate(&u)
return u.ID, nil
}
@ -150,5 +151,6 @@ func batchDeleteUser(c *gin.Context) (any, error) {
return nil, singleton.Localizer.ErrorT("can't delete yourself")
}
singleton.OnUserDelete(ids)
return nil, singleton.DB.Where("id IN (?)", ids).Delete(&model.User{}).Error
}

View File

@ -20,7 +20,7 @@ type Common struct {
// Do not use soft deletion
// DeletedAt gorm.DeletedAt `gorm:"index" json:"deleted_at,omitempty"`
UserID uint64 `json:"user_id,omitempty"`
UserID uint64 `json:"-"`
}
func (c *Common) GetID() uint64 {

View File

@ -1,5 +1,10 @@
package model
import (
"github.com/nezhahq/nezha/pkg/utils"
"gorm.io/gorm"
)
const (
RoleAdmin uint8 = iota
RoleMember
@ -7,9 +12,20 @@ const (
type User struct {
Common
Username string `json:"username,omitempty" gorm:"uniqueIndex"`
Password string `json:"password,omitempty" gorm:"type:char(72)"`
Role uint8 `json:"role,omitempty"`
Username string `json:"username,omitempty" gorm:"uniqueIndex"`
Password string `json:"password,omitempty" gorm:"type:char(72)"`
Role uint8 `json:"role,omitempty"`
AgentSecret string `json:"agent_secret,omitempty" gorm:"type:char(32)"`
}
func (u *User) BeforeSave(tx *gorm.DB) error {
key, err := utils.GenerateRandomString(32)
if err != nil {
return err
}
u.AgentSecret = key
return nil
}
type Profile struct {

View File

@ -1,6 +0,0 @@
package model
type UserGroup struct {
Common
Name string `json:"name"`
}

View File

@ -1,7 +0,0 @@
package model
type UserGroupUser struct {
Common
UserGroupId uint64 `json:"user_group_id"`
UserId uint64 `json:"user_id"`
}

View File

@ -2,6 +2,7 @@ package rpc
import (
"context"
"crypto/subtle"
"strings"
petname "github.com/dustinkirkland/golang-petname"
@ -36,10 +37,14 @@ func (a *authHandler) Check(ctx context.Context) (uint64, error) {
ip, _ := ctx.Value(model.CtxKeyRealIP{}).(string)
if clientSecret != singleton.Conf.AgentSecretKey {
singleton.UserLock.RLock()
userId, ok := singleton.AgentSecretToUserId[clientSecret]
if !ok && subtle.ConstantTimeCompare([]byte(clientSecret), []byte(singleton.Conf.AgentSecretKey)) != 1 {
singleton.UserLock.RUnlock()
model.BlockIP(singleton.DB, ip, model.WAFBlockReasonTypeAgentAuthFail)
return 0, status.Error(codes.Unauthenticated, "客户端认证失败")
}
singleton.UserLock.RUnlock()
model.ClearIP(singleton.DB, ip)
@ -53,21 +58,26 @@ func (a *authHandler) Check(ctx context.Context) (uint64, error) {
}
singleton.ServerLock.RLock()
defer singleton.ServerLock.RUnlock()
clientID, hasID := singleton.ServerUUIDToID[clientUUID]
singleton.ServerLock.RUnlock()
if !hasID {
s := model.Server{UUID: clientUUID, Name: petname.Generate(2, "-")}
s := model.Server{UUID: clientUUID, Name: petname.Generate(2, "-"), Common: model.Common{
UserID: userId,
}}
if err := singleton.DB.Create(&s).Error; err != nil {
return 0, status.Error(codes.Unauthenticated, err.Error())
}
s.Host = &model.Host{}
s.State = &model.HostState{}
s.GeoIP = &model.GeoIP{}
// generate a random silly server name
singleton.ServerLock.Lock()
singleton.ServerList[s.ID] = &s
singleton.ServerUUIDToID[clientUUID] = s.ID
singleton.ServerLock.Unlock()
singleton.ReSortServer()
clientID = s.ID
}

View File

@ -24,12 +24,10 @@ var (
func initDDNS() {
DB.Find(&DDNSList)
DDNSCacheLock.Lock()
DDNSCache = make(map[uint64]*model.DDNSProfile)
for i := 0; i < len(DDNSList); i++ {
DDNSCache[DDNSList[i].ID] = DDNSList[i]
}
DDNSCacheLock.Unlock()
OnNameserverUpdate()
}

View File

@ -19,8 +19,6 @@ var (
func initNAT() {
DB.Find(&NATList)
NATCacheRwLock.Lock()
defer NATCacheRwLock.Unlock()
NATCache = make(map[string]*model.NAT)
for i := 0; i < len(NATList); i++ {
NATCache[NATList[i].Domain] = NATList[i]

View File

@ -30,7 +30,7 @@ var (
)
// InitNotification 初始化 GroupID <-> ID <-> Notification 的映射
func InitNotification() {
func initNotification() {
NotificationList = make(map[uint64]map[uint64]*model.Notification)
NotificationIDToGroups = make(map[uint64]map[uint64]struct{})
NotificationGroup = make(map[uint64]string)
@ -38,9 +38,7 @@ func InitNotification() {
// loadNotifications 从 DB 初始化通知方式相关参数
func loadNotifications() {
InitNotification()
NotificationsLock.Lock()
initNotification()
groupNotifications := make(map[uint64][]uint64)
var ngn []model.NotificationGroupNotification
if err := DB.Find(&ngn).Error; err != nil {
@ -74,8 +72,6 @@ func loadNotifications() {
}
}
}
NotificationsLock.Unlock()
}
func UpdateNotificationList() {

View File

@ -192,13 +192,6 @@ func (ss *ServiceSentinel) loadServiceHistory() {
panic(err)
}
ss.serviceResponseDataStoreLock.Lock()
defer ss.serviceResponseDataStoreLock.Unlock()
ss.monthlyStatusLock.Lock()
defer ss.monthlyStatusLock.Unlock()
ss.ServicesLock.Lock()
defer ss.ServicesLock.Unlock()
for i := 0; i < len(services); i++ {
task := *services[i]
// 通过cron定时将服务监控任务传递给任务调度管道

View File

@ -40,6 +40,7 @@ func InitTimezoneAndCache() {
// LoadSingleton 加载子服务并执行
func LoadSingleton() {
initUser() // 加载用户ID绑定表
initI18n() // 加载本地化服务
loadNotifications() // 加载通知服务
loadServers() // 加载服务器列表
@ -79,8 +80,8 @@ func InitDBFromPath(path string) {
}
err = DB.AutoMigrate(model.Server{}, model.User{}, model.ServerGroup{}, model.NotificationGroup{},
model.Notification{}, model.AlertRule{}, model.Service{}, model.NotificationGroupNotification{},
model.ServiceHistory{}, model.Cron{}, model.Transfer{}, model.ServerGroupServer{}, model.UserGroup{},
model.UserGroupUser{}, model.NAT{}, model.DDNSProfile{}, model.NotificationGroupNotification{},
model.ServiceHistory{}, model.Cron{}, model.Transfer{}, model.ServerGroupServer{},
model.NAT{}, model.DDNSProfile{}, model.NotificationGroupNotification{},
model.WAF{})
if err != nil {
panic(err)

54
service/singleton/user.go Normal file
View File

@ -0,0 +1,54 @@
package singleton
import (
"sync"
"github.com/nezhahq/nezha/model"
)
var (
UserIdToAgentSecret map[uint64]string
AgentSecretToUserId map[string]uint64
UserLock sync.RWMutex
)
func initUser() {
UserIdToAgentSecret = make(map[uint64]string)
AgentSecretToUserId = make(map[string]uint64)
var users []model.User
DB.Find(&users)
for _, u := range users {
UserIdToAgentSecret[u.ID] = u.AgentSecret
AgentSecretToUserId[u.AgentSecret] = u.ID
}
}
func OnUserUpdate(u *model.User) {
UserLock.Lock()
defer UserLock.Unlock()
if u == nil {
return
}
UserIdToAgentSecret[u.ID] = u.AgentSecret
AgentSecretToUserId[u.AgentSecret] = u.ID
}
func OnUserDelete(id []uint64) {
UserLock.Lock()
defer UserLock.Unlock()
if len(id) < 1 {
return
}
for _, uid := range id {
secret := UserIdToAgentSecret[uid]
delete(AgentSecretToUserId, secret)
delete(UserIdToAgentSecret, uid)
}
}