diff --git a/cmd/dashboard/controller/controller.go b/cmd/dashboard/controller/controller.go index 6454b28..268f125 100644 --- a/cmd/dashboard/controller/controller.go +++ b/cmd/dashboard/controller/controller.go @@ -56,6 +56,8 @@ func routers(r *gin.Engine, frontendDist fs.FS) { } api := r.Group("api/v1") api.POST("/login", authMiddleware.LoginHandler) + api.GET("/oauth2/:provider", commonHandler(oauth2redirect)) + api.POST("/oauth2/:provider/callback", commonHandler(oauth2callback(authMiddleware))) optionalAuth := api.Group("", optionalAuthMiddleware(authMiddleware)) optionalAuth.GET("/ws/server", commonHandler(serverStream)) @@ -79,6 +81,9 @@ func routers(r *gin.Engine, frontendDist fs.FS) { auth.GET("/profile", commonHandler(getProfile)) auth.POST("/profile", commonHandler(updateProfile)) + auth.POST("/oauth2/:provider/bind", commonHandler(bindOauth2)) + auth.POST("/oauth2/:provider/unbind", commonHandler(unbindOauth2)) + auth.GET("/user", adminHandler(listUser)) auth.POST("/user", adminHandler(createUser)) auth.POST("/batch-delete/user", adminHandler(batchDeleteUser)) diff --git a/cmd/dashboard/controller/oauth2.go b/cmd/dashboard/controller/oauth2.go new file mode 100644 index 0000000..5dd6f16 --- /dev/null +++ b/cmd/dashboard/controller/oauth2.go @@ -0,0 +1,240 @@ +package controller + +import ( + "context" + "fmt" + "io" + "strconv" + "strings" + "time" + + jwt "github.com/appleboy/gin-jwt/v2" + "github.com/gin-gonic/gin" + "github.com/patrickmn/go-cache" + "github.com/tidwall/gjson" + "golang.org/x/oauth2" + "gorm.io/gorm" + + "github.com/nezhahq/nezha/model" + "github.com/nezhahq/nezha/pkg/utils" + "github.com/nezhahq/nezha/service/singleton" +) + +type Oauth2LoginType uint8 + +const ( + _ Oauth2LoginType = iota + rTypeLogin + rTypeBind +) + +func getRedirectURL(c *gin.Context, provider string, rType Oauth2LoginType) string { + scheme := "http://" + referer := c.Request.Referer() + if forwardedProto := c.Request.Header.Get("X-Forwarded-Proto"); forwardedProto == "https" || strings.HasPrefix(referer, "https://") { + scheme = "https://" + } + var suffix string + if rType == rTypeLogin { + suffix = "/dashboard/login?provider=" + provider + } else if rType == rTypeBind { + suffix = "/dashboard/profile?provider=" + provider + } + return scheme + c.Request.Host + suffix +} + +// @Summary Get Oauth2 Redirect URL +// @Description Get Oauth2 Redirect URL +// @Produce json +// @Param provider path string true "provider" +// @Param type query int false "type" Enums(1, 2) default(1) +// @Success 200 {object} model.Oauth2LoginResponse +// @Router /api/v1/oauth2/{provider} [get] +func oauth2redirect(c *gin.Context) (*model.Oauth2LoginResponse, error) { + provider := c.Param("provider") + if provider == "" { + return nil, singleton.Localizer.ErrorT("provider is required") + } + + rTypeInt, err := strconv.Atoi(c.Query("type")) + if err != nil { + return nil, err + } + + o2confRaw, has := singleton.Conf.Oauth2[provider] + if !has { + return nil, singleton.Localizer.ErrorT("provider not found") + } + o2conf := o2confRaw.Setup(getRedirectURL(c, provider, Oauth2LoginType(rTypeInt))) + + randomString, err := utils.GenerateRandomString(32) + if err != nil { + return nil, err + } + state, stateKey := randomString[:16], randomString[16:] + singleton.Cache.Set(fmt.Sprintf("%s%s", model.CacheKeyOauth2State, stateKey), state, cache.DefaultExpiration) + + url := o2conf.AuthCodeURL(state, oauth2.AccessTypeOnline) + c.SetCookie("nz-o2s", stateKey, 60*5, "", "", false, false) + + return &model.Oauth2LoginResponse{Redirect: url}, nil +} + +func exchangeOpenId(c *gin.Context, o2confRaw *model.Oauth2Config, provider string, callbackData model.Oauth2Callback) (string, error) { + // 验证登录跳转时的 State + stateKey, err := c.Cookie("nz-o2s") + if err != nil { + return "", singleton.Localizer.ErrorT("invalid state key") + } + state, ok := singleton.Cache.Get(fmt.Sprintf("%s%s", model.CacheKeyOauth2State, stateKey)) + if !ok || state.(string) != callbackData.State { + return "", singleton.Localizer.ErrorT("invalid state key") + } + + o2conf := o2confRaw.Setup(getRedirectURL(c, provider, rTypeLogin)) + + ctx := context.Background() + + otk, err := o2conf.Exchange(ctx, callbackData.Code) + if err != nil { + return "", err + } + oauth2client := o2conf.Client(ctx, otk) + resp, err := oauth2client.Get(o2confRaw.UserInfoURL) + if err != nil { + return "", err + } + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + + return gjson.Get(string(body), o2confRaw.UserIDPath).String(), nil +} + +// @Summary Oauth2 Callback +// @Description Oauth2 Callback +// @Accept json +// @Produce json +// @Param provider path string true "provider" +// @Param body body model.Oauth2Callback true "body" +// @Success 200 {object} model.LoginResponse +// @Router /api/v1/oauth2/{provider}/callback [post] +func oauth2callback(jwtConfig *jwt.GinJWTMiddleware) func(c *gin.Context) (*model.LoginResponse, error) { + return func(c *gin.Context) (*model.LoginResponse, error) { + provider := c.Param("provider") + if provider == "" { + return nil, singleton.Localizer.ErrorT("provider is required") + } + + o2confRaw, has := singleton.Conf.Oauth2[provider] + if !has { + return nil, singleton.Localizer.ErrorT("provider not found") + } + provider = strings.ToLower(provider) + + var callbackData model.Oauth2Callback + if err := c.ShouldBind(&callbackData); err != nil { + return nil, err + } + + realip := c.GetString(model.CtxKeyRealIPStr) + + if callbackData.Code == "" { + model.BlockIP(singleton.DB, realip, model.WAFBlockReasonTypeBruteForceOauth2, model.BlockIDToken) + return nil, singleton.Localizer.ErrorT("code is required") + } + + openId, err := exchangeOpenId(c, o2confRaw, provider, callbackData) + if err != nil { + model.BlockIP(singleton.DB, realip, model.WAFBlockReasonTypeBruteForceOauth2, model.BlockIDToken) + return nil, err + } + + var bind model.Oauth2Bind + if err := singleton.DB.Where("provider = ? AND open_id = ?", provider, openId).First(&bind).Error; err != nil { + return nil, singleton.Localizer.ErrorT("oauth2 user not binded yet") + } + + tokenString, expire, err := jwtConfig.TokenGenerator(fmt.Sprintf("%d", bind.UserID)) + if err != nil { + return nil, err + } + + jwtConfig.SetCookie(c, tokenString) + + return &model.LoginResponse{Token: tokenString, Expire: expire.Format(time.RFC3339)}, nil + } +} + +// @Summary Bind Oauth2 +// @Description Bind Oauth2 +// @Accept json +// @Produce json +// @Param provider path string true "provider" +// @Param body body model.Oauth2Callback true "body" +// @Success 200 {object} any +// @Router /api/v1/oauth2/{provider}/bind [post] +func bindOauth2(c *gin.Context) (any, error) { + var bindData model.Oauth2Callback + if err := c.ShouldBind(&bindData); err != nil { + return nil, err + } + + provider := c.Param("provider") + o2conf, has := singleton.Conf.Oauth2[provider] + if !has { + return nil, singleton.Localizer.ErrorT("provider not found") + } + provider = strings.ToLower(provider) + + openId, err := exchangeOpenId(c, o2conf, provider, bindData) + if err != nil { + return nil, err + } + + u := c.MustGet(model.CtxKeyAuthorizedUser).(*model.User) + + var bind model.Oauth2Bind + result := singleton.DB.Where("provider = ? AND open_id = ?", provider, openId).Limit(1).Find(&bind) + if result.Error != nil && result.Error != gorm.ErrRecordNotFound { + return nil, newGormError("%v", result.Error) + } + bind.UserID = u.ID + bind.Provider = provider + bind.OpenID = openId + if result.Error == gorm.ErrRecordNotFound { + result = singleton.DB.Create(&bind) + } else { + result = singleton.DB.Save(&bind) + } + if result.Error != nil { + return nil, newGormError("%v", result.Error) + } + return nil, nil +} + +// @Summary Unbind Oauth2 +// @Description Unbind Oauth2 +// @Accept json +// @Produce json +// @Param provider path string true "provider" +// @Success 200 {object} any +// @Router /api/v1/oauth2/{provider}/unbind [post] +func unbindOauth2(c *gin.Context) (any, error) { + provider := c.Param("provider") + if provider == "" { + return nil, singleton.Localizer.ErrorT("provider is required") + } + _, has := singleton.Conf.Oauth2[provider] + if !has { + return nil, singleton.Localizer.ErrorT("provider not found") + } + provider = strings.ToLower(provider) + u := c.MustGet(model.CtxKeyAuthorizedUser).(*model.User) + if err := singleton.DB.Where("provider = ? AND user_id = ?", provider, u.ID).Delete(&model.Oauth2Bind{}).Error; err != nil { + return nil, newGormError("%v", err) + } + return nil, nil +} diff --git a/cmd/dashboard/controller/setting.go b/cmd/dashboard/controller/setting.go index bcb894a..0a03787 100644 --- a/cmd/dashboard/controller/setting.go +++ b/cmd/dashboard/controller/setting.go @@ -17,9 +17,9 @@ import ( // @Security BearerAuth // @Tags common // @Produce json -// @Success 200 {object} model.CommonResponse[model.SettingResponse] +// @Success 200 {object} model.CommonResponse[model.SettingResponse[model.Config]] // @Router /setting [get] -func listConfig(c *gin.Context) (model.SettingResponse, error) { +func listConfig(c *gin.Context) (model.SettingResponse[any], error) { u, authorized := c.Get(model.CtxKeyAuthorizedUser) var isAdmin bool if authorized { @@ -27,30 +27,32 @@ func listConfig(c *gin.Context) (model.SettingResponse, error) { isAdmin = user.Role == model.RoleAdmin } - conf := model.SettingResponse{ - Config: *singleton.Conf, + config := *singleton.Conf + config.Language = strings.Replace(config.Language, "_", "-", -1) + + conf := model.SettingResponse[any]{ + Config: config, Version: singleton.Version, FrontendTemplates: singleton.FrontendTemplates, } if !authorized || !isAdmin { - conf = model.SettingResponse{ - Config: model.Config{ - SiteName: conf.SiteName, - Language: conf.Language, - CustomCode: conf.CustomCode, - CustomCodeDashboard: conf.CustomCodeDashboard, - }, + configForGuests := model.ConfigForGuests{ + Language: config.Language, + SiteName: config.SiteName, + CustomCode: config.CustomCode, + CustomCodeDashboard: config.CustomCodeDashboard, + Oauth2Providers: config.Oauth2Providers, + } + if authorized { + config.TLS = singleton.Conf.TLS + config.InstallHost = singleton.Conf.InstallHost + } + conf = model.SettingResponse[any]{ + Config: configForGuests, } } - if !isAdmin { - conf.Config.TLS = singleton.Conf.TLS - conf.Config.InstallHost = singleton.Conf.InstallHost - } - - conf.Config.Language = strings.Replace(conf.Config.Language, "_", "-", -1) - return conf, nil } diff --git a/cmd/dashboard/controller/user.go b/cmd/dashboard/controller/user.go index 31a07dd..1a0f7ae 100644 --- a/cmd/dashboard/controller/user.go +++ b/cmd/dashboard/controller/user.go @@ -26,9 +26,18 @@ func getProfile(c *gin.Context) (*model.Profile, error) { if !ok { return nil, singleton.Localizer.ErrorT("unauthorized") } + var ob []model.Oauth2Bind + if err := singleton.DB.Where("user_id = ?", auth.(*model.User).ID).Find(&ob).Error; err != nil { + return nil, newGormError("%v", err) + } + var obMap = make(map[string]string) + for _, v := range ob { + obMap[v.Provider] = v.OpenID + } return &model.Profile{ - User: *auth.(*model.User), - LoginIP: c.GetString(model.CtxKeyRealIPStr), + User: *auth.(*model.User), + LoginIP: c.GetString(model.CtxKeyRealIPStr), + Oauth2Bind: obMap, }, nil } diff --git a/go.mod b/go.mod index f0c6f0e..1df547b 100644 --- a/go.mod +++ b/go.mod @@ -33,6 +33,7 @@ require ( golang.org/x/crypto v0.31.0 golang.org/x/exp v0.0.0-20241217172543-b2144cdd0a67 golang.org/x/net v0.33.0 + golang.org/x/oauth2 v0.23.0 golang.org/x/sync v0.10.0 google.golang.org/grpc v1.69.2 google.golang.org/protobuf v1.36.0 diff --git a/go.sum b/go.sum index 8c35b05..a103bcc 100644 --- a/go.sum +++ b/go.sum @@ -198,6 +198,8 @@ golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= +golang.org/x/oauth2 v0.23.0 h1:PbgcYx2W7i4LvjJWEbf0ngHV6qJYr86PkAV3bXdLEbs= +golang.org/x/oauth2 v0.23.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= diff --git a/model/api.go b/model/api.go index dbcea5b..926f127 100644 --- a/model/api.go +++ b/model/api.go @@ -4,6 +4,15 @@ const ( ApiErrorUnauthorized = 10001 ) +type Oauth2LoginResponse struct { + Redirect string `json:"redirect,omitempty"` +} + +type Oauth2Callback struct { + State string `json:"state,omitempty"` + Code string `json:"code,omitempty"` +} + type LoginRequest struct { Username string `json:"username,omitempty"` Password string `json:"password,omitempty"` diff --git a/model/common.go b/model/common.go index 47f1e91..0cb6bf0 100644 --- a/model/common.go +++ b/model/common.go @@ -16,6 +16,10 @@ const ( CtxKeyRealIPStr = "ckri" ) +const ( + CacheKeyOauth2State = "cko2s::" +) + type CtxKeyRealIP struct{} type CtxKeyConnectingIP struct{} diff --git a/model/config.go b/model/config.go index 6fcda34..4072f24 100644 --- a/model/config.go +++ b/model/config.go @@ -1,8 +1,10 @@ package model import ( + "maps" "os" "path/filepath" + "slices" "strconv" "strings" @@ -21,6 +23,17 @@ const ( ConfigCoverIgnoreAll ) +type ConfigForGuests struct { + Language string `json:"language"` + SiteName string `json:"site_name"` + CustomCode string `json:"custom_code,omitempty"` + CustomCodeDashboard string `json:"custom_code_dashboard,omitempty"` + Oauth2Providers []string `json:"oauth2_providers,omitempty"` + + InstallHost string `json:"install_host,omitempty"` + TLS bool `json:"tls,omitempty"` +} + type Config struct { Debug bool `mapstructure:"debug" json:"debug,omitempty"` // debug模式开关 RealIPHeader string `mapstructure:"real_ip_header" json:"real_ip_header,omitempty"` // 真实IP @@ -52,6 +65,11 @@ type Config struct { CustomCode string `mapstructure:"custom_code" json:"custom_code,omitempty"` CustomCodeDashboard string `mapstructure:"custom_code_dashboard" json:"custom_code_dashboard,omitempty"` + // oauth2 配置 + Oauth2 map[string]*Oauth2Config `mapstructure:"oauth2" json:"oauth2,omitempty"` + // oauth2 供应商列表,无需配置,自动生成 + Oauth2Providers []string `yaml:"-" json:"oauth2_providers,omitempty"` + k *koanf.Koanf `json:"-"` filePath string `json:"-"` } @@ -132,6 +150,8 @@ func (c *Config) Read(path string, frontendTemplates []FrontendTemplate) error { } } + c.Oauth2Providers = slices.Collect(maps.Keys(c.Oauth2)) + c.updateIgnoredIPNotificationID() return nil } diff --git a/model/oauth2bind.go b/model/oauth2bind.go new file mode 100644 index 0000000..3631510 --- /dev/null +++ b/model/oauth2bind.go @@ -0,0 +1,9 @@ +package model + +type Oauth2Bind struct { + Common + + UserID uint64 `gorm:"uniqueIndex:u_p_o" json:"user_id,omitempty"` + Provider string `gorm:"uniqueIndex:u_p_o" json:"provider,omitempty"` + OpenID string `gorm:"uniqueIndex:u_p_o" json:"open_id,omitempty"` +} diff --git a/model/oauth2config.go b/model/oauth2config.go new file mode 100644 index 0000000..49b88bf --- /dev/null +++ b/model/oauth2config.go @@ -0,0 +1,33 @@ +package model + +import ( + "golang.org/x/oauth2" +) + +type Oauth2Config struct { + ClientID string `mapstructure:"client_id" json:"client_id,omitempty"` + ClientSecret string `mapstructure:"client_secret" json:"client_secret,omitempty"` + Endpoint Oauth2Endpoint `mapstructure:"endpoint" json:"endpoint,omitempty"` + Scopes []string `mapstructure:"scopes" json:"scopes,omitempty"` + + UserInfoURL string `mapstructure:"user_info_url" json:"user_info_url,omitempty"` + UserIDPath string `mapstructure:"user_id_path" json:"user_id_path,omitempty"` +} + +type Oauth2Endpoint struct { + AuthURL string `mapstructure:"auth_url" json:"auth_url,omitempty"` + TokenURL string `mapstructure:"token_url" json:"token_url,omitempty"` +} + +func (c *Oauth2Config) Setup(redirectURL string) *oauth2.Config { + return &oauth2.Config{ + ClientID: c.ClientID, + ClientSecret: c.ClientSecret, + Endpoint: oauth2.Endpoint{ + AuthURL: c.Endpoint.AuthURL, + TokenURL: c.Endpoint.TokenURL, + }, + RedirectURL: redirectURL, + Scopes: c.Scopes, + } +} diff --git a/model/setting_api.go b/model/setting_api.go index d81efc5..bb445d7 100644 --- a/model/setting_api.go +++ b/model/setting_api.go @@ -28,8 +28,8 @@ type FrontendTemplate struct { IsOfficial bool `json:"is_official,omitempty"` } -type SettingResponse struct { - Config +type SettingResponse[T any] struct { + Config T `json:"config,omitempty"` Version string `json:"version,omitempty"` FrontendTemplates []FrontendTemplate `json:"frontend_templates,omitempty"` diff --git a/model/user.go b/model/user.go index bff8d54..24c1613 100644 --- a/model/user.go +++ b/model/user.go @@ -42,7 +42,8 @@ func (u *User) BeforeSave(tx *gorm.DB) error { type Profile struct { User - LoginIP string `json:"login_ip,omitempty"` + LoginIP string `json:"login_ip,omitempty"` + Oauth2Bind map[string]string `json:"oauth2_bind,omitempty"` } type OnlineUser struct { diff --git a/model/waf.go b/model/waf.go index 2567c6a..2983eca 100644 --- a/model/waf.go +++ b/model/waf.go @@ -15,6 +15,7 @@ const ( WAFBlockReasonTypeBruteForceToken WAFBlockReasonTypeAgentAuthFail WAFBlockReasonTypeManual + WAFBlockReasonTypeBruteForceOauth2 ) const ( diff --git a/service/singleton/singleton.go b/service/singleton/singleton.go index c284e91..c46471b 100644 --- a/service/singleton/singleton.go +++ b/service/singleton/singleton.go @@ -82,7 +82,7 @@ func InitDBFromPath(path string) { model.Notification{}, model.AlertRule{}, model.Service{}, model.NotificationGroupNotification{}, model.ServiceHistory{}, model.Cron{}, model.Transfer{}, model.ServerGroupServer{}, model.NAT{}, model.DDNSProfile{}, model.NotificationGroupNotification{}, - model.WAF{}) + model.WAF{}, model.Oauth2Bind{}) if err != nil { panic(err) }