diff --git a/cmd/dashboard/controller/jwt.go b/cmd/dashboard/controller/jwt.go index 349ca9f..06c46af 100644 --- a/cmd/dashboard/controller/jwt.go +++ b/cmd/dashboard/controller/jwt.go @@ -166,11 +166,13 @@ func optionalAuthMiddleware(mw *jwt.GinJWTMiddleware) func(c *gin.Context) { identity := mw.IdentityHandler(c) if identity != nil { + model.ClearIP(singleton.DB, c.GetString(model.CtxKeyRealIPStr)) + c.Set(mw.IdentityKey, identity) + } else { if err := model.BlockIP(singleton.DB, c.GetString(model.CtxKeyRealIPStr), model.WAFBlockReasonTypeBruteForceToken); err != nil { waf.ShowBlockPage(c, err) return } - c.Set(mw.IdentityKey, identity) } c.Next() diff --git a/cmd/dashboard/controller/waf/waf.go b/cmd/dashboard/controller/waf/waf.go index 38d720c..9f47969 100644 --- a/cmd/dashboard/controller/waf/waf.go +++ b/cmd/dashboard/controller/waf/waf.go @@ -3,12 +3,12 @@ package waf import ( _ "embed" "net/http" - "net/netip" "strings" "github.com/gin-gonic/gin" "github.com/naiba/nezha/model" + "github.com/naiba/nezha/pkg/utils" "github.com/naiba/nezha/service/singleton" ) @@ -32,26 +32,17 @@ func RealIp(c *gin.Context) { c.AbortWithStatusJSON(http.StatusOK, model.CommonResponse[any]{Success: false, Error: "real ip header not found"}) return } - ip, err := netip.ParseAddrPort(vals) + ip, err := utils.GetIPFromHeader(vals) if err != nil { c.AbortWithStatusJSON(http.StatusOK, model.CommonResponse[any]{Success: false, Error: err.Error()}) return } - c.Set(model.CtxKeyRealIPStr, ip.Addr().String()) + c.Set(model.CtxKeyRealIPStr, ip) c.Next() } func Waf(c *gin.Context) { - if singleton.Conf.RealIPHeader == "" { - c.Next() - return - } - realipAddr := c.GetString(model.CtxKeyRealIPStr) - if realipAddr == "" { - c.Next() - return - } - if err := model.CheckIP(singleton.DB, realipAddr); err != nil { + if err := model.CheckIP(singleton.DB, c.GetString(model.CtxKeyRealIPStr)); err != nil { ShowBlockPage(c, err) return } diff --git a/cmd/dashboard/controller/waf/waf.html b/cmd/dashboard/controller/waf/waf.html index c278295..518417c 100644 --- a/cmd/dashboard/controller/waf/waf.html +++ b/cmd/dashboard/controller/waf/waf.html @@ -24,6 +24,13 @@ font-size: 12px; color: #888; } + + @media (prefers-color-scheme: dark) { + body { + background-color: #111; + color: #007C41 + } + } diff --git a/cmd/dashboard/rpc/rpc.go b/cmd/dashboard/rpc/rpc.go index 00eb298..9a70d89 100644 --- a/cmd/dashboard/rpc/rpc.go +++ b/cmd/dashboard/rpc/rpc.go @@ -5,7 +5,6 @@ import ( "fmt" "net/http" "net/netip" - "strings" "time" "google.golang.org/grpc" @@ -49,13 +48,11 @@ func getRealIp(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, if len(vals) == 0 { return nil, fmt.Errorf("real ip header not found") } - a := strings.Split(vals[0], ",") - h := strings.TrimSpace(a[len(a)-1]) - ip, err := netip.ParseAddrPort(h) + ip, err := utils.GetIPFromHeader(vals[0]) if err != nil { return nil, err } - ctx = context.WithValue(ctx, model.CtxKeyRealIP{}, ip.Addr().String()) + ctx = context.WithValue(ctx, model.CtxKeyRealIP{}, ip) return handler(ctx, req) } diff --git a/model/waf.go b/model/waf.go index 6c53b98..d71f5b1 100644 --- a/model/waf.go +++ b/model/waf.go @@ -2,11 +2,10 @@ package model import ( "errors" - "log" "math/big" - "net/netip" "time" + "github.com/naiba/nezha/pkg/utils" "gorm.io/gorm" ) @@ -27,24 +26,11 @@ func (w *WAF) TableName() string { return "waf" } -func ipStringToBinary(ip string) ([]byte, error) { - addr, err := netip.ParseAddr(ip) - if err != nil { - return nil, err - } - b := addr.As16() - return b[:], nil -} - -func binaryToIPString(b []byte) string { - var addr16 [16]byte - copy(addr16[:], b) - addr := netip.AddrFrom16(addr16) - return addr.Unmap().String() -} - func CheckIP(db *gorm.DB, ip string) error { - ipBinary, err := ipStringToBinary(ip) + if ip == "" { + return nil + } + ipBinary, err := utils.IPStringToBinary(ip) if err != nil { return err } @@ -55,39 +41,53 @@ func CheckIP(db *gorm.DB, ip string) error { } } now := time.Now().Unix() - if w.LastBlockTimestamp+pow(w.Count, 4) > uint64(now) { - log.Println(w.Count, w.LastBlockTimestamp+pow(w.Count, 4)-uint64(now)) + if powAdd(w.Count, 4, w.LastBlockTimestamp) > uint64(now) { return errors.New("you are blocked by nezha WAF") } return nil } +func ClearIP(db *gorm.DB, ip string) error { + if ip == "" { + return nil + } + ipBinary, err := utils.IPStringToBinary(ip) + if err != nil { + return err + } + return db.Delete(&WAF{}, "ip = ?", ipBinary).Error +} + func BlockIP(db *gorm.DB, ip string, reason uint8) error { if ip == "" { - return errors.New("empty ip") + return nil } - ipBinary, err := ipStringToBinary(ip) + ipBinary, err := utils.IPStringToBinary(ip) if err != nil { return err } var w WAF - w.LastBlockReason = reason - w.LastBlockTimestamp = uint64(time.Now().Unix()) + w.IP = ipBinary return db.Transaction(func(tx *gorm.DB) error { - if err := tx.FirstOrCreate(&w, WAF{IP: ipBinary}).Error; err != nil { + if err := tx.Where(&w).Attrs(WAF{ + LastBlockReason: reason, + LastBlockTimestamp: uint64(time.Now().Unix()), + }).FirstOrCreate(&w).Error; err != nil { return err } - return tx.Exec("UPDATE waf SET count = count + 1 WHERE ip = ?", ipBinary).Error + return tx.Exec("UPDATE waf SET count = count + 1, last_block_reason = ?, last_block_timestamp = ? WHERE ip = ?", reason, uint64(time.Now().Unix()), ipBinary).Error }) } -func pow(x, y uint64) uint64 { +func powAdd(x, y, z uint64) uint64 { base := big.NewInt(0).SetUint64(x) exp := big.NewInt(0).SetUint64(y) result := big.NewInt(1) result.Exp(base, exp, nil) + result.Add(result, big.NewInt(0).SetUint64(z)) if !result.IsUint64() { return ^uint64(0) // return max uint64 value on overflow } - return result.Uint64() + ret := result.Uint64() + return utils.IfOr(ret < z+3, z+3, ret) } diff --git a/model/waf_test.go b/model/waf_test.go deleted file mode 100644 index 0d18c25..0000000 --- a/model/waf_test.go +++ /dev/null @@ -1,83 +0,0 @@ -package model - -import ( - "reflect" - "testing" -) - -func TestIPStringToBinary(t *testing.T) { - cases := []struct { - ip string - want []byte - expectError bool - }{ - // 有效的 IPv4 地址 - { - ip: "192.168.1.1", - want: []byte{ - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 192, 168, 1, 1, - }, - expectError: false, - }, - // 有效的 IPv6 地址 - { - ip: "2001:db8::68", - want: []byte{ - 32, 1, 13, 184, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 104, - }, - expectError: false, - }, - // 无效的 IP 地址 - { - ip: "invalid_ip", - want: []byte{}, - expectError: true, - }, - } - - for _, c := range cases { - got, err := ipStringToBinary(c.ip) - if (err != nil) != c.expectError { - t.Errorf("IPStringToBinary(%q) error = %v, expect error = %v", c.ip, err, c.expectError) - continue - } - if err == nil && !reflect.DeepEqual(got, c.want) { - t.Errorf("IPStringToBinary(%q) = %v, want %v", c.ip, got, c.want) - } - } -} - -func TestBinaryToIPString(t *testing.T) { - cases := []struct { - binary []byte - want string - }{ - // IPv4 地址(IPv4 映射的 IPv6 地址格式) - { - binary: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 192, 168, 1, 1}, - want: "192.168.1.1", - }, - // 其他测试用例 - { - binary: []byte{32, 1, 13, 184, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 104}, - want: "2001:db8::68", - }, - // 全零值 - { - binary: []byte{}, - want: "::", - }, - // IPv4 映射的 IPv6 地址 - { - binary: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 127, 0, 0, 1}, - want: "127.0.0.1", - }, - } - - for _, c := range cases { - got := binaryToIPString(c.binary) - if got != c.want { - t.Errorf("BinaryToIPString(%v) = %q, 期望 %q", c.binary, got, c.want) - } - } -} diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go index 7407129..4c04e95 100644 --- a/pkg/utils/utils.go +++ b/pkg/utils/utils.go @@ -2,7 +2,9 @@ package utils import ( "crypto/rand" + "errors" "math/big" + "net/netip" "os" "regexp" "strconv" @@ -37,6 +39,35 @@ func IPDesensitize(ipAddr string) string { return ipAddr } +func IPStringToBinary(ip string) ([]byte, error) { + addr, err := netip.ParseAddr(ip) + if err != nil { + return nil, err + } + b := addr.As16() + return b[:], nil +} + +func BinaryToIPString(b []byte) string { + var addr16 [16]byte + copy(addr16[:], b) + addr := netip.AddrFrom16(addr16) + return addr.Unmap().String() +} + +func GetIPFromHeader(headerValue string) (string, error) { + a := strings.Split(headerValue, ",") + h := strings.TrimSpace(a[len(a)-1]) + ip, err := netip.ParseAddrPort(h) + if err != nil { + return "", err + } + if !ip.IsValid() { + return "", errors.New("invalid ip") + } + return ip.Addr().String(), nil +} + // SplitIPAddr 传入/分割的v4v6混合地址,返回v4和v6地址与有效地址 func SplitIPAddr(v4v6Bundle string) (string, string, string) { ipList := strings.Split(v4v6Bundle, "/") diff --git a/pkg/utils/utils_test.go b/pkg/utils/utils_test.go index eabe1cf..34c6f8b 100644 --- a/pkg/utils/utils_test.go +++ b/pkg/utils/utils_test.go @@ -1,6 +1,7 @@ package utils import ( + "reflect" "testing" ) @@ -56,3 +57,80 @@ func TestGenerGenerateRandomString(t *testing.T) { generatedString[str] = true } } + +func TestIPStringToBinary(t *testing.T) { + cases := []struct { + ip string + want []byte + expectError bool + }{ + // 有效的 IPv4 地址 + { + ip: "192.168.1.1", + want: []byte{ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 192, 168, 1, 1, + }, + expectError: false, + }, + // 有效的 IPv6 地址 + { + ip: "2001:db8::68", + want: []byte{ + 32, 1, 13, 184, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 104, + }, + expectError: false, + }, + // 无效的 IP 地址 + { + ip: "invalid_ip", + want: []byte{}, + expectError: true, + }, + } + + for _, c := range cases { + got, err := IPStringToBinary(c.ip) + if (err != nil) != c.expectError { + t.Errorf("IPStringToBinary(%q) error = %v, expect error = %v", c.ip, err, c.expectError) + continue + } + if err == nil && !reflect.DeepEqual(got, c.want) { + t.Errorf("IPStringToBinary(%q) = %v, want %v", c.ip, got, c.want) + } + } +} + +func TestBinaryToIPString(t *testing.T) { + cases := []struct { + binary []byte + want string + }{ + // IPv4 地址(IPv4 映射的 IPv6 地址格式) + { + binary: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 192, 168, 1, 1}, + want: "192.168.1.1", + }, + // 其他测试用例 + { + binary: []byte{32, 1, 13, 184, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 104}, + want: "2001:db8::68", + }, + // 全零值 + { + binary: []byte{}, + want: "::", + }, + // IPv4 映射的 IPv6 地址 + { + binary: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 127, 0, 0, 1}, + want: "127.0.0.1", + }, + } + + for _, c := range cases { + got := BinaryToIPString(c.binary) + if got != c.want { + t.Errorf("BinaryToIPString(%v) = %q, 期望 %q", c.binary, got, c.want) + } + } +}