From adf98fbc81fcd6b9156bd0f124c79bb73da18891 Mon Sep 17 00:00:00 2001 From: Yuzuki616 Date: Thu, 20 Jul 2023 21:14:18 +0800 Subject: [PATCH] support protocol rule --- api/panel/node.go | 21 +++++-- core/xray/app/dispatcher/default.go | 87 +++++++++++++++++++---------- limiter/limiter.go | 9 +-- limiter/rule.go | 15 +++-- node/controller.go | 3 +- node/task.go | 5 +- 6 files changed, 92 insertions(+), 48 deletions(-) diff --git a/api/panel/node.go b/api/panel/node.go index 86f311f..d902060 100644 --- a/api/panel/node.go +++ b/api/panel/node.go @@ -4,18 +4,16 @@ import ( "bytes" "encoding/base64" "fmt" - log "github.com/sirupsen/logrus" - coreConf "github.com/xtls/xray-core/infra/conf" "os" "reflect" - "regexp" "strconv" "strings" "time" "github.com/Yuzuki616/V2bX/common/crypt" - "github.com/goccy/go-json" + log "github.com/sirupsen/logrus" + coreConf "github.com/xtls/xray-core/infra/conf" ) type CommonNodeRsp struct { @@ -58,7 +56,7 @@ type HysteriaNodeRsp struct { type NodeInfo struct { Id int Type string - Rules []*regexp.Regexp + Rules Rules Host string Port int Network string @@ -75,6 +73,11 @@ type NodeInfo struct { PullInterval time.Duration } +type Rules struct { + Regexp []string + Protocol []string +} + type V2rayExtraConfig struct { EnableVless string `json:"EnableVless"` VlessFlow string `json:"VlessFlow"` @@ -131,7 +134,13 @@ func (c *Client) GetNodeInfo() (node *NodeInfo, err error) { switch common.Routes[i].Action { case "block": for _, v := range matchs { - node.Rules = append(node.Rules, regexp.MustCompile(v)) + if strings.HasPrefix(v, "protocol:") { + // protocol + node.Rules.Protocol = append(node.Rules.Protocol, strings.TrimPrefix(v, "protocol:")) + } else { + // domain + node.Rules.Regexp = append(node.Rules.Regexp, strings.TrimPrefix(v, "regexp:")) + } } case "dns": if matchs[0] != "main" { diff --git a/core/xray/app/dispatcher/default.go b/core/xray/app/dispatcher/default.go index 7c33636..d872b7f 100644 --- a/core/xray/app/dispatcher/default.go +++ b/core/xray/app/dispatcher/default.go @@ -337,7 +337,13 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin reader: outbound.Reader.(*pipe.Reader), } outbound.Reader = cReader - result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly, destination.Network) + result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly, destination.Network, l) + if _, ok := err.(limitedError); ok { + newError(err).AtInfo().WriteToLog() + common.Close(outbound.Writer) + common.Interrupt(outbound.Reader) + return + } if err == nil { content.Protocol = result.Protocol() } @@ -380,7 +386,13 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De reader: outbound.Reader.(*pipe.Reader), } outbound.Reader = cReader - result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly, destination.Network) + result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly, destination.Network, nil) + if _, ok := err.(limitedError); ok { + newError(err).AtInfo().WriteToLog() + common.Close(outbound.Writer) + common.Interrupt(outbound.Reader) + return + } if err == nil { content.Protocol = result.Protocol() } @@ -400,18 +412,50 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De return nil } -func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool, network net.Network) (SniffResult, error) { +type limitedError string + +func (l limitedError) Error() string { + return string(l) +} + +func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool, network net.Network, l *limiter.Limiter) (result SniffResult, err error) { payload := buf.New() defer payload.Release() + defer func() { + if err != nil { + return + } + // Check if domain and protocol hit the rule + sessionInbound := session.InboundFromContext(ctx) + // Whether the inbound connection contains a user + if sessionInbound.User != nil { + if l == nil { + l, err = limiter.GetLimiter(sessionInbound.Tag) + if err != nil { + return + } + } + if l.CheckDomainRule(result.Domain()) { + err = limitedError(fmt.Sprintf( + "User %s access domain %s reject by rule", + sessionInbound.User.Email, + result.Domain())) + } + if l.CheckProtocolRule(result.Protocol()) { + err = limitedError(fmt.Sprintf( + "User %s access protocol %s reject by rule", + sessionInbound.User.Email, + result.Protocol())) + } + } + }() + sniffer := NewSniffer(ctx) - metaresult, metadataErr := sniffer.SniffMetadata(ctx) - if metadataOnly { return metaresult, metadataErr } - contentResult, contentErr := func() (SniffResult, error) { totalAttempt := 0 for { @@ -460,32 +504,17 @@ func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport. } } } - var handler outbound.Handler - // Check if domain and protocol hit the rule - sessionInbound := session.InboundFromContext(ctx) - // Whether the inbound connection contains a user - if sessionInbound.User != nil { - if l == nil { - var err error - l, err = limiter.GetLimiter(sessionInbound.Tag) - if err != nil { - newError("Get limiter error: ", err).AtError().WriteToLog() - common.Close(link.Writer) - common.Interrupt(link.Reader) - return + // del connect count + if l != nil { + sessionInbound := session.InboundFromContext(ctx) + if sessionInbound.User != nil { + if destination.Network == net.Network_TCP { + defer func() { + l.ConnLimiter.DelConnCount(sessionInbound.User.Email, sessionInbound.Source.Address.IP().String()) + }() } - } else if destination.Network == net.Network_TCP { - defer func() { - l.ConnLimiter.DelConnCount(sessionInbound.User.Email, sessionInbound.Source.Address.IP().String()) - }() - } - if l.CheckDomainRule(destination.Address.String()) { - newError(fmt.Sprintf("User %s access %s reject by rule", sessionInbound.User.Email, destination.String())).AtError().WriteToLog() - common.Close(link.Writer) - common.Interrupt(link.Reader) - return } } diff --git a/limiter/limiter.go b/limiter/limiter.go index 149538a..f86e92d 100644 --- a/limiter/limiter.go +++ b/limiter/limiter.go @@ -3,15 +3,16 @@ package limiter import ( "errors" "fmt" + "regexp" + "sync" + "time" + "github.com/Yuzuki616/V2bX/api/panel" "github.com/Yuzuki616/V2bX/common/format" "github.com/Yuzuki616/V2bX/conf" "github.com/juju/ratelimit" log "github.com/sirupsen/logrus" "github.com/xtls/xray-core/common/task" - "regexp" - "sync" - "time" ) var limitLock sync.RWMutex @@ -32,7 +33,7 @@ func Init() { } type Limiter struct { - Rules []*regexp.Regexp + DomainRules []*regexp.Regexp ProtocolRules []string SpeedLimit int UserLimitInfo *sync.Map // Key: Uid value: UserLimitInfo diff --git a/limiter/rule.go b/limiter/rule.go index ce4dcf5..4f6e4c8 100644 --- a/limiter/rule.go +++ b/limiter/rule.go @@ -1,14 +1,15 @@ package limiter import ( - "reflect" "regexp" + + "github.com/Yuzuki616/V2bX/api/panel" ) func (l *Limiter) CheckDomainRule(destination string) (reject bool) { // have rule - for i := range l.Rules { - if l.Rules[i].MatchString(destination) { + for i := range l.DomainRules { + if l.DomainRules[i].MatchString(destination) { reject = true break } @@ -26,9 +27,11 @@ func (l *Limiter) CheckProtocolRule(protocol string) (reject bool) { return } -func (l *Limiter) UpdateRule(newRuleList []*regexp.Regexp) error { - if !reflect.DeepEqual(l.Rules, newRuleList) { - l.Rules = newRuleList +func (l *Limiter) UpdateRule(rule *panel.Rules) error { + l.DomainRules = make([]*regexp.Regexp, len(rule.Regexp)) + for i := range rule.Regexp { + l.DomainRules[i] = regexp.MustCompile(rule.Regexp[i]) } + l.ProtocolRules = rule.Protocol return nil } diff --git a/node/controller.go b/node/controller.go index 725b2ad..f4e4853 100644 --- a/node/controller.go +++ b/node/controller.go @@ -3,6 +3,7 @@ package node import ( "errors" "fmt" + "github.com/Yuzuki616/V2bX/api/iprecoder" "github.com/Yuzuki616/V2bX/api/panel" "github.com/Yuzuki616/V2bX/common/task" @@ -58,7 +59,7 @@ func (c *Controller) Start() error { // add limiter l := limiter.AddLimiter(c.Tag, &c.LimitConfig, c.userList) // add rule limiter - if err = l.UpdateRule(c.nodeInfo.Rules); err != nil { + if err = l.UpdateRule(&c.nodeInfo.Rules); err != nil { return fmt.Errorf("update rule error: %s", err) } if c.nodeInfo.Tls || c.nodeInfo.Type == "hysteria" { diff --git a/node/task.go b/node/task.go index 550f11d..77f1985 100644 --- a/node/task.go +++ b/node/task.go @@ -1,11 +1,12 @@ package node import ( + "time" + "github.com/Yuzuki616/V2bX/common/task" vCore "github.com/Yuzuki616/V2bX/core" "github.com/Yuzuki616/V2bX/limiter" log "github.com/sirupsen/logrus" - "time" ) func (c *Controller) initTask() { @@ -108,7 +109,7 @@ func (c *Controller) nodeInfoMonitor() (err error) { }).Error("Add users failed") return nil } - err = l.UpdateRule(newNodeInfo.Rules) + err = l.UpdateRule(&newNodeInfo.Rules) if err != nil { log.WithFields(log.Fields{ "tag": c.Tag,