package model import ( "errors" "math/big" "time" "github.com/nezhahq/nezha/pkg/utils" "gorm.io/gorm" ) const ( _ uint8 = iota WAFBlockReasonTypeLoginFail WAFBlockReasonTypeBruteForceToken WAFBlockReasonTypeAgentAuthFail WAFBlockReasonTypeManual WAFBlockReasonTypeBruteForceOauth2 ) const ( BlockIDgRPC = -127 + iota BlockIDToken BlockIDUnknownUser BlockIDManual ) type WAFApiMock struct { IP string `json:"ip,omitempty"` BlockIdentifier int64 `json:"block_identifier,omitempty"` BlockReason uint8 `json:"block_reason,omitempty"` BlockTimestamp uint64 `json:"block_timestamp,omitempty"` Count uint64 `json:"count,omitempty"` } type WAF struct { IP []byte `gorm:"type:binary(16);primaryKey" json:"ip,omitempty"` BlockIdentifier int64 `gorm:"primaryKey" json:"block_identifier,omitempty"` BlockReason uint8 `json:"block_reason,omitempty"` BlockTimestamp uint64 `gorm:"index" json:"block_timestamp,omitempty"` Count uint64 `json:"count,omitempty"` } func (w *WAF) TableName() string { return "nz_waf" } func CheckIP(db *gorm.DB, ip string) error { if ip == "" { return nil } ipBinary, err := utils.IPStringToBinary(ip) if err != nil { return err } var blockTimestamp uint64 result := db.Model(&WAF{}).Order("block_timestamp desc").Select("block_timestamp").Where("ip = ?", ipBinary).Limit(1).Find(&blockTimestamp) if result.Error != nil { return result.Error } // 检查是否未找到记录 if result.RowsAffected < 1 { return nil } var count uint64 if err := db.Model(&WAF{}).Select("SUM(count)").Where("ip = ?", ipBinary).Scan(&count).Error; err != nil { return err } now := time.Now().Unix() if powAdd(count, 4, blockTimestamp) > uint64(now) { return errors.New("you were blocked by nezha WAF") } return nil } func UnblockIP(db *gorm.DB, ip string, uid int64) error { if ip == "" { return nil } ipBinary, err := utils.IPStringToBinary(ip) if err != nil { return err } return db.Unscoped().Delete(&WAF{}, "ip = ? and block_identifier = ?", ipBinary, uid).Error } func BatchUnblockIP(db *gorm.DB, ip []string) error { if len(ip) < 1 { return nil } ips := make([][]byte, 0, len(ip)) for _, s := range ip { ipBinary, err := utils.IPStringToBinary(s) if err != nil { continue } ips = append(ips, ipBinary) } return db.Unscoped().Delete(&WAF{}, "ip in (?)", ips).Error } func BlockIP(db *gorm.DB, ip string, reason uint8, uid int64) error { if ip == "" { return nil } ipBinary, err := utils.IPStringToBinary(ip) if err != nil { return err } w := WAF{ IP: ipBinary, BlockIdentifier: uid, } now := uint64(time.Now().Unix()) var count interface{} if reason == WAFBlockReasonTypeManual { count = 99999 } else { count = gorm.Expr("count + 1") } return db.Transaction(func(tx *gorm.DB) error { if err := tx.Where(&w).Attrs(WAF{ BlockReason: reason, BlockTimestamp: now, }).FirstOrCreate(&w).Error; err != nil { return err } return tx.Exec("UPDATE nz_waf SET count = ?, block_reason = ?, block_timestamp = ? WHERE ip = ? and block_identifier = ?", count, reason, now, ipBinary, uid).Error }) } func powAdd(x, y, z uint64) uint64 { base := big.NewInt(0).SetUint64(x) exp := big.NewInt(0).SetUint64(y) result := big.NewInt(1) result.Exp(base, exp, nil) result.Add(result, big.NewInt(0).SetUint64(z)) if !result.IsUint64() { return ^uint64(0) // return max uint64 value on overflow } ret := result.Uint64() return utils.IfOr(ret < z+3, z+3, ret) }