Compare commits

..

No commits in common. "4af7e8300435c798a6c0c84f389abaaaa5faa9f2" and "50ee62172fd50872dedff60c01eaf4eb08e1749d" have entirely different histories.

39 changed files with 193 additions and 1009 deletions

View File

@ -50,9 +50,6 @@ func createAlertRule(c *gin.Context) (uint64, error) {
return 0, err
}
uid := getUid(c)
r.UserID = uid
r.Name = arf.Name
r.Rules = arf.Rules
r.FailTriggerTasks = arf.FailTriggerTasks
@ -62,7 +59,7 @@ func createAlertRule(c *gin.Context) (uint64, error) {
r.TriggerMode = arf.TriggerMode
r.Enable = &enable
if err := validateRule(c, &r); err != nil {
if err := validateRule(&r); err != nil {
return 0, err
}
@ -103,10 +100,6 @@ func updateAlertRule(c *gin.Context) (any, error) {
return nil, singleton.Localizer.ErrorT("alert id %d does not exist", id)
}
if !r.HasPermission(c) {
return nil, singleton.Localizer.ErrorT("permission denied")
}
r.Name = arf.Name
r.Rules = arf.Rules
r.FailTriggerTasks = arf.FailTriggerTasks
@ -116,7 +109,7 @@ func updateAlertRule(c *gin.Context) (any, error) {
r.TriggerMode = arf.TriggerMode
r.Enable = &enable
if err := validateRule(c, &r); err != nil {
if err := validateRule(&r); err != nil {
return 0, err
}
@ -141,21 +134,11 @@ func updateAlertRule(c *gin.Context) (any, error) {
// @Router /batch-delete/alert-rule [post]
func batchDeleteAlertRule(c *gin.Context) (any, error) {
var ar []uint64
if err := c.ShouldBindJSON(&ar); err != nil {
return nil, err
}
var ars []model.AlertRule
if err := singleton.DB.Where("id in (?)", ar).Find(&ars).Error; err != nil {
return nil, err
}
for _, a := range ars {
if !a.HasPermission(c) {
return nil, singleton.Localizer.ErrorT("permission denied")
}
}
if err := singleton.DB.Unscoped().Delete(&model.AlertRule{}, "id in (?)", ar).Error; err != nil {
return nil, newGormError("%v", err)
}
@ -164,20 +147,9 @@ func batchDeleteAlertRule(c *gin.Context) (any, error) {
return nil, nil
}
func validateRule(c *gin.Context, r *model.AlertRule) error {
func validateRule(r *model.AlertRule) error {
if len(r.Rules) > 0 {
for _, rule := range r.Rules {
singleton.ServerLock.RLock()
for s := range rule.Ignore {
if server, ok := singleton.ServerList[s]; ok {
if !server.HasPermission(c) {
singleton.ServerLock.RUnlock()
return singleton.Localizer.ErrorT("permission denied")
}
}
}
singleton.ServerLock.RUnlock()
if !rule.IsTransferDurationRule() {
if rule.Duration < 3 {
return singleton.Localizer.ErrorT("duration need to be at least 3")

View File

@ -9,7 +9,6 @@ import (
"net/http"
"os"
"path"
"slices"
"strings"
jwt "github.com/appleboy/gin-jwt/v2"
@ -79,11 +78,11 @@ func routers(r *gin.Engine, frontendDist fs.FS) {
auth.GET("/profile", commonHandler(getProfile))
auth.POST("/profile", commonHandler(updateProfile))
auth.GET("/user", adminHandler(listUser))
auth.POST("/user", adminHandler(createUser))
auth.POST("/batch-delete/user", adminHandler(batchDeleteUser))
auth.GET("/user", commonHandler(listUser))
auth.POST("/user", commonHandler(createUser))
auth.POST("/batch-delete/user", commonHandler(batchDeleteUser))
auth.GET("/service/list", listHandler(listService))
auth.GET("/service/list", commonHandler(listService))
auth.POST("/service", commonHandler(createService))
auth.PATCH("/service/:id", commonHandler(updateService))
auth.POST("/batch-delete/service", commonHandler(batchDeleteService))
@ -97,45 +96,42 @@ func routers(r *gin.Engine, frontendDist fs.FS) {
auth.PATCH("/notification-group/:id", commonHandler(updateNotificationGroup))
auth.POST("/batch-delete/notification-group", commonHandler(batchDeleteNotificationGroup))
auth.GET("/server", listHandler(listServer))
auth.GET("/server", commonHandler(listServer))
auth.PATCH("/server/:id", commonHandler(updateServer))
auth.POST("/batch-delete/server", commonHandler(batchDeleteServer))
auth.POST("/force-update/server", commonHandler(forceUpdateServer))
auth.GET("/notification", listHandler(listNotification))
auth.GET("/notification", commonHandler(listNotification))
auth.POST("/notification", commonHandler(createNotification))
auth.PATCH("/notification/:id", commonHandler(updateNotification))
auth.POST("/batch-delete/notification", commonHandler(batchDeleteNotification))
auth.GET("/alert-rule", listHandler(listAlertRule))
auth.GET("/alert-rule", commonHandler(listAlertRule))
auth.POST("/alert-rule", commonHandler(createAlertRule))
auth.PATCH("/alert-rule/:id", commonHandler(updateAlertRule))
auth.POST("/batch-delete/alert-rule", commonHandler(batchDeleteAlertRule))
auth.GET("/cron", listHandler(listCron))
auth.GET("/cron", commonHandler(listCron))
auth.POST("/cron", commonHandler(createCron))
auth.PATCH("/cron/:id", commonHandler(updateCron))
auth.GET("/cron/:id/manual", commonHandler(manualTriggerCron))
auth.POST("/batch-delete/cron", commonHandler(batchDeleteCron))
auth.GET("/ddns", listHandler(listDDNS))
auth.GET("/ddns", commonHandler(listDDNS))
auth.GET("/ddns/providers", commonHandler(listProviders))
auth.POST("/ddns", commonHandler(createDDNS))
auth.PATCH("/ddns/:id", commonHandler(updateDDNS))
auth.POST("/batch-delete/ddns", commonHandler(batchDeleteDDNS))
auth.GET("/nat", listHandler(listNAT))
auth.GET("/nat", commonHandler(listNAT))
auth.POST("/nat", commonHandler(createNAT))
auth.PATCH("/nat/:id", commonHandler(updateNAT))
auth.POST("/batch-delete/nat", commonHandler(batchDeleteNAT))
auth.GET("/waf", pCommonHandler(listBlockedAddress))
auth.POST("/batch-delete/waf", adminHandler(batchDeleteBlockedAddress))
auth.GET("/waf", commonHandler(listBlockedAddress))
auth.POST("/batch-delete/waf", commonHandler(batchDeleteBlockedAddress))
auth.GET("/online-user", pCommonHandler(listOnlineUser))
auth.GET("/online-user/batch-block", adminHandler(batchBlockOnlineUser))
auth.PATCH("/setting", adminHandler(updateConfig))
auth.PATCH("/setting", commonHandler(updateConfig))
r.NoRoute(fallbackToFrontend(frontendDist))
}
@ -156,7 +152,6 @@ func newErrorResponse(err error) model.CommonResponse[any] {
}
type handlerFunc[T any] func(c *gin.Context) (T, error)
type pHandlerFunc[S ~[]E, E any] func(c *gin.Context) (*model.Value[S], error)
// There are many error types in gorm, so create a custom type to represent all
// gorm errors here instead
@ -193,87 +188,30 @@ func (we *wsError) Error() string {
}
func commonHandler[T any](handler handlerFunc[T]) func(*gin.Context) {
return func(c *gin.Context) {
handle(c, handler)
}
}
func adminHandler[T any](handler handlerFunc[T]) func(*gin.Context) {
return func(c *gin.Context) {
auth, ok := c.Get(model.CtxKeyAuthorizedUser)
if !ok {
c.JSON(http.StatusOK, newErrorResponse(singleton.Localizer.ErrorT("unauthorized")))
return
}
user := *auth.(*model.User)
if user.Role != model.RoleAdmin {
c.JSON(http.StatusOK, newErrorResponse(singleton.Localizer.ErrorT("permission denied")))
return
}
handle(c, handler)
}
}
func handle[T any](c *gin.Context, handler handlerFunc[T]) {
data, err := handler(c)
if err == nil {
c.JSON(http.StatusOK, model.CommonResponse[T]{Success: true, Data: data})
return
}
switch err.(type) {
case *gormError:
log.Printf("NEZHA>> gorm error: %v", err)
c.JSON(http.StatusOK, newErrorResponse(singleton.Localizer.ErrorT("database error")))
return
case *wsError:
// Connection is upgraded to WebSocket, so c.Writer is no longer usable
if msg := err.Error(); msg != "" {
log.Printf("NEZHA>> websocket error: %v", err)
}
return
default:
c.JSON(http.StatusOK, newErrorResponse(err))
return
}
}
func listHandler[S ~[]E, E model.CommonInterface](handler handlerFunc[S]) func(*gin.Context) {
return func(c *gin.Context) {
data, err := handler(c)
if err != nil {
if err == nil {
c.JSON(http.StatusOK, model.CommonResponse[T]{Success: true, Data: data})
return
}
switch err.(type) {
case *gormError:
log.Printf("NEZHA>> gorm error: %v", err)
c.JSON(http.StatusOK, newErrorResponse(singleton.Localizer.ErrorT("database error")))
return
case *wsError:
// Connection is upgraded to WebSocket, so c.Writer is no longer usable
if msg := err.Error(); msg != "" {
log.Printf("NEZHA>> websocket error: %v", err)
}
return
default:
c.JSON(http.StatusOK, newErrorResponse(err))
return
}
c.JSON(http.StatusOK, model.CommonResponse[S]{Success: true, Data: filter(c, data)})
}
}
func pCommonHandler[S ~[]E, E any](handler pHandlerFunc[S, E]) func(*gin.Context) {
return func(c *gin.Context) {
data, err := handler(c)
if err != nil {
c.JSON(http.StatusOK, newErrorResponse(err))
return
}
c.JSON(http.StatusOK, model.PaginatedResponse[S, E]{Success: true, Data: data})
}
}
func filter[S ~[]E, E model.CommonInterface](ctx *gin.Context, s S) S {
return slices.DeleteFunc(s, func(e E) bool {
return !e.HasPermission(ctx)
})
}
func getUid(c *gin.Context) uint64 {
user, _ := c.MustGet(model.CtxKeyAuthorizedUser).(*model.User)
return user.ID
}
func fallbackToFrontend(frontendDist fs.FS) func(*gin.Context) {
checkLocalFileOrFs := func(c *gin.Context, fs fs.FS, path string) bool {
if _, err := os.Stat(path); err == nil {

View File

@ -1,6 +1,7 @@
package controller
import (
"fmt"
"strconv"
"github.com/gin-gonic/gin"
@ -49,18 +50,6 @@ func createCron(c *gin.Context) (uint64, error) {
return 0, err
}
singleton.ServerLock.RLock()
for _, sid := range cf.Servers {
if server, ok := singleton.ServerList[sid]; ok {
if !server.HasPermission(c) {
singleton.ServerLock.RUnlock()
return 0, singleton.Localizer.ErrorT("permission denied")
}
}
}
singleton.ServerLock.RUnlock()
cr.UserID = getUid(c)
cr.TaskType = cf.TaskType
cr.Name = cf.Name
cr.Scheduler = cf.Scheduler
@ -115,24 +104,9 @@ func updateCron(c *gin.Context) (any, error) {
return 0, err
}
singleton.ServerLock.RLock()
for _, sid := range cf.Servers {
if server, ok := singleton.ServerList[sid]; ok {
if !server.HasPermission(c) {
singleton.ServerLock.RUnlock()
return nil, singleton.Localizer.ErrorT("permission denied")
}
}
}
singleton.ServerLock.RUnlock()
var cr model.Cron
if err := singleton.DB.First(&cr, id).Error; err != nil {
return nil, singleton.Localizer.ErrorT("task id %d does not exist", id)
}
if !cr.HasPermission(c) {
return nil, singleton.Localizer.ErrorT("permission denied")
return nil, fmt.Errorf("task id %d does not exist", id)
}
cr.TaskType = cf.TaskType
@ -182,19 +156,12 @@ func manualTriggerCron(c *gin.Context) (any, error) {
return nil, err
}
singleton.CronLock.RLock()
cr, ok := singleton.Crons[id]
if !ok {
singleton.CronLock.RUnlock()
var cr model.Cron
if err := singleton.DB.First(&cr, id).Error; err != nil {
return nil, singleton.Localizer.ErrorT("task id %d does not exist", id)
}
singleton.CronLock.RUnlock()
if !cr.HasPermission(c) {
return nil, singleton.Localizer.ErrorT("permission denied")
}
singleton.ManualTrigger(cr)
singleton.ManualTrigger(&cr)
return nil, nil
}
@ -211,21 +178,11 @@ func manualTriggerCron(c *gin.Context) (any, error) {
// @Router /batch-delete/cron [post]
func batchDeleteCron(c *gin.Context) (any, error) {
var cr []uint64
if err := c.ShouldBindJSON(&cr); err != nil {
return nil, err
}
singleton.CronLock.RLock()
for _, crID := range cr {
if crn, ok := singleton.Crons[crID]; ok {
if !crn.HasPermission(c) {
singleton.CronLock.RUnlock()
return nil, singleton.Localizer.ErrorT("permission denied")
}
}
}
singleton.CronLock.RUnlock()
if err := singleton.DB.Unscoped().Delete(&model.Cron{}, "id in (?)", cr).Error; err != nil {
return nil, newGormError("%v", err)
}

View File

@ -56,7 +56,6 @@ func createDDNS(c *gin.Context) (uint64, error) {
return 0, singleton.Localizer.ErrorT("the retry count must be an integer between 1 and 10")
}
p.UserID = getUid(c)
p.Name = df.Name
enableIPv4 := df.EnableIPv4
enableIPv6 := df.EnableIPv6
@ -126,10 +125,6 @@ func updateDDNS(c *gin.Context) (any, error) {
return nil, singleton.Localizer.ErrorT("profile id %d does not exist", id)
}
if !p.HasPermission(c) {
return nil, singleton.Localizer.ErrorT("permission denied")
}
p.Name = df.Name
enableIPv4 := df.EnableIPv4
enableIPv6 := df.EnableIPv6
@ -183,17 +178,6 @@ func batchDeleteDDNS(c *gin.Context) (any, error) {
return nil, err
}
singleton.DDNSCacheLock.RLock()
for _, pid := range ddnsConfigs {
if p, ok := singleton.DDNSCache[pid]; ok {
if !p.HasPermission(c) {
singleton.DDNSCacheLock.RUnlock()
return nil, singleton.Localizer.ErrorT("permission denied")
}
}
}
singleton.DDNSCacheLock.RUnlock()
if err := singleton.DB.Unscoped().Delete(&model.DDNSProfile{}, "id in (?)", ddnsConfigs).Error; err != nil {
return nil, newGormError("%v", err)
}

View File

@ -7,7 +7,6 @@ import (
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/hashicorp/go-uuid"
"github.com/nezhahq/nezha/model"
"github.com/nezhahq/nezha/pkg/utils"
"github.com/nezhahq/nezha/pkg/websocketx"
@ -32,17 +31,6 @@ func createFM(c *gin.Context) (*model.CreateFMResponse, error) {
return nil, err
}
singleton.ServerLock.RLock()
server := singleton.ServerList[id]
singleton.ServerLock.RUnlock()
if server == nil || server.TaskStream == nil {
return nil, singleton.Localizer.ErrorT("server not found or not connected")
}
if !server.HasPermission(c) {
return nil, singleton.Localizer.ErrorT("permission denied")
}
streamId, err := uuid.GenerateUUID()
if err != nil {
return nil, err
@ -50,6 +38,13 @@ func createFM(c *gin.Context) (*model.CreateFMResponse, error) {
rpc.NezhaHandlerSingleton.CreateStream(streamId)
singleton.ServerLock.RLock()
server := singleton.ServerList[id]
singleton.ServerLock.RUnlock()
if server == nil || server.TaskStream == nil {
return nil, singleton.Localizer.ErrorT("server not found or not connected")
}
fmData, _ := utils.Json.Marshal(&model.TaskFM{
StreamID: streamId,
})

View File

@ -88,21 +88,18 @@ func authenticator() func(c *gin.Context) (interface{}, error) {
}
var user model.User
realip := c.GetString(model.CtxKeyRealIPStr)
if err := singleton.DB.Select("id", "password").Where("username = ?", loginVals.Username).First(&user).Error; err != nil {
if err == gorm.ErrRecordNotFound {
model.BlockIP(singleton.DB, realip, model.WAFBlockReasonTypeLoginFail, model.BlockIDUnknownUser)
model.BlockIP(singleton.DB, c.GetString(model.CtxKeyRealIPStr), model.WAFBlockReasonTypeLoginFail)
}
return nil, jwt.ErrFailedAuthentication
}
if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(loginVals.Password)); err != nil {
model.BlockIP(singleton.DB, realip, model.WAFBlockReasonTypeLoginFail, int64(user.ID))
model.BlockIP(singleton.DB, c.GetString(model.CtxKeyRealIPStr), model.WAFBlockReasonTypeLoginFail)
return nil, jwt.ErrFailedAuthentication
}
model.ClearIP(singleton.DB, realip, model.BlockIDUnknownUser)
model.ClearIP(singleton.DB, realip, int64(user.ID))
return utils.Itoa(user.ID), nil
}
}
@ -172,10 +169,10 @@ func optionalAuthMiddleware(mw *jwt.GinJWTMiddleware) func(c *gin.Context) {
identity := mw.IdentityHandler(c)
if identity != nil {
model.ClearIP(singleton.DB, c.GetString(model.CtxKeyRealIPStr), model.BlockIDToken)
model.ClearIP(singleton.DB, c.GetString(model.CtxKeyRealIPStr))
c.Set(mw.IdentityKey, identity)
} else {
if err := model.BlockIP(singleton.DB, c.GetString(model.CtxKeyRealIPStr), model.WAFBlockReasonTypeBruteForceToken, model.BlockIDToken); err != nil {
if err := model.BlockIP(singleton.DB, c.GetString(model.CtxKeyRealIPStr), model.WAFBlockReasonTypeBruteForceToken); err != nil {
waf.ShowBlockPage(c, err)
return
}

View File

@ -51,18 +51,6 @@ func createNAT(c *gin.Context) (uint64, error) {
return 0, err
}
singleton.ServerLock.RLock()
if server, ok := singleton.ServerList[nf.ServerID]; ok {
if !server.HasPermission(c) {
singleton.ServerLock.RUnlock()
return 0, singleton.Localizer.ErrorT("permission denied")
}
}
singleton.ServerLock.RUnlock()
uid := getUid(c)
n.UserID = uid
n.Name = nf.Name
n.Domain = nf.Domain
n.Host = nf.Host
@ -102,24 +90,11 @@ func updateNAT(c *gin.Context) (any, error) {
return nil, err
}
singleton.ServerLock.RLock()
if server, ok := singleton.ServerList[nf.ServerID]; ok {
if !server.HasPermission(c) {
singleton.ServerLock.RUnlock()
return nil, singleton.Localizer.ErrorT("permission denied")
}
}
singleton.ServerLock.RUnlock()
var n model.NAT
if err = singleton.DB.First(&n, id).Error; err != nil {
return nil, singleton.Localizer.ErrorT("profile id %d does not exist", id)
}
if !n.HasPermission(c) {
return nil, singleton.Localizer.ErrorT("permission denied")
}
n.Name = nf.Name
n.Domain = nf.Domain
n.Host = nf.Host
@ -147,21 +122,11 @@ func updateNAT(c *gin.Context) (any, error) {
// @Router /batch-delete/nat [post]
func batchDeleteNAT(c *gin.Context) (any, error) {
var n []uint64
if err := c.ShouldBindJSON(&n); err != nil {
return nil, err
}
singleton.NATCacheRwLock.RLock()
for _, id := range n {
if p, ok := singleton.NATCache[singleton.NATIDToDomain[id]]; ok {
if !p.HasPermission(c) {
singleton.NATCacheRwLock.RUnlock()
return nil, singleton.Localizer.ErrorT("permission denied")
}
}
}
singleton.NATCacheRwLock.RUnlock()
if err := singleton.DB.Unscoped().Delete(&model.NAT{}, "id in (?)", n).Error; err != nil {
return nil, newGormError("%v", err)
}

View File

@ -48,7 +48,6 @@ func createNotification(c *gin.Context) (uint64, error) {
}
var n model.Notification
n.UserID = getUid(c)
n.Name = nf.Name
n.RequestMethod = nf.RequestMethod
n.RequestType = nf.RequestType
@ -107,10 +106,6 @@ func updateNotification(c *gin.Context) (any, error) {
return nil, singleton.Localizer.ErrorT("notification id %d does not exist", id)
}
if !n.HasPermission(c) {
return nil, singleton.Localizer.ErrorT("permission denied")
}
n.Name = nf.Name
n.RequestMethod = nf.RequestMethod
n.RequestType = nf.RequestType
@ -154,20 +149,11 @@ func updateNotification(c *gin.Context) (any, error) {
// @Router /batch-delete/notification [post]
func batchDeleteNotification(c *gin.Context) (any, error) {
var n []uint64
if err := c.ShouldBindJSON(&n); err != nil {
return nil, err
}
singleton.NotificationsLock.RLock()
for _, nid := range n {
if ns, ok := singleton.NotificationMap[nid]; ok {
if !ns.HasPermission(c) {
return nil, singleton.Localizer.ErrorT("permission denied")
}
}
}
singleton.NotificationsLock.RUnlock()
err := singleton.DB.Transaction(func(tx *gorm.DB) error {
if err := tx.Unscoped().Delete(&model.Notification{}, "id in (?)", n).Error; err != nil {
return err

View File

@ -20,7 +20,7 @@ import (
// @Produce json
// @Success 200 {object} model.CommonResponse[[]model.NotificationGroupResponseItem]
// @Router /notification-group [get]
func listNotificationGroup(c *gin.Context) ([]*model.NotificationGroupResponseItem, error) {
func listNotificationGroup(c *gin.Context) ([]model.NotificationGroupResponseItem, error) {
var ng []model.NotificationGroup
if err := singleton.DB.Find(&ng).Error; err != nil {
return nil, err
@ -39,9 +39,9 @@ func listNotificationGroup(c *gin.Context) ([]*model.NotificationGroupResponseIt
groupNotifications[n.NotificationGroupID] = append(groupNotifications[n.NotificationGroupID], n.NotificationID)
}
ngRes := make([]*model.NotificationGroupResponseItem, 0, len(ng))
ngRes := make([]model.NotificationGroupResponseItem, 0, len(ng))
for _, n := range ng {
ngRes = append(ngRes, &model.NotificationGroupResponseItem{
ngRes = append(ngRes, model.NotificationGroupResponseItem{
Group: n,
Notifications: groupNotifications[n.ID],
})
@ -68,22 +68,8 @@ func createNotificationGroup(c *gin.Context) (uint64, error) {
}
ngf.Notifications = slices.Compact(ngf.Notifications)
singleton.NotificationsLock.RLock()
for _, nid := range ngf.Notifications {
if n, ok := singleton.NotificationMap[nid]; ok {
if !n.HasPermission(c) {
singleton.NotificationsLock.RUnlock()
return 0, singleton.Localizer.ErrorT("permission denied")
}
}
}
singleton.NotificationsLock.RUnlock()
uid := getUid(c)
var ng model.NotificationGroup
ng.Name = ngf.Name
ng.UserID = uid
var count int64
if err := singleton.DB.Model(&model.Notification{}).Where("id in (?)", ngf.Notifications).Count(&count).Error; err != nil {
@ -100,9 +86,6 @@ func createNotificationGroup(c *gin.Context) (uint64, error) {
}
for _, n := range ngf.Notifications {
if err := tx.Create(&model.NotificationGroupNotification{
Common: model.Common{
UserID: uid,
},
NotificationGroupID: ng.ID,
NotificationID: n,
}).Error; err != nil {
@ -143,27 +126,11 @@ func updateNotificationGroup(c *gin.Context) (any, error) {
if err := c.ShouldBindJSON(&ngf); err != nil {
return nil, err
}
singleton.NotificationsLock.RLock()
for _, nid := range ngf.Notifications {
if n, ok := singleton.NotificationMap[nid]; ok {
if !n.HasPermission(c) {
singleton.NotificationsLock.RUnlock()
return nil, singleton.Localizer.ErrorT("permission denied")
}
}
}
singleton.NotificationsLock.RUnlock()
var ngDB model.NotificationGroup
if err := singleton.DB.First(&ngDB, id).Error; err != nil {
return nil, singleton.Localizer.ErrorT("group id %d does not exist", id)
}
if !ngDB.HasPermission(c) {
return nil, singleton.Localizer.ErrorT("permission denied")
}
ngDB.Name = ngf.Name
ngf.Notifications = slices.Compact(ngf.Notifications)
@ -175,8 +142,6 @@ func updateNotificationGroup(c *gin.Context) (any, error) {
return nil, singleton.Localizer.ErrorT("have invalid notification id")
}
uid := getUid(c)
err = singleton.DB.Transaction(func(tx *gorm.DB) error {
if err := tx.Save(&ngDB).Error; err != nil {
return err
@ -187,9 +152,6 @@ func updateNotificationGroup(c *gin.Context) (any, error) {
for _, n := range ngf.Notifications {
if err := tx.Create(&model.NotificationGroupNotification{
Common: model.Common{
UserID: uid,
},
NotificationGroupID: ngDB.ID,
NotificationID: n,
}).Error; err != nil {
@ -223,17 +185,6 @@ func batchDeleteNotificationGroup(c *gin.Context) (any, error) {
return nil, err
}
var ng []model.NotificationGroup
if err := singleton.DB.Where("id in (?)", ngn).Find(&ng).Error; err != nil {
return nil, err
}
for _, n := range ng {
if !n.HasPermission(c) {
return nil, singleton.Localizer.ErrorT("permission denied")
}
}
err := singleton.DB.Transaction(func(tx *gorm.DB) error {
if err := tx.Unscoped().Delete(&model.NotificationGroup{}, "id in (?)", ngn).Error; err != nil {
return err

View File

@ -56,26 +56,11 @@ func updateServer(c *gin.Context) (any, error) {
return nil, err
}
singleton.DDNSCacheLock.RLock()
for _, pid := range sf.DDNSProfiles {
if p, ok := singleton.DDNSCache[pid]; ok {
if !p.HasPermission(c) {
singleton.DDNSCacheLock.RUnlock()
return nil, singleton.Localizer.ErrorT("permission denied")
}
}
}
singleton.DDNSCacheLock.RUnlock()
var s model.Server
if err := singleton.DB.First(&s, id).Error; err != nil {
return nil, singleton.Localizer.ErrorT("server id %d does not exist", id)
}
if !s.HasPermission(c) {
return nil, singleton.Localizer.ErrorT("permission denied")
}
s.Name = sf.Name
s.DisplayIndex = sf.DisplayIndex
s.Note = sf.Note
@ -119,17 +104,6 @@ func batchDeleteServer(c *gin.Context) (any, error) {
return nil, err
}
singleton.ServerLock.RLock()
for _, sid := range servers {
if s, ok := singleton.ServerList[sid]; ok {
if !s.HasPermission(c) {
singleton.ServerLock.RUnlock()
return nil, singleton.Localizer.ErrorT("permission denied")
}
}
}
singleton.ServerLock.RUnlock()
err := singleton.DB.Transaction(func(tx *gorm.DB) error {
if err := tx.Unscoped().Delete(&model.Server{}, "id in (?)", servers).Error; err != nil {
return err
@ -187,9 +161,6 @@ func forceUpdateServer(c *gin.Context) (*model.ForceUpdateResponse, error) {
server := singleton.ServerList[sid]
singleton.ServerLock.RUnlock()
if server != nil && server.TaskStream != nil {
if !server.HasPermission(c) {
return nil, singleton.Localizer.ErrorT("permission denied")
}
if err := server.TaskStream.Send(&pb.Task{
Type: model.TaskTypeUpgrade,
}); err != nil {

View File

@ -20,7 +20,7 @@ import (
// @Produce json
// @Success 200 {object} model.CommonResponse[[]model.ServerGroupResponseItem]
// @Router /server-group [get]
func listServerGroup(c *gin.Context) ([]*model.ServerGroupResponseItem, error) {
func listServerGroup(c *gin.Context) ([]model.ServerGroupResponseItem, error) {
var sg []model.ServerGroup
if err := singleton.DB.Find(&sg).Error; err != nil {
return nil, err
@ -38,9 +38,9 @@ func listServerGroup(c *gin.Context) ([]*model.ServerGroupResponseItem, error) {
groupServers[s.ServerGroupId] = append(groupServers[s.ServerGroupId], s.ServerId)
}
var sgRes []*model.ServerGroupResponseItem
var sgRes []model.ServerGroupResponseItem
for _, s := range sg {
sgRes = append(sgRes, &model.ServerGroupResponseItem{
sgRes = append(sgRes, model.ServerGroupResponseItem{
Group: s,
Servers: groupServers[s.ID],
})
@ -67,22 +67,8 @@ func createServerGroup(c *gin.Context) (uint64, error) {
}
sgf.Servers = slices.Compact(sgf.Servers)
singleton.ServerLock.RLock()
for _, sid := range sgf.Servers {
if server, ok := singleton.ServerList[sid]; ok {
if !server.HasPermission(c) {
singleton.ServerLock.RUnlock()
return 0, singleton.Localizer.ErrorT("permission denied")
}
}
}
singleton.ServerLock.RUnlock()
uid := getUid(c)
var sg model.ServerGroup
sg.Name = sgf.Name
sg.UserID = uid
var count int64
if err := singleton.DB.Model(&model.Server{}).Where("id in (?)", sgf.Servers).Count(&count).Error; err != nil {
@ -98,9 +84,6 @@ func createServerGroup(c *gin.Context) (uint64, error) {
}
for _, s := range sgf.Servers {
if err := tx.Create(&model.ServerGroupServer{
Common: model.Common{
UserID: uid,
},
ServerGroupId: sg.ID,
ServerId: s,
}).Error; err != nil {
@ -142,26 +125,10 @@ func updateServerGroup(c *gin.Context) (any, error) {
}
sg.Servers = slices.Compact(sg.Servers)
singleton.ServerLock.RLock()
for _, sid := range sg.Servers {
if server, ok := singleton.ServerList[sid]; ok {
if !server.HasPermission(c) {
singleton.ServerLock.RUnlock()
return nil, singleton.Localizer.ErrorT("permission denied")
}
}
}
singleton.ServerLock.RUnlock()
var sgDB model.ServerGroup
if err := singleton.DB.First(&sgDB, id).Error; err != nil {
return nil, singleton.Localizer.ErrorT("group id %d does not exist", id)
}
if !sgDB.HasPermission(c) {
return nil, singleton.Localizer.ErrorT("unauthorized")
}
sgDB.Name = sg.Name
var count int64
@ -172,8 +139,6 @@ func updateServerGroup(c *gin.Context) (any, error) {
return nil, singleton.Localizer.ErrorT("have invalid server id")
}
uid := getUid(c)
err = singleton.DB.Transaction(func(tx *gorm.DB) error {
if err := tx.Save(&sgDB).Error; err != nil {
return err
@ -184,9 +149,6 @@ func updateServerGroup(c *gin.Context) (any, error) {
for _, s := range sg.Servers {
if err := tx.Create(&model.ServerGroupServer{
Common: model.Common{
UserID: uid,
},
ServerGroupId: sgDB.ID,
ServerId: s,
}).Error; err != nil {
@ -219,17 +181,6 @@ func batchDeleteServerGroup(c *gin.Context) (any, error) {
return nil, err
}
var sg []model.ServerGroup
if err := singleton.DB.Where("id in (?)", sgs).Find(&sg).Error; err != nil {
return nil, err
}
for _, s := range sg {
if !s.HasPermission(c) {
return nil, singleton.Localizer.ErrorT("permission denied")
}
}
err := singleton.DB.Transaction(func(tx *gorm.DB) error {
if err := tx.Unscoped().Delete(&model.ServerGroup{}, "id in (?)", sgs).Error; err != nil {
return err

View File

@ -190,10 +190,7 @@ func createService(c *gin.Context) (uint64, error) {
return 0, err
}
uid := getUid(c)
var m model.Service
m.UserID = uid
m.Name = mf.Name
m.Target = strings.TrimSpace(mf.Target)
m.Type = mf.Type
@ -210,10 +207,6 @@ func createService(c *gin.Context) (uint64, error) {
m.RecoverTriggerTasks = mf.RecoverTriggerTasks
m.FailTriggerTasks = mf.FailTriggerTasks
if err := validateServers(c, &m); err != nil {
return 0, err
}
if err := singleton.DB.Create(&m).Error; err != nil {
return 0, newGormError("%v", err)
}
@ -267,11 +260,6 @@ func updateService(c *gin.Context) (any, error) {
if err := singleton.DB.First(&m, id).Error; err != nil {
return nil, singleton.Localizer.ErrorT("service id %d does not exist", id)
}
if !m.HasPermission(c) {
return nil, singleton.Localizer.ErrorT("permission denied")
}
m.Name = mf.Name
m.Target = strings.TrimSpace(mf.Target)
m.Type = mf.Type
@ -288,10 +276,6 @@ func updateService(c *gin.Context) (any, error) {
m.RecoverTriggerTasks = mf.RecoverTriggerTasks
m.FailTriggerTasks = mf.FailTriggerTasks
if err := validateServers(c, &m); err != nil {
return 0, err
}
if err := singleton.DB.Save(&m).Error; err != nil {
return nil, newGormError("%v", err)
}
@ -334,18 +318,6 @@ func batchDeleteService(c *gin.Context) (any, error) {
if err := c.ShouldBindJSON(&ids); err != nil {
return nil, err
}
singleton.ServiceSentinelShared.ServicesLock.RLock()
for _, id := range ids {
if ss, ok := singleton.ServiceSentinelShared.Services[id]; ok {
if !ss.HasPermission(c) {
singleton.ServiceSentinelShared.ServicesLock.RUnlock()
return nil, singleton.Localizer.ErrorT("permission denied")
}
}
}
singleton.ServiceSentinelShared.ServicesLock.RUnlock()
err := singleton.DB.Transaction(func(tx *gorm.DB) error {
if err := tx.Unscoped().Delete(&model.Service{}, "id in (?)", ids).Error; err != nil {
return err
@ -359,18 +331,3 @@ func batchDeleteService(c *gin.Context) (any, error) {
singleton.ServiceSentinelShared.UpdateServiceList()
return nil, nil
}
func validateServers(c *gin.Context, ss *model.Service) error {
singleton.ServerLock.RLock()
defer singleton.ServerLock.RUnlock()
for s := range ss.SkipServers {
if server, ok := singleton.ServerList[s]; ok {
if !server.HasPermission(c) {
return singleton.Localizer.ErrorT("permission denied")
}
}
}
return nil
}

View File

@ -50,7 +50,7 @@ func listConfig(c *gin.Context) (model.SettingResponse, error) {
// @Security BearerAuth
// @Schemes
// @Description Edit config
// @Tags admin required
// @Tags auth required
// @Accept json
// @Param body body model.SettingForm true "SettingForm"
// @Produce json

View File

@ -6,7 +6,6 @@ import (
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/hashicorp/go-uuid"
"github.com/nezhahq/nezha/model"
"github.com/nezhahq/nezha/pkg/utils"
"github.com/nezhahq/nezha/pkg/websocketx"
@ -30,17 +29,6 @@ func createTerminal(c *gin.Context) (*model.CreateTerminalResponse, error) {
return nil, err
}
singleton.ServerLock.RLock()
server := singleton.ServerList[createTerminalReq.ServerID]
singleton.ServerLock.RUnlock()
if server == nil || server.TaskStream == nil {
return nil, singleton.Localizer.ErrorT("server not found or not connected")
}
if !server.HasPermission(c) {
return nil, singleton.Localizer.ErrorT("permission denied")
}
streamId, err := uuid.GenerateUUID()
if err != nil {
return nil, err
@ -48,6 +36,13 @@ func createTerminal(c *gin.Context) (*model.CreateTerminalResponse, error) {
rpc.NezhaHandlerSingleton.CreateStream(streamId)
singleton.ServerLock.RLock()
server := singleton.ServerList[createTerminalReq.ServerID]
singleton.ServerLock.RUnlock()
if server == nil || server.TaskStream == nil {
return nil, singleton.Localizer.ErrorT("server not found or not connected")
}
terminalData, _ := utils.Json.Marshal(&model.TerminalTask{
StreamID: streamId,
})

View File

@ -2,7 +2,6 @@ package controller
import (
"slices"
"strconv"
"github.com/gin-gonic/gin"
"golang.org/x/crypto/bcrypt"
@ -77,7 +76,7 @@ func updateProfile(c *gin.Context) (any, error) {
// @Security BearerAuth
// @Schemes
// @Description List user
// @Tags admin required
// @Tags auth required
// @Produce json
// @Success 200 {object} model.CommonResponse[[]model.User]
// @Router /user [get]
@ -94,7 +93,7 @@ func listUser(c *gin.Context) ([]model.User, error) {
// @Security BearerAuth
// @Schemes
// @Description Create user
// @Tags admin required
// @Tags auth required
// @Accept json
// @param request body model.UserForm true "User Request"
// @Produce json
@ -115,7 +114,6 @@ func createUser(c *gin.Context) (uint64, error) {
var u model.User
u.Username = uf.Username
u.Role = model.RoleMember
hash, err := bcrypt.GenerateFromPassword([]byte(uf.Password), bcrypt.DefaultCost)
if err != nil {
@ -127,7 +125,6 @@ func createUser(c *gin.Context) (uint64, error) {
return 0, err
}
singleton.OnUserUpdate(&u)
return u.ID, nil
}
@ -136,7 +133,7 @@ func createUser(c *gin.Context) (uint64, error) {
// @Security BearerAuth
// @Schemes
// @Description Batch delete users
// @Tags admin required
// @Tags auth required
// @Accept json
// @param request body []uint true "id list"
// @Produce json
@ -152,62 +149,5 @@ func batchDeleteUser(c *gin.Context) (any, error) {
return nil, singleton.Localizer.ErrorT("can't delete yourself")
}
err := singleton.OnUserDelete(ids, newGormError)
return nil, err
}
// List online users
// @Summary List online users
// @Security BearerAuth
// @Schemes
// @Description List online users
// @Tags auth required
// @Param limit query uint false "Page limit"
// @Param offset query uint false "Page offset"
// @Produce json
// @Success 200 {object} model.PaginatedResponse[[]model.OnlineUser, model.OnlineUser]
// @Router /online-user [get]
func listOnlineUser(c *gin.Context) (*model.Value[[]*model.OnlineUser], error) {
limit, err := strconv.Atoi(c.Query("limit"))
if err != nil || limit < 1 {
limit = 25
}
offset, err := strconv.Atoi(c.Query("offset"))
if err != nil || offset < 0 {
offset = 0
}
return &model.Value[[]*model.OnlineUser]{
Value: singleton.GetOnlineUsers(limit, offset),
Pagination: model.Pagination{
Offset: offset,
Limit: limit,
Total: int64(singleton.GetOnlineUserCount()),
},
}, nil
}
// Batch block online user
// @Summary Batch block online user
// @Security BearerAuth
// @Schemes
// @Description Batch block online user
// @Tags admin required
// @Accept json
// @Param request body []string true "block list"
// @Produce json
// @Success 200 {object} model.CommonResponse[any]
// @Router /online-user/batch-block [patch]
func batchBlockOnlineUser(c *gin.Context) (any, error) {
var list []string
if err := c.ShouldBindJSON(&list); err != nil {
return nil, err
}
if err := singleton.BlockByIPs(list); err != nil {
return nil, newGormError("%v", err)
}
return nil, nil
return nil, singleton.DB.Where("id IN (?)", ids).Delete(&model.User{}).Error
}

View File

@ -1,8 +1,6 @@
package controller
import (
"strconv"
"github.com/gin-gonic/gin"
"github.com/nezhahq/nezha/model"
@ -15,40 +13,16 @@ import (
// @Schemes
// @Description List server
// @Tags auth required
// @Param limit query uint false "Page limit"
// @Param offset query uint false "Page offset"
// @Produce json
// @Success 200 {object} model.PaginatedResponse[[]model.WAFApiMock, model.WAFApiMock]
// @Success 200 {object} model.CommonResponse[[]model.WAFApiMock]
// @Router /waf [get]
func listBlockedAddress(c *gin.Context) (*model.Value[[]*model.WAF], error) {
limit, err := strconv.Atoi(c.Query("limit"))
if err != nil || limit < 1 {
limit = 25
}
offset, err := strconv.Atoi(c.Query("offset"))
if err != nil || offset < 0 {
offset = 0
}
func listBlockedAddress(c *gin.Context) ([]*model.WAF, error) {
var waf []*model.WAF
if err := singleton.DB.Limit(limit).Offset(offset).Find(&waf).Error; err != nil {
if err := singleton.DB.Find(&waf).Error; err != nil {
return nil, err
}
var total int64
if err := singleton.DB.Model(&model.WAF{}).Count(&total).Error; err != nil {
return nil, err
}
return &model.Value[[]*model.WAF]{
Value: waf,
Pagination: model.Pagination{
Offset: offset,
Limit: limit,
Total: total,
},
}, nil
return waf, nil
}
// Batch delete blocked addresses
@ -56,7 +30,7 @@ func listBlockedAddress(c *gin.Context) (*model.Value[[]*model.WAF], error) {
// @Security BearerAuth
// @Schemes
// @Description Edit server
// @Tags admin required
// @Tags auth required
// @Accept json
// @Param request body []string true "block list"
// @Produce json

View File

@ -10,7 +10,6 @@ import (
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/hashicorp/go-uuid"
"golang.org/x/sync/singleflight"
"github.com/nezhahq/nezha/model"
@ -103,29 +102,13 @@ func checkSameOrigin(r *http.Request) bool {
// @Success 200 {object} model.StreamServerData
// @Router /ws/server [get]
func serverStream(c *gin.Context) (any, error) {
connId, err := uuid.GenerateUUID()
if err != nil {
return nil, newWsError("%v", err)
}
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
return nil, newWsError("%v", err)
}
defer conn.Close()
userIp := c.GetString(model.CtxKeyRealIPStr)
if userIp == "" {
userIp = c.RemoteIP()
}
singleton.AddOnlineUser(connId, &model.OnlineUser{
IP: userIp,
ConnectedAt: time.Now(),
Conn: conn,
})
defer singleton.RemoveOnlineUser(connId)
singleton.OnlineUsers.Add(1)
defer singleton.OnlineUsers.Add(^uint64(0))
count := 0
for {
stat, err := getServerStat(c, count == 0)
@ -183,7 +166,7 @@ func getServerStat(c *gin.Context, withPublicNote bool) ([]byte, error) {
return utils.Json.Marshal(model.StreamServerData{
Now: time.Now().Unix() * 1000,
Online: singleton.GetOnlineUserCount(),
Online: singleton.OnlineUsers.Load(),
Servers: servers,
})
})

View File

@ -100,34 +100,12 @@ func DispatchTask(serviceSentinelDispatchBus <-chan model.Service) {
continue
}
if task.Cover == model.ServiceCoverIgnoreAll && task.SkipServers[singleton.SortedServerList[workedServerIndex].ID] {
server := singleton.SortedServerList[workedServerIndex]
singleton.UserLock.RLock()
var role uint8
if u, ok := singleton.UserInfoMap[server.UserID]; !ok {
role = model.RoleMember
} else {
role = u.Role
}
singleton.UserLock.RUnlock()
if task.UserID == server.UserID || role == model.RoleAdmin {
singleton.SortedServerList[workedServerIndex].TaskStream.Send(task.PB())
}
singleton.SortedServerList[workedServerIndex].TaskStream.Send(task.PB())
workedServerIndex++
continue
}
if task.Cover == model.ServiceCoverAll && !task.SkipServers[singleton.SortedServerList[workedServerIndex].ID] {
server := singleton.SortedServerList[workedServerIndex]
singleton.UserLock.RLock()
var role uint8
if u, ok := singleton.UserInfoMap[server.UserID]; !ok {
role = model.RoleMember
} else {
role = u.Role
}
singleton.UserLock.RUnlock()
if task.UserID == server.UserID || role == model.RoleAdmin {
singleton.SortedServerList[workedServerIndex].TaskStream.Send(task.PB())
}
singleton.SortedServerList[workedServerIndex].TaskStream.Send(task.PB())
workedServerIndex++
continue
}

4
go.mod
View File

@ -1,8 +1,8 @@
module github.com/nezhahq/nezha
go 1.23.0
go 1.22.7
toolchain go1.23.2
toolchain go1.23.1
require (
github.com/appleboy/gin-jwt/v2 v2.10.0

View File

@ -62,15 +62,10 @@ func (r *AlertRule) Enabled() bool {
}
// Snapshot 对传入的Server进行该报警规则下所有type的检查 返回每项检查结果
func (r *AlertRule) Snapshot(cycleTransferStats *CycleTransferStats, server *Server, db *gorm.DB, role uint8) []bool {
point := make([]bool, len(r.Rules))
if r.UserID != server.UserID && role != RoleAdmin {
return point
}
for i, rule := range r.Rules {
point[i] = rule.Snapshot(cycleTransferStats, server, db)
func (r *AlertRule) Snapshot(cycleTransferStats *CycleTransferStats, server *Server, db *gorm.DB) []bool {
point := make([]bool, 0, len(r.Rules))
for _, rule := range r.Rules {
point = append(point, rule.Snapshot(cycleTransferStats, server, db))
}
return point
}

View File

@ -15,23 +15,6 @@ type CommonResponse[T any] struct {
Error string `json:"error,omitempty"`
}
type PaginatedResponse[S ~[]E, E any] struct {
Success bool `json:"success,omitempty"`
Data *Value[S] `json:"data,omitempty"`
Error string `json:"error,omitempty"`
}
type Value[T any] struct {
Value T `json:"value,omitempty"`
Pagination Pagination `json:"pagination,omitempty"`
}
type Pagination struct {
Offset int `json:"offset,omitempty"`
Limit int `json:"limit,omitempty"`
Total int64 `json:"total,omitempty"`
}
type LoginResponse struct {
Token string `json:"token,omitempty"`
Expire string `json:"expire,omitempty"`

View File

@ -2,8 +2,6 @@ package model
import (
"time"
"github.com/gin-gonic/gin"
)
const (
@ -20,47 +18,6 @@ type Common struct {
UpdatedAt time.Time `gorm:"autoUpdateTime" json:"updated_at,omitempty"`
// Do not use soft deletion
// DeletedAt gorm.DeletedAt `gorm:"index" json:"deleted_at,omitempty"`
UserID uint64 `json:"-"`
}
func (c *Common) GetID() uint64 {
return c.ID
}
func (c *Common) GetUserID() uint64 {
return c.UserID
}
func (c *Common) HasPermission(ctx *gin.Context) bool {
auth, ok := ctx.Get(CtxKeyAuthorizedUser)
if !ok {
return false
}
user := *auth.(*User)
if user.Role == RoleAdmin {
return true
}
return user.ID == c.UserID
}
type CommonInterface interface {
GetID() uint64
GetUserID() uint64
HasPermission(*gin.Context) bool
}
func FindByUserID[S ~[]E, E CommonInterface](s S, uid uint64) []uint64 {
var list []uint64
for _, v := range s {
if v.GetUserID() == uid {
list = append(list, v.GetID())
}
}
return list
}
type Response struct {

View File

@ -16,7 +16,7 @@ type StreamServer struct {
type StreamServerData struct {
Now int64 `json:"now,omitempty"`
Online int `json:"online,omitempty"`
Online uint64 `json:"online,omitempty"`
Servers []StreamServer `json:"servers,omitempty"`
}

View File

@ -1,54 +1,12 @@
package model
import (
"time"
"github.com/gorilla/websocket"
"github.com/nezhahq/nezha/pkg/utils"
"gorm.io/gorm"
)
const (
RoleAdmin uint8 = iota
RoleMember
)
type User struct {
Common
Username string `json:"username,omitempty" gorm:"uniqueIndex"`
Password string `json:"password,omitempty" gorm:"type:char(72)"`
Role uint8 `json:"role,omitempty"`
AgentSecret string `json:"agent_secret,omitempty" gorm:"type:char(32)"`
}
type UserInfo struct {
Role uint8
AgentSecret string
}
func (u *User) BeforeSave(tx *gorm.DB) error {
if u.AgentSecret != "" {
return nil
}
key, err := utils.GenerateRandomString(32)
if err != nil {
return err
}
u.AgentSecret = key
return nil
Username string `json:"username,omitempty" gorm:"uniqueIndex"`
Password string `json:"password,omitempty" gorm:"type:char(72)"`
}
type Profile struct {
User
LoginIP string `json:"login_ip,omitempty"`
}
type OnlineUser struct {
UserID uint64 `json:"user_id,omitempty"`
ConnectedAt time.Time `json:"connected_at,omitempty"`
IP string `json:"ip,omitempty"`
Conn *websocket.Conn `json:"-"`
}

6
model/user_group.go Normal file
View File

@ -0,0 +1,6 @@
package model
type UserGroup struct {
Common
Name string `json:"name"`
}

7
model/user_group_user.go Normal file
View File

@ -0,0 +1,7 @@
package model
type UserGroupUser struct {
Common
UserGroupId uint64 `json:"user_group_id"`
UserId uint64 `json:"user_id"`
}

View File

@ -14,34 +14,24 @@ const (
WAFBlockReasonTypeLoginFail
WAFBlockReasonTypeBruteForceToken
WAFBlockReasonTypeAgentAuthFail
WAFBlockReasonTypeManual
)
const (
BlockIDgRPC = -127 + iota
BlockIDToken
BlockIDUnknownUser
BlockIDManual
)
type WAFApiMock struct {
IP string `json:"ip,omitempty"`
BlockIdentifier int64 `json:"block_identifier,omitempty"`
BlockReason uint8 `json:"block_reason,omitempty"`
BlockTimestamp uint64 `json:"block_timestamp,omitempty"`
Count uint64 `json:"count,omitempty"`
IP string `json:"ip,omitempty"`
Count uint64 `json:"count,omitempty"`
LastBlockReason uint8 `json:"last_block_reason,omitempty"`
LastBlockTimestamp uint64 `json:"last_block_timestamp,omitempty"`
}
type WAF struct {
IP []byte `gorm:"type:binary(16);primaryKey" json:"ip,omitempty"`
BlockIdentifier int64 `gorm:"primaryKey" json:"block_identifier,omitempty"`
BlockReason uint8 `json:"block_reason,omitempty"`
BlockTimestamp uint64 `gorm:"index" json:"block_timestamp,omitempty"`
Count uint64 `json:"count,omitempty"`
IP []byte `gorm:"type:binary(16);primaryKey" json:"ip,omitempty"`
Count uint64 `json:"count,omitempty"`
LastBlockReason uint8 `json:"last_block_reason,omitempty"`
LastBlockTimestamp uint64 `json:"last_block_timestamp,omitempty"`
}
func (w *WAF) TableName() string {
return "nz_waf"
return "waf"
}
func CheckIP(db *gorm.DB, ip string) error {
@ -52,31 +42,22 @@ func CheckIP(db *gorm.DB, ip string) error {
if err != nil {
return err
}
var blockTimestamp uint64
result := db.Model(&WAF{}).Order("block_timestamp desc").Select("block_timestamp").Where("ip = ?", ipBinary).Limit(1).Find(&blockTimestamp)
var w WAF
result := db.Limit(1).Find(&w, "ip = ?", ipBinary)
if result.Error != nil {
return result.Error
}
// 检查是否未找到记录
if result.RowsAffected < 1 {
if result.RowsAffected == 0 { // 检查是否未找到记录
return nil
}
var count uint64
if err := db.Model(&WAF{}).Select("SUM(count)").Where("ip = ?", ipBinary).Scan(&count).Error; err != nil {
return err
}
now := time.Now().Unix()
if powAdd(count, 4, blockTimestamp) > uint64(now) {
if powAdd(w.Count, 4, w.LastBlockTimestamp) > uint64(now) {
return errors.New("you are blocked by nezha WAF")
}
return nil
}
func ClearIP(db *gorm.DB, ip string, uid int64) error {
func ClearIP(db *gorm.DB, ip string) error {
if ip == "" {
return nil
}
@ -84,7 +65,7 @@ func ClearIP(db *gorm.DB, ip string, uid int64) error {
if err != nil {
return err
}
return db.Unscoped().Delete(&WAF{}, "ip = ? and block_identifier = ?", ipBinary, uid).Error
return db.Unscoped().Delete(&WAF{}, "ip = ?", ipBinary).Error
}
func BatchClearIP(db *gorm.DB, ip []string) error {
@ -102,7 +83,7 @@ func BatchClearIP(db *gorm.DB, ip []string) error {
return db.Unscoped().Delete(&WAF{}, "ip in (?)", ips).Error
}
func BlockIP(db *gorm.DB, ip string, reason uint8, uid int64) error {
func BlockIP(db *gorm.DB, ip string, reason uint8) error {
if ip == "" {
return nil
}
@ -110,19 +91,16 @@ func BlockIP(db *gorm.DB, ip string, reason uint8, uid int64) error {
if err != nil {
return err
}
w := WAF{
IP: ipBinary,
BlockIdentifier: uid,
}
now := uint64(time.Now().Unix())
var w WAF
w.IP = ipBinary
return db.Transaction(func(tx *gorm.DB) error {
if err := tx.Where(&w).Attrs(WAF{
BlockReason: reason,
BlockTimestamp: now,
LastBlockReason: reason,
LastBlockTimestamp: uint64(time.Now().Unix()),
}).FirstOrCreate(&w).Error; err != nil {
return err
}
return tx.Exec("UPDATE nz_waf SET count = count + 1, block_reason = ?, block_timestamp = ? WHERE ip = ? and block_identifier = ?", reason, now, ipBinary, uid).Error
return tx.Exec("UPDATE waf SET count = count + 1, last_block_reason = ?, last_block_timestamp = ? WHERE ip = ?", reason, uint64(time.Now().Unix()), ipBinary).Error
})
}

View File

@ -3,12 +3,10 @@ package utils
import (
"crypto/rand"
"errors"
"maps"
"math/big"
"net/netip"
"os"
"regexp"
"slices"
"strconv"
"strings"
@ -147,8 +145,3 @@ func Itoa[T constraints.Integer](i T) string {
return ""
}
}
func MapValuesToSlice[Map ~map[K]V, K comparable, V any](m Map) []V {
s := make([]V, 0, len(m))
return slices.AppendSeq(s, maps.Values(m))
}

View File

@ -36,16 +36,12 @@ func (a *authHandler) Check(ctx context.Context) (uint64, error) {
ip, _ := ctx.Value(model.CtxKeyRealIP{}).(string)
singleton.UserLock.RLock()
userId, ok := singleton.AgentSecretToUserId[clientSecret]
if !ok && clientSecret != singleton.Conf.AgentSecretKey {
singleton.UserLock.RUnlock()
model.BlockIP(singleton.DB, ip, model.WAFBlockReasonTypeAgentAuthFail, model.BlockIDgRPC)
if clientSecret != singleton.Conf.AgentSecretKey {
model.BlockIP(singleton.DB, ip, model.WAFBlockReasonTypeAgentAuthFail)
return 0, status.Error(codes.Unauthenticated, "客户端认证失败")
}
singleton.UserLock.RUnlock()
model.ClearIP(singleton.DB, ip, model.BlockIDgRPC)
model.ClearIP(singleton.DB, ip)
var clientUUID string
if value, ok := md["client_uuid"]; ok {
@ -57,26 +53,21 @@ func (a *authHandler) Check(ctx context.Context) (uint64, error) {
}
singleton.ServerLock.RLock()
clientID, hasID := singleton.ServerUUIDToID[clientUUID]
singleton.ServerLock.RUnlock()
defer singleton.ServerLock.RUnlock()
clientID, hasID := singleton.ServerUUIDToID[clientUUID]
if !hasID {
s := model.Server{UUID: clientUUID, Name: petname.Generate(2, "-"), Common: model.Common{
UserID: userId,
}}
s := model.Server{UUID: clientUUID, Name: petname.Generate(2, "-")}
if err := singleton.DB.Create(&s).Error; err != nil {
return 0, status.Error(codes.Unauthenticated, err.Error())
}
s.Host = &model.Host{}
s.State = &model.HostState{}
s.GeoIP = &model.GeoIP{}
singleton.ServerLock.Lock()
// generate a random silly server name
singleton.ServerList[s.ID] = &s
singleton.ServerUUIDToID[clientUUID] = s.ID
singleton.ServerLock.Unlock()
singleton.ReSortServer()
clientID = s.ID
}

View File

@ -143,16 +143,8 @@ func checkStatus() {
}
for _, server := range ServerList {
// 监测点
UserLock.RLock()
var role uint8
if u, ok := UserInfoMap[server.UserID]; !ok {
role = model.RoleMember
} else {
role = u.Role
}
UserLock.RUnlock()
alertsStore[alert.ID][server.ID] = append(alertsStore[alert.
ID][server.ID], alert.Snapshot(AlertsCycleTransferStatsStore[alert.ID], server, DB, role))
ID][server.ID], alert.Snapshot(AlertsCycleTransferStatsStore[alert.ID], server, DB))
// 发送通知,分为触发报警和恢复通知
max, passed := alert.Check(alertsStore[alert.ID][server.ID])
// 保存当前服务器状态信息

View File

@ -12,7 +12,6 @@ import (
"github.com/robfig/cron/v3"
"github.com/nezhahq/nezha/model"
"github.com/nezhahq/nezha/pkg/utils"
pb "github.com/nezhahq/nezha/proto"
)
@ -80,7 +79,10 @@ func UpdateCronList() {
CronLock.RLock()
defer CronLock.RUnlock()
CronList = utils.MapValuesToSlice(Crons)
CronList = make([]*model.Cron, 0, len(Crons))
for _, c := range Crons {
CronList = append(CronList, c)
}
slices.SortFunc(CronList, func(a, b *model.Cron) int {
return cmp.Compare(a.ID, b.ID)
})

View File

@ -13,7 +13,6 @@ import (
ddns2 "github.com/nezhahq/nezha/pkg/ddns"
"github.com/nezhahq/nezha/pkg/ddns/dummy"
"github.com/nezhahq/nezha/pkg/ddns/webhook"
"github.com/nezhahq/nezha/pkg/utils"
)
var (
@ -25,10 +24,12 @@ var (
func initDDNS() {
DB.Find(&DDNSList)
DDNSCacheLock.Lock()
DDNSCache = make(map[uint64]*model.DDNSProfile)
for i := 0; i < len(DDNSList); i++ {
DDNSCache[DDNSList[i].ID] = DDNSList[i]
}
DDNSCacheLock.Unlock()
OnNameserverUpdate()
}
@ -55,7 +56,10 @@ func UpdateDDNSList() {
DDNSListLock.Lock()
defer DDNSListLock.Unlock()
DDNSList = utils.MapValuesToSlice(DDNSCache)
DDNSList = make([]*model.DDNSProfile, 0, len(DDNSCache))
for _, p := range DDNSCache {
DDNSList = append(DDNSList, p)
}
slices.SortFunc(DDNSList, func(a, b *model.DDNSProfile) int {
return cmp.Compare(a.ID, b.ID)
})

View File

@ -6,7 +6,6 @@ import (
"sync"
"github.com/nezhahq/nezha/model"
"github.com/nezhahq/nezha/pkg/utils"
)
var (
@ -20,6 +19,8 @@ var (
func initNAT() {
DB.Find(&NATList)
NATCacheRwLock.Lock()
defer NATCacheRwLock.Unlock()
NATCache = make(map[string]*model.NAT)
for i := 0; i < len(NATList); i++ {
NATCache[NATList[i].Domain] = NATList[i]
@ -58,7 +59,10 @@ func UpdateNATList() {
NATListLock.Lock()
defer NATListLock.Unlock()
NATList = utils.MapValuesToSlice(NATCache)
NATList = make([]*model.NAT, 0, len(NATCache))
for _, n := range NATCache {
NATList = append(NATList, n)
}
slices.SortFunc(NATList, func(a, b *model.NAT) int {
return cmp.Compare(a.ID, b.ID)
})

View File

@ -9,7 +9,6 @@ import (
"time"
"github.com/nezhahq/nezha/model"
"github.com/nezhahq/nezha/pkg/utils"
)
const (
@ -31,7 +30,7 @@ var (
)
// InitNotification 初始化 GroupID <-> ID <-> Notification 的映射
func initNotification() {
func InitNotification() {
NotificationList = make(map[uint64]map[uint64]*model.Notification)
NotificationIDToGroups = make(map[uint64]map[uint64]struct{})
NotificationGroup = make(map[uint64]string)
@ -39,7 +38,9 @@ func initNotification() {
// loadNotifications 从 DB 初始化通知方式相关参数
func loadNotifications() {
initNotification()
InitNotification()
NotificationsLock.Lock()
groupNotifications := make(map[uint64][]uint64)
var ngn []model.NotificationGroupNotification
if err := DB.Find(&ngn).Error; err != nil {
@ -73,6 +74,8 @@ func loadNotifications() {
}
}
}
NotificationsLock.Unlock()
}
func UpdateNotificationList() {
@ -82,7 +85,10 @@ func UpdateNotificationList() {
NotificationSortedLock.Lock()
defer NotificationSortedLock.Unlock()
NotificationListSorted = utils.MapValuesToSlice(NotificationMap)
NotificationListSorted = make([]*model.Notification, 0, len(NotificationMap))
for _, n := range NotificationMap {
NotificationListSorted = append(NotificationListSorted, n)
}
slices.SortFunc(NotificationListSorted, func(a, b *model.Notification) int {
return cmp.Compare(a.ID, b.ID)
})

View File

@ -1,68 +0,0 @@
package singleton
import (
"slices"
"sync"
"github.com/nezhahq/nezha/model"
)
var (
OnlineUserMap = make(map[string]*model.OnlineUser)
OnlineUserMapLock = new(sync.Mutex)
)
func AddOnlineUser(connId string, user *model.OnlineUser) {
OnlineUserMapLock.Lock()
defer OnlineUserMapLock.Unlock()
OnlineUserMap[connId] = user
}
func RemoveOnlineUser(connId string) {
OnlineUserMapLock.Lock()
defer OnlineUserMapLock.Unlock()
delete(OnlineUserMap, connId)
}
func BlockByIPs(ipList []string) error {
OnlineUserMapLock.Lock()
defer OnlineUserMapLock.Unlock()
for _, ip := range ipList {
if err := model.BlockIP(DB, ip, model.WAFBlockReasonTypeManual, model.BlockIDManual); err != nil {
return err
}
for _, user := range OnlineUserMap {
if user.IP == ip && user.Conn != nil {
user.Conn.Close()
}
}
}
return nil
}
func GetOnlineUsers(limit, offset int) []*model.OnlineUser {
OnlineUserMapLock.Lock()
defer OnlineUserMapLock.Unlock()
var users []*model.OnlineUser
for _, user := range OnlineUserMap {
users = append(users, user)
}
slices.SortFunc(users, func(i, j *model.OnlineUser) int {
return i.ConnectedAt.Compare(j.ConnectedAt)
})
if offset > len(users) {
return nil
}
if offset+limit > len(users) {
return users[offset:]
}
return users[offset : offset+limit]
}
func GetOnlineUserCount() int {
OnlineUserMapLock.Lock()
defer OnlineUserMapLock.Unlock()
return len(OnlineUserMap)
}

View File

@ -1,12 +1,10 @@
package singleton
import (
"cmp"
"slices"
"sort"
"sync"
"github.com/nezhahq/nezha/model"
"github.com/nezhahq/nezha/pkg/utils"
)
var (
@ -47,21 +45,29 @@ func ReSortServer() {
SortedServerLock.Lock()
defer SortedServerLock.Unlock()
SortedServerList = utils.MapValuesToSlice(ServerList)
// 按照服务器 ID 排序的具体实现ID越大越靠前
slices.SortStableFunc(SortedServerList, func(a, b *model.Server) int {
if a.DisplayIndex == b.DisplayIndex {
return cmp.Compare(a.ID, b.ID)
}
return cmp.Compare(b.DisplayIndex, a.DisplayIndex)
})
SortedServerListForGuest = make([]*model.Server, 0, len(SortedServerList))
for _, s := range SortedServerList {
SortedServerList = make([]*model.Server, 0, len(ServerList))
SortedServerListForGuest = make([]*model.Server, 0)
for _, s := range ServerList {
SortedServerList = append(SortedServerList, s)
if !s.HideForGuest {
SortedServerListForGuest = append(SortedServerListForGuest, s)
}
}
// 按照服务器 ID 排序的具体实现ID越大越靠前
sort.SliceStable(SortedServerList, func(i, j int) bool {
if SortedServerList[i].DisplayIndex == SortedServerList[j].DisplayIndex {
return SortedServerList[i].ID < SortedServerList[j].ID
}
return SortedServerList[i].DisplayIndex > SortedServerList[j].DisplayIndex
})
sort.SliceStable(SortedServerListForGuest, func(i, j int) bool {
if SortedServerListForGuest[i].DisplayIndex == SortedServerListForGuest[j].DisplayIndex {
return SortedServerListForGuest[i].ID < SortedServerListForGuest[j].ID
}
return SortedServerListForGuest[i].DisplayIndex > SortedServerListForGuest[j].DisplayIndex
})
}
func OnServerDelete(sid []uint64) {

View File

@ -11,7 +11,6 @@ import (
"github.com/jinzhu/copier"
"github.com/nezhahq/nezha/model"
"github.com/nezhahq/nezha/pkg/utils"
pb "github.com/nezhahq/nezha/proto"
)
@ -175,7 +174,11 @@ func (ss *ServiceSentinel) UpdateServiceList() {
ss.ServiceListLock.Lock()
defer ss.ServiceListLock.Unlock()
ss.ServiceList = utils.MapValuesToSlice(ss.Services)
ss.ServiceList = make([]*model.Service, 0, len(ss.Services))
for _, v := range ss.Services {
ss.ServiceList = append(ss.ServiceList, v)
}
slices.SortFunc(ss.ServiceList, func(a, b *model.Service) int {
return cmp.Compare(a.ID, b.ID)
})
@ -189,6 +192,13 @@ func (ss *ServiceSentinel) loadServiceHistory() {
panic(err)
}
ss.serviceResponseDataStoreLock.Lock()
defer ss.serviceResponseDataStoreLock.Unlock()
ss.monthlyStatusLock.Lock()
defer ss.monthlyStatusLock.Unlock()
ss.ServicesLock.Lock()
defer ss.ServicesLock.Unlock()
for i := 0; i < len(services); i++ {
task := *services[i]
// 通过cron定时将服务监控任务传递给任务调度管道

View File

@ -3,6 +3,7 @@ package singleton
import (
_ "embed"
"log"
"sync/atomic"
"time"
"github.com/patrickmn/go-cache"
@ -23,6 +24,7 @@ var (
Loc *time.Location
FrontendTemplates []model.FrontendTemplate
DashboardBootTime = uint64(time.Now().Unix())
OnlineUsers = new(atomic.Uint64)
)
//go:embed frontend-templates.yaml
@ -40,7 +42,6 @@ func InitTimezoneAndCache() {
// LoadSingleton 加载子服务并执行
func LoadSingleton() {
initUser() // 加载用户ID绑定表
initI18n() // 加载本地化服务
loadNotifications() // 加载通知服务
loadServers() // 加载服务器列表
@ -80,8 +81,8 @@ func InitDBFromPath(path string) {
}
err = DB.AutoMigrate(model.Server{}, model.User{}, model.ServerGroup{}, model.NotificationGroup{},
model.Notification{}, model.AlertRule{}, model.Service{}, model.NotificationGroupNotification{},
model.ServiceHistory{}, model.Cron{}, model.Transfer{}, model.ServerGroupServer{},
model.NAT{}, model.DDNSProfile{}, model.NotificationGroupNotification{},
model.ServiceHistory{}, model.Cron{}, model.Transfer{}, model.ServerGroupServer{}, model.UserGroup{},
model.UserGroupUser{}, model.NAT{}, model.DDNSProfile{}, model.NotificationGroupNotification{},
model.WAF{})
if err != nil {
panic(err)

View File

@ -1,135 +0,0 @@
package singleton
import (
"sync"
"github.com/nezhahq/nezha/model"
"gorm.io/gorm"
)
var (
UserInfoMap map[uint64]model.UserInfo
AgentSecretToUserId map[string]uint64
UserLock sync.RWMutex
)
func initUser() {
UserInfoMap = make(map[uint64]model.UserInfo)
AgentSecretToUserId = make(map[string]uint64)
var users []model.User
DB.Find(&users)
for _, u := range users {
UserInfoMap[u.ID] = model.UserInfo{
Role: u.Role,
AgentSecret: u.AgentSecret,
}
AgentSecretToUserId[u.AgentSecret] = u.ID
}
}
func OnUserUpdate(u *model.User) {
UserLock.Lock()
defer UserLock.Unlock()
if u == nil {
return
}
UserInfoMap[u.ID] = model.UserInfo{
Role: u.Role,
AgentSecret: u.AgentSecret,
}
AgentSecretToUserId[u.AgentSecret] = u.ID
}
func OnUserDelete(id []uint64, errorFunc func(string, ...interface{}) error) error {
UserLock.Lock()
defer UserLock.Unlock()
if len(id) < 1 {
return Localizer.ErrorT("user id not specified")
}
var (
cron, server bool
crons, servers []uint64
)
for _, uid := range id {
err := DB.Transaction(func(tx *gorm.DB) error {
CronLock.RLock()
crons = model.FindByUserID(CronList, uid)
CronLock.RUnlock()
cron = len(crons) > 0
if cron {
if err := tx.Unscoped().Delete(&model.Cron{}, "id in (?)", crons).Error; err != nil {
return err
}
}
SortedServerLock.RLock()
servers = model.FindByUserID(SortedServerList, uid)
SortedServerLock.RUnlock()
server = len(servers) > 0
if server {
if err := tx.Unscoped().Delete(&model.Server{}, "id in (?)", servers).Error; err != nil {
return err
}
if err := tx.Unscoped().Delete(&model.ServerGroupServer{}, "server_id in (?)", servers).Error; err != nil {
return err
}
}
if err := tx.Unscoped().Delete(&model.Transfer{}, "server_id in (?)", servers).Error; err != nil {
return err
}
if err := tx.Where("id IN (?)", id).Delete(&model.User{}).Error; err != nil {
return err
}
return nil
})
if err != nil {
return errorFunc("%v", err)
}
if cron {
OnDeleteCron(crons)
}
if server {
AlertsLock.Lock()
for _, sid := range servers {
for _, alert := range Alerts {
if AlertsCycleTransferStatsStore[alert.ID] != nil {
delete(AlertsCycleTransferStatsStore[alert.ID].ServerName, sid)
delete(AlertsCycleTransferStatsStore[alert.ID].Transfer, sid)
delete(AlertsCycleTransferStatsStore[alert.ID].NextUpdate, sid)
}
}
}
AlertsLock.Unlock()
OnServerDelete(servers)
}
secret := UserInfoMap[uid].AgentSecret
delete(AgentSecretToUserId, secret)
delete(UserInfoMap, uid)
}
if cron {
UpdateCronList()
}
if server {
ReSortServer()
}
return nil
}