nezha/cmd/dashboard/controller/oauth2.go

298 lines
9.8 KiB
Go
Raw Normal View History

2019-12-08 03:59:58 -05:00
package controller
import (
"context"
"encoding/json"
"errors"
2019-12-08 03:59:58 -05:00
"fmt"
"net/http"
2022-10-13 13:31:00 -04:00
"net/url"
2020-11-29 09:17:40 -05:00
"strings"
2022-12-16 10:34:14 -05:00
"time"
2019-12-08 03:59:58 -05:00
"github.com/coreos/go-oidc/v3/oidc"
2024-06-23 04:15:08 -04:00
"github.com/naiba/nezha/pkg/oidc/cloudflare"
myOidc "github.com/naiba/nezha/pkg/oidc/general"
2024-06-23 04:15:08 -04:00
"code.gitea.io/sdk/gitea"
2019-12-08 03:59:58 -05:00
"github.com/gin-gonic/gin"
GitHubAPI "github.com/google/go-github/v47/github"
"github.com/naiba/nezha/model"
"github.com/naiba/nezha/pkg/mygin"
"github.com/naiba/nezha/pkg/utils"
"github.com/naiba/nezha/service/singleton"
"github.com/patrickmn/go-cache"
"github.com/xanzy/go-gitlab"
2019-12-08 03:59:58 -05:00
"golang.org/x/oauth2"
GitHubOauth2 "golang.org/x/oauth2/github"
2022-09-16 12:08:27 -04:00
GitlabOauth2 "golang.org/x/oauth2/gitlab"
2019-12-08 03:59:58 -05:00
)
type oauth2controller struct {
r gin.IRoutes
oidcProvider *oidc.Provider
2019-12-08 03:59:58 -05:00
}
func (oa *oauth2controller) serve() {
oa.r.GET("/oauth2/login", oa.login)
oa.r.GET("/oauth2/callback", oa.callback)
}
func (oa *oauth2controller) getCommonOauth2Config(c *gin.Context) *oauth2.Config {
2022-01-08 22:54:14 -05:00
if singleton.Conf.Oauth2.Type == model.ConfigTypeGitee {
return &oauth2.Config{
2022-01-08 22:54:14 -05:00
ClientID: singleton.Conf.Oauth2.ClientID,
ClientSecret: singleton.Conf.Oauth2.ClientSecret,
Scopes: []string{},
Endpoint: oauth2.Endpoint{
AuthURL: "https://gitee.com/oauth/authorize",
TokenURL: "https://gitee.com/oauth/token",
},
RedirectURL: oa.getRedirectURL(c),
}
} else if singleton.Conf.Oauth2.Type == model.ConfigTypeGitlab {
return &oauth2.Config{
ClientID: singleton.Conf.Oauth2.ClientID,
ClientSecret: singleton.Conf.Oauth2.ClientSecret,
Scopes: []string{"read_user", "read_api"},
Endpoint: GitlabOauth2.Endpoint,
RedirectURL: oa.getRedirectURL(c),
}
} else if singleton.Conf.Oauth2.Type == model.ConfigTypeJihulab {
return &oauth2.Config{
ClientID: singleton.Conf.Oauth2.ClientID,
ClientSecret: singleton.Conf.Oauth2.ClientSecret,
Scopes: []string{"read_user", "read_api"},
Endpoint: oauth2.Endpoint{
AuthURL: "https://jihulab.com/oauth/authorize",
TokenURL: "https://jihulab.com/oauth/token",
},
RedirectURL: oa.getRedirectURL(c),
}
} else if singleton.Conf.Oauth2.Type == model.ConfigTypeGitea {
return &oauth2.Config{
ClientID: singleton.Conf.Oauth2.ClientID,
ClientSecret: singleton.Conf.Oauth2.ClientSecret,
Endpoint: oauth2.Endpoint{
AuthURL: fmt.Sprintf("%s/login/oauth/authorize", singleton.Conf.Oauth2.Endpoint),
TokenURL: fmt.Sprintf("%s/login/oauth/access_token", singleton.Conf.Oauth2.Endpoint),
},
RedirectURL: oa.getRedirectURL(c),
}
} else if singleton.Conf.Oauth2.Type == model.ConfigTypeCloudflare {
return &oauth2.Config{
ClientID: singleton.Conf.Oauth2.ClientID,
ClientSecret: singleton.Conf.Oauth2.ClientSecret,
Scopes: []string{"openid", "email", "profile", "groups"},
Endpoint: oauth2.Endpoint{
AuthURL: fmt.Sprintf("%s/cdn-cgi/access/sso/oidc/%s/authorization", singleton.Conf.Oauth2.Endpoint, singleton.Conf.Oauth2.ClientID),
TokenURL: fmt.Sprintf("%s/cdn-cgi/access/sso/oidc/%s/token", singleton.Conf.Oauth2.Endpoint, singleton.Conf.Oauth2.ClientID),
},
RedirectURL: oa.getRedirectURL(c),
}
} else if singleton.Conf.Oauth2.Type == model.ConfigTypeOidc {
var err error
oa.oidcProvider, err = oidc.NewProvider(c.Request.Context(), singleton.Conf.Oauth2.OidcIssuer)
if err != nil {
mygin.ShowErrorPage(c, mygin.ErrInfo{
Code: http.StatusBadRequest,
Title: fmt.Sprintf("Cannot get OIDC infomaion from issuer from %s", singleton.Conf.Oauth2.OidcIssuer),
Msg: err.Error(),
}, true)
return nil
}
scopes := strings.Split(singleton.Conf.Oauth2.OidcScopes, ",")
scopes = append(scopes, oidc.ScopeOpenID)
uniqueScopes := removeDuplicates(scopes)
return &oauth2.Config{
ClientID: singleton.Conf.Oauth2.ClientID,
ClientSecret: singleton.Conf.Oauth2.ClientSecret,
Scopes: uniqueScopes,
Endpoint: oa.oidcProvider.Endpoint(),
RedirectURL: oa.getRedirectURL(c),
}
} else {
return &oauth2.Config{
2022-01-08 22:54:14 -05:00
ClientID: singleton.Conf.Oauth2.ClientID,
ClientSecret: singleton.Conf.Oauth2.ClientSecret,
Scopes: []string{},
Endpoint: GitHubOauth2.Endpoint,
}
}
}
func (oa *oauth2controller) getRedirectURL(c *gin.Context) string {
2022-04-10 08:28:22 -04:00
scheme := "http://"
referer := c.Request.Referer()
if forwardedProto := c.Request.Header.Get("X-Forwarded-Proto"); forwardedProto == "https" || strings.HasPrefix(referer, "https://") {
2022-04-10 08:28:22 -04:00
scheme = "https://"
}
2022-04-10 08:28:22 -04:00
return scheme + c.Request.Host + "/oauth2/callback"
}
2019-12-08 03:59:58 -05:00
func (oa *oauth2controller) login(c *gin.Context) {
2022-12-16 10:34:14 -05:00
randomString, err := utils.GenerateRandomString(32)
if err != nil {
mygin.ShowErrorPage(c, mygin.ErrInfo{
Code: http.StatusBadRequest,
Title: "Something Wrong",
Msg: err.Error(),
}, true)
return
}
state, stateKey := randomString[:16], randomString[16:]
singleton.Cache.Set(fmt.Sprintf("%s%s", model.CacheKeyOauth2State, stateKey), state, cache.DefaultExpiration)
url := oa.getCommonOauth2Config(c).AuthCodeURL(state, oauth2.AccessTypeOnline)
c.SetCookie(singleton.Conf.Site.CookieName+"-sk", stateKey, 60*5, "", "", false, false)
2022-06-02 21:45:11 -04:00
c.HTML(http.StatusOK, "dashboard-"+singleton.Conf.Site.DashboardTheme+"/redirect", mygin.CommonEnvironment(c, gin.H{
"URL": url,
2022-04-30 09:39:25 -04:00
}))
2019-12-08 03:59:58 -05:00
}
func (oa *oauth2controller) callback(c *gin.Context) {
var err error
2019-12-08 03:59:58 -05:00
// 验证登录跳转时的 State
stateKey, err := c.Cookie(singleton.Conf.Site.CookieName + "-sk")
if err == nil {
state, ok := singleton.Cache.Get(fmt.Sprintf("%s%s", model.CacheKeyOauth2State, stateKey))
if !ok || state.(string) != c.Query("state") {
err = errors.New("非法的登录方式")
}
2019-12-08 03:59:58 -05:00
}
oauth2Config := oa.getCommonOauth2Config(c)
2019-12-08 03:59:58 -05:00
ctx := context.Background()
var otk *oauth2.Token
if err == nil {
otk, err = oauth2Config.Exchange(ctx, c.Query("code"))
}
var user model.User
if err == nil {
if singleton.Conf.Oauth2.Type == model.ConfigTypeGitlab || singleton.Conf.Oauth2.Type == model.ConfigTypeJihulab {
var gitlabApiClient *gitlab.Client
if singleton.Conf.Oauth2.Type == model.ConfigTypeGitlab {
gitlabApiClient, err = gitlab.NewOAuthClient(otk.AccessToken)
} else {
gitlabApiClient, err = gitlab.NewOAuthClient(otk.AccessToken, gitlab.WithBaseURL("https://jihulab.com/api/v4/"))
}
var u *gitlab.User
if err == nil {
u, _, err = gitlabApiClient.Users.CurrentUser()
}
if err == nil {
user = model.NewUserFromGitlab(u)
}
} else if singleton.Conf.Oauth2.Type == model.ConfigTypeGitea {
var giteaApiClient *gitea.Client
giteaApiClient, err = gitea.NewClient(singleton.Conf.Oauth2.Endpoint, gitea.SetToken(otk.AccessToken))
var u *gitea.User
if err == nil {
u, _, err = giteaApiClient.GetMyUserInfo()
}
if err == nil {
user = model.NewUserFromGitea(u)
}
} else if singleton.Conf.Oauth2.Type == model.ConfigTypeCloudflare {
client := oauth2Config.Client(context.Background(), otk)
resp, err := client.Get(fmt.Sprintf("%s/cdn-cgi/access/sso/oidc/%s/userinfo", singleton.Conf.Oauth2.Endpoint, singleton.Conf.Oauth2.ClientID))
if err == nil {
defer resp.Body.Close()
var cloudflareUserInfo *cloudflare.UserInfo
if err := json.NewDecoder(resp.Body).Decode(&cloudflareUserInfo); err == nil {
user = cloudflareUserInfo.MapToNezhaUser()
}
}
} else if singleton.Conf.Oauth2.Type == model.ConfigTypeOidc {
userInfo, err := oa.oidcProvider.UserInfo(c.Request.Context(), oauth2.StaticTokenSource(otk))
if err == nil {
loginClaim := singleton.Conf.Oauth2.OidcLoginClaim
groupClain := singleton.Conf.Oauth2.OidcGroupClaim
adminGroups := strings.Split(singleton.Conf.Oauth2.AdminGroups, ",")
autoCreate := singleton.Conf.Oauth2.OidcAutoCreate
var oidceUserInfo *myOidc.UserInfo
if err := userInfo.Claims(&oidceUserInfo); err == nil {
user = oidceUserInfo.MapToNezhaUser(loginClaim, groupClain, adminGroups, autoCreate)
}
}
} else {
var client *GitHubAPI.Client
oc := oauth2Config.Client(ctx, otk)
if singleton.Conf.Oauth2.Type == model.ConfigTypeGitee {
2022-10-13 13:31:00 -04:00
baseURL, _ := url.Parse("https://gitee.com/api/v5/")
uploadURL, _ := url.Parse("https://gitee.com/api/v5/uploads/")
client = GitHubAPI.NewClient(oc)
client.BaseURL = baseURL
client.UploadURL = uploadURL
} else {
client = GitHubAPI.NewClient(oc)
}
var gu *GitHubAPI.User
2024-06-23 04:15:08 -04:00
gu, _, err = client.Users.Get(ctx, "")
if err == nil {
user = model.NewUserFromGitHub(gu)
}
}
}
if err == nil && user.Login == "" {
err = errors.New("获取用户信息失败")
2019-12-08 03:59:58 -05:00
}
if err != nil || user.Login == "" {
2019-12-08 03:59:58 -05:00
mygin.ShowErrorPage(c, mygin.ErrInfo{
Code: http.StatusBadRequest,
Title: "登录失败",
Msg: fmt.Sprintf("错误信息:%s", err),
}, true)
return
}
2020-03-22 10:28:25 -04:00
var isAdmin bool
if user.SuperAdmin {
isAdmin = true
} else {
for _, admin := range strings.Split(singleton.Conf.Oauth2.Admin, ",") {
if admin != "" && strings.EqualFold(user.Login, admin) {
isAdmin = true
break
}
2020-03-22 10:28:25 -04:00
}
}
if !isAdmin {
2019-12-08 03:59:58 -05:00
mygin.ShowErrorPage(c, mygin.ErrInfo{
Code: http.StatusBadRequest,
Title: "登录失败",
Msg: fmt.Sprintf("错误信息:%s", "该用户不是本站点管理员,无法登录"),
}, true)
return
}
2022-12-16 10:34:14 -05:00
user.Token, err = utils.GenerateRandomString(32)
if err != nil {
mygin.ShowErrorPage(c, mygin.ErrInfo{
Code: http.StatusBadRequest,
Title: "Something wrong",
Msg: err.Error(),
}, true)
return
}
user.TokenExpired = time.Now().AddDate(0, 2, 0)
2022-01-08 22:54:14 -05:00
singleton.DB.Save(&user)
c.SetCookie(singleton.Conf.Site.CookieName, user.Token, 60*60*24, "", "", false, false)
2022-06-02 21:45:11 -04:00
c.HTML(http.StatusOK, "dashboard-"+singleton.Conf.Site.DashboardTheme+"/redirect", mygin.CommonEnvironment(c, gin.H{
"URL": "/",
2022-04-30 09:39:25 -04:00
}))
2019-12-08 03:59:58 -05:00
}
func removeDuplicates(elements []string) []string {
encountered := map[string]bool{}
result := []string{}
for _, v := range elements {
if !encountered[v] {
encountered[v] = true
result = append(result, v)
}
}
return result
}