fix: waf condition

This commit is contained in:
naiba 2024-11-23 10:21:01 +08:00
parent 867f840265
commit cd42b1b9d5
8 changed files with 154 additions and 131 deletions

View File

@ -166,11 +166,13 @@ func optionalAuthMiddleware(mw *jwt.GinJWTMiddleware) func(c *gin.Context) {
identity := mw.IdentityHandler(c) identity := mw.IdentityHandler(c)
if identity != nil { if identity != nil {
model.ClearIP(singleton.DB, c.GetString(model.CtxKeyRealIPStr))
c.Set(mw.IdentityKey, identity)
} else {
if err := model.BlockIP(singleton.DB, c.GetString(model.CtxKeyRealIPStr), model.WAFBlockReasonTypeBruteForceToken); err != nil { if err := model.BlockIP(singleton.DB, c.GetString(model.CtxKeyRealIPStr), model.WAFBlockReasonTypeBruteForceToken); err != nil {
waf.ShowBlockPage(c, err) waf.ShowBlockPage(c, err)
return return
} }
c.Set(mw.IdentityKey, identity)
} }
c.Next() c.Next()

View File

@ -3,12 +3,12 @@ package waf
import ( import (
_ "embed" _ "embed"
"net/http" "net/http"
"net/netip"
"strings" "strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/naiba/nezha/model" "github.com/naiba/nezha/model"
"github.com/naiba/nezha/pkg/utils"
"github.com/naiba/nezha/service/singleton" "github.com/naiba/nezha/service/singleton"
) )
@ -32,26 +32,17 @@ func RealIp(c *gin.Context) {
c.AbortWithStatusJSON(http.StatusOK, model.CommonResponse[any]{Success: false, Error: "real ip header not found"}) c.AbortWithStatusJSON(http.StatusOK, model.CommonResponse[any]{Success: false, Error: "real ip header not found"})
return return
} }
ip, err := netip.ParseAddrPort(vals) ip, err := utils.GetIPFromHeader(vals)
if err != nil { if err != nil {
c.AbortWithStatusJSON(http.StatusOK, model.CommonResponse[any]{Success: false, Error: err.Error()}) c.AbortWithStatusJSON(http.StatusOK, model.CommonResponse[any]{Success: false, Error: err.Error()})
return return
} }
c.Set(model.CtxKeyRealIPStr, ip.Addr().String()) c.Set(model.CtxKeyRealIPStr, ip)
c.Next() c.Next()
} }
func Waf(c *gin.Context) { func Waf(c *gin.Context) {
if singleton.Conf.RealIPHeader == "" { if err := model.CheckIP(singleton.DB, c.GetString(model.CtxKeyRealIPStr)); err != nil {
c.Next()
return
}
realipAddr := c.GetString(model.CtxKeyRealIPStr)
if realipAddr == "" {
c.Next()
return
}
if err := model.CheckIP(singleton.DB, realipAddr); err != nil {
ShowBlockPage(c, err) ShowBlockPage(c, err)
return return
} }

View File

@ -24,6 +24,13 @@
font-size: 12px; font-size: 12px;
color: #888; color: #888;
} }
@media (prefers-color-scheme: dark) {
body {
background-color: #111;
color: #007C41
}
}
</style> </style>
</head> </head>

View File

@ -5,7 +5,6 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"net/netip" "net/netip"
"strings"
"time" "time"
"google.golang.org/grpc" "google.golang.org/grpc"
@ -49,13 +48,11 @@ func getRealIp(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
if len(vals) == 0 { if len(vals) == 0 {
return nil, fmt.Errorf("real ip header not found") return nil, fmt.Errorf("real ip header not found")
} }
a := strings.Split(vals[0], ",") ip, err := utils.GetIPFromHeader(vals[0])
h := strings.TrimSpace(a[len(a)-1])
ip, err := netip.ParseAddrPort(h)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ctx = context.WithValue(ctx, model.CtxKeyRealIP{}, ip.Addr().String()) ctx = context.WithValue(ctx, model.CtxKeyRealIP{}, ip)
return handler(ctx, req) return handler(ctx, req)
} }

View File

@ -2,11 +2,10 @@ package model
import ( import (
"errors" "errors"
"log"
"math/big" "math/big"
"net/netip"
"time" "time"
"github.com/naiba/nezha/pkg/utils"
"gorm.io/gorm" "gorm.io/gorm"
) )
@ -27,24 +26,11 @@ func (w *WAF) TableName() string {
return "waf" return "waf"
} }
func ipStringToBinary(ip string) ([]byte, error) {
addr, err := netip.ParseAddr(ip)
if err != nil {
return nil, err
}
b := addr.As16()
return b[:], nil
}
func binaryToIPString(b []byte) string {
var addr16 [16]byte
copy(addr16[:], b)
addr := netip.AddrFrom16(addr16)
return addr.Unmap().String()
}
func CheckIP(db *gorm.DB, ip string) error { func CheckIP(db *gorm.DB, ip string) error {
ipBinary, err := ipStringToBinary(ip) if ip == "" {
return nil
}
ipBinary, err := utils.IPStringToBinary(ip)
if err != nil { if err != nil {
return err return err
} }
@ -55,39 +41,53 @@ func CheckIP(db *gorm.DB, ip string) error {
} }
} }
now := time.Now().Unix() now := time.Now().Unix()
if w.LastBlockTimestamp+pow(w.Count, 4) > uint64(now) { if powAdd(w.Count, 4, w.LastBlockTimestamp) > uint64(now) {
log.Println(w.Count, w.LastBlockTimestamp+pow(w.Count, 4)-uint64(now))
return errors.New("you are blocked by nezha WAF") return errors.New("you are blocked by nezha WAF")
} }
return nil return nil
} }
func ClearIP(db *gorm.DB, ip string) error {
if ip == "" {
return nil
}
ipBinary, err := utils.IPStringToBinary(ip)
if err != nil {
return err
}
return db.Delete(&WAF{}, "ip = ?", ipBinary).Error
}
func BlockIP(db *gorm.DB, ip string, reason uint8) error { func BlockIP(db *gorm.DB, ip string, reason uint8) error {
if ip == "" { if ip == "" {
return errors.New("empty ip") return nil
} }
ipBinary, err := ipStringToBinary(ip) ipBinary, err := utils.IPStringToBinary(ip)
if err != nil { if err != nil {
return err return err
} }
var w WAF var w WAF
w.LastBlockReason = reason w.IP = ipBinary
w.LastBlockTimestamp = uint64(time.Now().Unix())
return db.Transaction(func(tx *gorm.DB) error { return db.Transaction(func(tx *gorm.DB) error {
if err := tx.FirstOrCreate(&w, WAF{IP: ipBinary}).Error; err != nil { if err := tx.Where(&w).Attrs(WAF{
LastBlockReason: reason,
LastBlockTimestamp: uint64(time.Now().Unix()),
}).FirstOrCreate(&w).Error; err != nil {
return err return err
} }
return tx.Exec("UPDATE waf SET count = count + 1 WHERE ip = ?", ipBinary).Error return tx.Exec("UPDATE waf SET count = count + 1, last_block_reason = ?, last_block_timestamp = ? WHERE ip = ?", reason, uint64(time.Now().Unix()), ipBinary).Error
}) })
} }
func pow(x, y uint64) uint64 { func powAdd(x, y, z uint64) uint64 {
base := big.NewInt(0).SetUint64(x) base := big.NewInt(0).SetUint64(x)
exp := big.NewInt(0).SetUint64(y) exp := big.NewInt(0).SetUint64(y)
result := big.NewInt(1) result := big.NewInt(1)
result.Exp(base, exp, nil) result.Exp(base, exp, nil)
result.Add(result, big.NewInt(0).SetUint64(z))
if !result.IsUint64() { if !result.IsUint64() {
return ^uint64(0) // return max uint64 value on overflow return ^uint64(0) // return max uint64 value on overflow
} }
return result.Uint64() ret := result.Uint64()
return utils.IfOr(ret < z+3, z+3, ret)
} }

View File

@ -1,83 +0,0 @@
package model
import (
"reflect"
"testing"
)
func TestIPStringToBinary(t *testing.T) {
cases := []struct {
ip string
want []byte
expectError bool
}{
// 有效的 IPv4 地址
{
ip: "192.168.1.1",
want: []byte{
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 192, 168, 1, 1,
},
expectError: false,
},
// 有效的 IPv6 地址
{
ip: "2001:db8::68",
want: []byte{
32, 1, 13, 184, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 104,
},
expectError: false,
},
// 无效的 IP 地址
{
ip: "invalid_ip",
want: []byte{},
expectError: true,
},
}
for _, c := range cases {
got, err := ipStringToBinary(c.ip)
if (err != nil) != c.expectError {
t.Errorf("IPStringToBinary(%q) error = %v, expect error = %v", c.ip, err, c.expectError)
continue
}
if err == nil && !reflect.DeepEqual(got, c.want) {
t.Errorf("IPStringToBinary(%q) = %v, want %v", c.ip, got, c.want)
}
}
}
func TestBinaryToIPString(t *testing.T) {
cases := []struct {
binary []byte
want string
}{
// IPv4 地址IPv4 映射的 IPv6 地址格式)
{
binary: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 192, 168, 1, 1},
want: "192.168.1.1",
},
// 其他测试用例
{
binary: []byte{32, 1, 13, 184, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 104},
want: "2001:db8::68",
},
// 全零值
{
binary: []byte{},
want: "::",
},
// IPv4 映射的 IPv6 地址
{
binary: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 127, 0, 0, 1},
want: "127.0.0.1",
},
}
for _, c := range cases {
got := binaryToIPString(c.binary)
if got != c.want {
t.Errorf("BinaryToIPString(%v) = %q, 期望 %q", c.binary, got, c.want)
}
}
}

View File

@ -2,7 +2,9 @@ package utils
import ( import (
"crypto/rand" "crypto/rand"
"errors"
"math/big" "math/big"
"net/netip"
"os" "os"
"regexp" "regexp"
"strconv" "strconv"
@ -37,6 +39,35 @@ func IPDesensitize(ipAddr string) string {
return ipAddr return ipAddr
} }
func IPStringToBinary(ip string) ([]byte, error) {
addr, err := netip.ParseAddr(ip)
if err != nil {
return nil, err
}
b := addr.As16()
return b[:], nil
}
func BinaryToIPString(b []byte) string {
var addr16 [16]byte
copy(addr16[:], b)
addr := netip.AddrFrom16(addr16)
return addr.Unmap().String()
}
func GetIPFromHeader(headerValue string) (string, error) {
a := strings.Split(headerValue, ",")
h := strings.TrimSpace(a[len(a)-1])
ip, err := netip.ParseAddrPort(h)
if err != nil {
return "", err
}
if !ip.IsValid() {
return "", errors.New("invalid ip")
}
return ip.Addr().String(), nil
}
// SplitIPAddr 传入/分割的v4v6混合地址返回v4和v6地址与有效地址 // SplitIPAddr 传入/分割的v4v6混合地址返回v4和v6地址与有效地址
func SplitIPAddr(v4v6Bundle string) (string, string, string) { func SplitIPAddr(v4v6Bundle string) (string, string, string) {
ipList := strings.Split(v4v6Bundle, "/") ipList := strings.Split(v4v6Bundle, "/")

View File

@ -1,6 +1,7 @@
package utils package utils
import ( import (
"reflect"
"testing" "testing"
) )
@ -56,3 +57,80 @@ func TestGenerGenerateRandomString(t *testing.T) {
generatedString[str] = true generatedString[str] = true
} }
} }
func TestIPStringToBinary(t *testing.T) {
cases := []struct {
ip string
want []byte
expectError bool
}{
// 有效的 IPv4 地址
{
ip: "192.168.1.1",
want: []byte{
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 192, 168, 1, 1,
},
expectError: false,
},
// 有效的 IPv6 地址
{
ip: "2001:db8::68",
want: []byte{
32, 1, 13, 184, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 104,
},
expectError: false,
},
// 无效的 IP 地址
{
ip: "invalid_ip",
want: []byte{},
expectError: true,
},
}
for _, c := range cases {
got, err := IPStringToBinary(c.ip)
if (err != nil) != c.expectError {
t.Errorf("IPStringToBinary(%q) error = %v, expect error = %v", c.ip, err, c.expectError)
continue
}
if err == nil && !reflect.DeepEqual(got, c.want) {
t.Errorf("IPStringToBinary(%q) = %v, want %v", c.ip, got, c.want)
}
}
}
func TestBinaryToIPString(t *testing.T) {
cases := []struct {
binary []byte
want string
}{
// IPv4 地址IPv4 映射的 IPv6 地址格式)
{
binary: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 192, 168, 1, 1},
want: "192.168.1.1",
},
// 其他测试用例
{
binary: []byte{32, 1, 13, 184, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 104},
want: "2001:db8::68",
},
// 全零值
{
binary: []byte{},
want: "::",
},
// IPv4 映射的 IPv6 地址
{
binary: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 127, 0, 0, 1},
want: "127.0.0.1",
},
}
for _, c := range cases {
got := BinaryToIPString(c.binary)
if got != c.want {
t.Errorf("BinaryToIPString(%v) = %q, 期望 %q", c.binary, got, c.want)
}
}
}