diff --git a/cmd/dashboard/controller/common_page.go b/cmd/dashboard/controller/common_page.go index 7719115..8182d50 100644 --- a/cmd/dashboard/controller/common_page.go +++ b/cmd/dashboard/controller/common_page.go @@ -3,12 +3,8 @@ package controller import ( "errors" "fmt" - "log" "net/http" - "regexp" "strconv" - "strings" - "sync" "time" "github.com/gin-gonic/gin" @@ -24,22 +20,13 @@ import ( "github.com/naiba/nezha/pkg/utils" "github.com/naiba/nezha/pkg/websocketx" "github.com/naiba/nezha/proto" + "github.com/naiba/nezha/service/rpc" "github.com/naiba/nezha/service/singleton" ) -type terminalContext struct { - agentConn *websocketx.Conn - userConn *websocketx.Conn - serverID uint64 - host string - useSSL bool -} - type commonPage struct { - r *gin.Engine - terminals map[string]*terminalContext - terminalsLock *sync.Mutex - requestGroup singleflight.Group + r *gin.Engine + requestGroup singleflight.Group } func (cp *commonPage) serve() { @@ -68,7 +55,6 @@ type viewPasswordForm struct { func (p *commonPage) issueViewPassword(c *gin.Context) { var vpf viewPasswordForm err := c.ShouldBind(&vpf) - log.Println("bingo", vpf) var hash []byte if err == nil && vpf.Password != singleton.Conf.Site.ViewPassword { err = errors.New(singleton.Localizer.MustLocalize(&i18n.LocalizeConfig{MessageID: "WrongAccessPassword"})) @@ -283,8 +269,6 @@ type Data struct { Servers []*model.Server `json:"servers,omitempty"` } -var cloudflareCookiesValidator = regexp.MustCompile("^[A-Za-z0-9-_]+$") - func (cp *commonPage) ws(c *gin.Context) { conn, err := upgrader.Upgrade(c.Writer, c.Request, nil) if err != nil { @@ -322,9 +306,7 @@ func (cp *commonPage) ws(c *gin.Context) { func (cp *commonPage) terminal(c *gin.Context) { terminalID := c.Param("id") - cp.terminalsLock.Lock() - if terminalID == "" || cp.terminals[terminalID] == nil { - cp.terminalsLock.Unlock() + if _, err := rpc.NezhaHandlerSingleton.GetStream(terminalID); err != nil { mygin.ShowErrorPage(c, mygin.ErrInfo{ Code: http.StatusForbidden, Title: "无权访问", @@ -334,104 +316,7 @@ func (cp *commonPage) terminal(c *gin.Context) { }, true) return } - - terminal := cp.terminals[terminalID] - cp.terminalsLock.Unlock() - - defer func() { - // 清理 context - cp.terminalsLock.Lock() - defer cp.terminalsLock.Unlock() - delete(cp.terminals, terminalID) - }() - - var isAgent bool - - if _, authorized := c.Get(model.CtxKeyAuthorizedUser); !authorized { - singleton.ServerLock.RLock() - _, hasID := singleton.SecretToID[c.Request.Header.Get("Secret")] - singleton.ServerLock.RUnlock() - if !hasID { - mygin.ShowErrorPage(c, mygin.ErrInfo{ - Code: http.StatusForbidden, - Title: "无权访问", - Msg: "用户未登录或非法终端", - Link: "/", - Btn: "返回首页", - }, true) - return - } - if terminal.userConn == nil { - mygin.ShowErrorPage(c, mygin.ErrInfo{ - Code: http.StatusForbidden, - Title: "无权访问", - Msg: "用户不在线", - Link: "/", - Btn: "返回首页", - }, true) - return - } - if terminal.agentConn != nil { - mygin.ShowErrorPage(c, mygin.ErrInfo{ - Code: http.StatusInternalServerError, - Title: "连接已存在", - Msg: "Websocket协议切换失败", - Link: "/", - Btn: "返回首页", - }, true) - return - } - isAgent = true - } else { - singleton.ServerLock.RLock() - server := singleton.ServerList[terminal.serverID] - singleton.ServerLock.RUnlock() - if server == nil || server.TaskStream == nil { - mygin.ShowErrorPage(c, mygin.ErrInfo{ - Code: http.StatusForbidden, - Title: "请求失败", - Msg: "服务器不存在或处于离线状态", - Link: "/server", - Btn: "返回重试", - }, true) - return - } - cloudflareCookies, _ := c.Cookie("CF_Authorization") - // Cloudflare Cookies 合法性验证 - // 其应该包含.分隔的三组BASE64-URL编码 - if cloudflareCookies != "" { - encodedCookies := strings.Split(cloudflareCookies, ".") - if len(encodedCookies) == 3 { - for i := 0; i < 3; i++ { - if !cloudflareCookiesValidator.MatchString(encodedCookies[i]) { - cloudflareCookies = "" - break - } - } - } else { - cloudflareCookies = "" - } - } - terminalData, _ := utils.Json.Marshal(&model.TerminalTask{ - Host: terminal.host, - UseSSL: terminal.useSSL, - Session: terminalID, - Cookie: cloudflareCookies, - }) - if err := server.TaskStream.Send(&proto.Task{ - Type: model.TaskTypeTerminal, - Data: string(terminalData), - }); err != nil { - mygin.ShowErrorPage(c, mygin.ErrInfo{ - Code: http.StatusForbidden, - Title: "请求失败", - Msg: "Agent信令下发失败", - Link: "/server", - Btn: "返回重试", - }, true) - return - } - } + defer rpc.NezhaHandlerSingleton.CloseStream(terminalID) wsConn, err := upgrader.Upgrade(c.Writer, c.Request, nil) if err != nil { @@ -447,36 +332,7 @@ func (cp *commonPage) terminal(c *gin.Context) { return } defer wsConn.Close() - conn := &websocketx.Conn{Conn: wsConn} - - log.Printf("NEZHA>> terminal connected %t %q", isAgent, c.Request.URL) - defer log.Printf("NEZHA>> terminal disconnected %t %q", isAgent, c.Request.URL) - - if isAgent { - terminal.agentConn = conn - defer func() { - // Agent断开链接时断开用户连接 - if terminal.userConn != nil { - terminal.userConn.Close() - } - }() - } else { - terminal.userConn = conn - defer func() { - // 用户断开链接时断开 Agent 连接 - if terminal.agentConn != nil { - terminal.agentConn.Close() - } - }() - } - - deadlineCh := make(chan interface{}) - go func() { - // 对方连接超时 - connectDeadline := time.NewTimer(time.Second * 15) - <-connectDeadline.C - deadlineCh <- struct{}{} - }() + conn := websocketx.NewConn(wsConn) go func() { // PING 保活 @@ -488,58 +344,11 @@ func (cp *commonPage) terminal(c *gin.Context) { } }() - dataCh := make(chan []byte) - errorCh := make(chan error) - go func() { - for { - msgType, data, err := conn.ReadMessage() - if err != nil { - errorCh <- err - return - } - // 将文本消息转换为命令输入 - if msgType == websocket.TextMessage { - data = append([]byte{0}, data...) - } - dataCh <- data - } - }() - - var dataBuffer [][]byte - var distConn *websocketx.Conn - checkDistConn := func() { - if distConn == nil { - if isAgent { - distConn = terminal.userConn - } else { - distConn = terminal.agentConn - } - } + if err = rpc.NezhaHandlerSingleton.UserConnected(terminalID, conn); err != nil { + return } - for { - select { - case <-deadlineCh: - checkDistConn() - if distConn == nil { - return - } - case <-errorCh: - return - case data := <-dataCh: - dataBuffer = append(dataBuffer, data) - checkDistConn() - if distConn != nil { - for i := 0; i < len(dataBuffer); i++ { - err = distConn.WriteMessage(websocket.BinaryMessage, dataBuffer[i]) - if err != nil { - return - } - } - dataBuffer = dataBuffer[:0] - } - } - } + rpc.NezhaHandlerSingleton.StartStream(terminalID, time.Second*10) } type createTerminalRequest struct { @@ -585,6 +394,8 @@ func (cp *commonPage) createTerminal(c *gin.Context) { return } + rpc.NezhaHandlerSingleton.CreateStream(id) + singleton.ServerLock.RLock() server := singleton.ServerList[createTerminalReq.ID] singleton.ServerLock.RUnlock() @@ -599,13 +410,21 @@ func (cp *commonPage) createTerminal(c *gin.Context) { return } - cp.terminalsLock.Lock() - defer cp.terminalsLock.Unlock() - - cp.terminals[id] = &terminalContext{ - serverID: createTerminalReq.ID, - host: createTerminalReq.Host, - useSSL: createTerminalReq.Protocol == "https:", + terminalData, _ := utils.Json.Marshal(&model.TerminalTask{ + StreamID: id, + }) + if err := server.TaskStream.Send(&proto.Task{ + Type: model.TaskTypeTerminalGRPC, + Data: string(terminalData), + }); err != nil { + mygin.ShowErrorPage(c, mygin.ErrInfo{ + Code: http.StatusForbidden, + Title: "请求失败", + Msg: "Agent信令下发失败", + Link: "/server", + Btn: "返回重试", + }, true) + return } c.HTML(http.StatusOK, "dashboard-"+singleton.Conf.Site.DashboardTheme+"/terminal", mygin.CommonEnvironment(c, gin.H{ diff --git a/cmd/dashboard/controller/controller.go b/cmd/dashboard/controller/controller.go index ece1d56..8ede497 100644 --- a/cmd/dashboard/controller/controller.go +++ b/cmd/dashboard/controller/controller.go @@ -9,7 +9,6 @@ import ( "os" "strconv" "strings" - "sync" "time" "code.cloudfoundry.org/bytefmt" @@ -68,7 +67,7 @@ func ServeWeb(port uint) *http.Server { func routers(r *gin.Engine) { // 通用页面 - cp := commonPage{r: r, terminals: make(map[string]*terminalContext), terminalsLock: new(sync.Mutex)} + cp := commonPage{r: r} cp.serve() // 游客页面 gp := guestPage{r} diff --git a/cmd/dashboard/main.go b/cmd/dashboard/main.go index db5dce8..2db7fed 100644 --- a/cmd/dashboard/main.go +++ b/cmd/dashboard/main.go @@ -59,6 +59,7 @@ func main() { return } + // TODO 使用 cmux 在同一端口服务 HTTP 和 gRPC singleton.CleanMonitorHistory() go rpc.ServeRPC(singleton.Conf.GRPCPort) serviceSentinelDispatchBus := make(chan model.Monitor) // 用于传递服务监控任务信息的channel diff --git a/cmd/dashboard/rpc/rpc.go b/cmd/dashboard/rpc/rpc.go index 540c4ad..5d013d7 100644 --- a/cmd/dashboard/rpc/rpc.go +++ b/cmd/dashboard/rpc/rpc.go @@ -14,9 +14,8 @@ import ( func ServeRPC(port uint) { server := grpc.NewServer() - pb.RegisterNezhaServiceServer(server, &rpcService.NezhaHandler{ - Auth: &rpcService.AuthHandler{}, - }) + rpcService.NezhaHandlerSingleton = rpcService.NewNezhaHandler() + pb.RegisterNezhaServiceServer(server, rpcService.NezhaHandlerSingleton) listen, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) if err != nil { panic(err) diff --git a/go.mod b/go.mod index 8d9438b..679494b 100644 --- a/go.mod +++ b/go.mod @@ -18,7 +18,6 @@ require ( github.com/ory/graceful v0.1.3 github.com/patrickmn/go-cache v2.1.0+incompatible github.com/robfig/cron/v3 v3.0.1 - github.com/samber/lo v1.39.0 github.com/spf13/pflag v1.0.5 github.com/spf13/viper v1.18.2 github.com/xanzy/go-gitlab v0.103.0 @@ -28,7 +27,7 @@ require ( golang.org/x/sync v0.7.0 golang.org/x/text v0.16.0 google.golang.org/grpc v1.63.0 - google.golang.org/protobuf v1.33.0 + google.golang.org/protobuf v1.34.2 gorm.io/driver/sqlite v1.5.5 gorm.io/gorm v1.25.10 sigs.k8s.io/yaml v1.4.0 diff --git a/go.sum b/go.sum index 26b1d67..933d897 100644 --- a/go.sum +++ b/go.sum @@ -152,8 +152,6 @@ github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6ke github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4= github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE= github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ= -github.com/samber/lo v1.39.0 h1:4gTz1wUhNYLhFSKl6O+8peW0v2F4BCY034GRpU9WnuA= -github.com/samber/lo v1.39.0/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA= github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= @@ -240,8 +238,8 @@ google.golang.org/grpc v1.63.0 h1:WjKe+dnvABXyPJMD7KDNLxtoGk5tgk+YFWN6cBWjZE8= google.golang.org/grpc v1.63.0/go.mod h1:WAX/8DgncnokcFUldAxq7GeB5DXHDbMF+lLvDomNkRA= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= -google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= -google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= +google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= +google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= diff --git a/model/monitor.go b/model/monitor.go index 69f3ad2..b858593 100644 --- a/model/monitor.go +++ b/model/monitor.go @@ -20,17 +20,11 @@ const ( TaskTypeTerminal TaskTypeUpgrade TaskTypeKeepalive + TaskTypeTerminalGRPC ) type TerminalTask struct { - // websocket 主机名 - Host string `json:"host,omitempty"` - // 是否启用 SSL - UseSSL bool `json:"use_ssl,omitempty"` - // 会话标识 - Session string `json:"session,omitempty"` - // Agent在连接Server时需要的额外Cookie信息 - Cookie string `json:"cookie,omitempty"` + StreamID string } const ( diff --git a/pkg/grpcx/io_stream_wrapper.go b/pkg/grpcx/io_stream_wrapper.go new file mode 100644 index 0000000..72855c4 --- /dev/null +++ b/pkg/grpcx/io_stream_wrapper.go @@ -0,0 +1,65 @@ +package grpcx + +import ( + "context" + "io" + "sync/atomic" + + "github.com/naiba/nezha/proto" +) + +var _ io.ReadWriteCloser = &IOStreamWrapper{} + +type IOStream interface { + Recv() (*proto.IOStreamData, error) + Send(*proto.IOStreamData) error + Context() context.Context +} + +type IOStreamWrapper struct { + IOStream + dataBuf []byte + closed *atomic.Bool + closeCh chan struct{} +} + +func NewIOStreamWrapper(stream IOStream) *IOStreamWrapper { + return &IOStreamWrapper{ + IOStream: stream, + closeCh: make(chan struct{}), + closed: new(atomic.Bool), + } +} + +func (iw *IOStreamWrapper) Read(p []byte) (n int, err error) { + if len(iw.dataBuf) > 0 { + n := copy(p, iw.dataBuf) + iw.dataBuf = iw.dataBuf[n:] + return n, nil + } + var data *proto.IOStreamData + if data, err = iw.Recv(); err != nil { + return 0, err + } + n = copy(p, data.Data) + if n < len(data.Data) { + iw.dataBuf = data.Data[n:] + } + return n, nil +} + +func (iw *IOStreamWrapper) Write(p []byte) (n int, err error) { + err = iw.Send(&proto.IOStreamData{Data: p}) + return len(p), err +} + +func (iw *IOStreamWrapper) Close() error { + if iw.closed.CompareAndSwap(false, true) { + close(iw.closeCh) + } + return nil +} + +func (iw *IOStreamWrapper) Wait() { + <-iw.closeCh +} diff --git a/pkg/websocketx/safe_conn.go b/pkg/websocketx/safe_conn.go index f442c12..3b45c75 100644 --- a/pkg/websocketx/safe_conn.go +++ b/pkg/websocketx/safe_conn.go @@ -4,22 +4,44 @@ import ( "sync" "github.com/gorilla/websocket" - "github.com/samber/lo" ) type Conn struct { *websocket.Conn - writeLock sync.Mutex + writeLock *sync.Mutex + dataBuf []byte } -func (conn *Conn) WriteMessage(msgType int, data []byte) error { +func NewConn(conn *websocket.Conn) *Conn { + return &Conn{Conn: conn, writeLock: new(sync.Mutex)} +} + +func (conn *Conn) Write(data []byte) (int, error) { conn.writeLock.Lock() defer conn.writeLock.Unlock() - var err error - lo.TryCatchWithErrorValue(func() error { - return conn.Conn.WriteMessage(msgType, data) - }, func(res any) { - err = res.(error) - }) - return err + if err := conn.Conn.WriteMessage(websocket.BinaryMessage, data); err != nil { + return 0, err + } + return len(data), nil +} + +func (conn *Conn) Read(data []byte) (int, error) { + if len(conn.dataBuf) > 0 { + n := copy(data, conn.dataBuf) + conn.dataBuf = conn.dataBuf[n:] + return n, nil + } + mType, innerData, err := conn.Conn.ReadMessage() + if err != nil { + return 0, err + } + // 将文本消息转换为命令输入 + if mType == websocket.TextMessage { + innerData = append([]byte{0}, innerData...) + } + n := copy(data, innerData) + if n < len(innerData) { + conn.dataBuf = innerData[n:] + } + return n, nil } diff --git a/proto/nezha.pb.go b/proto/nezha.pb.go index ad13b26..0d7720e 100644 --- a/proto/nezha.pb.go +++ b/proto/nezha.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.34.1 -// protoc v5.26.1 +// protoc-gen-go v1.34.2 +// protoc v5.27.1 // source: proto/nezha.proto package proto @@ -582,6 +582,53 @@ func (x *Receipt) GetProced() bool { return false } +type IOStreamData struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Data []byte `protobuf:"bytes,1,opt,name=data,proto3" json:"data,omitempty"` +} + +func (x *IOStreamData) Reset() { + *x = IOStreamData{} + if protoimpl.UnsafeEnabled { + mi := &file_proto_nezha_proto_msgTypes[6] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *IOStreamData) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*IOStreamData) ProtoMessage() {} + +func (x *IOStreamData) ProtoReflect() protoreflect.Message { + mi := &file_proto_nezha_proto_msgTypes[6] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use IOStreamData.ProtoReflect.Descriptor instead. +func (*IOStreamData) Descriptor() ([]byte, []int) { + return file_proto_nezha_proto_rawDescGZIP(), []int{6} +} + +func (x *IOStreamData) GetData() []byte { + if x != nil { + return x.Data + } + return nil +} + var File_proto_nezha_proto protoreflect.FileDescriptor var file_proto_nezha_proto_rawDesc = []byte{ @@ -663,21 +710,27 @@ var file_proto_nezha_proto_rawDesc = []byte{ 0x65, 0x73, 0x73, 0x66, 0x75, 0x6c, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x66, 0x75, 0x6c, 0x22, 0x21, 0x0a, 0x07, 0x52, 0x65, 0x63, 0x65, 0x69, 0x70, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x70, 0x72, 0x6f, 0x63, 0x65, 0x64, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x08, 0x52, 0x06, 0x70, 0x72, 0x6f, 0x63, 0x65, 0x64, 0x32, 0xd6, 0x01, 0x0a, 0x0c, - 0x4e, 0x65, 0x7a, 0x68, 0x61, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x33, 0x0a, 0x11, - 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x53, 0x74, 0x61, 0x74, - 0x65, 0x12, 0x0c, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x1a, - 0x0e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x65, 0x69, 0x70, 0x74, 0x22, - 0x00, 0x12, 0x31, 0x0a, 0x10, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x53, 0x79, 0x73, 0x74, 0x65, - 0x6d, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x0b, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x48, 0x6f, - 0x73, 0x74, 0x1a, 0x0e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x65, 0x69, - 0x70, 0x74, 0x22, 0x00, 0x12, 0x31, 0x0a, 0x0a, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x54, 0x61, - 0x73, 0x6b, 0x12, 0x11, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x54, 0x61, 0x73, 0x6b, 0x52, - 0x65, 0x73, 0x75, 0x6c, 0x74, 0x1a, 0x0e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, - 0x63, 0x65, 0x69, 0x70, 0x74, 0x22, 0x00, 0x12, 0x2b, 0x0a, 0x0b, 0x52, 0x65, 0x71, 0x75, 0x65, - 0x73, 0x74, 0x54, 0x61, 0x73, 0x6b, 0x12, 0x0b, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x48, - 0x6f, 0x73, 0x74, 0x1a, 0x0b, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x54, 0x61, 0x73, 0x6b, - 0x22, 0x00, 0x30, 0x01, 0x42, 0x09, 0x5a, 0x07, 0x2e, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, + 0x01, 0x28, 0x08, 0x52, 0x06, 0x70, 0x72, 0x6f, 0x63, 0x65, 0x64, 0x22, 0x22, 0x0a, 0x0c, 0x49, + 0x4f, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x44, 0x61, 0x74, 0x61, 0x12, 0x12, 0x0a, 0x04, 0x64, + 0x61, 0x74, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x64, 0x61, 0x74, 0x61, 0x32, + 0x92, 0x02, 0x0a, 0x0c, 0x4e, 0x65, 0x7a, 0x68, 0x61, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, + 0x12, 0x33, 0x0a, 0x11, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, + 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x0c, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x53, 0x74, + 0x61, 0x74, 0x65, 0x1a, 0x0e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x65, + 0x69, 0x70, 0x74, 0x22, 0x00, 0x12, 0x31, 0x0a, 0x10, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x53, + 0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x0b, 0x2e, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x1a, 0x0e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, + 0x65, 0x63, 0x65, 0x69, 0x70, 0x74, 0x22, 0x00, 0x12, 0x31, 0x0a, 0x0a, 0x52, 0x65, 0x70, 0x6f, + 0x72, 0x74, 0x54, 0x61, 0x73, 0x6b, 0x12, 0x11, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x54, + 0x61, 0x73, 0x6b, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x1a, 0x0e, 0x2e, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x65, 0x69, 0x70, 0x74, 0x22, 0x00, 0x12, 0x2b, 0x0a, 0x0b, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x54, 0x61, 0x73, 0x6b, 0x12, 0x0b, 0x2e, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x1a, 0x0b, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, + 0x54, 0x61, 0x73, 0x6b, 0x22, 0x00, 0x30, 0x01, 0x12, 0x3a, 0x0a, 0x08, 0x49, 0x4f, 0x53, 0x74, + 0x72, 0x65, 0x61, 0x6d, 0x12, 0x13, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x49, 0x4f, 0x53, + 0x74, 0x72, 0x65, 0x61, 0x6d, 0x44, 0x61, 0x74, 0x61, 0x1a, 0x13, 0x2e, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x2e, 0x49, 0x4f, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x44, 0x61, 0x74, 0x61, 0x22, 0x00, + 0x28, 0x01, 0x30, 0x01, 0x42, 0x09, 0x5a, 0x07, 0x2e, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } @@ -693,14 +746,15 @@ func file_proto_nezha_proto_rawDescGZIP() []byte { return file_proto_nezha_proto_rawDescData } -var file_proto_nezha_proto_msgTypes = make([]protoimpl.MessageInfo, 6) -var file_proto_nezha_proto_goTypes = []interface{}{ +var file_proto_nezha_proto_msgTypes = make([]protoimpl.MessageInfo, 7) +var file_proto_nezha_proto_goTypes = []any{ (*Host)(nil), // 0: proto.Host (*State)(nil), // 1: proto.State (*State_SensorTemperature)(nil), // 2: proto.State_SensorTemperature (*Task)(nil), // 3: proto.Task (*TaskResult)(nil), // 4: proto.TaskResult (*Receipt)(nil), // 5: proto.Receipt + (*IOStreamData)(nil), // 6: proto.IOStreamData } var file_proto_nezha_proto_depIdxs = []int32{ 2, // 0: proto.State.temperatures:type_name -> proto.State_SensorTemperature @@ -708,12 +762,14 @@ var file_proto_nezha_proto_depIdxs = []int32{ 0, // 2: proto.NezhaService.ReportSystemInfo:input_type -> proto.Host 4, // 3: proto.NezhaService.ReportTask:input_type -> proto.TaskResult 0, // 4: proto.NezhaService.RequestTask:input_type -> proto.Host - 5, // 5: proto.NezhaService.ReportSystemState:output_type -> proto.Receipt - 5, // 6: proto.NezhaService.ReportSystemInfo:output_type -> proto.Receipt - 5, // 7: proto.NezhaService.ReportTask:output_type -> proto.Receipt - 3, // 8: proto.NezhaService.RequestTask:output_type -> proto.Task - 5, // [5:9] is the sub-list for method output_type - 1, // [1:5] is the sub-list for method input_type + 6, // 5: proto.NezhaService.IOStream:input_type -> proto.IOStreamData + 5, // 6: proto.NezhaService.ReportSystemState:output_type -> proto.Receipt + 5, // 7: proto.NezhaService.ReportSystemInfo:output_type -> proto.Receipt + 5, // 8: proto.NezhaService.ReportTask:output_type -> proto.Receipt + 3, // 9: proto.NezhaService.RequestTask:output_type -> proto.Task + 6, // 10: proto.NezhaService.IOStream:output_type -> proto.IOStreamData + 6, // [6:11] is the sub-list for method output_type + 1, // [1:6] is the sub-list for method input_type 1, // [1:1] is the sub-list for extension type_name 1, // [1:1] is the sub-list for extension extendee 0, // [0:1] is the sub-list for field type_name @@ -725,7 +781,7 @@ func file_proto_nezha_proto_init() { return } if !protoimpl.UnsafeEnabled { - file_proto_nezha_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + file_proto_nezha_proto_msgTypes[0].Exporter = func(v any, i int) any { switch v := v.(*Host); i { case 0: return &v.state @@ -737,7 +793,7 @@ func file_proto_nezha_proto_init() { return nil } } - file_proto_nezha_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + file_proto_nezha_proto_msgTypes[1].Exporter = func(v any, i int) any { switch v := v.(*State); i { case 0: return &v.state @@ -749,7 +805,7 @@ func file_proto_nezha_proto_init() { return nil } } - file_proto_nezha_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + file_proto_nezha_proto_msgTypes[2].Exporter = func(v any, i int) any { switch v := v.(*State_SensorTemperature); i { case 0: return &v.state @@ -761,7 +817,7 @@ func file_proto_nezha_proto_init() { return nil } } - file_proto_nezha_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { + file_proto_nezha_proto_msgTypes[3].Exporter = func(v any, i int) any { switch v := v.(*Task); i { case 0: return &v.state @@ -773,7 +829,7 @@ func file_proto_nezha_proto_init() { return nil } } - file_proto_nezha_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { + file_proto_nezha_proto_msgTypes[4].Exporter = func(v any, i int) any { switch v := v.(*TaskResult); i { case 0: return &v.state @@ -785,7 +841,7 @@ func file_proto_nezha_proto_init() { return nil } } - file_proto_nezha_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} { + file_proto_nezha_proto_msgTypes[5].Exporter = func(v any, i int) any { switch v := v.(*Receipt); i { case 0: return &v.state @@ -797,6 +853,18 @@ func file_proto_nezha_proto_init() { return nil } } + file_proto_nezha_proto_msgTypes[6].Exporter = func(v any, i int) any { + switch v := v.(*IOStreamData); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } } type x struct{} out := protoimpl.TypeBuilder{ @@ -804,7 +872,7 @@ func file_proto_nezha_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_proto_nezha_proto_rawDesc, NumEnums: 0, - NumMessages: 6, + NumMessages: 7, NumExtensions: 0, NumServices: 1, }, diff --git a/proto/nezha.proto b/proto/nezha.proto index c3ff79c..da62c9a 100644 --- a/proto/nezha.proto +++ b/proto/nezha.proto @@ -8,6 +8,7 @@ service NezhaService { rpc ReportSystemInfo(Host)returns(Receipt){} rpc ReportTask(TaskResult)returns(Receipt){} rpc RequestTask(Host)returns(stream Task){} + rpc IOStream(stream IOStreamData)returns(stream IOStreamData){} } message Host { @@ -68,3 +69,7 @@ message TaskResult { message Receipt{ bool proced = 1; } + +message IOStreamData { + bytes data = 1; +} diff --git a/proto/nezha_grpc.pb.go b/proto/nezha_grpc.pb.go index 34f45a2..2c59636 100644 --- a/proto/nezha_grpc.pb.go +++ b/proto/nezha_grpc.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: // - protoc-gen-go-grpc v1.3.0 -// - protoc v3.21.12 +// - protoc v5.27.1 // source: proto/nezha.proto package proto @@ -23,6 +23,7 @@ const ( NezhaService_ReportSystemInfo_FullMethodName = "/proto.NezhaService/ReportSystemInfo" NezhaService_ReportTask_FullMethodName = "/proto.NezhaService/ReportTask" NezhaService_RequestTask_FullMethodName = "/proto.NezhaService/RequestTask" + NezhaService_IOStream_FullMethodName = "/proto.NezhaService/IOStream" ) // NezhaServiceClient is the client API for NezhaService service. @@ -33,6 +34,7 @@ type NezhaServiceClient interface { ReportSystemInfo(ctx context.Context, in *Host, opts ...grpc.CallOption) (*Receipt, error) ReportTask(ctx context.Context, in *TaskResult, opts ...grpc.CallOption) (*Receipt, error) RequestTask(ctx context.Context, in *Host, opts ...grpc.CallOption) (NezhaService_RequestTaskClient, error) + IOStream(ctx context.Context, opts ...grpc.CallOption) (NezhaService_IOStreamClient, error) } type nezhaServiceClient struct { @@ -102,6 +104,37 @@ func (x *nezhaServiceRequestTaskClient) Recv() (*Task, error) { return m, nil } +func (c *nezhaServiceClient) IOStream(ctx context.Context, opts ...grpc.CallOption) (NezhaService_IOStreamClient, error) { + stream, err := c.cc.NewStream(ctx, &NezhaService_ServiceDesc.Streams[1], NezhaService_IOStream_FullMethodName, opts...) + if err != nil { + return nil, err + } + x := &nezhaServiceIOStreamClient{stream} + return x, nil +} + +type NezhaService_IOStreamClient interface { + Send(*IOStreamData) error + Recv() (*IOStreamData, error) + grpc.ClientStream +} + +type nezhaServiceIOStreamClient struct { + grpc.ClientStream +} + +func (x *nezhaServiceIOStreamClient) Send(m *IOStreamData) error { + return x.ClientStream.SendMsg(m) +} + +func (x *nezhaServiceIOStreamClient) Recv() (*IOStreamData, error) { + m := new(IOStreamData) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + // NezhaServiceServer is the server API for NezhaService service. // All implementations should embed UnimplementedNezhaServiceServer // for forward compatibility @@ -110,6 +143,7 @@ type NezhaServiceServer interface { ReportSystemInfo(context.Context, *Host) (*Receipt, error) ReportTask(context.Context, *TaskResult) (*Receipt, error) RequestTask(*Host, NezhaService_RequestTaskServer) error + IOStream(NezhaService_IOStreamServer) error } // UnimplementedNezhaServiceServer should be embedded to have forward compatible implementations. @@ -128,6 +162,9 @@ func (UnimplementedNezhaServiceServer) ReportTask(context.Context, *TaskResult) func (UnimplementedNezhaServiceServer) RequestTask(*Host, NezhaService_RequestTaskServer) error { return status.Errorf(codes.Unimplemented, "method RequestTask not implemented") } +func (UnimplementedNezhaServiceServer) IOStream(NezhaService_IOStreamServer) error { + return status.Errorf(codes.Unimplemented, "method IOStream not implemented") +} // UnsafeNezhaServiceServer may be embedded to opt out of forward compatibility for this service. // Use of this interface is not recommended, as added methods to NezhaServiceServer will @@ -215,6 +252,32 @@ func (x *nezhaServiceRequestTaskServer) Send(m *Task) error { return x.ServerStream.SendMsg(m) } +func _NezhaService_IOStream_Handler(srv interface{}, stream grpc.ServerStream) error { + return srv.(NezhaServiceServer).IOStream(&nezhaServiceIOStreamServer{stream}) +} + +type NezhaService_IOStreamServer interface { + Send(*IOStreamData) error + Recv() (*IOStreamData, error) + grpc.ServerStream +} + +type nezhaServiceIOStreamServer struct { + grpc.ServerStream +} + +func (x *nezhaServiceIOStreamServer) Send(m *IOStreamData) error { + return x.ServerStream.SendMsg(m) +} + +func (x *nezhaServiceIOStreamServer) Recv() (*IOStreamData, error) { + m := new(IOStreamData) + if err := x.ServerStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + // NezhaService_ServiceDesc is the grpc.ServiceDesc for NezhaService service. // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) @@ -241,6 +304,12 @@ var NezhaService_ServiceDesc = grpc.ServiceDesc{ Handler: _NezhaService_RequestTask_Handler, ServerStreams: true, }, + { + StreamName: "IOStream", + Handler: _NezhaService_IOStream_Handler, + ServerStreams: true, + ClientStreams: true, + }, }, Metadata: "proto/nezha.proto", } diff --git a/script/proto.sh b/script/proto.sh index 1147e16..fb334e8 100755 --- a/script/proto.sh +++ b/script/proto.sh @@ -1 +1,3 @@ -protoc --go-grpc_out="require_unimplemented_servers=false:." --go_out="." proto/*.proto \ No newline at end of file +protoc --go-grpc_out="require_unimplemented_servers=false:." --go_out="." proto/*.proto +rm -rf ../agent/proto +cp -r proto ../agent \ No newline at end of file diff --git a/service/rpc/auth.go b/service/rpc/auth.go index 88d9017..a189916 100644 --- a/service/rpc/auth.go +++ b/service/rpc/auth.go @@ -10,19 +10,19 @@ import ( "github.com/naiba/nezha/service/singleton" ) -type AuthHandler struct { +type authHandler struct { ClientSecret string } -func (a *AuthHandler) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { +func (a *authHandler) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { return map[string]string{"client_secret": a.ClientSecret}, nil } -func (a *AuthHandler) RequireTransportSecurity() bool { +func (a *authHandler) RequireTransportSecurity() bool { return false } -func (a *AuthHandler) Check(ctx context.Context) (uint64, error) { +func (a *authHandler) Check(ctx context.Context) (uint64, error) { md, ok := metadata.FromIncomingContext(ctx) if !ok { return 0, status.Errorf(codes.Unauthenticated, "获取 metaData 失败") diff --git a/service/rpc/io_stream.go b/service/rpc/io_stream.go new file mode 100644 index 0000000..d21d95f --- /dev/null +++ b/service/rpc/io_stream.go @@ -0,0 +1,140 @@ +package rpc + +import ( + "errors" + "io" + "sync/atomic" + "time" +) + +type ioStreamContext struct { + userIo io.ReadWriteCloser + agentIo io.ReadWriteCloser + userIoConnectCh chan struct{} + agentIoConnectCh chan struct{} +} + +func (s *NezhaHandler) CreateStream(streamId string) { + s.ioStreamMutex.Lock() + defer s.ioStreamMutex.Unlock() + + s.ioStreams[streamId] = &ioStreamContext{ + userIoConnectCh: make(chan struct{}), + agentIoConnectCh: make(chan struct{}), + } +} + +func (s *NezhaHandler) GetStream(streamId string) (*ioStreamContext, error) { + s.ioStreamMutex.RLock() + defer s.ioStreamMutex.RUnlock() + + if ctx, ok := s.ioStreams[streamId]; ok { + return ctx, nil + } + + return nil, errors.New("stream not found") +} + +func (s *NezhaHandler) CloseStream(streamId string) error { + s.ioStreamMutex.Lock() + defer s.ioStreamMutex.Unlock() + + if ctx, ok := s.ioStreams[streamId]; ok { + if ctx.userIo != nil { + ctx.userIo.Close() + } + if ctx.agentIo != nil { + ctx.agentIo.Close() + } + delete(s.ioStreams, streamId) + } + + return nil +} + +func (s *NezhaHandler) UserConnected(streamId string, userIo io.ReadWriteCloser) error { + stream, err := s.GetStream(streamId) + if err != nil { + return err + } + + stream.userIo = userIo + close(stream.userIoConnectCh) + + return nil +} + +func (s *NezhaHandler) AgentConnected(streamId string, agentIo io.ReadWriteCloser) error { + stream, err := s.GetStream(streamId) + if err != nil { + return err + } + + stream.agentIo = agentIo + close(stream.agentIoConnectCh) + + return nil +} + +func (s *NezhaHandler) StartStream(streamId string, timeout time.Duration) error { + stream, err := s.GetStream(streamId) + if err != nil { + return err + } + + timeoutTimer := time.NewTimer(timeout) + +LOOP: + for { + select { + case <-stream.userIoConnectCh: + if stream.agentIo != nil { + timeoutTimer.Stop() + break LOOP + } + case <-stream.agentIoConnectCh: + if stream.userIo != nil { + timeoutTimer.Stop() + break LOOP + } + case <-time.After(timeout): + break LOOP + } + } + + if stream.userIo == nil && stream.agentIo == nil { + return errors.New("timeout: no connection established") + } + if stream.userIo == nil { + return errors.New("timeout: user connection not established") + } + if stream.agentIo == nil { + return errors.New("timeout: agent connection not established") + } + + isDone := new(atomic.Bool) + endCh := make(chan struct{}) + + go func() { + _, innerErr := io.Copy(stream.userIo, stream.agentIo) + if innerErr != nil { + err = innerErr + } + if isDone.CompareAndSwap(false, true) { + close(endCh) + } + }() + go func() { + _, innerErr := io.Copy(stream.agentIo, stream.userIo) + if innerErr != nil { + err = innerErr + } + if isDone.CompareAndSwap(false, true) { + close(endCh) + } + }() + + <-endCh + + return err +} diff --git a/service/rpc/nezha.go b/service/rpc/nezha.go index e79c429..852107d 100644 --- a/service/rpc/nezha.go +++ b/service/rpc/nezha.go @@ -3,11 +3,14 @@ package rpc import ( "context" "fmt" - "github.com/naiba/nezha/pkg/ddns" - "github.com/naiba/nezha/pkg/utils" "log" + "sync" "time" + "github.com/naiba/nezha/pkg/ddns" + "github.com/naiba/nezha/pkg/grpcx" + "github.com/naiba/nezha/pkg/utils" + "github.com/jinzhu/copier" "github.com/nicksnyder/go-i18n/v2/i18n" @@ -16,8 +19,20 @@ import ( "github.com/naiba/nezha/service/singleton" ) +var NezhaHandlerSingleton *NezhaHandler + type NezhaHandler struct { - Auth *AuthHandler + Auth *authHandler + ioStreams map[string]*ioStreamContext + ioStreamMutex *sync.RWMutex +} + +func NewNezhaHandler() *NezhaHandler { + return &NezhaHandler{ + Auth: &authHandler{}, + ioStreamMutex: new(sync.RWMutex), + ioStreams: make(map[string]*ioStreamContext), + } } func (s *NezhaHandler) ReportTask(c context.Context, r *pb.TaskResult) (*pb.Receipt, error) { @@ -177,3 +192,28 @@ func (s *NezhaHandler) ReportSystemInfo(c context.Context, r *pb.Host) (*pb.Rece singleton.ServerList[clientID].Host = &host return &pb.Receipt{Proced: true}, nil } + +func (s *NezhaHandler) IOStream(stream pb.NezhaService_IOStreamServer) error { + if _, err := s.Auth.Check(stream.Context()); err != nil { + return err + } + id, err := stream.Recv() + if err != nil { + return err + } + if id == nil || len(id.Data) < 4 || (id.Data[0] != 0xff && id.Data[1] != 0x05 && id.Data[2] != 0xff && id.Data[3] == 0x05) { + return fmt.Errorf("invalid stream id") + } + + streamId := string(id.Data[4:]) + + if _, err := s.GetStream(streamId); err != nil { + return err + } + iw := grpcx.NewIOStreamWrapper(stream) + if err := s.AgentConnected(streamId, iw); err != nil { + return err + } + iw.Wait() + return nil +}