diff --git a/cmd/dashboard/controller/jwt.go b/cmd/dashboard/controller/jwt.go index 2e4a462..72a261e 100644 --- a/cmd/dashboard/controller/jwt.go +++ b/cmd/dashboard/controller/jwt.go @@ -88,18 +88,21 @@ func authenticator() func(c *gin.Context) (interface{}, error) { } var user model.User + realip := c.GetString(model.CtxKeyRealIPStr) if err := singleton.DB.Select("id", "password").Where("username = ?", loginVals.Username).First(&user).Error; err != nil { if err == gorm.ErrRecordNotFound { - model.BlockIP(singleton.DB, c.GetString(model.CtxKeyRealIPStr), model.WAFBlockReasonTypeLoginFail) + model.BlockIP(singleton.DB, realip, model.WAFBlockReasonTypeLoginFail, model.BlockIDUnknownUser) } return nil, jwt.ErrFailedAuthentication } if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(loginVals.Password)); err != nil { - model.BlockIP(singleton.DB, c.GetString(model.CtxKeyRealIPStr), model.WAFBlockReasonTypeLoginFail) + model.BlockIP(singleton.DB, realip, model.WAFBlockReasonTypeLoginFail, int64(user.ID)) return nil, jwt.ErrFailedAuthentication } + model.ClearIP(singleton.DB, realip, model.BlockIDUnknownUser) + model.ClearIP(singleton.DB, realip, int64(user.ID)) return utils.Itoa(user.ID), nil } } @@ -169,10 +172,10 @@ func optionalAuthMiddleware(mw *jwt.GinJWTMiddleware) func(c *gin.Context) { identity := mw.IdentityHandler(c) if identity != nil { - model.ClearIP(singleton.DB, c.GetString(model.CtxKeyRealIPStr)) + model.ClearIP(singleton.DB, c.GetString(model.CtxKeyRealIPStr), model.BlockIDToken) c.Set(mw.IdentityKey, identity) } else { - if err := model.BlockIP(singleton.DB, c.GetString(model.CtxKeyRealIPStr), model.WAFBlockReasonTypeBruteForceToken); err != nil { + if err := model.BlockIP(singleton.DB, c.GetString(model.CtxKeyRealIPStr), model.WAFBlockReasonTypeBruteForceToken, model.BlockIDToken); err != nil { waf.ShowBlockPage(c, err) return } diff --git a/cmd/dashboard/controller/waf.go b/cmd/dashboard/controller/waf.go index cd37218..cd6fe9c 100644 --- a/cmd/dashboard/controller/waf.go +++ b/cmd/dashboard/controller/waf.go @@ -1,6 +1,8 @@ package controller import ( + "strconv" + "github.com/gin-gonic/gin" "github.com/nezhahq/nezha/model" @@ -13,12 +15,21 @@ import ( // @Schemes // @Description List server // @Tags auth required +// @Param page query uint false "Page number" // @Produce json // @Success 200 {object} model.CommonResponse[[]model.WAFApiMock] // @Router /waf [get] func listBlockedAddress(c *gin.Context) ([]*model.WAF, error) { + const pageSize = 25 + + page, err := strconv.ParseUint(c.Query("page"), 10, 64) + if err != nil || page < 1 { + page = 1 + } + offset := (page - 1) * pageSize + var waf []*model.WAF - if err := singleton.DB.Find(&waf).Error; err != nil { + if err := singleton.DB.Limit(pageSize).Offset(int(offset)).Find(&waf).Error; err != nil { return nil, err } diff --git a/model/user.go b/model/user.go index 3e987d3..5268298 100644 --- a/model/user.go +++ b/model/user.go @@ -19,6 +19,10 @@ type User struct { } func (u *User) BeforeSave(tx *gorm.DB) error { + if u.AgentSecret != "" { + return nil + } + key, err := utils.GenerateRandomString(32) if err != nil { return err diff --git a/model/waf.go b/model/waf.go index 63f5032..9e66a40 100644 --- a/model/waf.go +++ b/model/waf.go @@ -16,22 +16,30 @@ const ( WAFBlockReasonTypeAgentAuthFail ) +const ( + BlockIDgRPC = -127 + iota + BlockIDToken + BlockIDUnknownUser +) + type WAFApiMock struct { - IP string `json:"ip,omitempty"` - Count uint64 `json:"count,omitempty"` - LastBlockReason uint8 `json:"last_block_reason,omitempty"` - LastBlockTimestamp uint64 `json:"last_block_timestamp,omitempty"` + ID uint64 `json:"id,omitempty"` + IP string `json:"ip,omitempty"` + BlockReason uint8 `json:"block_reason,omitempty"` + BlockTimestamp uint64 `json:"block_timestamp,omitempty"` + BlockIdentifier uint64 `json:"block_identifier,omitempty"` } type WAF struct { - IP []byte `gorm:"type:binary(16);primaryKey" json:"ip,omitempty"` - Count uint64 `json:"count,omitempty"` - LastBlockReason uint8 `json:"last_block_reason,omitempty"` - LastBlockTimestamp uint64 `json:"last_block_timestamp,omitempty"` + ID uint64 `gorm:"primaryKey" json:"id,omitempty"` + IP []byte `gorm:"type:binary(16);index:idx_block_identifier" json:"ip,omitempty"` + BlockReason uint8 `json:"block_reason,omitempty"` + BlockTimestamp uint64 `json:"block_timestamp,omitempty"` + BlockIdentifier int64 `gorm:"index:idx_block_identifier" json:"block_identifier,omitempty"` } func (w *WAF) TableName() string { - return "waf" + return "nz_waf" } func CheckIP(db *gorm.DB, ip string) error { @@ -42,22 +50,31 @@ func CheckIP(db *gorm.DB, ip string) error { if err != nil { return err } - var w WAF - result := db.Limit(1).Find(&w, "ip = ?", ipBinary) - if result.Error != nil { + + var blockTimestamp uint64 + result := db.Model(&WAF{}).Select("block_timestamp").Order("id desc").Where("ip = ?", ipBinary).First(&blockTimestamp) + if result.Error != nil && result.Error != gorm.ErrRecordNotFound { return result.Error } - if result.RowsAffected == 0 { // 检查是否未找到记录 + + // 检查是否未找到记录 + if result.RowsAffected < 1 { return nil } + + var count int64 + if err := db.Model(&WAF{}).Where("ip = ?", ipBinary).Count(&count).Error; err != nil { + return err + } + now := time.Now().Unix() - if powAdd(w.Count, 4, w.LastBlockTimestamp) > uint64(now) { + if powAdd(uint64(count), 4, blockTimestamp) > uint64(now) { return errors.New("you are blocked by nezha WAF") } return nil } -func ClearIP(db *gorm.DB, ip string) error { +func ClearIP(db *gorm.DB, ip string, uid int64) error { if ip == "" { return nil } @@ -65,7 +82,7 @@ func ClearIP(db *gorm.DB, ip string) error { if err != nil { return err } - return db.Unscoped().Delete(&WAF{}, "ip = ?", ipBinary).Error + return db.Unscoped().Delete(&WAF{}, "ip = ? and block_identifier = ?", ipBinary, uid).Error } func BatchClearIP(db *gorm.DB, ip []string) error { @@ -83,7 +100,7 @@ func BatchClearIP(db *gorm.DB, ip []string) error { return db.Unscoped().Delete(&WAF{}, "ip in (?)", ips).Error } -func BlockIP(db *gorm.DB, ip string, reason uint8) error { +func BlockIP(db *gorm.DB, ip string, reason uint8, uid int64) error { if ip == "" { return nil } @@ -91,16 +108,20 @@ func BlockIP(db *gorm.DB, ip string, reason uint8) error { if err != nil { return err } - var w WAF - w.IP = ipBinary + w := WAF{ + IP: ipBinary, + BlockReason: reason, + BlockTimestamp: uint64(time.Now().Unix()), + BlockIdentifier: uid, + } return db.Transaction(func(tx *gorm.DB) error { - if err := tx.Where(&w).Attrs(WAF{ - LastBlockReason: reason, - LastBlockTimestamp: uint64(time.Now().Unix()), - }).FirstOrCreate(&w).Error; err != nil { - return err + var lastRecord WAF + if err := tx.Model(&WAF{}).Order("id desc").Where("ip = ?", ipBinary).First(&lastRecord).Error; err != nil { + if !errors.Is(err, gorm.ErrRecordNotFound) { + return err + } } - return tx.Exec("UPDATE waf SET count = count + 1, last_block_reason = ?, last_block_timestamp = ? WHERE ip = ?", reason, uint64(time.Now().Unix()), ipBinary).Error + return tx.Create(&w).Error }) } diff --git a/service/rpc/auth.go b/service/rpc/auth.go index 817f981..1f68090 100644 --- a/service/rpc/auth.go +++ b/service/rpc/auth.go @@ -41,12 +41,12 @@ func (a *authHandler) Check(ctx context.Context) (uint64, error) { userId, ok := singleton.AgentSecretToUserId[clientSecret] if !ok && subtle.ConstantTimeCompare([]byte(clientSecret), []byte(singleton.Conf.AgentSecretKey)) != 1 { singleton.UserLock.RUnlock() - model.BlockIP(singleton.DB, ip, model.WAFBlockReasonTypeAgentAuthFail) + model.BlockIP(singleton.DB, ip, model.WAFBlockReasonTypeAgentAuthFail, model.BlockIDgRPC) return 0, status.Error(codes.Unauthenticated, "客户端认证失败") } singleton.UserLock.RUnlock() - model.ClearIP(singleton.DB, ip) + model.ClearIP(singleton.DB, ip, model.BlockIDgRPC) var clientUUID string if value, ok := md["client_uuid"]; ok {