From 2d279aa2027960ceb5ccd9424104458d4a13d9ea Mon Sep 17 00:00:00 2001 From: yuzuki999 Date: Sun, 12 Jun 2022 21:10:20 +0800 Subject: [PATCH] add protocol rule,fix limit not work --- api/api.go | 2 +- api/interface.go | 2 +- api/node.go | 31 +++++++++++++++++++------------ app/mydispatcher/default.go | 30 ++++++++++++++++++------------ common/limiter/limiter.go | 2 +- common/rule/rule.go | 23 +++++++++++++++++++++++ service/controller/control.go | 6 ++++++ service/controller/controller.go | 17 ++++++++++++++--- 8 files changed, 83 insertions(+), 30 deletions(-) diff --git a/api/api.go b/api/api.go index 0a251b8..3757cbe 100644 --- a/api/api.go +++ b/api/api.go @@ -47,7 +47,7 @@ type Client struct { SpeedLimit float64 DeviceLimit int LocalRuleList []DetectRule - RemoteRuleCache *Rule + RemoteRuleCache *[]Rule access sync.Mutex NodeInfoRspMd5 [16]byte NodeRuleRspMd5 [16]byte diff --git a/api/interface.go b/api/interface.go index 40e92ab..80ef116 100644 --- a/api/interface.go +++ b/api/interface.go @@ -5,6 +5,6 @@ type API interface { GetUserList() (userList *[]UserInfo, err error) ReportUserTraffic(userTraffic *[]UserTraffic) (err error) Describe() ClientInfo - GetNodeRule() (ruleList *[]DetectRule, err error) + GetNodeRule() (ruleList *[]DetectRule, protocolList *[]string, err error) Debug() } diff --git a/api/node.go b/api/node.go index 81f448f..10a585e 100644 --- a/api/node.go +++ b/api/node.go @@ -77,7 +77,7 @@ type SSConfig struct { type V2rayConfig struct { Inbounds []conf.InboundDetourConfig `json:"inbounds"` Routing *struct { - Rules []json.RawMessage `json:"rules"` + Rules json.RawMessage `json:"rules"` } `json:"routing"` } @@ -160,25 +160,32 @@ func (c *Client) GetNodeInfo() (nodeInfo *NodeInfo, err error) { return nodeInfo, nil } -func (c *Client) GetNodeRule() (*[]DetectRule, error) { +func (c *Client) GetNodeRule() (*[]DetectRule, *[]string, error) { ruleList := c.LocalRuleList if c.NodeType != "V2ray" || c.RemoteRuleCache == nil { - return &ruleList, nil + return &ruleList, nil, nil } - // V2board only support the rule for v2ray // fix: reuse config response c.access.Lock() defer c.access.Unlock() - for i, rule := range c.RemoteRuleCache.Domain { - ruleListItem := DetectRule{ - ID: i, - Pattern: regexp.MustCompile(rule), + if len(*c.RemoteRuleCache) >= 2 { + for i, rule := range (*c.RemoteRuleCache)[1].Domain { + ruleListItem := DetectRule{ + ID: i, + Pattern: regexp.MustCompile(rule), + } + ruleList = append(ruleList, ruleListItem) + } + } + var protocolList []string + if len(*c.RemoteRuleCache) >= 3 { + for _, str := range (*c.RemoteRuleCache)[2].Protocol { + protocolList = append(protocolList, str) } - ruleList = append(ruleList, ruleListItem) } c.RemoteRuleCache = nil - return &ruleList, nil + return &ruleList, &protocolList, nil } // ParseTrojanNodeResponse parse the response for the given nodeinfor format @@ -238,8 +245,8 @@ func (c *Client) ParseV2rayNodeResponse(body []byte, notParseNode, parseRule boo return nil, fmt.Errorf("unmarshal nodeinfo error: %s", err) } if parseRule { - c.RemoteRuleCache = &Rule{} - err := json.Unmarshal(node.V2ray.Routing.Rules[1], c.RemoteRuleCache) + c.RemoteRuleCache = &[]Rule{} + err := json.Unmarshal(node.V2ray.Routing.Rules, c.RemoteRuleCache) if err != nil { log.Println(err) } diff --git a/app/mydispatcher/default.go b/app/mydispatcher/default.go index b9b6c9a..591e22e 100644 --- a/app/mydispatcher/default.go +++ b/app/mydispatcher/default.go @@ -5,10 +5,6 @@ package mydispatcher import ( "context" "fmt" - "strings" - "sync" - "time" - "github.com/Yuzuki616/V2bX/common/limiter" "github.com/Yuzuki616/V2bX/common/rule" "github.com/xtls/xray-core/common" @@ -26,6 +22,9 @@ import ( "github.com/xtls/xray-core/features/stats" "github.com/xtls/xray-core/transport" "github.com/xtls/xray-core/transport/pipe" + "strings" + "sync" + "time" ) var errSniffingTimeout = newError("timeout on sniffing") @@ -320,7 +319,7 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin } switch { case !sniffingRequest.Enabled: - go d.routedDispatch(ctx, outbound, destination) + go d.routedDispatch(ctx, outbound, destination, "") case destination.Network != net.Network_TCP: // Only metadata sniff will be used for non tcp connection result, err := sniffer(ctx, nil, true) @@ -337,7 +336,7 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin } } } - go d.routedDispatch(ctx, outbound, destination) + go d.routedDispatch(ctx, outbound, destination, content.Protocol) default: go func() { cReader := &cachedReader{ @@ -358,7 +357,7 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin ob.Target = destination } } - d.routedDispatch(ctx, outbound, destination) + d.routedDispatch(ctx, outbound, destination, content.Protocol) }() } return inbound, nil @@ -381,7 +380,7 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De sniffingRequest := content.SniffingRequest switch { case !sniffingRequest.Enabled: - go d.routedDispatch(ctx, outbound, destination) + go d.routedDispatch(ctx, outbound, destination, "") case destination.Network != net.Network_TCP: // Only metadata sniff will be used for non tcp connection result, err := sniffer(ctx, nil, true) @@ -398,7 +397,7 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De } } } - go d.routedDispatch(ctx, outbound, destination) + go d.routedDispatch(ctx, outbound, destination, content.Protocol) default: go func() { cReader := &cachedReader{ @@ -419,7 +418,7 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De ob.Target = destination } } - d.routedDispatch(ctx, outbound, destination) + d.routedDispatch(ctx, outbound, destination, content.Protocol) }() } return nil @@ -471,7 +470,7 @@ func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool) (Sni return contentResult, contentErr } -func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.Link, destination net.Destination) { +func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.Link, destination net.Destination, protocol string) { ob := session.OutboundFromContext(ctx) if hosts, ok := d.dns.(dns.HostsLookup); ok && destination.Address.Family().IsDomain() { proxied := hosts.LookupHosts(ob.Target.String()) @@ -499,6 +498,13 @@ func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport. return }*/ if sessionInbound.User != nil { + if d.RuleManager.ProtocolDetect(sessionInbound.Tag, protocol) { + newError(fmt.Sprintf("User %s access %s reject by protocol rule", sessionInbound.User.Email, destination.String())).AtError().WriteToLog() + newError("destination is reject by protocol rule") + common.Close(link.Writer) + common.Interrupt(link.Reader) + return + } if d.RuleManager.Detect(sessionInbound.Tag, destination.String(), sessionInbound.User.Email) { newError(fmt.Sprintf("User %s access %s reject by rule", sessionInbound.User.Email, destination.String())).AtError().WriteToLog() newError("destination is reject by rule") @@ -539,7 +545,7 @@ func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport. } if handler == nil { - handler = d.ohm.GetHandler(inTag) // Default outbound hander tag should be as same as the inbound tag + handler = d.ohm.GetHandler(inTag) // Default outbound handier tag should be as same as the inbound tag } // If there is no outbound with tag as same as the inbound tag diff --git a/common/limiter/limiter.go b/common/limiter/limiter.go index fa4b40e..c06800c 100644 --- a/common/limiter/limiter.go +++ b/common/limiter/limiter.go @@ -49,7 +49,7 @@ func (l *Limiter) AddInboundLimiter(tag string, nodeInfo *api.NodeInfo, userList if (*userList)[i].DeviceLimit == 0 { (*userList)[i].DeviceLimit = nodeInfo.DeviceLimit } - userMap.Store(fmt.Sprintf("%s|%d|%d", tag, (*userList)[i].Port, (*userList)[i].UID), UserInfo{ + userMap.Store(fmt.Sprintf("%s|%s|%d", tag, (*userList)[i].GetUserEmail(), (*userList)[i].UID), UserInfo{ UID: (*userList)[i].UID, SpeedLimit: (*userList)[i].SpeedLimit, DeviceLimit: (*userList)[i].DeviceLimit, diff --git a/common/rule/rule.go b/common/rule/rule.go index ee87c6a..260fea7 100644 --- a/common/rule/rule.go +++ b/common/rule/rule.go @@ -14,12 +14,14 @@ import ( type Rule struct { InboundRule *sync.Map // Key: Tag, Value: []api.DetectRule + InboundProtocolRule *sync.Map // Key: Tag, Value: []string InboundDetectResult *sync.Map // key: Tag, Value: mapset.NewSet []api.DetectResult } func New() *Rule { return &Rule{ InboundRule: new(sync.Map), + InboundProtocolRule: new(sync.Map), InboundDetectResult: new(sync.Map), } } @@ -34,6 +36,16 @@ func (r *Rule) UpdateRule(tag string, newRuleList []api.DetectRule) error { return nil } +func (r *Rule) UpdateProtocolRule(tag string, ruleList []string) error { + if value, ok := r.InboundProtocolRule.LoadOrStore(tag, ruleList); ok { + old := value.([]string) + if !reflect.DeepEqual(old, ruleList) { + r.InboundProtocolRule.Store(tag, ruleList) + } + } + return nil +} + func (r *Rule) GetDetectResult(tag string) (*[]api.DetectResult, error) { detectResult := make([]api.DetectResult, 0) if value, ok := r.InboundDetectResult.LoadAndDelete(tag); ok { @@ -80,3 +92,14 @@ func (r *Rule) Detect(tag string, destination string, email string) (reject bool } return reject } +func (r *Rule) ProtocolDetect(tag string, protocol string) bool { + if value, ok := r.InboundProtocolRule.Load(tag); ok { + ruleList := value.([]string) + for _, r := range ruleList { + if r == protocol { + return true + } + } + } + return false +} diff --git a/service/controller/control.go b/service/controller/control.go index ba855e7..da6301b 100644 --- a/service/controller/control.go +++ b/service/controller/control.go @@ -158,6 +158,12 @@ func (c *Controller) UpdateRule(tag string, newRuleList []api.DetectRule) error return err } +func (c *Controller) UpdateProtocolRule(tag string, newRuleList []string) error { + dispather := c.server.GetFeature(routing.DispatcherType()).(*mydispatcher.DefaultDispatcher) + err := dispather.RuleManager.UpdateProtocolRule(tag, newRuleList) + return err +} + func (c *Controller) GetDetectResult(tag string) (*[]api.DetectResult, error) { dispather := c.server.GetFeature(routing.DispatcherType()).(*mydispatcher.DefaultDispatcher) return dispather.RuleManager.GetDetectResult(tag) diff --git a/service/controller/controller.go b/service/controller/controller.go index cd14f00..58ee3e9 100644 --- a/service/controller/controller.go +++ b/service/controller/controller.go @@ -71,12 +71,17 @@ func (c *Controller) Start() error { } // Add Rule Manager if !c.config.DisableGetRule { - if ruleList, err := c.apiClient.GetNodeRule(); err != nil { + if ruleList, protocolRule, err := c.apiClient.GetNodeRule(); err != nil { log.Printf("Get rule list filed: %s", err) } else if len(*ruleList) > 0 { if err := c.UpdateRule(c.Tag, *ruleList); err != nil { log.Print(err) } + if len(*protocolRule) > 0 { + if err := c.UpdateProtocolRule(c.Tag, *protocolRule); err != nil { + log.Print(err) + } + } } } c.nodeInfoMonitorPeriodic = &task.Periodic{ @@ -160,12 +165,18 @@ func (c *Controller) nodeInfoMonitor() (err error) { // Check Rule if !c.config.DisableGetRule { - if ruleList, err := c.apiClient.GetNodeRule(); err != nil { + if ruleList, protocolRule, err := c.apiClient.GetNodeRule(); err != nil { log.Printf("Get rule list filed: %s", err) - } else if ruleList != nil { + } else if len(*ruleList) > 0 { if err := c.UpdateRule(c.Tag, *ruleList); err != nil { log.Print(err) } + if len(*protocolRule) > 0 { + if err := c.UpdateProtocolRule(c.Tag, *protocolRule); err != nil { + log.Print(err) + } + } + } }