fix some bugs

support ip limit for hy
This commit is contained in:
yuzuki999 2023-06-16 10:04:39 +08:00
parent 28acdc7087
commit fee53a8384
10 changed files with 236 additions and 165 deletions

83
common/task/task.go Normal file
View File

@ -0,0 +1,83 @@
package task
import (
"sync"
"time"
)
// Task is a task that runs periodically.
type Task struct {
// Interval of the task being run
Interval time.Duration
// Execute is the task function
Execute func() error
access sync.Mutex
timer *time.Timer
running bool
}
func (t *Task) hasClosed() bool {
t.access.Lock()
defer t.access.Unlock()
return !t.running
}
func (t *Task) checkedExecute() error {
if t.hasClosed() {
return nil
}
t.access.Lock()
defer t.access.Unlock()
if !t.running {
return nil
}
t.timer = time.AfterFunc(t.Interval, func() {
t.checkedExecute()
})
return nil
}
// Start implements common.Runnable.
func (t *Task) Start(first bool) error {
t.access.Lock()
if t.running {
t.access.Unlock()
return nil
}
t.running = true
t.access.Unlock()
if first {
if err := t.Execute(); err != nil {
t.access.Lock()
t.running = false
t.access.Unlock()
return err
}
}
if err := t.checkedExecute(); err != nil {
t.access.Lock()
t.running = false
t.access.Unlock()
return err
}
return nil
}
// Close implements common.Closable.
func (t *Task) Close() {
t.access.Lock()
defer t.access.Unlock()
t.running = false
if t.timer != nil {
t.timer.Stop()
t.timer = nil
}
}

View File

@ -2,14 +2,9 @@ package conf
import (
"fmt"
"io"
"log"
"os"
"path"
"time"
"github.com/fsnotify/fsnotify"
"gopkg.in/yaml.v3"
"io"
"os"
)
type Conf struct {
@ -56,44 +51,3 @@ func (p *Conf) LoadFromPath(filePath string) error {
}
return nil
}
func (p *Conf) Watch(filePath string, reload func()) error {
watcher, err := fsnotify.NewWatcher()
if err != nil {
return fmt.Errorf("new watcher error: %s", err)
}
go func() {
var pre time.Time
defer watcher.Close()
for {
select {
case e := <-watcher.Events:
if e.Has(fsnotify.Chmod) {
continue
}
if pre.Add(1 * time.Second).After(time.Now()) {
continue
}
time.Sleep(2 * time.Second)
pre = time.Now()
log.Println("config dir changed, reloading...")
*p = *New()
err := p.LoadFromPath(filePath)
if err != nil {
log.Printf("reload config error: %s", err)
}
reload()
log.Println("reload config success")
case err := <-watcher.Errors:
if err != nil {
log.Printf("File watcher error: %s", err)
}
}
}
}()
err = watcher.Add(path.Dir(filePath))
if err != nil {
return fmt.Errorf("watch file error: %s", err)
}
return nil
}

52
conf/watch.go Normal file
View File

@ -0,0 +1,52 @@
package conf
import (
"fmt"
"github.com/fsnotify/fsnotify"
"log"
"path"
"time"
)
func (p *Conf) Watch(filePath string, reload func()) error {
watcher, err := fsnotify.NewWatcher()
if err != nil {
return fmt.Errorf("new watcher error: %s", err)
}
go func() {
var pre time.Time
defer watcher.Close()
for {
select {
case e := <-watcher.Events:
if e.Has(fsnotify.Chmod) {
continue
}
if pre.Add(10 * time.Second).After(time.Now()) {
continue
}
pre = time.Now()
go func() {
time.Sleep(10 * time.Second)
log.Println("config dir changed, reloading...")
*p = *New()
err := p.LoadFromPath(filePath)
if err != nil {
log.Printf("reload config error: %s", err)
}
reload()
log.Println("reload config success")
}()
case err := <-watcher.Errors:
if err != nil {
log.Printf("File watcher error: %s", err)
}
}
}
}()
err = watcher.Add(path.Dir(filePath))
if err != nil {
return fmt.Errorf("watch file error: %s", err)
}
return nil
}

View File

@ -5,6 +5,7 @@ import (
"fmt"
"github.com/Yuzuki616/V2bX/api/panel"
"github.com/Yuzuki616/V2bX/conf"
"github.com/Yuzuki616/V2bX/limiter"
"github.com/apernet/hysteria/core/cs"
)
@ -16,8 +17,12 @@ func (h *Hy) AddNode(tag string, info *panel.NodeInfo, c *conf.ControllerConfig)
case "reality", "none", "":
return errors.New("hysteria need normal tls cert")
}
s := NewServer(tag)
err := s.runServer(info, c)
l, err := limiter.GetLimiter(tag)
if err != nil {
return fmt.Errorf("get limiter error: %s", err)
}
s := NewServer(tag, l)
err = s.runServer(info, c)
if err != nil {
return fmt.Errorf("run hy server error: %s", err)
}

View File

@ -5,6 +5,7 @@ import (
"fmt"
"github.com/Yuzuki616/V2bX/api/panel"
"github.com/Yuzuki616/V2bX/conf"
"github.com/Yuzuki616/V2bX/limiter"
"github.com/apernet/hysteria/core/sockopt"
"io"
"net"
@ -32,15 +33,17 @@ var serverPacketConnFuncFactoryMap = map[string]pktconns.ServerPacketConnFuncFac
type Server struct {
tag string
l *limiter.Limiter
counter *UserTrafficCounter
users sync.Map
running atomic.Bool
*cs.Server
}
func NewServer(tag string) *Server {
func NewServer(tag string, l *limiter.Limiter) *Server {
return &Server{
tag: tag,
l: l,
}
}
@ -173,6 +176,9 @@ func (s *Server) runServer(node *panel.NodeInfo, c *conf.ControllerConfig) error
}
func (s *Server) authByUser(addr net.Addr, auth []byte, sSend uint64, sRecv uint64) (bool, string) {
if _, r := s.l.CheckLimit(string(auth), addr.String(), false); r {
return false, "device limited"
}
if _, ok := s.users.Load(string(auth)); ok {
return true, "Done"
}
@ -180,6 +186,7 @@ func (s *Server) authByUser(addr net.Addr, auth []byte, sSend uint64, sRecv uint
}
func (s *Server) connectFunc(addr net.Addr, auth []byte, sSend uint64, sRecv uint64) (bool, string) {
s.l.ConnLimiter.AddConnCount(addr.String(), string(auth), false)
ok, msg := s.authByUser(addr, auth, sSend, sRecv)
if !ok {
logrus.WithFields(logrus.Fields{
@ -196,6 +203,7 @@ func (s *Server) connectFunc(addr net.Addr, auth []byte, sSend uint64, sRecv uin
}
func (s *Server) disconnectFunc(addr net.Addr, auth []byte, err error) {
s.l.ConnLimiter.DelConnCount(addr.String(), string(auth))
logrus.WithFields(logrus.Fields{
"src": defaultIPMasker.Mask(addr.String()),
"error": err,

View File

@ -3,7 +3,6 @@ package xray
import (
"context"
"fmt"
"github.com/Yuzuki616/V2bX/common/builder"
vCore "github.com/Yuzuki616/V2bX/core"
"github.com/xtls/xray-core/common/protocol"

View File

@ -4,8 +4,22 @@ import (
"fmt"
"github.com/Yuzuki616/V2bX/common/file"
"github.com/Yuzuki616/V2bX/node/lego"
"log"
)
func (c *Controller) renewCertTask() {
l, err := lego.New(c.CertConfig)
if err != nil {
log.Print("new lego error: ", err)
return
}
err = l.RenewCert()
if err != nil {
log.Print("renew cert error: ", err)
return
}
}
func (c *Controller) requestCert() error {
if c.CertConfig.CertFile == "" || c.CertConfig.KeyFile == "" {
return fmt.Errorf("cert file path or key file path not exist")

View File

@ -3,14 +3,13 @@ package node
import (
"errors"
"fmt"
"log"
"github.com/Yuzuki616/V2bX/api/iprecoder"
"github.com/Yuzuki616/V2bX/api/panel"
"github.com/Yuzuki616/V2bX/common/task"
"github.com/Yuzuki616/V2bX/conf"
vCore "github.com/Yuzuki616/V2bX/core"
"github.com/Yuzuki616/V2bX/limiter"
"github.com/xtls/xray-core/common/task"
"log"
)
type Controller struct {
@ -20,11 +19,11 @@ type Controller struct {
Tag string
userList []panel.UserInfo
ipRecorder iprecoder.IpRecorder
nodeInfoMonitorPeriodic *task.Periodic
userReportPeriodic *task.Periodic
renewCertPeriodic *task.Periodic
dynamicSpeedLimitPeriodic *task.Periodic
onlineIpReportPeriodic *task.Periodic
nodeInfoMonitorPeriodic *task.Task
userReportPeriodic *task.Task
renewCertPeriodic *task.Task
dynamicSpeedLimitPeriodic *task.Task
onlineIpReportPeriodic *task.Task
*conf.ControllerConfig
}
@ -93,38 +92,23 @@ func (c *Controller) Start() error {
func (c *Controller) Close() error {
limiter.DeleteLimiter(c.Tag)
if c.nodeInfoMonitorPeriodic != nil {
err := c.nodeInfoMonitorPeriodic.Close()
if err != nil {
return fmt.Errorf("node info periodic close error: %s", err)
}
c.nodeInfoMonitorPeriodic.Close()
}
if c.nodeInfoMonitorPeriodic != nil {
err := c.userReportPeriodic.Close()
if err != nil {
return fmt.Errorf("user report periodic close error: %s", err)
}
if c.userReportPeriodic != nil {
c.userReportPeriodic.Close()
}
if c.renewCertPeriodic != nil {
err := c.renewCertPeriodic.Close()
if err != nil {
return fmt.Errorf("renew cert periodic close error: %s", err)
}
c.renewCertPeriodic.Close()
}
if c.dynamicSpeedLimitPeriodic != nil {
err := c.dynamicSpeedLimitPeriodic.Close()
if err != nil {
return fmt.Errorf("dynamic speed limit periodic close error: %s", err)
}
c.dynamicSpeedLimitPeriodic.Close()
}
if c.onlineIpReportPeriodic != nil {
err := c.onlineIpReportPeriodic.Close()
if err != nil {
return fmt.Errorf("online ip report periodic close error: %s", err)
}
c.onlineIpReportPeriodic.Close()
}
return nil
}
func (c *Controller) buildNodeTag() string {
return fmt.Sprintf("%s_%s_%d", c.nodeInfo.Type, c.ListenIP, c.nodeInfo.Id)
return fmt.Sprintf("%s-%s-%d", c.apiClient.APIHost, c.nodeInfo.Type, c.nodeInfo.Id)
}

View File

@ -2,52 +2,41 @@ package node
import (
"fmt"
"log"
"runtime"
"time"
"github.com/Yuzuki616/V2bX/common/task"
vCore "github.com/Yuzuki616/V2bX/core"
"github.com/Yuzuki616/V2bX/api/panel"
"github.com/Yuzuki616/V2bX/limiter"
"github.com/Yuzuki616/V2bX/node/lego"
"github.com/xtls/xray-core/common/task"
"log"
"time"
)
func (c *Controller) initTask() {
// fetch node info task
c.nodeInfoMonitorPeriodic = &task.Periodic{
c.nodeInfoMonitorPeriodic = &task.Task{
Interval: c.nodeInfo.PullInterval,
Execute: c.nodeInfoMonitor,
}
// fetch user list task
c.userReportPeriodic = &task.Periodic{
c.userReportPeriodic = &task.Task{
Interval: c.nodeInfo.PushInterval,
Execute: c.reportUserTraffic,
Execute: c.reportUserTrafficTask,
}
log.Printf("[%s: %d] Start monitor node status", c.nodeInfo.Type, c.nodeInfo.Id)
log.Printf("[%s] Start monitor node status", c.Tag)
// delay to start nodeInfoMonitor
go func() {
time.Sleep(c.nodeInfo.PullInterval)
_ = c.nodeInfoMonitorPeriodic.Start()
}()
log.Printf("[%s: %d] Start report node status", c.nodeInfo.Type, c.nodeInfo.Id)
// delay to start userReport
go func() {
time.Sleep(c.nodeInfo.PullInterval)
_ = c.userReportPeriodic.Start()
}()
if c.nodeInfo.Tls && c.CertConfig.CertMode != "none" &&
(c.CertConfig.CertMode == "dns" || c.CertConfig.CertMode == "http") {
c.renewCertPeriodic = &task.Periodic{
Interval: time.Hour * 24,
Execute: c.reportUserTraffic,
_ = c.nodeInfoMonitorPeriodic.Start(false)
log.Printf("[%s] Start report node status", c.Tag)
_ = c.userReportPeriodic.Start(false)
if c.nodeInfo.Tls {
switch c.CertConfig.CertMode {
case "reality", "none", "":
default:
c.renewCertPeriodic = &task.Task{
Interval: time.Hour * 24,
Execute: c.reportUserTrafficTask,
}
log.Printf("[%s] Start renew cert", c.Tag)
// delay to start renewCert
_ = c.renewCertPeriodic.Start(true)
}
log.Printf("[%s: %d] Start renew cert", c.nodeInfo.Type, c.nodeInfo.Id)
// delay to start renewCert
go func() {
_ = c.renewCertPeriodic.Start()
}()
}
}
@ -114,20 +103,14 @@ func (c *Controller) nodeInfoMonitor() (err error) {
if c.nodeInfoMonitorPeriodic.Interval != newNodeInfo.PullInterval &&
newNodeInfo.PullInterval != 0 {
c.nodeInfoMonitorPeriodic.Interval = newNodeInfo.PullInterval
_ = c.nodeInfoMonitorPeriodic.Close()
go func() {
time.Sleep(c.nodeInfoMonitorPeriodic.Interval)
_ = c.nodeInfoMonitorPeriodic.Start()
}()
c.nodeInfoMonitorPeriodic.Close()
_ = c.nodeInfoMonitorPeriodic.Start(false)
}
if c.userReportPeriodic.Interval != newNodeInfo.PushInterval &&
newNodeInfo.PushInterval != 0 {
c.userReportPeriodic.Interval = newNodeInfo.PullInterval
_ = c.userReportPeriodic.Close()
go func() {
time.Sleep(c.userReportPeriodic.Interval)
_ = c.userReportPeriodic.Start()
}()
c.userReportPeriodic.Close()
_ = c.userReportPeriodic.Start(false)
}
} else {
deleted, added := compareUserList(c.userList, newUserInfo)
@ -168,44 +151,3 @@ func (c *Controller) nodeInfoMonitor() (err error) {
}
return nil
}
func (c *Controller) reportUserTraffic() (err error) {
// Get User traffic
userTraffic := make([]panel.UserTraffic, 0)
for i := range c.userList {
up, down := c.server.GetUserTraffic(c.Tag, c.userList[i].Uuid, true)
if up > 0 || down > 0 {
if c.LimitConfig.EnableDynamicSpeedLimit {
c.userList[i].Traffic += up + down
}
userTraffic = append(userTraffic, panel.UserTraffic{
UID: (c.userList)[i].Id,
Upload: up,
Download: down})
}
}
if len(userTraffic) > 0 && !c.DisableUploadTraffic {
err = c.apiClient.ReportUserTraffic(userTraffic)
if err != nil {
log.Printf("Report user traffic faild: %s", err)
} else {
log.Printf("[%s: %d] Report %d online users", c.nodeInfo.Type, c.nodeInfo.Id, len(userTraffic))
}
}
userTraffic = nil
runtime.GC()
return nil
}
func (c *Controller) RenewCert() {
l, err := lego.New(c.CertConfig)
if err != nil {
log.Print("new lego error: ", err)
return
}
err = l.RenewCert()
if err != nil {
log.Print("renew cert error: ", err)
return
}
}

View File

@ -2,9 +2,39 @@ package node
import (
"github.com/Yuzuki616/V2bX/api/panel"
"log"
"runtime"
"strconv"
)
func (c *Controller) reportUserTrafficTask() (err error) {
// Get User traffic
userTraffic := make([]panel.UserTraffic, 0)
for i := range c.userList {
up, down := c.server.GetUserTraffic(c.Tag, c.userList[i].Uuid, true)
if up > 0 || down > 0 {
if c.LimitConfig.EnableDynamicSpeedLimit {
c.userList[i].Traffic += up + down
}
userTraffic = append(userTraffic, panel.UserTraffic{
UID: (c.userList)[i].Id,
Upload: up,
Download: down})
}
}
if len(userTraffic) > 0 && !c.DisableUploadTraffic {
err = c.apiClient.ReportUserTraffic(userTraffic)
if err != nil {
log.Printf("Report user traffic faild: %s", err)
} else {
log.Printf("[%s] Report %d online users", c.Tag, len(userTraffic))
}
}
userTraffic = nil
runtime.GC()
return nil
}
func compareUserList(old, new []panel.UserInfo) (deleted, added []panel.UserInfo) {
tmp := map[string]struct{}{}
tmp2 := map[string]struct{}{}