support protocol rule

This commit is contained in:
Yuzuki616 2023-07-20 21:14:18 +08:00
parent 7cc57fe2ba
commit adf98fbc81
6 changed files with 92 additions and 48 deletions

View File

@ -4,18 +4,16 @@ import (
"bytes" "bytes"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
log "github.com/sirupsen/logrus"
coreConf "github.com/xtls/xray-core/infra/conf"
"os" "os"
"reflect" "reflect"
"regexp"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/Yuzuki616/V2bX/common/crypt" "github.com/Yuzuki616/V2bX/common/crypt"
"github.com/goccy/go-json" "github.com/goccy/go-json"
log "github.com/sirupsen/logrus"
coreConf "github.com/xtls/xray-core/infra/conf"
) )
type CommonNodeRsp struct { type CommonNodeRsp struct {
@ -58,7 +56,7 @@ type HysteriaNodeRsp struct {
type NodeInfo struct { type NodeInfo struct {
Id int Id int
Type string Type string
Rules []*regexp.Regexp Rules Rules
Host string Host string
Port int Port int
Network string Network string
@ -75,6 +73,11 @@ type NodeInfo struct {
PullInterval time.Duration PullInterval time.Duration
} }
type Rules struct {
Regexp []string
Protocol []string
}
type V2rayExtraConfig struct { type V2rayExtraConfig struct {
EnableVless string `json:"EnableVless"` EnableVless string `json:"EnableVless"`
VlessFlow string `json:"VlessFlow"` VlessFlow string `json:"VlessFlow"`
@ -131,7 +134,13 @@ func (c *Client) GetNodeInfo() (node *NodeInfo, err error) {
switch common.Routes[i].Action { switch common.Routes[i].Action {
case "block": case "block":
for _, v := range matchs { for _, v := range matchs {
node.Rules = append(node.Rules, regexp.MustCompile(v)) if strings.HasPrefix(v, "protocol:") {
// protocol
node.Rules.Protocol = append(node.Rules.Protocol, strings.TrimPrefix(v, "protocol:"))
} else {
// domain
node.Rules.Regexp = append(node.Rules.Regexp, strings.TrimPrefix(v, "regexp:"))
}
} }
case "dns": case "dns":
if matchs[0] != "main" { if matchs[0] != "main" {

View File

@ -337,7 +337,13 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin
reader: outbound.Reader.(*pipe.Reader), reader: outbound.Reader.(*pipe.Reader),
} }
outbound.Reader = cReader outbound.Reader = cReader
result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly, destination.Network) result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly, destination.Network, l)
if _, ok := err.(limitedError); ok {
newError(err).AtInfo().WriteToLog()
common.Close(outbound.Writer)
common.Interrupt(outbound.Reader)
return
}
if err == nil { if err == nil {
content.Protocol = result.Protocol() content.Protocol = result.Protocol()
} }
@ -380,7 +386,13 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De
reader: outbound.Reader.(*pipe.Reader), reader: outbound.Reader.(*pipe.Reader),
} }
outbound.Reader = cReader outbound.Reader = cReader
result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly, destination.Network) result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly, destination.Network, nil)
if _, ok := err.(limitedError); ok {
newError(err).AtInfo().WriteToLog()
common.Close(outbound.Writer)
common.Interrupt(outbound.Reader)
return
}
if err == nil { if err == nil {
content.Protocol = result.Protocol() content.Protocol = result.Protocol()
} }
@ -400,18 +412,50 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De
return nil return nil
} }
func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool, network net.Network) (SniffResult, error) { type limitedError string
func (l limitedError) Error() string {
return string(l)
}
func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool, network net.Network, l *limiter.Limiter) (result SniffResult, err error) {
payload := buf.New() payload := buf.New()
defer payload.Release() defer payload.Release()
defer func() {
if err != nil {
return
}
// Check if domain and protocol hit the rule
sessionInbound := session.InboundFromContext(ctx)
// Whether the inbound connection contains a user
if sessionInbound.User != nil {
if l == nil {
l, err = limiter.GetLimiter(sessionInbound.Tag)
if err != nil {
return
}
}
if l.CheckDomainRule(result.Domain()) {
err = limitedError(fmt.Sprintf(
"User %s access domain %s reject by rule",
sessionInbound.User.Email,
result.Domain()))
}
if l.CheckProtocolRule(result.Protocol()) {
err = limitedError(fmt.Sprintf(
"User %s access protocol %s reject by rule",
sessionInbound.User.Email,
result.Protocol()))
}
}
}()
sniffer := NewSniffer(ctx) sniffer := NewSniffer(ctx)
metaresult, metadataErr := sniffer.SniffMetadata(ctx) metaresult, metadataErr := sniffer.SniffMetadata(ctx)
if metadataOnly { if metadataOnly {
return metaresult, metadataErr return metaresult, metadataErr
} }
contentResult, contentErr := func() (SniffResult, error) { contentResult, contentErr := func() (SniffResult, error) {
totalAttempt := 0 totalAttempt := 0
for { for {
@ -460,32 +504,17 @@ func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.
} }
} }
} }
var handler outbound.Handler var handler outbound.Handler
// Check if domain and protocol hit the rule // del connect count
sessionInbound := session.InboundFromContext(ctx) if l != nil {
// Whether the inbound connection contains a user sessionInbound := session.InboundFromContext(ctx)
if sessionInbound.User != nil { if sessionInbound.User != nil {
if l == nil { if destination.Network == net.Network_TCP {
var err error defer func() {
l, err = limiter.GetLimiter(sessionInbound.Tag) l.ConnLimiter.DelConnCount(sessionInbound.User.Email, sessionInbound.Source.Address.IP().String())
if err != nil { }()
newError("Get limiter error: ", err).AtError().WriteToLog()
common.Close(link.Writer)
common.Interrupt(link.Reader)
return
} }
} else if destination.Network == net.Network_TCP {
defer func() {
l.ConnLimiter.DelConnCount(sessionInbound.User.Email, sessionInbound.Source.Address.IP().String())
}()
}
if l.CheckDomainRule(destination.Address.String()) {
newError(fmt.Sprintf("User %s access %s reject by rule", sessionInbound.User.Email, destination.String())).AtError().WriteToLog()
common.Close(link.Writer)
common.Interrupt(link.Reader)
return
} }
} }

View File

@ -3,15 +3,16 @@ package limiter
import ( import (
"errors" "errors"
"fmt" "fmt"
"regexp"
"sync"
"time"
"github.com/Yuzuki616/V2bX/api/panel" "github.com/Yuzuki616/V2bX/api/panel"
"github.com/Yuzuki616/V2bX/common/format" "github.com/Yuzuki616/V2bX/common/format"
"github.com/Yuzuki616/V2bX/conf" "github.com/Yuzuki616/V2bX/conf"
"github.com/juju/ratelimit" "github.com/juju/ratelimit"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/xtls/xray-core/common/task" "github.com/xtls/xray-core/common/task"
"regexp"
"sync"
"time"
) )
var limitLock sync.RWMutex var limitLock sync.RWMutex
@ -32,7 +33,7 @@ func Init() {
} }
type Limiter struct { type Limiter struct {
Rules []*regexp.Regexp DomainRules []*regexp.Regexp
ProtocolRules []string ProtocolRules []string
SpeedLimit int SpeedLimit int
UserLimitInfo *sync.Map // Key: Uid value: UserLimitInfo UserLimitInfo *sync.Map // Key: Uid value: UserLimitInfo

View File

@ -1,14 +1,15 @@
package limiter package limiter
import ( import (
"reflect"
"regexp" "regexp"
"github.com/Yuzuki616/V2bX/api/panel"
) )
func (l *Limiter) CheckDomainRule(destination string) (reject bool) { func (l *Limiter) CheckDomainRule(destination string) (reject bool) {
// have rule // have rule
for i := range l.Rules { for i := range l.DomainRules {
if l.Rules[i].MatchString(destination) { if l.DomainRules[i].MatchString(destination) {
reject = true reject = true
break break
} }
@ -26,9 +27,11 @@ func (l *Limiter) CheckProtocolRule(protocol string) (reject bool) {
return return
} }
func (l *Limiter) UpdateRule(newRuleList []*regexp.Regexp) error { func (l *Limiter) UpdateRule(rule *panel.Rules) error {
if !reflect.DeepEqual(l.Rules, newRuleList) { l.DomainRules = make([]*regexp.Regexp, len(rule.Regexp))
l.Rules = newRuleList for i := range rule.Regexp {
l.DomainRules[i] = regexp.MustCompile(rule.Regexp[i])
} }
l.ProtocolRules = rule.Protocol
return nil return nil
} }

View File

@ -3,6 +3,7 @@ package node
import ( import (
"errors" "errors"
"fmt" "fmt"
"github.com/Yuzuki616/V2bX/api/iprecoder" "github.com/Yuzuki616/V2bX/api/iprecoder"
"github.com/Yuzuki616/V2bX/api/panel" "github.com/Yuzuki616/V2bX/api/panel"
"github.com/Yuzuki616/V2bX/common/task" "github.com/Yuzuki616/V2bX/common/task"
@ -58,7 +59,7 @@ func (c *Controller) Start() error {
// add limiter // add limiter
l := limiter.AddLimiter(c.Tag, &c.LimitConfig, c.userList) l := limiter.AddLimiter(c.Tag, &c.LimitConfig, c.userList)
// add rule limiter // add rule limiter
if err = l.UpdateRule(c.nodeInfo.Rules); err != nil { if err = l.UpdateRule(&c.nodeInfo.Rules); err != nil {
return fmt.Errorf("update rule error: %s", err) return fmt.Errorf("update rule error: %s", err)
} }
if c.nodeInfo.Tls || c.nodeInfo.Type == "hysteria" { if c.nodeInfo.Tls || c.nodeInfo.Type == "hysteria" {

View File

@ -1,11 +1,12 @@
package node package node
import ( import (
"time"
"github.com/Yuzuki616/V2bX/common/task" "github.com/Yuzuki616/V2bX/common/task"
vCore "github.com/Yuzuki616/V2bX/core" vCore "github.com/Yuzuki616/V2bX/core"
"github.com/Yuzuki616/V2bX/limiter" "github.com/Yuzuki616/V2bX/limiter"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"time"
) )
func (c *Controller) initTask() { func (c *Controller) initTask() {
@ -108,7 +109,7 @@ func (c *Controller) nodeInfoMonitor() (err error) {
}).Error("Add users failed") }).Error("Add users failed")
return nil return nil
} }
err = l.UpdateRule(newNodeInfo.Rules) err = l.UpdateRule(&newNodeInfo.Rules)
if err != nil { if err != nil {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"tag": c.Tag, "tag": c.Tag,