Merge branch 'master' of github.com:nezhahq/nezha

This commit is contained in:
naiba 2024-12-31 23:58:49 +08:00
commit 126773974c
5 changed files with 16 additions and 12 deletions

View File

@ -198,6 +198,8 @@ func (we *wsError) Error() string {
return fmt.Sprintf(we.msg, we.a...) return fmt.Sprintf(we.msg, we.a...)
} }
var errNoop = errors.New("wrote")
func commonHandler[T any](handler handlerFunc[T]) func(*gin.Context) { func commonHandler[T any](handler handlerFunc[T]) func(*gin.Context) {
return func(c *gin.Context) { return func(c *gin.Context) {
handle(c, handler) handle(c, handler)
@ -240,7 +242,9 @@ func handle[T any](c *gin.Context, handler handlerFunc[T]) {
} }
return return
default: default:
c.JSON(http.StatusOK, newErrorResponse(err)) if !errors.Is(err, errNoop) {
c.JSON(http.StatusOK, newErrorResponse(err))
}
return return
} }
} }

View File

@ -90,7 +90,7 @@ func authenticator() func(c *gin.Context) (interface{}, error) {
var user model.User var user model.User
realip := c.GetString(model.CtxKeyRealIPStr) realip := c.GetString(model.CtxKeyRealIPStr)
if err := singleton.DB.Select("id", "password").Where("username = ?", loginVals.Username).First(&user).Error; err != nil { if err := singleton.DB.Select("id", "password", "reject_password").Where("username = ?", loginVals.Username).First(&user).Error; err != nil {
if err == gorm.ErrRecordNotFound { if err == gorm.ErrRecordNotFound {
model.BlockIP(singleton.DB, realip, model.WAFBlockReasonTypeLoginFail, model.BlockIDUnknownUser) model.BlockIP(singleton.DB, realip, model.WAFBlockReasonTypeLoginFail, model.BlockIDUnknownUser)
} }

View File

@ -7,7 +7,6 @@ import (
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
"time"
jwt "github.com/appleboy/gin-jwt/v2" jwt "github.com/appleboy/gin-jwt/v2"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@ -114,10 +113,10 @@ func unbindOauth2(c *gin.Context) (any, error) {
// @Produce json // @Produce json
// @Param state query string true "state" // @Param state query string true "state"
// @Param code query string true "code" // @Param code query string true "code"
// @Success 200 {object} model.LoginResponse // @Success 200 {object} model.CommonResponse[any]
// @Router /api/v1/oauth2/callback [get] // @Router /api/v1/oauth2/callback [get]
func oauth2callback(jwtConfig *jwt.GinJWTMiddleware) func(c *gin.Context) (*model.LoginResponse, error) { func oauth2callback(jwtConfig *jwt.GinJWTMiddleware) func(c *gin.Context) (any, error) {
return func(c *gin.Context) (*model.LoginResponse, error) { return func(c *gin.Context) (any, error) {
callbackData := &model.Oauth2Callback{ callbackData := &model.Oauth2Callback{
State: c.Query("state"), State: c.Query("state"),
Code: c.Query("code"), Code: c.Query("code"),
@ -146,6 +145,7 @@ func oauth2callback(jwtConfig *jwt.GinJWTMiddleware) func(c *gin.Context) (*mode
} }
var bind model.Oauth2Bind var bind model.Oauth2Bind
state.Provider = strings.ToLower(state.Provider)
switch state.Action { switch state.Action {
case model.RTypeBind: case model.RTypeBind:
u, authorized := c.Get(model.CtxKeyAuthorizedUser) u, authorized := c.Get(model.CtxKeyAuthorizedUser)
@ -154,7 +154,7 @@ func oauth2callback(jwtConfig *jwt.GinJWTMiddleware) func(c *gin.Context) (*mode
} }
user := u.(*model.User) user := u.(*model.User)
result := singleton.DB.Where("provider = ? AND open_id = ?", strings.ToLower(state.Provider), openId).Limit(1).Find(&bind) result := singleton.DB.Where("provider = ? AND open_id = ?", state.Provider, openId).Limit(1).Find(&bind)
if result.Error != nil && result.Error != gorm.ErrRecordNotFound { if result.Error != nil && result.Error != gorm.ErrRecordNotFound {
return nil, newGormError("%v", result.Error) return nil, newGormError("%v", result.Error)
} }
@ -171,12 +171,12 @@ func oauth2callback(jwtConfig *jwt.GinJWTMiddleware) func(c *gin.Context) (*mode
return nil, newGormError("%v", result.Error) return nil, newGormError("%v", result.Error)
} }
default: default:
if err := singleton.DB.Where("provider = ? AND open_id = ?", strings.ToLower(state.Provider), openId).First(&bind).Error; err != nil { if err := singleton.DB.Where("provider = ? AND open_id = ?", state.Provider, openId).First(&bind).Error; err != nil {
return nil, singleton.Localizer.ErrorT("oauth2 user not binded yet") return nil, singleton.Localizer.ErrorT("oauth2 user not binded yet")
} }
} }
tokenString, expire, err := jwtConfig.TokenGenerator(fmt.Sprintf("%d", bind.UserID)) tokenString, _, err := jwtConfig.TokenGenerator(fmt.Sprintf("%d", bind.UserID))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -184,7 +184,7 @@ func oauth2callback(jwtConfig *jwt.GinJWTMiddleware) func(c *gin.Context) (*mode
jwtConfig.SetCookie(c, tokenString) jwtConfig.SetCookie(c, tokenString)
c.Redirect(http.StatusFound, utils.IfOr(state.Action == model.RTypeBind, "/dashboard/profile?oauth2=true", "/dashboard/login?oauth2=true")) c.Redirect(http.StatusFound, utils.IfOr(state.Action == model.RTypeBind, "/dashboard/profile?oauth2=true", "/dashboard/login?oauth2=true"))
return &model.LoginResponse{Token: tokenString, Expire: expire.Format(time.RFC3339)}, nil return nil, errNoop
} }
} }

View File

@ -74,7 +74,7 @@ func updateProfile(c *gin.Context) (any, error) {
} }
var bindCount int64 var bindCount int64
if err := singleton.DB.Where("user_id = ?", auth.(*model.User).ID).Count(&bindCount).Error; err != nil { if err := singleton.DB.Model(&model.Oauth2Bind{}).Where("user_id = ?", auth.(*model.User).ID).Count(&bindCount).Error; err != nil {
return nil, newGormError("%v", err) return nil, newGormError("%v", err)
} }

View File

@ -10,5 +10,5 @@ type ProfileForm struct {
OriginalPassword string `json:"original_password,omitempty"` OriginalPassword string `json:"original_password,omitempty"`
NewUsername string `json:"new_username,omitempty"` NewUsername string `json:"new_username,omitempty"`
NewPassword string `json:"new_password,omitempty"` NewPassword string `json:"new_password,omitempty"`
RejectPassword bool `json:"reject_password,omitempty"` RejectPassword bool `json:"reject_password,omitempty" validate:"optional"`
} }