V2bX/core/sing/hook.go

169 lines
4.0 KiB
Go
Raw Normal View History

2023-07-27 21:13:11 -04:00
package sing
import (
"context"
2025-01-10 02:33:05 -05:00
"fmt"
2023-07-27 21:13:11 -04:00
"net"
2023-07-29 06:47:47 -04:00
"sync"
"github.com/InazumaV/V2bX/common/format"
2023-07-29 07:27:15 -04:00
"github.com/InazumaV/V2bX/common/rate"
2023-07-29 06:47:47 -04:00
2023-07-29 07:27:15 -04:00
"github.com/InazumaV/V2bX/limiter"
2023-07-27 21:13:11 -04:00
2023-07-29 07:27:15 -04:00
"github.com/InazumaV/V2bX/common/counter"
2023-10-26 01:06:43 -04:00
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/log"
2023-07-27 21:13:11 -04:00
N "github.com/sagernet/sing/common/network"
)
2024-12-12 16:22:44 -05:00
var _ adapter.ConnectionTracker = (*HookServer)(nil)
2023-07-27 21:13:11 -04:00
type HookServer struct {
counter sync.Map //map[string]*counter.TrafficCounter
userconn sync.Map //map[string][]net.Conn
2023-07-27 21:13:11 -04:00
}
2023-09-13 14:25:33 -04:00
func (h *HookServer) ModeList() []string {
return nil
}
2025-01-10 02:33:05 -05:00
func NewHookServer() *HookServer {
server := &HookServer{
counter: sync.Map{},
userconn: sync.Map{},
2023-07-27 21:13:11 -04:00
}
return server
2023-07-27 21:13:11 -04:00
}
2024-12-12 16:22:44 -05:00
func (h *HookServer) RoutedConnection(_ context.Context, conn net.Conn, m adapter.InboundContext, _ adapter.Rule, _ adapter.Outbound) net.Conn {
l, err := limiter.GetLimiter(m.Inbound)
2023-07-29 06:47:47 -04:00
if err != nil {
2023-10-26 01:06:43 -04:00
log.Warn("get limiter for ", m.Inbound, " error: ", err)
2024-12-12 16:22:44 -05:00
return conn
}
taguuid := format.UserTag(m.Inbound, m.User)
ip := m.Source.Addr.String()
if b, r := l.CheckLimit(taguuid, ip, true, true); r {
conn.Close()
2023-10-26 01:06:43 -04:00
log.Error("[", m.Inbound, "] ", "Limited ", m.User, " by ip or conn")
2024-12-12 16:22:44 -05:00
return conn
2023-07-29 06:47:47 -04:00
} else if b != nil {
conn = rate.NewConnRateLimiter(conn, b)
}
2025-01-10 02:33:05 -05:00
if l != nil {
destStr := m.Destination.AddrString()
protocol := m.Destination.Network()
if l.CheckDomainRule(destStr) {
log.Error(fmt.Sprintf(
"User %s access domain %s reject by rule",
m.User,
destStr))
conn.Close()
return conn
2023-10-13 03:32:06 -04:00
}
2025-01-10 02:33:05 -05:00
if len(protocol) != 0 {
if l.CheckProtocolRule(protocol) {
log.Error(fmt.Sprintf(
"User %s access protocol %s reject by rule",
m.User,
protocol))
conn.Close()
return conn
}
2023-10-13 03:32:06 -04:00
}
}
var t *counter.TrafficCounter
if c, ok := h.counter.Load(m.Inbound); !ok {
t = counter.NewTrafficCounter()
h.counter.Store(m.Inbound, t)
2023-07-29 06:47:47 -04:00
} else {
t = c.(*counter.TrafficCounter)
2023-07-27 21:13:11 -04:00
}
conn = counter.NewConnCounter(conn, t.GetCounter(m.User))
if conns, exist := h.userconn.Load(taguuid); exist {
if connList, ok := conns.([]net.Conn); ok {
h.userconn.Store(taguuid, append(connList, conn))
} else {
h.userconn.Store(taguuid, []net.Conn{conn})
}
} else {
h.userconn.Store(taguuid, []net.Conn{conn})
}
return conn
2023-07-27 21:13:11 -04:00
}
2024-12-12 16:22:44 -05:00
func (h *HookServer) RoutedPacketConnection(_ context.Context, conn N.PacketConn, m adapter.InboundContext, _ adapter.Rule, _ adapter.Outbound) N.PacketConn {
l, err := limiter.GetLimiter(m.Inbound)
if err != nil {
2023-10-26 01:06:43 -04:00
log.Warn("get limiter for ", m.Inbound, " error: ", err)
2024-12-12 16:22:44 -05:00
return conn
}
ip := m.Source.Addr.String()
taguuid := format.UserTag(m.Inbound, m.User)
if b, r := l.CheckLimit(taguuid, ip, false, false); r {
conn.Close()
2023-10-26 01:06:43 -04:00
log.Error("[", m.Inbound, "] ", "Limited ", m.User, " by ip or conn")
2024-12-12 16:22:44 -05:00
return conn
} else if b != nil {
//conn = rate.NewPacketConnCounter(conn, b)
}
2025-01-10 02:33:05 -05:00
if l != nil {
destStr := m.Destination.AddrString()
protocol := m.Destination.Network()
if l.CheckDomainRule(destStr) {
log.Error(fmt.Sprintf(
"User %s access domain %s reject by rule",
m.User,
destStr))
conn.Close()
return conn
2023-10-13 03:32:06 -04:00
}
2025-01-10 02:33:05 -05:00
if len(protocol) != 0 {
if l.CheckProtocolRule(protocol) {
log.Error(fmt.Sprintf(
"User %s access protocol %s reject by rule",
m.User,
protocol))
conn.Close()
return conn
}
2023-10-13 03:32:06 -04:00
}
}
var t *counter.TrafficCounter
if c, ok := h.counter.Load(m.Inbound); !ok {
t = counter.NewTrafficCounter()
h.counter.Store(m.Inbound, t)
2023-07-29 06:47:47 -04:00
} else {
t = c.(*counter.TrafficCounter)
}
conn = counter.NewPacketConnCounter(conn, t.GetCounter(m.User))
return conn
}
func (h *HookServer) CloseConnections(tag string, uuids []string) error {
for _, uuid := range uuids {
taguuid := format.UserTag(tag, uuid)
v, ok := h.userconn.Load(taguuid)
if !ok {
continue
}
connList, ok := v.([]net.Conn)
if !ok {
h.userconn.Delete(taguuid)
continue
}
for _, conn := range connList {
err := conn.Close()
if err != nil {
log.Error("close conn error: ", err)
}
}
h.userconn.Delete(taguuid)
2023-07-27 21:13:11 -04:00
}
return nil
2023-07-27 21:13:11 -04:00
}