feat: 去除 webTerminal 的 websocket 依赖

This commit is contained in:
naiba 2024-07-14 12:47:36 +08:00
parent 417f972659
commit 1c91fcffac
16 changed files with 497 additions and 277 deletions

View File

@ -3,12 +3,8 @@ package controller
import ( import (
"errors" "errors"
"fmt" "fmt"
"log"
"net/http" "net/http"
"regexp"
"strconv" "strconv"
"strings"
"sync"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@ -24,22 +20,13 @@ import (
"github.com/naiba/nezha/pkg/utils" "github.com/naiba/nezha/pkg/utils"
"github.com/naiba/nezha/pkg/websocketx" "github.com/naiba/nezha/pkg/websocketx"
"github.com/naiba/nezha/proto" "github.com/naiba/nezha/proto"
"github.com/naiba/nezha/service/rpc"
"github.com/naiba/nezha/service/singleton" "github.com/naiba/nezha/service/singleton"
) )
type terminalContext struct {
agentConn *websocketx.Conn
userConn *websocketx.Conn
serverID uint64
host string
useSSL bool
}
type commonPage struct { type commonPage struct {
r *gin.Engine r *gin.Engine
terminals map[string]*terminalContext requestGroup singleflight.Group
terminalsLock *sync.Mutex
requestGroup singleflight.Group
} }
func (cp *commonPage) serve() { func (cp *commonPage) serve() {
@ -68,7 +55,6 @@ type viewPasswordForm struct {
func (p *commonPage) issueViewPassword(c *gin.Context) { func (p *commonPage) issueViewPassword(c *gin.Context) {
var vpf viewPasswordForm var vpf viewPasswordForm
err := c.ShouldBind(&vpf) err := c.ShouldBind(&vpf)
log.Println("bingo", vpf)
var hash []byte var hash []byte
if err == nil && vpf.Password != singleton.Conf.Site.ViewPassword { if err == nil && vpf.Password != singleton.Conf.Site.ViewPassword {
err = errors.New(singleton.Localizer.MustLocalize(&i18n.LocalizeConfig{MessageID: "WrongAccessPassword"})) err = errors.New(singleton.Localizer.MustLocalize(&i18n.LocalizeConfig{MessageID: "WrongAccessPassword"}))
@ -283,8 +269,6 @@ type Data struct {
Servers []*model.Server `json:"servers,omitempty"` Servers []*model.Server `json:"servers,omitempty"`
} }
var cloudflareCookiesValidator = regexp.MustCompile("^[A-Za-z0-9-_]+$")
func (cp *commonPage) ws(c *gin.Context) { func (cp *commonPage) ws(c *gin.Context) {
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil) conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil { if err != nil {
@ -322,9 +306,7 @@ func (cp *commonPage) ws(c *gin.Context) {
func (cp *commonPage) terminal(c *gin.Context) { func (cp *commonPage) terminal(c *gin.Context) {
terminalID := c.Param("id") terminalID := c.Param("id")
cp.terminalsLock.Lock() if _, err := rpc.NezhaHandlerSingleton.GetStream(terminalID); err != nil {
if terminalID == "" || cp.terminals[terminalID] == nil {
cp.terminalsLock.Unlock()
mygin.ShowErrorPage(c, mygin.ErrInfo{ mygin.ShowErrorPage(c, mygin.ErrInfo{
Code: http.StatusForbidden, Code: http.StatusForbidden,
Title: "无权访问", Title: "无权访问",
@ -334,104 +316,7 @@ func (cp *commonPage) terminal(c *gin.Context) {
}, true) }, true)
return return
} }
defer rpc.NezhaHandlerSingleton.CloseStream(terminalID)
terminal := cp.terminals[terminalID]
cp.terminalsLock.Unlock()
defer func() {
// 清理 context
cp.terminalsLock.Lock()
defer cp.terminalsLock.Unlock()
delete(cp.terminals, terminalID)
}()
var isAgent bool
if _, authorized := c.Get(model.CtxKeyAuthorizedUser); !authorized {
singleton.ServerLock.RLock()
_, hasID := singleton.SecretToID[c.Request.Header.Get("Secret")]
singleton.ServerLock.RUnlock()
if !hasID {
mygin.ShowErrorPage(c, mygin.ErrInfo{
Code: http.StatusForbidden,
Title: "无权访问",
Msg: "用户未登录或非法终端",
Link: "/",
Btn: "返回首页",
}, true)
return
}
if terminal.userConn == nil {
mygin.ShowErrorPage(c, mygin.ErrInfo{
Code: http.StatusForbidden,
Title: "无权访问",
Msg: "用户不在线",
Link: "/",
Btn: "返回首页",
}, true)
return
}
if terminal.agentConn != nil {
mygin.ShowErrorPage(c, mygin.ErrInfo{
Code: http.StatusInternalServerError,
Title: "连接已存在",
Msg: "Websocket协议切换失败",
Link: "/",
Btn: "返回首页",
}, true)
return
}
isAgent = true
} else {
singleton.ServerLock.RLock()
server := singleton.ServerList[terminal.serverID]
singleton.ServerLock.RUnlock()
if server == nil || server.TaskStream == nil {
mygin.ShowErrorPage(c, mygin.ErrInfo{
Code: http.StatusForbidden,
Title: "请求失败",
Msg: "服务器不存在或处于离线状态",
Link: "/server",
Btn: "返回重试",
}, true)
return
}
cloudflareCookies, _ := c.Cookie("CF_Authorization")
// Cloudflare Cookies 合法性验证
// 其应该包含.分隔的三组BASE64-URL编码
if cloudflareCookies != "" {
encodedCookies := strings.Split(cloudflareCookies, ".")
if len(encodedCookies) == 3 {
for i := 0; i < 3; i++ {
if !cloudflareCookiesValidator.MatchString(encodedCookies[i]) {
cloudflareCookies = ""
break
}
}
} else {
cloudflareCookies = ""
}
}
terminalData, _ := utils.Json.Marshal(&model.TerminalTask{
Host: terminal.host,
UseSSL: terminal.useSSL,
Session: terminalID,
Cookie: cloudflareCookies,
})
if err := server.TaskStream.Send(&proto.Task{
Type: model.TaskTypeTerminal,
Data: string(terminalData),
}); err != nil {
mygin.ShowErrorPage(c, mygin.ErrInfo{
Code: http.StatusForbidden,
Title: "请求失败",
Msg: "Agent信令下发失败",
Link: "/server",
Btn: "返回重试",
}, true)
return
}
}
wsConn, err := upgrader.Upgrade(c.Writer, c.Request, nil) wsConn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil { if err != nil {
@ -447,36 +332,7 @@ func (cp *commonPage) terminal(c *gin.Context) {
return return
} }
defer wsConn.Close() defer wsConn.Close()
conn := &websocketx.Conn{Conn: wsConn} conn := websocketx.NewConn(wsConn)
log.Printf("NEZHA>> terminal connected %t %q", isAgent, c.Request.URL)
defer log.Printf("NEZHA>> terminal disconnected %t %q", isAgent, c.Request.URL)
if isAgent {
terminal.agentConn = conn
defer func() {
// Agent断开链接时断开用户连接
if terminal.userConn != nil {
terminal.userConn.Close()
}
}()
} else {
terminal.userConn = conn
defer func() {
// 用户断开链接时断开 Agent 连接
if terminal.agentConn != nil {
terminal.agentConn.Close()
}
}()
}
deadlineCh := make(chan interface{})
go func() {
// 对方连接超时
connectDeadline := time.NewTimer(time.Second * 15)
<-connectDeadline.C
deadlineCh <- struct{}{}
}()
go func() { go func() {
// PING 保活 // PING 保活
@ -488,58 +344,11 @@ func (cp *commonPage) terminal(c *gin.Context) {
} }
}() }()
dataCh := make(chan []byte) if err = rpc.NezhaHandlerSingleton.UserConnected(terminalID, conn); err != nil {
errorCh := make(chan error) return
go func() {
for {
msgType, data, err := conn.ReadMessage()
if err != nil {
errorCh <- err
return
}
// 将文本消息转换为命令输入
if msgType == websocket.TextMessage {
data = append([]byte{0}, data...)
}
dataCh <- data
}
}()
var dataBuffer [][]byte
var distConn *websocketx.Conn
checkDistConn := func() {
if distConn == nil {
if isAgent {
distConn = terminal.userConn
} else {
distConn = terminal.agentConn
}
}
} }
for { rpc.NezhaHandlerSingleton.StartStream(terminalID, time.Second*10)
select {
case <-deadlineCh:
checkDistConn()
if distConn == nil {
return
}
case <-errorCh:
return
case data := <-dataCh:
dataBuffer = append(dataBuffer, data)
checkDistConn()
if distConn != nil {
for i := 0; i < len(dataBuffer); i++ {
err = distConn.WriteMessage(websocket.BinaryMessage, dataBuffer[i])
if err != nil {
return
}
}
dataBuffer = dataBuffer[:0]
}
}
}
} }
type createTerminalRequest struct { type createTerminalRequest struct {
@ -585,6 +394,8 @@ func (cp *commonPage) createTerminal(c *gin.Context) {
return return
} }
rpc.NezhaHandlerSingleton.CreateStream(id)
singleton.ServerLock.RLock() singleton.ServerLock.RLock()
server := singleton.ServerList[createTerminalReq.ID] server := singleton.ServerList[createTerminalReq.ID]
singleton.ServerLock.RUnlock() singleton.ServerLock.RUnlock()
@ -599,13 +410,21 @@ func (cp *commonPage) createTerminal(c *gin.Context) {
return return
} }
cp.terminalsLock.Lock() terminalData, _ := utils.Json.Marshal(&model.TerminalTask{
defer cp.terminalsLock.Unlock() StreamID: id,
})
cp.terminals[id] = &terminalContext{ if err := server.TaskStream.Send(&proto.Task{
serverID: createTerminalReq.ID, Type: model.TaskTypeTerminalGRPC,
host: createTerminalReq.Host, Data: string(terminalData),
useSSL: createTerminalReq.Protocol == "https:", }); err != nil {
mygin.ShowErrorPage(c, mygin.ErrInfo{
Code: http.StatusForbidden,
Title: "请求失败",
Msg: "Agent信令下发失败",
Link: "/server",
Btn: "返回重试",
}, true)
return
} }
c.HTML(http.StatusOK, "dashboard-"+singleton.Conf.Site.DashboardTheme+"/terminal", mygin.CommonEnvironment(c, gin.H{ c.HTML(http.StatusOK, "dashboard-"+singleton.Conf.Site.DashboardTheme+"/terminal", mygin.CommonEnvironment(c, gin.H{

View File

@ -9,7 +9,6 @@ import (
"os" "os"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
"code.cloudfoundry.org/bytefmt" "code.cloudfoundry.org/bytefmt"
@ -68,7 +67,7 @@ func ServeWeb(port uint) *http.Server {
func routers(r *gin.Engine) { func routers(r *gin.Engine) {
// 通用页面 // 通用页面
cp := commonPage{r: r, terminals: make(map[string]*terminalContext), terminalsLock: new(sync.Mutex)} cp := commonPage{r: r}
cp.serve() cp.serve()
// 游客页面 // 游客页面
gp := guestPage{r} gp := guestPage{r}

View File

@ -59,6 +59,7 @@ func main() {
return return
} }
// TODO 使用 cmux 在同一端口服务 HTTP 和 gRPC
singleton.CleanMonitorHistory() singleton.CleanMonitorHistory()
go rpc.ServeRPC(singleton.Conf.GRPCPort) go rpc.ServeRPC(singleton.Conf.GRPCPort)
serviceSentinelDispatchBus := make(chan model.Monitor) // 用于传递服务监控任务信息的channel serviceSentinelDispatchBus := make(chan model.Monitor) // 用于传递服务监控任务信息的channel

View File

@ -14,9 +14,8 @@ import (
func ServeRPC(port uint) { func ServeRPC(port uint) {
server := grpc.NewServer() server := grpc.NewServer()
pb.RegisterNezhaServiceServer(server, &rpcService.NezhaHandler{ rpcService.NezhaHandlerSingleton = rpcService.NewNezhaHandler()
Auth: &rpcService.AuthHandler{}, pb.RegisterNezhaServiceServer(server, rpcService.NezhaHandlerSingleton)
})
listen, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) listen, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
if err != nil { if err != nil {
panic(err) panic(err)

3
go.mod
View File

@ -18,7 +18,6 @@ require (
github.com/ory/graceful v0.1.3 github.com/ory/graceful v0.1.3
github.com/patrickmn/go-cache v2.1.0+incompatible github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/robfig/cron/v3 v3.0.1 github.com/robfig/cron/v3 v3.0.1
github.com/samber/lo v1.39.0
github.com/spf13/pflag v1.0.5 github.com/spf13/pflag v1.0.5
github.com/spf13/viper v1.18.2 github.com/spf13/viper v1.18.2
github.com/xanzy/go-gitlab v0.103.0 github.com/xanzy/go-gitlab v0.103.0
@ -28,7 +27,7 @@ require (
golang.org/x/sync v0.7.0 golang.org/x/sync v0.7.0
golang.org/x/text v0.16.0 golang.org/x/text v0.16.0
google.golang.org/grpc v1.63.0 google.golang.org/grpc v1.63.0
google.golang.org/protobuf v1.33.0 google.golang.org/protobuf v1.34.2
gorm.io/driver/sqlite v1.5.5 gorm.io/driver/sqlite v1.5.5
gorm.io/gorm v1.25.10 gorm.io/gorm v1.25.10
sigs.k8s.io/yaml v1.4.0 sigs.k8s.io/yaml v1.4.0

6
go.sum
View File

@ -152,8 +152,6 @@ github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6ke
github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4= github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4=
github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE= github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE=
github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ= github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ=
github.com/samber/lo v1.39.0 h1:4gTz1wUhNYLhFSKl6O+8peW0v2F4BCY034GRpU9WnuA=
github.com/samber/lo v1.39.0/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA=
github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo=
github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0=
github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
@ -240,8 +238,8 @@ google.golang.org/grpc v1.63.0 h1:WjKe+dnvABXyPJMD7KDNLxtoGk5tgk+YFWN6cBWjZE8=
google.golang.org/grpc v1.63.0/go.mod h1:WAX/8DgncnokcFUldAxq7GeB5DXHDbMF+lLvDomNkRA= google.golang.org/grpc v1.63.0/go.mod h1:WAX/8DgncnokcFUldAxq7GeB5DXHDbMF+lLvDomNkRA=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=

View File

@ -20,17 +20,11 @@ const (
TaskTypeTerminal TaskTypeTerminal
TaskTypeUpgrade TaskTypeUpgrade
TaskTypeKeepalive TaskTypeKeepalive
TaskTypeTerminalGRPC
) )
type TerminalTask struct { type TerminalTask struct {
// websocket 主机名 StreamID string
Host string `json:"host,omitempty"`
// 是否启用 SSL
UseSSL bool `json:"use_ssl,omitempty"`
// 会话标识
Session string `json:"session,omitempty"`
// Agent在连接Server时需要的额外Cookie信息
Cookie string `json:"cookie,omitempty"`
} }
const ( const (

View File

@ -0,0 +1,65 @@
package grpcx
import (
"context"
"io"
"sync/atomic"
"github.com/naiba/nezha/proto"
)
var _ io.ReadWriteCloser = &IOStreamWrapper{}
type IOStream interface {
Recv() (*proto.IOStreamData, error)
Send(*proto.IOStreamData) error
Context() context.Context
}
type IOStreamWrapper struct {
IOStream
dataBuf []byte
closed *atomic.Bool
closeCh chan struct{}
}
func NewIOStreamWrapper(stream IOStream) *IOStreamWrapper {
return &IOStreamWrapper{
IOStream: stream,
closeCh: make(chan struct{}),
closed: new(atomic.Bool),
}
}
func (iw *IOStreamWrapper) Read(p []byte) (n int, err error) {
if len(iw.dataBuf) > 0 {
n := copy(p, iw.dataBuf)
iw.dataBuf = iw.dataBuf[n:]
return n, nil
}
var data *proto.IOStreamData
if data, err = iw.Recv(); err != nil {
return 0, err
}
n = copy(p, data.Data)
if n < len(data.Data) {
iw.dataBuf = data.Data[n:]
}
return n, nil
}
func (iw *IOStreamWrapper) Write(p []byte) (n int, err error) {
err = iw.Send(&proto.IOStreamData{Data: p})
return len(p), err
}
func (iw *IOStreamWrapper) Close() error {
if iw.closed.CompareAndSwap(false, true) {
close(iw.closeCh)
}
return nil
}
func (iw *IOStreamWrapper) Wait() {
<-iw.closeCh
}

View File

@ -4,22 +4,44 @@ import (
"sync" "sync"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/samber/lo"
) )
type Conn struct { type Conn struct {
*websocket.Conn *websocket.Conn
writeLock sync.Mutex writeLock *sync.Mutex
dataBuf []byte
} }
func (conn *Conn) WriteMessage(msgType int, data []byte) error { func NewConn(conn *websocket.Conn) *Conn {
return &Conn{Conn: conn, writeLock: new(sync.Mutex)}
}
func (conn *Conn) Write(data []byte) (int, error) {
conn.writeLock.Lock() conn.writeLock.Lock()
defer conn.writeLock.Unlock() defer conn.writeLock.Unlock()
var err error if err := conn.Conn.WriteMessage(websocket.BinaryMessage, data); err != nil {
lo.TryCatchWithErrorValue(func() error { return 0, err
return conn.Conn.WriteMessage(msgType, data) }
}, func(res any) { return len(data), nil
err = res.(error) }
})
return err func (conn *Conn) Read(data []byte) (int, error) {
if len(conn.dataBuf) > 0 {
n := copy(data, conn.dataBuf)
conn.dataBuf = conn.dataBuf[n:]
return n, nil
}
mType, innerData, err := conn.Conn.ReadMessage()
if err != nil {
return 0, err
}
// 将文本消息转换为命令输入
if mType == websocket.TextMessage {
innerData = append([]byte{0}, innerData...)
}
n := copy(data, innerData)
if n < len(innerData) {
conn.dataBuf = innerData[n:]
}
return n, nil
} }

View File

@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT. // Code generated by protoc-gen-go. DO NOT EDIT.
// versions: // versions:
// protoc-gen-go v1.34.1 // protoc-gen-go v1.34.2
// protoc v5.26.1 // protoc v5.27.1
// source: proto/nezha.proto // source: proto/nezha.proto
package proto package proto
@ -582,6 +582,53 @@ func (x *Receipt) GetProced() bool {
return false return false
} }
type IOStreamData struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Data []byte `protobuf:"bytes,1,opt,name=data,proto3" json:"data,omitempty"`
}
func (x *IOStreamData) Reset() {
*x = IOStreamData{}
if protoimpl.UnsafeEnabled {
mi := &file_proto_nezha_proto_msgTypes[6]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *IOStreamData) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*IOStreamData) ProtoMessage() {}
func (x *IOStreamData) ProtoReflect() protoreflect.Message {
mi := &file_proto_nezha_proto_msgTypes[6]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use IOStreamData.ProtoReflect.Descriptor instead.
func (*IOStreamData) Descriptor() ([]byte, []int) {
return file_proto_nezha_proto_rawDescGZIP(), []int{6}
}
func (x *IOStreamData) GetData() []byte {
if x != nil {
return x.Data
}
return nil
}
var File_proto_nezha_proto protoreflect.FileDescriptor var File_proto_nezha_proto protoreflect.FileDescriptor
var file_proto_nezha_proto_rawDesc = []byte{ var file_proto_nezha_proto_rawDesc = []byte{
@ -663,21 +710,27 @@ var file_proto_nezha_proto_rawDesc = []byte{
0x65, 0x73, 0x73, 0x66, 0x75, 0x6c, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x75, 0x65, 0x73, 0x73, 0x66, 0x75, 0x6c, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x75,
0x63, 0x63, 0x65, 0x73, 0x73, 0x66, 0x75, 0x6c, 0x22, 0x21, 0x0a, 0x07, 0x52, 0x65, 0x63, 0x65, 0x63, 0x63, 0x65, 0x73, 0x73, 0x66, 0x75, 0x6c, 0x22, 0x21, 0x0a, 0x07, 0x52, 0x65, 0x63, 0x65,
0x69, 0x70, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x70, 0x72, 0x6f, 0x63, 0x65, 0x64, 0x18, 0x01, 0x20, 0x69, 0x70, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x70, 0x72, 0x6f, 0x63, 0x65, 0x64, 0x18, 0x01, 0x20,
0x01, 0x28, 0x08, 0x52, 0x06, 0x70, 0x72, 0x6f, 0x63, 0x65, 0x64, 0x32, 0xd6, 0x01, 0x0a, 0x0c, 0x01, 0x28, 0x08, 0x52, 0x06, 0x70, 0x72, 0x6f, 0x63, 0x65, 0x64, 0x22, 0x22, 0x0a, 0x0c, 0x49,
0x4e, 0x65, 0x7a, 0x68, 0x61, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x33, 0x0a, 0x11, 0x4f, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x44, 0x61, 0x74, 0x61, 0x12, 0x12, 0x0a, 0x04, 0x64,
0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x53, 0x74, 0x61, 0x74, 0x61, 0x74, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x64, 0x61, 0x74, 0x61, 0x32,
0x65, 0x12, 0x0c, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x1a, 0x92, 0x02, 0x0a, 0x0c, 0x4e, 0x65, 0x7a, 0x68, 0x61, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65,
0x0e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x65, 0x69, 0x70, 0x74, 0x22, 0x12, 0x33, 0x0a, 0x11, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d,
0x00, 0x12, 0x31, 0x0a, 0x10, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x53, 0x79, 0x73, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x0c, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x53, 0x74,
0x6d, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x0b, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x48, 0x6f, 0x61, 0x74, 0x65, 0x1a, 0x0e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x65,
0x73, 0x74, 0x1a, 0x0e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x65, 0x69, 0x69, 0x70, 0x74, 0x22, 0x00, 0x12, 0x31, 0x0a, 0x10, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x53,
0x70, 0x74, 0x22, 0x00, 0x12, 0x31, 0x0a, 0x0a, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x54, 0x61, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x0b, 0x2e, 0x70, 0x72, 0x6f, 0x74,
0x73, 0x6b, 0x12, 0x11, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x54, 0x61, 0x73, 0x6b, 0x52, 0x6f, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x1a, 0x0e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52,
0x65, 0x73, 0x75, 0x6c, 0x74, 0x1a, 0x0e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x65, 0x63, 0x65, 0x69, 0x70, 0x74, 0x22, 0x00, 0x12, 0x31, 0x0a, 0x0a, 0x52, 0x65, 0x70, 0x6f,
0x63, 0x65, 0x69, 0x70, 0x74, 0x22, 0x00, 0x12, 0x2b, 0x0a, 0x0b, 0x52, 0x65, 0x71, 0x75, 0x65, 0x72, 0x74, 0x54, 0x61, 0x73, 0x6b, 0x12, 0x11, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x54,
0x73, 0x74, 0x54, 0x61, 0x73, 0x6b, 0x12, 0x0b, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x48, 0x61, 0x73, 0x6b, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x1a, 0x0e, 0x2e, 0x70, 0x72, 0x6f, 0x74,
0x6f, 0x73, 0x74, 0x1a, 0x0b, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x54, 0x61, 0x73, 0x6b, 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x65, 0x69, 0x70, 0x74, 0x22, 0x00, 0x12, 0x2b, 0x0a, 0x0b, 0x52,
0x22, 0x00, 0x30, 0x01, 0x42, 0x09, 0x5a, 0x07, 0x2e, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x54, 0x61, 0x73, 0x6b, 0x12, 0x0b, 0x2e, 0x70, 0x72, 0x6f,
0x74, 0x6f, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x1a, 0x0b, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e,
0x54, 0x61, 0x73, 0x6b, 0x22, 0x00, 0x30, 0x01, 0x12, 0x3a, 0x0a, 0x08, 0x49, 0x4f, 0x53, 0x74,
0x72, 0x65, 0x61, 0x6d, 0x12, 0x13, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x49, 0x4f, 0x53,
0x74, 0x72, 0x65, 0x61, 0x6d, 0x44, 0x61, 0x74, 0x61, 0x1a, 0x13, 0x2e, 0x70, 0x72, 0x6f, 0x74,
0x6f, 0x2e, 0x49, 0x4f, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x44, 0x61, 0x74, 0x61, 0x22, 0x00,
0x28, 0x01, 0x30, 0x01, 0x42, 0x09, 0x5a, 0x07, 0x2e, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62,
0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
} }
@ -693,14 +746,15 @@ func file_proto_nezha_proto_rawDescGZIP() []byte {
return file_proto_nezha_proto_rawDescData return file_proto_nezha_proto_rawDescData
} }
var file_proto_nezha_proto_msgTypes = make([]protoimpl.MessageInfo, 6) var file_proto_nezha_proto_msgTypes = make([]protoimpl.MessageInfo, 7)
var file_proto_nezha_proto_goTypes = []interface{}{ var file_proto_nezha_proto_goTypes = []any{
(*Host)(nil), // 0: proto.Host (*Host)(nil), // 0: proto.Host
(*State)(nil), // 1: proto.State (*State)(nil), // 1: proto.State
(*State_SensorTemperature)(nil), // 2: proto.State_SensorTemperature (*State_SensorTemperature)(nil), // 2: proto.State_SensorTemperature
(*Task)(nil), // 3: proto.Task (*Task)(nil), // 3: proto.Task
(*TaskResult)(nil), // 4: proto.TaskResult (*TaskResult)(nil), // 4: proto.TaskResult
(*Receipt)(nil), // 5: proto.Receipt (*Receipt)(nil), // 5: proto.Receipt
(*IOStreamData)(nil), // 6: proto.IOStreamData
} }
var file_proto_nezha_proto_depIdxs = []int32{ var file_proto_nezha_proto_depIdxs = []int32{
2, // 0: proto.State.temperatures:type_name -> proto.State_SensorTemperature 2, // 0: proto.State.temperatures:type_name -> proto.State_SensorTemperature
@ -708,12 +762,14 @@ var file_proto_nezha_proto_depIdxs = []int32{
0, // 2: proto.NezhaService.ReportSystemInfo:input_type -> proto.Host 0, // 2: proto.NezhaService.ReportSystemInfo:input_type -> proto.Host
4, // 3: proto.NezhaService.ReportTask:input_type -> proto.TaskResult 4, // 3: proto.NezhaService.ReportTask:input_type -> proto.TaskResult
0, // 4: proto.NezhaService.RequestTask:input_type -> proto.Host 0, // 4: proto.NezhaService.RequestTask:input_type -> proto.Host
5, // 5: proto.NezhaService.ReportSystemState:output_type -> proto.Receipt 6, // 5: proto.NezhaService.IOStream:input_type -> proto.IOStreamData
5, // 6: proto.NezhaService.ReportSystemInfo:output_type -> proto.Receipt 5, // 6: proto.NezhaService.ReportSystemState:output_type -> proto.Receipt
5, // 7: proto.NezhaService.ReportTask:output_type -> proto.Receipt 5, // 7: proto.NezhaService.ReportSystemInfo:output_type -> proto.Receipt
3, // 8: proto.NezhaService.RequestTask:output_type -> proto.Task 5, // 8: proto.NezhaService.ReportTask:output_type -> proto.Receipt
5, // [5:9] is the sub-list for method output_type 3, // 9: proto.NezhaService.RequestTask:output_type -> proto.Task
1, // [1:5] is the sub-list for method input_type 6, // 10: proto.NezhaService.IOStream:output_type -> proto.IOStreamData
6, // [6:11] is the sub-list for method output_type
1, // [1:6] is the sub-list for method input_type
1, // [1:1] is the sub-list for extension type_name 1, // [1:1] is the sub-list for extension type_name
1, // [1:1] is the sub-list for extension extendee 1, // [1:1] is the sub-list for extension extendee
0, // [0:1] is the sub-list for field type_name 0, // [0:1] is the sub-list for field type_name
@ -725,7 +781,7 @@ func file_proto_nezha_proto_init() {
return return
} }
if !protoimpl.UnsafeEnabled { if !protoimpl.UnsafeEnabled {
file_proto_nezha_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { file_proto_nezha_proto_msgTypes[0].Exporter = func(v any, i int) any {
switch v := v.(*Host); i { switch v := v.(*Host); i {
case 0: case 0:
return &v.state return &v.state
@ -737,7 +793,7 @@ func file_proto_nezha_proto_init() {
return nil return nil
} }
} }
file_proto_nezha_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { file_proto_nezha_proto_msgTypes[1].Exporter = func(v any, i int) any {
switch v := v.(*State); i { switch v := v.(*State); i {
case 0: case 0:
return &v.state return &v.state
@ -749,7 +805,7 @@ func file_proto_nezha_proto_init() {
return nil return nil
} }
} }
file_proto_nezha_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { file_proto_nezha_proto_msgTypes[2].Exporter = func(v any, i int) any {
switch v := v.(*State_SensorTemperature); i { switch v := v.(*State_SensorTemperature); i {
case 0: case 0:
return &v.state return &v.state
@ -761,7 +817,7 @@ func file_proto_nezha_proto_init() {
return nil return nil
} }
} }
file_proto_nezha_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { file_proto_nezha_proto_msgTypes[3].Exporter = func(v any, i int) any {
switch v := v.(*Task); i { switch v := v.(*Task); i {
case 0: case 0:
return &v.state return &v.state
@ -773,7 +829,7 @@ func file_proto_nezha_proto_init() {
return nil return nil
} }
} }
file_proto_nezha_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { file_proto_nezha_proto_msgTypes[4].Exporter = func(v any, i int) any {
switch v := v.(*TaskResult); i { switch v := v.(*TaskResult); i {
case 0: case 0:
return &v.state return &v.state
@ -785,7 +841,7 @@ func file_proto_nezha_proto_init() {
return nil return nil
} }
} }
file_proto_nezha_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} { file_proto_nezha_proto_msgTypes[5].Exporter = func(v any, i int) any {
switch v := v.(*Receipt); i { switch v := v.(*Receipt); i {
case 0: case 0:
return &v.state return &v.state
@ -797,6 +853,18 @@ func file_proto_nezha_proto_init() {
return nil return nil
} }
} }
file_proto_nezha_proto_msgTypes[6].Exporter = func(v any, i int) any {
switch v := v.(*IOStreamData); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
} }
type x struct{} type x struct{}
out := protoimpl.TypeBuilder{ out := protoimpl.TypeBuilder{
@ -804,7 +872,7 @@ func file_proto_nezha_proto_init() {
GoPackagePath: reflect.TypeOf(x{}).PkgPath(), GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_proto_nezha_proto_rawDesc, RawDescriptor: file_proto_nezha_proto_rawDesc,
NumEnums: 0, NumEnums: 0,
NumMessages: 6, NumMessages: 7,
NumExtensions: 0, NumExtensions: 0,
NumServices: 1, NumServices: 1,
}, },

View File

@ -8,6 +8,7 @@ service NezhaService {
rpc ReportSystemInfo(Host)returns(Receipt){} rpc ReportSystemInfo(Host)returns(Receipt){}
rpc ReportTask(TaskResult)returns(Receipt){} rpc ReportTask(TaskResult)returns(Receipt){}
rpc RequestTask(Host)returns(stream Task){} rpc RequestTask(Host)returns(stream Task){}
rpc IOStream(stream IOStreamData)returns(stream IOStreamData){}
} }
message Host { message Host {
@ -68,3 +69,7 @@ message TaskResult {
message Receipt{ message Receipt{
bool proced = 1; bool proced = 1;
} }
message IOStreamData {
bytes data = 1;
}

View File

@ -1,7 +1,7 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT. // Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions: // versions:
// - protoc-gen-go-grpc v1.3.0 // - protoc-gen-go-grpc v1.3.0
// - protoc v3.21.12 // - protoc v5.27.1
// source: proto/nezha.proto // source: proto/nezha.proto
package proto package proto
@ -23,6 +23,7 @@ const (
NezhaService_ReportSystemInfo_FullMethodName = "/proto.NezhaService/ReportSystemInfo" NezhaService_ReportSystemInfo_FullMethodName = "/proto.NezhaService/ReportSystemInfo"
NezhaService_ReportTask_FullMethodName = "/proto.NezhaService/ReportTask" NezhaService_ReportTask_FullMethodName = "/proto.NezhaService/ReportTask"
NezhaService_RequestTask_FullMethodName = "/proto.NezhaService/RequestTask" NezhaService_RequestTask_FullMethodName = "/proto.NezhaService/RequestTask"
NezhaService_IOStream_FullMethodName = "/proto.NezhaService/IOStream"
) )
// NezhaServiceClient is the client API for NezhaService service. // NezhaServiceClient is the client API for NezhaService service.
@ -33,6 +34,7 @@ type NezhaServiceClient interface {
ReportSystemInfo(ctx context.Context, in *Host, opts ...grpc.CallOption) (*Receipt, error) ReportSystemInfo(ctx context.Context, in *Host, opts ...grpc.CallOption) (*Receipt, error)
ReportTask(ctx context.Context, in *TaskResult, opts ...grpc.CallOption) (*Receipt, error) ReportTask(ctx context.Context, in *TaskResult, opts ...grpc.CallOption) (*Receipt, error)
RequestTask(ctx context.Context, in *Host, opts ...grpc.CallOption) (NezhaService_RequestTaskClient, error) RequestTask(ctx context.Context, in *Host, opts ...grpc.CallOption) (NezhaService_RequestTaskClient, error)
IOStream(ctx context.Context, opts ...grpc.CallOption) (NezhaService_IOStreamClient, error)
} }
type nezhaServiceClient struct { type nezhaServiceClient struct {
@ -102,6 +104,37 @@ func (x *nezhaServiceRequestTaskClient) Recv() (*Task, error) {
return m, nil return m, nil
} }
func (c *nezhaServiceClient) IOStream(ctx context.Context, opts ...grpc.CallOption) (NezhaService_IOStreamClient, error) {
stream, err := c.cc.NewStream(ctx, &NezhaService_ServiceDesc.Streams[1], NezhaService_IOStream_FullMethodName, opts...)
if err != nil {
return nil, err
}
x := &nezhaServiceIOStreamClient{stream}
return x, nil
}
type NezhaService_IOStreamClient interface {
Send(*IOStreamData) error
Recv() (*IOStreamData, error)
grpc.ClientStream
}
type nezhaServiceIOStreamClient struct {
grpc.ClientStream
}
func (x *nezhaServiceIOStreamClient) Send(m *IOStreamData) error {
return x.ClientStream.SendMsg(m)
}
func (x *nezhaServiceIOStreamClient) Recv() (*IOStreamData, error) {
m := new(IOStreamData)
if err := x.ClientStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
// NezhaServiceServer is the server API for NezhaService service. // NezhaServiceServer is the server API for NezhaService service.
// All implementations should embed UnimplementedNezhaServiceServer // All implementations should embed UnimplementedNezhaServiceServer
// for forward compatibility // for forward compatibility
@ -110,6 +143,7 @@ type NezhaServiceServer interface {
ReportSystemInfo(context.Context, *Host) (*Receipt, error) ReportSystemInfo(context.Context, *Host) (*Receipt, error)
ReportTask(context.Context, *TaskResult) (*Receipt, error) ReportTask(context.Context, *TaskResult) (*Receipt, error)
RequestTask(*Host, NezhaService_RequestTaskServer) error RequestTask(*Host, NezhaService_RequestTaskServer) error
IOStream(NezhaService_IOStreamServer) error
} }
// UnimplementedNezhaServiceServer should be embedded to have forward compatible implementations. // UnimplementedNezhaServiceServer should be embedded to have forward compatible implementations.
@ -128,6 +162,9 @@ func (UnimplementedNezhaServiceServer) ReportTask(context.Context, *TaskResult)
func (UnimplementedNezhaServiceServer) RequestTask(*Host, NezhaService_RequestTaskServer) error { func (UnimplementedNezhaServiceServer) RequestTask(*Host, NezhaService_RequestTaskServer) error {
return status.Errorf(codes.Unimplemented, "method RequestTask not implemented") return status.Errorf(codes.Unimplemented, "method RequestTask not implemented")
} }
func (UnimplementedNezhaServiceServer) IOStream(NezhaService_IOStreamServer) error {
return status.Errorf(codes.Unimplemented, "method IOStream not implemented")
}
// UnsafeNezhaServiceServer may be embedded to opt out of forward compatibility for this service. // UnsafeNezhaServiceServer may be embedded to opt out of forward compatibility for this service.
// Use of this interface is not recommended, as added methods to NezhaServiceServer will // Use of this interface is not recommended, as added methods to NezhaServiceServer will
@ -215,6 +252,32 @@ func (x *nezhaServiceRequestTaskServer) Send(m *Task) error {
return x.ServerStream.SendMsg(m) return x.ServerStream.SendMsg(m)
} }
func _NezhaService_IOStream_Handler(srv interface{}, stream grpc.ServerStream) error {
return srv.(NezhaServiceServer).IOStream(&nezhaServiceIOStreamServer{stream})
}
type NezhaService_IOStreamServer interface {
Send(*IOStreamData) error
Recv() (*IOStreamData, error)
grpc.ServerStream
}
type nezhaServiceIOStreamServer struct {
grpc.ServerStream
}
func (x *nezhaServiceIOStreamServer) Send(m *IOStreamData) error {
return x.ServerStream.SendMsg(m)
}
func (x *nezhaServiceIOStreamServer) Recv() (*IOStreamData, error) {
m := new(IOStreamData)
if err := x.ServerStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
// NezhaService_ServiceDesc is the grpc.ServiceDesc for NezhaService service. // NezhaService_ServiceDesc is the grpc.ServiceDesc for NezhaService service.
// It's only intended for direct use with grpc.RegisterService, // It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy) // and not to be introspected or modified (even as a copy)
@ -241,6 +304,12 @@ var NezhaService_ServiceDesc = grpc.ServiceDesc{
Handler: _NezhaService_RequestTask_Handler, Handler: _NezhaService_RequestTask_Handler,
ServerStreams: true, ServerStreams: true,
}, },
{
StreamName: "IOStream",
Handler: _NezhaService_IOStream_Handler,
ServerStreams: true,
ClientStreams: true,
},
}, },
Metadata: "proto/nezha.proto", Metadata: "proto/nezha.proto",
} }

View File

@ -1 +1,3 @@
protoc --go-grpc_out="require_unimplemented_servers=false:." --go_out="." proto/*.proto protoc --go-grpc_out="require_unimplemented_servers=false:." --go_out="." proto/*.proto
rm -rf ../agent/proto
cp -r proto ../agent

View File

@ -10,19 +10,19 @@ import (
"github.com/naiba/nezha/service/singleton" "github.com/naiba/nezha/service/singleton"
) )
type AuthHandler struct { type authHandler struct {
ClientSecret string ClientSecret string
} }
func (a *AuthHandler) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { func (a *authHandler) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
return map[string]string{"client_secret": a.ClientSecret}, nil return map[string]string{"client_secret": a.ClientSecret}, nil
} }
func (a *AuthHandler) RequireTransportSecurity() bool { func (a *authHandler) RequireTransportSecurity() bool {
return false return false
} }
func (a *AuthHandler) Check(ctx context.Context) (uint64, error) { func (a *authHandler) Check(ctx context.Context) (uint64, error) {
md, ok := metadata.FromIncomingContext(ctx) md, ok := metadata.FromIncomingContext(ctx)
if !ok { if !ok {
return 0, status.Errorf(codes.Unauthenticated, "获取 metaData 失败") return 0, status.Errorf(codes.Unauthenticated, "获取 metaData 失败")

140
service/rpc/io_stream.go Normal file
View File

@ -0,0 +1,140 @@
package rpc
import (
"errors"
"io"
"sync/atomic"
"time"
)
type ioStreamContext struct {
userIo io.ReadWriteCloser
agentIo io.ReadWriteCloser
userIoConnectCh chan struct{}
agentIoConnectCh chan struct{}
}
func (s *NezhaHandler) CreateStream(streamId string) {
s.ioStreamMutex.Lock()
defer s.ioStreamMutex.Unlock()
s.ioStreams[streamId] = &ioStreamContext{
userIoConnectCh: make(chan struct{}),
agentIoConnectCh: make(chan struct{}),
}
}
func (s *NezhaHandler) GetStream(streamId string) (*ioStreamContext, error) {
s.ioStreamMutex.RLock()
defer s.ioStreamMutex.RUnlock()
if ctx, ok := s.ioStreams[streamId]; ok {
return ctx, nil
}
return nil, errors.New("stream not found")
}
func (s *NezhaHandler) CloseStream(streamId string) error {
s.ioStreamMutex.Lock()
defer s.ioStreamMutex.Unlock()
if ctx, ok := s.ioStreams[streamId]; ok {
if ctx.userIo != nil {
ctx.userIo.Close()
}
if ctx.agentIo != nil {
ctx.agentIo.Close()
}
delete(s.ioStreams, streamId)
}
return nil
}
func (s *NezhaHandler) UserConnected(streamId string, userIo io.ReadWriteCloser) error {
stream, err := s.GetStream(streamId)
if err != nil {
return err
}
stream.userIo = userIo
close(stream.userIoConnectCh)
return nil
}
func (s *NezhaHandler) AgentConnected(streamId string, agentIo io.ReadWriteCloser) error {
stream, err := s.GetStream(streamId)
if err != nil {
return err
}
stream.agentIo = agentIo
close(stream.agentIoConnectCh)
return nil
}
func (s *NezhaHandler) StartStream(streamId string, timeout time.Duration) error {
stream, err := s.GetStream(streamId)
if err != nil {
return err
}
timeoutTimer := time.NewTimer(timeout)
LOOP:
for {
select {
case <-stream.userIoConnectCh:
if stream.agentIo != nil {
timeoutTimer.Stop()
break LOOP
}
case <-stream.agentIoConnectCh:
if stream.userIo != nil {
timeoutTimer.Stop()
break LOOP
}
case <-time.After(timeout):
break LOOP
}
}
if stream.userIo == nil && stream.agentIo == nil {
return errors.New("timeout: no connection established")
}
if stream.userIo == nil {
return errors.New("timeout: user connection not established")
}
if stream.agentIo == nil {
return errors.New("timeout: agent connection not established")
}
isDone := new(atomic.Bool)
endCh := make(chan struct{})
go func() {
_, innerErr := io.Copy(stream.userIo, stream.agentIo)
if innerErr != nil {
err = innerErr
}
if isDone.CompareAndSwap(false, true) {
close(endCh)
}
}()
go func() {
_, innerErr := io.Copy(stream.agentIo, stream.userIo)
if innerErr != nil {
err = innerErr
}
if isDone.CompareAndSwap(false, true) {
close(endCh)
}
}()
<-endCh
return err
}

View File

@ -3,11 +3,14 @@ package rpc
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/naiba/nezha/pkg/ddns"
"github.com/naiba/nezha/pkg/utils"
"log" "log"
"sync"
"time" "time"
"github.com/naiba/nezha/pkg/ddns"
"github.com/naiba/nezha/pkg/grpcx"
"github.com/naiba/nezha/pkg/utils"
"github.com/jinzhu/copier" "github.com/jinzhu/copier"
"github.com/nicksnyder/go-i18n/v2/i18n" "github.com/nicksnyder/go-i18n/v2/i18n"
@ -16,8 +19,20 @@ import (
"github.com/naiba/nezha/service/singleton" "github.com/naiba/nezha/service/singleton"
) )
var NezhaHandlerSingleton *NezhaHandler
type NezhaHandler struct { type NezhaHandler struct {
Auth *AuthHandler Auth *authHandler
ioStreams map[string]*ioStreamContext
ioStreamMutex *sync.RWMutex
}
func NewNezhaHandler() *NezhaHandler {
return &NezhaHandler{
Auth: &authHandler{},
ioStreamMutex: new(sync.RWMutex),
ioStreams: make(map[string]*ioStreamContext),
}
} }
func (s *NezhaHandler) ReportTask(c context.Context, r *pb.TaskResult) (*pb.Receipt, error) { func (s *NezhaHandler) ReportTask(c context.Context, r *pb.TaskResult) (*pb.Receipt, error) {
@ -177,3 +192,28 @@ func (s *NezhaHandler) ReportSystemInfo(c context.Context, r *pb.Host) (*pb.Rece
singleton.ServerList[clientID].Host = &host singleton.ServerList[clientID].Host = &host
return &pb.Receipt{Proced: true}, nil return &pb.Receipt{Proced: true}, nil
} }
func (s *NezhaHandler) IOStream(stream pb.NezhaService_IOStreamServer) error {
if _, err := s.Auth.Check(stream.Context()); err != nil {
return err
}
id, err := stream.Recv()
if err != nil {
return err
}
if id == nil || len(id.Data) < 4 || (id.Data[0] != 0xff && id.Data[1] != 0x05 && id.Data[2] != 0xff && id.Data[3] == 0x05) {
return fmt.Errorf("invalid stream id")
}
streamId := string(id.Data[4:])
if _, err := s.GetStream(streamId); err != nil {
return err
}
iw := grpcx.NewIOStreamWrapper(stream)
if err := s.AgentConnected(streamId, iw); err != nil {
return err
}
iw.Wait()
return nil
}