diff --git a/cmd/dashboard/controller/alertrule.go b/cmd/dashboard/controller/alertrule.go new file mode 100644 index 0000000..37a1877 --- /dev/null +++ b/cmd/dashboard/controller/alertrule.go @@ -0,0 +1,177 @@ +package controller + +import ( + "errors" + "fmt" + "strconv" + "time" + + "github.com/gin-gonic/gin" + "github.com/jinzhu/copier" + + "github.com/naiba/nezha/model" + "github.com/naiba/nezha/service/singleton" +) + +// List Alert rules +// @Summary List Alert rules +// @Security BearerAuth +// @Schemes +// @Description List Alert rules +// @Tags auth required +// @Produce json +// @Success 200 {object} model.CommonResponse[[]model.AlertRule] +// @Router /alert-rule [get] +func listAlertRule(c *gin.Context) ([]*model.AlertRule, error) { + singleton.AlertsLock.RLock() + defer singleton.AlertsLock.RUnlock() + + var ar []*model.AlertRule + if err := copier.Copy(&ar, &singleton.Alerts); err != nil { + return nil, err + } + return ar, nil +} + +// Add Alert Rule +// @Summary Add Alert Rule +// @Security BearerAuth +// @Schemes +// @Description Add Alert Rule +// @Tags auth required +// @Accept json +// @param request body model.AlertRuleForm true "AlertRuleForm" +// @Produce json +// @Success 200 {object} model.CommonResponse[uint64] +// @Router /alert-rule [post] +func createAlertRule(c *gin.Context) (uint64, error) { + var arf model.AlertRuleForm + var r model.AlertRule + + if err := c.ShouldBindJSON(&arf); err != nil { + return 0, err + } + + if err := validateRule(&r); err != nil { + return 0, err + } + + r.Name = arf.Name + r.Rules = arf.Rules + r.FailTriggerTasks = arf.FailTriggerTasks + r.RecoverTriggerTasks = arf.RecoverTriggerTasks + r.NotificationGroupID = arf.NotificationGroupID + enable := arf.Enable + r.TriggerMode = arf.TriggerMode + r.Enable = &enable + r.ID = arf.ID + + if err := singleton.DB.Create(&r).Error; err != nil { + return 0, newGormError("%v", err) + } + + singleton.OnRefreshOrAddAlert(r) + return r.ID, nil +} + +// Update Alert Rule +// @Summary Update Alert Rule +// @Security BearerAuth +// @Schemes +// @Description Update Alert Rule +// @Tags auth required +// @Accept json +// @param id path uint true "Alert ID" +// @param request body model.AlertRuleForm true "AlertRuleForm" +// @Produce json +// @Success 200 {object} model.CommonResponse[any] +// @Router /alert-rule/{id} [patch] +func updateAlertRule(c *gin.Context) (any, error) { + idStr := c.Param("id") + id, err := strconv.ParseUint(idStr, 10, 64) + if err != nil { + return nil, err + } + + var arf model.AlertRuleForm + if err := c.ShouldBindJSON(&arf); err != nil { + return 0, err + } + + var r model.AlertRule + if err := singleton.DB.First(&r, id).Error; err != nil { + return nil, fmt.Errorf("alert id %d does not exist", id) + } + + if err := validateRule(&r); err != nil { + return 0, err + } + + r.Name = arf.Name + r.Rules = arf.Rules + r.FailTriggerTasks = arf.FailTriggerTasks + r.RecoverTriggerTasks = arf.RecoverTriggerTasks + r.NotificationGroupID = arf.NotificationGroupID + enable := arf.Enable + r.TriggerMode = arf.TriggerMode + r.Enable = &enable + r.ID = arf.ID + + if err := singleton.DB.Save(&r).Error; err != nil { + return 0, newGormError("%v", err) + } + + singleton.OnRefreshOrAddAlert(r) + return r.ID, nil +} + +// Batch delete Alert rules +// @Summary Batch delete Alert rules +// @Security BearerAuth +// @Schemes +// @Description Batch delete Alert rules +// @Tags auth required +// @Accept json +// @param request body []uint64 true "id list" +// @Produce json +// @Success 200 {object} model.CommonResponse[any] +// @Router /batch-delete/alert-rule [post] +func batchDeleteAlertRule(c *gin.Context) (any, error) { + var ar []uint64 + + if err := c.ShouldBindJSON(&ar); err != nil { + return nil, err + } + + if err := singleton.DB.Unscoped().Delete(&model.DDNSProfile{}, "id in (?)", ar).Error; err != nil { + return nil, newGormError("%v", err) + } + + singleton.OnDeleteAlert(ar) + return nil, nil +} + +func validateRule(r *model.AlertRule) error { + if len(r.Rules) > 0 { + for _, rule := range r.Rules { + if !rule.IsTransferDurationRule() { + if rule.Duration < 3 { + return errors.New("错误: Duration 至少为 3") + } + } else { + if rule.CycleInterval < 1 { + return errors.New("错误: cycle_interval 至少为 1") + } + if rule.CycleStart == nil { + return errors.New("错误: cycle_start 未设置") + } + if rule.CycleStart.After(time.Now()) { + return errors.New("错误: cycle_start 是个未来值") + } + } + } + } else { + return errors.New("至少定义一条规则") + } + return nil +} diff --git a/cmd/dashboard/controller/controller.go b/cmd/dashboard/controller/controller.go index 0376b15..2d0e673 100644 --- a/cmd/dashboard/controller/controller.go +++ b/cmd/dashboard/controller/controller.go @@ -92,6 +92,11 @@ func routers(r *gin.Engine) { auth.PATCH("/notification/:id", commonHandler(updateNotification)) auth.POST("/batch-delete/notification", commonHandler(batchDeleteNotification)) + auth.GET("/alert-rule", commonHandler(listAlertRule)) + auth.POST("/alert-rule", commonHandler(createAlertRule)) + auth.PATCH("/alert-rule/:id", commonHandler(updateAlertRule)) + auth.POST("/batch-delete/alert-rule", commonHandler(batchDeleteAlertRule)) + auth.GET("/ddns", commonHandler(listDDNS)) auth.GET("/ddns/providers", commonHandler(listProviders)) auth.POST("/ddns", commonHandler(createDDNS)) diff --git a/cmd/dashboard/controller/ddns.go b/cmd/dashboard/controller/ddns.go index 5384508..16c9c3d 100644 --- a/cmd/dashboard/controller/ddns.go +++ b/cmd/dashboard/controller/ddns.go @@ -4,7 +4,6 @@ import ( "errors" "fmt" "strconv" - "strings" "github.com/gin-gonic/gin" "github.com/jinzhu/copier" @@ -66,8 +65,7 @@ func createDDNS(c *gin.Context) (uint64, error) { p.EnableIPv6 = &enableIPv6 p.MaxRetries = df.MaxRetries p.Provider = df.Provider - p.DomainsRaw = df.DomainsRaw - p.Domains = strings.Split(p.DomainsRaw, ",") + p.Domains = df.Domains p.AccessID = df.AccessID p.AccessSecret = df.AccessSecret p.WebhookURL = df.WebhookURL @@ -137,8 +135,7 @@ func updateDDNS(c *gin.Context) (any, error) { p.EnableIPv6 = &enableIPv6 p.MaxRetries = df.MaxRetries p.Provider = df.Provider - p.DomainsRaw = df.DomainsRaw - p.Domains = strings.Split(p.DomainsRaw, ",") + p.Domains = df.Domains p.AccessID = df.AccessID p.AccessSecret = df.AccessSecret p.WebhookURL = df.WebhookURL diff --git a/cmd/dashboard/controller/member_api.go b/cmd/dashboard/controller/member_api.go index 409ead3..4e3e3a9 100644 --- a/cmd/dashboard/controller/member_api.go +++ b/cmd/dashboard/controller/member_api.go @@ -75,11 +75,6 @@ func (ma *memberAPI) delete(c *gin.Context) { } delete(singleton.Crons, id) } - case "alert-rule": - err = singleton.DB.Unscoped().Delete(&model.AlertRule{}, "id = ?", id).Error - if err == nil { - singleton.OnDeleteAlert(id) - } } if err != nil { c.JSON(http.StatusOK, model.Response{ diff --git a/model/alertrule.go b/model/alertrule.go index 8c68353..c375fd8 100644 --- a/model/alertrule.go +++ b/model/alertrule.go @@ -12,16 +12,16 @@ const ( type AlertRule struct { Common - Name string - RulesRaw string - Enable *bool - TriggerMode int `gorm:"default:0"` // 触发模式: 0-始终触发(默认) 1-单次触发 - NotificationGroupID uint64 // 该报警规则所在的通知组 - FailTriggerTasksRaw string `gorm:"default:'[]'"` - RecoverTriggerTasksRaw string `gorm:"default:'[]'"` - Rules []Rule `gorm:"-" json:"-"` - FailTriggerTasks []uint64 `gorm:"-" json:"-"` // 失败时执行的触发任务id - RecoverTriggerTasks []uint64 `gorm:"-" json:"-"` // 恢复时执行的触发任务id + Name string `json:"name,omitempty"` + RulesRaw string `json:"-"` + Enable *bool `json:"enable,omitempty"` + TriggerMode int `gorm:"default:0" json:"trigger_mode,omitempty"` // 触发模式: 0-始终触发(默认) 1-单次触发 + NotificationGroupID uint64 `json:"notification_group_id,omitempty"` // 该报警规则所在的通知组 + FailTriggerTasksRaw string `gorm:"default:'[]'" json:"-"` + RecoverTriggerTasksRaw string `gorm:"default:'[]'" json:"-"` + Rules []Rule `gorm:"-" json:"rules"` + FailTriggerTasks []uint64 `gorm:"-" json:"fail_trigger_tasks"` // 失败时执行的触发任务id + RecoverTriggerTasks []uint64 `gorm:"-" json:"recover_trigger_tasks"` // 恢复时执行的触发任务id } func (r *AlertRule) BeforeSave(tx *gorm.DB) error { diff --git a/model/alertrule_api.go b/model/alertrule_api.go new file mode 100644 index 0000000..8c93b3c --- /dev/null +++ b/model/alertrule_api.go @@ -0,0 +1,12 @@ +package model + +type AlertRuleForm struct { + ID uint64 `json:"id"` + Name string `json:"name"` + Rules []Rule `json:"rules"` + FailTriggerTasks []uint64 `json:"fail_trigger_tasks"` // 失败时触发的任务id + RecoverTriggerTasks []uint64 `json:"recover_trigger_tasks"` // 恢复时触发的任务id + NotificationGroupID uint64 `json:"notification_group_id"` + TriggerMode int `json:"trigger_mode"` + Enable bool `json:"enable"` +} diff --git a/model/ddns.go b/model/ddns.go index f63dfcb..989e412 100644 --- a/model/ddns.go +++ b/model/ddns.go @@ -3,6 +3,7 @@ package model import ( "strings" + "github.com/naiba/nezha/pkg/utils" "gorm.io/gorm" ) @@ -32,33 +33,25 @@ type DDNSProfile struct { WebhookRequestBody string `json:"webhook_request_body,omitempty"` WebhookHeaders string `json:"webhook_headers,omitempty"` Domains []string `json:"domains,omitempty" gorm:"-"` - DomainsRaw string `json:"domains_raw,omitempty"` + DomainsRaw string `json:"-"` } func (d DDNSProfile) TableName() string { return "ddns" } +func (d *DDNSProfile) BeforeSave(tx *gorm.DB) error { + if data, err := utils.Json.Marshal(d.Domains); err != nil { + return err + } else { + d.DomainsRaw = string(data) + } + return nil +} + func (d *DDNSProfile) AfterFind(tx *gorm.DB) error { if d.DomainsRaw != "" { d.Domains = strings.Split(d.DomainsRaw, ",") } return nil } - -type DDNSForm struct { - ID uint64 `json:"id,omitempty"` - MaxRetries uint64 `json:"max_retries,omitempty"` - EnableIPv4 bool `json:"enable_ipv4,omitempty"` - EnableIPv6 bool `json:"enable_ipv6,omitempty"` - Name string `json:"name,omitempty"` - Provider string `json:"provider,omitempty"` - DomainsRaw string `json:"domains_raw,omitempty"` - AccessID string `json:"access_id,omitempty"` - AccessSecret string `json:"access_secret,omitempty"` - WebhookURL string `json:"webhook_url,omitempty"` - WebhookMethod uint8 `json:"webhook_method,omitempty"` - WebhookRequestType uint8 `json:"webhook_request_type,omitempty"` - WebhookRequestBody string `json:"webhook_request_body,omitempty"` - WebhookHeaders string `json:"webhook_headers,omitempty"` -} diff --git a/model/ddns_api.go b/model/ddns_api.go new file mode 100644 index 0000000..e33b176 --- /dev/null +++ b/model/ddns_api.go @@ -0,0 +1,18 @@ +package model + +type DDNSForm struct { + ID uint64 `json:"id,omitempty"` + MaxRetries uint64 `json:"max_retries,omitempty"` + EnableIPv4 bool `json:"enable_ipv4,omitempty"` + EnableIPv6 bool `json:"enable_ipv6,omitempty"` + Name string `json:"name,omitempty"` + Provider string `json:"provider,omitempty"` + Domains []string `json:"domains,omitempty"` + AccessID string `json:"access_id,omitempty"` + AccessSecret string `json:"access_secret,omitempty"` + WebhookURL string `json:"webhook_url,omitempty"` + WebhookMethod uint8 `json:"webhook_method,omitempty"` + WebhookRequestType uint8 `json:"webhook_request_type,omitempty"` + WebhookRequestBody string `json:"webhook_request_body,omitempty"` + WebhookHeaders string `json:"webhook_headers,omitempty"` +} diff --git a/model/rule.go b/model/rule.go index b3e702d..de3706a 100644 --- a/model/rule.go +++ b/model/rule.go @@ -175,85 +175,85 @@ func (u *Rule) Snapshot(cycleTransferStats *CycleTransferStats, server *Server, } // IsTransferDurationRule 判断该规则是否属于周期流量规则 属于则返回true -func (rule Rule) IsTransferDurationRule() bool { - return strings.HasSuffix(rule.Type, "_cycle") +func (u *Rule) IsTransferDurationRule() bool { + return strings.HasSuffix(u.Type, "_cycle") } // GetTransferDurationStart 获取周期流量的起始时间 -func (rule Rule) GetTransferDurationStart() time.Time { +func (u *Rule) GetTransferDurationStart() time.Time { // Accept uppercase and lowercase - unit := strings.ToLower(rule.CycleUnit) - startTime := *rule.CycleStart + unit := strings.ToLower(u.CycleUnit) + startTime := *u.CycleStart var nextTime time.Time switch unit { case "year": - nextTime = startTime.AddDate(int(rule.CycleInterval), 0, 0) + nextTime = startTime.AddDate(int(u.CycleInterval), 0, 0) for time.Now().After(nextTime) { startTime = nextTime - nextTime = nextTime.AddDate(int(rule.CycleInterval), 0, 0) + nextTime = nextTime.AddDate(int(u.CycleInterval), 0, 0) } case "month": - nextTime = startTime.AddDate(0, int(rule.CycleInterval), 0) + nextTime = startTime.AddDate(0, int(u.CycleInterval), 0) for time.Now().After(nextTime) { startTime = nextTime - nextTime = nextTime.AddDate(0, int(rule.CycleInterval), 0) + nextTime = nextTime.AddDate(0, int(u.CycleInterval), 0) } case "week": - nextTime = startTime.AddDate(0, 0, 7*int(rule.CycleInterval)) + nextTime = startTime.AddDate(0, 0, 7*int(u.CycleInterval)) for time.Now().After(nextTime) { startTime = nextTime - nextTime = nextTime.AddDate(0, 0, 7*int(rule.CycleInterval)) + nextTime = nextTime.AddDate(0, 0, 7*int(u.CycleInterval)) } case "day": - nextTime = startTime.AddDate(0, 0, int(rule.CycleInterval)) + nextTime = startTime.AddDate(0, 0, int(u.CycleInterval)) for time.Now().After(nextTime) { startTime = nextTime - nextTime = nextTime.AddDate(0, 0, int(rule.CycleInterval)) + nextTime = nextTime.AddDate(0, 0, int(u.CycleInterval)) } default: // For hour unit or not set. - interval := 3600 * int64(rule.CycleInterval) - startTime = time.Unix(rule.CycleStart.Unix()+(time.Now().Unix()-rule.CycleStart.Unix())/interval*interval, 0) + interval := 3600 * int64(u.CycleInterval) + startTime = time.Unix(u.CycleStart.Unix()+(time.Now().Unix()-u.CycleStart.Unix())/interval*interval, 0) } return startTime } // GetTransferDurationEnd 获取周期流量结束时间 -func (rule Rule) GetTransferDurationEnd() time.Time { +func (u *Rule) GetTransferDurationEnd() time.Time { // Accept uppercase and lowercase - unit := strings.ToLower(rule.CycleUnit) - startTime := *rule.CycleStart + unit := strings.ToLower(u.CycleUnit) + startTime := *u.CycleStart var nextTime time.Time switch unit { case "year": - nextTime = startTime.AddDate(int(rule.CycleInterval), 0, 0) + nextTime = startTime.AddDate(int(u.CycleInterval), 0, 0) for time.Now().After(nextTime) { startTime = nextTime - nextTime = nextTime.AddDate(int(rule.CycleInterval), 0, 0) + nextTime = nextTime.AddDate(int(u.CycleInterval), 0, 0) } case "month": - nextTime = startTime.AddDate(0, int(rule.CycleInterval), 0) + nextTime = startTime.AddDate(0, int(u.CycleInterval), 0) for time.Now().After(nextTime) { startTime = nextTime - nextTime = nextTime.AddDate(0, int(rule.CycleInterval), 0) + nextTime = nextTime.AddDate(0, int(u.CycleInterval), 0) } case "week": - nextTime = startTime.AddDate(0, 0, 7*int(rule.CycleInterval)) + nextTime = startTime.AddDate(0, 0, 7*int(u.CycleInterval)) for time.Now().After(nextTime) { startTime = nextTime - nextTime = nextTime.AddDate(0, 0, 7*int(rule.CycleInterval)) + nextTime = nextTime.AddDate(0, 0, 7*int(u.CycleInterval)) } case "day": - nextTime = startTime.AddDate(0, 0, int(rule.CycleInterval)) + nextTime = startTime.AddDate(0, 0, int(u.CycleInterval)) for time.Now().After(nextTime) { startTime = nextTime - nextTime = nextTime.AddDate(0, 0, int(rule.CycleInterval)) + nextTime = nextTime.AddDate(0, 0, int(u.CycleInterval)) } default: // For hour unit or not set. - interval := 3600 * int64(rule.CycleInterval) - startTime = time.Unix(rule.CycleStart.Unix()+(time.Now().Unix()-rule.CycleStart.Unix())/interval*interval, 0) + interval := 3600 * int64(u.CycleInterval) + startTime = time.Unix(u.CycleStart.Unix()+(time.Now().Unix()-u.CycleStart.Unix())/interval*interval, 0) nextTime = time.Unix(startTime.Unix()+interval, 0) } diff --git a/service/singleton/alertsentinel.go b/service/singleton/alertsentinel.go index a15f17b..ea63a9f 100644 --- a/service/singleton/alertsentinel.go +++ b/service/singleton/alertsentinel.go @@ -27,7 +27,7 @@ var ( AlertsLock sync.RWMutex Alerts []*model.AlertRule alertsStore map[uint64]map[uint64][][]interface{} // [alert_id][server_id] -> 对应报警规则的检查结果 - alertsPrevState map[uint64]map[uint64]uint // [alert_id][server_id] -> 对应报警规则的上一次报警状态 + alertsPrevState map[uint64]map[uint64]uint8 // [alert_id][server_id] -> 对应报警规则的上一次报警状态 AlertsCycleTransferStatsStore map[uint64]*model.CycleTransferStats // [alert_id] -> 对应报警规则的周期流量统计 ) @@ -60,7 +60,7 @@ func addCycleTransferStatsInfo(alert *model.AlertRule) { // AlertSentinelStart 报警器启动 func AlertSentinelStart() { alertsStore = make(map[uint64]map[uint64][][]interface{}) - alertsPrevState = make(map[uint64]map[uint64]uint) + alertsPrevState = make(map[uint64]map[uint64]uint8) AlertsCycleTransferStatsStore = make(map[uint64]*model.CycleTransferStats) AlertsLock.Lock() if err := DB.Find(&Alerts).Error; err != nil { @@ -68,7 +68,7 @@ func AlertSentinelStart() { } for _, alert := range Alerts { alertsStore[alert.ID] = make(map[uint64][][]interface{}) - alertsPrevState[alert.ID] = make(map[uint64]uint) + alertsPrevState[alert.ID] = make(map[uint64]uint8) addCycleTransferStatsInfo(alert) } AlertsLock.Unlock() @@ -107,23 +107,26 @@ func OnRefreshOrAddAlert(alert model.AlertRule) { Alerts = append(Alerts, &alert) } alertsStore[alert.ID] = make(map[uint64][][]interface{}) - alertsPrevState[alert.ID] = make(map[uint64]uint) + alertsPrevState[alert.ID] = make(map[uint64]uint8) delete(AlertsCycleTransferStatsStore, alert.ID) addCycleTransferStatsInfo(&alert) } -func OnDeleteAlert(id uint64) { +func OnDeleteAlert(id []uint64) { AlertsLock.Lock() defer AlertsLock.Unlock() - delete(alertsStore, id) - delete(alertsPrevState, id) - for i := 0; i < len(Alerts); i++ { - if Alerts[i].ID == id { - Alerts = append(Alerts[:i], Alerts[i+1:]...) - i-- + for _, i := range id { + delete(alertsStore, i) + delete(alertsPrevState, i) + currentAlerts := Alerts[:0] + for _, alert := range Alerts { + if alert.ID != i { + currentAlerts = append(currentAlerts, alert) + } } + Alerts = currentAlerts + delete(AlertsCycleTransferStatsStore, i) } - delete(AlertsCycleTransferStatsStore, id) } // checkStatus 检查报警规则并发送报警