From e7679a3fa66a506746bdddd7e8f6902a32d48d0b Mon Sep 17 00:00:00 2001 From: naiba Date: Sat, 23 Nov 2024 00:58:15 +0800 Subject: [PATCH] refactor: ip data type --- cmd/dashboard/controller/waf/waf.go | 30 +-------- cmd/dashboard/controller/waf/waf_test.go | 29 --------- model/waf.go | 59 ++++++++++++++++- model/waf_test.go | 83 ++++++++++++++++++++++++ 4 files changed, 142 insertions(+), 59 deletions(-) delete mode 100644 cmd/dashboard/controller/waf/waf_test.go create mode 100644 model/waf_test.go diff --git a/cmd/dashboard/controller/waf/waf.go b/cmd/dashboard/controller/waf/waf.go index 0a58ed1..38d720c 100644 --- a/cmd/dashboard/controller/waf/waf.go +++ b/cmd/dashboard/controller/waf/waf.go @@ -2,18 +2,14 @@ package waf import ( _ "embed" - "errors" - "log" - "math/big" "net/http" "net/netip" "strings" - "time" "github.com/gin-gonic/gin" + "github.com/naiba/nezha/model" "github.com/naiba/nezha/service/singleton" - "gorm.io/gorm" ) //go:embed waf.html @@ -55,33 +51,13 @@ func Waf(c *gin.Context) { c.Next() return } - var w model.WAF - if err := singleton.DB.First(&w, "ip = ?", realipAddr).Error; err != nil { - if err != gorm.ErrRecordNotFound { - ShowBlockPage(c, err) - return - } - } - now := time.Now().Unix() - if w.LastBlockTimestamp+pow(w.Count, 4) > uint64(now) { - log.Println(w.Count, w.LastBlockTimestamp+pow(w.Count, 4)-uint64(now)) - ShowBlockPage(c, errors.New("you are blocked by nezha WAF")) + if err := model.CheckIP(singleton.DB, realipAddr); err != nil { + ShowBlockPage(c, err) return } c.Next() } -func pow(x, y uint64) uint64 { - base := big.NewInt(0).SetUint64(x) - exp := big.NewInt(0).SetUint64(y) - result := big.NewInt(1) - result.Exp(base, exp, nil) - if !result.IsUint64() { - return ^uint64(0) // return max uint64 value on overflow - } - return result.Uint64() -} - func ShowBlockPage(c *gin.Context, err error) { c.Writer.WriteHeader(http.StatusForbidden) c.Header("Content-Type", "text/html; charset=utf-8") diff --git a/cmd/dashboard/controller/waf/waf_test.go b/cmd/dashboard/controller/waf/waf_test.go deleted file mode 100644 index 9f43c63..0000000 --- a/cmd/dashboard/controller/waf/waf_test.go +++ /dev/null @@ -1,29 +0,0 @@ -package waf - -import ( - "math" - "testing" -) - -func TestPow(t *testing.T) { - tests := []struct { - x, - y, - expect uint64 - }{ - {2, 64, math.MaxUint64}, // 2 的 64 次方,超过 uint64 最大值 - {uint64(1 << 63), 2, math.MaxUint64}, // 大数平方,可能溢出 - {uint64(^uint64(0)), 2, math.MaxUint64}, // uint64 最大值的平方,溢出 - {2, 3, 8}, - {5, 0, 1}, - {3, 1, 3}, - {0, 5, 0}, - } - - for _, tt := range tests { - result := pow(tt.x, tt.y) - if result != tt.expect { - t.Errorf("pow(%d, %d) = %d; expect %d", tt.x, tt.y, result, tt.expect) - } - } -} diff --git a/model/waf.go b/model/waf.go index d179723..6c53b98 100644 --- a/model/waf.go +++ b/model/waf.go @@ -2,6 +2,9 @@ package model import ( "errors" + "log" + "math/big" + "net/netip" "time" "gorm.io/gorm" @@ -14,7 +17,7 @@ const ( ) type WAF struct { - IP string `gorm:"type:binary(16);primaryKey" json:"ip,omitempty"` + IP []byte `gorm:"type:binary(16);primaryKey" json:"ip,omitempty"` Count uint64 `json:"count,omitempty"` LastBlockReason uint8 `json:"last_block_reason,omitempty"` LastBlockTimestamp uint64 `json:"last_block_timestamp,omitempty"` @@ -24,17 +27,67 @@ func (w *WAF) TableName() string { return "waf" } +func ipStringToBinary(ip string) ([]byte, error) { + addr, err := netip.ParseAddr(ip) + if err != nil { + return nil, err + } + b := addr.As16() + return b[:], nil +} + +func binaryToIPString(b []byte) string { + var addr16 [16]byte + copy(addr16[:], b) + addr := netip.AddrFrom16(addr16) + return addr.Unmap().String() +} + +func CheckIP(db *gorm.DB, ip string) error { + ipBinary, err := ipStringToBinary(ip) + if err != nil { + return err + } + var w WAF + if err := db.First(&w, "ip = ?", ipBinary).Error; err != nil { + if err != gorm.ErrRecordNotFound { + return err + } + } + now := time.Now().Unix() + if w.LastBlockTimestamp+pow(w.Count, 4) > uint64(now) { + log.Println(w.Count, w.LastBlockTimestamp+pow(w.Count, 4)-uint64(now)) + return errors.New("you are blocked by nezha WAF") + } + return nil +} + func BlockIP(db *gorm.DB, ip string, reason uint8) error { if ip == "" { return errors.New("empty ip") } + ipBinary, err := ipStringToBinary(ip) + if err != nil { + return err + } var w WAF w.LastBlockReason = reason w.LastBlockTimestamp = uint64(time.Now().Unix()) return db.Transaction(func(tx *gorm.DB) error { - if err := tx.FirstOrCreate(&w, WAF{IP: ip}).Error; err != nil { + if err := tx.FirstOrCreate(&w, WAF{IP: ipBinary}).Error; err != nil { return err } - return tx.Exec("UPDATE waf SET count = count + 1 WHERE ip = ?", ip).Error + return tx.Exec("UPDATE waf SET count = count + 1 WHERE ip = ?", ipBinary).Error }) } + +func pow(x, y uint64) uint64 { + base := big.NewInt(0).SetUint64(x) + exp := big.NewInt(0).SetUint64(y) + result := big.NewInt(1) + result.Exp(base, exp, nil) + if !result.IsUint64() { + return ^uint64(0) // return max uint64 value on overflow + } + return result.Uint64() +} diff --git a/model/waf_test.go b/model/waf_test.go new file mode 100644 index 0000000..0d18c25 --- /dev/null +++ b/model/waf_test.go @@ -0,0 +1,83 @@ +package model + +import ( + "reflect" + "testing" +) + +func TestIPStringToBinary(t *testing.T) { + cases := []struct { + ip string + want []byte + expectError bool + }{ + // 有效的 IPv4 地址 + { + ip: "192.168.1.1", + want: []byte{ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 192, 168, 1, 1, + }, + expectError: false, + }, + // 有效的 IPv6 地址 + { + ip: "2001:db8::68", + want: []byte{ + 32, 1, 13, 184, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 104, + }, + expectError: false, + }, + // 无效的 IP 地址 + { + ip: "invalid_ip", + want: []byte{}, + expectError: true, + }, + } + + for _, c := range cases { + got, err := ipStringToBinary(c.ip) + if (err != nil) != c.expectError { + t.Errorf("IPStringToBinary(%q) error = %v, expect error = %v", c.ip, err, c.expectError) + continue + } + if err == nil && !reflect.DeepEqual(got, c.want) { + t.Errorf("IPStringToBinary(%q) = %v, want %v", c.ip, got, c.want) + } + } +} + +func TestBinaryToIPString(t *testing.T) { + cases := []struct { + binary []byte + want string + }{ + // IPv4 地址(IPv4 映射的 IPv6 地址格式) + { + binary: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 192, 168, 1, 1}, + want: "192.168.1.1", + }, + // 其他测试用例 + { + binary: []byte{32, 1, 13, 184, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 104}, + want: "2001:db8::68", + }, + // 全零值 + { + binary: []byte{}, + want: "::", + }, + // IPv4 映射的 IPv6 地址 + { + binary: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 127, 0, 0, 1}, + want: "127.0.0.1", + }, + } + + for _, c := range cases { + got := binaryToIPString(c.binary) + if got != c.want { + t.Errorf("BinaryToIPString(%v) = %q, 期望 %q", c.binary, got, c.want) + } + } +}