Compare commits

...

8 Commits

Author SHA1 Message Date
uubulb
00a0494124 some changes 2024-12-21 23:56:56 +08:00
uubulb
ddd782d623 cover? 2024-12-21 23:41:39 +08:00
uubulb
4085a928a8 1 2024-12-21 23:39:27 +08:00
uubulb
18513110cf switch to runtime check 2024-12-21 23:18:43 +08:00
uubulb
48c8ebc1e1 update permission checks 2024-12-21 22:22:07 +08:00
uubulb
12d534fa64 add pagination for waf api 2024-12-21 21:28:12 +08:00
uubulb
2f4cf5ad28 fix several problems 2024-12-21 20:25:31 +08:00
uubulb
162937cb35 update waf table 2024-12-21 20:10:25 +08:00
20 changed files with 289 additions and 90 deletions

View File

@ -62,7 +62,7 @@ func createAlertRule(c *gin.Context) (uint64, error) {
r.TriggerMode = arf.TriggerMode
r.Enable = &enable
if err := validateRule(&r); err != nil {
if err := validateRule(c, &r); err != nil {
return 0, err
}
@ -116,7 +116,7 @@ func updateAlertRule(c *gin.Context) (any, error) {
r.TriggerMode = arf.TriggerMode
r.Enable = &enable
if err := validateRule(&r); err != nil {
if err := validateRule(c, &r); err != nil {
return 0, err
}
@ -140,22 +140,20 @@ func updateAlertRule(c *gin.Context) (any, error) {
// @Success 200 {object} model.CommonResponse[any]
// @Router /batch-delete/alert-rule [post]
func batchDeleteAlertRule(c *gin.Context) (any, error) {
var arr []uint64
if err := c.ShouldBindJSON(&arr); err != nil {
var ar []uint64
if err := c.ShouldBindJSON(&ar); err != nil {
return nil, err
}
var ars []model.AlertRule
if err := singleton.DB.Where("id in (?)", arr).Find(&ars).Error; err != nil {
if err := singleton.DB.Where("id in (?)", ar).Find(&ars).Error; err != nil {
return nil, err
}
var ar []uint64
for _, a := range ars {
if !a.HasPermission(c) {
return nil, singleton.Localizer.ErrorT("permission denied")
}
ar = append(ar, a.ID)
}
if err := singleton.DB.Unscoped().Delete(&model.AlertRule{}, "id in (?)", ar).Error; err != nil {
@ -166,9 +164,20 @@ func batchDeleteAlertRule(c *gin.Context) (any, error) {
return nil, nil
}
func validateRule(r *model.AlertRule) error {
func validateRule(c *gin.Context, r *model.AlertRule) error {
if len(r.Rules) > 0 {
for _, rule := range r.Rules {
singleton.ServerLock.RLock()
for s := range rule.Ignore {
if server, ok := singleton.ServerList[s]; ok {
if !server.HasPermission(c) {
singleton.ServerLock.RUnlock()
return singleton.Localizer.ErrorT("permission denied")
}
}
}
singleton.ServerLock.RUnlock()
if !rule.IsTransferDurationRule() {
if rule.Duration < 3 {
return singleton.Localizer.ErrorT("duration need to be at least 3")

View File

@ -129,7 +129,7 @@ func routers(r *gin.Engine, frontendDist fs.FS) {
auth.PATCH("/nat/:id", commonHandler(updateNAT))
auth.POST("/batch-delete/nat", commonHandler(batchDeleteNAT))
auth.GET("/waf", commonHandler(listBlockedAddress))
auth.GET("/waf", pCommonHandler(listBlockedAddress))
auth.POST("/batch-delete/waf", adminHandler(batchDeleteBlockedAddress))
auth.PATCH("/setting", adminHandler(updateConfig))
@ -153,6 +153,7 @@ func newErrorResponse(err error) model.CommonResponse[any] {
}
type handlerFunc[T any] func(c *gin.Context) (T, error)
type pHandlerFunc[S ~[]E, E any] func(c *gin.Context) (*model.Value[S], error)
// There are many error types in gorm, so create a custom type to represent all
// gorm errors here instead
@ -247,6 +248,18 @@ func listHandler[S ~[]E, E model.CommonInterface](handler handlerFunc[S]) func(*
}
}
func pCommonHandler[S ~[]E, E any](handler pHandlerFunc[S, E]) func(*gin.Context) {
return func(c *gin.Context) {
data, err := handler(c)
if err != nil {
c.JSON(http.StatusOK, newErrorResponse(err))
return
}
c.JSON(http.StatusOK, model.PaginatedResponse[S, E]{Success: true, Data: data})
}
}
func filter[S ~[]E, E model.CommonInterface](ctx *gin.Context, s S) S {
return slices.DeleteFunc(s, func(e E) bool {
return !e.HasPermission(ctx)

View File

@ -49,6 +49,17 @@ func createCron(c *gin.Context) (uint64, error) {
return 0, err
}
singleton.ServerLock.RLock()
for _, sid := range cf.Servers {
if server, ok := singleton.ServerList[sid]; ok {
if !server.HasPermission(c) {
singleton.ServerLock.RUnlock()
return 0, singleton.Localizer.ErrorT("permission denied")
}
}
}
singleton.ServerLock.RUnlock()
cr.UserID = getUid(c)
cr.TaskType = cf.TaskType
cr.Name = cf.Name
@ -104,6 +115,17 @@ func updateCron(c *gin.Context) (any, error) {
return 0, err
}
singleton.ServerLock.RLock()
for _, sid := range cf.Servers {
if server, ok := singleton.ServerList[sid]; ok {
if !server.HasPermission(c) {
singleton.ServerLock.RUnlock()
return nil, singleton.Localizer.ErrorT("permission denied")
}
}
}
singleton.ServerLock.RUnlock()
var cr model.Cron
if err := singleton.DB.First(&cr, id).Error; err != nil {
return nil, singleton.Localizer.ErrorT("task id %d does not exist", id)
@ -188,20 +210,18 @@ func manualTriggerCron(c *gin.Context) (any, error) {
// @Success 200 {object} model.CommonResponse[any]
// @Router /batch-delete/cron [post]
func batchDeleteCron(c *gin.Context) (any, error) {
var crr []uint64
if err := c.ShouldBindJSON(&crr); err != nil {
var cr []uint64
if err := c.ShouldBindJSON(&cr); err != nil {
return nil, err
}
var cr []uint64
singleton.CronLock.RLock()
for _, crID := range crr {
for _, crID := range cr {
if crn, ok := singleton.Crons[crID]; ok {
if !crn.HasPermission(c) {
singleton.CronLock.RUnlock()
return nil, singleton.Localizer.ErrorT("permission denied")
}
cr = append(cr, crn.ID)
}
}
singleton.CronLock.RUnlock()

View File

@ -177,21 +177,19 @@ func updateDDNS(c *gin.Context) (any, error) {
// @Success 200 {object} model.CommonResponse[any]
// @Router /batch-delete/ddns [post]
func batchDeleteDDNS(c *gin.Context) (any, error) {
var ddnsConfigsr []uint64
var ddnsConfigs []uint64
if err := c.ShouldBindJSON(&ddnsConfigsr); err != nil {
if err := c.ShouldBindJSON(&ddnsConfigs); err != nil {
return nil, err
}
var ddnsConfigs []uint64
singleton.DDNSCacheLock.RLock()
for _, pid := range ddnsConfigsr {
for _, pid := range ddnsConfigs {
if p, ok := singleton.DDNSCache[pid]; ok {
if !p.HasPermission(c) {
singleton.DDNSCacheLock.RUnlock()
return nil, singleton.Localizer.ErrorT("permission denied")
}
ddnsConfigs = append(ddnsConfigs, p.ID)
}
}
singleton.DDNSCacheLock.RUnlock()

View File

@ -51,6 +51,15 @@ func createNAT(c *gin.Context) (uint64, error) {
return 0, err
}
singleton.ServerLock.RLock()
if server, ok := singleton.ServerList[nf.ServerID]; ok {
if !server.HasPermission(c) {
singleton.ServerLock.RUnlock()
return 0, singleton.Localizer.ErrorT("permission denied")
}
}
singleton.ServerLock.RUnlock()
uid := getUid(c)
n.UserID = uid
@ -93,6 +102,15 @@ func updateNAT(c *gin.Context) (any, error) {
return nil, err
}
singleton.ServerLock.RLock()
if server, ok := singleton.ServerList[nf.ServerID]; ok {
if !server.HasPermission(c) {
singleton.ServerLock.RUnlock()
return nil, singleton.Localizer.ErrorT("permission denied")
}
}
singleton.ServerLock.RUnlock()
var n model.NAT
if err = singleton.DB.First(&n, id).Error; err != nil {
return nil, singleton.Localizer.ErrorT("profile id %d does not exist", id)
@ -128,20 +146,18 @@ func updateNAT(c *gin.Context) (any, error) {
// @Success 200 {object} model.CommonResponse[any]
// @Router /batch-delete/nat [post]
func batchDeleteNAT(c *gin.Context) (any, error) {
var nr []uint64
if err := c.ShouldBindJSON(&nr); err != nil {
var n []uint64
if err := c.ShouldBindJSON(&n); err != nil {
return nil, err
}
var n []uint64
singleton.NATCacheRwLock.RLock()
for _, id := range nr {
for _, id := range n {
if p, ok := singleton.NATCache[singleton.NATIDToDomain[id]]; ok {
if !p.HasPermission(c) {
singleton.NATCacheRwLock.RUnlock()
return nil, singleton.Localizer.ErrorT("permission denied")
}
n = append(n, p.ID)
}
}
singleton.NATCacheRwLock.RUnlock()

View File

@ -153,19 +153,17 @@ func updateNotification(c *gin.Context) (any, error) {
// @Success 200 {object} model.CommonResponse[any]
// @Router /batch-delete/notification [post]
func batchDeleteNotification(c *gin.Context) (any, error) {
var nr []uint64
if err := c.ShouldBindJSON(&nr); err != nil {
var n []uint64
if err := c.ShouldBindJSON(&n); err != nil {
return nil, err
}
var n []uint64
singleton.NotificationsLock.RLock()
for _, nid := range nr {
for _, nid := range n {
if ns, ok := singleton.NotificationMap[nid]; ok {
if !ns.HasPermission(c) {
return nil, singleton.Localizer.ErrorT("permission denied")
}
n = append(n, ns.ID)
}
}
singleton.NotificationsLock.RUnlock()

View File

@ -68,6 +68,17 @@ func createNotificationGroup(c *gin.Context) (uint64, error) {
}
ngf.Notifications = slices.Compact(ngf.Notifications)
singleton.NotificationsLock.RLock()
for _, nid := range ngf.Notifications {
if n, ok := singleton.NotificationMap[nid]; ok {
if !n.HasPermission(c) {
singleton.NotificationsLock.RUnlock()
return 0, singleton.Localizer.ErrorT("permission denied")
}
}
}
singleton.NotificationsLock.RUnlock()
uid := getUid(c)
var ng model.NotificationGroup
@ -132,6 +143,18 @@ func updateNotificationGroup(c *gin.Context) (any, error) {
if err := c.ShouldBindJSON(&ngf); err != nil {
return nil, err
}
singleton.NotificationsLock.RLock()
for _, nid := range ngf.Notifications {
if n, ok := singleton.NotificationMap[nid]; ok {
if !n.HasPermission(c) {
singleton.NotificationsLock.RUnlock()
return nil, singleton.Localizer.ErrorT("permission denied")
}
}
}
singleton.NotificationsLock.RUnlock()
var ngDB model.NotificationGroup
if err := singleton.DB.First(&ngDB, id).Error; err != nil {
return nil, singleton.Localizer.ErrorT("group id %d does not exist", id)
@ -195,22 +218,20 @@ func updateNotificationGroup(c *gin.Context) (any, error) {
// @Success 200 {object} model.CommonResponse[any]
// @Router /batch-delete/notification-group [post]
func batchDeleteNotificationGroup(c *gin.Context) (any, error) {
var ngnr []uint64
if err := c.ShouldBindJSON(&ngnr); err != nil {
var ngn []uint64
if err := c.ShouldBindJSON(&ngn); err != nil {
return nil, err
}
var ng []model.NotificationGroup
if err := singleton.DB.Where("id in (?)", ngnr).Find(&ng).Error; err != nil {
if err := singleton.DB.Where("id in (?)", ngn).Find(&ng).Error; err != nil {
return nil, err
}
var ngn []uint64
for _, n := range ng {
if !n.HasPermission(c) {
return nil, singleton.Localizer.ErrorT("permission denied")
}
ngn = append(ngn, n.ID)
}
err := singleton.DB.Transaction(func(tx *gorm.DB) error {

View File

@ -56,6 +56,17 @@ func updateServer(c *gin.Context) (any, error) {
return nil, err
}
singleton.DDNSCacheLock.RLock()
for _, pid := range sf.DDNSProfiles {
if p, ok := singleton.DDNSCache[pid]; ok {
if !p.HasPermission(c) {
singleton.DDNSCacheLock.RUnlock()
return nil, singleton.Localizer.ErrorT("permission denied")
}
}
}
singleton.DDNSCacheLock.RUnlock()
var s model.Server
if err := singleton.DB.First(&s, id).Error; err != nil {
return nil, singleton.Localizer.ErrorT("server id %d does not exist", id)
@ -103,20 +114,18 @@ func updateServer(c *gin.Context) (any, error) {
// @Success 200 {object} model.CommonResponse[any]
// @Router /batch-delete/server [post]
func batchDeleteServer(c *gin.Context) (any, error) {
var serversRaw []uint64
if err := c.ShouldBindJSON(&serversRaw); err != nil {
var servers []uint64
if err := c.ShouldBindJSON(&servers); err != nil {
return nil, err
}
var servers []uint64
singleton.ServerLock.RLock()
for _, sid := range serversRaw {
for _, sid := range servers {
if s, ok := singleton.ServerList[sid]; ok {
if !s.HasPermission(c) {
singleton.ServerLock.RUnlock()
return nil, singleton.Localizer.ErrorT("permission denied")
}
servers = append(servers, s.ID)
}
}
singleton.ServerLock.RUnlock()

View File

@ -67,6 +67,17 @@ func createServerGroup(c *gin.Context) (uint64, error) {
}
sgf.Servers = slices.Compact(sgf.Servers)
singleton.ServerLock.RLock()
for _, sid := range sgf.Servers {
if server, ok := singleton.ServerList[sid]; ok {
if !server.HasPermission(c) {
singleton.ServerLock.RUnlock()
return 0, singleton.Localizer.ErrorT("permission denied")
}
}
}
singleton.ServerLock.RUnlock()
uid := getUid(c)
var sg model.ServerGroup
@ -131,6 +142,17 @@ func updateServerGroup(c *gin.Context) (any, error) {
}
sg.Servers = slices.Compact(sg.Servers)
singleton.ServerLock.RLock()
for _, sid := range sg.Servers {
if server, ok := singleton.ServerList[sid]; ok {
if !server.HasPermission(c) {
singleton.ServerLock.RUnlock()
return nil, singleton.Localizer.ErrorT("permission denied")
}
}
}
singleton.ServerLock.RUnlock()
var sgDB model.ServerGroup
if err := singleton.DB.First(&sgDB, id).Error; err != nil {
return nil, singleton.Localizer.ErrorT("group id %d does not exist", id)
@ -192,22 +214,20 @@ func updateServerGroup(c *gin.Context) (any, error) {
// @Success 200 {object} model.CommonResponse[any]
// @Router /batch-delete/server-group [post]
func batchDeleteServerGroup(c *gin.Context) (any, error) {
var sgsr []uint64
if err := c.ShouldBindJSON(&sgsr); err != nil {
var sgs []uint64
if err := c.ShouldBindJSON(&sgs); err != nil {
return nil, err
}
var sg []model.ServerGroup
if err := singleton.DB.Where("id in (?)", sgsr).Find(&sg).Error; err != nil {
if err := singleton.DB.Where("id in (?)", sgs).Find(&sg).Error; err != nil {
return nil, err
}
var sgs []uint64
for _, s := range sg {
if !s.HasPermission(c) {
return nil, singleton.Localizer.ErrorT("permission denied")
}
sgs = append(sgs, s.ID)
}
err := singleton.DB.Transaction(func(tx *gorm.DB) error {

View File

@ -210,6 +210,10 @@ func createService(c *gin.Context) (uint64, error) {
m.RecoverTriggerTasks = mf.RecoverTriggerTasks
m.FailTriggerTasks = mf.FailTriggerTasks
if err := validateServers(c, &m); err != nil {
return 0, err
}
if err := singleton.DB.Create(&m).Error; err != nil {
return 0, newGormError("%v", err)
}
@ -284,6 +288,10 @@ func updateService(c *gin.Context) (any, error) {
m.RecoverTriggerTasks = mf.RecoverTriggerTasks
m.FailTriggerTasks = mf.FailTriggerTasks
if err := validateServers(c, &m); err != nil {
return 0, err
}
if err := singleton.DB.Save(&m).Error; err != nil {
return nil, newGormError("%v", err)
}
@ -322,20 +330,18 @@ func updateService(c *gin.Context) (any, error) {
// @Success 200 {object} model.CommonResponse[any]
// @Router /batch-delete/service [post]
func batchDeleteService(c *gin.Context) (any, error) {
var idsr []uint64
if err := c.ShouldBindJSON(&idsr); err != nil {
var ids []uint64
if err := c.ShouldBindJSON(&ids); err != nil {
return nil, err
}
var ids []uint64
singleton.ServiceSentinelShared.ServicesLock.RLock()
for _, id := range idsr {
for _, id := range ids {
if ss, ok := singleton.ServiceSentinelShared.Services[id]; ok {
if !ss.HasPermission(c) {
singleton.ServiceSentinelShared.ServicesLock.RUnlock()
return nil, singleton.Localizer.ErrorT("permission denied")
}
ids = append(ids, ss.ID)
}
}
singleton.ServiceSentinelShared.ServicesLock.RUnlock()
@ -353,3 +359,18 @@ func batchDeleteService(c *gin.Context) (any, error) {
singleton.ServiceSentinelShared.UpdateServiceList()
return nil, nil
}
func validateServers(c *gin.Context, ss *model.Service) error {
singleton.ServerLock.RLock()
defer singleton.ServerLock.RUnlock()
for s := range ss.SkipServers {
if server, ok := singleton.ServerList[s]; ok {
if !server.HasPermission(c) {
return singleton.Localizer.ErrorT("permission denied")
}
}
}
return nil
}

View File

@ -18,17 +18,17 @@ import (
// @Param limit query uint false "Page limit"
// @Param offset query uint false "Page offset"
// @Produce json
// @Success 200 {object} model.CommonResponse[[]model.WAFApiMock]
// @Success 200 {object} model.PaginatedResponse[[]model.WAFApiMock, model.WAFApiMock]
// @Router /waf [get]
func listBlockedAddress(c *gin.Context) ([]*model.WAF, error) {
func listBlockedAddress(c *gin.Context) (*model.Value[[]*model.WAF], error) {
limit, err := strconv.Atoi(c.Query("limit"))
if err != nil || limit < 1 {
limit = 25
}
offset, err := strconv.Atoi(c.Query("offset"))
if err != nil || offset < 1 {
offset = 1
if err != nil || offset < 0 {
offset = 0
}
var waf []*model.WAF
@ -36,7 +36,19 @@ func listBlockedAddress(c *gin.Context) ([]*model.WAF, error) {
return nil, err
}
return waf, nil
var total int64
if err := singleton.DB.Model(&model.WAF{}).Count(&total).Error; err != nil {
return nil, err
}
return &model.Value[[]*model.WAF]{
Value: waf,
Pagination: model.Pagination{
Offset: offset,
Limit: limit,
Total: total,
},
}, nil
}
// Batch delete blocked addresses

View File

@ -100,12 +100,34 @@ func DispatchTask(serviceSentinelDispatchBus <-chan model.Service) {
continue
}
if task.Cover == model.ServiceCoverIgnoreAll && task.SkipServers[singleton.SortedServerList[workedServerIndex].ID] {
singleton.SortedServerList[workedServerIndex].TaskStream.Send(task.PB())
server := singleton.SortedServerList[workedServerIndex]
singleton.UserLock.RLock()
var role uint8
if u, ok := singleton.UserInfoMap[server.UserID]; !ok {
role = model.RoleMember
} else {
role = u.Role
}
singleton.UserLock.RUnlock()
if task.UserID == server.UserID || role == model.RoleAdmin {
singleton.SortedServerList[workedServerIndex].TaskStream.Send(task.PB())
}
workedServerIndex++
continue
}
if task.Cover == model.ServiceCoverAll && !task.SkipServers[singleton.SortedServerList[workedServerIndex].ID] {
singleton.SortedServerList[workedServerIndex].TaskStream.Send(task.PB())
server := singleton.SortedServerList[workedServerIndex]
singleton.UserLock.RLock()
var role uint8
if u, ok := singleton.UserInfoMap[server.UserID]; !ok {
role = model.RoleMember
} else {
role = u.Role
}
singleton.UserLock.RUnlock()
if task.UserID == server.UserID || role == model.RoleAdmin {
singleton.SortedServerList[workedServerIndex].TaskStream.Send(task.PB())
}
workedServerIndex++
continue
}

View File

@ -62,10 +62,15 @@ func (r *AlertRule) Enabled() bool {
}
// Snapshot 对传入的Server进行该报警规则下所有type的检查 返回每项检查结果
func (r *AlertRule) Snapshot(cycleTransferStats *CycleTransferStats, server *Server, db *gorm.DB) []bool {
point := make([]bool, 0, len(r.Rules))
for _, rule := range r.Rules {
point = append(point, rule.Snapshot(cycleTransferStats, server, db))
func (r *AlertRule) Snapshot(cycleTransferStats *CycleTransferStats, server *Server, db *gorm.DB, role uint8) []bool {
point := make([]bool, len(r.Rules))
if r.UserID != server.UserID && role != RoleAdmin {
return point
}
for i, rule := range r.Rules {
point[i] = rule.Snapshot(cycleTransferStats, server, db)
}
return point
}

View File

@ -15,6 +15,23 @@ type CommonResponse[T any] struct {
Error string `json:"error,omitempty"`
}
type PaginatedResponse[S ~[]E, E any] struct {
Success bool `json:"success,omitempty"`
Data *Value[S] `json:"data,omitempty"`
Error string `json:"error,omitempty"`
}
type Value[T any] struct {
Value T `json:"value,omitempty"`
Pagination Pagination `json:"pagination,omitempty"`
}
type Pagination struct {
Offset int `json:"offset,omitempty"`
Limit int `json:"limit,omitempty"`
Total int64 `json:"total,omitempty"`
}
type LoginResponse struct {
Token string `json:"token,omitempty"`
Expire string `json:"expire,omitempty"`

View File

@ -52,7 +52,7 @@ type CommonInterface interface {
HasPermission(*gin.Context) bool
}
func FindUserID[S ~[]E, E CommonInterface](s S, uid uint64) []uint64 {
func FindByUserID[S ~[]E, E CommonInterface](s S, uid uint64) []uint64 {
var list []uint64
for _, v := range s {
if v.GetUserID() == uid {

View File

@ -18,6 +18,12 @@ type User struct {
AgentSecret string `json:"agent_secret,omitempty" gorm:"type:char(32)"`
}
type UserInfo struct {
Role uint8
_ [3]byte
AgentSecret string
}
func (u *User) BeforeSave(tx *gorm.DB) error {
if u.AgentSecret != "" {
return nil

View File

@ -23,19 +23,19 @@ const (
)
type WAFApiMock struct {
ID uint64 `json:"id,omitempty"`
IP string `json:"ip,omitempty"`
BlockIdentifier int64 `json:"block_identifier,omitempty"`
BlockReason uint8 `json:"block_reason,omitempty"`
BlockTimestamp uint64 `json:"block_timestamp,omitempty"`
BlockIdentifier uint64 `json:"block_identifier,omitempty"`
Count uint64 `json:"count,omitempty"`
}
type WAF struct {
ID uint64 `gorm:"primaryKey" json:"id,omitempty"`
IP []byte `gorm:"type:binary(16);index:idx_block_identifier" json:"ip,omitempty"`
IP []byte `gorm:"type:binary(16);primaryKey" json:"ip,omitempty"`
BlockIdentifier int64 `gorm:"primaryKey" json:"block_identifier,omitempty"`
BlockReason uint8 `json:"block_reason,omitempty"`
BlockTimestamp uint64 `json:"block_timestamp,omitempty"`
BlockIdentifier int64 `gorm:"index:idx_block_identifier" json:"block_identifier,omitempty"`
BlockTimestamp uint64 `gorm:"index" json:"block_timestamp,omitempty"`
Count uint64 `json:"count,omitempty"`
}
func (w *WAF) TableName() string {
@ -52,7 +52,7 @@ func CheckIP(db *gorm.DB, ip string) error {
}
var blockTimestamp uint64
result := db.Model(&WAF{}).Select("block_timestamp").Order("id desc").Where("ip = ?", ipBinary).Limit(1).Find(&blockTimestamp)
result := db.Model(&WAF{}).Order("block_timestamp desc").Select("block_timestamp").Where("ip = ?", ipBinary).Limit(1).Find(&blockTimestamp)
if result.Error != nil {
return result.Error
}
@ -62,13 +62,13 @@ func CheckIP(db *gorm.DB, ip string) error {
return nil
}
var count int64
if err := db.Model(&WAF{}).Where("ip = ?", ipBinary).Count(&count).Error; err != nil {
var count uint64
if err := db.Model(&WAF{}).Select("SUM(count)").Where("ip = ?", ipBinary).Scan(&count).Error; err != nil {
return err
}
now := time.Now().Unix()
if powAdd(uint64(count), 4, blockTimestamp) > uint64(now) {
if powAdd(count, 4, blockTimestamp) > uint64(now) {
return errors.New("you are blocked by nezha WAF")
}
return nil
@ -110,18 +110,17 @@ func BlockIP(db *gorm.DB, ip string, reason uint8, uid int64) error {
}
w := WAF{
IP: ipBinary,
BlockReason: reason,
BlockTimestamp: uint64(time.Now().Unix()),
BlockIdentifier: uid,
}
now := uint64(time.Now().Unix())
return db.Transaction(func(tx *gorm.DB) error {
var lastRecord WAF
if err := tx.Model(&WAF{}).Order("id desc").Where("ip = ?", ipBinary).First(&lastRecord).Error; err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
if err := tx.Where(&w).Attrs(WAF{
BlockReason: reason,
BlockTimestamp: now,
}).FirstOrCreate(&w).Error; err != nil {
return err
}
return tx.Create(&w).Error
return tx.Exec("UPDATE nz_waf SET count = count + 1, block_reason = ?, block_timestamp = ? WHERE ip = ? and block_identifier = ?", reason, now, ipBinary, uid).Error
})
}

View File

@ -2,7 +2,6 @@ package rpc
import (
"context"
"crypto/subtle"
"strings"
petname "github.com/dustinkirkland/golang-petname"
@ -39,7 +38,7 @@ func (a *authHandler) Check(ctx context.Context) (uint64, error) {
singleton.UserLock.RLock()
userId, ok := singleton.AgentSecretToUserId[clientSecret]
if !ok && subtle.ConstantTimeCompare([]byte(clientSecret), []byte(singleton.Conf.AgentSecretKey)) != 1 {
if !ok && clientSecret != singleton.Conf.AgentSecretKey {
singleton.UserLock.RUnlock()
model.BlockIP(singleton.DB, ip, model.WAFBlockReasonTypeAgentAuthFail, model.BlockIDgRPC)
return 0, status.Error(codes.Unauthenticated, "客户端认证失败")

View File

@ -143,8 +143,16 @@ func checkStatus() {
}
for _, server := range ServerList {
// 监测点
UserLock.RLock()
var role uint8
if u, ok := UserInfoMap[server.UserID]; !ok {
role = model.RoleMember
} else {
role = u.Role
}
UserLock.RUnlock()
alertsStore[alert.ID][server.ID] = append(alertsStore[alert.
ID][server.ID], alert.Snapshot(AlertsCycleTransferStatsStore[alert.ID], server, DB))
ID][server.ID], alert.Snapshot(AlertsCycleTransferStatsStore[alert.ID], server, DB, role))
// 发送通知,分为触发报警和恢复通知
max, passed := alert.Check(alertsStore[alert.ID][server.ID])
// 保存当前服务器状态信息

View File

@ -8,21 +8,24 @@ import (
)
var (
UserIdToAgentSecret map[uint64]string
UserInfoMap map[uint64]model.UserInfo
AgentSecretToUserId map[string]uint64
UserLock sync.RWMutex
)
func initUser() {
UserIdToAgentSecret = make(map[uint64]string)
UserInfoMap = make(map[uint64]model.UserInfo)
AgentSecretToUserId = make(map[string]uint64)
var users []model.User
DB.Find(&users)
for _, u := range users {
UserIdToAgentSecret[u.ID] = u.AgentSecret
UserInfoMap[u.ID] = model.UserInfo{
Role: u.Role,
AgentSecret: u.AgentSecret,
}
AgentSecretToUserId[u.AgentSecret] = u.ID
}
}
@ -35,7 +38,10 @@ func OnUserUpdate(u *model.User) {
return
}
UserIdToAgentSecret[u.ID] = u.AgentSecret
UserInfoMap[u.ID] = model.UserInfo{
Role: u.Role,
AgentSecret: u.AgentSecret,
}
AgentSecretToUserId[u.AgentSecret] = u.ID
}
@ -55,7 +61,7 @@ func OnUserDelete(id []uint64, errorFunc func(string, ...interface{}) error) err
for _, uid := range id {
err := DB.Transaction(func(tx *gorm.DB) error {
CronLock.RLock()
crons = model.FindUserID(CronList, uid)
crons = model.FindByUserID(CronList, uid)
CronLock.RUnlock()
cron = len(crons) > 0
@ -66,7 +72,7 @@ func OnUserDelete(id []uint64, errorFunc func(string, ...interface{}) error) err
}
SortedServerLock.RLock()
servers = model.FindUserID(SortedServerList, uid)
servers = model.FindByUserID(SortedServerList, uid)
SortedServerLock.RUnlock()
server = len(servers) > 0
@ -112,9 +118,9 @@ func OnUserDelete(id []uint64, errorFunc func(string, ...interface{}) error) err
OnServerDelete(servers)
}
secret := UserIdToAgentSecret[uid]
secret := UserInfoMap[uid].AgentSecret
delete(AgentSecretToUserId, secret)
delete(UserIdToAgentSecret, uid)
delete(UserInfoMap, uid)
}
if cron {