From fee53a83841541c319505c36200d305b58fc0648 Mon Sep 17 00:00:00 2001 From: yuzuki999 Date: Fri, 16 Jun 2023 10:04:39 +0800 Subject: [PATCH] fix some bugs support ip limit for hy --- common/task/task.go | 83 ++++++++++++++++++++++++++++++++++ conf/conf.go | 50 +------------------- conf/watch.go | 52 +++++++++++++++++++++ core/hy/node.go | 9 +++- core/hy/server.go | 10 +++- core/xray/user.go | 1 - node/cert.go | 14 ++++++ node/controller.go | 44 ++++++------------ node/task.go | 108 ++++++++++---------------------------------- node/user.go | 30 ++++++++++++ 10 files changed, 236 insertions(+), 165 deletions(-) create mode 100644 common/task/task.go create mode 100644 conf/watch.go diff --git a/common/task/task.go b/common/task/task.go new file mode 100644 index 0000000..b04eeb7 --- /dev/null +++ b/common/task/task.go @@ -0,0 +1,83 @@ +package task + +import ( + "sync" + "time" +) + +// Task is a task that runs periodically. +type Task struct { + // Interval of the task being run + Interval time.Duration + // Execute is the task function + Execute func() error + + access sync.Mutex + timer *time.Timer + running bool +} + +func (t *Task) hasClosed() bool { + t.access.Lock() + defer t.access.Unlock() + + return !t.running +} + +func (t *Task) checkedExecute() error { + if t.hasClosed() { + return nil + } + + t.access.Lock() + defer t.access.Unlock() + + if !t.running { + return nil + } + + t.timer = time.AfterFunc(t.Interval, func() { + t.checkedExecute() + }) + + return nil +} + +// Start implements common.Runnable. +func (t *Task) Start(first bool) error { + t.access.Lock() + if t.running { + t.access.Unlock() + return nil + } + t.running = true + t.access.Unlock() + if first { + if err := t.Execute(); err != nil { + t.access.Lock() + t.running = false + t.access.Unlock() + return err + } + } + if err := t.checkedExecute(); err != nil { + t.access.Lock() + t.running = false + t.access.Unlock() + return err + } + + return nil +} + +// Close implements common.Closable. +func (t *Task) Close() { + t.access.Lock() + defer t.access.Unlock() + + t.running = false + if t.timer != nil { + t.timer.Stop() + t.timer = nil + } +} diff --git a/conf/conf.go b/conf/conf.go index e54997c..4f2c7e8 100644 --- a/conf/conf.go +++ b/conf/conf.go @@ -2,14 +2,9 @@ package conf import ( "fmt" - "io" - "log" - "os" - "path" - "time" - - "github.com/fsnotify/fsnotify" "gopkg.in/yaml.v3" + "io" + "os" ) type Conf struct { @@ -56,44 +51,3 @@ func (p *Conf) LoadFromPath(filePath string) error { } return nil } - -func (p *Conf) Watch(filePath string, reload func()) error { - watcher, err := fsnotify.NewWatcher() - if err != nil { - return fmt.Errorf("new watcher error: %s", err) - } - go func() { - var pre time.Time - defer watcher.Close() - for { - select { - case e := <-watcher.Events: - if e.Has(fsnotify.Chmod) { - continue - } - if pre.Add(1 * time.Second).After(time.Now()) { - continue - } - time.Sleep(2 * time.Second) - pre = time.Now() - log.Println("config dir changed, reloading...") - *p = *New() - err := p.LoadFromPath(filePath) - if err != nil { - log.Printf("reload config error: %s", err) - } - reload() - log.Println("reload config success") - case err := <-watcher.Errors: - if err != nil { - log.Printf("File watcher error: %s", err) - } - } - } - }() - err = watcher.Add(path.Dir(filePath)) - if err != nil { - return fmt.Errorf("watch file error: %s", err) - } - return nil -} diff --git a/conf/watch.go b/conf/watch.go new file mode 100644 index 0000000..eacf0fa --- /dev/null +++ b/conf/watch.go @@ -0,0 +1,52 @@ +package conf + +import ( + "fmt" + "github.com/fsnotify/fsnotify" + "log" + "path" + "time" +) + +func (p *Conf) Watch(filePath string, reload func()) error { + watcher, err := fsnotify.NewWatcher() + if err != nil { + return fmt.Errorf("new watcher error: %s", err) + } + go func() { + var pre time.Time + defer watcher.Close() + for { + select { + case e := <-watcher.Events: + if e.Has(fsnotify.Chmod) { + continue + } + if pre.Add(10 * time.Second).After(time.Now()) { + continue + } + pre = time.Now() + go func() { + time.Sleep(10 * time.Second) + log.Println("config dir changed, reloading...") + *p = *New() + err := p.LoadFromPath(filePath) + if err != nil { + log.Printf("reload config error: %s", err) + } + reload() + log.Println("reload config success") + }() + case err := <-watcher.Errors: + if err != nil { + log.Printf("File watcher error: %s", err) + } + } + } + }() + err = watcher.Add(path.Dir(filePath)) + if err != nil { + return fmt.Errorf("watch file error: %s", err) + } + return nil +} diff --git a/core/hy/node.go b/core/hy/node.go index a39908b..f97aa97 100644 --- a/core/hy/node.go +++ b/core/hy/node.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/Yuzuki616/V2bX/api/panel" "github.com/Yuzuki616/V2bX/conf" + "github.com/Yuzuki616/V2bX/limiter" "github.com/apernet/hysteria/core/cs" ) @@ -16,8 +17,12 @@ func (h *Hy) AddNode(tag string, info *panel.NodeInfo, c *conf.ControllerConfig) case "reality", "none", "": return errors.New("hysteria need normal tls cert") } - s := NewServer(tag) - err := s.runServer(info, c) + l, err := limiter.GetLimiter(tag) + if err != nil { + return fmt.Errorf("get limiter error: %s", err) + } + s := NewServer(tag, l) + err = s.runServer(info, c) if err != nil { return fmt.Errorf("run hy server error: %s", err) } diff --git a/core/hy/server.go b/core/hy/server.go index fda7f6c..6e34831 100644 --- a/core/hy/server.go +++ b/core/hy/server.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/Yuzuki616/V2bX/api/panel" "github.com/Yuzuki616/V2bX/conf" + "github.com/Yuzuki616/V2bX/limiter" "github.com/apernet/hysteria/core/sockopt" "io" "net" @@ -32,15 +33,17 @@ var serverPacketConnFuncFactoryMap = map[string]pktconns.ServerPacketConnFuncFac type Server struct { tag string + l *limiter.Limiter counter *UserTrafficCounter users sync.Map running atomic.Bool *cs.Server } -func NewServer(tag string) *Server { +func NewServer(tag string, l *limiter.Limiter) *Server { return &Server{ tag: tag, + l: l, } } @@ -173,6 +176,9 @@ func (s *Server) runServer(node *panel.NodeInfo, c *conf.ControllerConfig) error } func (s *Server) authByUser(addr net.Addr, auth []byte, sSend uint64, sRecv uint64) (bool, string) { + if _, r := s.l.CheckLimit(string(auth), addr.String(), false); r { + return false, "device limited" + } if _, ok := s.users.Load(string(auth)); ok { return true, "Done" } @@ -180,6 +186,7 @@ func (s *Server) authByUser(addr net.Addr, auth []byte, sSend uint64, sRecv uint } func (s *Server) connectFunc(addr net.Addr, auth []byte, sSend uint64, sRecv uint64) (bool, string) { + s.l.ConnLimiter.AddConnCount(addr.String(), string(auth), false) ok, msg := s.authByUser(addr, auth, sSend, sRecv) if !ok { logrus.WithFields(logrus.Fields{ @@ -196,6 +203,7 @@ func (s *Server) connectFunc(addr net.Addr, auth []byte, sSend uint64, sRecv uin } func (s *Server) disconnectFunc(addr net.Addr, auth []byte, err error) { + s.l.ConnLimiter.DelConnCount(addr.String(), string(auth)) logrus.WithFields(logrus.Fields{ "src": defaultIPMasker.Mask(addr.String()), "error": err, diff --git a/core/xray/user.go b/core/xray/user.go index 43f2c63..c636bc0 100644 --- a/core/xray/user.go +++ b/core/xray/user.go @@ -3,7 +3,6 @@ package xray import ( "context" "fmt" - "github.com/Yuzuki616/V2bX/common/builder" vCore "github.com/Yuzuki616/V2bX/core" "github.com/xtls/xray-core/common/protocol" diff --git a/node/cert.go b/node/cert.go index d80857c..500c631 100644 --- a/node/cert.go +++ b/node/cert.go @@ -4,8 +4,22 @@ import ( "fmt" "github.com/Yuzuki616/V2bX/common/file" "github.com/Yuzuki616/V2bX/node/lego" + "log" ) +func (c *Controller) renewCertTask() { + l, err := lego.New(c.CertConfig) + if err != nil { + log.Print("new lego error: ", err) + return + } + err = l.RenewCert() + if err != nil { + log.Print("renew cert error: ", err) + return + } +} + func (c *Controller) requestCert() error { if c.CertConfig.CertFile == "" || c.CertConfig.KeyFile == "" { return fmt.Errorf("cert file path or key file path not exist") diff --git a/node/controller.go b/node/controller.go index d9c1bb1..2753d4c 100644 --- a/node/controller.go +++ b/node/controller.go @@ -3,14 +3,13 @@ package node import ( "errors" "fmt" - "log" - "github.com/Yuzuki616/V2bX/api/iprecoder" "github.com/Yuzuki616/V2bX/api/panel" + "github.com/Yuzuki616/V2bX/common/task" "github.com/Yuzuki616/V2bX/conf" vCore "github.com/Yuzuki616/V2bX/core" "github.com/Yuzuki616/V2bX/limiter" - "github.com/xtls/xray-core/common/task" + "log" ) type Controller struct { @@ -20,11 +19,11 @@ type Controller struct { Tag string userList []panel.UserInfo ipRecorder iprecoder.IpRecorder - nodeInfoMonitorPeriodic *task.Periodic - userReportPeriodic *task.Periodic - renewCertPeriodic *task.Periodic - dynamicSpeedLimitPeriodic *task.Periodic - onlineIpReportPeriodic *task.Periodic + nodeInfoMonitorPeriodic *task.Task + userReportPeriodic *task.Task + renewCertPeriodic *task.Task + dynamicSpeedLimitPeriodic *task.Task + onlineIpReportPeriodic *task.Task *conf.ControllerConfig } @@ -93,38 +92,23 @@ func (c *Controller) Start() error { func (c *Controller) Close() error { limiter.DeleteLimiter(c.Tag) if c.nodeInfoMonitorPeriodic != nil { - err := c.nodeInfoMonitorPeriodic.Close() - if err != nil { - return fmt.Errorf("node info periodic close error: %s", err) - } + c.nodeInfoMonitorPeriodic.Close() } - if c.nodeInfoMonitorPeriodic != nil { - err := c.userReportPeriodic.Close() - if err != nil { - return fmt.Errorf("user report periodic close error: %s", err) - } + if c.userReportPeriodic != nil { + c.userReportPeriodic.Close() } if c.renewCertPeriodic != nil { - err := c.renewCertPeriodic.Close() - if err != nil { - return fmt.Errorf("renew cert periodic close error: %s", err) - } + c.renewCertPeriodic.Close() } if c.dynamicSpeedLimitPeriodic != nil { - err := c.dynamicSpeedLimitPeriodic.Close() - if err != nil { - return fmt.Errorf("dynamic speed limit periodic close error: %s", err) - } + c.dynamicSpeedLimitPeriodic.Close() } if c.onlineIpReportPeriodic != nil { - err := c.onlineIpReportPeriodic.Close() - if err != nil { - return fmt.Errorf("online ip report periodic close error: %s", err) - } + c.onlineIpReportPeriodic.Close() } return nil } func (c *Controller) buildNodeTag() string { - return fmt.Sprintf("%s_%s_%d", c.nodeInfo.Type, c.ListenIP, c.nodeInfo.Id) + return fmt.Sprintf("%s-%s-%d", c.apiClient.APIHost, c.nodeInfo.Type, c.nodeInfo.Id) } diff --git a/node/task.go b/node/task.go index b029ecc..768b7e2 100644 --- a/node/task.go +++ b/node/task.go @@ -2,52 +2,41 @@ package node import ( "fmt" - "log" - "runtime" - "time" - + "github.com/Yuzuki616/V2bX/common/task" vCore "github.com/Yuzuki616/V2bX/core" - - "github.com/Yuzuki616/V2bX/api/panel" "github.com/Yuzuki616/V2bX/limiter" - "github.com/Yuzuki616/V2bX/node/lego" - "github.com/xtls/xray-core/common/task" + "log" + "time" ) func (c *Controller) initTask() { // fetch node info task - c.nodeInfoMonitorPeriodic = &task.Periodic{ + c.nodeInfoMonitorPeriodic = &task.Task{ Interval: c.nodeInfo.PullInterval, Execute: c.nodeInfoMonitor, } // fetch user list task - c.userReportPeriodic = &task.Periodic{ + c.userReportPeriodic = &task.Task{ Interval: c.nodeInfo.PushInterval, - Execute: c.reportUserTraffic, + Execute: c.reportUserTrafficTask, } - log.Printf("[%s: %d] Start monitor node status", c.nodeInfo.Type, c.nodeInfo.Id) + log.Printf("[%s] Start monitor node status", c.Tag) // delay to start nodeInfoMonitor - go func() { - time.Sleep(c.nodeInfo.PullInterval) - _ = c.nodeInfoMonitorPeriodic.Start() - }() - log.Printf("[%s: %d] Start report node status", c.nodeInfo.Type, c.nodeInfo.Id) - // delay to start userReport - go func() { - time.Sleep(c.nodeInfo.PullInterval) - _ = c.userReportPeriodic.Start() - }() - if c.nodeInfo.Tls && c.CertConfig.CertMode != "none" && - (c.CertConfig.CertMode == "dns" || c.CertConfig.CertMode == "http") { - c.renewCertPeriodic = &task.Periodic{ - Interval: time.Hour * 24, - Execute: c.reportUserTraffic, + _ = c.nodeInfoMonitorPeriodic.Start(false) + log.Printf("[%s] Start report node status", c.Tag) + _ = c.userReportPeriodic.Start(false) + if c.nodeInfo.Tls { + switch c.CertConfig.CertMode { + case "reality", "none", "": + default: + c.renewCertPeriodic = &task.Task{ + Interval: time.Hour * 24, + Execute: c.reportUserTrafficTask, + } + log.Printf("[%s] Start renew cert", c.Tag) + // delay to start renewCert + _ = c.renewCertPeriodic.Start(true) } - log.Printf("[%s: %d] Start renew cert", c.nodeInfo.Type, c.nodeInfo.Id) - // delay to start renewCert - go func() { - _ = c.renewCertPeriodic.Start() - }() } } @@ -114,20 +103,14 @@ func (c *Controller) nodeInfoMonitor() (err error) { if c.nodeInfoMonitorPeriodic.Interval != newNodeInfo.PullInterval && newNodeInfo.PullInterval != 0 { c.nodeInfoMonitorPeriodic.Interval = newNodeInfo.PullInterval - _ = c.nodeInfoMonitorPeriodic.Close() - go func() { - time.Sleep(c.nodeInfoMonitorPeriodic.Interval) - _ = c.nodeInfoMonitorPeriodic.Start() - }() + c.nodeInfoMonitorPeriodic.Close() + _ = c.nodeInfoMonitorPeriodic.Start(false) } if c.userReportPeriodic.Interval != newNodeInfo.PushInterval && newNodeInfo.PushInterval != 0 { c.userReportPeriodic.Interval = newNodeInfo.PullInterval - _ = c.userReportPeriodic.Close() - go func() { - time.Sleep(c.userReportPeriodic.Interval) - _ = c.userReportPeriodic.Start() - }() + c.userReportPeriodic.Close() + _ = c.userReportPeriodic.Start(false) } } else { deleted, added := compareUserList(c.userList, newUserInfo) @@ -168,44 +151,3 @@ func (c *Controller) nodeInfoMonitor() (err error) { } return nil } - -func (c *Controller) reportUserTraffic() (err error) { - // Get User traffic - userTraffic := make([]panel.UserTraffic, 0) - for i := range c.userList { - up, down := c.server.GetUserTraffic(c.Tag, c.userList[i].Uuid, true) - if up > 0 || down > 0 { - if c.LimitConfig.EnableDynamicSpeedLimit { - c.userList[i].Traffic += up + down - } - userTraffic = append(userTraffic, panel.UserTraffic{ - UID: (c.userList)[i].Id, - Upload: up, - Download: down}) - } - } - if len(userTraffic) > 0 && !c.DisableUploadTraffic { - err = c.apiClient.ReportUserTraffic(userTraffic) - if err != nil { - log.Printf("Report user traffic faild: %s", err) - } else { - log.Printf("[%s: %d] Report %d online users", c.nodeInfo.Type, c.nodeInfo.Id, len(userTraffic)) - } - } - userTraffic = nil - runtime.GC() - return nil -} - -func (c *Controller) RenewCert() { - l, err := lego.New(c.CertConfig) - if err != nil { - log.Print("new lego error: ", err) - return - } - err = l.RenewCert() - if err != nil { - log.Print("renew cert error: ", err) - return - } -} diff --git a/node/user.go b/node/user.go index 31857f0..25b1318 100644 --- a/node/user.go +++ b/node/user.go @@ -2,9 +2,39 @@ package node import ( "github.com/Yuzuki616/V2bX/api/panel" + "log" + "runtime" "strconv" ) +func (c *Controller) reportUserTrafficTask() (err error) { + // Get User traffic + userTraffic := make([]panel.UserTraffic, 0) + for i := range c.userList { + up, down := c.server.GetUserTraffic(c.Tag, c.userList[i].Uuid, true) + if up > 0 || down > 0 { + if c.LimitConfig.EnableDynamicSpeedLimit { + c.userList[i].Traffic += up + down + } + userTraffic = append(userTraffic, panel.UserTraffic{ + UID: (c.userList)[i].Id, + Upload: up, + Download: down}) + } + } + if len(userTraffic) > 0 && !c.DisableUploadTraffic { + err = c.apiClient.ReportUserTraffic(userTraffic) + if err != nil { + log.Printf("Report user traffic faild: %s", err) + } else { + log.Printf("[%s] Report %d online users", c.Tag, len(userTraffic)) + } + } + userTraffic = nil + runtime.GC() + return nil +} + func compareUserList(old, new []panel.UserInfo) (deleted, added []panel.UserInfo) { tmp := map[string]struct{}{} tmp2 := map[string]struct{}{}