add protocol rule,fix limit not work

This commit is contained in:
yuzuki999 2022-06-12 21:10:20 +08:00
parent da5f94247d
commit 2d279aa202
8 changed files with 83 additions and 30 deletions

View File

@ -47,7 +47,7 @@ type Client struct {
SpeedLimit float64
DeviceLimit int
LocalRuleList []DetectRule
RemoteRuleCache *Rule
RemoteRuleCache *[]Rule
access sync.Mutex
NodeInfoRspMd5 [16]byte
NodeRuleRspMd5 [16]byte

View File

@ -5,6 +5,6 @@ type API interface {
GetUserList() (userList *[]UserInfo, err error)
ReportUserTraffic(userTraffic *[]UserTraffic) (err error)
Describe() ClientInfo
GetNodeRule() (ruleList *[]DetectRule, err error)
GetNodeRule() (ruleList *[]DetectRule, protocolList *[]string, err error)
Debug()
}

View File

@ -77,7 +77,7 @@ type SSConfig struct {
type V2rayConfig struct {
Inbounds []conf.InboundDetourConfig `json:"inbounds"`
Routing *struct {
Rules []json.RawMessage `json:"rules"`
Rules json.RawMessage `json:"rules"`
} `json:"routing"`
}
@ -160,25 +160,32 @@ func (c *Client) GetNodeInfo() (nodeInfo *NodeInfo, err error) {
return nodeInfo, nil
}
func (c *Client) GetNodeRule() (*[]DetectRule, error) {
func (c *Client) GetNodeRule() (*[]DetectRule, *[]string, error) {
ruleList := c.LocalRuleList
if c.NodeType != "V2ray" || c.RemoteRuleCache == nil {
return &ruleList, nil
return &ruleList, nil, nil
}
// V2board only support the rule for v2ray
// fix: reuse config response
c.access.Lock()
defer c.access.Unlock()
for i, rule := range c.RemoteRuleCache.Domain {
ruleListItem := DetectRule{
ID: i,
Pattern: regexp.MustCompile(rule),
if len(*c.RemoteRuleCache) >= 2 {
for i, rule := range (*c.RemoteRuleCache)[1].Domain {
ruleListItem := DetectRule{
ID: i,
Pattern: regexp.MustCompile(rule),
}
ruleList = append(ruleList, ruleListItem)
}
}
var protocolList []string
if len(*c.RemoteRuleCache) >= 3 {
for _, str := range (*c.RemoteRuleCache)[2].Protocol {
protocolList = append(protocolList, str)
}
ruleList = append(ruleList, ruleListItem)
}
c.RemoteRuleCache = nil
return &ruleList, nil
return &ruleList, &protocolList, nil
}
// ParseTrojanNodeResponse parse the response for the given nodeinfor format
@ -238,8 +245,8 @@ func (c *Client) ParseV2rayNodeResponse(body []byte, notParseNode, parseRule boo
return nil, fmt.Errorf("unmarshal nodeinfo error: %s", err)
}
if parseRule {
c.RemoteRuleCache = &Rule{}
err := json.Unmarshal(node.V2ray.Routing.Rules[1], c.RemoteRuleCache)
c.RemoteRuleCache = &[]Rule{}
err := json.Unmarshal(node.V2ray.Routing.Rules, c.RemoteRuleCache)
if err != nil {
log.Println(err)
}

View File

@ -5,10 +5,6 @@ package mydispatcher
import (
"context"
"fmt"
"strings"
"sync"
"time"
"github.com/Yuzuki616/V2bX/common/limiter"
"github.com/Yuzuki616/V2bX/common/rule"
"github.com/xtls/xray-core/common"
@ -26,6 +22,9 @@ import (
"github.com/xtls/xray-core/features/stats"
"github.com/xtls/xray-core/transport"
"github.com/xtls/xray-core/transport/pipe"
"strings"
"sync"
"time"
)
var errSniffingTimeout = newError("timeout on sniffing")
@ -320,7 +319,7 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin
}
switch {
case !sniffingRequest.Enabled:
go d.routedDispatch(ctx, outbound, destination)
go d.routedDispatch(ctx, outbound, destination, "")
case destination.Network != net.Network_TCP:
// Only metadata sniff will be used for non tcp connection
result, err := sniffer(ctx, nil, true)
@ -337,7 +336,7 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin
}
}
}
go d.routedDispatch(ctx, outbound, destination)
go d.routedDispatch(ctx, outbound, destination, content.Protocol)
default:
go func() {
cReader := &cachedReader{
@ -358,7 +357,7 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin
ob.Target = destination
}
}
d.routedDispatch(ctx, outbound, destination)
d.routedDispatch(ctx, outbound, destination, content.Protocol)
}()
}
return inbound, nil
@ -381,7 +380,7 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De
sniffingRequest := content.SniffingRequest
switch {
case !sniffingRequest.Enabled:
go d.routedDispatch(ctx, outbound, destination)
go d.routedDispatch(ctx, outbound, destination, "")
case destination.Network != net.Network_TCP:
// Only metadata sniff will be used for non tcp connection
result, err := sniffer(ctx, nil, true)
@ -398,7 +397,7 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De
}
}
}
go d.routedDispatch(ctx, outbound, destination)
go d.routedDispatch(ctx, outbound, destination, content.Protocol)
default:
go func() {
cReader := &cachedReader{
@ -419,7 +418,7 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De
ob.Target = destination
}
}
d.routedDispatch(ctx, outbound, destination)
d.routedDispatch(ctx, outbound, destination, content.Protocol)
}()
}
return nil
@ -471,7 +470,7 @@ func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool) (Sni
return contentResult, contentErr
}
func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.Link, destination net.Destination) {
func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.Link, destination net.Destination, protocol string) {
ob := session.OutboundFromContext(ctx)
if hosts, ok := d.dns.(dns.HostsLookup); ok && destination.Address.Family().IsDomain() {
proxied := hosts.LookupHosts(ob.Target.String())
@ -499,6 +498,13 @@ func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.
return
}*/
if sessionInbound.User != nil {
if d.RuleManager.ProtocolDetect(sessionInbound.Tag, protocol) {
newError(fmt.Sprintf("User %s access %s reject by protocol rule", sessionInbound.User.Email, destination.String())).AtError().WriteToLog()
newError("destination is reject by protocol rule")
common.Close(link.Writer)
common.Interrupt(link.Reader)
return
}
if d.RuleManager.Detect(sessionInbound.Tag, destination.String(), sessionInbound.User.Email) {
newError(fmt.Sprintf("User %s access %s reject by rule", sessionInbound.User.Email, destination.String())).AtError().WriteToLog()
newError("destination is reject by rule")
@ -539,7 +545,7 @@ func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.
}
if handler == nil {
handler = d.ohm.GetHandler(inTag) // Default outbound hander tag should be as same as the inbound tag
handler = d.ohm.GetHandler(inTag) // Default outbound handier tag should be as same as the inbound tag
}
// If there is no outbound with tag as same as the inbound tag

View File

@ -49,7 +49,7 @@ func (l *Limiter) AddInboundLimiter(tag string, nodeInfo *api.NodeInfo, userList
if (*userList)[i].DeviceLimit == 0 {
(*userList)[i].DeviceLimit = nodeInfo.DeviceLimit
}
userMap.Store(fmt.Sprintf("%s|%d|%d", tag, (*userList)[i].Port, (*userList)[i].UID), UserInfo{
userMap.Store(fmt.Sprintf("%s|%s|%d", tag, (*userList)[i].GetUserEmail(), (*userList)[i].UID), UserInfo{
UID: (*userList)[i].UID,
SpeedLimit: (*userList)[i].SpeedLimit,
DeviceLimit: (*userList)[i].DeviceLimit,

View File

@ -14,12 +14,14 @@ import (
type Rule struct {
InboundRule *sync.Map // Key: Tag, Value: []api.DetectRule
InboundProtocolRule *sync.Map // Key: Tag, Value: []string
InboundDetectResult *sync.Map // key: Tag, Value: mapset.NewSet []api.DetectResult
}
func New() *Rule {
return &Rule{
InboundRule: new(sync.Map),
InboundProtocolRule: new(sync.Map),
InboundDetectResult: new(sync.Map),
}
}
@ -34,6 +36,16 @@ func (r *Rule) UpdateRule(tag string, newRuleList []api.DetectRule) error {
return nil
}
func (r *Rule) UpdateProtocolRule(tag string, ruleList []string) error {
if value, ok := r.InboundProtocolRule.LoadOrStore(tag, ruleList); ok {
old := value.([]string)
if !reflect.DeepEqual(old, ruleList) {
r.InboundProtocolRule.Store(tag, ruleList)
}
}
return nil
}
func (r *Rule) GetDetectResult(tag string) (*[]api.DetectResult, error) {
detectResult := make([]api.DetectResult, 0)
if value, ok := r.InboundDetectResult.LoadAndDelete(tag); ok {
@ -80,3 +92,14 @@ func (r *Rule) Detect(tag string, destination string, email string) (reject bool
}
return reject
}
func (r *Rule) ProtocolDetect(tag string, protocol string) bool {
if value, ok := r.InboundProtocolRule.Load(tag); ok {
ruleList := value.([]string)
for _, r := range ruleList {
if r == protocol {
return true
}
}
}
return false
}

View File

@ -158,6 +158,12 @@ func (c *Controller) UpdateRule(tag string, newRuleList []api.DetectRule) error
return err
}
func (c *Controller) UpdateProtocolRule(tag string, newRuleList []string) error {
dispather := c.server.GetFeature(routing.DispatcherType()).(*mydispatcher.DefaultDispatcher)
err := dispather.RuleManager.UpdateProtocolRule(tag, newRuleList)
return err
}
func (c *Controller) GetDetectResult(tag string) (*[]api.DetectResult, error) {
dispather := c.server.GetFeature(routing.DispatcherType()).(*mydispatcher.DefaultDispatcher)
return dispather.RuleManager.GetDetectResult(tag)

View File

@ -71,12 +71,17 @@ func (c *Controller) Start() error {
}
// Add Rule Manager
if !c.config.DisableGetRule {
if ruleList, err := c.apiClient.GetNodeRule(); err != nil {
if ruleList, protocolRule, err := c.apiClient.GetNodeRule(); err != nil {
log.Printf("Get rule list filed: %s", err)
} else if len(*ruleList) > 0 {
if err := c.UpdateRule(c.Tag, *ruleList); err != nil {
log.Print(err)
}
if len(*protocolRule) > 0 {
if err := c.UpdateProtocolRule(c.Tag, *protocolRule); err != nil {
log.Print(err)
}
}
}
}
c.nodeInfoMonitorPeriodic = &task.Periodic{
@ -160,12 +165,18 @@ func (c *Controller) nodeInfoMonitor() (err error) {
// Check Rule
if !c.config.DisableGetRule {
if ruleList, err := c.apiClient.GetNodeRule(); err != nil {
if ruleList, protocolRule, err := c.apiClient.GetNodeRule(); err != nil {
log.Printf("Get rule list filed: %s", err)
} else if ruleList != nil {
} else if len(*ruleList) > 0 {
if err := c.UpdateRule(c.Tag, *ruleList); err != nil {
log.Print(err)
}
if len(*protocolRule) > 0 {
if err := c.UpdateProtocolRule(c.Tag, *protocolRule); err != nil {
log.Print(err)
}
}
}
}