From 17b02640a93fc398959dca215c4b0bf6463511d7 Mon Sep 17 00:00:00 2001 From: naiba Date: Fri, 22 Nov 2024 23:57:25 +0800 Subject: [PATCH] feat: waf :clown_face: --- cmd/dashboard/controller/controller.go | 32 +-------- cmd/dashboard/controller/jwt.go | 7 ++ cmd/dashboard/controller/waf/waf.go | 90 ++++++++++++++++++++++++ cmd/dashboard/controller/waf/waf.html | 39 ++++++++++ cmd/dashboard/controller/waf/waf_test.go | 29 ++++++++ cmd/dashboard/rpc/rpc.go | 5 +- model/waf.go | 40 +++++++++++ service/rpc/auth.go | 4 -- service/singleton/singleton.go | 3 +- 9 files changed, 214 insertions(+), 35 deletions(-) create mode 100644 cmd/dashboard/controller/waf/waf.go create mode 100644 cmd/dashboard/controller/waf/waf.html create mode 100644 cmd/dashboard/controller/waf/waf_test.go create mode 100644 model/waf.go diff --git a/cmd/dashboard/controller/controller.go b/cmd/dashboard/controller/controller.go index 720466a..2efd00a 100644 --- a/cmd/dashboard/controller/controller.go +++ b/cmd/dashboard/controller/controller.go @@ -5,7 +5,6 @@ import ( "fmt" "log" "net/http" - "net/netip" "os" "path/filepath" "strings" @@ -16,6 +15,7 @@ import ( swaggerfiles "github.com/swaggo/files" ginSwagger "github.com/swaggo/gin-swagger" + "github.com/naiba/nezha/cmd/dashboard/controller/waf" docs "github.com/naiba/nezha/cmd/dashboard/docs" "github.com/naiba/nezha/model" "github.com/naiba/nezha/service/singleton" @@ -34,39 +34,14 @@ func ServeWeb() http.Handler { r.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerfiles.Handler)) } - r.Use(realIp) + r.Use(waf.RealIp) + r.Use(waf.Waf) r.Use(recordPath) routers(r) return r } -func realIp(c *gin.Context) { - if singleton.Conf.RealIPHeader == "" { - c.Next() - return - } - - if singleton.Conf.RealIPHeader == model.ConfigUsePeerIP { - c.Set(model.CtxKeyRealIPStr, c.RemoteIP()) - c.Next() - return - } - - vals := c.Request.Header.Get(singleton.Conf.RealIPHeader) - if vals == "" { - c.AbortWithStatusJSON(http.StatusOK, model.CommonResponse[any]{Success: false, Error: "real ip header not found"}) - return - } - ip, err := netip.ParseAddr(vals) - if err != nil { - c.AbortWithStatusJSON(http.StatusOK, model.CommonResponse[any]{Success: false, Error: err.Error()}) - return - } - c.Set(model.CtxKeyRealIPStr, ip.String()) - c.Next() -} - func routers(r *gin.Engine) { authMiddleware, err := jwt.New(initParams()) if err != nil { @@ -154,7 +129,6 @@ func routers(r *gin.Engine) { } func recordPath(c *gin.Context) { - log.Printf("bingo web real ip: %s", c.GetString(model.CtxKeyRealIPStr)) url := c.Request.URL.String() for _, p := range c.Params { url = strings.Replace(url, p.Value, ":"+p.Key, 1) diff --git a/cmd/dashboard/controller/jwt.go b/cmd/dashboard/controller/jwt.go index bc5ec4e..349ca9f 100644 --- a/cmd/dashboard/controller/jwt.go +++ b/cmd/dashboard/controller/jwt.go @@ -9,6 +9,7 @@ import ( "github.com/gin-gonic/gin" "golang.org/x/crypto/bcrypt" + "github.com/naiba/nezha/cmd/dashboard/controller/waf" "github.com/naiba/nezha/model" "github.com/naiba/nezha/pkg/utils" "github.com/naiba/nezha/service/singleton" @@ -87,10 +88,12 @@ func authenticator() func(c *gin.Context) (interface{}, error) { var user model.User if err := singleton.DB.Select("id", "password").Where("username = ?", loginVals.Username).First(&user).Error; err != nil { + model.BlockIP(singleton.DB, c.GetString(model.CtxKeyRealIPStr), model.WAFBlockReasonTypeLoginFail) return nil, jwt.ErrFailedAuthentication } if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(loginVals.Password)); err != nil { + model.BlockIP(singleton.DB, c.GetString(model.CtxKeyRealIPStr), model.WAFBlockReasonTypeLoginFail) return nil, jwt.ErrFailedAuthentication } @@ -163,6 +166,10 @@ func optionalAuthMiddleware(mw *jwt.GinJWTMiddleware) func(c *gin.Context) { identity := mw.IdentityHandler(c) if identity != nil { + if err := model.BlockIP(singleton.DB, c.GetString(model.CtxKeyRealIPStr), model.WAFBlockReasonTypeBruteForceToken); err != nil { + waf.ShowBlockPage(c, err) + return + } c.Set(mw.IdentityKey, identity) } diff --git a/cmd/dashboard/controller/waf/waf.go b/cmd/dashboard/controller/waf/waf.go new file mode 100644 index 0000000..4a559ea --- /dev/null +++ b/cmd/dashboard/controller/waf/waf.go @@ -0,0 +1,90 @@ +package waf + +import ( + _ "embed" + "errors" + "log" + "math/big" + "net/http" + "net/netip" + "strings" + "time" + + "github.com/gin-gonic/gin" + "github.com/naiba/nezha/model" + "github.com/naiba/nezha/service/singleton" + "gorm.io/gorm" +) + +//go:embed waf.html +var errorPageTemplate string + +func RealIp(c *gin.Context) { + if singleton.Conf.RealIPHeader == "" { + c.Next() + return + } + + if singleton.Conf.RealIPHeader == model.ConfigUsePeerIP { + c.Set(model.CtxKeyRealIPStr, c.RemoteIP()) + c.Next() + return + } + + vals := c.Request.Header.Get(singleton.Conf.RealIPHeader) + if vals == "" { + c.AbortWithStatusJSON(http.StatusOK, model.CommonResponse[any]{Success: false, Error: "real ip header not found"}) + return + } + ip, err := netip.ParseAddr(vals) + if err != nil { + c.AbortWithStatusJSON(http.StatusOK, model.CommonResponse[any]{Success: false, Error: err.Error()}) + return + } + c.Set(model.CtxKeyRealIPStr, ip.String()) + c.Next() +} + +func Waf(c *gin.Context) { + if singleton.Conf.RealIPHeader == "" { + c.Next() + return + } + realipAddr := c.GetString(model.CtxKeyRealIPStr) + if realipAddr == "" { + c.Next() + return + } + var w model.WAF + if err := singleton.DB.First(&w, "ip = ?", realipAddr).Error; err != nil { + if err != gorm.ErrRecordNotFound { + ShowBlockPage(c, err) + return + } + } + now := time.Now().Unix() + if w.LastBlockTimestamp+pow(w.Count, 4) > uint64(now) { + log.Println(w.Count, w.LastBlockTimestamp+pow(w.Count, 4)-uint64(now)) + ShowBlockPage(c, errors.New("you are blocked by nezha WAF")) + return + } + c.Next() +} + +func pow(x, y uint64) uint64 { + base := big.NewInt(0).SetUint64(x) + exp := big.NewInt(0).SetUint64(y) + result := big.NewInt(1) + result.Exp(base, exp, nil) + if !result.IsUint64() { + return ^uint64(0) // return max uint64 value on overflow + } + return result.Uint64() +} + +func ShowBlockPage(c *gin.Context, err error) { + c.Writer.WriteHeader(http.StatusForbidden) + c.Header("Content-Type", "text/html; charset=utf-8") + c.Writer.WriteString(strings.Replace(errorPageTemplate, "{error}", err.Error(), 1)) + c.Abort() +} diff --git a/cmd/dashboard/controller/waf/waf.html b/cmd/dashboard/controller/waf/waf.html new file mode 100644 index 0000000..c278295 --- /dev/null +++ b/cmd/dashboard/controller/waf/waf.html @@ -0,0 +1,39 @@ + + + + + + + Blocked + + + + +
+
🤡
+

Blocked

+

{error}

+

nezha WAF

+
+ + + \ No newline at end of file diff --git a/cmd/dashboard/controller/waf/waf_test.go b/cmd/dashboard/controller/waf/waf_test.go new file mode 100644 index 0000000..9f43c63 --- /dev/null +++ b/cmd/dashboard/controller/waf/waf_test.go @@ -0,0 +1,29 @@ +package waf + +import ( + "math" + "testing" +) + +func TestPow(t *testing.T) { + tests := []struct { + x, + y, + expect uint64 + }{ + {2, 64, math.MaxUint64}, // 2 的 64 次方,超过 uint64 最大值 + {uint64(1 << 63), 2, math.MaxUint64}, // 大数平方,可能溢出 + {uint64(^uint64(0)), 2, math.MaxUint64}, // uint64 最大值的平方,溢出 + {2, 3, 8}, + {5, 0, 1}, + {3, 1, 3}, + {0, 5, 0}, + } + + for _, tt := range tests { + result := pow(tt.x, tt.y) + if result != tt.expect { + t.Errorf("pow(%d, %d) = %d; expect %d", tt.x, tt.y, result, tt.expect) + } + } +} diff --git a/cmd/dashboard/rpc/rpc.go b/cmd/dashboard/rpc/rpc.go index ce5f059..c36b443 100644 --- a/cmd/dashboard/rpc/rpc.go +++ b/cmd/dashboard/rpc/rpc.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" "net/netip" + "strings" "time" "google.golang.org/grpc" @@ -48,7 +49,9 @@ func getRealIp(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, if len(vals) == 0 { return nil, fmt.Errorf("real ip header not found") } - ip, err := netip.ParseAddr(vals[0]) + a := strings.Split(vals[0], ",") + h := strings.TrimSpace(a[len(a)-1]) + ip, err := netip.ParseAddr(h) if err != nil { return nil, err } diff --git a/model/waf.go b/model/waf.go new file mode 100644 index 0000000..d179723 --- /dev/null +++ b/model/waf.go @@ -0,0 +1,40 @@ +package model + +import ( + "errors" + "time" + + "gorm.io/gorm" +) + +const ( + _ uint8 = iota + WAFBlockReasonTypeLoginFail + WAFBlockReasonTypeBruteForceToken +) + +type WAF struct { + IP string `gorm:"type:binary(16);primaryKey" json:"ip,omitempty"` + Count uint64 `json:"count,omitempty"` + LastBlockReason uint8 `json:"last_block_reason,omitempty"` + LastBlockTimestamp uint64 `json:"last_block_timestamp,omitempty"` +} + +func (w *WAF) TableName() string { + return "waf" +} + +func BlockIP(db *gorm.DB, ip string, reason uint8) error { + if ip == "" { + return errors.New("empty ip") + } + var w WAF + w.LastBlockReason = reason + w.LastBlockTimestamp = uint64(time.Now().Unix()) + return db.Transaction(func(tx *gorm.DB) error { + if err := tx.FirstOrCreate(&w, WAF{IP: ip}).Error; err != nil { + return err + } + return tx.Exec("UPDATE waf SET count = count + 1 WHERE ip = ?", ip).Error + }) +} diff --git a/service/rpc/auth.go b/service/rpc/auth.go index b965958..d1762f1 100644 --- a/service/rpc/auth.go +++ b/service/rpc/auth.go @@ -2,7 +2,6 @@ package rpc import ( "context" - "log" "sync" "google.golang.org/grpc/codes" @@ -25,9 +24,6 @@ func (a *authHandler) Check(ctx context.Context) (uint64, error) { return 0, status.Errorf(codes.Unauthenticated, "获取 metaData 失败") } - realIp := ctx.Value(model.CtxKeyRealIP{}) - log.Printf("bingo rpc realIp: %s, metadata: %v", realIp, md) - var clientSecret string if value, ok := md["client_secret"]; ok { clientSecret = value[0] diff --git a/service/singleton/singleton.go b/service/singleton/singleton.go index b4548d9..f636d3d 100644 --- a/service/singleton/singleton.go +++ b/service/singleton/singleton.go @@ -65,7 +65,8 @@ func InitDBFromPath(path string) { err = DB.AutoMigrate(model.Server{}, model.User{}, model.ServerGroup{}, model.NotificationGroup{}, model.Notification{}, model.AlertRule{}, model.Service{}, model.NotificationGroupNotification{}, model.ServiceHistory{}, model.Cron{}, model.Transfer{}, model.ServerGroupServer{}, model.UserGroup{}, - model.UserGroupUser{}, model.NAT{}, model.DDNSProfile{}, model.NotificationGroupNotification{}) + model.UserGroupUser{}, model.NAT{}, model.DDNSProfile{}, model.NotificationGroupNotification{}, + model.WAF{}) if err != nil { panic(err) }