add protocol rule,fix limit not work

This commit is contained in:
yuzuki999 2022-06-12 21:10:20 +08:00
parent da5f94247d
commit 2d279aa202
8 changed files with 83 additions and 30 deletions

View File

@ -47,7 +47,7 @@ type Client struct {
SpeedLimit float64 SpeedLimit float64
DeviceLimit int DeviceLimit int
LocalRuleList []DetectRule LocalRuleList []DetectRule
RemoteRuleCache *Rule RemoteRuleCache *[]Rule
access sync.Mutex access sync.Mutex
NodeInfoRspMd5 [16]byte NodeInfoRspMd5 [16]byte
NodeRuleRspMd5 [16]byte NodeRuleRspMd5 [16]byte

View File

@ -5,6 +5,6 @@ type API interface {
GetUserList() (userList *[]UserInfo, err error) GetUserList() (userList *[]UserInfo, err error)
ReportUserTraffic(userTraffic *[]UserTraffic) (err error) ReportUserTraffic(userTraffic *[]UserTraffic) (err error)
Describe() ClientInfo Describe() ClientInfo
GetNodeRule() (ruleList *[]DetectRule, err error) GetNodeRule() (ruleList *[]DetectRule, protocolList *[]string, err error)
Debug() Debug()
} }

View File

@ -77,7 +77,7 @@ type SSConfig struct {
type V2rayConfig struct { type V2rayConfig struct {
Inbounds []conf.InboundDetourConfig `json:"inbounds"` Inbounds []conf.InboundDetourConfig `json:"inbounds"`
Routing *struct { Routing *struct {
Rules []json.RawMessage `json:"rules"` Rules json.RawMessage `json:"rules"`
} `json:"routing"` } `json:"routing"`
} }
@ -160,25 +160,32 @@ func (c *Client) GetNodeInfo() (nodeInfo *NodeInfo, err error) {
return nodeInfo, nil return nodeInfo, nil
} }
func (c *Client) GetNodeRule() (*[]DetectRule, error) { func (c *Client) GetNodeRule() (*[]DetectRule, *[]string, error) {
ruleList := c.LocalRuleList ruleList := c.LocalRuleList
if c.NodeType != "V2ray" || c.RemoteRuleCache == nil { if c.NodeType != "V2ray" || c.RemoteRuleCache == nil {
return &ruleList, nil return &ruleList, nil, nil
} }
// V2board only support the rule for v2ray // V2board only support the rule for v2ray
// fix: reuse config response // fix: reuse config response
c.access.Lock() c.access.Lock()
defer c.access.Unlock() defer c.access.Unlock()
for i, rule := range c.RemoteRuleCache.Domain { if len(*c.RemoteRuleCache) >= 2 {
for i, rule := range (*c.RemoteRuleCache)[1].Domain {
ruleListItem := DetectRule{ ruleListItem := DetectRule{
ID: i, ID: i,
Pattern: regexp.MustCompile(rule), Pattern: regexp.MustCompile(rule),
} }
ruleList = append(ruleList, ruleListItem) ruleList = append(ruleList, ruleListItem)
} }
}
var protocolList []string
if len(*c.RemoteRuleCache) >= 3 {
for _, str := range (*c.RemoteRuleCache)[2].Protocol {
protocolList = append(protocolList, str)
}
}
c.RemoteRuleCache = nil c.RemoteRuleCache = nil
return &ruleList, nil return &ruleList, &protocolList, nil
} }
// ParseTrojanNodeResponse parse the response for the given nodeinfor format // ParseTrojanNodeResponse parse the response for the given nodeinfor format
@ -238,8 +245,8 @@ func (c *Client) ParseV2rayNodeResponse(body []byte, notParseNode, parseRule boo
return nil, fmt.Errorf("unmarshal nodeinfo error: %s", err) return nil, fmt.Errorf("unmarshal nodeinfo error: %s", err)
} }
if parseRule { if parseRule {
c.RemoteRuleCache = &Rule{} c.RemoteRuleCache = &[]Rule{}
err := json.Unmarshal(node.V2ray.Routing.Rules[1], c.RemoteRuleCache) err := json.Unmarshal(node.V2ray.Routing.Rules, c.RemoteRuleCache)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
} }

View File

@ -5,10 +5,6 @@ package mydispatcher
import ( import (
"context" "context"
"fmt" "fmt"
"strings"
"sync"
"time"
"github.com/Yuzuki616/V2bX/common/limiter" "github.com/Yuzuki616/V2bX/common/limiter"
"github.com/Yuzuki616/V2bX/common/rule" "github.com/Yuzuki616/V2bX/common/rule"
"github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common"
@ -26,6 +22,9 @@ import (
"github.com/xtls/xray-core/features/stats" "github.com/xtls/xray-core/features/stats"
"github.com/xtls/xray-core/transport" "github.com/xtls/xray-core/transport"
"github.com/xtls/xray-core/transport/pipe" "github.com/xtls/xray-core/transport/pipe"
"strings"
"sync"
"time"
) )
var errSniffingTimeout = newError("timeout on sniffing") var errSniffingTimeout = newError("timeout on sniffing")
@ -320,7 +319,7 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin
} }
switch { switch {
case !sniffingRequest.Enabled: case !sniffingRequest.Enabled:
go d.routedDispatch(ctx, outbound, destination) go d.routedDispatch(ctx, outbound, destination, "")
case destination.Network != net.Network_TCP: case destination.Network != net.Network_TCP:
// Only metadata sniff will be used for non tcp connection // Only metadata sniff will be used for non tcp connection
result, err := sniffer(ctx, nil, true) result, err := sniffer(ctx, nil, true)
@ -337,7 +336,7 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin
} }
} }
} }
go d.routedDispatch(ctx, outbound, destination) go d.routedDispatch(ctx, outbound, destination, content.Protocol)
default: default:
go func() { go func() {
cReader := &cachedReader{ cReader := &cachedReader{
@ -358,7 +357,7 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin
ob.Target = destination ob.Target = destination
} }
} }
d.routedDispatch(ctx, outbound, destination) d.routedDispatch(ctx, outbound, destination, content.Protocol)
}() }()
} }
return inbound, nil return inbound, nil
@ -381,7 +380,7 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De
sniffingRequest := content.SniffingRequest sniffingRequest := content.SniffingRequest
switch { switch {
case !sniffingRequest.Enabled: case !sniffingRequest.Enabled:
go d.routedDispatch(ctx, outbound, destination) go d.routedDispatch(ctx, outbound, destination, "")
case destination.Network != net.Network_TCP: case destination.Network != net.Network_TCP:
// Only metadata sniff will be used for non tcp connection // Only metadata sniff will be used for non tcp connection
result, err := sniffer(ctx, nil, true) result, err := sniffer(ctx, nil, true)
@ -398,7 +397,7 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De
} }
} }
} }
go d.routedDispatch(ctx, outbound, destination) go d.routedDispatch(ctx, outbound, destination, content.Protocol)
default: default:
go func() { go func() {
cReader := &cachedReader{ cReader := &cachedReader{
@ -419,7 +418,7 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De
ob.Target = destination ob.Target = destination
} }
} }
d.routedDispatch(ctx, outbound, destination) d.routedDispatch(ctx, outbound, destination, content.Protocol)
}() }()
} }
return nil return nil
@ -471,7 +470,7 @@ func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool) (Sni
return contentResult, contentErr return contentResult, contentErr
} }
func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.Link, destination net.Destination) { func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.Link, destination net.Destination, protocol string) {
ob := session.OutboundFromContext(ctx) ob := session.OutboundFromContext(ctx)
if hosts, ok := d.dns.(dns.HostsLookup); ok && destination.Address.Family().IsDomain() { if hosts, ok := d.dns.(dns.HostsLookup); ok && destination.Address.Family().IsDomain() {
proxied := hosts.LookupHosts(ob.Target.String()) proxied := hosts.LookupHosts(ob.Target.String())
@ -499,6 +498,13 @@ func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.
return return
}*/ }*/
if sessionInbound.User != nil { if sessionInbound.User != nil {
if d.RuleManager.ProtocolDetect(sessionInbound.Tag, protocol) {
newError(fmt.Sprintf("User %s access %s reject by protocol rule", sessionInbound.User.Email, destination.String())).AtError().WriteToLog()
newError("destination is reject by protocol rule")
common.Close(link.Writer)
common.Interrupt(link.Reader)
return
}
if d.RuleManager.Detect(sessionInbound.Tag, destination.String(), sessionInbound.User.Email) { if d.RuleManager.Detect(sessionInbound.Tag, destination.String(), sessionInbound.User.Email) {
newError(fmt.Sprintf("User %s access %s reject by rule", sessionInbound.User.Email, destination.String())).AtError().WriteToLog() newError(fmt.Sprintf("User %s access %s reject by rule", sessionInbound.User.Email, destination.String())).AtError().WriteToLog()
newError("destination is reject by rule") newError("destination is reject by rule")
@ -539,7 +545,7 @@ func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.
} }
if handler == nil { if handler == nil {
handler = d.ohm.GetHandler(inTag) // Default outbound hander tag should be as same as the inbound tag handler = d.ohm.GetHandler(inTag) // Default outbound handier tag should be as same as the inbound tag
} }
// If there is no outbound with tag as same as the inbound tag // If there is no outbound with tag as same as the inbound tag

View File

@ -49,7 +49,7 @@ func (l *Limiter) AddInboundLimiter(tag string, nodeInfo *api.NodeInfo, userList
if (*userList)[i].DeviceLimit == 0 { if (*userList)[i].DeviceLimit == 0 {
(*userList)[i].DeviceLimit = nodeInfo.DeviceLimit (*userList)[i].DeviceLimit = nodeInfo.DeviceLimit
} }
userMap.Store(fmt.Sprintf("%s|%d|%d", tag, (*userList)[i].Port, (*userList)[i].UID), UserInfo{ userMap.Store(fmt.Sprintf("%s|%s|%d", tag, (*userList)[i].GetUserEmail(), (*userList)[i].UID), UserInfo{
UID: (*userList)[i].UID, UID: (*userList)[i].UID,
SpeedLimit: (*userList)[i].SpeedLimit, SpeedLimit: (*userList)[i].SpeedLimit,
DeviceLimit: (*userList)[i].DeviceLimit, DeviceLimit: (*userList)[i].DeviceLimit,

View File

@ -14,12 +14,14 @@ import (
type Rule struct { type Rule struct {
InboundRule *sync.Map // Key: Tag, Value: []api.DetectRule InboundRule *sync.Map // Key: Tag, Value: []api.DetectRule
InboundProtocolRule *sync.Map // Key: Tag, Value: []string
InboundDetectResult *sync.Map // key: Tag, Value: mapset.NewSet []api.DetectResult InboundDetectResult *sync.Map // key: Tag, Value: mapset.NewSet []api.DetectResult
} }
func New() *Rule { func New() *Rule {
return &Rule{ return &Rule{
InboundRule: new(sync.Map), InboundRule: new(sync.Map),
InboundProtocolRule: new(sync.Map),
InboundDetectResult: new(sync.Map), InboundDetectResult: new(sync.Map),
} }
} }
@ -34,6 +36,16 @@ func (r *Rule) UpdateRule(tag string, newRuleList []api.DetectRule) error {
return nil return nil
} }
func (r *Rule) UpdateProtocolRule(tag string, ruleList []string) error {
if value, ok := r.InboundProtocolRule.LoadOrStore(tag, ruleList); ok {
old := value.([]string)
if !reflect.DeepEqual(old, ruleList) {
r.InboundProtocolRule.Store(tag, ruleList)
}
}
return nil
}
func (r *Rule) GetDetectResult(tag string) (*[]api.DetectResult, error) { func (r *Rule) GetDetectResult(tag string) (*[]api.DetectResult, error) {
detectResult := make([]api.DetectResult, 0) detectResult := make([]api.DetectResult, 0)
if value, ok := r.InboundDetectResult.LoadAndDelete(tag); ok { if value, ok := r.InboundDetectResult.LoadAndDelete(tag); ok {
@ -80,3 +92,14 @@ func (r *Rule) Detect(tag string, destination string, email string) (reject bool
} }
return reject return reject
} }
func (r *Rule) ProtocolDetect(tag string, protocol string) bool {
if value, ok := r.InboundProtocolRule.Load(tag); ok {
ruleList := value.([]string)
for _, r := range ruleList {
if r == protocol {
return true
}
}
}
return false
}

View File

@ -158,6 +158,12 @@ func (c *Controller) UpdateRule(tag string, newRuleList []api.DetectRule) error
return err return err
} }
func (c *Controller) UpdateProtocolRule(tag string, newRuleList []string) error {
dispather := c.server.GetFeature(routing.DispatcherType()).(*mydispatcher.DefaultDispatcher)
err := dispather.RuleManager.UpdateProtocolRule(tag, newRuleList)
return err
}
func (c *Controller) GetDetectResult(tag string) (*[]api.DetectResult, error) { func (c *Controller) GetDetectResult(tag string) (*[]api.DetectResult, error) {
dispather := c.server.GetFeature(routing.DispatcherType()).(*mydispatcher.DefaultDispatcher) dispather := c.server.GetFeature(routing.DispatcherType()).(*mydispatcher.DefaultDispatcher)
return dispather.RuleManager.GetDetectResult(tag) return dispather.RuleManager.GetDetectResult(tag)

View File

@ -71,12 +71,17 @@ func (c *Controller) Start() error {
} }
// Add Rule Manager // Add Rule Manager
if !c.config.DisableGetRule { if !c.config.DisableGetRule {
if ruleList, err := c.apiClient.GetNodeRule(); err != nil { if ruleList, protocolRule, err := c.apiClient.GetNodeRule(); err != nil {
log.Printf("Get rule list filed: %s", err) log.Printf("Get rule list filed: %s", err)
} else if len(*ruleList) > 0 { } else if len(*ruleList) > 0 {
if err := c.UpdateRule(c.Tag, *ruleList); err != nil { if err := c.UpdateRule(c.Tag, *ruleList); err != nil {
log.Print(err) log.Print(err)
} }
if len(*protocolRule) > 0 {
if err := c.UpdateProtocolRule(c.Tag, *protocolRule); err != nil {
log.Print(err)
}
}
} }
} }
c.nodeInfoMonitorPeriodic = &task.Periodic{ c.nodeInfoMonitorPeriodic = &task.Periodic{
@ -160,12 +165,18 @@ func (c *Controller) nodeInfoMonitor() (err error) {
// Check Rule // Check Rule
if !c.config.DisableGetRule { if !c.config.DisableGetRule {
if ruleList, err := c.apiClient.GetNodeRule(); err != nil { if ruleList, protocolRule, err := c.apiClient.GetNodeRule(); err != nil {
log.Printf("Get rule list filed: %s", err) log.Printf("Get rule list filed: %s", err)
} else if ruleList != nil { } else if len(*ruleList) > 0 {
if err := c.UpdateRule(c.Tag, *ruleList); err != nil { if err := c.UpdateRule(c.Tag, *ruleList); err != nil {
log.Print(err) log.Print(err)
} }
if len(*protocolRule) > 0 {
if err := c.UpdateProtocolRule(c.Tag, *protocolRule); err != nil {
log.Print(err)
}
}
} }
} }