From 4635bcf44fc95c7a2f342c58340dd982380f51a6 Mon Sep 17 00:00:00 2001 From: naiba Date: Wed, 23 Oct 2024 17:56:51 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=92=84=20refactor=20common=20handler?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cmd/dashboard/controller/controller.go | 25 +++++---- cmd/dashboard/controller/ddns.go | 71 ++++++++++-------------- cmd/dashboard/controller/server.go | 39 +++++-------- cmd/dashboard/controller/server_group.go | 61 +++++++++----------- cmd/dashboard/controller/terminal.go | 36 +++++------- cmd/dashboard/controller/user.go | 34 ++++++------ cmd/dashboard/controller/ws.go | 6 +- 7 files changed, 117 insertions(+), 155 deletions(-) diff --git a/cmd/dashboard/controller/controller.go b/cmd/dashboard/controller/controller.go index 089c601..2a8508e 100644 --- a/cmd/dashboard/controller/controller.go +++ b/cmd/dashboard/controller/controller.go @@ -159,7 +159,7 @@ func newErrorResponse(err error) model.CommonResponse[any] { } } -type handlerFunc func(c *gin.Context) error +type handlerFunc[T any] func(c *gin.Context) (T, error) // There are many error types in gorm, so create a custom type to represent all // gorm errors here instead @@ -179,17 +179,20 @@ func (ge *gormError) Error() string { return fmt.Sprintf(ge.msg, ge.a...) } -func commonHandler(handler handlerFunc) func(*gin.Context) { +func commonHandler[T any](handler handlerFunc[T]) func(*gin.Context) { return func(c *gin.Context) { - if err := handler(c); err != nil { - if _, ok := err.(*gormError); ok { - log.Printf("NEZHA>> gorm error: %v", err) - c.JSON(http.StatusOK, newErrorResponse(errors.New("database error"))) - return - } else { - c.JSON(http.StatusOK, newErrorResponse(err)) - return - } + data, err := handler(c) + if err == nil { + c.JSON(http.StatusOK, model.CommonResponse[T]{Success: true, Data: data}) + return + } + if _, ok := err.(*gormError); ok { + log.Printf("NEZHA>> gorm error: %v", err) + c.JSON(http.StatusOK, newErrorResponse(errors.New("database error"))) + return + } else { + c.JSON(http.StatusOK, newErrorResponse(err)) + return } } } diff --git a/cmd/dashboard/controller/ddns.go b/cmd/dashboard/controller/ddns.go index 059f7d4..8bae96d 100644 --- a/cmd/dashboard/controller/ddns.go +++ b/cmd/dashboard/controller/ddns.go @@ -3,7 +3,6 @@ package controller import ( "errors" "fmt" - "net/http" "strconv" "strings" @@ -23,18 +22,18 @@ import ( // @Accept json // @param request body model.DDNSForm true "DDNS Request" // @Produce json -// @Success 200 {object} model.CommonResponse[any] +// @Success 200 {object} model.CommonResponse[uint64] // @Router /ddns [post] -func createDDNS(c *gin.Context) error { +func createDDNS(c *gin.Context) (uint64, error) { var df model.DDNSForm var p model.DDNSProfile if err := c.ShouldBindJSON(&df); err != nil { - return err + return 0, err } if df.MaxRetries < 1 || df.MaxRetries > 10 { - return errors.New("重试次数必须为大于 1 且不超过 10 的整数") + return 0, errors.New("重试次数必须为大于 1 且不超过 10 的整数") } p.Name = df.Name @@ -58,20 +57,18 @@ func createDDNS(c *gin.Context) error { // IDN to ASCII domainValid, domainErr := idna.Lookup.ToASCII(domain) if domainErr != nil { - return fmt.Errorf("域名 %s 解析错误: %v", domain, domainErr) + return 0, fmt.Errorf("域名 %s 解析错误: %v", domain, domainErr) } p.Domains[n] = domainValid } if err := singleton.DB.Create(&p).Error; err != nil { - return newGormError("%v", err) + return 0, newGormError("%v", err) } singleton.OnDDNSUpdate() - c.JSON(http.StatusOK, model.CommonResponse[any]{ - Success: true, - }) - return nil + + return p.ID, nil } // Edit DDNS profile @@ -86,26 +83,26 @@ func createDDNS(c *gin.Context) error { // @Produce json // @Success 200 {object} model.CommonResponse[any] // @Router /ddns/{id} [patch] -func updateDDNS(c *gin.Context) error { +func updateDDNS(c *gin.Context) (any, error) { idStr := c.Param("id") id, err := strconv.ParseUint(idStr, 10, 64) if err != nil { - return err + return nil, err } var df model.DDNSForm if err := c.ShouldBindJSON(&df); err != nil { - return err + return nil, err } if df.MaxRetries < 1 || df.MaxRetries > 10 { - return errors.New("重试次数必须为大于 1 且不超过 10 的整数") + return nil, errors.New("重试次数必须为大于 1 且不超过 10 的整数") } var p model.DDNSProfile if err = singleton.DB.First(&p, id).Error; err != nil { - return fmt.Errorf("profile id %d does not exist", id) + return nil, fmt.Errorf("profile id %d does not exist", id) } p.Name = df.Name @@ -130,20 +127,18 @@ func updateDDNS(c *gin.Context) error { // IDN to ASCII domainValid, domainErr := idna.Lookup.ToASCII(domain) if domainErr != nil { - return fmt.Errorf("域名 %s 解析错误: %v", domain, domainErr) + return nil, fmt.Errorf("域名 %s 解析错误: %v", domain, domainErr) } p.Domains[n] = domainValid } if err = singleton.DB.Save(&p).Error; err != nil { - return newGormError("%v", err) + return nil, newGormError("%v", err) } singleton.OnDDNSUpdate() - c.JSON(http.StatusOK, model.CommonResponse[any]{ - Success: true, - }) - return nil + + return nil, nil } // Batch delete DDNS configurations @@ -157,22 +152,20 @@ func updateDDNS(c *gin.Context) error { // @Produce json // @Success 200 {object} model.CommonResponse[any] // @Router /batch-delete/ddns [post] -func batchDeleteDDNS(c *gin.Context) error { +func batchDeleteDDNS(c *gin.Context) (any, error) { var ddnsConfigs []uint64 if err := c.ShouldBindJSON(&ddnsConfigs); err != nil { - return err + return nil, err } if err := singleton.DB.Unscoped().Delete(&model.DDNSProfile{}, "id in (?)", ddnsConfigs).Error; err != nil { - return newGormError("%v", err) + return nil, newGormError("%v", err) } singleton.OnDDNSUpdate() - c.JSON(http.StatusOK, model.CommonResponse[any]{ - Success: true, - }) - return nil + + return nil, nil } // List DDNS Profiles @@ -185,7 +178,7 @@ func batchDeleteDDNS(c *gin.Context) error { // @Produce json // @Success 200 {object} model.CommonResponse[[]model.DDNSProfile] // @Router /ddns [get] -func listDDNS(c *gin.Context) error { +func listDDNS(c *gin.Context) ([]model.DDNSProfile, error) { var idList []uint64 idQuery := c.Query("id") @@ -195,7 +188,7 @@ func listDDNS(c *gin.Context) error { for _, v := range idListStr { id, err := strconv.ParseUint(v, 10, 64) if err != nil { - return err + return nil, err } idList = append(idList, id) } @@ -210,7 +203,7 @@ func listDDNS(c *gin.Context) error { if profile, ok := singleton.DDNSCache[id]; ok { ddnsProfiles = append(ddnsProfiles, *profile) } else { - return fmt.Errorf("profile id %d not found", id) + return nil, fmt.Errorf("profile id %d not found", id) } } } else { @@ -221,11 +214,7 @@ func listDDNS(c *gin.Context) error { } singleton.DDNSCacheLock.RUnlock() - c.JSON(http.StatusOK, model.CommonResponse[[]model.DDNSProfile]{ - Success: true, - Data: ddnsProfiles, - }) - return nil + return ddnsProfiles, nil } // List DDNS Providers @@ -237,10 +226,6 @@ func listDDNS(c *gin.Context) error { // @Produce json // @Success 200 {object} model.CommonResponse[[]string] // @Router /ddns/providers [get] -func listProviders(c *gin.Context) error { - c.JSON(http.StatusOK, model.CommonResponse[[]string]{ - Success: true, - Data: model.ProviderList, - }) - return nil +func listProviders(c *gin.Context) ([]string, error) { + return model.ProviderList, nil } diff --git a/cmd/dashboard/controller/server.go b/cmd/dashboard/controller/server.go index f2789c2..3633ee6 100644 --- a/cmd/dashboard/controller/server.go +++ b/cmd/dashboard/controller/server.go @@ -2,7 +2,6 @@ package controller import ( "fmt" - "net/http" "strconv" "github.com/gin-gonic/gin" @@ -21,15 +20,12 @@ import ( // @Produce json // @Success 200 {object} model.CommonResponse[any] // @Router /server [get] -func listServer(c *gin.Context) error { +func listServer(c *gin.Context) ([]model.Server, error) { var servers []model.Server if err := singleton.DB.Find(&servers).Error; err != nil { - return newGormError("%v", err) + return nil, newGormError("%v", err) } - c.JSON(http.StatusOK, model.CommonResponse[any]{ - Data: servers, - }) - return nil + return servers, nil } // Edit server @@ -44,20 +40,20 @@ func listServer(c *gin.Context) error { // @Produce json // @Success 200 {object} model.CommonResponse[any] // @Router /server/{id} [patch] -func updateServer(c *gin.Context) error { +func updateServer(c *gin.Context) (any, error) { idStr := c.Param("id") id, err := strconv.ParseUint(idStr, 10, 64) if err != nil { - return err + return nil, err } var sf model.ServerForm if err := c.ShouldBindJSON(&sf); err != nil { - return err + return nil, err } var s model.Server if err := singleton.DB.First(&s, id).Error; err != nil { - return fmt.Errorf("server id %d does not exist", id) + return nil, fmt.Errorf("server id %d does not exist", id) } s.Name = sf.Name @@ -70,12 +66,12 @@ func updateServer(c *gin.Context) error { s.DDNSProfiles = sf.DDNSProfiles ddnsProfilesRaw, err := utils.Json.Marshal(s.DDNSProfiles) if err != nil { - return err + return nil, err } s.DDNSProfilesRaw = string(ddnsProfilesRaw) if err := singleton.DB.Save(&s).Error; err != nil { - return newGormError("%v", err) + return nil, newGormError("%v", err) } singleton.ServerLock.Lock() @@ -83,10 +79,8 @@ func updateServer(c *gin.Context) error { singleton.ServerList[s.ID] = &s singleton.ServerLock.Unlock() singleton.ReSortServer() - c.JSON(http.StatusOK, model.Response{ - Code: http.StatusOK, - }) - return nil + + return nil, nil } // Batch delete server @@ -100,14 +94,14 @@ func updateServer(c *gin.Context) error { // @Produce json // @Success 200 {object} model.CommonResponse[any] // @Router /batch-delete/server [post] -func batchDeleteServer(c *gin.Context) error { +func batchDeleteServer(c *gin.Context) (any, error) { var servers []uint64 if err := c.ShouldBindJSON(&servers); err != nil { - return err + return nil, err } if err := singleton.DB.Unscoped().Delete(&model.Server{}, "id in (?)", servers).Error; err != nil { - return newGormError("%v", err) + return nil, newGormError("%v", err) } singleton.ServerLock.Lock() @@ -131,8 +125,5 @@ func batchDeleteServer(c *gin.Context) error { singleton.ReSortServer() - c.JSON(http.StatusOK, model.CommonResponse[any]{ - Success: true, - }) - return nil + return nil, nil } diff --git a/cmd/dashboard/controller/server_group.go b/cmd/dashboard/controller/server_group.go index a8d6224..cbbc084 100644 --- a/cmd/dashboard/controller/server_group.go +++ b/cmd/dashboard/controller/server_group.go @@ -2,7 +2,6 @@ package controller import ( "fmt" - "net/http" "github.com/gin-gonic/gin" "gorm.io/gorm" @@ -20,16 +19,16 @@ import ( // @Produce json // @Success 200 {object} model.CommonResponse[[]model.ServerGroupResponseItem] // @Router /server-group [get] -func listServerGroup(c *gin.Context) error { +func listServerGroup(c *gin.Context) ([]model.ServerGroupResponseItem, error) { var sg []model.ServerGroup if err := singleton.DB.Find(&sg).Error; err != nil { - return err + return nil, err } groupServers := make(map[uint64][]uint64, 0) var sgs []model.ServerGroupServer if err := singleton.DB.Find(&sgs).Error; err != nil { - return err + return nil, err } for _, s := range sgs { if _, ok := groupServers[s.ServerGroupId]; !ok { @@ -46,11 +45,7 @@ func listServerGroup(c *gin.Context) error { }) } - c.JSON(http.StatusOK, model.CommonResponse[[]model.ServerGroupResponseItem]{ - Success: true, - Data: sgRes, - }) - return nil + return sgRes, nil } // New server group @@ -62,12 +57,12 @@ func listServerGroup(c *gin.Context) error { // @Accept json // @Param body body model.ServerGroupForm true "ServerGroupForm" // @Produce json -// @Success 200 {object} model.CommonResponse[any] +// @Success 200 {object} model.CommonResponse[uint64] // @Router /server-group [post] -func createServerGroup(c *gin.Context) error { +func createServerGroup(c *gin.Context) (uint64, error) { var sgf model.ServerGroupForm if err := c.ShouldBindJSON(&sgf); err != nil { - return err + return 0, err } var sg model.ServerGroup @@ -75,13 +70,13 @@ func createServerGroup(c *gin.Context) error { var count int64 if err := singleton.DB.Model(&model.Server{}).Where("id = ?", sgf.Servers).Count(&count).Error; err != nil { - return err + return 0, err } if count != int64(len(sgf.Servers)) { - return fmt.Errorf("have invalid server id") + return 0, fmt.Errorf("have invalid server id") } - singleton.DB.Transaction(func(tx *gorm.DB) error { + err := singleton.DB.Transaction(func(tx *gorm.DB) error { if err := tx.Create(&sg).Error; err != nil { return err } @@ -95,11 +90,11 @@ func createServerGroup(c *gin.Context) error { } return nil }) + if err != nil { + return 0, newGormError("%v", err) + } - c.JSON(http.StatusOK, model.CommonResponse[any]{ - Success: true, - }) - return nil + return sg.ID, nil } // Edit server group @@ -114,24 +109,24 @@ func createServerGroup(c *gin.Context) error { // @Produce json // @Success 200 {object} model.CommonResponse[any] // @Router /server-group/{id} [patch] -func updateServerGroup(c *gin.Context) error { +func updateServerGroup(c *gin.Context) (any, error) { id := c.Param("id") var sg model.ServerGroupForm if err := c.ShouldBindJSON(&sg); err != nil { - return err + return nil, err } var sgDB model.ServerGroup if err := singleton.DB.First(&sgDB, id).Error; err != nil { - return fmt.Errorf("group id %s does not exist", id) + return nil, fmt.Errorf("group id %s does not exist", id) } sgDB.Name = sg.Name var count int64 if err := singleton.DB.Model(&model.Server{}).Where("id = ?", sg.Servers).Count(&count).Error; err != nil { - return err + return nil, err } if count != int64(len(sg.Servers)) { - return fmt.Errorf("have invalid server id") + return nil, fmt.Errorf("have invalid server id") } err := singleton.DB.Transaction(func(tx *gorm.DB) error { @@ -153,13 +148,10 @@ func updateServerGroup(c *gin.Context) error { return nil }) if err != nil { - return newGormError("%v", err) + return nil, newGormError("%v", err) } - c.JSON(http.StatusOK, model.CommonResponse[any]{ - Success: true, - }) - return nil + return nil, nil } // Batch delete server group @@ -173,10 +165,10 @@ func updateServerGroup(c *gin.Context) error { // @Produce json // @Success 200 {object} model.CommonResponse[any] // @Router /batch-delete/server-group [post] -func batchDeleteServerGroup(c *gin.Context) error { +func batchDeleteServerGroup(c *gin.Context) (any, error) { var sgs []uint64 if err := c.ShouldBindJSON(&sgs); err != nil { - return err + return nil, err } err := singleton.DB.Transaction(func(tx *gorm.DB) error { @@ -190,11 +182,8 @@ func batchDeleteServerGroup(c *gin.Context) error { }) if err != nil { - return newGormError("%v", err) + return nil, newGormError("%v", err) } - c.JSON(http.StatusOK, model.CommonResponse[any]{ - Success: true, - }) - return nil + return nil, nil } diff --git a/cmd/dashboard/controller/terminal.go b/cmd/dashboard/controller/terminal.go index b2d95d9..acc37a3 100644 --- a/cmd/dashboard/controller/terminal.go +++ b/cmd/dashboard/controller/terminal.go @@ -2,7 +2,6 @@ package controller import ( "errors" - "net/http" "time" "github.com/gin-gonic/gin" @@ -25,15 +24,15 @@ import ( // @Produce json // @Success 200 {object} model.CreateTerminalResponse // @Router /terminal [post] -func createTerminal(c *gin.Context) error { +func createTerminal(c *gin.Context) (*model.CreateTerminalResponse, error) { var createTerminalReq model.TerminalForm if err := c.ShouldBind(&createTerminalReq); err != nil { - return err + return nil, err } streamId, err := uuid.GenerateUUID() if err != nil { - return err + return nil, err } rpc.NezhaHandlerSingleton.CreateStream(streamId) @@ -42,7 +41,7 @@ func createTerminal(c *gin.Context) error { server := singleton.ServerList[createTerminalReq.ServerID] singleton.ServerLock.RUnlock() if server == nil || server.TaskStream == nil { - return errors.New("server not found or not connected") + return nil, errors.New("server not found or not connected") } terminalData, _ := utils.Json.Marshal(&model.TerminalTask{ @@ -52,19 +51,14 @@ func createTerminal(c *gin.Context) error { Type: model.TaskTypeTerminalGRPC, Data: string(terminalData), }); err != nil { - return err + return nil, err } - c.JSON(http.StatusOK, model.CommonResponse[model.CreateTerminalResponse]{ - Success: true, - Data: model.CreateTerminalResponse{ - SessionID: streamId, - ServerID: server.ID, - ServerName: server.Name, - }, - }) - - return nil + return &model.CreateTerminalResponse{ + SessionID: streamId, + ServerID: server.ID, + ServerName: server.Name, + }, nil } // TerminalStream web ssh terminal stream @@ -73,16 +67,16 @@ func createTerminal(c *gin.Context) error { // @Tags auth required // @Param id path string true "Stream ID" // @Router /terminal/{id} [get] -func terminalStream(c *gin.Context) error { +func terminalStream(c *gin.Context) (any, error) { streamId := c.Param("id") if _, err := rpc.NezhaHandlerSingleton.GetStream(streamId); err != nil { - return err + return nil, err } defer rpc.NezhaHandlerSingleton.CloseStream(streamId) wsConn, err := upgrader.Upgrade(c.Writer, c.Request, nil) if err != nil { - return err + return nil, err } defer wsConn.Close() conn := websocketx.NewConn(wsConn) @@ -98,8 +92,8 @@ func terminalStream(c *gin.Context) error { }() if err = rpc.NezhaHandlerSingleton.UserConnected(streamId, conn); err != nil { - return err + return nil, err } - return rpc.NezhaHandlerSingleton.StartStream(streamId, time.Second*10) + return nil, rpc.NezhaHandlerSingleton.StartStream(streamId, time.Second*10) } diff --git a/cmd/dashboard/controller/user.go b/cmd/dashboard/controller/user.go index 84e9f44..0dbf29d 100644 --- a/cmd/dashboard/controller/user.go +++ b/cmd/dashboard/controller/user.go @@ -18,16 +18,12 @@ import ( // @Produce json // @Success 200 {object} model.CommonResponse[[]model.User] // @Router /user [get] -func listUser(c *gin.Context) error { +func listUser(c *gin.Context) ([]model.User, error) { var users []model.User if err := singleton.DB.Find(&users).Error; err != nil { - return err + return nil, err } - c.JSON(200, model.CommonResponse[[]model.User]{ - Success: true, - Data: users, - }) - return nil + return users, nil } // Create user @@ -39,19 +35,19 @@ func listUser(c *gin.Context) error { // @Accept json // @param request body model.UserForm true "User Request" // @Produce json -// @Success 200 {object} model.CommonResponse[any] +// @Success 200 {object} model.CommonResponse[uint64] // @Router /user [post] -func createUser(c *gin.Context) error { +func createUser(c *gin.Context) (uint64, error) { var uf model.UserForm if err := c.ShouldBindJSON(&uf); err != nil { - return err + return 0, err } if len(uf.Password) < 6 { - return errors.New("password length must be greater than 6") + return 0, errors.New("password length must be greater than 6") } if uf.Username == "" { - return errors.New("username can't be empty") + return 0, errors.New("username can't be empty") } var u model.User @@ -59,11 +55,15 @@ func createUser(c *gin.Context) error { hash, err := bcrypt.GenerateFromPassword([]byte(uf.Password), bcrypt.DefaultCost) if err != nil { - return err + return 0, err } u.Password = string(hash) - return singleton.DB.Create(&u).Error + if err := singleton.DB.Create(&u).Error; err != nil { + return 0, err + } + + return u.ID, nil } // Batch delete users @@ -77,10 +77,10 @@ func createUser(c *gin.Context) error { // @Produce json // @Success 200 {object} model.CommonResponse[any] // @Router /batch-delete/user [post] -func batchDeleteUser(c *gin.Context) error { +func batchDeleteUser(c *gin.Context) (any, error) { var ids []uint if err := c.ShouldBindJSON(&ids); err != nil { - return err + return nil, err } - return singleton.DB.Where("id IN (?)", ids).Delete(&model.User{}).Error + return nil, singleton.DB.Where("id IN (?)", ids).Delete(&model.User{}).Error } diff --git a/cmd/dashboard/controller/ws.go b/cmd/dashboard/controller/ws.go index a3f3472..c4f1711 100644 --- a/cmd/dashboard/controller/ws.go +++ b/cmd/dashboard/controller/ws.go @@ -27,10 +27,10 @@ var upgrader = websocket.Upgrader{ // @Produce json // @Success 200 {object} model.StreamServerData // @Router /ws/server [get] -func serverStream(c *gin.Context) error { +func serverStream(c *gin.Context) (any, error) { conn, err := upgrader.Upgrade(c.Writer, c.Request, nil) if err != nil { - return err + return nil, err } defer conn.Close() count := 0 @@ -51,7 +51,7 @@ func serverStream(c *gin.Context) error { } time.Sleep(time.Second * 2) } - return nil + return nil, nil } var requestGroup singleflight.Group