From be51c5a1e6c933da1b886a431873951dc75f8727 Mon Sep 17 00:00:00 2001 From: naiba Date: Thu, 5 Dec 2024 00:11:34 +0800 Subject: [PATCH] feat: refactor grpc keepalive --- cmd/dashboard/main.go | 2 +- cmd/dashboard/rpc/rpc.go | 7 +-- model/notification_test.go | 1 - model/server.go | 7 +-- model/service.go | 2 +- proto/nezha.pb.go | 50 ++++++++--------- proto/nezha.proto | 3 +- proto/nezha_grpc.pb.go | 72 +++++++------------------ script/bootstrap.sh | 2 - service/rpc/auth.go | 2 - service/rpc/nezha.go | 105 ++++++++++++++++++------------------ service/singleton/server.go | 1 - 12 files changed, 102 insertions(+), 152 deletions(-) diff --git a/cmd/dashboard/main.go b/cmd/dashboard/main.go index 673592c..8bcf9b2 100644 --- a/cmd/dashboard/main.go +++ b/cmd/dashboard/main.go @@ -119,8 +119,8 @@ func main() { singleton.CleanServiceHistory() serviceSentinelDispatchBus := make(chan model.Service) // 用于传递服务监控任务信息的channel + rpc.DispatchKeepalive() go rpc.DispatchTask(serviceSentinelDispatchBus) - go rpc.DispatchKeepalive() go singleton.AlertSentinelStart() singleton.NewServiceSentinel(serviceSentinelDispatchBus) diff --git a/cmd/dashboard/rpc/rpc.go b/cmd/dashboard/rpc/rpc.go index 2f7ea47..35ec1b5 100644 --- a/cmd/dashboard/rpc/rpc.go +++ b/cmd/dashboard/rpc/rpc.go @@ -107,24 +107,19 @@ func DispatchTask(serviceSentinelDispatchBus <-chan model.Service) { workedServerIndex++ continue } - // 找到合适机器执行任务,跳出循环 - // singleton.SortedServerList[workedServerIndex].TaskStream.Send(task.PB()) - // workedServerIndex++ - // break } singleton.SortedServerLock.RUnlock() } } func DispatchKeepalive() { - singleton.Cron.AddFunc("@every 60s", func() { + singleton.Cron.AddFunc("@every 30s", func() { singleton.SortedServerLock.RLock() defer singleton.SortedServerLock.RUnlock() for i := 0; i < len(singleton.SortedServerList); i++ { if singleton.SortedServerList[i] == nil || singleton.SortedServerList[i].TaskStream == nil { continue } - singleton.SortedServerList[i].TaskStream.Send(&proto.Task{Type: model.TaskTypeKeepalive}) } }) diff --git a/model/notification_test.go b/model/notification_test.go index e7e4c63..480a319 100644 --- a/model/notification_test.go +++ b/model/notification_test.go @@ -75,7 +75,6 @@ func execCase(t *testing.T, item testSt) { CountryCode: "", }, LastActive: time.Time{}, - TaskClose: nil, TaskStream: nil, PrevTransferInSnapshot: 0, PrevTransferOutSnapshot: 0, diff --git a/model/server.go b/model/server.go index aa79f8f..240cb9a 100644 --- a/model/server.go +++ b/model/server.go @@ -2,7 +2,6 @@ package model import ( "log" - "sync" "time" "gorm.io/gorm" @@ -30,9 +29,7 @@ type Server struct { GeoIP *GeoIP `gorm:"-" json:"geoip,omitempty"` LastActive time.Time `gorm:"-" json:"last_active,omitempty"` - TaskClose chan error `gorm:"-" json:"-"` - TaskCloseLock *sync.Mutex `gorm:"-" json:"-"` - TaskStream pb.NezhaService_RequestTaskServer `gorm:"-" json:"-"` + TaskStream pb.NezhaService_RequestTaskServer `gorm:"-" json:"-"` PrevTransferInSnapshot int64 `gorm:"-" json:"-"` // 上次数据点时的入站使用量 PrevTransferOutSnapshot int64 `gorm:"-" json:"-"` // 上次数据点时的出站使用量 @@ -43,8 +40,6 @@ func (s *Server) CopyFromRunningServer(old *Server) { s.State = old.State s.GeoIP = old.GeoIP s.LastActive = old.LastActive - s.TaskClose = old.TaskClose - s.TaskCloseLock = old.TaskCloseLock s.TaskStream = old.TaskStream s.PrevTransferInSnapshot = old.PrevTransferInSnapshot s.PrevTransferOutSnapshot = old.PrevTransferOutSnapshot diff --git a/model/service.go b/model/service.go index 61789ef..2a423cd 100644 --- a/model/service.go +++ b/model/service.go @@ -127,5 +127,5 @@ func (m *Service) AfterFind(tx *gorm.DB) error { // IsServiceSentinelNeeded 判断该任务类型是否需要进行服务监控 需要则返回true func IsServiceSentinelNeeded(t uint64) bool { - return t != TaskTypeCommand && t != TaskTypeTerminalGRPC && t != TaskTypeUpgrade + return t != TaskTypeCommand && t != TaskTypeTerminalGRPC && t != TaskTypeUpgrade && t != TaskTypeKeepalive } diff --git a/proto/nezha.pb.go b/proto/nezha.pb.go index 8573cb1..72eb4ef 100644 --- a/proto/nezha.pb.go +++ b/proto/nezha.pb.go @@ -820,7 +820,7 @@ var file_proto_nezha_proto_rawDesc = []byte{ 0x74, 0x72, 0x79, 0x43, 0x6f, 0x64, 0x65, 0x22, 0x2c, 0x0a, 0x02, 0x49, 0x50, 0x12, 0x12, 0x0a, 0x04, 0x69, 0x70, 0x76, 0x34, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x69, 0x70, 0x76, 0x34, 0x12, 0x12, 0x0a, 0x04, 0x69, 0x70, 0x76, 0x36, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x04, 0x69, 0x70, 0x76, 0x36, 0x32, 0xc3, 0x02, 0x0a, 0x0c, 0x4e, 0x65, 0x7a, 0x68, 0x61, 0x53, + 0x04, 0x69, 0x70, 0x76, 0x36, 0x32, 0x98, 0x02, 0x0a, 0x0c, 0x4e, 0x65, 0x7a, 0x68, 0x61, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x37, 0x0a, 0x11, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x0c, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x1a, 0x0e, 0x2e, 0x70, 0x72, 0x6f, 0x74, @@ -828,20 +828,18 @@ var file_proto_nezha_proto_rawDesc = []byte{ 0x31, 0x0a, 0x10, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x0b, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x1a, 0x0e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x65, 0x69, 0x70, 0x74, - 0x22, 0x00, 0x12, 0x31, 0x0a, 0x0a, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x54, 0x61, 0x73, 0x6b, - 0x12, 0x11, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x54, 0x61, 0x73, 0x6b, 0x52, 0x65, 0x73, - 0x75, 0x6c, 0x74, 0x1a, 0x0e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x65, - 0x69, 0x70, 0x74, 0x22, 0x00, 0x12, 0x2b, 0x0a, 0x0b, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, - 0x54, 0x61, 0x73, 0x6b, 0x12, 0x0b, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x48, 0x6f, 0x73, - 0x74, 0x1a, 0x0b, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x54, 0x61, 0x73, 0x6b, 0x22, 0x00, - 0x30, 0x01, 0x12, 0x3a, 0x0a, 0x08, 0x49, 0x4f, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, 0x13, - 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x49, 0x4f, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x44, - 0x61, 0x74, 0x61, 0x1a, 0x13, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x49, 0x4f, 0x53, 0x74, - 0x72, 0x65, 0x61, 0x6d, 0x44, 0x61, 0x74, 0x61, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x12, 0x2b, - 0x0a, 0x0b, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x47, 0x65, 0x6f, 0x49, 0x50, 0x12, 0x0c, 0x2e, - 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x47, 0x65, 0x6f, 0x49, 0x50, 0x1a, 0x0c, 0x2e, 0x70, 0x72, - 0x6f, 0x74, 0x6f, 0x2e, 0x47, 0x65, 0x6f, 0x49, 0x50, 0x22, 0x00, 0x42, 0x09, 0x5a, 0x07, 0x2e, - 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x22, 0x00, 0x12, 0x33, 0x0a, 0x0b, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x54, 0x61, 0x73, + 0x6b, 0x12, 0x11, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x54, 0x61, 0x73, 0x6b, 0x52, 0x65, + 0x73, 0x75, 0x6c, 0x74, 0x1a, 0x0b, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x54, 0x61, 0x73, + 0x6b, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x12, 0x3a, 0x0a, 0x08, 0x49, 0x4f, 0x53, 0x74, 0x72, + 0x65, 0x61, 0x6d, 0x12, 0x13, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x49, 0x4f, 0x53, 0x74, + 0x72, 0x65, 0x61, 0x6d, 0x44, 0x61, 0x74, 0x61, 0x1a, 0x13, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x2e, 0x49, 0x4f, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x44, 0x61, 0x74, 0x61, 0x22, 0x00, 0x28, + 0x01, 0x30, 0x01, 0x12, 0x2b, 0x0a, 0x0b, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x47, 0x65, 0x6f, + 0x49, 0x50, 0x12, 0x0c, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x47, 0x65, 0x6f, 0x49, 0x50, + 0x1a, 0x0c, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x47, 0x65, 0x6f, 0x49, 0x50, 0x22, 0x00, + 0x42, 0x09, 0x5a, 0x07, 0x2e, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x33, } var ( @@ -873,18 +871,16 @@ var file_proto_nezha_proto_depIdxs = []int32{ 8, // 1: proto.GeoIP.ip:type_name -> proto.IP 1, // 2: proto.NezhaService.ReportSystemState:input_type -> proto.State 0, // 3: proto.NezhaService.ReportSystemInfo:input_type -> proto.Host - 4, // 4: proto.NezhaService.ReportTask:input_type -> proto.TaskResult - 0, // 5: proto.NezhaService.RequestTask:input_type -> proto.Host - 6, // 6: proto.NezhaService.IOStream:input_type -> proto.IOStreamData - 7, // 7: proto.NezhaService.ReportGeoIP:input_type -> proto.GeoIP - 5, // 8: proto.NezhaService.ReportSystemState:output_type -> proto.Receipt - 5, // 9: proto.NezhaService.ReportSystemInfo:output_type -> proto.Receipt - 5, // 10: proto.NezhaService.ReportTask:output_type -> proto.Receipt - 3, // 11: proto.NezhaService.RequestTask:output_type -> proto.Task - 6, // 12: proto.NezhaService.IOStream:output_type -> proto.IOStreamData - 7, // 13: proto.NezhaService.ReportGeoIP:output_type -> proto.GeoIP - 8, // [8:14] is the sub-list for method output_type - 2, // [2:8] is the sub-list for method input_type + 4, // 4: proto.NezhaService.RequestTask:input_type -> proto.TaskResult + 6, // 5: proto.NezhaService.IOStream:input_type -> proto.IOStreamData + 7, // 6: proto.NezhaService.ReportGeoIP:input_type -> proto.GeoIP + 5, // 7: proto.NezhaService.ReportSystemState:output_type -> proto.Receipt + 5, // 8: proto.NezhaService.ReportSystemInfo:output_type -> proto.Receipt + 3, // 9: proto.NezhaService.RequestTask:output_type -> proto.Task + 6, // 10: proto.NezhaService.IOStream:output_type -> proto.IOStreamData + 7, // 11: proto.NezhaService.ReportGeoIP:output_type -> proto.GeoIP + 7, // [7:12] is the sub-list for method output_type + 2, // [2:7] is the sub-list for method input_type 2, // [2:2] is the sub-list for extension type_name 2, // [2:2] is the sub-list for extension extendee 0, // [0:2] is the sub-list for field type_name diff --git a/proto/nezha.proto b/proto/nezha.proto index 632bbd3..326e496 100644 --- a/proto/nezha.proto +++ b/proto/nezha.proto @@ -6,8 +6,7 @@ package proto; service NezhaService { rpc ReportSystemState(stream State) returns (stream Receipt) {} rpc ReportSystemInfo(Host) returns (Receipt) {} - rpc ReportTask(TaskResult) returns (Receipt) {} - rpc RequestTask(Host) returns (stream Task) {} + rpc RequestTask(stream TaskResult) returns (stream Task) {} rpc IOStream(stream IOStreamData) returns (stream IOStreamData) {} rpc ReportGeoIP(GeoIP) returns (GeoIP) {} } diff --git a/proto/nezha_grpc.pb.go b/proto/nezha_grpc.pb.go index 0e80ed1..1669475 100644 --- a/proto/nezha_grpc.pb.go +++ b/proto/nezha_grpc.pb.go @@ -21,7 +21,6 @@ const _ = grpc.SupportPackageIsVersion7 const ( NezhaService_ReportSystemState_FullMethodName = "/proto.NezhaService/ReportSystemState" NezhaService_ReportSystemInfo_FullMethodName = "/proto.NezhaService/ReportSystemInfo" - NezhaService_ReportTask_FullMethodName = "/proto.NezhaService/ReportTask" NezhaService_RequestTask_FullMethodName = "/proto.NezhaService/RequestTask" NezhaService_IOStream_FullMethodName = "/proto.NezhaService/IOStream" NezhaService_ReportGeoIP_FullMethodName = "/proto.NezhaService/ReportGeoIP" @@ -33,8 +32,7 @@ const ( type NezhaServiceClient interface { ReportSystemState(ctx context.Context, opts ...grpc.CallOption) (NezhaService_ReportSystemStateClient, error) ReportSystemInfo(ctx context.Context, in *Host, opts ...grpc.CallOption) (*Receipt, error) - ReportTask(ctx context.Context, in *TaskResult, opts ...grpc.CallOption) (*Receipt, error) - RequestTask(ctx context.Context, in *Host, opts ...grpc.CallOption) (NezhaService_RequestTaskClient, error) + RequestTask(ctx context.Context, opts ...grpc.CallOption) (NezhaService_RequestTaskClient, error) IOStream(ctx context.Context, opts ...grpc.CallOption) (NezhaService_IOStreamClient, error) ReportGeoIP(ctx context.Context, in *GeoIP, opts ...grpc.CallOption) (*GeoIP, error) } @@ -87,31 +85,17 @@ func (c *nezhaServiceClient) ReportSystemInfo(ctx context.Context, in *Host, opt return out, nil } -func (c *nezhaServiceClient) ReportTask(ctx context.Context, in *TaskResult, opts ...grpc.CallOption) (*Receipt, error) { - out := new(Receipt) - err := c.cc.Invoke(ctx, NezhaService_ReportTask_FullMethodName, in, out, opts...) - if err != nil { - return nil, err - } - return out, nil -} - -func (c *nezhaServiceClient) RequestTask(ctx context.Context, in *Host, opts ...grpc.CallOption) (NezhaService_RequestTaskClient, error) { +func (c *nezhaServiceClient) RequestTask(ctx context.Context, opts ...grpc.CallOption) (NezhaService_RequestTaskClient, error) { stream, err := c.cc.NewStream(ctx, &NezhaService_ServiceDesc.Streams[1], NezhaService_RequestTask_FullMethodName, opts...) if err != nil { return nil, err } x := &nezhaServiceRequestTaskClient{stream} - if err := x.ClientStream.SendMsg(in); err != nil { - return nil, err - } - if err := x.ClientStream.CloseSend(); err != nil { - return nil, err - } return x, nil } type NezhaService_RequestTaskClient interface { + Send(*TaskResult) error Recv() (*Task, error) grpc.ClientStream } @@ -120,6 +104,10 @@ type nezhaServiceRequestTaskClient struct { grpc.ClientStream } +func (x *nezhaServiceRequestTaskClient) Send(m *TaskResult) error { + return x.ClientStream.SendMsg(m) +} + func (x *nezhaServiceRequestTaskClient) Recv() (*Task, error) { m := new(Task) if err := x.ClientStream.RecvMsg(m); err != nil { @@ -174,8 +162,7 @@ func (c *nezhaServiceClient) ReportGeoIP(ctx context.Context, in *GeoIP, opts .. type NezhaServiceServer interface { ReportSystemState(NezhaService_ReportSystemStateServer) error ReportSystemInfo(context.Context, *Host) (*Receipt, error) - ReportTask(context.Context, *TaskResult) (*Receipt, error) - RequestTask(*Host, NezhaService_RequestTaskServer) error + RequestTask(NezhaService_RequestTaskServer) error IOStream(NezhaService_IOStreamServer) error ReportGeoIP(context.Context, *GeoIP) (*GeoIP, error) } @@ -190,10 +177,7 @@ func (UnimplementedNezhaServiceServer) ReportSystemState(NezhaService_ReportSyst func (UnimplementedNezhaServiceServer) ReportSystemInfo(context.Context, *Host) (*Receipt, error) { return nil, status.Errorf(codes.Unimplemented, "method ReportSystemInfo not implemented") } -func (UnimplementedNezhaServiceServer) ReportTask(context.Context, *TaskResult) (*Receipt, error) { - return nil, status.Errorf(codes.Unimplemented, "method ReportTask not implemented") -} -func (UnimplementedNezhaServiceServer) RequestTask(*Host, NezhaService_RequestTaskServer) error { +func (UnimplementedNezhaServiceServer) RequestTask(NezhaService_RequestTaskServer) error { return status.Errorf(codes.Unimplemented, "method RequestTask not implemented") } func (UnimplementedNezhaServiceServer) IOStream(NezhaService_IOStreamServer) error { @@ -258,34 +242,13 @@ func _NezhaService_ReportSystemInfo_Handler(srv interface{}, ctx context.Context return interceptor(ctx, in, info, handler) } -func _NezhaService_ReportTask_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(TaskResult) - if err := dec(in); err != nil { - return nil, err - } - if interceptor == nil { - return srv.(NezhaServiceServer).ReportTask(ctx, in) - } - info := &grpc.UnaryServerInfo{ - Server: srv, - FullMethod: NezhaService_ReportTask_FullMethodName, - } - handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(NezhaServiceServer).ReportTask(ctx, req.(*TaskResult)) - } - return interceptor(ctx, in, info, handler) -} - func _NezhaService_RequestTask_Handler(srv interface{}, stream grpc.ServerStream) error { - m := new(Host) - if err := stream.RecvMsg(m); err != nil { - return err - } - return srv.(NezhaServiceServer).RequestTask(m, &nezhaServiceRequestTaskServer{stream}) + return srv.(NezhaServiceServer).RequestTask(&nezhaServiceRequestTaskServer{stream}) } type NezhaService_RequestTaskServer interface { Send(*Task) error + Recv() (*TaskResult, error) grpc.ServerStream } @@ -297,6 +260,14 @@ func (x *nezhaServiceRequestTaskServer) Send(m *Task) error { return x.ServerStream.SendMsg(m) } +func (x *nezhaServiceRequestTaskServer) Recv() (*TaskResult, error) { + m := new(TaskResult) + if err := x.ServerStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + func _NezhaService_IOStream_Handler(srv interface{}, stream grpc.ServerStream) error { return srv.(NezhaServiceServer).IOStream(&nezhaServiceIOStreamServer{stream}) } @@ -352,10 +323,6 @@ var NezhaService_ServiceDesc = grpc.ServiceDesc{ MethodName: "ReportSystemInfo", Handler: _NezhaService_ReportSystemInfo_Handler, }, - { - MethodName: "ReportTask", - Handler: _NezhaService_ReportTask_Handler, - }, { MethodName: "ReportGeoIP", Handler: _NezhaService_ReportGeoIP_Handler, @@ -372,6 +339,7 @@ var NezhaService_ServiceDesc = grpc.ServiceDesc{ StreamName: "RequestTask", Handler: _NezhaService_RequestTask_Handler, ServerStreams: true, + ClientStreams: true, }, { StreamName: "IOStream", diff --git a/script/bootstrap.sh b/script/bootstrap.sh index 4f60128..60f8029 100755 --- a/script/bootstrap.sh +++ b/script/bootstrap.sh @@ -1,5 +1,3 @@ -touch ./cmd/dashboard/user-dist/a -touch ./cmd/dashboard/admin-dist/a swag init --pd -d . -g ./cmd/dashboard/main.go -o ./cmd/dashboard/docs --requiredByDefault protoc --go-grpc_out="require_unimplemented_servers=false:." --go_out="." proto/*.proto rm -rf ../agent/proto diff --git a/service/rpc/auth.go b/service/rpc/auth.go index 75ae871..6f0dda0 100644 --- a/service/rpc/auth.go +++ b/service/rpc/auth.go @@ -3,7 +3,6 @@ package rpc import ( "context" "strings" - "sync" petname "github.com/dustinkirkland/golang-petname" "github.com/hashicorp/go-uuid" @@ -64,7 +63,6 @@ func (a *authHandler) Check(ctx context.Context) (uint64, error) { } s.Host = &model.Host{} s.State = &model.HostState{} - s.TaskCloseLock = new(sync.Mutex) // generate a random silly server name singleton.ServerList[s.ID] = &s singleton.ServerUUIDToID[clientUUID] = s.ID diff --git a/service/rpc/nezha.go b/service/rpc/nezha.go index 7c11fe8..990011a 100644 --- a/service/rpc/nezha.go +++ b/service/rpc/nezha.go @@ -8,12 +8,11 @@ import ( "sync" "time" + "github.com/jinzhu/copier" "github.com/nezhahq/nezha/pkg/ddns" geoipx "github.com/nezhahq/nezha/pkg/geoip" "github.com/nezhahq/nezha/pkg/grpcx" - "github.com/jinzhu/copier" - "github.com/nezhahq/nezha/model" pb "github.com/nezhahq/nezha/proto" "github.com/nezhahq/nezha/service/singleton" @@ -37,63 +36,55 @@ func NewNezhaHandler() *NezhaHandler { } } -func (s *NezhaHandler) ReportTask(c context.Context, r *pb.TaskResult) (*pb.Receipt, error) { - var err error - var clientID uint64 - if clientID, err = s.Auth.Check(c); err != nil { - return nil, err - } - if r.GetType() == model.TaskTypeCommand { - // 处理上报的计划任务 - singleton.CronLock.RLock() - defer singleton.CronLock.RUnlock() - cr := singleton.Crons[r.GetId()] - if cr != nil { - singleton.ServerLock.RLock() - defer singleton.ServerLock.RUnlock() - // 保存当前服务器状态信息 - curServer := model.Server{} - copier.Copy(&curServer, singleton.ServerList[clientID]) - if cr.PushSuccessful && r.GetSuccessful() { - singleton.SendNotification(cr.NotificationGroupID, fmt.Sprintf("[%s] %s, %s\n%s", singleton.Localizer.T("Scheduled Task Executed Successfully"), - cr.Name, singleton.ServerList[clientID].Name, r.GetData()), nil, &curServer) - } - if !r.GetSuccessful() { - singleton.SendNotification(cr.NotificationGroupID, fmt.Sprintf("[%s] %s, %s\n%s", singleton.Localizer.T("Scheduled Task Executed Failed"), - cr.Name, singleton.ServerList[clientID].Name, r.GetData()), nil, &curServer) - } - singleton.DB.Model(cr).Updates(model.Cron{ - LastExecutedAt: time.Now().Add(time.Second * -1 * time.Duration(r.GetDelay())), - LastResult: r.GetSuccessful(), - }) - } - } else if model.IsServiceSentinelNeeded(r.GetType()) { - singleton.ServiceSentinelShared.Dispatch(singleton.ReportData{ - Data: r, - Reporter: clientID, - }) - } - return &pb.Receipt{Proced: true}, nil -} - -func (s *NezhaHandler) RequestTask(h *pb.Host, stream pb.NezhaService_RequestTaskServer) error { +func (s *NezhaHandler) RequestTask(stream pb.NezhaService_RequestTaskServer) error { var clientID uint64 var err error if clientID, err = s.Auth.Check(stream.Context()); err != nil { return err } - closeCh := make(chan error) + singleton.ServerLock.RLock() - singleton.ServerList[clientID].TaskCloseLock.Lock() - // 修复不断的请求 task 但是没有 return 导致内存泄漏 - if singleton.ServerList[clientID].TaskClose != nil { - close(singleton.ServerList[clientID].TaskClose) - } singleton.ServerList[clientID].TaskStream = stream - singleton.ServerList[clientID].TaskClose = closeCh - singleton.ServerList[clientID].TaskCloseLock.Unlock() singleton.ServerLock.RUnlock() - return <-closeCh + + var result *pb.TaskResult + for { + result, err = stream.Recv() + if err != nil { + log.Printf("NEZHA>> RequestTask error: %v, clientID: %d\n", err, clientID) + return nil + } + if result.GetType() == model.TaskTypeCommand { + // 处理上报的计划任务 + singleton.CronLock.RLock() + cr := singleton.Crons[result.GetId()] + singleton.CronLock.RUnlock() + if cr != nil { + // 保存当前服务器状态信息 + var curServer model.Server + singleton.ServerLock.RLock() + copier.Copy(&curServer, singleton.ServerList[clientID]) + singleton.ServerLock.RUnlock() + if cr.PushSuccessful && result.GetSuccessful() { + singleton.SendNotification(cr.NotificationGroupID, fmt.Sprintf("[%s] %s, %s\n%s", singleton.Localizer.T("Scheduled Task Executed Successfully"), + cr.Name, singleton.ServerList[clientID].Name, result.GetData()), nil, &curServer) + } + if !result.GetSuccessful() { + singleton.SendNotification(cr.NotificationGroupID, fmt.Sprintf("[%s] %s, %s\n%s", singleton.Localizer.T("Scheduled Task Executed Failed"), + cr.Name, singleton.ServerList[clientID].Name, result.GetData()), nil, &curServer) + } + singleton.DB.Model(cr).Updates(model.Cron{ + LastExecutedAt: time.Now().Add(time.Second * -1 * time.Duration(result.GetDelay())), + LastResult: result.GetSuccessful(), + }) + } + } else if model.IsServiceSentinelNeeded(result.GetType()) { + singleton.ServiceSentinelShared.Dispatch(singleton.ReportData{ + Data: result, + Reporter: clientID, + }) + } + } } func (s *NezhaHandler) ReportSystemState(stream pb.NezhaService_ReportSystemStateServer) error { @@ -157,10 +148,22 @@ func (s *NezhaHandler) IOStream(stream pb.NezhaService_IOStreamServer) error { if err != nil { return err } + + // ff05ff05 是 Nezha 的魔数,用于标识流 ID if id == nil || len(id.Data) < 4 || (id.Data[0] != 0xff && id.Data[1] != 0x05 && id.Data[2] != 0xff && id.Data[3] == 0x05) { return fmt.Errorf("invalid stream id") } + go func() { + for { + if err := stream.Send(&pb.IOStreamData{Data: []byte{}}); err != nil { + log.Printf("NEZHA>> IOStream keepAlive error: %v\n", err) + return + } + time.Sleep(time.Second * 30) + } + }() + streamId := string(id.Data[4:]) if _, err := s.GetStream(streamId); err != nil { diff --git a/service/singleton/server.go b/service/singleton/server.go index d6e3478..9450d9f 100644 --- a/service/singleton/server.go +++ b/service/singleton/server.go @@ -32,7 +32,6 @@ func loadServers() { innerS.Host = &model.Host{} innerS.State = &model.HostState{} innerS.GeoIP = new(model.GeoIP) - innerS.TaskCloseLock = new(sync.Mutex) ServerList[innerS.ID] = &innerS ServerUUIDToID[innerS.UUID] = innerS.ID }