diff --git a/cmd/dashboard/controller/controller.go b/cmd/dashboard/controller/controller.go index 1e9c054..1dcffa1 100644 --- a/cmd/dashboard/controller/controller.go +++ b/cmd/dashboard/controller/controller.go @@ -197,6 +197,8 @@ func (we *wsError) Error() string { return fmt.Sprintf(we.msg, we.a...) } +var errNoop = errors.New("wrote") + func commonHandler[T any](handler handlerFunc[T]) func(*gin.Context) { return func(c *gin.Context) { handle(c, handler) @@ -239,7 +241,9 @@ func handle[T any](c *gin.Context, handler handlerFunc[T]) { } return default: - c.JSON(http.StatusOK, newErrorResponse(err)) + if !errors.Is(err, errNoop) { + c.JSON(http.StatusOK, newErrorResponse(err)) + } return } } diff --git a/cmd/dashboard/controller/jwt.go b/cmd/dashboard/controller/jwt.go index adbcf1d..d80bd66 100644 --- a/cmd/dashboard/controller/jwt.go +++ b/cmd/dashboard/controller/jwt.go @@ -90,7 +90,7 @@ func authenticator() func(c *gin.Context) (interface{}, error) { var user model.User realip := c.GetString(model.CtxKeyRealIPStr) - if err := singleton.DB.Select("id", "password").Where("username = ?", loginVals.Username).First(&user).Error; err != nil { + if err := singleton.DB.Select("id", "password", "reject_password").Where("username = ?", loginVals.Username).First(&user).Error; err != nil { if err == gorm.ErrRecordNotFound { model.BlockIP(singleton.DB, realip, model.WAFBlockReasonTypeLoginFail, model.BlockIDUnknownUser) } diff --git a/cmd/dashboard/controller/oauth2.go b/cmd/dashboard/controller/oauth2.go index 10eb640..d36e077 100644 --- a/cmd/dashboard/controller/oauth2.go +++ b/cmd/dashboard/controller/oauth2.go @@ -7,7 +7,6 @@ import ( "net/http" "strconv" "strings" - "time" jwt "github.com/appleboy/gin-jwt/v2" "github.com/gin-gonic/gin" @@ -114,10 +113,10 @@ func unbindOauth2(c *gin.Context) (any, error) { // @Produce json // @Param state query string true "state" // @Param code query string true "code" -// @Success 200 {object} model.LoginResponse +// @Success 200 {object} model.CommonResponse[any] // @Router /api/v1/oauth2/callback [get] -func oauth2callback(jwtConfig *jwt.GinJWTMiddleware) func(c *gin.Context) (*model.LoginResponse, error) { - return func(c *gin.Context) (*model.LoginResponse, error) { +func oauth2callback(jwtConfig *jwt.GinJWTMiddleware) func(c *gin.Context) (any, error) { + return func(c *gin.Context) (any, error) { callbackData := &model.Oauth2Callback{ State: c.Query("state"), Code: c.Query("code"), @@ -146,6 +145,7 @@ func oauth2callback(jwtConfig *jwt.GinJWTMiddleware) func(c *gin.Context) (*mode } var bind model.Oauth2Bind + state.Provider = strings.ToLower(state.Provider) switch state.Action { case model.RTypeBind: u, authorized := c.Get(model.CtxKeyAuthorizedUser) @@ -154,7 +154,7 @@ func oauth2callback(jwtConfig *jwt.GinJWTMiddleware) func(c *gin.Context) (*mode } user := u.(*model.User) - result := singleton.DB.Where("provider = ? AND open_id = ?", strings.ToLower(state.Provider), openId).Limit(1).Find(&bind) + result := singleton.DB.Where("provider = ? AND open_id = ?", state.Provider, openId).Limit(1).Find(&bind) if result.Error != nil && result.Error != gorm.ErrRecordNotFound { return nil, newGormError("%v", result.Error) } @@ -171,12 +171,12 @@ func oauth2callback(jwtConfig *jwt.GinJWTMiddleware) func(c *gin.Context) (*mode return nil, newGormError("%v", result.Error) } default: - if err := singleton.DB.Where("provider = ? AND open_id = ?", strings.ToLower(state.Provider), openId).First(&bind).Error; err != nil { + if err := singleton.DB.Where("provider = ? AND open_id = ?", state.Provider, openId).First(&bind).Error; err != nil { return nil, singleton.Localizer.ErrorT("oauth2 user not binded yet") } } - tokenString, expire, err := jwtConfig.TokenGenerator(fmt.Sprintf("%d", bind.UserID)) + tokenString, _, err := jwtConfig.TokenGenerator(fmt.Sprintf("%d", bind.UserID)) if err != nil { return nil, err } @@ -184,7 +184,7 @@ func oauth2callback(jwtConfig *jwt.GinJWTMiddleware) func(c *gin.Context) (*mode jwtConfig.SetCookie(c, tokenString) c.Redirect(http.StatusFound, utils.IfOr(state.Action == model.RTypeBind, "/dashboard/profile?oauth2=true", "/dashboard/login?oauth2=true")) - return &model.LoginResponse{Token: tokenString, Expire: expire.Format(time.RFC3339)}, nil + return nil, errNoop } } diff --git a/cmd/dashboard/controller/user.go b/cmd/dashboard/controller/user.go index 44f7577..486597e 100644 --- a/cmd/dashboard/controller/user.go +++ b/cmd/dashboard/controller/user.go @@ -74,7 +74,7 @@ func updateProfile(c *gin.Context) (any, error) { } var bindCount int64 - if err := singleton.DB.Where("user_id = ?", auth.(*model.User).ID).Count(&bindCount).Error; err != nil { + if err := singleton.DB.Model(&model.Oauth2Bind{}).Where("user_id = ?", auth.(*model.User).ID).Count(&bindCount).Error; err != nil { return nil, newGormError("%v", err) } diff --git a/model/user_api.go b/model/user_api.go index 315fb50..c7fac9e 100644 --- a/model/user_api.go +++ b/model/user_api.go @@ -10,5 +10,5 @@ type ProfileForm struct { OriginalPassword string `json:"original_password,omitempty"` NewUsername string `json:"new_username,omitempty"` NewPassword string `json:"new_password,omitempty"` - RejectPassword bool `json:"reject_password,omitempty"` + RejectPassword bool `json:"reject_password,omitempty" validate:"optional"` }