support protocol rule

This commit is contained in:
Yuzuki616 2023-07-20 21:14:18 +08:00
parent 7cc57fe2ba
commit adf98fbc81
6 changed files with 92 additions and 48 deletions

View File

@ -4,18 +4,16 @@ import (
"bytes"
"encoding/base64"
"fmt"
log "github.com/sirupsen/logrus"
coreConf "github.com/xtls/xray-core/infra/conf"
"os"
"reflect"
"regexp"
"strconv"
"strings"
"time"
"github.com/Yuzuki616/V2bX/common/crypt"
"github.com/goccy/go-json"
log "github.com/sirupsen/logrus"
coreConf "github.com/xtls/xray-core/infra/conf"
)
type CommonNodeRsp struct {
@ -58,7 +56,7 @@ type HysteriaNodeRsp struct {
type NodeInfo struct {
Id int
Type string
Rules []*regexp.Regexp
Rules Rules
Host string
Port int
Network string
@ -75,6 +73,11 @@ type NodeInfo struct {
PullInterval time.Duration
}
type Rules struct {
Regexp []string
Protocol []string
}
type V2rayExtraConfig struct {
EnableVless string `json:"EnableVless"`
VlessFlow string `json:"VlessFlow"`
@ -131,7 +134,13 @@ func (c *Client) GetNodeInfo() (node *NodeInfo, err error) {
switch common.Routes[i].Action {
case "block":
for _, v := range matchs {
node.Rules = append(node.Rules, regexp.MustCompile(v))
if strings.HasPrefix(v, "protocol:") {
// protocol
node.Rules.Protocol = append(node.Rules.Protocol, strings.TrimPrefix(v, "protocol:"))
} else {
// domain
node.Rules.Regexp = append(node.Rules.Regexp, strings.TrimPrefix(v, "regexp:"))
}
}
case "dns":
if matchs[0] != "main" {

View File

@ -337,7 +337,13 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin
reader: outbound.Reader.(*pipe.Reader),
}
outbound.Reader = cReader
result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly, destination.Network)
result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly, destination.Network, l)
if _, ok := err.(limitedError); ok {
newError(err).AtInfo().WriteToLog()
common.Close(outbound.Writer)
common.Interrupt(outbound.Reader)
return
}
if err == nil {
content.Protocol = result.Protocol()
}
@ -380,7 +386,13 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De
reader: outbound.Reader.(*pipe.Reader),
}
outbound.Reader = cReader
result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly, destination.Network)
result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly, destination.Network, nil)
if _, ok := err.(limitedError); ok {
newError(err).AtInfo().WriteToLog()
common.Close(outbound.Writer)
common.Interrupt(outbound.Reader)
return
}
if err == nil {
content.Protocol = result.Protocol()
}
@ -400,18 +412,50 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De
return nil
}
func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool, network net.Network) (SniffResult, error) {
type limitedError string
func (l limitedError) Error() string {
return string(l)
}
func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool, network net.Network, l *limiter.Limiter) (result SniffResult, err error) {
payload := buf.New()
defer payload.Release()
defer func() {
if err != nil {
return
}
// Check if domain and protocol hit the rule
sessionInbound := session.InboundFromContext(ctx)
// Whether the inbound connection contains a user
if sessionInbound.User != nil {
if l == nil {
l, err = limiter.GetLimiter(sessionInbound.Tag)
if err != nil {
return
}
}
if l.CheckDomainRule(result.Domain()) {
err = limitedError(fmt.Sprintf(
"User %s access domain %s reject by rule",
sessionInbound.User.Email,
result.Domain()))
}
if l.CheckProtocolRule(result.Protocol()) {
err = limitedError(fmt.Sprintf(
"User %s access protocol %s reject by rule",
sessionInbound.User.Email,
result.Protocol()))
}
}
}()
sniffer := NewSniffer(ctx)
metaresult, metadataErr := sniffer.SniffMetadata(ctx)
if metadataOnly {
return metaresult, metadataErr
}
contentResult, contentErr := func() (SniffResult, error) {
totalAttempt := 0
for {
@ -460,32 +504,17 @@ func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.
}
}
}
var handler outbound.Handler
// Check if domain and protocol hit the rule
// del connect count
if l != nil {
sessionInbound := session.InboundFromContext(ctx)
// Whether the inbound connection contains a user
if sessionInbound.User != nil {
if l == nil {
var err error
l, err = limiter.GetLimiter(sessionInbound.Tag)
if err != nil {
newError("Get limiter error: ", err).AtError().WriteToLog()
common.Close(link.Writer)
common.Interrupt(link.Reader)
return
}
} else if destination.Network == net.Network_TCP {
if destination.Network == net.Network_TCP {
defer func() {
l.ConnLimiter.DelConnCount(sessionInbound.User.Email, sessionInbound.Source.Address.IP().String())
}()
}
if l.CheckDomainRule(destination.Address.String()) {
newError(fmt.Sprintf("User %s access %s reject by rule", sessionInbound.User.Email, destination.String())).AtError().WriteToLog()
common.Close(link.Writer)
common.Interrupt(link.Reader)
return
}
}

View File

@ -3,15 +3,16 @@ package limiter
import (
"errors"
"fmt"
"regexp"
"sync"
"time"
"github.com/Yuzuki616/V2bX/api/panel"
"github.com/Yuzuki616/V2bX/common/format"
"github.com/Yuzuki616/V2bX/conf"
"github.com/juju/ratelimit"
log "github.com/sirupsen/logrus"
"github.com/xtls/xray-core/common/task"
"regexp"
"sync"
"time"
)
var limitLock sync.RWMutex
@ -32,7 +33,7 @@ func Init() {
}
type Limiter struct {
Rules []*regexp.Regexp
DomainRules []*regexp.Regexp
ProtocolRules []string
SpeedLimit int
UserLimitInfo *sync.Map // Key: Uid value: UserLimitInfo

View File

@ -1,14 +1,15 @@
package limiter
import (
"reflect"
"regexp"
"github.com/Yuzuki616/V2bX/api/panel"
)
func (l *Limiter) CheckDomainRule(destination string) (reject bool) {
// have rule
for i := range l.Rules {
if l.Rules[i].MatchString(destination) {
for i := range l.DomainRules {
if l.DomainRules[i].MatchString(destination) {
reject = true
break
}
@ -26,9 +27,11 @@ func (l *Limiter) CheckProtocolRule(protocol string) (reject bool) {
return
}
func (l *Limiter) UpdateRule(newRuleList []*regexp.Regexp) error {
if !reflect.DeepEqual(l.Rules, newRuleList) {
l.Rules = newRuleList
func (l *Limiter) UpdateRule(rule *panel.Rules) error {
l.DomainRules = make([]*regexp.Regexp, len(rule.Regexp))
for i := range rule.Regexp {
l.DomainRules[i] = regexp.MustCompile(rule.Regexp[i])
}
l.ProtocolRules = rule.Protocol
return nil
}

View File

@ -3,6 +3,7 @@ package node
import (
"errors"
"fmt"
"github.com/Yuzuki616/V2bX/api/iprecoder"
"github.com/Yuzuki616/V2bX/api/panel"
"github.com/Yuzuki616/V2bX/common/task"
@ -58,7 +59,7 @@ func (c *Controller) Start() error {
// add limiter
l := limiter.AddLimiter(c.Tag, &c.LimitConfig, c.userList)
// add rule limiter
if err = l.UpdateRule(c.nodeInfo.Rules); err != nil {
if err = l.UpdateRule(&c.nodeInfo.Rules); err != nil {
return fmt.Errorf("update rule error: %s", err)
}
if c.nodeInfo.Tls || c.nodeInfo.Type == "hysteria" {

View File

@ -1,11 +1,12 @@
package node
import (
"time"
"github.com/Yuzuki616/V2bX/common/task"
vCore "github.com/Yuzuki616/V2bX/core"
"github.com/Yuzuki616/V2bX/limiter"
log "github.com/sirupsen/logrus"
"time"
)
func (c *Controller) initTask() {
@ -108,7 +109,7 @@ func (c *Controller) nodeInfoMonitor() (err error) {
}).Error("Add users failed")
return nil
}
err = l.UpdateRule(newNodeInfo.Rules)
err = l.UpdateRule(&newNodeInfo.Rules)
if err != nil {
log.WithFields(log.Fields{
"tag": c.Tag,