mirror of
https://github.com/nezhahq/nezha.git
synced 2025-01-22 12:48:14 -05:00
feat: waf 🤡
This commit is contained in:
parent
d699d0ee87
commit
17b02640a9
@ -5,7 +5,6 @@ import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@ -16,6 +15,7 @@ import (
|
||||
swaggerfiles "github.com/swaggo/files"
|
||||
ginSwagger "github.com/swaggo/gin-swagger"
|
||||
|
||||
"github.com/naiba/nezha/cmd/dashboard/controller/waf"
|
||||
docs "github.com/naiba/nezha/cmd/dashboard/docs"
|
||||
"github.com/naiba/nezha/model"
|
||||
"github.com/naiba/nezha/service/singleton"
|
||||
@ -34,39 +34,14 @@ func ServeWeb() http.Handler {
|
||||
r.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerfiles.Handler))
|
||||
}
|
||||
|
||||
r.Use(realIp)
|
||||
r.Use(waf.RealIp)
|
||||
r.Use(waf.Waf)
|
||||
r.Use(recordPath)
|
||||
routers(r)
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
func realIp(c *gin.Context) {
|
||||
if singleton.Conf.RealIPHeader == "" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
if singleton.Conf.RealIPHeader == model.ConfigUsePeerIP {
|
||||
c.Set(model.CtxKeyRealIPStr, c.RemoteIP())
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
vals := c.Request.Header.Get(singleton.Conf.RealIPHeader)
|
||||
if vals == "" {
|
||||
c.AbortWithStatusJSON(http.StatusOK, model.CommonResponse[any]{Success: false, Error: "real ip header not found"})
|
||||
return
|
||||
}
|
||||
ip, err := netip.ParseAddr(vals)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusOK, model.CommonResponse[any]{Success: false, Error: err.Error()})
|
||||
return
|
||||
}
|
||||
c.Set(model.CtxKeyRealIPStr, ip.String())
|
||||
c.Next()
|
||||
}
|
||||
|
||||
func routers(r *gin.Engine) {
|
||||
authMiddleware, err := jwt.New(initParams())
|
||||
if err != nil {
|
||||
@ -154,7 +129,6 @@ func routers(r *gin.Engine) {
|
||||
}
|
||||
|
||||
func recordPath(c *gin.Context) {
|
||||
log.Printf("bingo web real ip: %s", c.GetString(model.CtxKeyRealIPStr))
|
||||
url := c.Request.URL.String()
|
||||
for _, p := range c.Params {
|
||||
url = strings.Replace(url, p.Value, ":"+p.Key, 1)
|
||||
|
@ -9,6 +9,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
"github.com/naiba/nezha/cmd/dashboard/controller/waf"
|
||||
"github.com/naiba/nezha/model"
|
||||
"github.com/naiba/nezha/pkg/utils"
|
||||
"github.com/naiba/nezha/service/singleton"
|
||||
@ -87,10 +88,12 @@ func authenticator() func(c *gin.Context) (interface{}, error) {
|
||||
|
||||
var user model.User
|
||||
if err := singleton.DB.Select("id", "password").Where("username = ?", loginVals.Username).First(&user).Error; err != nil {
|
||||
model.BlockIP(singleton.DB, c.GetString(model.CtxKeyRealIPStr), model.WAFBlockReasonTypeLoginFail)
|
||||
return nil, jwt.ErrFailedAuthentication
|
||||
}
|
||||
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(loginVals.Password)); err != nil {
|
||||
model.BlockIP(singleton.DB, c.GetString(model.CtxKeyRealIPStr), model.WAFBlockReasonTypeLoginFail)
|
||||
return nil, jwt.ErrFailedAuthentication
|
||||
}
|
||||
|
||||
@ -163,6 +166,10 @@ func optionalAuthMiddleware(mw *jwt.GinJWTMiddleware) func(c *gin.Context) {
|
||||
identity := mw.IdentityHandler(c)
|
||||
|
||||
if identity != nil {
|
||||
if err := model.BlockIP(singleton.DB, c.GetString(model.CtxKeyRealIPStr), model.WAFBlockReasonTypeBruteForceToken); err != nil {
|
||||
waf.ShowBlockPage(c, err)
|
||||
return
|
||||
}
|
||||
c.Set(mw.IdentityKey, identity)
|
||||
}
|
||||
|
||||
|
90
cmd/dashboard/controller/waf/waf.go
Normal file
90
cmd/dashboard/controller/waf/waf.go
Normal file
@ -0,0 +1,90 @@
|
||||
package waf
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"errors"
|
||||
"log"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/naiba/nezha/model"
|
||||
"github.com/naiba/nezha/service/singleton"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
//go:embed waf.html
|
||||
var errorPageTemplate string
|
||||
|
||||
func RealIp(c *gin.Context) {
|
||||
if singleton.Conf.RealIPHeader == "" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
if singleton.Conf.RealIPHeader == model.ConfigUsePeerIP {
|
||||
c.Set(model.CtxKeyRealIPStr, c.RemoteIP())
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
vals := c.Request.Header.Get(singleton.Conf.RealIPHeader)
|
||||
if vals == "" {
|
||||
c.AbortWithStatusJSON(http.StatusOK, model.CommonResponse[any]{Success: false, Error: "real ip header not found"})
|
||||
return
|
||||
}
|
||||
ip, err := netip.ParseAddr(vals)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusOK, model.CommonResponse[any]{Success: false, Error: err.Error()})
|
||||
return
|
||||
}
|
||||
c.Set(model.CtxKeyRealIPStr, ip.String())
|
||||
c.Next()
|
||||
}
|
||||
|
||||
func Waf(c *gin.Context) {
|
||||
if singleton.Conf.RealIPHeader == "" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
realipAddr := c.GetString(model.CtxKeyRealIPStr)
|
||||
if realipAddr == "" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
var w model.WAF
|
||||
if err := singleton.DB.First(&w, "ip = ?", realipAddr).Error; err != nil {
|
||||
if err != gorm.ErrRecordNotFound {
|
||||
ShowBlockPage(c, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
now := time.Now().Unix()
|
||||
if w.LastBlockTimestamp+pow(w.Count, 4) > uint64(now) {
|
||||
log.Println(w.Count, w.LastBlockTimestamp+pow(w.Count, 4)-uint64(now))
|
||||
ShowBlockPage(c, errors.New("you are blocked by nezha WAF"))
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
|
||||
func pow(x, y uint64) uint64 {
|
||||
base := big.NewInt(0).SetUint64(x)
|
||||
exp := big.NewInt(0).SetUint64(y)
|
||||
result := big.NewInt(1)
|
||||
result.Exp(base, exp, nil)
|
||||
if !result.IsUint64() {
|
||||
return ^uint64(0) // return max uint64 value on overflow
|
||||
}
|
||||
return result.Uint64()
|
||||
}
|
||||
|
||||
func ShowBlockPage(c *gin.Context, err error) {
|
||||
c.Writer.WriteHeader(http.StatusForbidden)
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
c.Writer.WriteString(strings.Replace(errorPageTemplate, "{error}", err.Error(), 1))
|
||||
c.Abort()
|
||||
}
|
39
cmd/dashboard/controller/waf/waf.html
Normal file
39
cmd/dashboard/controller/waf/waf.html
Normal file
@ -0,0 +1,39 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Blocked</title>
|
||||
<style>
|
||||
body {
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
height: 90vh;
|
||||
font-weight: bolder;
|
||||
font-family: 'Courier New', Courier, monospace;
|
||||
}
|
||||
main {
|
||||
text-align: center;
|
||||
}
|
||||
.emoji {
|
||||
font-size: 200px;
|
||||
}
|
||||
p.secondary {
|
||||
font-size: 12px;
|
||||
color: #888;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
|
||||
<body>
|
||||
<main>
|
||||
<div class="emoji">🤡</div>
|
||||
<h1>Blocked</h1>
|
||||
<p>{error}</p>
|
||||
<p class="secondary">nezha WAF</p>
|
||||
</main>
|
||||
</body>
|
||||
|
||||
</html>
|
29
cmd/dashboard/controller/waf/waf_test.go
Normal file
29
cmd/dashboard/controller/waf/waf_test.go
Normal file
@ -0,0 +1,29 @@
|
||||
package waf
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPow(t *testing.T) {
|
||||
tests := []struct {
|
||||
x,
|
||||
y,
|
||||
expect uint64
|
||||
}{
|
||||
{2, 64, math.MaxUint64}, // 2 的 64 次方,超过 uint64 最大值
|
||||
{uint64(1 << 63), 2, math.MaxUint64}, // 大数平方,可能溢出
|
||||
{uint64(^uint64(0)), 2, math.MaxUint64}, // uint64 最大值的平方,溢出
|
||||
{2, 3, 8},
|
||||
{5, 0, 1},
|
||||
{3, 1, 3},
|
||||
{0, 5, 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := pow(tt.x, tt.y)
|
||||
if result != tt.expect {
|
||||
t.Errorf("pow(%d, %d) = %d; expect %d", tt.x, tt.y, result, tt.expect)
|
||||
}
|
||||
}
|
||||
}
|
@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
@ -48,7 +49,9 @@ func getRealIp(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
|
||||
if len(vals) == 0 {
|
||||
return nil, fmt.Errorf("real ip header not found")
|
||||
}
|
||||
ip, err := netip.ParseAddr(vals[0])
|
||||
a := strings.Split(vals[0], ",")
|
||||
h := strings.TrimSpace(a[len(a)-1])
|
||||
ip, err := netip.ParseAddr(h)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
40
model/waf.go
Normal file
40
model/waf.go
Normal file
@ -0,0 +1,40 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
_ uint8 = iota
|
||||
WAFBlockReasonTypeLoginFail
|
||||
WAFBlockReasonTypeBruteForceToken
|
||||
)
|
||||
|
||||
type WAF struct {
|
||||
IP string `gorm:"type:binary(16);primaryKey" json:"ip,omitempty"`
|
||||
Count uint64 `json:"count,omitempty"`
|
||||
LastBlockReason uint8 `json:"last_block_reason,omitempty"`
|
||||
LastBlockTimestamp uint64 `json:"last_block_timestamp,omitempty"`
|
||||
}
|
||||
|
||||
func (w *WAF) TableName() string {
|
||||
return "waf"
|
||||
}
|
||||
|
||||
func BlockIP(db *gorm.DB, ip string, reason uint8) error {
|
||||
if ip == "" {
|
||||
return errors.New("empty ip")
|
||||
}
|
||||
var w WAF
|
||||
w.LastBlockReason = reason
|
||||
w.LastBlockTimestamp = uint64(time.Now().Unix())
|
||||
return db.Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.FirstOrCreate(&w, WAF{IP: ip}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Exec("UPDATE waf SET count = count + 1 WHERE ip = ?", ip).Error
|
||||
})
|
||||
}
|
@ -2,7 +2,6 @@ package rpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"sync"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
@ -25,9 +24,6 @@ func (a *authHandler) Check(ctx context.Context) (uint64, error) {
|
||||
return 0, status.Errorf(codes.Unauthenticated, "获取 metaData 失败")
|
||||
}
|
||||
|
||||
realIp := ctx.Value(model.CtxKeyRealIP{})
|
||||
log.Printf("bingo rpc realIp: %s, metadata: %v", realIp, md)
|
||||
|
||||
var clientSecret string
|
||||
if value, ok := md["client_secret"]; ok {
|
||||
clientSecret = value[0]
|
||||
|
@ -65,7 +65,8 @@ func InitDBFromPath(path string) {
|
||||
err = DB.AutoMigrate(model.Server{}, model.User{}, model.ServerGroup{}, model.NotificationGroup{},
|
||||
model.Notification{}, model.AlertRule{}, model.Service{}, model.NotificationGroupNotification{},
|
||||
model.ServiceHistory{}, model.Cron{}, model.Transfer{}, model.ServerGroupServer{}, model.UserGroup{},
|
||||
model.UserGroupUser{}, model.NAT{}, model.DDNSProfile{}, model.NotificationGroupNotification{})
|
||||
model.UserGroupUser{}, model.NAT{}, model.DDNSProfile{}, model.NotificationGroupNotification{},
|
||||
model.WAF{})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user