mirror of
https://github.com/nezhahq/nezha.git
synced 2025-01-22 12:48:14 -05:00
refactor nat
This commit is contained in:
parent
4635bcf44f
commit
c9ec634857
@ -6,20 +6,15 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
jwt "github.com/appleboy/gin-jwt/v2"
|
jwt "github.com/appleboy/gin-jwt/v2"
|
||||||
"github.com/gin-contrib/pprof"
|
"github.com/gin-contrib/pprof"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/hashicorp/go-uuid"
|
|
||||||
swaggerfiles "github.com/swaggo/files"
|
swaggerfiles "github.com/swaggo/files"
|
||||||
ginSwagger "github.com/swaggo/gin-swagger"
|
ginSwagger "github.com/swaggo/gin-swagger"
|
||||||
|
|
||||||
docs "github.com/naiba/nezha/cmd/dashboard/docs"
|
docs "github.com/naiba/nezha/cmd/dashboard/docs"
|
||||||
"github.com/naiba/nezha/model"
|
"github.com/naiba/nezha/model"
|
||||||
"github.com/naiba/nezha/pkg/utils"
|
|
||||||
"github.com/naiba/nezha/proto"
|
|
||||||
"github.com/naiba/nezha/service/rpc"
|
|
||||||
"github.com/naiba/nezha/service/singleton"
|
"github.com/naiba/nezha/service/singleton"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -31,7 +26,6 @@ func ServeWeb() http.Handler {
|
|||||||
gin.SetMode(gin.DebugMode)
|
gin.SetMode(gin.DebugMode)
|
||||||
pprof.Register(r)
|
pprof.Register(r)
|
||||||
}
|
}
|
||||||
r.Use(natGateway)
|
|
||||||
if singleton.Conf.Debug {
|
if singleton.Conf.Debug {
|
||||||
r.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerfiles.Handler))
|
r.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerfiles.Handler))
|
||||||
}
|
}
|
||||||
@ -83,67 +77,6 @@ func routers(r *gin.Engine) {
|
|||||||
auth.POST("/batch-delete/ddns", commonHandler(batchDeleteDDNS))
|
auth.POST("/batch-delete/ddns", commonHandler(batchDeleteDDNS))
|
||||||
}
|
}
|
||||||
|
|
||||||
func natGateway(c *gin.Context) {
|
|
||||||
natConfig := singleton.GetNATConfigByDomain(c.Request.Host)
|
|
||||||
if natConfig == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
singleton.ServerLock.RLock()
|
|
||||||
server := singleton.ServerList[natConfig.ServerID]
|
|
||||||
singleton.ServerLock.RUnlock()
|
|
||||||
if server == nil || server.TaskStream == nil {
|
|
||||||
c.Writer.WriteString("server not found or not connected")
|
|
||||||
c.Abort()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
streamId, err := uuid.GenerateUUID()
|
|
||||||
if err != nil {
|
|
||||||
c.Writer.WriteString(fmt.Sprintf("stream id error: %v", err))
|
|
||||||
c.Abort()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
rpc.NezhaHandlerSingleton.CreateStream(streamId)
|
|
||||||
defer rpc.NezhaHandlerSingleton.CloseStream(streamId)
|
|
||||||
|
|
||||||
taskData, err := utils.Json.Marshal(model.TaskNAT{
|
|
||||||
StreamID: streamId,
|
|
||||||
Host: natConfig.Host,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
c.Writer.WriteString(fmt.Sprintf("task data error: %v", err))
|
|
||||||
c.Abort()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := server.TaskStream.Send(&proto.Task{
|
|
||||||
Type: model.TaskTypeNAT,
|
|
||||||
Data: string(taskData),
|
|
||||||
}); err != nil {
|
|
||||||
c.Writer.WriteString(fmt.Sprintf("send task error: %v", err))
|
|
||||||
c.Abort()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
w, err := utils.NewRequestWrapper(c.Request, c.Writer)
|
|
||||||
if err != nil {
|
|
||||||
c.Writer.WriteString(fmt.Sprintf("request wrapper error: %v", err))
|
|
||||||
c.Abort()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := rpc.NezhaHandlerSingleton.UserConnected(streamId, w); err != nil {
|
|
||||||
c.Writer.WriteString(fmt.Sprintf("user connected error: %v", err))
|
|
||||||
c.Abort()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
rpc.NezhaHandlerSingleton.StartStream(streamId, time.Second*10)
|
|
||||||
c.Abort()
|
|
||||||
}
|
|
||||||
|
|
||||||
func recordPath(c *gin.Context) {
|
func recordPath(c *gin.Context) {
|
||||||
url := c.Request.URL.String()
|
url := c.Request.URL.String()
|
||||||
for _, p := range c.Params {
|
for _, p := range c.Params {
|
||||||
|
@ -159,6 +159,11 @@ func dispatchReportInfoTask() {
|
|||||||
|
|
||||||
func newHTTPandGRPCMux(httpHandler http.Handler, grpcHandler http.Handler) http.Handler {
|
func newHTTPandGRPCMux(httpHandler http.Handler, grpcHandler http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
natConfig := singleton.GetNATConfigByDomain(r.Host)
|
||||||
|
if natConfig != nil {
|
||||||
|
rpc.ServeNAT(w, r, natConfig)
|
||||||
|
return
|
||||||
|
}
|
||||||
if r.ProtoMajor == 2 && r.Header.Get("Content-Type") == "application/grpc" &&
|
if r.ProtoMajor == 2 && r.Header.Get("Content-Type") == "application/grpc" &&
|
||||||
strings.HasPrefix(r.URL.Path, "/"+proto.NezhaService_ServiceDesc.ServiceName) {
|
strings.HasPrefix(r.URL.Path, "/"+proto.NezhaService_ServiceDesc.ServiceName) {
|
||||||
grpcHandler.ServeHTTP(w, r)
|
grpcHandler.ServeHTTP(w, r)
|
||||||
|
@ -1,10 +1,16 @@
|
|||||||
package rpc
|
package rpc
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-uuid"
|
||||||
"github.com/naiba/nezha/model"
|
"github.com/naiba/nezha/model"
|
||||||
pb "github.com/naiba/nezha/proto"
|
"github.com/naiba/nezha/pkg/utils"
|
||||||
|
"github.com/naiba/nezha/proto"
|
||||||
rpcService "github.com/naiba/nezha/service/rpc"
|
rpcService "github.com/naiba/nezha/service/rpc"
|
||||||
"github.com/naiba/nezha/service/singleton"
|
"github.com/naiba/nezha/service/singleton"
|
||||||
)
|
)
|
||||||
@ -12,7 +18,7 @@ import (
|
|||||||
func ServeRPC() *grpc.Server {
|
func ServeRPC() *grpc.Server {
|
||||||
server := grpc.NewServer()
|
server := grpc.NewServer()
|
||||||
rpcService.NezhaHandlerSingleton = rpcService.NewNezhaHandler()
|
rpcService.NezhaHandlerSingleton = rpcService.NewNezhaHandler()
|
||||||
pb.RegisterNezhaServiceServer(server, rpcService.NezhaHandlerSingleton)
|
proto.RegisterNezhaServiceServer(server, rpcService.NezhaHandlerSingleton)
|
||||||
return server
|
return server
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -69,7 +75,62 @@ func DispatchKeepalive() {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
singleton.SortedServerList[i].TaskStream.Send(&pb.Task{Type: model.TaskTypeKeepalive})
|
singleton.SortedServerList[i].TaskStream.Send(&proto.Task{Type: model.TaskTypeKeepalive})
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ServeNAT(w http.ResponseWriter, r *http.Request, natConfig *model.NAT) {
|
||||||
|
singleton.ServerLock.RLock()
|
||||||
|
server := singleton.ServerList[natConfig.ServerID]
|
||||||
|
singleton.ServerLock.RUnlock()
|
||||||
|
if server == nil || server.TaskStream == nil {
|
||||||
|
w.WriteHeader(http.StatusServiceUnavailable)
|
||||||
|
w.Write([]byte("server not found or not connected"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
streamId, err := uuid.GenerateUUID()
|
||||||
|
if err != nil {
|
||||||
|
w.WriteHeader(http.StatusServiceUnavailable)
|
||||||
|
w.Write([]byte(fmt.Sprintf("stream id error: %v", err)))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rpcService.NezhaHandlerSingleton.CreateStream(streamId)
|
||||||
|
defer rpcService.NezhaHandlerSingleton.CloseStream(streamId)
|
||||||
|
|
||||||
|
taskData, err := utils.Json.Marshal(model.TaskNAT{
|
||||||
|
StreamID: streamId,
|
||||||
|
Host: natConfig.Host,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
w.WriteHeader(http.StatusServiceUnavailable)
|
||||||
|
w.Write([]byte(fmt.Sprintf("task data error: %v", err)))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := server.TaskStream.Send(&proto.Task{
|
||||||
|
Type: model.TaskTypeNAT,
|
||||||
|
Data: string(taskData),
|
||||||
|
}); err != nil {
|
||||||
|
w.WriteHeader(http.StatusServiceUnavailable)
|
||||||
|
w.Write([]byte(fmt.Sprintf("send task error: %v", err)))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
wWrapped, err := utils.NewRequestWrapper(r, w)
|
||||||
|
if err != nil {
|
||||||
|
w.WriteHeader(http.StatusServiceUnavailable)
|
||||||
|
w.Write([]byte(fmt.Sprintf("request wrapper error: %v", err)))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rpcService.NezhaHandlerSingleton.UserConnected(streamId, wWrapped); err != nil {
|
||||||
|
w.WriteHeader(http.StatusServiceUnavailable)
|
||||||
|
w.Write([]byte(fmt.Sprintf("user connected error: %v", err)))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rpcService.NezhaHandlerSingleton.StartStream(streamId, time.Second*10)
|
||||||
|
}
|
||||||
|
@ -2,11 +2,10 @@ package utils
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ io.ReadWriteCloser = &RequestWrapper{}
|
var _ io.ReadWriteCloser = &RequestWrapper{}
|
||||||
@ -17,8 +16,12 @@ type RequestWrapper struct {
|
|||||||
writer net.Conn
|
writer net.Conn
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewRequestWrapper(req *http.Request, writer gin.ResponseWriter) (*RequestWrapper, error) {
|
func NewRequestWrapper(req *http.Request, writer http.ResponseWriter) (*RequestWrapper, error) {
|
||||||
conn, _, err := writer.Hijack()
|
hj, ok := writer.(http.Hijacker)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("http server does not support hijacking")
|
||||||
|
}
|
||||||
|
conn, _, err := hj.Hijack()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user