mirror of
https://github.com/nezhahq/nezha.git
synced 2025-01-22 20:58:14 -05:00
fix: waf condition
This commit is contained in:
parent
867f840265
commit
cd42b1b9d5
@ -166,11 +166,13 @@ func optionalAuthMiddleware(mw *jwt.GinJWTMiddleware) func(c *gin.Context) {
|
|||||||
identity := mw.IdentityHandler(c)
|
identity := mw.IdentityHandler(c)
|
||||||
|
|
||||||
if identity != nil {
|
if identity != nil {
|
||||||
|
model.ClearIP(singleton.DB, c.GetString(model.CtxKeyRealIPStr))
|
||||||
|
c.Set(mw.IdentityKey, identity)
|
||||||
|
} else {
|
||||||
if err := model.BlockIP(singleton.DB, c.GetString(model.CtxKeyRealIPStr), model.WAFBlockReasonTypeBruteForceToken); err != nil {
|
if err := model.BlockIP(singleton.DB, c.GetString(model.CtxKeyRealIPStr), model.WAFBlockReasonTypeBruteForceToken); err != nil {
|
||||||
waf.ShowBlockPage(c, err)
|
waf.ShowBlockPage(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.Set(mw.IdentityKey, identity)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Next()
|
c.Next()
|
||||||
|
@ -3,12 +3,12 @@ package waf
|
|||||||
import (
|
import (
|
||||||
_ "embed"
|
_ "embed"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/netip"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
"github.com/naiba/nezha/model"
|
"github.com/naiba/nezha/model"
|
||||||
|
"github.com/naiba/nezha/pkg/utils"
|
||||||
"github.com/naiba/nezha/service/singleton"
|
"github.com/naiba/nezha/service/singleton"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -32,26 +32,17 @@ func RealIp(c *gin.Context) {
|
|||||||
c.AbortWithStatusJSON(http.StatusOK, model.CommonResponse[any]{Success: false, Error: "real ip header not found"})
|
c.AbortWithStatusJSON(http.StatusOK, model.CommonResponse[any]{Success: false, Error: "real ip header not found"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ip, err := netip.ParseAddrPort(vals)
|
ip, err := utils.GetIPFromHeader(vals)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.AbortWithStatusJSON(http.StatusOK, model.CommonResponse[any]{Success: false, Error: err.Error()})
|
c.AbortWithStatusJSON(http.StatusOK, model.CommonResponse[any]{Success: false, Error: err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.Set(model.CtxKeyRealIPStr, ip.Addr().String())
|
c.Set(model.CtxKeyRealIPStr, ip)
|
||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
|
|
||||||
func Waf(c *gin.Context) {
|
func Waf(c *gin.Context) {
|
||||||
if singleton.Conf.RealIPHeader == "" {
|
if err := model.CheckIP(singleton.DB, c.GetString(model.CtxKeyRealIPStr)); err != nil {
|
||||||
c.Next()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
realipAddr := c.GetString(model.CtxKeyRealIPStr)
|
|
||||||
if realipAddr == "" {
|
|
||||||
c.Next()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err := model.CheckIP(singleton.DB, realipAddr); err != nil {
|
|
||||||
ShowBlockPage(c, err)
|
ShowBlockPage(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -24,6 +24,13 @@
|
|||||||
font-size: 12px;
|
font-size: 12px;
|
||||||
color: #888;
|
color: #888;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@media (prefers-color-scheme: dark) {
|
||||||
|
body {
|
||||||
|
background-color: #111;
|
||||||
|
color: #007C41
|
||||||
|
}
|
||||||
|
}
|
||||||
</style>
|
</style>
|
||||||
</head>
|
</head>
|
||||||
|
|
||||||
|
@ -5,7 +5,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
@ -49,13 +48,11 @@ func getRealIp(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
|
|||||||
if len(vals) == 0 {
|
if len(vals) == 0 {
|
||||||
return nil, fmt.Errorf("real ip header not found")
|
return nil, fmt.Errorf("real ip header not found")
|
||||||
}
|
}
|
||||||
a := strings.Split(vals[0], ",")
|
ip, err := utils.GetIPFromHeader(vals[0])
|
||||||
h := strings.TrimSpace(a[len(a)-1])
|
|
||||||
ip, err := netip.ParseAddrPort(h)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
ctx = context.WithValue(ctx, model.CtxKeyRealIP{}, ip.Addr().String())
|
ctx = context.WithValue(ctx, model.CtxKeyRealIP{}, ip)
|
||||||
return handler(ctx, req)
|
return handler(ctx, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
58
model/waf.go
58
model/waf.go
@ -2,11 +2,10 @@ package model
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"log"
|
|
||||||
"math/big"
|
"math/big"
|
||||||
"net/netip"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/naiba/nezha/pkg/utils"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -27,24 +26,11 @@ func (w *WAF) TableName() string {
|
|||||||
return "waf"
|
return "waf"
|
||||||
}
|
}
|
||||||
|
|
||||||
func ipStringToBinary(ip string) ([]byte, error) {
|
|
||||||
addr, err := netip.ParseAddr(ip)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
b := addr.As16()
|
|
||||||
return b[:], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func binaryToIPString(b []byte) string {
|
|
||||||
var addr16 [16]byte
|
|
||||||
copy(addr16[:], b)
|
|
||||||
addr := netip.AddrFrom16(addr16)
|
|
||||||
return addr.Unmap().String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func CheckIP(db *gorm.DB, ip string) error {
|
func CheckIP(db *gorm.DB, ip string) error {
|
||||||
ipBinary, err := ipStringToBinary(ip)
|
if ip == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
ipBinary, err := utils.IPStringToBinary(ip)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -55,39 +41,53 @@ func CheckIP(db *gorm.DB, ip string) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
now := time.Now().Unix()
|
now := time.Now().Unix()
|
||||||
if w.LastBlockTimestamp+pow(w.Count, 4) > uint64(now) {
|
if powAdd(w.Count, 4, w.LastBlockTimestamp) > uint64(now) {
|
||||||
log.Println(w.Count, w.LastBlockTimestamp+pow(w.Count, 4)-uint64(now))
|
|
||||||
return errors.New("you are blocked by nezha WAF")
|
return errors.New("you are blocked by nezha WAF")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ClearIP(db *gorm.DB, ip string) error {
|
||||||
|
if ip == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
ipBinary, err := utils.IPStringToBinary(ip)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return db.Delete(&WAF{}, "ip = ?", ipBinary).Error
|
||||||
|
}
|
||||||
|
|
||||||
func BlockIP(db *gorm.DB, ip string, reason uint8) error {
|
func BlockIP(db *gorm.DB, ip string, reason uint8) error {
|
||||||
if ip == "" {
|
if ip == "" {
|
||||||
return errors.New("empty ip")
|
return nil
|
||||||
}
|
}
|
||||||
ipBinary, err := ipStringToBinary(ip)
|
ipBinary, err := utils.IPStringToBinary(ip)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
var w WAF
|
var w WAF
|
||||||
w.LastBlockReason = reason
|
w.IP = ipBinary
|
||||||
w.LastBlockTimestamp = uint64(time.Now().Unix())
|
|
||||||
return db.Transaction(func(tx *gorm.DB) error {
|
return db.Transaction(func(tx *gorm.DB) error {
|
||||||
if err := tx.FirstOrCreate(&w, WAF{IP: ipBinary}).Error; err != nil {
|
if err := tx.Where(&w).Attrs(WAF{
|
||||||
|
LastBlockReason: reason,
|
||||||
|
LastBlockTimestamp: uint64(time.Now().Unix()),
|
||||||
|
}).FirstOrCreate(&w).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return tx.Exec("UPDATE waf SET count = count + 1 WHERE ip = ?", ipBinary).Error
|
return tx.Exec("UPDATE waf SET count = count + 1, last_block_reason = ?, last_block_timestamp = ? WHERE ip = ?", reason, uint64(time.Now().Unix()), ipBinary).Error
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func pow(x, y uint64) uint64 {
|
func powAdd(x, y, z uint64) uint64 {
|
||||||
base := big.NewInt(0).SetUint64(x)
|
base := big.NewInt(0).SetUint64(x)
|
||||||
exp := big.NewInt(0).SetUint64(y)
|
exp := big.NewInt(0).SetUint64(y)
|
||||||
result := big.NewInt(1)
|
result := big.NewInt(1)
|
||||||
result.Exp(base, exp, nil)
|
result.Exp(base, exp, nil)
|
||||||
|
result.Add(result, big.NewInt(0).SetUint64(z))
|
||||||
if !result.IsUint64() {
|
if !result.IsUint64() {
|
||||||
return ^uint64(0) // return max uint64 value on overflow
|
return ^uint64(0) // return max uint64 value on overflow
|
||||||
}
|
}
|
||||||
return result.Uint64()
|
ret := result.Uint64()
|
||||||
|
return utils.IfOr(ret < z+3, z+3, ret)
|
||||||
}
|
}
|
||||||
|
@ -1,83 +0,0 @@
|
|||||||
package model
|
|
||||||
|
|
||||||
import (
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestIPStringToBinary(t *testing.T) {
|
|
||||||
cases := []struct {
|
|
||||||
ip string
|
|
||||||
want []byte
|
|
||||||
expectError bool
|
|
||||||
}{
|
|
||||||
// 有效的 IPv4 地址
|
|
||||||
{
|
|
||||||
ip: "192.168.1.1",
|
|
||||||
want: []byte{
|
|
||||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 192, 168, 1, 1,
|
|
||||||
},
|
|
||||||
expectError: false,
|
|
||||||
},
|
|
||||||
// 有效的 IPv6 地址
|
|
||||||
{
|
|
||||||
ip: "2001:db8::68",
|
|
||||||
want: []byte{
|
|
||||||
32, 1, 13, 184, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 104,
|
|
||||||
},
|
|
||||||
expectError: false,
|
|
||||||
},
|
|
||||||
// 无效的 IP 地址
|
|
||||||
{
|
|
||||||
ip: "invalid_ip",
|
|
||||||
want: []byte{},
|
|
||||||
expectError: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, c := range cases {
|
|
||||||
got, err := ipStringToBinary(c.ip)
|
|
||||||
if (err != nil) != c.expectError {
|
|
||||||
t.Errorf("IPStringToBinary(%q) error = %v, expect error = %v", c.ip, err, c.expectError)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if err == nil && !reflect.DeepEqual(got, c.want) {
|
|
||||||
t.Errorf("IPStringToBinary(%q) = %v, want %v", c.ip, got, c.want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBinaryToIPString(t *testing.T) {
|
|
||||||
cases := []struct {
|
|
||||||
binary []byte
|
|
||||||
want string
|
|
||||||
}{
|
|
||||||
// IPv4 地址(IPv4 映射的 IPv6 地址格式)
|
|
||||||
{
|
|
||||||
binary: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 192, 168, 1, 1},
|
|
||||||
want: "192.168.1.1",
|
|
||||||
},
|
|
||||||
// 其他测试用例
|
|
||||||
{
|
|
||||||
binary: []byte{32, 1, 13, 184, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 104},
|
|
||||||
want: "2001:db8::68",
|
|
||||||
},
|
|
||||||
// 全零值
|
|
||||||
{
|
|
||||||
binary: []byte{},
|
|
||||||
want: "::",
|
|
||||||
},
|
|
||||||
// IPv4 映射的 IPv6 地址
|
|
||||||
{
|
|
||||||
binary: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 127, 0, 0, 1},
|
|
||||||
want: "127.0.0.1",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, c := range cases {
|
|
||||||
got := binaryToIPString(c.binary)
|
|
||||||
if got != c.want {
|
|
||||||
t.Errorf("BinaryToIPString(%v) = %q, 期望 %q", c.binary, got, c.want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@ -2,7 +2,9 @@ package utils
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
|
"errors"
|
||||||
"math/big"
|
"math/big"
|
||||||
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
@ -37,6 +39,35 @@ func IPDesensitize(ipAddr string) string {
|
|||||||
return ipAddr
|
return ipAddr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func IPStringToBinary(ip string) ([]byte, error) {
|
||||||
|
addr, err := netip.ParseAddr(ip)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
b := addr.As16()
|
||||||
|
return b[:], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func BinaryToIPString(b []byte) string {
|
||||||
|
var addr16 [16]byte
|
||||||
|
copy(addr16[:], b)
|
||||||
|
addr := netip.AddrFrom16(addr16)
|
||||||
|
return addr.Unmap().String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetIPFromHeader(headerValue string) (string, error) {
|
||||||
|
a := strings.Split(headerValue, ",")
|
||||||
|
h := strings.TrimSpace(a[len(a)-1])
|
||||||
|
ip, err := netip.ParseAddrPort(h)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if !ip.IsValid() {
|
||||||
|
return "", errors.New("invalid ip")
|
||||||
|
}
|
||||||
|
return ip.Addr().String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
// SplitIPAddr 传入/分割的v4v6混合地址,返回v4和v6地址与有效地址
|
// SplitIPAddr 传入/分割的v4v6混合地址,返回v4和v6地址与有效地址
|
||||||
func SplitIPAddr(v4v6Bundle string) (string, string, string) {
|
func SplitIPAddr(v4v6Bundle string) (string, string, string) {
|
||||||
ipList := strings.Split(v4v6Bundle, "/")
|
ipList := strings.Split(v4v6Bundle, "/")
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package utils
|
package utils
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -56,3 +57,80 @@ func TestGenerGenerateRandomString(t *testing.T) {
|
|||||||
generatedString[str] = true
|
generatedString[str] = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIPStringToBinary(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
ip string
|
||||||
|
want []byte
|
||||||
|
expectError bool
|
||||||
|
}{
|
||||||
|
// 有效的 IPv4 地址
|
||||||
|
{
|
||||||
|
ip: "192.168.1.1",
|
||||||
|
want: []byte{
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 192, 168, 1, 1,
|
||||||
|
},
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
// 有效的 IPv6 地址
|
||||||
|
{
|
||||||
|
ip: "2001:db8::68",
|
||||||
|
want: []byte{
|
||||||
|
32, 1, 13, 184, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 104,
|
||||||
|
},
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
// 无效的 IP 地址
|
||||||
|
{
|
||||||
|
ip: "invalid_ip",
|
||||||
|
want: []byte{},
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range cases {
|
||||||
|
got, err := IPStringToBinary(c.ip)
|
||||||
|
if (err != nil) != c.expectError {
|
||||||
|
t.Errorf("IPStringToBinary(%q) error = %v, expect error = %v", c.ip, err, c.expectError)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err == nil && !reflect.DeepEqual(got, c.want) {
|
||||||
|
t.Errorf("IPStringToBinary(%q) = %v, want %v", c.ip, got, c.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBinaryToIPString(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
binary []byte
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
// IPv4 地址(IPv4 映射的 IPv6 地址格式)
|
||||||
|
{
|
||||||
|
binary: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 192, 168, 1, 1},
|
||||||
|
want: "192.168.1.1",
|
||||||
|
},
|
||||||
|
// 其他测试用例
|
||||||
|
{
|
||||||
|
binary: []byte{32, 1, 13, 184, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 104},
|
||||||
|
want: "2001:db8::68",
|
||||||
|
},
|
||||||
|
// 全零值
|
||||||
|
{
|
||||||
|
binary: []byte{},
|
||||||
|
want: "::",
|
||||||
|
},
|
||||||
|
// IPv4 映射的 IPv6 地址
|
||||||
|
{
|
||||||
|
binary: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 127, 0, 0, 1},
|
||||||
|
want: "127.0.0.1",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range cases {
|
||||||
|
got := BinaryToIPString(c.binary)
|
||||||
|
if got != c.want {
|
||||||
|
t.Errorf("BinaryToIPString(%v) = %q, 期望 %q", c.binary, got, c.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user