diff --git a/cmd/dashboard/controller/oauth2.go b/cmd/dashboard/controller/oauth2.go index d36e077..1c4eac6 100644 --- a/cmd/dashboard/controller/oauth2.go +++ b/cmd/dashboard/controller/oauth2.go @@ -2,9 +2,13 @@ package controller import ( "context" + "crypto/hmac" + "crypto/sha256" + "encoding/hex" "fmt" "io" "net/http" + "sort" "strconv" "strings" @@ -52,7 +56,10 @@ func oauth2redirect(c *gin.Context) (*model.Oauth2LoginResponse, error) { return nil, singleton.Localizer.ErrorT("provider not found") } o2conf := o2confRaw.Setup(getRedirectURL(c)) - + if provider == "Telegram" { + // 直接返回配置中的 AuthURL + return &model.Oauth2LoginResponse{Redirect: o2confRaw.Endpoint.AuthURL}, nil + } randomString, err := utils.GenerateRandomString(32) if err != nil { return nil, err @@ -117,6 +124,69 @@ func unbindOauth2(c *gin.Context) (any, error) { // @Router /api/v1/oauth2/callback [get] func oauth2callback(jwtConfig *jwt.GinJWTMiddleware) func(c *gin.Context) (any, error) { return func(c *gin.Context) (any, error) { + // 通过判断请求参数来确定是否是 Telegram 回调 + if c.Query("id") != "" && c.Query("auth_date") != "" && c.Query("hash") != "" { + queryParams := make(map[string]string) + for k, v := range c.Request.URL.Query() { + if len(v) > 0 { + queryParams[k] = v[0] + } + } + o2confRaw, has := singleton.Conf.Oauth2["Telegram"] + if !has { + return nil, singleton.Localizer.ErrorT("provider not found") + } + + // 验证 Telegram Hash数据 + if valid, err := verifyTelegramAuth(queryParams, o2confRaw.ClientID); err != nil { + return nil, err + } else if !valid { + return nil, singleton.Localizer.ErrorT("invalid Telegram auth data") + } + + var bind model.Oauth2Bind + provider := "telegram" + openId := queryParams["id"] + + u, authorized := c.Get(model.CtxKeyAuthorizedUser) + if authorized { + user := u.(*model.User) + result := singleton.DB.Where("provider = ? AND open_id = ?", provider, openId).Limit(1).Find(&bind) + if result.Error != nil && result.Error != gorm.ErrRecordNotFound { + return nil, newGormError("%v", result.Error) + } + bind.UserID = user.ID + bind.Provider = provider + bind.OpenID = openId + + if result.Error == gorm.ErrRecordNotFound { + result = singleton.DB.Create(&bind) + } else { + result = singleton.DB.Save(&bind) + } + if result.Error != nil { + return nil, newGormError("%v", result.Error) + } + + c.Redirect(http.StatusFound, "/dashboard/profile?oauth2=true") + } else { + if err := singleton.DB.Where("provider = ? AND open_id = ?", provider, openId).First(&bind).Error; err != nil { + return nil, singleton.Localizer.ErrorT("oauth2 user not binded yet") + } + + tokenString, _, err := jwtConfig.TokenGenerator(fmt.Sprintf("%d", bind.UserID)) + if err != nil { + return nil, err + } + + jwtConfig.SetCookie(c, tokenString) + c.Redirect(http.StatusFound, "/dashboard/login?oauth2=true") + } + + return nil, errNoop + } + + // 其他 OAuth2 提供商的原有逻辑 callbackData := &model.Oauth2Callback{ State: c.Query("state"), Code: c.Query("code"), @@ -188,7 +258,68 @@ func oauth2callback(jwtConfig *jwt.GinJWTMiddleware) func(c *gin.Context) (any, } } +func verifyTelegramAuth(data map[string]string, botToken string) (bool, error) { + // 只保留需要验证的字段 + requiredFields := []string{"id", "first_name", "last_name", "username", "photo_url", "auth_date"} + checkData := make(map[string]string) + + // 只复制需要的字段 + for _, field := range requiredFields { + if value, exists := data[field]; exists { + checkData[field] = value + } + } + + var dataCheckString string + keys := make([]string, 0, len(checkData)) + for k := range checkData { + if k != "hash" { + keys = append(keys, k) + } + } + sort.Strings(keys) + + for _, k := range keys { + if len(dataCheckString) > 0 { + dataCheckString += "\n" + } + dataCheckString += fmt.Sprintf("%s=%s", k, checkData[k]) + } + + // 先对 bot token 进行 SHA256 哈希作为密钥 + secretKeyHash := sha256.Sum256([]byte(botToken)) + + // 使用哈希后的密钥计算 HMAC + h := hmac.New(sha256.New, secretKeyHash[:]) + h.Write([]byte(dataCheckString)) + hash := hex.EncodeToString(h.Sum(nil)) + + return hash == data["hash"], nil +} + func exchangeOpenId(c *gin.Context, o2confRaw *model.Oauth2Config, callbackData *model.Oauth2Callback) (string, error) { + // 处理Telegram Widget OAuth + if strings.ToLower(c.Param("provider")) == "telegram" { + // 解析查询参数 + queryParams := make(map[string]string) + for k, v := range c.Request.URL.Query() { + if len(v) > 0 { + queryParams[k] = v[0] + } + } + + // 验证数据 + if valid, err := verifyTelegramAuth(queryParams, o2confRaw.ClientID); err != nil { + return "", err + } else if !valid { + return "", singleton.Localizer.ErrorT("invalid Telegram auth data") + } + + // 返回Telegram用户ID + return queryParams["id"], nil + } + + // 原有OAuth2处理逻辑 o2conf := o2confRaw.Setup(getRedirectURL(c)) ctx := context.Background()