feat: waf 🤡

This commit is contained in:
naiba 2024-11-22 23:57:25 +08:00
parent d699d0ee87
commit 17b02640a9
9 changed files with 214 additions and 35 deletions

View File

@ -5,7 +5,6 @@ import (
"fmt"
"log"
"net/http"
"net/netip"
"os"
"path/filepath"
"strings"
@ -16,6 +15,7 @@ import (
swaggerfiles "github.com/swaggo/files"
ginSwagger "github.com/swaggo/gin-swagger"
"github.com/naiba/nezha/cmd/dashboard/controller/waf"
docs "github.com/naiba/nezha/cmd/dashboard/docs"
"github.com/naiba/nezha/model"
"github.com/naiba/nezha/service/singleton"
@ -34,39 +34,14 @@ func ServeWeb() http.Handler {
r.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerfiles.Handler))
}
r.Use(realIp)
r.Use(waf.RealIp)
r.Use(waf.Waf)
r.Use(recordPath)
routers(r)
return r
}
func realIp(c *gin.Context) {
if singleton.Conf.RealIPHeader == "" {
c.Next()
return
}
if singleton.Conf.RealIPHeader == model.ConfigUsePeerIP {
c.Set(model.CtxKeyRealIPStr, c.RemoteIP())
c.Next()
return
}
vals := c.Request.Header.Get(singleton.Conf.RealIPHeader)
if vals == "" {
c.AbortWithStatusJSON(http.StatusOK, model.CommonResponse[any]{Success: false, Error: "real ip header not found"})
return
}
ip, err := netip.ParseAddr(vals)
if err != nil {
c.AbortWithStatusJSON(http.StatusOK, model.CommonResponse[any]{Success: false, Error: err.Error()})
return
}
c.Set(model.CtxKeyRealIPStr, ip.String())
c.Next()
}
func routers(r *gin.Engine) {
authMiddleware, err := jwt.New(initParams())
if err != nil {
@ -154,7 +129,6 @@ func routers(r *gin.Engine) {
}
func recordPath(c *gin.Context) {
log.Printf("bingo web real ip: %s", c.GetString(model.CtxKeyRealIPStr))
url := c.Request.URL.String()
for _, p := range c.Params {
url = strings.Replace(url, p.Value, ":"+p.Key, 1)

View File

@ -9,6 +9,7 @@ import (
"github.com/gin-gonic/gin"
"golang.org/x/crypto/bcrypt"
"github.com/naiba/nezha/cmd/dashboard/controller/waf"
"github.com/naiba/nezha/model"
"github.com/naiba/nezha/pkg/utils"
"github.com/naiba/nezha/service/singleton"
@ -87,10 +88,12 @@ func authenticator() func(c *gin.Context) (interface{}, error) {
var user model.User
if err := singleton.DB.Select("id", "password").Where("username = ?", loginVals.Username).First(&user).Error; err != nil {
model.BlockIP(singleton.DB, c.GetString(model.CtxKeyRealIPStr), model.WAFBlockReasonTypeLoginFail)
return nil, jwt.ErrFailedAuthentication
}
if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(loginVals.Password)); err != nil {
model.BlockIP(singleton.DB, c.GetString(model.CtxKeyRealIPStr), model.WAFBlockReasonTypeLoginFail)
return nil, jwt.ErrFailedAuthentication
}
@ -163,6 +166,10 @@ func optionalAuthMiddleware(mw *jwt.GinJWTMiddleware) func(c *gin.Context) {
identity := mw.IdentityHandler(c)
if identity != nil {
if err := model.BlockIP(singleton.DB, c.GetString(model.CtxKeyRealIPStr), model.WAFBlockReasonTypeBruteForceToken); err != nil {
waf.ShowBlockPage(c, err)
return
}
c.Set(mw.IdentityKey, identity)
}

View File

@ -0,0 +1,90 @@
package waf
import (
_ "embed"
"errors"
"log"
"math/big"
"net/http"
"net/netip"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/naiba/nezha/model"
"github.com/naiba/nezha/service/singleton"
"gorm.io/gorm"
)
//go:embed waf.html
var errorPageTemplate string
func RealIp(c *gin.Context) {
if singleton.Conf.RealIPHeader == "" {
c.Next()
return
}
if singleton.Conf.RealIPHeader == model.ConfigUsePeerIP {
c.Set(model.CtxKeyRealIPStr, c.RemoteIP())
c.Next()
return
}
vals := c.Request.Header.Get(singleton.Conf.RealIPHeader)
if vals == "" {
c.AbortWithStatusJSON(http.StatusOK, model.CommonResponse[any]{Success: false, Error: "real ip header not found"})
return
}
ip, err := netip.ParseAddr(vals)
if err != nil {
c.AbortWithStatusJSON(http.StatusOK, model.CommonResponse[any]{Success: false, Error: err.Error()})
return
}
c.Set(model.CtxKeyRealIPStr, ip.String())
c.Next()
}
func Waf(c *gin.Context) {
if singleton.Conf.RealIPHeader == "" {
c.Next()
return
}
realipAddr := c.GetString(model.CtxKeyRealIPStr)
if realipAddr == "" {
c.Next()
return
}
var w model.WAF
if err := singleton.DB.First(&w, "ip = ?", realipAddr).Error; err != nil {
if err != gorm.ErrRecordNotFound {
ShowBlockPage(c, err)
return
}
}
now := time.Now().Unix()
if w.LastBlockTimestamp+pow(w.Count, 4) > uint64(now) {
log.Println(w.Count, w.LastBlockTimestamp+pow(w.Count, 4)-uint64(now))
ShowBlockPage(c, errors.New("you are blocked by nezha WAF"))
return
}
c.Next()
}
func pow(x, y uint64) uint64 {
base := big.NewInt(0).SetUint64(x)
exp := big.NewInt(0).SetUint64(y)
result := big.NewInt(1)
result.Exp(base, exp, nil)
if !result.IsUint64() {
return ^uint64(0) // return max uint64 value on overflow
}
return result.Uint64()
}
func ShowBlockPage(c *gin.Context, err error) {
c.Writer.WriteHeader(http.StatusForbidden)
c.Header("Content-Type", "text/html; charset=utf-8")
c.Writer.WriteString(strings.Replace(errorPageTemplate, "{error}", err.Error(), 1))
c.Abort()
}

View File

@ -0,0 +1,39 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Blocked</title>
<style>
body {
display: flex;
justify-content: center;
align-items: center;
height: 90vh;
font-weight: bolder;
font-family: 'Courier New', Courier, monospace;
}
main {
text-align: center;
}
.emoji {
font-size: 200px;
}
p.secondary {
font-size: 12px;
color: #888;
}
</style>
</head>
<body>
<main>
<div class="emoji">🤡</div>
<h1>Blocked</h1>
<p>{error}</p>
<p class="secondary">nezha WAF</p>
</main>
</body>
</html>

View File

@ -0,0 +1,29 @@
package waf
import (
"math"
"testing"
)
func TestPow(t *testing.T) {
tests := []struct {
x,
y,
expect uint64
}{
{2, 64, math.MaxUint64}, // 2 的 64 次方,超过 uint64 最大值
{uint64(1 << 63), 2, math.MaxUint64}, // 大数平方,可能溢出
{uint64(^uint64(0)), 2, math.MaxUint64}, // uint64 最大值的平方,溢出
{2, 3, 8},
{5, 0, 1},
{3, 1, 3},
{0, 5, 0},
}
for _, tt := range tests {
result := pow(tt.x, tt.y)
if result != tt.expect {
t.Errorf("pow(%d, %d) = %d; expect %d", tt.x, tt.y, result, tt.expect)
}
}
}

View File

@ -5,6 +5,7 @@ import (
"fmt"
"net/http"
"net/netip"
"strings"
"time"
"google.golang.org/grpc"
@ -48,7 +49,9 @@ func getRealIp(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
if len(vals) == 0 {
return nil, fmt.Errorf("real ip header not found")
}
ip, err := netip.ParseAddr(vals[0])
a := strings.Split(vals[0], ",")
h := strings.TrimSpace(a[len(a)-1])
ip, err := netip.ParseAddr(h)
if err != nil {
return nil, err
}

40
model/waf.go Normal file
View File

@ -0,0 +1,40 @@
package model
import (
"errors"
"time"
"gorm.io/gorm"
)
const (
_ uint8 = iota
WAFBlockReasonTypeLoginFail
WAFBlockReasonTypeBruteForceToken
)
type WAF struct {
IP string `gorm:"type:binary(16);primaryKey" json:"ip,omitempty"`
Count uint64 `json:"count,omitempty"`
LastBlockReason uint8 `json:"last_block_reason,omitempty"`
LastBlockTimestamp uint64 `json:"last_block_timestamp,omitempty"`
}
func (w *WAF) TableName() string {
return "waf"
}
func BlockIP(db *gorm.DB, ip string, reason uint8) error {
if ip == "" {
return errors.New("empty ip")
}
var w WAF
w.LastBlockReason = reason
w.LastBlockTimestamp = uint64(time.Now().Unix())
return db.Transaction(func(tx *gorm.DB) error {
if err := tx.FirstOrCreate(&w, WAF{IP: ip}).Error; err != nil {
return err
}
return tx.Exec("UPDATE waf SET count = count + 1 WHERE ip = ?", ip).Error
})
}

View File

@ -2,7 +2,6 @@ package rpc
import (
"context"
"log"
"sync"
"google.golang.org/grpc/codes"
@ -25,9 +24,6 @@ func (a *authHandler) Check(ctx context.Context) (uint64, error) {
return 0, status.Errorf(codes.Unauthenticated, "获取 metaData 失败")
}
realIp := ctx.Value(model.CtxKeyRealIP{})
log.Printf("bingo rpc realIp: %s, metadata: %v", realIp, md)
var clientSecret string
if value, ok := md["client_secret"]; ok {
clientSecret = value[0]

View File

@ -65,7 +65,8 @@ func InitDBFromPath(path string) {
err = DB.AutoMigrate(model.Server{}, model.User{}, model.ServerGroup{}, model.NotificationGroup{},
model.Notification{}, model.AlertRule{}, model.Service{}, model.NotificationGroupNotification{},
model.ServiceHistory{}, model.Cron{}, model.Transfer{}, model.ServerGroupServer{}, model.UserGroup{},
model.UserGroupUser{}, model.NAT{}, model.DDNSProfile{}, model.NotificationGroupNotification{})
model.UserGroupUser{}, model.NAT{}, model.DDNSProfile{}, model.NotificationGroupNotification{},
model.WAF{})
if err != nil {
panic(err)
}