feat: 去除 webTerminal 的 websocket 依赖

This commit is contained in:
naiba 2024-07-14 12:47:36 +08:00
parent 417f972659
commit 1c91fcffac
16 changed files with 497 additions and 277 deletions

View File

@ -3,12 +3,8 @@ package controller
import (
"errors"
"fmt"
"log"
"net/http"
"regexp"
"strconv"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
@ -24,22 +20,13 @@ import (
"github.com/naiba/nezha/pkg/utils"
"github.com/naiba/nezha/pkg/websocketx"
"github.com/naiba/nezha/proto"
"github.com/naiba/nezha/service/rpc"
"github.com/naiba/nezha/service/singleton"
)
type terminalContext struct {
agentConn *websocketx.Conn
userConn *websocketx.Conn
serverID uint64
host string
useSSL bool
}
type commonPage struct {
r *gin.Engine
terminals map[string]*terminalContext
terminalsLock *sync.Mutex
requestGroup singleflight.Group
r *gin.Engine
requestGroup singleflight.Group
}
func (cp *commonPage) serve() {
@ -68,7 +55,6 @@ type viewPasswordForm struct {
func (p *commonPage) issueViewPassword(c *gin.Context) {
var vpf viewPasswordForm
err := c.ShouldBind(&vpf)
log.Println("bingo", vpf)
var hash []byte
if err == nil && vpf.Password != singleton.Conf.Site.ViewPassword {
err = errors.New(singleton.Localizer.MustLocalize(&i18n.LocalizeConfig{MessageID: "WrongAccessPassword"}))
@ -283,8 +269,6 @@ type Data struct {
Servers []*model.Server `json:"servers,omitempty"`
}
var cloudflareCookiesValidator = regexp.MustCompile("^[A-Za-z0-9-_]+$")
func (cp *commonPage) ws(c *gin.Context) {
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
@ -322,9 +306,7 @@ func (cp *commonPage) ws(c *gin.Context) {
func (cp *commonPage) terminal(c *gin.Context) {
terminalID := c.Param("id")
cp.terminalsLock.Lock()
if terminalID == "" || cp.terminals[terminalID] == nil {
cp.terminalsLock.Unlock()
if _, err := rpc.NezhaHandlerSingleton.GetStream(terminalID); err != nil {
mygin.ShowErrorPage(c, mygin.ErrInfo{
Code: http.StatusForbidden,
Title: "无权访问",
@ -334,104 +316,7 @@ func (cp *commonPage) terminal(c *gin.Context) {
}, true)
return
}
terminal := cp.terminals[terminalID]
cp.terminalsLock.Unlock()
defer func() {
// 清理 context
cp.terminalsLock.Lock()
defer cp.terminalsLock.Unlock()
delete(cp.terminals, terminalID)
}()
var isAgent bool
if _, authorized := c.Get(model.CtxKeyAuthorizedUser); !authorized {
singleton.ServerLock.RLock()
_, hasID := singleton.SecretToID[c.Request.Header.Get("Secret")]
singleton.ServerLock.RUnlock()
if !hasID {
mygin.ShowErrorPage(c, mygin.ErrInfo{
Code: http.StatusForbidden,
Title: "无权访问",
Msg: "用户未登录或非法终端",
Link: "/",
Btn: "返回首页",
}, true)
return
}
if terminal.userConn == nil {
mygin.ShowErrorPage(c, mygin.ErrInfo{
Code: http.StatusForbidden,
Title: "无权访问",
Msg: "用户不在线",
Link: "/",
Btn: "返回首页",
}, true)
return
}
if terminal.agentConn != nil {
mygin.ShowErrorPage(c, mygin.ErrInfo{
Code: http.StatusInternalServerError,
Title: "连接已存在",
Msg: "Websocket协议切换失败",
Link: "/",
Btn: "返回首页",
}, true)
return
}
isAgent = true
} else {
singleton.ServerLock.RLock()
server := singleton.ServerList[terminal.serverID]
singleton.ServerLock.RUnlock()
if server == nil || server.TaskStream == nil {
mygin.ShowErrorPage(c, mygin.ErrInfo{
Code: http.StatusForbidden,
Title: "请求失败",
Msg: "服务器不存在或处于离线状态",
Link: "/server",
Btn: "返回重试",
}, true)
return
}
cloudflareCookies, _ := c.Cookie("CF_Authorization")
// Cloudflare Cookies 合法性验证
// 其应该包含.分隔的三组BASE64-URL编码
if cloudflareCookies != "" {
encodedCookies := strings.Split(cloudflareCookies, ".")
if len(encodedCookies) == 3 {
for i := 0; i < 3; i++ {
if !cloudflareCookiesValidator.MatchString(encodedCookies[i]) {
cloudflareCookies = ""
break
}
}
} else {
cloudflareCookies = ""
}
}
terminalData, _ := utils.Json.Marshal(&model.TerminalTask{
Host: terminal.host,
UseSSL: terminal.useSSL,
Session: terminalID,
Cookie: cloudflareCookies,
})
if err := server.TaskStream.Send(&proto.Task{
Type: model.TaskTypeTerminal,
Data: string(terminalData),
}); err != nil {
mygin.ShowErrorPage(c, mygin.ErrInfo{
Code: http.StatusForbidden,
Title: "请求失败",
Msg: "Agent信令下发失败",
Link: "/server",
Btn: "返回重试",
}, true)
return
}
}
defer rpc.NezhaHandlerSingleton.CloseStream(terminalID)
wsConn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
@ -447,36 +332,7 @@ func (cp *commonPage) terminal(c *gin.Context) {
return
}
defer wsConn.Close()
conn := &websocketx.Conn{Conn: wsConn}
log.Printf("NEZHA>> terminal connected %t %q", isAgent, c.Request.URL)
defer log.Printf("NEZHA>> terminal disconnected %t %q", isAgent, c.Request.URL)
if isAgent {
terminal.agentConn = conn
defer func() {
// Agent断开链接时断开用户连接
if terminal.userConn != nil {
terminal.userConn.Close()
}
}()
} else {
terminal.userConn = conn
defer func() {
// 用户断开链接时断开 Agent 连接
if terminal.agentConn != nil {
terminal.agentConn.Close()
}
}()
}
deadlineCh := make(chan interface{})
go func() {
// 对方连接超时
connectDeadline := time.NewTimer(time.Second * 15)
<-connectDeadline.C
deadlineCh <- struct{}{}
}()
conn := websocketx.NewConn(wsConn)
go func() {
// PING 保活
@ -488,58 +344,11 @@ func (cp *commonPage) terminal(c *gin.Context) {
}
}()
dataCh := make(chan []byte)
errorCh := make(chan error)
go func() {
for {
msgType, data, err := conn.ReadMessage()
if err != nil {
errorCh <- err
return
}
// 将文本消息转换为命令输入
if msgType == websocket.TextMessage {
data = append([]byte{0}, data...)
}
dataCh <- data
}
}()
var dataBuffer [][]byte
var distConn *websocketx.Conn
checkDistConn := func() {
if distConn == nil {
if isAgent {
distConn = terminal.userConn
} else {
distConn = terminal.agentConn
}
}
if err = rpc.NezhaHandlerSingleton.UserConnected(terminalID, conn); err != nil {
return
}
for {
select {
case <-deadlineCh:
checkDistConn()
if distConn == nil {
return
}
case <-errorCh:
return
case data := <-dataCh:
dataBuffer = append(dataBuffer, data)
checkDistConn()
if distConn != nil {
for i := 0; i < len(dataBuffer); i++ {
err = distConn.WriteMessage(websocket.BinaryMessage, dataBuffer[i])
if err != nil {
return
}
}
dataBuffer = dataBuffer[:0]
}
}
}
rpc.NezhaHandlerSingleton.StartStream(terminalID, time.Second*10)
}
type createTerminalRequest struct {
@ -585,6 +394,8 @@ func (cp *commonPage) createTerminal(c *gin.Context) {
return
}
rpc.NezhaHandlerSingleton.CreateStream(id)
singleton.ServerLock.RLock()
server := singleton.ServerList[createTerminalReq.ID]
singleton.ServerLock.RUnlock()
@ -599,13 +410,21 @@ func (cp *commonPage) createTerminal(c *gin.Context) {
return
}
cp.terminalsLock.Lock()
defer cp.terminalsLock.Unlock()
cp.terminals[id] = &terminalContext{
serverID: createTerminalReq.ID,
host: createTerminalReq.Host,
useSSL: createTerminalReq.Protocol == "https:",
terminalData, _ := utils.Json.Marshal(&model.TerminalTask{
StreamID: id,
})
if err := server.TaskStream.Send(&proto.Task{
Type: model.TaskTypeTerminalGRPC,
Data: string(terminalData),
}); err != nil {
mygin.ShowErrorPage(c, mygin.ErrInfo{
Code: http.StatusForbidden,
Title: "请求失败",
Msg: "Agent信令下发失败",
Link: "/server",
Btn: "返回重试",
}, true)
return
}
c.HTML(http.StatusOK, "dashboard-"+singleton.Conf.Site.DashboardTheme+"/terminal", mygin.CommonEnvironment(c, gin.H{

View File

@ -9,7 +9,6 @@ import (
"os"
"strconv"
"strings"
"sync"
"time"
"code.cloudfoundry.org/bytefmt"
@ -68,7 +67,7 @@ func ServeWeb(port uint) *http.Server {
func routers(r *gin.Engine) {
// 通用页面
cp := commonPage{r: r, terminals: make(map[string]*terminalContext), terminalsLock: new(sync.Mutex)}
cp := commonPage{r: r}
cp.serve()
// 游客页面
gp := guestPage{r}

View File

@ -59,6 +59,7 @@ func main() {
return
}
// TODO 使用 cmux 在同一端口服务 HTTP 和 gRPC
singleton.CleanMonitorHistory()
go rpc.ServeRPC(singleton.Conf.GRPCPort)
serviceSentinelDispatchBus := make(chan model.Monitor) // 用于传递服务监控任务信息的channel

View File

@ -14,9 +14,8 @@ import (
func ServeRPC(port uint) {
server := grpc.NewServer()
pb.RegisterNezhaServiceServer(server, &rpcService.NezhaHandler{
Auth: &rpcService.AuthHandler{},
})
rpcService.NezhaHandlerSingleton = rpcService.NewNezhaHandler()
pb.RegisterNezhaServiceServer(server, rpcService.NezhaHandlerSingleton)
listen, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
if err != nil {
panic(err)

3
go.mod
View File

@ -18,7 +18,6 @@ require (
github.com/ory/graceful v0.1.3
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/robfig/cron/v3 v3.0.1
github.com/samber/lo v1.39.0
github.com/spf13/pflag v1.0.5
github.com/spf13/viper v1.18.2
github.com/xanzy/go-gitlab v0.103.0
@ -28,7 +27,7 @@ require (
golang.org/x/sync v0.7.0
golang.org/x/text v0.16.0
google.golang.org/grpc v1.63.0
google.golang.org/protobuf v1.33.0
google.golang.org/protobuf v1.34.2
gorm.io/driver/sqlite v1.5.5
gorm.io/gorm v1.25.10
sigs.k8s.io/yaml v1.4.0

6
go.sum
View File

@ -152,8 +152,6 @@ github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6ke
github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4=
github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE=
github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ=
github.com/samber/lo v1.39.0 h1:4gTz1wUhNYLhFSKl6O+8peW0v2F4BCY034GRpU9WnuA=
github.com/samber/lo v1.39.0/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA=
github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo=
github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0=
github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
@ -240,8 +238,8 @@ google.golang.org/grpc v1.63.0 h1:WjKe+dnvABXyPJMD7KDNLxtoGk5tgk+YFWN6cBWjZE8=
google.golang.org/grpc v1.63.0/go.mod h1:WAX/8DgncnokcFUldAxq7GeB5DXHDbMF+lLvDomNkRA=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=

View File

@ -20,17 +20,11 @@ const (
TaskTypeTerminal
TaskTypeUpgrade
TaskTypeKeepalive
TaskTypeTerminalGRPC
)
type TerminalTask struct {
// websocket 主机名
Host string `json:"host,omitempty"`
// 是否启用 SSL
UseSSL bool `json:"use_ssl,omitempty"`
// 会话标识
Session string `json:"session,omitempty"`
// Agent在连接Server时需要的额外Cookie信息
Cookie string `json:"cookie,omitempty"`
StreamID string
}
const (

View File

@ -0,0 +1,65 @@
package grpcx
import (
"context"
"io"
"sync/atomic"
"github.com/naiba/nezha/proto"
)
var _ io.ReadWriteCloser = &IOStreamWrapper{}
type IOStream interface {
Recv() (*proto.IOStreamData, error)
Send(*proto.IOStreamData) error
Context() context.Context
}
type IOStreamWrapper struct {
IOStream
dataBuf []byte
closed *atomic.Bool
closeCh chan struct{}
}
func NewIOStreamWrapper(stream IOStream) *IOStreamWrapper {
return &IOStreamWrapper{
IOStream: stream,
closeCh: make(chan struct{}),
closed: new(atomic.Bool),
}
}
func (iw *IOStreamWrapper) Read(p []byte) (n int, err error) {
if len(iw.dataBuf) > 0 {
n := copy(p, iw.dataBuf)
iw.dataBuf = iw.dataBuf[n:]
return n, nil
}
var data *proto.IOStreamData
if data, err = iw.Recv(); err != nil {
return 0, err
}
n = copy(p, data.Data)
if n < len(data.Data) {
iw.dataBuf = data.Data[n:]
}
return n, nil
}
func (iw *IOStreamWrapper) Write(p []byte) (n int, err error) {
err = iw.Send(&proto.IOStreamData{Data: p})
return len(p), err
}
func (iw *IOStreamWrapper) Close() error {
if iw.closed.CompareAndSwap(false, true) {
close(iw.closeCh)
}
return nil
}
func (iw *IOStreamWrapper) Wait() {
<-iw.closeCh
}

View File

@ -4,22 +4,44 @@ import (
"sync"
"github.com/gorilla/websocket"
"github.com/samber/lo"
)
type Conn struct {
*websocket.Conn
writeLock sync.Mutex
writeLock *sync.Mutex
dataBuf []byte
}
func (conn *Conn) WriteMessage(msgType int, data []byte) error {
func NewConn(conn *websocket.Conn) *Conn {
return &Conn{Conn: conn, writeLock: new(sync.Mutex)}
}
func (conn *Conn) Write(data []byte) (int, error) {
conn.writeLock.Lock()
defer conn.writeLock.Unlock()
var err error
lo.TryCatchWithErrorValue(func() error {
return conn.Conn.WriteMessage(msgType, data)
}, func(res any) {
err = res.(error)
})
return err
if err := conn.Conn.WriteMessage(websocket.BinaryMessage, data); err != nil {
return 0, err
}
return len(data), nil
}
func (conn *Conn) Read(data []byte) (int, error) {
if len(conn.dataBuf) > 0 {
n := copy(data, conn.dataBuf)
conn.dataBuf = conn.dataBuf[n:]
return n, nil
}
mType, innerData, err := conn.Conn.ReadMessage()
if err != nil {
return 0, err
}
// 将文本消息转换为命令输入
if mType == websocket.TextMessage {
innerData = append([]byte{0}, innerData...)
}
n := copy(data, innerData)
if n < len(innerData) {
conn.dataBuf = innerData[n:]
}
return n, nil
}

View File

@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.34.1
// protoc v5.26.1
// protoc-gen-go v1.34.2
// protoc v5.27.1
// source: proto/nezha.proto
package proto
@ -582,6 +582,53 @@ func (x *Receipt) GetProced() bool {
return false
}
type IOStreamData struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Data []byte `protobuf:"bytes,1,opt,name=data,proto3" json:"data,omitempty"`
}
func (x *IOStreamData) Reset() {
*x = IOStreamData{}
if protoimpl.UnsafeEnabled {
mi := &file_proto_nezha_proto_msgTypes[6]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *IOStreamData) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*IOStreamData) ProtoMessage() {}
func (x *IOStreamData) ProtoReflect() protoreflect.Message {
mi := &file_proto_nezha_proto_msgTypes[6]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use IOStreamData.ProtoReflect.Descriptor instead.
func (*IOStreamData) Descriptor() ([]byte, []int) {
return file_proto_nezha_proto_rawDescGZIP(), []int{6}
}
func (x *IOStreamData) GetData() []byte {
if x != nil {
return x.Data
}
return nil
}
var File_proto_nezha_proto protoreflect.FileDescriptor
var file_proto_nezha_proto_rawDesc = []byte{
@ -663,21 +710,27 @@ var file_proto_nezha_proto_rawDesc = []byte{
0x65, 0x73, 0x73, 0x66, 0x75, 0x6c, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x75,
0x63, 0x63, 0x65, 0x73, 0x73, 0x66, 0x75, 0x6c, 0x22, 0x21, 0x0a, 0x07, 0x52, 0x65, 0x63, 0x65,
0x69, 0x70, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x70, 0x72, 0x6f, 0x63, 0x65, 0x64, 0x18, 0x01, 0x20,
0x01, 0x28, 0x08, 0x52, 0x06, 0x70, 0x72, 0x6f, 0x63, 0x65, 0x64, 0x32, 0xd6, 0x01, 0x0a, 0x0c,
0x4e, 0x65, 0x7a, 0x68, 0x61, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x33, 0x0a, 0x11,
0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x53, 0x74, 0x61, 0x74,
0x65, 0x12, 0x0c, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x1a,
0x0e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x65, 0x69, 0x70, 0x74, 0x22,
0x00, 0x12, 0x31, 0x0a, 0x10, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x53, 0x79, 0x73, 0x74, 0x65,
0x6d, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x0b, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x48, 0x6f,
0x73, 0x74, 0x1a, 0x0e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x65, 0x69,
0x70, 0x74, 0x22, 0x00, 0x12, 0x31, 0x0a, 0x0a, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x54, 0x61,
0x73, 0x6b, 0x12, 0x11, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x54, 0x61, 0x73, 0x6b, 0x52,
0x65, 0x73, 0x75, 0x6c, 0x74, 0x1a, 0x0e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65,
0x63, 0x65, 0x69, 0x70, 0x74, 0x22, 0x00, 0x12, 0x2b, 0x0a, 0x0b, 0x52, 0x65, 0x71, 0x75, 0x65,
0x73, 0x74, 0x54, 0x61, 0x73, 0x6b, 0x12, 0x0b, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x48,
0x6f, 0x73, 0x74, 0x1a, 0x0b, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x54, 0x61, 0x73, 0x6b,
0x22, 0x00, 0x30, 0x01, 0x42, 0x09, 0x5a, 0x07, 0x2e, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62,
0x01, 0x28, 0x08, 0x52, 0x06, 0x70, 0x72, 0x6f, 0x63, 0x65, 0x64, 0x22, 0x22, 0x0a, 0x0c, 0x49,
0x4f, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x44, 0x61, 0x74, 0x61, 0x12, 0x12, 0x0a, 0x04, 0x64,
0x61, 0x74, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x64, 0x61, 0x74, 0x61, 0x32,
0x92, 0x02, 0x0a, 0x0c, 0x4e, 0x65, 0x7a, 0x68, 0x61, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65,
0x12, 0x33, 0x0a, 0x11, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d,
0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x0c, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x53, 0x74,
0x61, 0x74, 0x65, 0x1a, 0x0e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x65,
0x69, 0x70, 0x74, 0x22, 0x00, 0x12, 0x31, 0x0a, 0x10, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x53,
0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x0b, 0x2e, 0x70, 0x72, 0x6f, 0x74,
0x6f, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x1a, 0x0e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52,
0x65, 0x63, 0x65, 0x69, 0x70, 0x74, 0x22, 0x00, 0x12, 0x31, 0x0a, 0x0a, 0x52, 0x65, 0x70, 0x6f,
0x72, 0x74, 0x54, 0x61, 0x73, 0x6b, 0x12, 0x11, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x54,
0x61, 0x73, 0x6b, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x1a, 0x0e, 0x2e, 0x70, 0x72, 0x6f, 0x74,
0x6f, 0x2e, 0x52, 0x65, 0x63, 0x65, 0x69, 0x70, 0x74, 0x22, 0x00, 0x12, 0x2b, 0x0a, 0x0b, 0x52,
0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x54, 0x61, 0x73, 0x6b, 0x12, 0x0b, 0x2e, 0x70, 0x72, 0x6f,
0x74, 0x6f, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x1a, 0x0b, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e,
0x54, 0x61, 0x73, 0x6b, 0x22, 0x00, 0x30, 0x01, 0x12, 0x3a, 0x0a, 0x08, 0x49, 0x4f, 0x53, 0x74,
0x72, 0x65, 0x61, 0x6d, 0x12, 0x13, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x49, 0x4f, 0x53,
0x74, 0x72, 0x65, 0x61, 0x6d, 0x44, 0x61, 0x74, 0x61, 0x1a, 0x13, 0x2e, 0x70, 0x72, 0x6f, 0x74,
0x6f, 0x2e, 0x49, 0x4f, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x44, 0x61, 0x74, 0x61, 0x22, 0x00,
0x28, 0x01, 0x30, 0x01, 0x42, 0x09, 0x5a, 0x07, 0x2e, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62,
0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
@ -693,14 +746,15 @@ func file_proto_nezha_proto_rawDescGZIP() []byte {
return file_proto_nezha_proto_rawDescData
}
var file_proto_nezha_proto_msgTypes = make([]protoimpl.MessageInfo, 6)
var file_proto_nezha_proto_goTypes = []interface{}{
var file_proto_nezha_proto_msgTypes = make([]protoimpl.MessageInfo, 7)
var file_proto_nezha_proto_goTypes = []any{
(*Host)(nil), // 0: proto.Host
(*State)(nil), // 1: proto.State
(*State_SensorTemperature)(nil), // 2: proto.State_SensorTemperature
(*Task)(nil), // 3: proto.Task
(*TaskResult)(nil), // 4: proto.TaskResult
(*Receipt)(nil), // 5: proto.Receipt
(*IOStreamData)(nil), // 6: proto.IOStreamData
}
var file_proto_nezha_proto_depIdxs = []int32{
2, // 0: proto.State.temperatures:type_name -> proto.State_SensorTemperature
@ -708,12 +762,14 @@ var file_proto_nezha_proto_depIdxs = []int32{
0, // 2: proto.NezhaService.ReportSystemInfo:input_type -> proto.Host
4, // 3: proto.NezhaService.ReportTask:input_type -> proto.TaskResult
0, // 4: proto.NezhaService.RequestTask:input_type -> proto.Host
5, // 5: proto.NezhaService.ReportSystemState:output_type -> proto.Receipt
5, // 6: proto.NezhaService.ReportSystemInfo:output_type -> proto.Receipt
5, // 7: proto.NezhaService.ReportTask:output_type -> proto.Receipt
3, // 8: proto.NezhaService.RequestTask:output_type -> proto.Task
5, // [5:9] is the sub-list for method output_type
1, // [1:5] is the sub-list for method input_type
6, // 5: proto.NezhaService.IOStream:input_type -> proto.IOStreamData
5, // 6: proto.NezhaService.ReportSystemState:output_type -> proto.Receipt
5, // 7: proto.NezhaService.ReportSystemInfo:output_type -> proto.Receipt
5, // 8: proto.NezhaService.ReportTask:output_type -> proto.Receipt
3, // 9: proto.NezhaService.RequestTask:output_type -> proto.Task
6, // 10: proto.NezhaService.IOStream:output_type -> proto.IOStreamData
6, // [6:11] is the sub-list for method output_type
1, // [1:6] is the sub-list for method input_type
1, // [1:1] is the sub-list for extension type_name
1, // [1:1] is the sub-list for extension extendee
0, // [0:1] is the sub-list for field type_name
@ -725,7 +781,7 @@ func file_proto_nezha_proto_init() {
return
}
if !protoimpl.UnsafeEnabled {
file_proto_nezha_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
file_proto_nezha_proto_msgTypes[0].Exporter = func(v any, i int) any {
switch v := v.(*Host); i {
case 0:
return &v.state
@ -737,7 +793,7 @@ func file_proto_nezha_proto_init() {
return nil
}
}
file_proto_nezha_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
file_proto_nezha_proto_msgTypes[1].Exporter = func(v any, i int) any {
switch v := v.(*State); i {
case 0:
return &v.state
@ -749,7 +805,7 @@ func file_proto_nezha_proto_init() {
return nil
}
}
file_proto_nezha_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} {
file_proto_nezha_proto_msgTypes[2].Exporter = func(v any, i int) any {
switch v := v.(*State_SensorTemperature); i {
case 0:
return &v.state
@ -761,7 +817,7 @@ func file_proto_nezha_proto_init() {
return nil
}
}
file_proto_nezha_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} {
file_proto_nezha_proto_msgTypes[3].Exporter = func(v any, i int) any {
switch v := v.(*Task); i {
case 0:
return &v.state
@ -773,7 +829,7 @@ func file_proto_nezha_proto_init() {
return nil
}
}
file_proto_nezha_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} {
file_proto_nezha_proto_msgTypes[4].Exporter = func(v any, i int) any {
switch v := v.(*TaskResult); i {
case 0:
return &v.state
@ -785,7 +841,7 @@ func file_proto_nezha_proto_init() {
return nil
}
}
file_proto_nezha_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} {
file_proto_nezha_proto_msgTypes[5].Exporter = func(v any, i int) any {
switch v := v.(*Receipt); i {
case 0:
return &v.state
@ -797,6 +853,18 @@ func file_proto_nezha_proto_init() {
return nil
}
}
file_proto_nezha_proto_msgTypes[6].Exporter = func(v any, i int) any {
switch v := v.(*IOStreamData); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
@ -804,7 +872,7 @@ func file_proto_nezha_proto_init() {
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_proto_nezha_proto_rawDesc,
NumEnums: 0,
NumMessages: 6,
NumMessages: 7,
NumExtensions: 0,
NumServices: 1,
},

View File

@ -8,6 +8,7 @@ service NezhaService {
rpc ReportSystemInfo(Host)returns(Receipt){}
rpc ReportTask(TaskResult)returns(Receipt){}
rpc RequestTask(Host)returns(stream Task){}
rpc IOStream(stream IOStreamData)returns(stream IOStreamData){}
}
message Host {
@ -68,3 +69,7 @@ message TaskResult {
message Receipt{
bool proced = 1;
}
message IOStreamData {
bytes data = 1;
}

View File

@ -1,7 +1,7 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.3.0
// - protoc v3.21.12
// - protoc v5.27.1
// source: proto/nezha.proto
package proto
@ -23,6 +23,7 @@ const (
NezhaService_ReportSystemInfo_FullMethodName = "/proto.NezhaService/ReportSystemInfo"
NezhaService_ReportTask_FullMethodName = "/proto.NezhaService/ReportTask"
NezhaService_RequestTask_FullMethodName = "/proto.NezhaService/RequestTask"
NezhaService_IOStream_FullMethodName = "/proto.NezhaService/IOStream"
)
// NezhaServiceClient is the client API for NezhaService service.
@ -33,6 +34,7 @@ type NezhaServiceClient interface {
ReportSystemInfo(ctx context.Context, in *Host, opts ...grpc.CallOption) (*Receipt, error)
ReportTask(ctx context.Context, in *TaskResult, opts ...grpc.CallOption) (*Receipt, error)
RequestTask(ctx context.Context, in *Host, opts ...grpc.CallOption) (NezhaService_RequestTaskClient, error)
IOStream(ctx context.Context, opts ...grpc.CallOption) (NezhaService_IOStreamClient, error)
}
type nezhaServiceClient struct {
@ -102,6 +104,37 @@ func (x *nezhaServiceRequestTaskClient) Recv() (*Task, error) {
return m, nil
}
func (c *nezhaServiceClient) IOStream(ctx context.Context, opts ...grpc.CallOption) (NezhaService_IOStreamClient, error) {
stream, err := c.cc.NewStream(ctx, &NezhaService_ServiceDesc.Streams[1], NezhaService_IOStream_FullMethodName, opts...)
if err != nil {
return nil, err
}
x := &nezhaServiceIOStreamClient{stream}
return x, nil
}
type NezhaService_IOStreamClient interface {
Send(*IOStreamData) error
Recv() (*IOStreamData, error)
grpc.ClientStream
}
type nezhaServiceIOStreamClient struct {
grpc.ClientStream
}
func (x *nezhaServiceIOStreamClient) Send(m *IOStreamData) error {
return x.ClientStream.SendMsg(m)
}
func (x *nezhaServiceIOStreamClient) Recv() (*IOStreamData, error) {
m := new(IOStreamData)
if err := x.ClientStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
// NezhaServiceServer is the server API for NezhaService service.
// All implementations should embed UnimplementedNezhaServiceServer
// for forward compatibility
@ -110,6 +143,7 @@ type NezhaServiceServer interface {
ReportSystemInfo(context.Context, *Host) (*Receipt, error)
ReportTask(context.Context, *TaskResult) (*Receipt, error)
RequestTask(*Host, NezhaService_RequestTaskServer) error
IOStream(NezhaService_IOStreamServer) error
}
// UnimplementedNezhaServiceServer should be embedded to have forward compatible implementations.
@ -128,6 +162,9 @@ func (UnimplementedNezhaServiceServer) ReportTask(context.Context, *TaskResult)
func (UnimplementedNezhaServiceServer) RequestTask(*Host, NezhaService_RequestTaskServer) error {
return status.Errorf(codes.Unimplemented, "method RequestTask not implemented")
}
func (UnimplementedNezhaServiceServer) IOStream(NezhaService_IOStreamServer) error {
return status.Errorf(codes.Unimplemented, "method IOStream not implemented")
}
// UnsafeNezhaServiceServer may be embedded to opt out of forward compatibility for this service.
// Use of this interface is not recommended, as added methods to NezhaServiceServer will
@ -215,6 +252,32 @@ func (x *nezhaServiceRequestTaskServer) Send(m *Task) error {
return x.ServerStream.SendMsg(m)
}
func _NezhaService_IOStream_Handler(srv interface{}, stream grpc.ServerStream) error {
return srv.(NezhaServiceServer).IOStream(&nezhaServiceIOStreamServer{stream})
}
type NezhaService_IOStreamServer interface {
Send(*IOStreamData) error
Recv() (*IOStreamData, error)
grpc.ServerStream
}
type nezhaServiceIOStreamServer struct {
grpc.ServerStream
}
func (x *nezhaServiceIOStreamServer) Send(m *IOStreamData) error {
return x.ServerStream.SendMsg(m)
}
func (x *nezhaServiceIOStreamServer) Recv() (*IOStreamData, error) {
m := new(IOStreamData)
if err := x.ServerStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
// NezhaService_ServiceDesc is the grpc.ServiceDesc for NezhaService service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
@ -241,6 +304,12 @@ var NezhaService_ServiceDesc = grpc.ServiceDesc{
Handler: _NezhaService_RequestTask_Handler,
ServerStreams: true,
},
{
StreamName: "IOStream",
Handler: _NezhaService_IOStream_Handler,
ServerStreams: true,
ClientStreams: true,
},
},
Metadata: "proto/nezha.proto",
}

View File

@ -1 +1,3 @@
protoc --go-grpc_out="require_unimplemented_servers=false:." --go_out="." proto/*.proto
protoc --go-grpc_out="require_unimplemented_servers=false:." --go_out="." proto/*.proto
rm -rf ../agent/proto
cp -r proto ../agent

View File

@ -10,19 +10,19 @@ import (
"github.com/naiba/nezha/service/singleton"
)
type AuthHandler struct {
type authHandler struct {
ClientSecret string
}
func (a *AuthHandler) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
func (a *authHandler) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
return map[string]string{"client_secret": a.ClientSecret}, nil
}
func (a *AuthHandler) RequireTransportSecurity() bool {
func (a *authHandler) RequireTransportSecurity() bool {
return false
}
func (a *AuthHandler) Check(ctx context.Context) (uint64, error) {
func (a *authHandler) Check(ctx context.Context) (uint64, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return 0, status.Errorf(codes.Unauthenticated, "获取 metaData 失败")

140
service/rpc/io_stream.go Normal file
View File

@ -0,0 +1,140 @@
package rpc
import (
"errors"
"io"
"sync/atomic"
"time"
)
type ioStreamContext struct {
userIo io.ReadWriteCloser
agentIo io.ReadWriteCloser
userIoConnectCh chan struct{}
agentIoConnectCh chan struct{}
}
func (s *NezhaHandler) CreateStream(streamId string) {
s.ioStreamMutex.Lock()
defer s.ioStreamMutex.Unlock()
s.ioStreams[streamId] = &ioStreamContext{
userIoConnectCh: make(chan struct{}),
agentIoConnectCh: make(chan struct{}),
}
}
func (s *NezhaHandler) GetStream(streamId string) (*ioStreamContext, error) {
s.ioStreamMutex.RLock()
defer s.ioStreamMutex.RUnlock()
if ctx, ok := s.ioStreams[streamId]; ok {
return ctx, nil
}
return nil, errors.New("stream not found")
}
func (s *NezhaHandler) CloseStream(streamId string) error {
s.ioStreamMutex.Lock()
defer s.ioStreamMutex.Unlock()
if ctx, ok := s.ioStreams[streamId]; ok {
if ctx.userIo != nil {
ctx.userIo.Close()
}
if ctx.agentIo != nil {
ctx.agentIo.Close()
}
delete(s.ioStreams, streamId)
}
return nil
}
func (s *NezhaHandler) UserConnected(streamId string, userIo io.ReadWriteCloser) error {
stream, err := s.GetStream(streamId)
if err != nil {
return err
}
stream.userIo = userIo
close(stream.userIoConnectCh)
return nil
}
func (s *NezhaHandler) AgentConnected(streamId string, agentIo io.ReadWriteCloser) error {
stream, err := s.GetStream(streamId)
if err != nil {
return err
}
stream.agentIo = agentIo
close(stream.agentIoConnectCh)
return nil
}
func (s *NezhaHandler) StartStream(streamId string, timeout time.Duration) error {
stream, err := s.GetStream(streamId)
if err != nil {
return err
}
timeoutTimer := time.NewTimer(timeout)
LOOP:
for {
select {
case <-stream.userIoConnectCh:
if stream.agentIo != nil {
timeoutTimer.Stop()
break LOOP
}
case <-stream.agentIoConnectCh:
if stream.userIo != nil {
timeoutTimer.Stop()
break LOOP
}
case <-time.After(timeout):
break LOOP
}
}
if stream.userIo == nil && stream.agentIo == nil {
return errors.New("timeout: no connection established")
}
if stream.userIo == nil {
return errors.New("timeout: user connection not established")
}
if stream.agentIo == nil {
return errors.New("timeout: agent connection not established")
}
isDone := new(atomic.Bool)
endCh := make(chan struct{})
go func() {
_, innerErr := io.Copy(stream.userIo, stream.agentIo)
if innerErr != nil {
err = innerErr
}
if isDone.CompareAndSwap(false, true) {
close(endCh)
}
}()
go func() {
_, innerErr := io.Copy(stream.agentIo, stream.userIo)
if innerErr != nil {
err = innerErr
}
if isDone.CompareAndSwap(false, true) {
close(endCh)
}
}()
<-endCh
return err
}

View File

@ -3,11 +3,14 @@ package rpc
import (
"context"
"fmt"
"github.com/naiba/nezha/pkg/ddns"
"github.com/naiba/nezha/pkg/utils"
"log"
"sync"
"time"
"github.com/naiba/nezha/pkg/ddns"
"github.com/naiba/nezha/pkg/grpcx"
"github.com/naiba/nezha/pkg/utils"
"github.com/jinzhu/copier"
"github.com/nicksnyder/go-i18n/v2/i18n"
@ -16,8 +19,20 @@ import (
"github.com/naiba/nezha/service/singleton"
)
var NezhaHandlerSingleton *NezhaHandler
type NezhaHandler struct {
Auth *AuthHandler
Auth *authHandler
ioStreams map[string]*ioStreamContext
ioStreamMutex *sync.RWMutex
}
func NewNezhaHandler() *NezhaHandler {
return &NezhaHandler{
Auth: &authHandler{},
ioStreamMutex: new(sync.RWMutex),
ioStreams: make(map[string]*ioStreamContext),
}
}
func (s *NezhaHandler) ReportTask(c context.Context, r *pb.TaskResult) (*pb.Receipt, error) {
@ -177,3 +192,28 @@ func (s *NezhaHandler) ReportSystemInfo(c context.Context, r *pb.Host) (*pb.Rece
singleton.ServerList[clientID].Host = &host
return &pb.Receipt{Proced: true}, nil
}
func (s *NezhaHandler) IOStream(stream pb.NezhaService_IOStreamServer) error {
if _, err := s.Auth.Check(stream.Context()); err != nil {
return err
}
id, err := stream.Recv()
if err != nil {
return err
}
if id == nil || len(id.Data) < 4 || (id.Data[0] != 0xff && id.Data[1] != 0x05 && id.Data[2] != 0xff && id.Data[3] == 0x05) {
return fmt.Errorf("invalid stream id")
}
streamId := string(id.Data[4:])
if _, err := s.GetStream(streamId); err != nil {
return err
}
iw := grpcx.NewIOStreamWrapper(stream)
if err := s.AgentConnected(streamId, iw); err != nil {
return err
}
iw.Wait()
return nil
}