refactor nat

This commit is contained in:
naiba 2024-10-23 20:37:29 +08:00
parent 4635bcf44f
commit c9ec634857
4 changed files with 76 additions and 74 deletions

View File

@ -6,20 +6,15 @@ import (
"log" "log"
"net/http" "net/http"
"strings" "strings"
"time"
jwt "github.com/appleboy/gin-jwt/v2" jwt "github.com/appleboy/gin-jwt/v2"
"github.com/gin-contrib/pprof" "github.com/gin-contrib/pprof"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/hashicorp/go-uuid"
swaggerfiles "github.com/swaggo/files" swaggerfiles "github.com/swaggo/files"
ginSwagger "github.com/swaggo/gin-swagger" ginSwagger "github.com/swaggo/gin-swagger"
docs "github.com/naiba/nezha/cmd/dashboard/docs" docs "github.com/naiba/nezha/cmd/dashboard/docs"
"github.com/naiba/nezha/model" "github.com/naiba/nezha/model"
"github.com/naiba/nezha/pkg/utils"
"github.com/naiba/nezha/proto"
"github.com/naiba/nezha/service/rpc"
"github.com/naiba/nezha/service/singleton" "github.com/naiba/nezha/service/singleton"
) )
@ -31,7 +26,6 @@ func ServeWeb() http.Handler {
gin.SetMode(gin.DebugMode) gin.SetMode(gin.DebugMode)
pprof.Register(r) pprof.Register(r)
} }
r.Use(natGateway)
if singleton.Conf.Debug { if singleton.Conf.Debug {
r.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerfiles.Handler)) r.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerfiles.Handler))
} }
@ -83,67 +77,6 @@ func routers(r *gin.Engine) {
auth.POST("/batch-delete/ddns", commonHandler(batchDeleteDDNS)) auth.POST("/batch-delete/ddns", commonHandler(batchDeleteDDNS))
} }
func natGateway(c *gin.Context) {
natConfig := singleton.GetNATConfigByDomain(c.Request.Host)
if natConfig == nil {
return
}
singleton.ServerLock.RLock()
server := singleton.ServerList[natConfig.ServerID]
singleton.ServerLock.RUnlock()
if server == nil || server.TaskStream == nil {
c.Writer.WriteString("server not found or not connected")
c.Abort()
return
}
streamId, err := uuid.GenerateUUID()
if err != nil {
c.Writer.WriteString(fmt.Sprintf("stream id error: %v", err))
c.Abort()
return
}
rpc.NezhaHandlerSingleton.CreateStream(streamId)
defer rpc.NezhaHandlerSingleton.CloseStream(streamId)
taskData, err := utils.Json.Marshal(model.TaskNAT{
StreamID: streamId,
Host: natConfig.Host,
})
if err != nil {
c.Writer.WriteString(fmt.Sprintf("task data error: %v", err))
c.Abort()
return
}
if err := server.TaskStream.Send(&proto.Task{
Type: model.TaskTypeNAT,
Data: string(taskData),
}); err != nil {
c.Writer.WriteString(fmt.Sprintf("send task error: %v", err))
c.Abort()
return
}
w, err := utils.NewRequestWrapper(c.Request, c.Writer)
if err != nil {
c.Writer.WriteString(fmt.Sprintf("request wrapper error: %v", err))
c.Abort()
return
}
if err := rpc.NezhaHandlerSingleton.UserConnected(streamId, w); err != nil {
c.Writer.WriteString(fmt.Sprintf("user connected error: %v", err))
c.Abort()
return
}
rpc.NezhaHandlerSingleton.StartStream(streamId, time.Second*10)
c.Abort()
}
func recordPath(c *gin.Context) { func recordPath(c *gin.Context) {
url := c.Request.URL.String() url := c.Request.URL.String()
for _, p := range c.Params { for _, p := range c.Params {

View File

@ -159,6 +159,11 @@ func dispatchReportInfoTask() {
func newHTTPandGRPCMux(httpHandler http.Handler, grpcHandler http.Handler) http.Handler { func newHTTPandGRPCMux(httpHandler http.Handler, grpcHandler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
natConfig := singleton.GetNATConfigByDomain(r.Host)
if natConfig != nil {
rpc.ServeNAT(w, r, natConfig)
return
}
if r.ProtoMajor == 2 && r.Header.Get("Content-Type") == "application/grpc" && if r.ProtoMajor == 2 && r.Header.Get("Content-Type") == "application/grpc" &&
strings.HasPrefix(r.URL.Path, "/"+proto.NezhaService_ServiceDesc.ServiceName) { strings.HasPrefix(r.URL.Path, "/"+proto.NezhaService_ServiceDesc.ServiceName) {
grpcHandler.ServeHTTP(w, r) grpcHandler.ServeHTTP(w, r)

View File

@ -1,10 +1,16 @@
package rpc package rpc
import ( import (
"fmt"
"net/http"
"time"
"google.golang.org/grpc" "google.golang.org/grpc"
"github.com/hashicorp/go-uuid"
"github.com/naiba/nezha/model" "github.com/naiba/nezha/model"
pb "github.com/naiba/nezha/proto" "github.com/naiba/nezha/pkg/utils"
"github.com/naiba/nezha/proto"
rpcService "github.com/naiba/nezha/service/rpc" rpcService "github.com/naiba/nezha/service/rpc"
"github.com/naiba/nezha/service/singleton" "github.com/naiba/nezha/service/singleton"
) )
@ -12,7 +18,7 @@ import (
func ServeRPC() *grpc.Server { func ServeRPC() *grpc.Server {
server := grpc.NewServer() server := grpc.NewServer()
rpcService.NezhaHandlerSingleton = rpcService.NewNezhaHandler() rpcService.NezhaHandlerSingleton = rpcService.NewNezhaHandler()
pb.RegisterNezhaServiceServer(server, rpcService.NezhaHandlerSingleton) proto.RegisterNezhaServiceServer(server, rpcService.NezhaHandlerSingleton)
return server return server
} }
@ -69,7 +75,62 @@ func DispatchKeepalive() {
continue continue
} }
singleton.SortedServerList[i].TaskStream.Send(&pb.Task{Type: model.TaskTypeKeepalive}) singleton.SortedServerList[i].TaskStream.Send(&proto.Task{Type: model.TaskTypeKeepalive})
} }
}) })
} }
func ServeNAT(w http.ResponseWriter, r *http.Request, natConfig *model.NAT) {
singleton.ServerLock.RLock()
server := singleton.ServerList[natConfig.ServerID]
singleton.ServerLock.RUnlock()
if server == nil || server.TaskStream == nil {
w.WriteHeader(http.StatusServiceUnavailable)
w.Write([]byte("server not found or not connected"))
return
}
streamId, err := uuid.GenerateUUID()
if err != nil {
w.WriteHeader(http.StatusServiceUnavailable)
w.Write([]byte(fmt.Sprintf("stream id error: %v", err)))
return
}
rpcService.NezhaHandlerSingleton.CreateStream(streamId)
defer rpcService.NezhaHandlerSingleton.CloseStream(streamId)
taskData, err := utils.Json.Marshal(model.TaskNAT{
StreamID: streamId,
Host: natConfig.Host,
})
if err != nil {
w.WriteHeader(http.StatusServiceUnavailable)
w.Write([]byte(fmt.Sprintf("task data error: %v", err)))
return
}
if err := server.TaskStream.Send(&proto.Task{
Type: model.TaskTypeNAT,
Data: string(taskData),
}); err != nil {
w.WriteHeader(http.StatusServiceUnavailable)
w.Write([]byte(fmt.Sprintf("send task error: %v", err)))
return
}
wWrapped, err := utils.NewRequestWrapper(r, w)
if err != nil {
w.WriteHeader(http.StatusServiceUnavailable)
w.Write([]byte(fmt.Sprintf("request wrapper error: %v", err)))
return
}
if err := rpcService.NezhaHandlerSingleton.UserConnected(streamId, wWrapped); err != nil {
w.WriteHeader(http.StatusServiceUnavailable)
w.Write([]byte(fmt.Sprintf("user connected error: %v", err)))
return
}
rpcService.NezhaHandlerSingleton.StartStream(streamId, time.Second*10)
}

View File

@ -2,11 +2,10 @@ package utils
import ( import (
"bytes" "bytes"
"errors"
"io" "io"
"net" "net"
"net/http" "net/http"
"github.com/gin-gonic/gin"
) )
var _ io.ReadWriteCloser = &RequestWrapper{} var _ io.ReadWriteCloser = &RequestWrapper{}
@ -17,8 +16,12 @@ type RequestWrapper struct {
writer net.Conn writer net.Conn
} }
func NewRequestWrapper(req *http.Request, writer gin.ResponseWriter) (*RequestWrapper, error) { func NewRequestWrapper(req *http.Request, writer http.ResponseWriter) (*RequestWrapper, error) {
conn, _, err := writer.Hijack() hj, ok := writer.(http.Hijacker)
if !ok {
return nil, errors.New("http server does not support hijacking")
}
conn, _, err := hj.Hijack()
if err != nil { if err != nil {
return nil, err return nil, err
} }