mirror of
https://github.com/nezhahq/nezha.git
synced 2025-01-22 12:48:14 -05:00
refactor: ip data type
This commit is contained in:
parent
68f6da436d
commit
e7679a3fa6
@ -2,18 +2,14 @@ package waf
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"errors"
|
||||
"log"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/naiba/nezha/model"
|
||||
"github.com/naiba/nezha/service/singleton"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
//go:embed waf.html
|
||||
@ -55,33 +51,13 @@ func Waf(c *gin.Context) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
var w model.WAF
|
||||
if err := singleton.DB.First(&w, "ip = ?", realipAddr).Error; err != nil {
|
||||
if err != gorm.ErrRecordNotFound {
|
||||
ShowBlockPage(c, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
now := time.Now().Unix()
|
||||
if w.LastBlockTimestamp+pow(w.Count, 4) > uint64(now) {
|
||||
log.Println(w.Count, w.LastBlockTimestamp+pow(w.Count, 4)-uint64(now))
|
||||
ShowBlockPage(c, errors.New("you are blocked by nezha WAF"))
|
||||
if err := model.CheckIP(singleton.DB, realipAddr); err != nil {
|
||||
ShowBlockPage(c, err)
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
|
||||
func pow(x, y uint64) uint64 {
|
||||
base := big.NewInt(0).SetUint64(x)
|
||||
exp := big.NewInt(0).SetUint64(y)
|
||||
result := big.NewInt(1)
|
||||
result.Exp(base, exp, nil)
|
||||
if !result.IsUint64() {
|
||||
return ^uint64(0) // return max uint64 value on overflow
|
||||
}
|
||||
return result.Uint64()
|
||||
}
|
||||
|
||||
func ShowBlockPage(c *gin.Context, err error) {
|
||||
c.Writer.WriteHeader(http.StatusForbidden)
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
|
@ -1,29 +0,0 @@
|
||||
package waf
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPow(t *testing.T) {
|
||||
tests := []struct {
|
||||
x,
|
||||
y,
|
||||
expect uint64
|
||||
}{
|
||||
{2, 64, math.MaxUint64}, // 2 的 64 次方,超过 uint64 最大值
|
||||
{uint64(1 << 63), 2, math.MaxUint64}, // 大数平方,可能溢出
|
||||
{uint64(^uint64(0)), 2, math.MaxUint64}, // uint64 最大值的平方,溢出
|
||||
{2, 3, 8},
|
||||
{5, 0, 1},
|
||||
{3, 1, 3},
|
||||
{0, 5, 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := pow(tt.x, tt.y)
|
||||
if result != tt.expect {
|
||||
t.Errorf("pow(%d, %d) = %d; expect %d", tt.x, tt.y, result, tt.expect)
|
||||
}
|
||||
}
|
||||
}
|
59
model/waf.go
59
model/waf.go
@ -2,6 +2,9 @@ package model
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log"
|
||||
"math/big"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
@ -14,7 +17,7 @@ const (
|
||||
)
|
||||
|
||||
type WAF struct {
|
||||
IP string `gorm:"type:binary(16);primaryKey" json:"ip,omitempty"`
|
||||
IP []byte `gorm:"type:binary(16);primaryKey" json:"ip,omitempty"`
|
||||
Count uint64 `json:"count,omitempty"`
|
||||
LastBlockReason uint8 `json:"last_block_reason,omitempty"`
|
||||
LastBlockTimestamp uint64 `json:"last_block_timestamp,omitempty"`
|
||||
@ -24,17 +27,67 @@ func (w *WAF) TableName() string {
|
||||
return "waf"
|
||||
}
|
||||
|
||||
func ipStringToBinary(ip string) ([]byte, error) {
|
||||
addr, err := netip.ParseAddr(ip)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
b := addr.As16()
|
||||
return b[:], nil
|
||||
}
|
||||
|
||||
func binaryToIPString(b []byte) string {
|
||||
var addr16 [16]byte
|
||||
copy(addr16[:], b)
|
||||
addr := netip.AddrFrom16(addr16)
|
||||
return addr.Unmap().String()
|
||||
}
|
||||
|
||||
func CheckIP(db *gorm.DB, ip string) error {
|
||||
ipBinary, err := ipStringToBinary(ip)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var w WAF
|
||||
if err := db.First(&w, "ip = ?", ipBinary).Error; err != nil {
|
||||
if err != gorm.ErrRecordNotFound {
|
||||
return err
|
||||
}
|
||||
}
|
||||
now := time.Now().Unix()
|
||||
if w.LastBlockTimestamp+pow(w.Count, 4) > uint64(now) {
|
||||
log.Println(w.Count, w.LastBlockTimestamp+pow(w.Count, 4)-uint64(now))
|
||||
return errors.New("you are blocked by nezha WAF")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func BlockIP(db *gorm.DB, ip string, reason uint8) error {
|
||||
if ip == "" {
|
||||
return errors.New("empty ip")
|
||||
}
|
||||
ipBinary, err := ipStringToBinary(ip)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var w WAF
|
||||
w.LastBlockReason = reason
|
||||
w.LastBlockTimestamp = uint64(time.Now().Unix())
|
||||
return db.Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.FirstOrCreate(&w, WAF{IP: ip}).Error; err != nil {
|
||||
if err := tx.FirstOrCreate(&w, WAF{IP: ipBinary}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Exec("UPDATE waf SET count = count + 1 WHERE ip = ?", ip).Error
|
||||
return tx.Exec("UPDATE waf SET count = count + 1 WHERE ip = ?", ipBinary).Error
|
||||
})
|
||||
}
|
||||
|
||||
func pow(x, y uint64) uint64 {
|
||||
base := big.NewInt(0).SetUint64(x)
|
||||
exp := big.NewInt(0).SetUint64(y)
|
||||
result := big.NewInt(1)
|
||||
result.Exp(base, exp, nil)
|
||||
if !result.IsUint64() {
|
||||
return ^uint64(0) // return max uint64 value on overflow
|
||||
}
|
||||
return result.Uint64()
|
||||
}
|
||||
|
83
model/waf_test.go
Normal file
83
model/waf_test.go
Normal file
@ -0,0 +1,83 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIPStringToBinary(t *testing.T) {
|
||||
cases := []struct {
|
||||
ip string
|
||||
want []byte
|
||||
expectError bool
|
||||
}{
|
||||
// 有效的 IPv4 地址
|
||||
{
|
||||
ip: "192.168.1.1",
|
||||
want: []byte{
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 192, 168, 1, 1,
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
// 有效的 IPv6 地址
|
||||
{
|
||||
ip: "2001:db8::68",
|
||||
want: []byte{
|
||||
32, 1, 13, 184, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 104,
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
// 无效的 IP 地址
|
||||
{
|
||||
ip: "invalid_ip",
|
||||
want: []byte{},
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
got, err := ipStringToBinary(c.ip)
|
||||
if (err != nil) != c.expectError {
|
||||
t.Errorf("IPStringToBinary(%q) error = %v, expect error = %v", c.ip, err, c.expectError)
|
||||
continue
|
||||
}
|
||||
if err == nil && !reflect.DeepEqual(got, c.want) {
|
||||
t.Errorf("IPStringToBinary(%q) = %v, want %v", c.ip, got, c.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBinaryToIPString(t *testing.T) {
|
||||
cases := []struct {
|
||||
binary []byte
|
||||
want string
|
||||
}{
|
||||
// IPv4 地址(IPv4 映射的 IPv6 地址格式)
|
||||
{
|
||||
binary: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 192, 168, 1, 1},
|
||||
want: "192.168.1.1",
|
||||
},
|
||||
// 其他测试用例
|
||||
{
|
||||
binary: []byte{32, 1, 13, 184, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 104},
|
||||
want: "2001:db8::68",
|
||||
},
|
||||
// 全零值
|
||||
{
|
||||
binary: []byte{},
|
||||
want: "::",
|
||||
},
|
||||
// IPv4 映射的 IPv6 地址
|
||||
{
|
||||
binary: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 127, 0, 0, 1},
|
||||
want: "127.0.0.1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
got := binaryToIPString(c.binary)
|
||||
if got != c.want {
|
||||
t.Errorf("BinaryToIPString(%v) = %q, 期望 %q", c.binary, got, c.want)
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user