From 22738b6244f1c84261cc5937f481bcdc48406677 Mon Sep 17 00:00:00 2001 From: naiba Date: Sat, 23 Nov 2024 12:43:02 +0800 Subject: [PATCH] improve: use stream reduce auth check time --- cmd/dashboard/rpc/rpc.go | 32 +++-- model/waf.go | 1 + pkg/grpcx/io_stream_wrapper.go | 2 +- pkg/utils/request_wrapper.go | 2 +- pkg/websocketx/safe_conn.go | 2 +- proto/nezha.pb.go | 64 ++++----- proto/nezha.proto | 2 +- proto/nezha_grpc.pb.go | 229 ++++++++++++++++++++++----------- service/rpc/auth.go | 14 +- service/rpc/nezha.go | 42 +++--- 10 files changed, 249 insertions(+), 141 deletions(-) diff --git a/cmd/dashboard/rpc/rpc.go b/cmd/dashboard/rpc/rpc.go index 9a70d89..415606c 100644 --- a/cmd/dashboard/rpc/rpc.go +++ b/cmd/dashboard/rpc/rpc.go @@ -20,17 +20,26 @@ import ( ) func ServeRPC() *grpc.Server { - server := grpc.NewServer(grpc.UnaryInterceptor(getRealIp)) + server := grpc.NewServer(grpc.ChainUnaryInterceptor(getRealIp, waf)) rpcService.NezhaHandlerSingleton = rpcService.NewNezhaHandler() proto.RegisterNezhaServiceServer(server, rpcService.NezhaHandlerSingleton) return server } +func waf(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + if err := model.CheckIP(singleton.DB, ctx.Value(model.CtxKeyRealIP{}).(string)); err != nil { + return nil, err + } + return handler(ctx, req) +} + func getRealIp(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { if singleton.Conf.RealIPHeader == "" { return handler(ctx, req) } + var ip string + if singleton.Conf.RealIPHeader == model.ConfigUsePeerIP { p, ok := peer.FromContext(ctx) if !ok { @@ -40,18 +49,19 @@ func getRealIp(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, if err != nil { return nil, err } - ctx = context.WithValue(ctx, model.CtxKeyRealIP{}, addrPort.Addr().String()) - return handler(ctx, req) + ip = addrPort.Addr().String() + } else { + vals := metadata.ValueFromIncomingContext(ctx, singleton.Conf.RealIPHeader) + if len(vals) == 0 { + return nil, fmt.Errorf("real ip header not found") + } + var err error + ip, err = utils.GetIPFromHeader(vals[0]) + if err != nil { + return nil, err + } } - vals := metadata.ValueFromIncomingContext(ctx, singleton.Conf.RealIPHeader) - if len(vals) == 0 { - return nil, fmt.Errorf("real ip header not found") - } - ip, err := utils.GetIPFromHeader(vals[0]) - if err != nil { - return nil, err - } ctx = context.WithValue(ctx, model.CtxKeyRealIP{}, ip) return handler(ctx, req) } diff --git a/model/waf.go b/model/waf.go index d71f5b1..e74135d 100644 --- a/model/waf.go +++ b/model/waf.go @@ -13,6 +13,7 @@ const ( _ uint8 = iota WAFBlockReasonTypeLoginFail WAFBlockReasonTypeBruteForceToken + WAFBlockReasonTypeAgentAuthFail ) type WAF struct { diff --git a/pkg/grpcx/io_stream_wrapper.go b/pkg/grpcx/io_stream_wrapper.go index 72855c4..8adaa64 100644 --- a/pkg/grpcx/io_stream_wrapper.go +++ b/pkg/grpcx/io_stream_wrapper.go @@ -8,7 +8,7 @@ import ( "github.com/naiba/nezha/proto" ) -var _ io.ReadWriteCloser = &IOStreamWrapper{} +var _ io.ReadWriteCloser = (*IOStreamWrapper)(nil) type IOStream interface { Recv() (*proto.IOStreamData, error) diff --git a/pkg/utils/request_wrapper.go b/pkg/utils/request_wrapper.go index e4f6103..e18e076 100644 --- a/pkg/utils/request_wrapper.go +++ b/pkg/utils/request_wrapper.go @@ -8,7 +8,7 @@ import ( "net/http" ) -var _ io.ReadWriteCloser = &RequestWrapper{} +var _ io.ReadWriteCloser = (*RequestWrapper)(nil) type RequestWrapper struct { req *http.Request diff --git a/pkg/websocketx/safe_conn.go b/pkg/websocketx/safe_conn.go index 512bddf..c38d487 100644 --- a/pkg/websocketx/safe_conn.go +++ b/pkg/websocketx/safe_conn.go @@ -7,7 +7,7 @@ import ( "github.com/gorilla/websocket" ) -var _ io.ReadWriteCloser = &Conn{} +var _ io.ReadWriteCloser = (*Conn)(nil) type Conn struct { *websocket.Conn diff --git a/proto/nezha.pb.go b/proto/nezha.pb.go index ec5c12b..8573cb1 100644 --- a/proto/nezha.pb.go +++ b/proto/nezha.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.34.1 -// protoc v5.28.3 +// protoc-gen-go v1.34.2 +// protoc v5.28.1 // source: proto/nezha.proto package proto @@ -820,28 +820,28 @@ var file_proto_nezha_proto_rawDesc = []byte{ 0x74, 0x72, 0x79, 0x43, 0x6f, 0x64, 0x65, 0x22, 0x2c, 0x0a, 0x02, 0x49, 0x50, 0x12, 0x12, 0x0a, 0x04, 0x69, 0x70, 0x76, 0x34, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x69, 0x70, 0x76, 0x34, 0x12, 0x12, 0x0a, 0x04, 0x69, 0x70, 0x76, 0x36, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x04, 0x69, 0x70, 0x76, 0x36, 0x32, 0xbf, 0x02, 0x0a, 0x0c, 0x4e, 0x65, 0x7a, 0x68, 0x61, 0x53, - 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x33, 0x0a, 0x11, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, + 0x04, 0x69, 0x70, 0x76, 0x36, 0x32, 0xc3, 0x02, 0x0a, 0x0c, 0x4e, 0x65, 0x7a, 0x68, 0x61, 0x53, + 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x37, 0x0a, 0x11, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x0c, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x1a, 0x0e, 0x2e, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x65, 0x69, 0x70, 0x74, 0x22, 0x00, 0x12, 0x31, 0x0a, 0x10, 0x52, - 0x65, 0x70, 0x6f, 0x72, 0x74, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, 0x6e, 0x66, 0x6f, 0x12, - 0x0b, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x1a, 0x0e, 0x2e, 0x70, - 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x65, 0x69, 0x70, 0x74, 0x22, 0x00, 0x12, 0x31, - 0x0a, 0x0a, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x54, 0x61, 0x73, 0x6b, 0x12, 0x11, 0x2e, 0x70, - 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x54, 0x61, 0x73, 0x6b, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x1a, - 0x0e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x65, 0x69, 0x70, 0x74, 0x22, - 0x00, 0x12, 0x2b, 0x0a, 0x0b, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x54, 0x61, 0x73, 0x6b, - 0x12, 0x0b, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x1a, 0x0b, 0x2e, - 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x54, 0x61, 0x73, 0x6b, 0x22, 0x00, 0x30, 0x01, 0x12, 0x3a, - 0x0a, 0x08, 0x49, 0x4f, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, 0x13, 0x2e, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x2e, 0x49, 0x4f, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x44, 0x61, 0x74, 0x61, 0x1a, - 0x13, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x49, 0x4f, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, - 0x44, 0x61, 0x74, 0x61, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x12, 0x2b, 0x0a, 0x0b, 0x52, 0x65, - 0x70, 0x6f, 0x72, 0x74, 0x47, 0x65, 0x6f, 0x49, 0x50, 0x12, 0x0c, 0x2e, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x2e, 0x47, 0x65, 0x6f, 0x49, 0x50, 0x1a, 0x0c, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, - 0x47, 0x65, 0x6f, 0x49, 0x50, 0x22, 0x00, 0x42, 0x09, 0x5a, 0x07, 0x2e, 0x2f, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x65, 0x69, 0x70, 0x74, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x12, + 0x31, 0x0a, 0x10, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, + 0x6e, 0x66, 0x6f, 0x12, 0x0b, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x48, 0x6f, 0x73, 0x74, + 0x1a, 0x0e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x65, 0x69, 0x70, 0x74, + 0x22, 0x00, 0x12, 0x31, 0x0a, 0x0a, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x54, 0x61, 0x73, 0x6b, + 0x12, 0x11, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x54, 0x61, 0x73, 0x6b, 0x52, 0x65, 0x73, + 0x75, 0x6c, 0x74, 0x1a, 0x0e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x65, + 0x69, 0x70, 0x74, 0x22, 0x00, 0x12, 0x2b, 0x0a, 0x0b, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x54, 0x61, 0x73, 0x6b, 0x12, 0x0b, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x48, 0x6f, 0x73, + 0x74, 0x1a, 0x0b, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x54, 0x61, 0x73, 0x6b, 0x22, 0x00, + 0x30, 0x01, 0x12, 0x3a, 0x0a, 0x08, 0x49, 0x4f, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, 0x13, + 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x49, 0x4f, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x44, + 0x61, 0x74, 0x61, 0x1a, 0x13, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x49, 0x4f, 0x53, 0x74, + 0x72, 0x65, 0x61, 0x6d, 0x44, 0x61, 0x74, 0x61, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x12, 0x2b, + 0x0a, 0x0b, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x47, 0x65, 0x6f, 0x49, 0x50, 0x12, 0x0c, 0x2e, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x47, 0x65, 0x6f, 0x49, 0x50, 0x1a, 0x0c, 0x2e, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x2e, 0x47, 0x65, 0x6f, 0x49, 0x50, 0x22, 0x00, 0x42, 0x09, 0x5a, 0x07, 0x2e, + 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -857,7 +857,7 @@ func file_proto_nezha_proto_rawDescGZIP() []byte { } var file_proto_nezha_proto_msgTypes = make([]protoimpl.MessageInfo, 9) -var file_proto_nezha_proto_goTypes = []interface{}{ +var file_proto_nezha_proto_goTypes = []any{ (*Host)(nil), // 0: proto.Host (*State)(nil), // 1: proto.State (*State_SensorTemperature)(nil), // 2: proto.State_SensorTemperature @@ -896,7 +896,7 @@ func file_proto_nezha_proto_init() { return } if !protoimpl.UnsafeEnabled { - file_proto_nezha_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + file_proto_nezha_proto_msgTypes[0].Exporter = func(v any, i int) any { switch v := v.(*Host); i { case 0: return &v.state @@ -908,7 +908,7 @@ func file_proto_nezha_proto_init() { return nil } } - file_proto_nezha_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + file_proto_nezha_proto_msgTypes[1].Exporter = func(v any, i int) any { switch v := v.(*State); i { case 0: return &v.state @@ -920,7 +920,7 @@ func file_proto_nezha_proto_init() { return nil } } - file_proto_nezha_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + file_proto_nezha_proto_msgTypes[2].Exporter = func(v any, i int) any { switch v := v.(*State_SensorTemperature); i { case 0: return &v.state @@ -932,7 +932,7 @@ func file_proto_nezha_proto_init() { return nil } } - file_proto_nezha_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { + file_proto_nezha_proto_msgTypes[3].Exporter = func(v any, i int) any { switch v := v.(*Task); i { case 0: return &v.state @@ -944,7 +944,7 @@ func file_proto_nezha_proto_init() { return nil } } - file_proto_nezha_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { + file_proto_nezha_proto_msgTypes[4].Exporter = func(v any, i int) any { switch v := v.(*TaskResult); i { case 0: return &v.state @@ -956,7 +956,7 @@ func file_proto_nezha_proto_init() { return nil } } - file_proto_nezha_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} { + file_proto_nezha_proto_msgTypes[5].Exporter = func(v any, i int) any { switch v := v.(*Receipt); i { case 0: return &v.state @@ -968,7 +968,7 @@ func file_proto_nezha_proto_init() { return nil } } - file_proto_nezha_proto_msgTypes[6].Exporter = func(v interface{}, i int) interface{} { + file_proto_nezha_proto_msgTypes[6].Exporter = func(v any, i int) any { switch v := v.(*IOStreamData); i { case 0: return &v.state @@ -980,7 +980,7 @@ func file_proto_nezha_proto_init() { return nil } } - file_proto_nezha_proto_msgTypes[7].Exporter = func(v interface{}, i int) interface{} { + file_proto_nezha_proto_msgTypes[7].Exporter = func(v any, i int) any { switch v := v.(*GeoIP); i { case 0: return &v.state @@ -992,7 +992,7 @@ func file_proto_nezha_proto_init() { return nil } } - file_proto_nezha_proto_msgTypes[8].Exporter = func(v interface{}, i int) interface{} { + file_proto_nezha_proto_msgTypes[8].Exporter = func(v any, i int) any { switch v := v.(*IP); i { case 0: return &v.state diff --git a/proto/nezha.proto b/proto/nezha.proto index 5f12b01..632bbd3 100644 --- a/proto/nezha.proto +++ b/proto/nezha.proto @@ -4,7 +4,7 @@ option go_package = "./proto"; package proto; service NezhaService { - rpc ReportSystemState(State) returns (Receipt) {} + rpc ReportSystemState(stream State) returns (stream Receipt) {} rpc ReportSystemInfo(Host) returns (Receipt) {} rpc ReportTask(TaskResult) returns (Receipt) {} rpc RequestTask(Host) returns (stream Task) {} diff --git a/proto/nezha_grpc.pb.go b/proto/nezha_grpc.pb.go index b719f76..0e80ed1 100644 --- a/proto/nezha_grpc.pb.go +++ b/proto/nezha_grpc.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: -// - protoc-gen-go-grpc v1.5.1 -// - protoc v5.28.3 +// - protoc-gen-go-grpc v1.3.0 +// - protoc v5.28.1 // source: proto/nezha.proto package proto @@ -15,8 +15,8 @@ import ( // This is a compile-time assertion to ensure that this generated file // is compatible with the grpc package it is being compiled against. -// Requires gRPC-Go v1.64.0 or later. -const _ = grpc.SupportPackageIsVersion9 +// Requires gRPC-Go v1.32.0 or later. +const _ = grpc.SupportPackageIsVersion7 const ( NezhaService_ReportSystemState_FullMethodName = "/proto.NezhaService/ReportSystemState" @@ -31,11 +31,11 @@ const ( // // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. type NezhaServiceClient interface { - ReportSystemState(ctx context.Context, in *State, opts ...grpc.CallOption) (*Receipt, error) + ReportSystemState(ctx context.Context, opts ...grpc.CallOption) (NezhaService_ReportSystemStateClient, error) ReportSystemInfo(ctx context.Context, in *Host, opts ...grpc.CallOption) (*Receipt, error) ReportTask(ctx context.Context, in *TaskResult, opts ...grpc.CallOption) (*Receipt, error) - RequestTask(ctx context.Context, in *Host, opts ...grpc.CallOption) (grpc.ServerStreamingClient[Task], error) - IOStream(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[IOStreamData, IOStreamData], error) + RequestTask(ctx context.Context, in *Host, opts ...grpc.CallOption) (NezhaService_RequestTaskClient, error) + IOStream(ctx context.Context, opts ...grpc.CallOption) (NezhaService_IOStreamClient, error) ReportGeoIP(ctx context.Context, in *GeoIP, opts ...grpc.CallOption) (*GeoIP, error) } @@ -47,20 +47,40 @@ func NewNezhaServiceClient(cc grpc.ClientConnInterface) NezhaServiceClient { return &nezhaServiceClient{cc} } -func (c *nezhaServiceClient) ReportSystemState(ctx context.Context, in *State, opts ...grpc.CallOption) (*Receipt, error) { - cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) - out := new(Receipt) - err := c.cc.Invoke(ctx, NezhaService_ReportSystemState_FullMethodName, in, out, cOpts...) +func (c *nezhaServiceClient) ReportSystemState(ctx context.Context, opts ...grpc.CallOption) (NezhaService_ReportSystemStateClient, error) { + stream, err := c.cc.NewStream(ctx, &NezhaService_ServiceDesc.Streams[0], NezhaService_ReportSystemState_FullMethodName, opts...) if err != nil { return nil, err } - return out, nil + x := &nezhaServiceReportSystemStateClient{stream} + return x, nil +} + +type NezhaService_ReportSystemStateClient interface { + Send(*State) error + Recv() (*Receipt, error) + grpc.ClientStream +} + +type nezhaServiceReportSystemStateClient struct { + grpc.ClientStream +} + +func (x *nezhaServiceReportSystemStateClient) Send(m *State) error { + return x.ClientStream.SendMsg(m) +} + +func (x *nezhaServiceReportSystemStateClient) Recv() (*Receipt, error) { + m := new(Receipt) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil } func (c *nezhaServiceClient) ReportSystemInfo(ctx context.Context, in *Host, opts ...grpc.CallOption) (*Receipt, error) { - cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(Receipt) - err := c.cc.Invoke(ctx, NezhaService_ReportSystemInfo_FullMethodName, in, out, cOpts...) + err := c.cc.Invoke(ctx, NezhaService_ReportSystemInfo_FullMethodName, in, out, opts...) if err != nil { return nil, err } @@ -68,22 +88,20 @@ func (c *nezhaServiceClient) ReportSystemInfo(ctx context.Context, in *Host, opt } func (c *nezhaServiceClient) ReportTask(ctx context.Context, in *TaskResult, opts ...grpc.CallOption) (*Receipt, error) { - cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(Receipt) - err := c.cc.Invoke(ctx, NezhaService_ReportTask_FullMethodName, in, out, cOpts...) + err := c.cc.Invoke(ctx, NezhaService_ReportTask_FullMethodName, in, out, opts...) if err != nil { return nil, err } return out, nil } -func (c *nezhaServiceClient) RequestTask(ctx context.Context, in *Host, opts ...grpc.CallOption) (grpc.ServerStreamingClient[Task], error) { - cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) - stream, err := c.cc.NewStream(ctx, &NezhaService_ServiceDesc.Streams[0], NezhaService_RequestTask_FullMethodName, cOpts...) +func (c *nezhaServiceClient) RequestTask(ctx context.Context, in *Host, opts ...grpc.CallOption) (NezhaService_RequestTaskClient, error) { + stream, err := c.cc.NewStream(ctx, &NezhaService_ServiceDesc.Streams[1], NezhaService_RequestTask_FullMethodName, opts...) if err != nil { return nil, err } - x := &grpc.GenericClientStream[Host, Task]{ClientStream: stream} + x := &nezhaServiceRequestTaskClient{stream} if err := x.ClientStream.SendMsg(in); err != nil { return nil, err } @@ -93,26 +111,57 @@ func (c *nezhaServiceClient) RequestTask(ctx context.Context, in *Host, opts ... return x, nil } -// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. -type NezhaService_RequestTaskClient = grpc.ServerStreamingClient[Task] +type NezhaService_RequestTaskClient interface { + Recv() (*Task, error) + grpc.ClientStream +} -func (c *nezhaServiceClient) IOStream(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[IOStreamData, IOStreamData], error) { - cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) - stream, err := c.cc.NewStream(ctx, &NezhaService_ServiceDesc.Streams[1], NezhaService_IOStream_FullMethodName, cOpts...) +type nezhaServiceRequestTaskClient struct { + grpc.ClientStream +} + +func (x *nezhaServiceRequestTaskClient) Recv() (*Task, error) { + m := new(Task) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +func (c *nezhaServiceClient) IOStream(ctx context.Context, opts ...grpc.CallOption) (NezhaService_IOStreamClient, error) { + stream, err := c.cc.NewStream(ctx, &NezhaService_ServiceDesc.Streams[2], NezhaService_IOStream_FullMethodName, opts...) if err != nil { return nil, err } - x := &grpc.GenericClientStream[IOStreamData, IOStreamData]{ClientStream: stream} + x := &nezhaServiceIOStreamClient{stream} return x, nil } -// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. -type NezhaService_IOStreamClient = grpc.BidiStreamingClient[IOStreamData, IOStreamData] +type NezhaService_IOStreamClient interface { + Send(*IOStreamData) error + Recv() (*IOStreamData, error) + grpc.ClientStream +} + +type nezhaServiceIOStreamClient struct { + grpc.ClientStream +} + +func (x *nezhaServiceIOStreamClient) Send(m *IOStreamData) error { + return x.ClientStream.SendMsg(m) +} + +func (x *nezhaServiceIOStreamClient) Recv() (*IOStreamData, error) { + m := new(IOStreamData) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} func (c *nezhaServiceClient) ReportGeoIP(ctx context.Context, in *GeoIP, opts ...grpc.CallOption) (*GeoIP, error) { - cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(GeoIP) - err := c.cc.Invoke(ctx, NezhaService_ReportGeoIP_FullMethodName, in, out, cOpts...) + err := c.cc.Invoke(ctx, NezhaService_ReportGeoIP_FullMethodName, in, out, opts...) if err != nil { return nil, err } @@ -121,25 +170,22 @@ func (c *nezhaServiceClient) ReportGeoIP(ctx context.Context, in *GeoIP, opts .. // NezhaServiceServer is the server API for NezhaService service. // All implementations should embed UnimplementedNezhaServiceServer -// for forward compatibility. +// for forward compatibility type NezhaServiceServer interface { - ReportSystemState(context.Context, *State) (*Receipt, error) + ReportSystemState(NezhaService_ReportSystemStateServer) error ReportSystemInfo(context.Context, *Host) (*Receipt, error) ReportTask(context.Context, *TaskResult) (*Receipt, error) - RequestTask(*Host, grpc.ServerStreamingServer[Task]) error - IOStream(grpc.BidiStreamingServer[IOStreamData, IOStreamData]) error + RequestTask(*Host, NezhaService_RequestTaskServer) error + IOStream(NezhaService_IOStreamServer) error ReportGeoIP(context.Context, *GeoIP) (*GeoIP, error) } -// UnimplementedNezhaServiceServer should be embedded to have -// forward compatible implementations. -// -// NOTE: this should be embedded by value instead of pointer to avoid a nil -// pointer dereference when methods are called. -type UnimplementedNezhaServiceServer struct{} +// UnimplementedNezhaServiceServer should be embedded to have forward compatible implementations. +type UnimplementedNezhaServiceServer struct { +} -func (UnimplementedNezhaServiceServer) ReportSystemState(context.Context, *State) (*Receipt, error) { - return nil, status.Errorf(codes.Unimplemented, "method ReportSystemState not implemented") +func (UnimplementedNezhaServiceServer) ReportSystemState(NezhaService_ReportSystemStateServer) error { + return status.Errorf(codes.Unimplemented, "method ReportSystemState not implemented") } func (UnimplementedNezhaServiceServer) ReportSystemInfo(context.Context, *Host) (*Receipt, error) { return nil, status.Errorf(codes.Unimplemented, "method ReportSystemInfo not implemented") @@ -147,16 +193,15 @@ func (UnimplementedNezhaServiceServer) ReportSystemInfo(context.Context, *Host) func (UnimplementedNezhaServiceServer) ReportTask(context.Context, *TaskResult) (*Receipt, error) { return nil, status.Errorf(codes.Unimplemented, "method ReportTask not implemented") } -func (UnimplementedNezhaServiceServer) RequestTask(*Host, grpc.ServerStreamingServer[Task]) error { +func (UnimplementedNezhaServiceServer) RequestTask(*Host, NezhaService_RequestTaskServer) error { return status.Errorf(codes.Unimplemented, "method RequestTask not implemented") } -func (UnimplementedNezhaServiceServer) IOStream(grpc.BidiStreamingServer[IOStreamData, IOStreamData]) error { +func (UnimplementedNezhaServiceServer) IOStream(NezhaService_IOStreamServer) error { return status.Errorf(codes.Unimplemented, "method IOStream not implemented") } func (UnimplementedNezhaServiceServer) ReportGeoIP(context.Context, *GeoIP) (*GeoIP, error) { return nil, status.Errorf(codes.Unimplemented, "method ReportGeoIP not implemented") } -func (UnimplementedNezhaServiceServer) testEmbeddedByValue() {} // UnsafeNezhaServiceServer may be embedded to opt out of forward compatibility for this service. // Use of this interface is not recommended, as added methods to NezhaServiceServer will @@ -166,32 +211,33 @@ type UnsafeNezhaServiceServer interface { } func RegisterNezhaServiceServer(s grpc.ServiceRegistrar, srv NezhaServiceServer) { - // If the following call panics, it indicates UnimplementedNezhaServiceServer was - // embedded by pointer and is nil. This will cause panics if an - // unimplemented method is ever invoked, so we test this at initialization - // time to prevent it from happening at runtime later due to I/O. - if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { - t.testEmbeddedByValue() - } s.RegisterService(&NezhaService_ServiceDesc, srv) } -func _NezhaService_ReportSystemState_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(State) - if err := dec(in); err != nil { +func _NezhaService_ReportSystemState_Handler(srv interface{}, stream grpc.ServerStream) error { + return srv.(NezhaServiceServer).ReportSystemState(&nezhaServiceReportSystemStateServer{stream}) +} + +type NezhaService_ReportSystemStateServer interface { + Send(*Receipt) error + Recv() (*State, error) + grpc.ServerStream +} + +type nezhaServiceReportSystemStateServer struct { + grpc.ServerStream +} + +func (x *nezhaServiceReportSystemStateServer) Send(m *Receipt) error { + return x.ServerStream.SendMsg(m) +} + +func (x *nezhaServiceReportSystemStateServer) Recv() (*State, error) { + m := new(State) + if err := x.ServerStream.RecvMsg(m); err != nil { return nil, err } - if interceptor == nil { - return srv.(NezhaServiceServer).ReportSystemState(ctx, in) - } - info := &grpc.UnaryServerInfo{ - Server: srv, - FullMethod: NezhaService_ReportSystemState_FullMethodName, - } - handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(NezhaServiceServer).ReportSystemState(ctx, req.(*State)) - } - return interceptor(ctx, in, info, handler) + return m, nil } func _NezhaService_ReportSystemInfo_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { @@ -235,18 +281,47 @@ func _NezhaService_RequestTask_Handler(srv interface{}, stream grpc.ServerStream if err := stream.RecvMsg(m); err != nil { return err } - return srv.(NezhaServiceServer).RequestTask(m, &grpc.GenericServerStream[Host, Task]{ServerStream: stream}) + return srv.(NezhaServiceServer).RequestTask(m, &nezhaServiceRequestTaskServer{stream}) } -// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. -type NezhaService_RequestTaskServer = grpc.ServerStreamingServer[Task] +type NezhaService_RequestTaskServer interface { + Send(*Task) error + grpc.ServerStream +} + +type nezhaServiceRequestTaskServer struct { + grpc.ServerStream +} + +func (x *nezhaServiceRequestTaskServer) Send(m *Task) error { + return x.ServerStream.SendMsg(m) +} func _NezhaService_IOStream_Handler(srv interface{}, stream grpc.ServerStream) error { - return srv.(NezhaServiceServer).IOStream(&grpc.GenericServerStream[IOStreamData, IOStreamData]{ServerStream: stream}) + return srv.(NezhaServiceServer).IOStream(&nezhaServiceIOStreamServer{stream}) } -// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. -type NezhaService_IOStreamServer = grpc.BidiStreamingServer[IOStreamData, IOStreamData] +type NezhaService_IOStreamServer interface { + Send(*IOStreamData) error + Recv() (*IOStreamData, error) + grpc.ServerStream +} + +type nezhaServiceIOStreamServer struct { + grpc.ServerStream +} + +func (x *nezhaServiceIOStreamServer) Send(m *IOStreamData) error { + return x.ServerStream.SendMsg(m) +} + +func (x *nezhaServiceIOStreamServer) Recv() (*IOStreamData, error) { + m := new(IOStreamData) + if err := x.ServerStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} func _NezhaService_ReportGeoIP_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(GeoIP) @@ -273,10 +348,6 @@ var NezhaService_ServiceDesc = grpc.ServiceDesc{ ServiceName: "proto.NezhaService", HandlerType: (*NezhaServiceServer)(nil), Methods: []grpc.MethodDesc{ - { - MethodName: "ReportSystemState", - Handler: _NezhaService_ReportSystemState_Handler, - }, { MethodName: "ReportSystemInfo", Handler: _NezhaService_ReportSystemInfo_Handler, @@ -291,6 +362,12 @@ var NezhaService_ServiceDesc = grpc.ServiceDesc{ }, }, Streams: []grpc.StreamDesc{ + { + StreamName: "ReportSystemState", + Handler: _NezhaService_ReportSystemState_Handler, + ServerStreams: true, + ClientStreams: true, + }, { StreamName: "RequestTask", Handler: _NezhaService_RequestTask_Handler, diff --git a/service/rpc/auth.go b/service/rpc/auth.go index 8b49de9..adf869c 100644 --- a/service/rpc/auth.go +++ b/service/rpc/auth.go @@ -2,6 +2,7 @@ package rpc import ( "context" + "strings" "sync" petname "github.com/dustinkirkland/golang-petname" @@ -27,13 +28,22 @@ func (a *authHandler) Check(ctx context.Context) (uint64, error) { var clientSecret string if value, ok := md["client_secret"]; ok { - clientSecret = value[0] + clientSecret = strings.TrimSpace(value[0]) } - if clientSecret != singleton.Conf.AgentSecretKey { + if clientSecret == "" { return 0, status.Error(codes.Unauthenticated, "客户端认证失败") } + ip, _ := ctx.Value(model.CtxKeyRealIP{}).(string) + + if clientSecret != singleton.Conf.AgentSecretKey { + model.BlockIP(singleton.DB, ip, model.WAFBlockReasonTypeAgentAuthFail) + return 0, status.Error(codes.Unauthenticated, "客户端认证失败") + } + + model.ClearIP(singleton.DB, ip) + var clientUUID string if value, ok := md["client_uuid"]; ok { clientUUID = value[0] diff --git a/service/rpc/nezha.go b/service/rpc/nezha.go index 596e490..4065804 100644 --- a/service/rpc/nezha.go +++ b/service/rpc/nezha.go @@ -19,6 +19,8 @@ import ( "github.com/naiba/nezha/service/singleton" ) +var _ pb.NezhaServiceServer = (*NezhaHandler)(nil) + var NezhaHandlerSingleton *NezhaHandler type NezhaHandler struct { @@ -94,25 +96,33 @@ func (s *NezhaHandler) RequestTask(h *pb.Host, stream pb.NezhaService_RequestTas return <-closeCh } -func (s *NezhaHandler) ReportSystemState(c context.Context, r *pb.State) (*pb.Receipt, error) { - var clientID uint64 +func (s *NezhaHandler) ReportSystemState(stream pb.NezhaService_ReportSystemStateServer) error { var err error - if clientID, err = s.Auth.Check(c); err != nil { - return nil, err + var clientID uint64 + if clientID, err = s.Auth.Check(stream.Context()); err != nil { + return err } - state := model.PB2State(r) - singleton.ServerLock.RLock() - defer singleton.ServerLock.RUnlock() - singleton.ServerList[clientID].LastActive = time.Now() - singleton.ServerList[clientID].State = &state + var state *pb.State + for { + state, err = stream.Recv() + if err != nil { + log.Printf("NEZHA>> ReportSystemState eror: %v, clientID: %d\n", err, clientID) + return nil + } + state := model.PB2State(state) - // 应对 dashboard 重启的情况,如果从未记录过,先打点,等到小时时间点时入库 - if singleton.ServerList[clientID].PrevTransferInSnapshot == 0 || singleton.ServerList[clientID].PrevTransferOutSnapshot == 0 { - singleton.ServerList[clientID].PrevTransferInSnapshot = int64(state.NetInTransfer) - singleton.ServerList[clientID].PrevTransferOutSnapshot = int64(state.NetOutTransfer) + singleton.ServerLock.RLock() + singleton.ServerList[clientID].LastActive = time.Now() + singleton.ServerList[clientID].State = &state + // 应对 dashboard 重启的情况,如果从未记录过,先打点,等到小时时间点时入库 + if singleton.ServerList[clientID].PrevTransferInSnapshot == 0 || singleton.ServerList[clientID].PrevTransferOutSnapshot == 0 { + singleton.ServerList[clientID].PrevTransferInSnapshot = int64(state.NetInTransfer) + singleton.ServerList[clientID].PrevTransferOutSnapshot = int64(state.NetOutTransfer) + } + singleton.ServerLock.RUnlock() + + stream.Send(&pb.Receipt{Proced: true}) } - - return &pb.Receipt{Proced: true}, nil } func (s *NezhaHandler) ReportSystemInfo(c context.Context, r *pb.Host) (*pb.Receipt, error) { @@ -232,5 +242,5 @@ func (s *NezhaHandler) ReportGeoIP(c context.Context, r *pb.GeoIP) (*pb.GeoIP, e defer singleton.ServerLock.Unlock() singleton.ServerList[clientID].GeoIP = &geoip - return &pb.GeoIP{Ip: nil, CountryCode: location}, err + return &pb.GeoIP{Ip: nil, CountryCode: location}, nil }