nezha/model/waf.go

150 lines
3.5 KiB
Go
Raw Normal View History

2024-11-22 10:57:25 -05:00
package model
import (
"errors"
2024-11-22 11:58:15 -05:00
"math/big"
2024-11-22 10:57:25 -05:00
"time"
2024-11-28 06:38:54 -05:00
"github.com/nezhahq/nezha/pkg/utils"
2024-11-22 10:57:25 -05:00
"gorm.io/gorm"
)
const (
_ uint8 = iota
WAFBlockReasonTypeLoginFail
WAFBlockReasonTypeBruteForceToken
WAFBlockReasonTypeAgentAuthFail
2024-12-21 12:08:07 -05:00
WAFBlockReasonTypeManual
2024-12-28 10:50:59 -05:00
WAFBlockReasonTypeBruteForceOauth2
2024-11-22 10:57:25 -05:00
)
const (
BlockIDgRPC = -127 + iota
BlockIDToken
BlockIDUnknownUser
2024-12-21 12:08:07 -05:00
BlockIDManual
)
2024-11-30 08:33:18 -05:00
type WAFApiMock struct {
IP string `json:"ip,omitempty"`
BlockIdentifier int64 `json:"block_identifier,omitempty"`
BlockReason uint8 `json:"block_reason,omitempty"`
BlockTimestamp uint64 `json:"block_timestamp,omitempty"`
Count uint64 `json:"count,omitempty"`
2024-11-30 08:33:18 -05:00
}
2024-11-22 10:57:25 -05:00
type WAF struct {
IP []byte `gorm:"type:binary(16);primaryKey" json:"ip,omitempty"`
BlockIdentifier int64 `gorm:"primaryKey" json:"block_identifier,omitempty"`
BlockReason uint8 `json:"block_reason,omitempty"`
BlockTimestamp uint64 `gorm:"index" json:"block_timestamp,omitempty"`
Count uint64 `json:"count,omitempty"`
2024-11-22 10:57:25 -05:00
}
func (w *WAF) TableName() string {
return "nz_waf"
2024-11-22 10:57:25 -05:00
}
2024-11-22 11:58:15 -05:00
func CheckIP(db *gorm.DB, ip string) error {
2024-11-22 21:21:01 -05:00
if ip == "" {
return nil
}
ipBinary, err := utils.IPStringToBinary(ip)
2024-11-22 11:58:15 -05:00
if err != nil {
return err
}
var blockTimestamp uint64
result := db.Model(&WAF{}).Order("block_timestamp desc").Select("block_timestamp").Where("ip = ?", ipBinary).Limit(1).Find(&blockTimestamp)
if result.Error != nil {
return result.Error
}
// 检查是否未找到记录
if result.RowsAffected < 1 {
return nil
2024-11-22 11:58:15 -05:00
}
var count uint64
if err := db.Model(&WAF{}).Select("SUM(count)").Where("ip = ?", ipBinary).Scan(&count).Error; err != nil {
return err
}
2024-11-22 11:58:15 -05:00
now := time.Now().Unix()
if powAdd(count, 4, blockTimestamp) > uint64(now) {
2025-01-01 03:14:19 -05:00
return errors.New("you were blocked by nezha WAF")
2024-11-22 11:58:15 -05:00
}
return nil
}
2024-12-25 08:02:54 -05:00
func UnblockIP(db *gorm.DB, ip string, uid int64) error {
2024-11-22 21:21:01 -05:00
if ip == "" {
return nil
}
ipBinary, err := utils.IPStringToBinary(ip)
if err != nil {
return err
}
return db.Unscoped().Delete(&WAF{}, "ip = ? and block_identifier = ?", ipBinary, uid).Error
}
2024-12-25 08:02:54 -05:00
func BatchUnblockIP(db *gorm.DB, ip []string) error {
if len(ip) < 1 {
return nil
}
ips := make([][]byte, 0, len(ip))
for _, s := range ip {
ipBinary, err := utils.IPStringToBinary(s)
if err != nil {
continue
}
ips = append(ips, ipBinary)
}
return db.Unscoped().Delete(&WAF{}, "ip in (?)", ips).Error
2024-11-22 21:21:01 -05:00
}
func BlockIP(db *gorm.DB, ip string, reason uint8, uid int64) error {
2024-11-22 10:57:25 -05:00
if ip == "" {
2024-11-22 21:21:01 -05:00
return nil
2024-11-22 10:57:25 -05:00
}
2024-11-22 21:21:01 -05:00
ipBinary, err := utils.IPStringToBinary(ip)
2024-11-22 11:58:15 -05:00
if err != nil {
return err
}
w := WAF{
IP: ipBinary,
BlockIdentifier: uid,
}
now := uint64(time.Now().Unix())
var count interface{}
if reason == WAFBlockReasonTypeManual {
count = 99999
} else {
count = gorm.Expr("count + 1")
}
2024-11-22 10:57:25 -05:00
return db.Transaction(func(tx *gorm.DB) error {
2024-11-22 21:21:01 -05:00
if err := tx.Where(&w).Attrs(WAF{
BlockReason: reason,
BlockTimestamp: now,
2024-11-22 21:21:01 -05:00
}).FirstOrCreate(&w).Error; err != nil {
2024-11-22 10:57:25 -05:00
return err
}
return tx.Exec("UPDATE nz_waf SET count = ?, block_reason = ?, block_timestamp = ? WHERE ip = ? and block_identifier = ?", count, reason, now, ipBinary, uid).Error
2024-11-22 10:57:25 -05:00
})
}
2024-11-22 11:58:15 -05:00
2024-11-22 21:21:01 -05:00
func powAdd(x, y, z uint64) uint64 {
2024-11-22 11:58:15 -05:00
base := big.NewInt(0).SetUint64(x)
exp := big.NewInt(0).SetUint64(y)
result := big.NewInt(1)
result.Exp(base, exp, nil)
2024-11-22 21:21:01 -05:00
result.Add(result, big.NewInt(0).SetUint64(z))
2024-11-22 11:58:15 -05:00
if !result.IsUint64() {
return ^uint64(0) // return max uint64 value on overflow
}
2024-11-22 21:21:01 -05:00
ret := result.Uint64()
return utils.IfOr(ret < z+3, z+3, ret)
2024-11-22 11:58:15 -05:00
}