Compare commits

..

No commits in common. "00a04941249e5152eb43273da871da2a53056368" and "9598657a81cdb1cd0edadc5317e0e653a0104683" have entirely different histories.

20 changed files with 90 additions and 289 deletions

View File

@ -62,7 +62,7 @@ func createAlertRule(c *gin.Context) (uint64, error) {
r.TriggerMode = arf.TriggerMode r.TriggerMode = arf.TriggerMode
r.Enable = &enable r.Enable = &enable
if err := validateRule(c, &r); err != nil { if err := validateRule(&r); err != nil {
return 0, err return 0, err
} }
@ -116,7 +116,7 @@ func updateAlertRule(c *gin.Context) (any, error) {
r.TriggerMode = arf.TriggerMode r.TriggerMode = arf.TriggerMode
r.Enable = &enable r.Enable = &enable
if err := validateRule(c, &r); err != nil { if err := validateRule(&r); err != nil {
return 0, err return 0, err
} }
@ -140,20 +140,22 @@ func updateAlertRule(c *gin.Context) (any, error) {
// @Success 200 {object} model.CommonResponse[any] // @Success 200 {object} model.CommonResponse[any]
// @Router /batch-delete/alert-rule [post] // @Router /batch-delete/alert-rule [post]
func batchDeleteAlertRule(c *gin.Context) (any, error) { func batchDeleteAlertRule(c *gin.Context) (any, error) {
var ar []uint64 var arr []uint64
if err := c.ShouldBindJSON(&ar); err != nil { if err := c.ShouldBindJSON(&arr); err != nil {
return nil, err return nil, err
} }
var ars []model.AlertRule var ars []model.AlertRule
if err := singleton.DB.Where("id in (?)", ar).Find(&ars).Error; err != nil { if err := singleton.DB.Where("id in (?)", arr).Find(&ars).Error; err != nil {
return nil, err return nil, err
} }
var ar []uint64
for _, a := range ars { for _, a := range ars {
if !a.HasPermission(c) { if !a.HasPermission(c) {
return nil, singleton.Localizer.ErrorT("permission denied") return nil, singleton.Localizer.ErrorT("permission denied")
} }
ar = append(ar, a.ID)
} }
if err := singleton.DB.Unscoped().Delete(&model.AlertRule{}, "id in (?)", ar).Error; err != nil { if err := singleton.DB.Unscoped().Delete(&model.AlertRule{}, "id in (?)", ar).Error; err != nil {
@ -164,20 +166,9 @@ func batchDeleteAlertRule(c *gin.Context) (any, error) {
return nil, nil return nil, nil
} }
func validateRule(c *gin.Context, r *model.AlertRule) error { func validateRule(r *model.AlertRule) error {
if len(r.Rules) > 0 { if len(r.Rules) > 0 {
for _, rule := range r.Rules { for _, rule := range r.Rules {
singleton.ServerLock.RLock()
for s := range rule.Ignore {
if server, ok := singleton.ServerList[s]; ok {
if !server.HasPermission(c) {
singleton.ServerLock.RUnlock()
return singleton.Localizer.ErrorT("permission denied")
}
}
}
singleton.ServerLock.RUnlock()
if !rule.IsTransferDurationRule() { if !rule.IsTransferDurationRule() {
if rule.Duration < 3 { if rule.Duration < 3 {
return singleton.Localizer.ErrorT("duration need to be at least 3") return singleton.Localizer.ErrorT("duration need to be at least 3")

View File

@ -129,7 +129,7 @@ func routers(r *gin.Engine, frontendDist fs.FS) {
auth.PATCH("/nat/:id", commonHandler(updateNAT)) auth.PATCH("/nat/:id", commonHandler(updateNAT))
auth.POST("/batch-delete/nat", commonHandler(batchDeleteNAT)) auth.POST("/batch-delete/nat", commonHandler(batchDeleteNAT))
auth.GET("/waf", pCommonHandler(listBlockedAddress)) auth.GET("/waf", commonHandler(listBlockedAddress))
auth.POST("/batch-delete/waf", adminHandler(batchDeleteBlockedAddress)) auth.POST("/batch-delete/waf", adminHandler(batchDeleteBlockedAddress))
auth.PATCH("/setting", adminHandler(updateConfig)) auth.PATCH("/setting", adminHandler(updateConfig))
@ -153,7 +153,6 @@ func newErrorResponse(err error) model.CommonResponse[any] {
} }
type handlerFunc[T any] func(c *gin.Context) (T, error) type handlerFunc[T any] func(c *gin.Context) (T, error)
type pHandlerFunc[S ~[]E, E any] func(c *gin.Context) (*model.Value[S], error)
// There are many error types in gorm, so create a custom type to represent all // There are many error types in gorm, so create a custom type to represent all
// gorm errors here instead // gorm errors here instead
@ -248,18 +247,6 @@ func listHandler[S ~[]E, E model.CommonInterface](handler handlerFunc[S]) func(*
} }
} }
func pCommonHandler[S ~[]E, E any](handler pHandlerFunc[S, E]) func(*gin.Context) {
return func(c *gin.Context) {
data, err := handler(c)
if err != nil {
c.JSON(http.StatusOK, newErrorResponse(err))
return
}
c.JSON(http.StatusOK, model.PaginatedResponse[S, E]{Success: true, Data: data})
}
}
func filter[S ~[]E, E model.CommonInterface](ctx *gin.Context, s S) S { func filter[S ~[]E, E model.CommonInterface](ctx *gin.Context, s S) S {
return slices.DeleteFunc(s, func(e E) bool { return slices.DeleteFunc(s, func(e E) bool {
return !e.HasPermission(ctx) return !e.HasPermission(ctx)

View File

@ -49,17 +49,6 @@ func createCron(c *gin.Context) (uint64, error) {
return 0, err return 0, err
} }
singleton.ServerLock.RLock()
for _, sid := range cf.Servers {
if server, ok := singleton.ServerList[sid]; ok {
if !server.HasPermission(c) {
singleton.ServerLock.RUnlock()
return 0, singleton.Localizer.ErrorT("permission denied")
}
}
}
singleton.ServerLock.RUnlock()
cr.UserID = getUid(c) cr.UserID = getUid(c)
cr.TaskType = cf.TaskType cr.TaskType = cf.TaskType
cr.Name = cf.Name cr.Name = cf.Name
@ -115,17 +104,6 @@ func updateCron(c *gin.Context) (any, error) {
return 0, err return 0, err
} }
singleton.ServerLock.RLock()
for _, sid := range cf.Servers {
if server, ok := singleton.ServerList[sid]; ok {
if !server.HasPermission(c) {
singleton.ServerLock.RUnlock()
return nil, singleton.Localizer.ErrorT("permission denied")
}
}
}
singleton.ServerLock.RUnlock()
var cr model.Cron var cr model.Cron
if err := singleton.DB.First(&cr, id).Error; err != nil { if err := singleton.DB.First(&cr, id).Error; err != nil {
return nil, singleton.Localizer.ErrorT("task id %d does not exist", id) return nil, singleton.Localizer.ErrorT("task id %d does not exist", id)
@ -210,18 +188,20 @@ func manualTriggerCron(c *gin.Context) (any, error) {
// @Success 200 {object} model.CommonResponse[any] // @Success 200 {object} model.CommonResponse[any]
// @Router /batch-delete/cron [post] // @Router /batch-delete/cron [post]
func batchDeleteCron(c *gin.Context) (any, error) { func batchDeleteCron(c *gin.Context) (any, error) {
var cr []uint64 var crr []uint64
if err := c.ShouldBindJSON(&cr); err != nil { if err := c.ShouldBindJSON(&crr); err != nil {
return nil, err return nil, err
} }
var cr []uint64
singleton.CronLock.RLock() singleton.CronLock.RLock()
for _, crID := range cr { for _, crID := range crr {
if crn, ok := singleton.Crons[crID]; ok { if crn, ok := singleton.Crons[crID]; ok {
if !crn.HasPermission(c) { if !crn.HasPermission(c) {
singleton.CronLock.RUnlock() singleton.CronLock.RUnlock()
return nil, singleton.Localizer.ErrorT("permission denied") return nil, singleton.Localizer.ErrorT("permission denied")
} }
cr = append(cr, crn.ID)
} }
} }
singleton.CronLock.RUnlock() singleton.CronLock.RUnlock()

View File

@ -177,19 +177,21 @@ func updateDDNS(c *gin.Context) (any, error) {
// @Success 200 {object} model.CommonResponse[any] // @Success 200 {object} model.CommonResponse[any]
// @Router /batch-delete/ddns [post] // @Router /batch-delete/ddns [post]
func batchDeleteDDNS(c *gin.Context) (any, error) { func batchDeleteDDNS(c *gin.Context) (any, error) {
var ddnsConfigs []uint64 var ddnsConfigsr []uint64
if err := c.ShouldBindJSON(&ddnsConfigs); err != nil { if err := c.ShouldBindJSON(&ddnsConfigsr); err != nil {
return nil, err return nil, err
} }
var ddnsConfigs []uint64
singleton.DDNSCacheLock.RLock() singleton.DDNSCacheLock.RLock()
for _, pid := range ddnsConfigs { for _, pid := range ddnsConfigsr {
if p, ok := singleton.DDNSCache[pid]; ok { if p, ok := singleton.DDNSCache[pid]; ok {
if !p.HasPermission(c) { if !p.HasPermission(c) {
singleton.DDNSCacheLock.RUnlock() singleton.DDNSCacheLock.RUnlock()
return nil, singleton.Localizer.ErrorT("permission denied") return nil, singleton.Localizer.ErrorT("permission denied")
} }
ddnsConfigs = append(ddnsConfigs, p.ID)
} }
} }
singleton.DDNSCacheLock.RUnlock() singleton.DDNSCacheLock.RUnlock()

View File

@ -51,15 +51,6 @@ func createNAT(c *gin.Context) (uint64, error) {
return 0, err return 0, err
} }
singleton.ServerLock.RLock()
if server, ok := singleton.ServerList[nf.ServerID]; ok {
if !server.HasPermission(c) {
singleton.ServerLock.RUnlock()
return 0, singleton.Localizer.ErrorT("permission denied")
}
}
singleton.ServerLock.RUnlock()
uid := getUid(c) uid := getUid(c)
n.UserID = uid n.UserID = uid
@ -102,15 +93,6 @@ func updateNAT(c *gin.Context) (any, error) {
return nil, err return nil, err
} }
singleton.ServerLock.RLock()
if server, ok := singleton.ServerList[nf.ServerID]; ok {
if !server.HasPermission(c) {
singleton.ServerLock.RUnlock()
return nil, singleton.Localizer.ErrorT("permission denied")
}
}
singleton.ServerLock.RUnlock()
var n model.NAT var n model.NAT
if err = singleton.DB.First(&n, id).Error; err != nil { if err = singleton.DB.First(&n, id).Error; err != nil {
return nil, singleton.Localizer.ErrorT("profile id %d does not exist", id) return nil, singleton.Localizer.ErrorT("profile id %d does not exist", id)
@ -146,18 +128,20 @@ func updateNAT(c *gin.Context) (any, error) {
// @Success 200 {object} model.CommonResponse[any] // @Success 200 {object} model.CommonResponse[any]
// @Router /batch-delete/nat [post] // @Router /batch-delete/nat [post]
func batchDeleteNAT(c *gin.Context) (any, error) { func batchDeleteNAT(c *gin.Context) (any, error) {
var n []uint64 var nr []uint64
if err := c.ShouldBindJSON(&n); err != nil { if err := c.ShouldBindJSON(&nr); err != nil {
return nil, err return nil, err
} }
var n []uint64
singleton.NATCacheRwLock.RLock() singleton.NATCacheRwLock.RLock()
for _, id := range n { for _, id := range nr {
if p, ok := singleton.NATCache[singleton.NATIDToDomain[id]]; ok { if p, ok := singleton.NATCache[singleton.NATIDToDomain[id]]; ok {
if !p.HasPermission(c) { if !p.HasPermission(c) {
singleton.NATCacheRwLock.RUnlock() singleton.NATCacheRwLock.RUnlock()
return nil, singleton.Localizer.ErrorT("permission denied") return nil, singleton.Localizer.ErrorT("permission denied")
} }
n = append(n, p.ID)
} }
} }
singleton.NATCacheRwLock.RUnlock() singleton.NATCacheRwLock.RUnlock()

View File

@ -153,17 +153,19 @@ func updateNotification(c *gin.Context) (any, error) {
// @Success 200 {object} model.CommonResponse[any] // @Success 200 {object} model.CommonResponse[any]
// @Router /batch-delete/notification [post] // @Router /batch-delete/notification [post]
func batchDeleteNotification(c *gin.Context) (any, error) { func batchDeleteNotification(c *gin.Context) (any, error) {
var n []uint64 var nr []uint64
if err := c.ShouldBindJSON(&n); err != nil { if err := c.ShouldBindJSON(&nr); err != nil {
return nil, err return nil, err
} }
var n []uint64
singleton.NotificationsLock.RLock() singleton.NotificationsLock.RLock()
for _, nid := range n { for _, nid := range nr {
if ns, ok := singleton.NotificationMap[nid]; ok { if ns, ok := singleton.NotificationMap[nid]; ok {
if !ns.HasPermission(c) { if !ns.HasPermission(c) {
return nil, singleton.Localizer.ErrorT("permission denied") return nil, singleton.Localizer.ErrorT("permission denied")
} }
n = append(n, ns.ID)
} }
} }
singleton.NotificationsLock.RUnlock() singleton.NotificationsLock.RUnlock()

View File

@ -68,17 +68,6 @@ func createNotificationGroup(c *gin.Context) (uint64, error) {
} }
ngf.Notifications = slices.Compact(ngf.Notifications) ngf.Notifications = slices.Compact(ngf.Notifications)
singleton.NotificationsLock.RLock()
for _, nid := range ngf.Notifications {
if n, ok := singleton.NotificationMap[nid]; ok {
if !n.HasPermission(c) {
singleton.NotificationsLock.RUnlock()
return 0, singleton.Localizer.ErrorT("permission denied")
}
}
}
singleton.NotificationsLock.RUnlock()
uid := getUid(c) uid := getUid(c)
var ng model.NotificationGroup var ng model.NotificationGroup
@ -143,18 +132,6 @@ func updateNotificationGroup(c *gin.Context) (any, error) {
if err := c.ShouldBindJSON(&ngf); err != nil { if err := c.ShouldBindJSON(&ngf); err != nil {
return nil, err return nil, err
} }
singleton.NotificationsLock.RLock()
for _, nid := range ngf.Notifications {
if n, ok := singleton.NotificationMap[nid]; ok {
if !n.HasPermission(c) {
singleton.NotificationsLock.RUnlock()
return nil, singleton.Localizer.ErrorT("permission denied")
}
}
}
singleton.NotificationsLock.RUnlock()
var ngDB model.NotificationGroup var ngDB model.NotificationGroup
if err := singleton.DB.First(&ngDB, id).Error; err != nil { if err := singleton.DB.First(&ngDB, id).Error; err != nil {
return nil, singleton.Localizer.ErrorT("group id %d does not exist", id) return nil, singleton.Localizer.ErrorT("group id %d does not exist", id)
@ -218,20 +195,22 @@ func updateNotificationGroup(c *gin.Context) (any, error) {
// @Success 200 {object} model.CommonResponse[any] // @Success 200 {object} model.CommonResponse[any]
// @Router /batch-delete/notification-group [post] // @Router /batch-delete/notification-group [post]
func batchDeleteNotificationGroup(c *gin.Context) (any, error) { func batchDeleteNotificationGroup(c *gin.Context) (any, error) {
var ngn []uint64 var ngnr []uint64
if err := c.ShouldBindJSON(&ngn); err != nil { if err := c.ShouldBindJSON(&ngnr); err != nil {
return nil, err return nil, err
} }
var ng []model.NotificationGroup var ng []model.NotificationGroup
if err := singleton.DB.Where("id in (?)", ngn).Find(&ng).Error; err != nil { if err := singleton.DB.Where("id in (?)", ngnr).Find(&ng).Error; err != nil {
return nil, err return nil, err
} }
var ngn []uint64
for _, n := range ng { for _, n := range ng {
if !n.HasPermission(c) { if !n.HasPermission(c) {
return nil, singleton.Localizer.ErrorT("permission denied") return nil, singleton.Localizer.ErrorT("permission denied")
} }
ngn = append(ngn, n.ID)
} }
err := singleton.DB.Transaction(func(tx *gorm.DB) error { err := singleton.DB.Transaction(func(tx *gorm.DB) error {

View File

@ -56,17 +56,6 @@ func updateServer(c *gin.Context) (any, error) {
return nil, err return nil, err
} }
singleton.DDNSCacheLock.RLock()
for _, pid := range sf.DDNSProfiles {
if p, ok := singleton.DDNSCache[pid]; ok {
if !p.HasPermission(c) {
singleton.DDNSCacheLock.RUnlock()
return nil, singleton.Localizer.ErrorT("permission denied")
}
}
}
singleton.DDNSCacheLock.RUnlock()
var s model.Server var s model.Server
if err := singleton.DB.First(&s, id).Error; err != nil { if err := singleton.DB.First(&s, id).Error; err != nil {
return nil, singleton.Localizer.ErrorT("server id %d does not exist", id) return nil, singleton.Localizer.ErrorT("server id %d does not exist", id)
@ -114,18 +103,20 @@ func updateServer(c *gin.Context) (any, error) {
// @Success 200 {object} model.CommonResponse[any] // @Success 200 {object} model.CommonResponse[any]
// @Router /batch-delete/server [post] // @Router /batch-delete/server [post]
func batchDeleteServer(c *gin.Context) (any, error) { func batchDeleteServer(c *gin.Context) (any, error) {
var servers []uint64 var serversRaw []uint64
if err := c.ShouldBindJSON(&servers); err != nil { if err := c.ShouldBindJSON(&serversRaw); err != nil {
return nil, err return nil, err
} }
var servers []uint64
singleton.ServerLock.RLock() singleton.ServerLock.RLock()
for _, sid := range servers { for _, sid := range serversRaw {
if s, ok := singleton.ServerList[sid]; ok { if s, ok := singleton.ServerList[sid]; ok {
if !s.HasPermission(c) { if !s.HasPermission(c) {
singleton.ServerLock.RUnlock() singleton.ServerLock.RUnlock()
return nil, singleton.Localizer.ErrorT("permission denied") return nil, singleton.Localizer.ErrorT("permission denied")
} }
servers = append(servers, s.ID)
} }
} }
singleton.ServerLock.RUnlock() singleton.ServerLock.RUnlock()

View File

@ -67,17 +67,6 @@ func createServerGroup(c *gin.Context) (uint64, error) {
} }
sgf.Servers = slices.Compact(sgf.Servers) sgf.Servers = slices.Compact(sgf.Servers)
singleton.ServerLock.RLock()
for _, sid := range sgf.Servers {
if server, ok := singleton.ServerList[sid]; ok {
if !server.HasPermission(c) {
singleton.ServerLock.RUnlock()
return 0, singleton.Localizer.ErrorT("permission denied")
}
}
}
singleton.ServerLock.RUnlock()
uid := getUid(c) uid := getUid(c)
var sg model.ServerGroup var sg model.ServerGroup
@ -142,17 +131,6 @@ func updateServerGroup(c *gin.Context) (any, error) {
} }
sg.Servers = slices.Compact(sg.Servers) sg.Servers = slices.Compact(sg.Servers)
singleton.ServerLock.RLock()
for _, sid := range sg.Servers {
if server, ok := singleton.ServerList[sid]; ok {
if !server.HasPermission(c) {
singleton.ServerLock.RUnlock()
return nil, singleton.Localizer.ErrorT("permission denied")
}
}
}
singleton.ServerLock.RUnlock()
var sgDB model.ServerGroup var sgDB model.ServerGroup
if err := singleton.DB.First(&sgDB, id).Error; err != nil { if err := singleton.DB.First(&sgDB, id).Error; err != nil {
return nil, singleton.Localizer.ErrorT("group id %d does not exist", id) return nil, singleton.Localizer.ErrorT("group id %d does not exist", id)
@ -214,20 +192,22 @@ func updateServerGroup(c *gin.Context) (any, error) {
// @Success 200 {object} model.CommonResponse[any] // @Success 200 {object} model.CommonResponse[any]
// @Router /batch-delete/server-group [post] // @Router /batch-delete/server-group [post]
func batchDeleteServerGroup(c *gin.Context) (any, error) { func batchDeleteServerGroup(c *gin.Context) (any, error) {
var sgs []uint64 var sgsr []uint64
if err := c.ShouldBindJSON(&sgs); err != nil { if err := c.ShouldBindJSON(&sgsr); err != nil {
return nil, err return nil, err
} }
var sg []model.ServerGroup var sg []model.ServerGroup
if err := singleton.DB.Where("id in (?)", sgs).Find(&sg).Error; err != nil { if err := singleton.DB.Where("id in (?)", sgsr).Find(&sg).Error; err != nil {
return nil, err return nil, err
} }
var sgs []uint64
for _, s := range sg { for _, s := range sg {
if !s.HasPermission(c) { if !s.HasPermission(c) {
return nil, singleton.Localizer.ErrorT("permission denied") return nil, singleton.Localizer.ErrorT("permission denied")
} }
sgs = append(sgs, s.ID)
} }
err := singleton.DB.Transaction(func(tx *gorm.DB) error { err := singleton.DB.Transaction(func(tx *gorm.DB) error {

View File

@ -210,10 +210,6 @@ func createService(c *gin.Context) (uint64, error) {
m.RecoverTriggerTasks = mf.RecoverTriggerTasks m.RecoverTriggerTasks = mf.RecoverTriggerTasks
m.FailTriggerTasks = mf.FailTriggerTasks m.FailTriggerTasks = mf.FailTriggerTasks
if err := validateServers(c, &m); err != nil {
return 0, err
}
if err := singleton.DB.Create(&m).Error; err != nil { if err := singleton.DB.Create(&m).Error; err != nil {
return 0, newGormError("%v", err) return 0, newGormError("%v", err)
} }
@ -288,10 +284,6 @@ func updateService(c *gin.Context) (any, error) {
m.RecoverTriggerTasks = mf.RecoverTriggerTasks m.RecoverTriggerTasks = mf.RecoverTriggerTasks
m.FailTriggerTasks = mf.FailTriggerTasks m.FailTriggerTasks = mf.FailTriggerTasks
if err := validateServers(c, &m); err != nil {
return 0, err
}
if err := singleton.DB.Save(&m).Error; err != nil { if err := singleton.DB.Save(&m).Error; err != nil {
return nil, newGormError("%v", err) return nil, newGormError("%v", err)
} }
@ -330,18 +322,20 @@ func updateService(c *gin.Context) (any, error) {
// @Success 200 {object} model.CommonResponse[any] // @Success 200 {object} model.CommonResponse[any]
// @Router /batch-delete/service [post] // @Router /batch-delete/service [post]
func batchDeleteService(c *gin.Context) (any, error) { func batchDeleteService(c *gin.Context) (any, error) {
var ids []uint64 var idsr []uint64
if err := c.ShouldBindJSON(&ids); err != nil { if err := c.ShouldBindJSON(&idsr); err != nil {
return nil, err return nil, err
} }
var ids []uint64
singleton.ServiceSentinelShared.ServicesLock.RLock() singleton.ServiceSentinelShared.ServicesLock.RLock()
for _, id := range ids { for _, id := range idsr {
if ss, ok := singleton.ServiceSentinelShared.Services[id]; ok { if ss, ok := singleton.ServiceSentinelShared.Services[id]; ok {
if !ss.HasPermission(c) { if !ss.HasPermission(c) {
singleton.ServiceSentinelShared.ServicesLock.RUnlock() singleton.ServiceSentinelShared.ServicesLock.RUnlock()
return nil, singleton.Localizer.ErrorT("permission denied") return nil, singleton.Localizer.ErrorT("permission denied")
} }
ids = append(ids, ss.ID)
} }
} }
singleton.ServiceSentinelShared.ServicesLock.RUnlock() singleton.ServiceSentinelShared.ServicesLock.RUnlock()
@ -359,18 +353,3 @@ func batchDeleteService(c *gin.Context) (any, error) {
singleton.ServiceSentinelShared.UpdateServiceList() singleton.ServiceSentinelShared.UpdateServiceList()
return nil, nil return nil, nil
} }
func validateServers(c *gin.Context, ss *model.Service) error {
singleton.ServerLock.RLock()
defer singleton.ServerLock.RUnlock()
for s := range ss.SkipServers {
if server, ok := singleton.ServerList[s]; ok {
if !server.HasPermission(c) {
return singleton.Localizer.ErrorT("permission denied")
}
}
}
return nil
}

View File

@ -18,17 +18,17 @@ import (
// @Param limit query uint false "Page limit" // @Param limit query uint false "Page limit"
// @Param offset query uint false "Page offset" // @Param offset query uint false "Page offset"
// @Produce json // @Produce json
// @Success 200 {object} model.PaginatedResponse[[]model.WAFApiMock, model.WAFApiMock] // @Success 200 {object} model.CommonResponse[[]model.WAFApiMock]
// @Router /waf [get] // @Router /waf [get]
func listBlockedAddress(c *gin.Context) (*model.Value[[]*model.WAF], error) { func listBlockedAddress(c *gin.Context) ([]*model.WAF, error) {
limit, err := strconv.Atoi(c.Query("limit")) limit, err := strconv.Atoi(c.Query("limit"))
if err != nil || limit < 1 { if err != nil || limit < 1 {
limit = 25 limit = 25
} }
offset, err := strconv.Atoi(c.Query("offset")) offset, err := strconv.Atoi(c.Query("offset"))
if err != nil || offset < 0 { if err != nil || offset < 1 {
offset = 0 offset = 1
} }
var waf []*model.WAF var waf []*model.WAF
@ -36,19 +36,7 @@ func listBlockedAddress(c *gin.Context) (*model.Value[[]*model.WAF], error) {
return nil, err return nil, err
} }
var total int64 return waf, nil
if err := singleton.DB.Model(&model.WAF{}).Count(&total).Error; err != nil {
return nil, err
}
return &model.Value[[]*model.WAF]{
Value: waf,
Pagination: model.Pagination{
Offset: offset,
Limit: limit,
Total: total,
},
}, nil
} }
// Batch delete blocked addresses // Batch delete blocked addresses

View File

@ -100,34 +100,12 @@ func DispatchTask(serviceSentinelDispatchBus <-chan model.Service) {
continue continue
} }
if task.Cover == model.ServiceCoverIgnoreAll && task.SkipServers[singleton.SortedServerList[workedServerIndex].ID] { if task.Cover == model.ServiceCoverIgnoreAll && task.SkipServers[singleton.SortedServerList[workedServerIndex].ID] {
server := singleton.SortedServerList[workedServerIndex] singleton.SortedServerList[workedServerIndex].TaskStream.Send(task.PB())
singleton.UserLock.RLock()
var role uint8
if u, ok := singleton.UserInfoMap[server.UserID]; !ok {
role = model.RoleMember
} else {
role = u.Role
}
singleton.UserLock.RUnlock()
if task.UserID == server.UserID || role == model.RoleAdmin {
singleton.SortedServerList[workedServerIndex].TaskStream.Send(task.PB())
}
workedServerIndex++ workedServerIndex++
continue continue
} }
if task.Cover == model.ServiceCoverAll && !task.SkipServers[singleton.SortedServerList[workedServerIndex].ID] { if task.Cover == model.ServiceCoverAll && !task.SkipServers[singleton.SortedServerList[workedServerIndex].ID] {
server := singleton.SortedServerList[workedServerIndex] singleton.SortedServerList[workedServerIndex].TaskStream.Send(task.PB())
singleton.UserLock.RLock()
var role uint8
if u, ok := singleton.UserInfoMap[server.UserID]; !ok {
role = model.RoleMember
} else {
role = u.Role
}
singleton.UserLock.RUnlock()
if task.UserID == server.UserID || role == model.RoleAdmin {
singleton.SortedServerList[workedServerIndex].TaskStream.Send(task.PB())
}
workedServerIndex++ workedServerIndex++
continue continue
} }

View File

@ -62,15 +62,10 @@ func (r *AlertRule) Enabled() bool {
} }
// Snapshot 对传入的Server进行该报警规则下所有type的检查 返回每项检查结果 // Snapshot 对传入的Server进行该报警规则下所有type的检查 返回每项检查结果
func (r *AlertRule) Snapshot(cycleTransferStats *CycleTransferStats, server *Server, db *gorm.DB, role uint8) []bool { func (r *AlertRule) Snapshot(cycleTransferStats *CycleTransferStats, server *Server, db *gorm.DB) []bool {
point := make([]bool, len(r.Rules)) point := make([]bool, 0, len(r.Rules))
for _, rule := range r.Rules {
if r.UserID != server.UserID && role != RoleAdmin { point = append(point, rule.Snapshot(cycleTransferStats, server, db))
return point
}
for i, rule := range r.Rules {
point[i] = rule.Snapshot(cycleTransferStats, server, db)
} }
return point return point
} }

View File

@ -15,23 +15,6 @@ type CommonResponse[T any] struct {
Error string `json:"error,omitempty"` Error string `json:"error,omitempty"`
} }
type PaginatedResponse[S ~[]E, E any] struct {
Success bool `json:"success,omitempty"`
Data *Value[S] `json:"data,omitempty"`
Error string `json:"error,omitempty"`
}
type Value[T any] struct {
Value T `json:"value,omitempty"`
Pagination Pagination `json:"pagination,omitempty"`
}
type Pagination struct {
Offset int `json:"offset,omitempty"`
Limit int `json:"limit,omitempty"`
Total int64 `json:"total,omitempty"`
}
type LoginResponse struct { type LoginResponse struct {
Token string `json:"token,omitempty"` Token string `json:"token,omitempty"`
Expire string `json:"expire,omitempty"` Expire string `json:"expire,omitempty"`

View File

@ -52,7 +52,7 @@ type CommonInterface interface {
HasPermission(*gin.Context) bool HasPermission(*gin.Context) bool
} }
func FindByUserID[S ~[]E, E CommonInterface](s S, uid uint64) []uint64 { func FindUserID[S ~[]E, E CommonInterface](s S, uid uint64) []uint64 {
var list []uint64 var list []uint64
for _, v := range s { for _, v := range s {
if v.GetUserID() == uid { if v.GetUserID() == uid {

View File

@ -18,12 +18,6 @@ type User struct {
AgentSecret string `json:"agent_secret,omitempty" gorm:"type:char(32)"` AgentSecret string `json:"agent_secret,omitempty" gorm:"type:char(32)"`
} }
type UserInfo struct {
Role uint8
_ [3]byte
AgentSecret string
}
func (u *User) BeforeSave(tx *gorm.DB) error { func (u *User) BeforeSave(tx *gorm.DB) error {
if u.AgentSecret != "" { if u.AgentSecret != "" {
return nil return nil

View File

@ -23,19 +23,19 @@ const (
) )
type WAFApiMock struct { type WAFApiMock struct {
ID uint64 `json:"id,omitempty"`
IP string `json:"ip,omitempty"` IP string `json:"ip,omitempty"`
BlockIdentifier int64 `json:"block_identifier,omitempty"`
BlockReason uint8 `json:"block_reason,omitempty"` BlockReason uint8 `json:"block_reason,omitempty"`
BlockTimestamp uint64 `json:"block_timestamp,omitempty"` BlockTimestamp uint64 `json:"block_timestamp,omitempty"`
Count uint64 `json:"count,omitempty"` BlockIdentifier uint64 `json:"block_identifier,omitempty"`
} }
type WAF struct { type WAF struct {
IP []byte `gorm:"type:binary(16);primaryKey" json:"ip,omitempty"` ID uint64 `gorm:"primaryKey" json:"id,omitempty"`
BlockIdentifier int64 `gorm:"primaryKey" json:"block_identifier,omitempty"` IP []byte `gorm:"type:binary(16);index:idx_block_identifier" json:"ip,omitempty"`
BlockReason uint8 `json:"block_reason,omitempty"` BlockReason uint8 `json:"block_reason,omitempty"`
BlockTimestamp uint64 `gorm:"index" json:"block_timestamp,omitempty"` BlockTimestamp uint64 `json:"block_timestamp,omitempty"`
Count uint64 `json:"count,omitempty"` BlockIdentifier int64 `gorm:"index:idx_block_identifier" json:"block_identifier,omitempty"`
} }
func (w *WAF) TableName() string { func (w *WAF) TableName() string {
@ -52,7 +52,7 @@ func CheckIP(db *gorm.DB, ip string) error {
} }
var blockTimestamp uint64 var blockTimestamp uint64
result := db.Model(&WAF{}).Order("block_timestamp desc").Select("block_timestamp").Where("ip = ?", ipBinary).Limit(1).Find(&blockTimestamp) result := db.Model(&WAF{}).Select("block_timestamp").Order("id desc").Where("ip = ?", ipBinary).Limit(1).Find(&blockTimestamp)
if result.Error != nil { if result.Error != nil {
return result.Error return result.Error
} }
@ -62,13 +62,13 @@ func CheckIP(db *gorm.DB, ip string) error {
return nil return nil
} }
var count uint64 var count int64
if err := db.Model(&WAF{}).Select("SUM(count)").Where("ip = ?", ipBinary).Scan(&count).Error; err != nil { if err := db.Model(&WAF{}).Where("ip = ?", ipBinary).Count(&count).Error; err != nil {
return err return err
} }
now := time.Now().Unix() now := time.Now().Unix()
if powAdd(count, 4, blockTimestamp) > uint64(now) { if powAdd(uint64(count), 4, blockTimestamp) > uint64(now) {
return errors.New("you are blocked by nezha WAF") return errors.New("you are blocked by nezha WAF")
} }
return nil return nil
@ -110,17 +110,18 @@ func BlockIP(db *gorm.DB, ip string, reason uint8, uid int64) error {
} }
w := WAF{ w := WAF{
IP: ipBinary, IP: ipBinary,
BlockReason: reason,
BlockTimestamp: uint64(time.Now().Unix()),
BlockIdentifier: uid, BlockIdentifier: uid,
} }
now := uint64(time.Now().Unix())
return db.Transaction(func(tx *gorm.DB) error { return db.Transaction(func(tx *gorm.DB) error {
if err := tx.Where(&w).Attrs(WAF{ var lastRecord WAF
BlockReason: reason, if err := tx.Model(&WAF{}).Order("id desc").Where("ip = ?", ipBinary).First(&lastRecord).Error; err != nil {
BlockTimestamp: now, if !errors.Is(err, gorm.ErrRecordNotFound) {
}).FirstOrCreate(&w).Error; err != nil { return err
return err }
} }
return tx.Exec("UPDATE nz_waf SET count = count + 1, block_reason = ?, block_timestamp = ? WHERE ip = ? and block_identifier = ?", reason, now, ipBinary, uid).Error return tx.Create(&w).Error
}) })
} }

View File

@ -2,6 +2,7 @@ package rpc
import ( import (
"context" "context"
"crypto/subtle"
"strings" "strings"
petname "github.com/dustinkirkland/golang-petname" petname "github.com/dustinkirkland/golang-petname"
@ -38,7 +39,7 @@ func (a *authHandler) Check(ctx context.Context) (uint64, error) {
singleton.UserLock.RLock() singleton.UserLock.RLock()
userId, ok := singleton.AgentSecretToUserId[clientSecret] userId, ok := singleton.AgentSecretToUserId[clientSecret]
if !ok && clientSecret != singleton.Conf.AgentSecretKey { if !ok && subtle.ConstantTimeCompare([]byte(clientSecret), []byte(singleton.Conf.AgentSecretKey)) != 1 {
singleton.UserLock.RUnlock() singleton.UserLock.RUnlock()
model.BlockIP(singleton.DB, ip, model.WAFBlockReasonTypeAgentAuthFail, model.BlockIDgRPC) model.BlockIP(singleton.DB, ip, model.WAFBlockReasonTypeAgentAuthFail, model.BlockIDgRPC)
return 0, status.Error(codes.Unauthenticated, "客户端认证失败") return 0, status.Error(codes.Unauthenticated, "客户端认证失败")

View File

@ -143,16 +143,8 @@ func checkStatus() {
} }
for _, server := range ServerList { for _, server := range ServerList {
// 监测点 // 监测点
UserLock.RLock()
var role uint8
if u, ok := UserInfoMap[server.UserID]; !ok {
role = model.RoleMember
} else {
role = u.Role
}
UserLock.RUnlock()
alertsStore[alert.ID][server.ID] = append(alertsStore[alert. alertsStore[alert.ID][server.ID] = append(alertsStore[alert.
ID][server.ID], alert.Snapshot(AlertsCycleTransferStatsStore[alert.ID], server, DB, role)) ID][server.ID], alert.Snapshot(AlertsCycleTransferStatsStore[alert.ID], server, DB))
// 发送通知,分为触发报警和恢复通知 // 发送通知,分为触发报警和恢复通知
max, passed := alert.Check(alertsStore[alert.ID][server.ID]) max, passed := alert.Check(alertsStore[alert.ID][server.ID])
// 保存当前服务器状态信息 // 保存当前服务器状态信息

View File

@ -8,24 +8,21 @@ import (
) )
var ( var (
UserInfoMap map[uint64]model.UserInfo UserIdToAgentSecret map[uint64]string
AgentSecretToUserId map[string]uint64 AgentSecretToUserId map[string]uint64
UserLock sync.RWMutex UserLock sync.RWMutex
) )
func initUser() { func initUser() {
UserInfoMap = make(map[uint64]model.UserInfo) UserIdToAgentSecret = make(map[uint64]string)
AgentSecretToUserId = make(map[string]uint64) AgentSecretToUserId = make(map[string]uint64)
var users []model.User var users []model.User
DB.Find(&users) DB.Find(&users)
for _, u := range users { for _, u := range users {
UserInfoMap[u.ID] = model.UserInfo{ UserIdToAgentSecret[u.ID] = u.AgentSecret
Role: u.Role,
AgentSecret: u.AgentSecret,
}
AgentSecretToUserId[u.AgentSecret] = u.ID AgentSecretToUserId[u.AgentSecret] = u.ID
} }
} }
@ -38,10 +35,7 @@ func OnUserUpdate(u *model.User) {
return return
} }
UserInfoMap[u.ID] = model.UserInfo{ UserIdToAgentSecret[u.ID] = u.AgentSecret
Role: u.Role,
AgentSecret: u.AgentSecret,
}
AgentSecretToUserId[u.AgentSecret] = u.ID AgentSecretToUserId[u.AgentSecret] = u.ID
} }
@ -61,7 +55,7 @@ func OnUserDelete(id []uint64, errorFunc func(string, ...interface{}) error) err
for _, uid := range id { for _, uid := range id {
err := DB.Transaction(func(tx *gorm.DB) error { err := DB.Transaction(func(tx *gorm.DB) error {
CronLock.RLock() CronLock.RLock()
crons = model.FindByUserID(CronList, uid) crons = model.FindUserID(CronList, uid)
CronLock.RUnlock() CronLock.RUnlock()
cron = len(crons) > 0 cron = len(crons) > 0
@ -72,7 +66,7 @@ func OnUserDelete(id []uint64, errorFunc func(string, ...interface{}) error) err
} }
SortedServerLock.RLock() SortedServerLock.RLock()
servers = model.FindByUserID(SortedServerList, uid) servers = model.FindUserID(SortedServerList, uid)
SortedServerLock.RUnlock() SortedServerLock.RUnlock()
server = len(servers) > 0 server = len(servers) > 0
@ -118,9 +112,9 @@ func OnUserDelete(id []uint64, errorFunc func(string, ...interface{}) error) err
OnServerDelete(servers) OnServerDelete(servers)
} }
secret := UserInfoMap[uid].AgentSecret secret := UserIdToAgentSecret[uid]
delete(AgentSecretToUserId, secret) delete(AgentSecretToUserId, secret)
delete(UserInfoMap, uid) delete(UserIdToAgentSecret, uid)
} }
if cron { if cron {