From c9ec6348572600b26233027eba2f501821daef2f Mon Sep 17 00:00:00 2001 From: naiba Date: Wed, 23 Oct 2024 20:37:29 +0800 Subject: [PATCH] refactor nat --- cmd/dashboard/controller/controller.go | 67 -------------------------- cmd/dashboard/main.go | 5 ++ cmd/dashboard/rpc/rpc.go | 67 ++++++++++++++++++++++++-- pkg/utils/request_wrapper.go | 11 +++-- 4 files changed, 76 insertions(+), 74 deletions(-) diff --git a/cmd/dashboard/controller/controller.go b/cmd/dashboard/controller/controller.go index 2a8508e..1195a02 100644 --- a/cmd/dashboard/controller/controller.go +++ b/cmd/dashboard/controller/controller.go @@ -6,20 +6,15 @@ import ( "log" "net/http" "strings" - "time" jwt "github.com/appleboy/gin-jwt/v2" "github.com/gin-contrib/pprof" "github.com/gin-gonic/gin" - "github.com/hashicorp/go-uuid" swaggerfiles "github.com/swaggo/files" ginSwagger "github.com/swaggo/gin-swagger" docs "github.com/naiba/nezha/cmd/dashboard/docs" "github.com/naiba/nezha/model" - "github.com/naiba/nezha/pkg/utils" - "github.com/naiba/nezha/proto" - "github.com/naiba/nezha/service/rpc" "github.com/naiba/nezha/service/singleton" ) @@ -31,7 +26,6 @@ func ServeWeb() http.Handler { gin.SetMode(gin.DebugMode) pprof.Register(r) } - r.Use(natGateway) if singleton.Conf.Debug { r.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerfiles.Handler)) } @@ -83,67 +77,6 @@ func routers(r *gin.Engine) { auth.POST("/batch-delete/ddns", commonHandler(batchDeleteDDNS)) } -func natGateway(c *gin.Context) { - natConfig := singleton.GetNATConfigByDomain(c.Request.Host) - if natConfig == nil { - return - } - - singleton.ServerLock.RLock() - server := singleton.ServerList[natConfig.ServerID] - singleton.ServerLock.RUnlock() - if server == nil || server.TaskStream == nil { - c.Writer.WriteString("server not found or not connected") - c.Abort() - return - } - - streamId, err := uuid.GenerateUUID() - if err != nil { - c.Writer.WriteString(fmt.Sprintf("stream id error: %v", err)) - c.Abort() - return - } - - rpc.NezhaHandlerSingleton.CreateStream(streamId) - defer rpc.NezhaHandlerSingleton.CloseStream(streamId) - - taskData, err := utils.Json.Marshal(model.TaskNAT{ - StreamID: streamId, - Host: natConfig.Host, - }) - if err != nil { - c.Writer.WriteString(fmt.Sprintf("task data error: %v", err)) - c.Abort() - return - } - - if err := server.TaskStream.Send(&proto.Task{ - Type: model.TaskTypeNAT, - Data: string(taskData), - }); err != nil { - c.Writer.WriteString(fmt.Sprintf("send task error: %v", err)) - c.Abort() - return - } - - w, err := utils.NewRequestWrapper(c.Request, c.Writer) - if err != nil { - c.Writer.WriteString(fmt.Sprintf("request wrapper error: %v", err)) - c.Abort() - return - } - - if err := rpc.NezhaHandlerSingleton.UserConnected(streamId, w); err != nil { - c.Writer.WriteString(fmt.Sprintf("user connected error: %v", err)) - c.Abort() - return - } - - rpc.NezhaHandlerSingleton.StartStream(streamId, time.Second*10) - c.Abort() -} - func recordPath(c *gin.Context) { url := c.Request.URL.String() for _, p := range c.Params { diff --git a/cmd/dashboard/main.go b/cmd/dashboard/main.go index 6822a50..6f2d94a 100644 --- a/cmd/dashboard/main.go +++ b/cmd/dashboard/main.go @@ -159,6 +159,11 @@ func dispatchReportInfoTask() { func newHTTPandGRPCMux(httpHandler http.Handler, grpcHandler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + natConfig := singleton.GetNATConfigByDomain(r.Host) + if natConfig != nil { + rpc.ServeNAT(w, r, natConfig) + return + } if r.ProtoMajor == 2 && r.Header.Get("Content-Type") == "application/grpc" && strings.HasPrefix(r.URL.Path, "/"+proto.NezhaService_ServiceDesc.ServiceName) { grpcHandler.ServeHTTP(w, r) diff --git a/cmd/dashboard/rpc/rpc.go b/cmd/dashboard/rpc/rpc.go index b55d9b9..55689ec 100644 --- a/cmd/dashboard/rpc/rpc.go +++ b/cmd/dashboard/rpc/rpc.go @@ -1,10 +1,16 @@ package rpc import ( + "fmt" + "net/http" + "time" + "google.golang.org/grpc" + "github.com/hashicorp/go-uuid" "github.com/naiba/nezha/model" - pb "github.com/naiba/nezha/proto" + "github.com/naiba/nezha/pkg/utils" + "github.com/naiba/nezha/proto" rpcService "github.com/naiba/nezha/service/rpc" "github.com/naiba/nezha/service/singleton" ) @@ -12,7 +18,7 @@ import ( func ServeRPC() *grpc.Server { server := grpc.NewServer() rpcService.NezhaHandlerSingleton = rpcService.NewNezhaHandler() - pb.RegisterNezhaServiceServer(server, rpcService.NezhaHandlerSingleton) + proto.RegisterNezhaServiceServer(server, rpcService.NezhaHandlerSingleton) return server } @@ -69,7 +75,62 @@ func DispatchKeepalive() { continue } - singleton.SortedServerList[i].TaskStream.Send(&pb.Task{Type: model.TaskTypeKeepalive}) + singleton.SortedServerList[i].TaskStream.Send(&proto.Task{Type: model.TaskTypeKeepalive}) } }) } + +func ServeNAT(w http.ResponseWriter, r *http.Request, natConfig *model.NAT) { + singleton.ServerLock.RLock() + server := singleton.ServerList[natConfig.ServerID] + singleton.ServerLock.RUnlock() + if server == nil || server.TaskStream == nil { + w.WriteHeader(http.StatusServiceUnavailable) + w.Write([]byte("server not found or not connected")) + return + } + + streamId, err := uuid.GenerateUUID() + if err != nil { + w.WriteHeader(http.StatusServiceUnavailable) + w.Write([]byte(fmt.Sprintf("stream id error: %v", err))) + return + } + + rpcService.NezhaHandlerSingleton.CreateStream(streamId) + defer rpcService.NezhaHandlerSingleton.CloseStream(streamId) + + taskData, err := utils.Json.Marshal(model.TaskNAT{ + StreamID: streamId, + Host: natConfig.Host, + }) + if err != nil { + w.WriteHeader(http.StatusServiceUnavailable) + w.Write([]byte(fmt.Sprintf("task data error: %v", err))) + return + } + + if err := server.TaskStream.Send(&proto.Task{ + Type: model.TaskTypeNAT, + Data: string(taskData), + }); err != nil { + w.WriteHeader(http.StatusServiceUnavailable) + w.Write([]byte(fmt.Sprintf("send task error: %v", err))) + return + } + + wWrapped, err := utils.NewRequestWrapper(r, w) + if err != nil { + w.WriteHeader(http.StatusServiceUnavailable) + w.Write([]byte(fmt.Sprintf("request wrapper error: %v", err))) + return + } + + if err := rpcService.NezhaHandlerSingleton.UserConnected(streamId, wWrapped); err != nil { + w.WriteHeader(http.StatusServiceUnavailable) + w.Write([]byte(fmt.Sprintf("user connected error: %v", err))) + return + } + + rpcService.NezhaHandlerSingleton.StartStream(streamId, time.Second*10) +} diff --git a/pkg/utils/request_wrapper.go b/pkg/utils/request_wrapper.go index 9699129..e4f6103 100644 --- a/pkg/utils/request_wrapper.go +++ b/pkg/utils/request_wrapper.go @@ -2,11 +2,10 @@ package utils import ( "bytes" + "errors" "io" "net" "net/http" - - "github.com/gin-gonic/gin" ) var _ io.ReadWriteCloser = &RequestWrapper{} @@ -17,8 +16,12 @@ type RequestWrapper struct { writer net.Conn } -func NewRequestWrapper(req *http.Request, writer gin.ResponseWriter) (*RequestWrapper, error) { - conn, _, err := writer.Hijack() +func NewRequestWrapper(req *http.Request, writer http.ResponseWriter) (*RequestWrapper, error) { + hj, ok := writer.(http.Hijacker) + if !ok { + return nil, errors.New("http server does not support hijacking") + } + conn, _, err := hj.Hijack() if err != nil { return nil, err }