Merge branch 'master' of github.com:nezhahq/nezha

This commit is contained in:
naiba 2024-12-31 23:58:49 +08:00
commit 126773974c
5 changed files with 16 additions and 12 deletions

View File

@ -198,6 +198,8 @@ func (we *wsError) Error() string {
return fmt.Sprintf(we.msg, we.a...)
}
var errNoop = errors.New("wrote")
func commonHandler[T any](handler handlerFunc[T]) func(*gin.Context) {
return func(c *gin.Context) {
handle(c, handler)
@ -240,7 +242,9 @@ func handle[T any](c *gin.Context, handler handlerFunc[T]) {
}
return
default:
c.JSON(http.StatusOK, newErrorResponse(err))
if !errors.Is(err, errNoop) {
c.JSON(http.StatusOK, newErrorResponse(err))
}
return
}
}

View File

@ -90,7 +90,7 @@ func authenticator() func(c *gin.Context) (interface{}, error) {
var user model.User
realip := c.GetString(model.CtxKeyRealIPStr)
if err := singleton.DB.Select("id", "password").Where("username = ?", loginVals.Username).First(&user).Error; err != nil {
if err := singleton.DB.Select("id", "password", "reject_password").Where("username = ?", loginVals.Username).First(&user).Error; err != nil {
if err == gorm.ErrRecordNotFound {
model.BlockIP(singleton.DB, realip, model.WAFBlockReasonTypeLoginFail, model.BlockIDUnknownUser)
}

View File

@ -7,7 +7,6 @@ import (
"net/http"
"strconv"
"strings"
"time"
jwt "github.com/appleboy/gin-jwt/v2"
"github.com/gin-gonic/gin"
@ -114,10 +113,10 @@ func unbindOauth2(c *gin.Context) (any, error) {
// @Produce json
// @Param state query string true "state"
// @Param code query string true "code"
// @Success 200 {object} model.LoginResponse
// @Success 200 {object} model.CommonResponse[any]
// @Router /api/v1/oauth2/callback [get]
func oauth2callback(jwtConfig *jwt.GinJWTMiddleware) func(c *gin.Context) (*model.LoginResponse, error) {
return func(c *gin.Context) (*model.LoginResponse, error) {
func oauth2callback(jwtConfig *jwt.GinJWTMiddleware) func(c *gin.Context) (any, error) {
return func(c *gin.Context) (any, error) {
callbackData := &model.Oauth2Callback{
State: c.Query("state"),
Code: c.Query("code"),
@ -146,6 +145,7 @@ func oauth2callback(jwtConfig *jwt.GinJWTMiddleware) func(c *gin.Context) (*mode
}
var bind model.Oauth2Bind
state.Provider = strings.ToLower(state.Provider)
switch state.Action {
case model.RTypeBind:
u, authorized := c.Get(model.CtxKeyAuthorizedUser)
@ -154,7 +154,7 @@ func oauth2callback(jwtConfig *jwt.GinJWTMiddleware) func(c *gin.Context) (*mode
}
user := u.(*model.User)
result := singleton.DB.Where("provider = ? AND open_id = ?", strings.ToLower(state.Provider), openId).Limit(1).Find(&bind)
result := singleton.DB.Where("provider = ? AND open_id = ?", state.Provider, openId).Limit(1).Find(&bind)
if result.Error != nil && result.Error != gorm.ErrRecordNotFound {
return nil, newGormError("%v", result.Error)
}
@ -171,12 +171,12 @@ func oauth2callback(jwtConfig *jwt.GinJWTMiddleware) func(c *gin.Context) (*mode
return nil, newGormError("%v", result.Error)
}
default:
if err := singleton.DB.Where("provider = ? AND open_id = ?", strings.ToLower(state.Provider), openId).First(&bind).Error; err != nil {
if err := singleton.DB.Where("provider = ? AND open_id = ?", state.Provider, openId).First(&bind).Error; err != nil {
return nil, singleton.Localizer.ErrorT("oauth2 user not binded yet")
}
}
tokenString, expire, err := jwtConfig.TokenGenerator(fmt.Sprintf("%d", bind.UserID))
tokenString, _, err := jwtConfig.TokenGenerator(fmt.Sprintf("%d", bind.UserID))
if err != nil {
return nil, err
}
@ -184,7 +184,7 @@ func oauth2callback(jwtConfig *jwt.GinJWTMiddleware) func(c *gin.Context) (*mode
jwtConfig.SetCookie(c, tokenString)
c.Redirect(http.StatusFound, utils.IfOr(state.Action == model.RTypeBind, "/dashboard/profile?oauth2=true", "/dashboard/login?oauth2=true"))
return &model.LoginResponse{Token: tokenString, Expire: expire.Format(time.RFC3339)}, nil
return nil, errNoop
}
}

View File

@ -74,7 +74,7 @@ func updateProfile(c *gin.Context) (any, error) {
}
var bindCount int64
if err := singleton.DB.Where("user_id = ?", auth.(*model.User).ID).Count(&bindCount).Error; err != nil {
if err := singleton.DB.Model(&model.Oauth2Bind{}).Where("user_id = ?", auth.(*model.User).ID).Count(&bindCount).Error; err != nil {
return nil, newGormError("%v", err)
}

View File

@ -10,5 +10,5 @@ type ProfileForm struct {
OriginalPassword string `json:"original_password,omitempty"`
NewUsername string `json:"new_username,omitempty"`
NewPassword string `json:"new_password,omitempty"`
RejectPassword bool `json:"reject_password,omitempty"`
RejectPassword bool `json:"reject_password,omitempty" validate:"optional"`
}