From 55f5c89c1ca142ba79702a159074ac9cbb9389fe Mon Sep 17 00:00:00 2001 From: UUBulb <35923940+uubulb@users.noreply.github.com> Date: Thu, 10 Oct 2024 00:08:16 +0800 Subject: [PATCH] feat: description file for custom theme; use gjson (#433) * feat: description file for custom theme; use gjson * fix gosec * remove outdated stuff --- cmd/dashboard/controller/controller.go | 66 +++++++++------- go.mod | 3 + go.sum | 6 ++ model/config.go | 1 - model/notification.go | 8 +- pkg/ddns/cloudflare.go | 103 ++++++++----------------- pkg/ddns/ddns.go | 7 ++ pkg/ddns/tencentcloud.go | 98 ++++++++++------------- pkg/utils/gjson.go | 36 +++++++++ pkg/utils/hfs.go | 33 ++++++++ resource/resource.go | 14 +++- service/rpc/nezha.go | 2 +- 12 files changed, 216 insertions(+), 161 deletions(-) create mode 100644 pkg/utils/gjson.go create mode 100644 pkg/utils/hfs.go diff --git a/cmd/dashboard/controller/controller.go b/cmd/dashboard/controller/controller.go index 50d4745..518769b 100644 --- a/cmd/dashboard/controller/controller.go +++ b/cmd/dashboard/controller/controller.go @@ -3,10 +3,10 @@ package controller import ( "fmt" "html/template" - "io/fs" "log" "net/http" "os" + "path/filepath" "strconv" "strings" "time" @@ -34,28 +34,16 @@ func ServeWeb(port uint) *http.Server { pprof.Register(r) } r.Use(natGateway) - if os.Getenv("NZ_LOCAL_TEMPLATE") == "true" { - r.SetFuncMap(funcMap) - r.Use(mygin.RecordPath) - r.Static("/static", "resource/static") - r.LoadHTMLGlob("resource/template/**/*.html") - } else { - tmpl := template.New("").Funcs(funcMap) - var err error - tmpl, err = tmpl.ParseFS(resource.TemplateFS, "template/**/*.html") - if err != nil { - panic(err) - } - tmpl = loadThirdPartyTemplates(tmpl) - r.SetHTMLTemplate(tmpl) - r.Use(mygin.RecordPath) - staticFs, err := fs.Sub(resource.StaticFS, "static") - if err != nil { - panic(err) - } - r.StaticFS("/static", http.FS(staticFs)) + tmpl := template.New("").Funcs(funcMap) + var err error + tmpl, err = tmpl.ParseFS(resource.TemplateFS, "template/**/*.html") + if err != nil { + panic(err) } - r.Static("/static-custom", "resource/static/custom") + tmpl = loadThirdPartyTemplates(tmpl) + r.SetHTMLTemplate(tmpl) + r.Use(mygin.RecordPath) + r.StaticFS("/static", http.FS(resource.StaticFS)) routers(r) page404 := func(c *gin.Context) { mygin.ShowErrorPage(c, mygin.ErrInfo{ @@ -106,14 +94,40 @@ func loadThirdPartyTemplates(tmpl *template.Template) *template.Template { if !theme.IsDir() { continue } - // load templates - t, err := ret.ParseGlob(fmt.Sprintf("resource/template/%s/*.html", theme.Name())) - if err != nil { - log.Printf("NEZHA>> Error parsing templates %s error: %v", theme.Name(), err) + + themeDir := theme.Name() + if !strings.HasPrefix(themeDir, "theme-") { + log.Printf("NEZHA>> Invalid theme name: %s", themeDir) continue } + + descPath := filepath.Join("resource", "template", themeDir, "theme.json") + desc, err := os.ReadFile(filepath.Clean(descPath)) + if err != nil { + log.Printf("NEZHA>> Error opening %s config: %v", themeDir, err) + continue + } + + themeName, err := utils.GjsonGet(desc, "name") + if err != nil { + log.Printf("NEZHA>> Error opening %s config: not a valid description file", theme.Name()) + continue + } + + // load templates + templatePath := filepath.Join("resource", "template", themeDir, "*.html") + t, err := ret.ParseGlob(templatePath) + if err != nil { + log.Printf("NEZHA>> Error parsing templates %s: %v", themeDir, err) + continue + } + + themeKey := strings.TrimPrefix(themeDir, "theme-") + model.Themes[themeKey] = themeName.String() + ret = t } + return ret } diff --git a/go.mod b/go.mod index bface7c..beae780 100644 --- a/go.mod +++ b/go.mod @@ -21,6 +21,7 @@ require ( github.com/robfig/cron/v3 v3.0.1 github.com/spf13/pflag v1.0.5 github.com/spf13/viper v1.18.2 + github.com/tidwall/gjson v1.18.0 github.com/xanzy/go-gitlab v0.103.0 golang.org/x/crypto v0.25.0 golang.org/x/net v0.27.0 @@ -70,6 +71,8 @@ require ( github.com/spf13/afero v1.11.0 // indirect github.com/spf13/cast v1.6.0 // indirect github.com/subosito/gotenv v1.6.0 // indirect + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.11 // indirect go.uber.org/atomic v1.9.0 // indirect diff --git a/go.sum b/go.sum index e0cbf40..03c89dd 100644 --- a/go.sum +++ b/go.sum @@ -180,6 +180,12 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= +github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= +github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6M= diff --git a/model/config.go b/model/config.go index 91d7da7..17fe7e3 100644 --- a/model/config.go +++ b/model/config.go @@ -23,7 +23,6 @@ var Themes = map[string]string{ "hotaru": "Hotaru", "angel-kanade": "AngelKanade", "server-status": "ServerStatus", - "custom": "Custom(local)", } var DashboardThemes = map[string]string{ diff --git a/model/notification.go b/model/notification.go index dd9c128..4207367 100644 --- a/model/notification.go +++ b/model/notification.go @@ -71,8 +71,8 @@ func (ns *NotificationServerBundle) reqBody(message string) (string, error) { return string(msgBytes)[1 : len(msgBytes)-1] }), nil case NotificationRequestTypeForm: - var data map[string]string - if err := utils.Json.Unmarshal([]byte(n.RequestBody), &data); err != nil { + data, err := utils.GjsonParseStringMap(n.RequestBody) + if err != nil { return "", err } params := url.Values{} @@ -99,8 +99,8 @@ func (n *Notification) setRequestHeader(req *http.Request) error { if n.RequestHeader == "" { return nil } - var m map[string]string - if err := utils.Json.Unmarshal([]byte(n.RequestHeader), &m); err != nil { + m, err := utils.GjsonParseStringMap(n.RequestHeader) + if err != nil { return err } for k, v := range m { diff --git a/pkg/ddns/cloudflare.go b/pkg/ddns/cloudflare.go index e19ab79..ef14d9d 100644 --- a/pkg/ddns/cloudflare.go +++ b/pkg/ddns/cloudflare.go @@ -2,6 +2,7 @@ package ddns import ( "bytes" + "errors" "fmt" "io" "log" @@ -14,10 +15,13 @@ import ( const baseEndpoint = "https://api.cloudflare.com/client/v4/zones" type ProviderCloudflare struct { + isIpv4 bool + domainConfig *DomainConfig secret string zoneId string + ipAddr string recordId string - domainConfig *DomainConfig + recordType string } type cfReq struct { @@ -28,13 +32,6 @@ type cfReq struct { Proxied bool `json:"proxied"` } -type cfResp struct { - Result []struct { - ID string `json:"id"` - Name string `json:"name"` - } `json:"result"` -} - func NewProviderCloudflare(s string) *ProviderCloudflare { return &ProviderCloudflare{ secret: s, @@ -54,13 +51,19 @@ func (provider *ProviderCloudflare) UpdateDomain(domainConfig *DomainConfig) err // 当IPv4和IPv6同时成功才算作成功 if provider.domainConfig.EnableIPv4 { - if err = provider.addDomainRecord(true); err != nil { + provider.isIpv4 = true + provider.recordType = getRecordString(provider.isIpv4) + provider.ipAddr = provider.domainConfig.Ipv4Addr + if err = provider.addDomainRecord(); err != nil { return err } } if provider.domainConfig.EnableIpv6 { - if err = provider.addDomainRecord(false); err != nil { + provider.isIpv4 = false + provider.recordType = getRecordString(provider.isIpv4) + provider.ipAddr = provider.domainConfig.Ipv6Addr + if err = provider.addDomainRecord(); err != nil { return err } } @@ -68,19 +71,18 @@ func (provider *ProviderCloudflare) UpdateDomain(domainConfig *DomainConfig) err return nil } -func (provider *ProviderCloudflare) addDomainRecord(isIpv4 bool) error { - err := provider.findDNSRecord(isIpv4) +func (provider *ProviderCloudflare) addDomainRecord() error { + err := provider.findDNSRecord() if err != nil { + if errors.Is(err, utils.ErrGjsonNotFound) { + // 添加 DNS 记录 + return provider.createDNSRecord() + } return fmt.Errorf("查找 DNS 记录时出错: %s", err) } - if provider.recordId == "" { - // 添加 DNS 记录 - return provider.createDNSRecord(isIpv4) - } else { - // 更新 DNS 记录 - return provider.updateDNSRecord(isIpv4) - } + // 更新 DNS 记录 + return provider.updateDNSRecord() } func (provider *ProviderCloudflare) getZoneID() error { @@ -96,35 +98,22 @@ func (provider *ProviderCloudflare) getZoneID() error { return err } - res := &cfResp{} - err = utils.Json.Unmarshal(body, res) + result, err := utils.GjsonGet(body, "result.0.id") if err != nil { return err } - result := res.Result - if len(result) > 0 { - provider.zoneId = result[0].ID - return nil - } - - return fmt.Errorf("找不到 Zone ID") + provider.zoneId = result.String() + return nil } -func (provider *ProviderCloudflare) findDNSRecord(isIPv4 bool) error { - var ipType string - if isIPv4 { - ipType = "A" - } else { - ipType = "AAAA" - } - +func (provider *ProviderCloudflare) findDNSRecord() error { de, _ := url.JoinPath(baseEndpoint, provider.zoneId, "dns_records") du, _ := url.Parse(de) q := du.Query() q.Set("name", provider.domainConfig.FullDomain) - q.Set("type", ipType) + q.Set("type", provider.recordType) du.RawQuery = q.Encode() body, err := provider.sendRequest("GET", du.String(), nil) @@ -132,36 +121,21 @@ func (provider *ProviderCloudflare) findDNSRecord(isIPv4 bool) error { return err } - res := &cfResp{} - err = utils.Json.Unmarshal(body, res) + result, err := utils.GjsonGet(body, "result.0.id") if err != nil { return err } - result := res.Result - if len(result) > 0 { - provider.recordId = result[0].ID - return nil - } - + provider.recordId = result.String() return nil } -func (provider *ProviderCloudflare) createDNSRecord(isIPv4 bool) error { - var ipType, ipAddr string - if isIPv4 { - ipType = "A" - ipAddr = provider.domainConfig.Ipv4Addr - } else { - ipType = "AAAA" - ipAddr = provider.domainConfig.Ipv6Addr - } - +func (provider *ProviderCloudflare) createDNSRecord() error { de, _ := url.JoinPath(baseEndpoint, provider.zoneId, "dns_records") data := &cfReq{ Name: provider.domainConfig.FullDomain, - Type: ipType, - Content: ipAddr, + Type: provider.recordType, + Content: provider.ipAddr, TTL: 60, Proxied: false, } @@ -171,21 +145,12 @@ func (provider *ProviderCloudflare) createDNSRecord(isIPv4 bool) error { return err } -func (provider *ProviderCloudflare) updateDNSRecord(isIPv4 bool) error { - var ipType, ipAddr string - if isIPv4 { - ipType = "A" - ipAddr = provider.domainConfig.Ipv4Addr - } else { - ipType = "AAAA" - ipAddr = provider.domainConfig.Ipv6Addr - } - +func (provider *ProviderCloudflare) updateDNSRecord() error { de, _ := url.JoinPath(baseEndpoint, provider.zoneId, "dns_records", provider.recordId) data := &cfReq{ Name: provider.domainConfig.FullDomain, - Type: ipType, - Content: ipAddr, + Type: provider.recordType, + Content: provider.ipAddr, TTL: 60, Proxied: false, } diff --git a/pkg/ddns/ddns.go b/pkg/ddns/ddns.go index 2bf9251..b3e45ff 100644 --- a/pkg/ddns/ddns.go +++ b/pkg/ddns/ddns.go @@ -20,3 +20,10 @@ func splitDomain(domain string) (prefix string, realDomain string) { prefix = domain[:len(domain)-len(realDomain)-1] return prefix, realDomain } + +func getRecordString(isIpv4 bool) string { + if isIpv4 { + return "A" + } + return "AAAA" +} diff --git a/pkg/ddns/tencentcloud.go b/pkg/ddns/tencentcloud.go index fb75d79..61d7568 100644 --- a/pkg/ddns/tencentcloud.go +++ b/pkg/ddns/tencentcloud.go @@ -5,6 +5,7 @@ import ( "crypto/hmac" "crypto/sha256" "encoding/hex" + "errors" "fmt" "io" "log" @@ -19,10 +20,14 @@ import ( const te = "https://dnspod.tencentcloudapi.com" type ProviderTencentCloud struct { + isIpv4 bool + domainConfig *DomainConfig + recordID uint64 + recordType string secretID string secretKey string - domainConfig *DomainConfig - resp *tcResp + errCode string + ipAddr string } type tcReq struct { @@ -36,18 +41,6 @@ type tcReq struct { RecordId uint64 `json:"RecordId,omitempty"` } -type tcResp struct { - Response struct { - RecordList []struct { - RecordId uint64 - Value string - } - Error struct { - Code string - } - } -} - func NewProviderTencentCloud(id, key string) *ProviderTencentCloud { return &ProviderTencentCloud{ secretID: id, @@ -64,13 +57,19 @@ func (provider *ProviderTencentCloud) UpdateDomain(domainConfig *DomainConfig) e // 当IPv4和IPv6同时成功才算作成功 var err error if provider.domainConfig.EnableIPv4 { - if err = provider.addDomainRecord(true); err != nil { + provider.isIpv4 = true + provider.recordType = getRecordString(provider.isIpv4) + provider.ipAddr = provider.domainConfig.Ipv4Addr + if err = provider.addDomainRecord(); err != nil { return err } } if provider.domainConfig.EnableIpv6 { - if err = provider.addDomainRecord(false); err != nil { + provider.isIpv4 = false + provider.recordType = getRecordString(provider.isIpv4) + provider.ipAddr = provider.domainConfig.Ipv6Addr + if err = provider.addDomainRecord(); err != nil { return err } } @@ -78,33 +77,26 @@ func (provider *ProviderTencentCloud) UpdateDomain(domainConfig *DomainConfig) e return err } -func (provider *ProviderTencentCloud) addDomainRecord(isIpv4 bool) error { - err := provider.findDNSRecord(isIpv4) +func (provider *ProviderTencentCloud) addDomainRecord() error { + err := provider.findDNSRecord() if err != nil { return fmt.Errorf("查找 DNS 记录时出错: %s", err) } - if provider.resp.Response.Error.Code == "ResourceNotFound.NoDataOfRecord" { // 没有找到 DNS 记录 - return provider.createDNSRecord(isIpv4) - } else if provider.resp.Response.Error.Code != "" { - return fmt.Errorf("查询 DNS 记录时出错,错误代码为: %s", provider.resp.Response.Error.Code) + if provider.errCode == "ResourceNotFound.NoDataOfRecord" { // 没有找到 DNS 记录 + return provider.createDNSRecord() + } else if provider.errCode != "" { + return fmt.Errorf("查询 DNS 记录时出错,错误代码为: %s", provider.errCode) } // 默认情况下更新 DNS 记录 - return provider.updateDNSRecord(isIpv4) + return provider.updateDNSRecord() } -func (provider *ProviderTencentCloud) findDNSRecord(isIPv4 bool) error { - var ipType string - if isIPv4 { - ipType = "A" - } else { - ipType = "AAAA" - } - +func (provider *ProviderTencentCloud) findDNSRecord() error { prefix, realDomain := splitDomain(provider.domainConfig.FullDomain) data := &tcReq{ - RecordType: ipType, + RecordType: provider.recordType, Domain: realDomain, RecordLine: "默认", Subdomain: prefix, @@ -116,32 +108,29 @@ func (provider *ProviderTencentCloud) findDNSRecord(isIPv4 bool) error { return err } - provider.resp = &tcResp{} - err = utils.Json.Unmarshal(body, provider.resp) + result, err := utils.GjsonGet(body, "Response.RecordList.0.RecordId") if err != nil { + if errors.Is(err, utils.ErrGjsonNotFound) { + if errCode, err := utils.GjsonGet(body, "Response.Error.Code"); err == nil { + provider.errCode = errCode.String() + return nil + } + } return err } + provider.recordID = result.Uint() return nil } -func (provider *ProviderTencentCloud) createDNSRecord(isIPv4 bool) error { - var ipType, ipAddr string - if isIPv4 { - ipType = "A" - ipAddr = provider.domainConfig.Ipv4Addr - } else { - ipType = "AAAA" - ipAddr = provider.domainConfig.Ipv6Addr - } - +func (provider *ProviderTencentCloud) createDNSRecord() error { prefix, realDomain := splitDomain(provider.domainConfig.FullDomain) data := &tcReq{ - RecordType: ipType, + RecordType: provider.recordType, RecordLine: "默认", Domain: realDomain, SubDomain: prefix, - Value: ipAddr, + Value: provider.ipAddr, TTL: 600, } @@ -150,25 +139,16 @@ func (provider *ProviderTencentCloud) createDNSRecord(isIPv4 bool) error { return err } -func (provider *ProviderTencentCloud) updateDNSRecord(isIPv4 bool) error { - var ipType, ipAddr string - if isIPv4 { - ipType = "A" - ipAddr = provider.domainConfig.Ipv4Addr - } else { - ipType = "AAAA" - ipAddr = provider.domainConfig.Ipv6Addr - } - +func (provider *ProviderTencentCloud) updateDNSRecord() error { prefix, realDomain := splitDomain(provider.domainConfig.FullDomain) data := &tcReq{ - RecordType: ipType, + RecordType: provider.recordType, RecordLine: "默认", Domain: realDomain, SubDomain: prefix, - Value: ipAddr, + Value: provider.ipAddr, TTL: 600, - RecordId: provider.resp.Response.RecordList[0].RecordId, + RecordId: provider.recordID, } jsonData, _ := utils.Json.Marshal(data) diff --git a/pkg/utils/gjson.go b/pkg/utils/gjson.go new file mode 100644 index 0000000..de377dd --- /dev/null +++ b/pkg/utils/gjson.go @@ -0,0 +1,36 @@ +package utils + +import ( + "errors" + + "github.com/tidwall/gjson" +) + +var ( + ErrGjsonNotFound = errors.New("specified path does not exist") + ErrGjsonWrongType = errors.New("wrong type") +) + +func GjsonGet(json []byte, path string) (gjson.Result, error) { + result := gjson.GetBytes(json, path) + if !result.Exists() { + return result, ErrGjsonNotFound + } + + return result, nil +} + +func GjsonParseStringMap(jsonObject string) (map[string]string, error) { + result := gjson.Parse(jsonObject) + if !result.IsObject() { + return nil, ErrGjsonWrongType + } + + ret := make(map[string]string) + result.ForEach(func(key, value gjson.Result) bool { + ret[key.String()] = value.String() + return true + }) + + return ret, nil +} diff --git a/pkg/utils/hfs.go b/pkg/utils/hfs.go new file mode 100644 index 0000000..de62e31 --- /dev/null +++ b/pkg/utils/hfs.go @@ -0,0 +1,33 @@ +package utils + +import ( + "embed" + "io/fs" + "os" +) + +// HybridFS combines embed.FS and os.DirFS. +type HybridFS struct { + embedFS, dir fs.FS +} + +func NewHybridFS(embed embed.FS, subDir string, localDir string) (*HybridFS, error) { + subFS, err := fs.Sub(embed, subDir) + if err != nil { + return nil, err + } + + return &HybridFS{ + embedFS: subFS, + dir: os.DirFS(localDir), + }, nil +} + +func (hfs *HybridFS) Open(name string) (fs.File, error) { + // Ensure embed files are not replaced + if file, err := hfs.embedFS.Open(name); err == nil { + return file, nil + } + + return hfs.dir.Open(name) +} diff --git a/resource/resource.go b/resource/resource.go index 635b4bc..ca9d965 100644 --- a/resource/resource.go +++ b/resource/resource.go @@ -2,10 +2,14 @@ package resource import ( "embed" + + "github.com/naiba/nezha/pkg/utils" ) +var StaticFS *utils.HybridFS + //go:embed static -var StaticFS embed.FS +var staticFS embed.FS //go:embed template var TemplateFS embed.FS @@ -13,6 +17,14 @@ var TemplateFS embed.FS //go:embed l10n var I18nFS embed.FS +func init() { + var err error + StaticFS, err = utils.NewHybridFS(staticFS, "static", "resource/static/custom") + if err != nil { + panic(err) + } +} + func IsTemplateFileExist(name string) bool { _, err := TemplateFS.Open(name) return err == nil diff --git a/service/rpc/nezha.go b/service/rpc/nezha.go index 10300c8..a1a7b57 100644 --- a/service/rpc/nezha.go +++ b/service/rpc/nezha.go @@ -155,8 +155,8 @@ func (s *NezhaHandler) ReportSystemInfo(c context.Context, r *pb.Host) (*pb.Rece Ipv4Addr: ipv4, Ipv6Addr: ipv6, } - go singleton.RetryableUpdateDomain(provider, config, maxRetries) + go singleton.RetryableUpdateDomain(provider, config, maxRetries) } else { // 虽然会在启动时panic, 可以断言不会走这个分支, 但是考虑到动态加载配置或者其它情况, 这里输出一下方便检查奇奇怪怪的BUG log.Printf("NEZHA>> 未找到对应的DDNS配置(%s), 或者是provider填写不正确, 请前往config.yml检查你的设置", singleton.ServerList[clientID].DDNSProfile)