mirror of
https://github.com/nezhahq/nezha.git
synced 2025-02-08 12:38:13 -05:00
feat: waf 🤡
This commit is contained in:
parent
d699d0ee87
commit
17b02640a9
@ -5,7 +5,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/netip"
|
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
@ -16,6 +15,7 @@ import (
|
|||||||
swaggerfiles "github.com/swaggo/files"
|
swaggerfiles "github.com/swaggo/files"
|
||||||
ginSwagger "github.com/swaggo/gin-swagger"
|
ginSwagger "github.com/swaggo/gin-swagger"
|
||||||
|
|
||||||
|
"github.com/naiba/nezha/cmd/dashboard/controller/waf"
|
||||||
docs "github.com/naiba/nezha/cmd/dashboard/docs"
|
docs "github.com/naiba/nezha/cmd/dashboard/docs"
|
||||||
"github.com/naiba/nezha/model"
|
"github.com/naiba/nezha/model"
|
||||||
"github.com/naiba/nezha/service/singleton"
|
"github.com/naiba/nezha/service/singleton"
|
||||||
@ -34,39 +34,14 @@ func ServeWeb() http.Handler {
|
|||||||
r.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerfiles.Handler))
|
r.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerfiles.Handler))
|
||||||
}
|
}
|
||||||
|
|
||||||
r.Use(realIp)
|
r.Use(waf.RealIp)
|
||||||
|
r.Use(waf.Waf)
|
||||||
r.Use(recordPath)
|
r.Use(recordPath)
|
||||||
routers(r)
|
routers(r)
|
||||||
|
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
func realIp(c *gin.Context) {
|
|
||||||
if singleton.Conf.RealIPHeader == "" {
|
|
||||||
c.Next()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if singleton.Conf.RealIPHeader == model.ConfigUsePeerIP {
|
|
||||||
c.Set(model.CtxKeyRealIPStr, c.RemoteIP())
|
|
||||||
c.Next()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
vals := c.Request.Header.Get(singleton.Conf.RealIPHeader)
|
|
||||||
if vals == "" {
|
|
||||||
c.AbortWithStatusJSON(http.StatusOK, model.CommonResponse[any]{Success: false, Error: "real ip header not found"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ip, err := netip.ParseAddr(vals)
|
|
||||||
if err != nil {
|
|
||||||
c.AbortWithStatusJSON(http.StatusOK, model.CommonResponse[any]{Success: false, Error: err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.Set(model.CtxKeyRealIPStr, ip.String())
|
|
||||||
c.Next()
|
|
||||||
}
|
|
||||||
|
|
||||||
func routers(r *gin.Engine) {
|
func routers(r *gin.Engine) {
|
||||||
authMiddleware, err := jwt.New(initParams())
|
authMiddleware, err := jwt.New(initParams())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -154,7 +129,6 @@ func routers(r *gin.Engine) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func recordPath(c *gin.Context) {
|
func recordPath(c *gin.Context) {
|
||||||
log.Printf("bingo web real ip: %s", c.GetString(model.CtxKeyRealIPStr))
|
|
||||||
url := c.Request.URL.String()
|
url := c.Request.URL.String()
|
||||||
for _, p := range c.Params {
|
for _, p := range c.Params {
|
||||||
url = strings.Replace(url, p.Value, ":"+p.Key, 1)
|
url = strings.Replace(url, p.Value, ":"+p.Key, 1)
|
||||||
|
@ -9,6 +9,7 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
|
||||||
|
"github.com/naiba/nezha/cmd/dashboard/controller/waf"
|
||||||
"github.com/naiba/nezha/model"
|
"github.com/naiba/nezha/model"
|
||||||
"github.com/naiba/nezha/pkg/utils"
|
"github.com/naiba/nezha/pkg/utils"
|
||||||
"github.com/naiba/nezha/service/singleton"
|
"github.com/naiba/nezha/service/singleton"
|
||||||
@ -87,10 +88,12 @@ func authenticator() func(c *gin.Context) (interface{}, error) {
|
|||||||
|
|
||||||
var user model.User
|
var user model.User
|
||||||
if err := singleton.DB.Select("id", "password").Where("username = ?", loginVals.Username).First(&user).Error; err != nil {
|
if err := singleton.DB.Select("id", "password").Where("username = ?", loginVals.Username).First(&user).Error; err != nil {
|
||||||
|
model.BlockIP(singleton.DB, c.GetString(model.CtxKeyRealIPStr), model.WAFBlockReasonTypeLoginFail)
|
||||||
return nil, jwt.ErrFailedAuthentication
|
return nil, jwt.ErrFailedAuthentication
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(loginVals.Password)); err != nil {
|
if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(loginVals.Password)); err != nil {
|
||||||
|
model.BlockIP(singleton.DB, c.GetString(model.CtxKeyRealIPStr), model.WAFBlockReasonTypeLoginFail)
|
||||||
return nil, jwt.ErrFailedAuthentication
|
return nil, jwt.ErrFailedAuthentication
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -163,6 +166,10 @@ func optionalAuthMiddleware(mw *jwt.GinJWTMiddleware) func(c *gin.Context) {
|
|||||||
identity := mw.IdentityHandler(c)
|
identity := mw.IdentityHandler(c)
|
||||||
|
|
||||||
if identity != nil {
|
if identity != nil {
|
||||||
|
if err := model.BlockIP(singleton.DB, c.GetString(model.CtxKeyRealIPStr), model.WAFBlockReasonTypeBruteForceToken); err != nil {
|
||||||
|
waf.ShowBlockPage(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
c.Set(mw.IdentityKey, identity)
|
c.Set(mw.IdentityKey, identity)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
90
cmd/dashboard/controller/waf/waf.go
Normal file
90
cmd/dashboard/controller/waf/waf.go
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
package waf
|
||||||
|
|
||||||
|
import (
|
||||||
|
_ "embed"
|
||||||
|
"errors"
|
||||||
|
"log"
|
||||||
|
"math/big"
|
||||||
|
"net/http"
|
||||||
|
"net/netip"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/naiba/nezha/model"
|
||||||
|
"github.com/naiba/nezha/service/singleton"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
//go:embed waf.html
|
||||||
|
var errorPageTemplate string
|
||||||
|
|
||||||
|
func RealIp(c *gin.Context) {
|
||||||
|
if singleton.Conf.RealIPHeader == "" {
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if singleton.Conf.RealIPHeader == model.ConfigUsePeerIP {
|
||||||
|
c.Set(model.CtxKeyRealIPStr, c.RemoteIP())
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
vals := c.Request.Header.Get(singleton.Conf.RealIPHeader)
|
||||||
|
if vals == "" {
|
||||||
|
c.AbortWithStatusJSON(http.StatusOK, model.CommonResponse[any]{Success: false, Error: "real ip header not found"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ip, err := netip.ParseAddr(vals)
|
||||||
|
if err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusOK, model.CommonResponse[any]{Success: false, Error: err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Set(model.CtxKeyRealIPStr, ip.String())
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
|
||||||
|
func Waf(c *gin.Context) {
|
||||||
|
if singleton.Conf.RealIPHeader == "" {
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
realipAddr := c.GetString(model.CtxKeyRealIPStr)
|
||||||
|
if realipAddr == "" {
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var w model.WAF
|
||||||
|
if err := singleton.DB.First(&w, "ip = ?", realipAddr).Error; err != nil {
|
||||||
|
if err != gorm.ErrRecordNotFound {
|
||||||
|
ShowBlockPage(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
now := time.Now().Unix()
|
||||||
|
if w.LastBlockTimestamp+pow(w.Count, 4) > uint64(now) {
|
||||||
|
log.Println(w.Count, w.LastBlockTimestamp+pow(w.Count, 4)-uint64(now))
|
||||||
|
ShowBlockPage(c, errors.New("you are blocked by nezha WAF"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
|
||||||
|
func pow(x, y uint64) uint64 {
|
||||||
|
base := big.NewInt(0).SetUint64(x)
|
||||||
|
exp := big.NewInt(0).SetUint64(y)
|
||||||
|
result := big.NewInt(1)
|
||||||
|
result.Exp(base, exp, nil)
|
||||||
|
if !result.IsUint64() {
|
||||||
|
return ^uint64(0) // return max uint64 value on overflow
|
||||||
|
}
|
||||||
|
return result.Uint64()
|
||||||
|
}
|
||||||
|
|
||||||
|
func ShowBlockPage(c *gin.Context, err error) {
|
||||||
|
c.Writer.WriteHeader(http.StatusForbidden)
|
||||||
|
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||||
|
c.Writer.WriteString(strings.Replace(errorPageTemplate, "{error}", err.Error(), 1))
|
||||||
|
c.Abort()
|
||||||
|
}
|
39
cmd/dashboard/controller/waf/waf.html
Normal file
39
cmd/dashboard/controller/waf/waf.html
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>Blocked</title>
|
||||||
|
<style>
|
||||||
|
body {
|
||||||
|
display: flex;
|
||||||
|
justify-content: center;
|
||||||
|
align-items: center;
|
||||||
|
height: 90vh;
|
||||||
|
font-weight: bolder;
|
||||||
|
font-family: 'Courier New', Courier, monospace;
|
||||||
|
}
|
||||||
|
main {
|
||||||
|
text-align: center;
|
||||||
|
}
|
||||||
|
.emoji {
|
||||||
|
font-size: 200px;
|
||||||
|
}
|
||||||
|
p.secondary {
|
||||||
|
font-size: 12px;
|
||||||
|
color: #888;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
|
||||||
|
<body>
|
||||||
|
<main>
|
||||||
|
<div class="emoji">🤡</div>
|
||||||
|
<h1>Blocked</h1>
|
||||||
|
<p>{error}</p>
|
||||||
|
<p class="secondary">nezha WAF</p>
|
||||||
|
</main>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
</html>
|
29
cmd/dashboard/controller/waf/waf_test.go
Normal file
29
cmd/dashboard/controller/waf/waf_test.go
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
package waf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPow(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
x,
|
||||||
|
y,
|
||||||
|
expect uint64
|
||||||
|
}{
|
||||||
|
{2, 64, math.MaxUint64}, // 2 的 64 次方,超过 uint64 最大值
|
||||||
|
{uint64(1 << 63), 2, math.MaxUint64}, // 大数平方,可能溢出
|
||||||
|
{uint64(^uint64(0)), 2, math.MaxUint64}, // uint64 最大值的平方,溢出
|
||||||
|
{2, 3, 8},
|
||||||
|
{5, 0, 1},
|
||||||
|
{3, 1, 3},
|
||||||
|
{0, 5, 0},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
result := pow(tt.x, tt.y)
|
||||||
|
if result != tt.expect {
|
||||||
|
t.Errorf("pow(%d, %d) = %d; expect %d", tt.x, tt.y, result, tt.expect)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -5,6 +5,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
@ -48,7 +49,9 @@ func getRealIp(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
|
|||||||
if len(vals) == 0 {
|
if len(vals) == 0 {
|
||||||
return nil, fmt.Errorf("real ip header not found")
|
return nil, fmt.Errorf("real ip header not found")
|
||||||
}
|
}
|
||||||
ip, err := netip.ParseAddr(vals[0])
|
a := strings.Split(vals[0], ",")
|
||||||
|
h := strings.TrimSpace(a[len(a)-1])
|
||||||
|
ip, err := netip.ParseAddr(h)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
40
model/waf.go
Normal file
40
model/waf.go
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
_ uint8 = iota
|
||||||
|
WAFBlockReasonTypeLoginFail
|
||||||
|
WAFBlockReasonTypeBruteForceToken
|
||||||
|
)
|
||||||
|
|
||||||
|
type WAF struct {
|
||||||
|
IP string `gorm:"type:binary(16);primaryKey" json:"ip,omitempty"`
|
||||||
|
Count uint64 `json:"count,omitempty"`
|
||||||
|
LastBlockReason uint8 `json:"last_block_reason,omitempty"`
|
||||||
|
LastBlockTimestamp uint64 `json:"last_block_timestamp,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *WAF) TableName() string {
|
||||||
|
return "waf"
|
||||||
|
}
|
||||||
|
|
||||||
|
func BlockIP(db *gorm.DB, ip string, reason uint8) error {
|
||||||
|
if ip == "" {
|
||||||
|
return errors.New("empty ip")
|
||||||
|
}
|
||||||
|
var w WAF
|
||||||
|
w.LastBlockReason = reason
|
||||||
|
w.LastBlockTimestamp = uint64(time.Now().Unix())
|
||||||
|
return db.Transaction(func(tx *gorm.DB) error {
|
||||||
|
if err := tx.FirstOrCreate(&w, WAF{IP: ip}).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return tx.Exec("UPDATE waf SET count = count + 1 WHERE ip = ?", ip).Error
|
||||||
|
})
|
||||||
|
}
|
@ -2,7 +2,6 @@ package rpc
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"log"
|
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
@ -25,9 +24,6 @@ func (a *authHandler) Check(ctx context.Context) (uint64, error) {
|
|||||||
return 0, status.Errorf(codes.Unauthenticated, "获取 metaData 失败")
|
return 0, status.Errorf(codes.Unauthenticated, "获取 metaData 失败")
|
||||||
}
|
}
|
||||||
|
|
||||||
realIp := ctx.Value(model.CtxKeyRealIP{})
|
|
||||||
log.Printf("bingo rpc realIp: %s, metadata: %v", realIp, md)
|
|
||||||
|
|
||||||
var clientSecret string
|
var clientSecret string
|
||||||
if value, ok := md["client_secret"]; ok {
|
if value, ok := md["client_secret"]; ok {
|
||||||
clientSecret = value[0]
|
clientSecret = value[0]
|
||||||
|
@ -65,7 +65,8 @@ func InitDBFromPath(path string) {
|
|||||||
err = DB.AutoMigrate(model.Server{}, model.User{}, model.ServerGroup{}, model.NotificationGroup{},
|
err = DB.AutoMigrate(model.Server{}, model.User{}, model.ServerGroup{}, model.NotificationGroup{},
|
||||||
model.Notification{}, model.AlertRule{}, model.Service{}, model.NotificationGroupNotification{},
|
model.Notification{}, model.AlertRule{}, model.Service{}, model.NotificationGroupNotification{},
|
||||||
model.ServiceHistory{}, model.Cron{}, model.Transfer{}, model.ServerGroupServer{}, model.UserGroup{},
|
model.ServiceHistory{}, model.Cron{}, model.Transfer{}, model.ServerGroupServer{}, model.UserGroup{},
|
||||||
model.UserGroupUser{}, model.NAT{}, model.DDNSProfile{}, model.NotificationGroupNotification{})
|
model.UserGroupUser{}, model.NAT{}, model.DDNSProfile{}, model.NotificationGroupNotification{},
|
||||||
|
model.WAF{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user