refactor agent auth & server api

This commit is contained in:
naiba 2024-10-20 23:23:04 +08:00
parent d3f907b5c3
commit aa20c97312
19 changed files with 488 additions and 330 deletions

View File

@ -8,6 +8,10 @@ on:
- "go.sum"
- "resource/**"
- ".github/workflows/test.yml"
pull_request:
branches:
- master
jobs:
tests:

View File

@ -72,7 +72,7 @@ func (v *apiV1) serverDetails(c *gin.Context) {
}
tag := c.Query("tag")
if tag != "" {
c.JSON(200, singleton.ServerAPI.GetStatusByTag(tag))
// c.JSON(200, singleton.ServerAPI.GetStatusByTag(tag))
return
}
if len(idList) != 0 {

View File

@ -10,7 +10,6 @@ import (
"github.com/gorilla/websocket"
"github.com/hashicorp/go-uuid"
"github.com/jinzhu/copier"
"golang.org/x/sync/singleflight"
"github.com/naiba/nezha/model"
"github.com/naiba/nezha/pkg/utils"
@ -21,8 +20,7 @@ import (
)
type commonPage struct {
r *gin.Engine
requestGroup singleflight.Group
r *gin.Engine
}
func (cp *commonPage) serve() {
@ -37,14 +35,13 @@ func (cp *commonPage) serve() {
// TODO: 界面直接跳转使用该接口
cr.GET("/network/:id", cp.network)
cr.GET("/network", cp.network)
cr.GET("/ws", cp.ws)
cr.POST("/terminal", cp.createTerminal)
cr.GET("/file", cp.createFM)
cr.GET("/file/:id", cp.fm)
}
func (p *commonPage) service(c *gin.Context) {
res, _, _ := p.requestGroup.Do("servicePage", func() (interface{}, error) {
res, _, _ := requestGroup.Do("servicePage", func() (interface{}, error) {
singleton.AlertsLock.RLock()
defer singleton.AlertsLock.RUnlock()
var stats map[uint64]model.ServiceItemResponse
@ -71,7 +68,7 @@ func (p *commonPage) service(c *gin.Context) {
func (cp *commonPage) network(c *gin.Context) {
var (
monitorHistory *model.MonitorHistory
servers []*model.Server
servers []model.Server
serverIdsWithMonitor []uint64
monitorInfos = []byte("{}")
id uint64
@ -148,7 +145,7 @@ func (cp *commonPage) network(c *gin.Context) {
for _, server := range singleton.SortedServerList {
for _, id := range serverIdsWithMonitor {
if server.ID == id {
servers = append(servers, server)
servers = append(servers, *server)
}
}
}
@ -156,14 +153,14 @@ func (cp *commonPage) network(c *gin.Context) {
for _, server := range singleton.SortedServerListForGuest {
for _, id := range serverIdsWithMonitor {
if server.ID == id {
servers = append(servers, server)
servers = append(servers, *server)
}
}
}
}
serversBytes, _ := utils.Json.Marshal(Data{
Now: time.Now().Unix() * 1000,
Servers: servers,
serversBytes, _ := utils.Json.Marshal(model.StreamServerData{
Now: time.Now().Unix() * 1000,
// Servers: servers,
})
c.HTML(http.StatusOK, "", gin.H{
@ -176,7 +173,7 @@ func (cp *commonPage) getServerStat(c *gin.Context, withPublicNote bool) ([]byte
_, isMember := c.Get(model.CtxKeyAuthorizedUser)
var isViewPasswordVerfied bool
authorized := isMember || isViewPasswordVerfied
v, err, _ := cp.requestGroup.Do(fmt.Sprintf("serverStats::%t", authorized), func() (interface{}, error) {
v, err, _ := requestGroup.Do(fmt.Sprintf("serverStats::%t", authorized), func() (interface{}, error) {
singleton.SortedServerLock.RLock()
defer singleton.SortedServerLock.RUnlock()
@ -187,18 +184,18 @@ func (cp *commonPage) getServerStat(c *gin.Context, withPublicNote bool) ([]byte
serverList = singleton.SortedServerListForGuest
}
var servers []*model.Server
var servers []model.Server
for _, server := range serverList {
item := *server
if !withPublicNote {
item.PublicNote = ""
}
servers = append(servers, &item)
servers = append(servers, item)
}
return utils.Json.Marshal(Data{
Now: time.Now().Unix() * 1000,
Servers: servers,
return utils.Json.Marshal(model.StreamServerData{
Now: time.Now().Unix() * 1000,
// Servers: servers,
})
})
return v.([]byte), err
@ -223,51 +220,6 @@ func (cp *commonPage) home(c *gin.Context) {
})
}
var upgrader = websocket.Upgrader{
ReadBufferSize: 32768,
WriteBufferSize: 32768,
}
type Data struct {
Now int64 `json:"now,omitempty"`
Servers []*model.Server `json:"servers,omitempty"`
}
func (cp *commonPage) ws(c *gin.Context) {
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
// mygin.ShowErrorPage(c, mygin.ErrInfo{
// Code: http.StatusInternalServerError,
// // Title: singleton.Localizer.MustLocalize(&i18n.LocalizeConfig{
// // MessageID: "NetworkError",
// // }),
// Msg: "Websocket协议切换失败",
// Link: "/",
// Btn: "返回首页",
// }, true)
return
}
defer conn.Close()
count := 0
for {
stat, err := cp.getServerStat(c, false)
if err != nil {
continue
}
if err := conn.WriteMessage(websocket.TextMessage, stat); err != nil {
break
}
count += 1
if count%4 == 0 {
err = conn.WriteMessage(websocket.PingMessage, []byte{})
if err != nil {
break
}
}
time.Sleep(time.Second * 2)
}
}
func (cp *commonPage) terminal(c *gin.Context) {
streamId := c.Param("id")
if _, err := rpc.NezhaHandlerSingleton.GetStream(streamId); err != nil {

View File

@ -49,13 +49,21 @@ func routers(r *gin.Engine) {
if err != nil {
log.Fatal("JWT Error:" + err.Error())
}
if err := authMiddleware.MiddlewareInit(); err != nil {
log.Fatal("authMiddleware.MiddlewareInit Error:" + err.Error())
}
api := r.Group("api/v1")
api.Use(handlerMiddleWare(authMiddleware))
api.POST("/login", authMiddleware.LoginHandler)
unrequiredAuth := api.Group("", unrquiredAuthMiddleware(authMiddleware))
unrequiredAuth.GET("/ws/server", serverStream)
unrequiredAuth.GET("/server-group", listServerGroup)
auth := api.Group("", authMiddleware.MiddlewareFunc())
auth.GET("/refresh_token", authMiddleware.RefreshHandler)
auth.PATCH("/server/:id", editServer)
api.DELETE("/batch-delete/server", batchDeleteServer)
// 通用页面
// cp := commonPage{r: r}

View File

@ -1,8 +1,8 @@
package controller
import (
"encoding/json"
"fmt"
"log"
"net/http"
"time"
@ -16,7 +16,7 @@ import (
func initParams() *jwt.GinJWTMiddleware {
return &jwt.GinJWTMiddleware{
Realm: singleton.Conf.SiteName,
Key: []byte(singleton.Conf.SecretKey),
Key: []byte(singleton.Conf.JWTSecretKey),
CookieName: "nz-jwt",
Timeout: time.Hour,
MaxRefresh: time.Hour,
@ -44,15 +44,6 @@ func initParams() *jwt.GinJWTMiddleware {
}
}
func handlerMiddleWare(authMiddleware *jwt.GinJWTMiddleware) gin.HandlerFunc {
return func(context *gin.Context) {
errInit := authMiddleware.MiddlewareInit()
if errInit != nil {
log.Fatal("authMiddleware.MiddlewareInit() Error:" + errInit.Error())
}
}
}
func payloadFunc() func(data interface{}) jwt.MapClaims {
return func(data interface{}) jwt.MapClaims {
if v, ok := data.(string); ok {
@ -81,7 +72,7 @@ func identityHandler() func(c *gin.Context) interface{} {
// @Schemes
// @Description user login
// @Accept json
// @param request body model.LoginRequest true "Login Request"
// @param loginRequest body model.LoginRequest true "Login Request"
// @Produce json
// @Success 200 {object} model.CommonResponse[model.LoginResponse]
// @Router /login [post]
@ -152,3 +143,40 @@ func refreshResponse(c *gin.Context, code int, token string, expire time.Time) {
},
})
}
func unrquiredAuthMiddleware(mw *jwt.GinJWTMiddleware) func(c *gin.Context) {
return func(c *gin.Context) {
claims, err := mw.GetClaimsFromJWT(c)
if err != nil {
return
}
switch v := claims["exp"].(type) {
case nil:
return
case float64:
if int64(v) < mw.TimeFunc().Unix() {
return
}
case json.Number:
n, err := v.Int64()
if err != nil {
return
}
if n < mw.TimeFunc().Unix() {
return
}
default:
return
}
c.Set("JWT_PAYLOAD", claims)
identity := mw.IdentityHandler(c)
if identity != nil {
c.Set(mw.IdentityKey, identity)
}
c.Next()
}
}

View File

@ -13,7 +13,6 @@ import (
"github.com/gin-gonic/gin"
"github.com/jinzhu/copier"
"golang.org/x/net/idna"
"gorm.io/gorm"
"github.com/naiba/nezha/model"
"github.com/naiba/nezha/pkg/utils"
@ -44,7 +43,6 @@ func (ma *memberAPI) serve() {
mr.GET("/cron/:id/manual", ma.manualTrigger)
mr.POST("/force-update", ma.forceUpdate)
mr.POST("/batch-update-server-group", ma.batchUpdateServerGroup)
mr.POST("/batch-delete-server", ma.batchDeleteServer)
mr.POST("/notification", ma.addOrEditNotification)
mr.POST("/ddns", ma.addOrEditDDNS)
mr.POST("/nat", ma.addOrEditNAT)
@ -188,25 +186,6 @@ func (ma *memberAPI) delete(c *gin.Context) {
var err error
switch c.Param("model") {
case "server":
err := singleton.DB.Transaction(func(tx *gorm.DB) error {
err = singleton.DB.Unscoped().Delete(&model.Server{}, "id = ?", id).Error
if err != nil {
return err
}
err = singleton.DB.Unscoped().Delete(&model.MonitorHistory{}, "server_id = ?", id).Error
if err != nil {
return err
}
return nil
})
if err == nil {
// 删除服务器
singleton.ServerLock.Lock()
onServerDelete(id)
singleton.ServerLock.Unlock()
singleton.ReSortServer()
}
case "notification":
err = singleton.DB.Unscoped().Delete(&model.Notification{}, "id = ?", id).Error
if err == nil {
@ -346,10 +325,8 @@ func (ma *memberAPI) addOrEditServer(c *gin.Context) {
err := c.ShouldBindJSON(&sf)
if err == nil {
s.Name = sf.Name
s.Secret = sf.Secret
s.DisplayIndex = sf.DisplayIndex
s.ID = sf.ID
s.Tag = sf.Tag
s.Note = sf.Note
s.PublicNote = sf.PublicNote
s.HideForGuest = sf.HideForGuest == "on"
@ -358,7 +335,7 @@ func (ma *memberAPI) addOrEditServer(c *gin.Context) {
err = utils.Json.Unmarshal([]byte(sf.DDNSProfilesRaw), &s.DDNSProfiles)
if err == nil {
if s.ID == 0 {
s.Secret, err = utils.GenerateRandomString(18)
_, err = utils.GenerateRandomString(18)
if err == nil {
err = singleton.DB.Create(&s).Error
}
@ -378,34 +355,6 @@ func (ma *memberAPI) addOrEditServer(c *gin.Context) {
if isEdit {
singleton.ServerLock.Lock()
s.CopyFromRunningServer(singleton.ServerList[s.ID])
// 如果修改了 Secret
if s.Secret != singleton.ServerList[s.ID].Secret {
// 删除旧 Secret-ID 绑定关系
singleton.SecretToID[s.Secret] = s.ID
// 设置新的 Secret-ID 绑定关系
delete(singleton.SecretToID, singleton.ServerList[s.ID].Secret)
}
// 如果修改了Tag
oldTag := singleton.ServerList[s.ID].Tag
newTag := s.Tag
if newTag != oldTag {
index := -1
for i := 0; i < len(singleton.ServerTagToIDList[oldTag]); i++ {
if singleton.ServerTagToIDList[oldTag][i] == s.ID {
index = i
break
}
}
if index > -1 {
// 删除旧 Tag-ID 绑定关系
singleton.ServerTagToIDList[oldTag] = append(singleton.ServerTagToIDList[oldTag][:index], singleton.ServerTagToIDList[oldTag][index+1:]...)
if len(singleton.ServerTagToIDList[oldTag]) == 0 {
delete(singleton.ServerTagToIDList, oldTag)
}
}
// 设置新的 Tag-ID 绑定关系
singleton.ServerTagToIDList[newTag] = append(singleton.ServerTagToIDList[newTag], s.ID)
}
singleton.ServerList[s.ID] = &s
singleton.ServerLock.Unlock()
} else {
@ -413,9 +362,7 @@ func (ma *memberAPI) addOrEditServer(c *gin.Context) {
s.State = &model.HostState{}
s.TaskCloseLock = new(sync.Mutex)
singleton.ServerLock.Lock()
singleton.SecretToID[s.Secret] = s.ID
singleton.ServerList[s.ID] = &s
singleton.ServerTagToIDList[s.Tag] = append(singleton.ServerTagToIDList[s.Tag], s.ID)
singleton.ServerLock.Unlock()
}
singleton.ReSortServer()
@ -636,28 +583,28 @@ func (ma *memberAPI) batchUpdateServerGroup(c *gin.Context) {
serverId := req.Servers[i]
var s model.Server
copier.Copy(&s, singleton.ServerList[serverId])
s.Tag = req.Group
// 如果修改了Ta
oldTag := singleton.ServerList[serverId].Tag
newTag := s.Tag
if newTag != oldTag {
index := -1
for i := 0; i < len(singleton.ServerTagToIDList[oldTag]); i++ {
if singleton.ServerTagToIDList[oldTag][i] == s.ID {
index = i
break
}
}
if index > -1 {
// 删除旧 Tag-ID 绑定关系
singleton.ServerTagToIDList[oldTag] = append(singleton.ServerTagToIDList[oldTag][:index], singleton.ServerTagToIDList[oldTag][index+1:]...)
if len(singleton.ServerTagToIDList[oldTag]) == 0 {
delete(singleton.ServerTagToIDList, oldTag)
}
}
// 设置新的 Tag-ID 绑定关系
singleton.ServerTagToIDList[newTag] = append(singleton.ServerTagToIDList[newTag], s.ID)
}
// s.Tag = req.Group
// // 如果修改了Ta
// oldTag := singleton.ServerList[serverId].Tag
// newTag := s.Tag
// if newTag != oldTag {
// index := -1
// for i := 0; i < len(singleton.ServerTagToIDList[oldTag]); i++ {
// if singleton.ServerTagToIDList[oldTag][i] == s.ID {
// index = i
// break
// }
// }
// if index > -1 {
// // 删除旧 Tag-ID 绑定关系
// singleton.ServerTagToIDList[oldTag] = append(singleton.ServerTagToIDList[oldTag][:index], singleton.ServerTagToIDList[oldTag][index+1:]...)
// if len(singleton.ServerTagToIDList[oldTag]) == 0 {
// delete(singleton.ServerTagToIDList, oldTag)
// }
// }
// // 设置新的 Tag-ID 绑定关系
// singleton.ServerTagToIDList[newTag] = append(singleton.ServerTagToIDList[newTag], s.ID)
// }
singleton.ServerList[s.ID] = &s
}
@ -1067,63 +1014,3 @@ func (ma *memberAPI) updateSetting(c *gin.Context) {
Code: http.StatusOK,
})
}
func (ma *memberAPI) batchDeleteServer(c *gin.Context) {
var servers []uint64
if err := c.ShouldBindJSON(&servers); err != nil {
c.JSON(http.StatusOK, model.Response{
Code: http.StatusBadRequest,
Message: err.Error(),
})
return
}
if err := singleton.DB.Unscoped().Delete(&model.Server{}, "id in (?)", servers).Error; err != nil {
c.JSON(http.StatusOK, model.Response{
Code: http.StatusBadRequest,
Message: err.Error(),
})
return
}
singleton.ServerLock.Lock()
for i := 0; i < len(servers); i++ {
id := servers[i]
onServerDelete(id)
}
singleton.ServerLock.Unlock()
singleton.ReSortServer()
c.JSON(http.StatusOK, model.Response{
Code: http.StatusOK,
})
}
func onServerDelete(id uint64) {
tag := singleton.ServerList[id].Tag
delete(singleton.SecretToID, singleton.ServerList[id].Secret)
delete(singleton.ServerList, id)
index := -1
for i := 0; i < len(singleton.ServerTagToIDList[tag]); i++ {
if singleton.ServerTagToIDList[tag][i] == id {
index = i
break
}
}
if index > -1 {
singleton.ServerTagToIDList[tag] = append(singleton.ServerTagToIDList[tag][:index], singleton.ServerTagToIDList[tag][index+1:]...)
if len(singleton.ServerTagToIDList[tag]) == 0 {
delete(singleton.ServerTagToIDList, tag)
}
}
singleton.AlertsLock.Lock()
for i := 0; i < len(singleton.Alerts); i++ {
if singleton.AlertsCycleTransferStatsStore[singleton.Alerts[i].ID] != nil {
delete(singleton.AlertsCycleTransferStatsStore[singleton.Alerts[i].ID].ServerName, id)
delete(singleton.AlertsCycleTransferStatsStore[singleton.Alerts[i].ID].Transfer, id)
delete(singleton.AlertsCycleTransferStatsStore[singleton.Alerts[i].ID].NextUpdate, id)
}
}
singleton.AlertsLock.Unlock()
singleton.DB.Unscoped().Delete(&model.Transfer{}, "server_id = ?", id)
}

View File

@ -0,0 +1,130 @@
package controller
import (
"net/http"
"strconv"
"github.com/gin-gonic/gin"
"github.com/naiba/nezha/model"
"github.com/naiba/nezha/pkg/utils"
"github.com/naiba/nezha/service/singleton"
)
// Edit server
// @Summary Edit server
// @Security BearerAuth
// @Schemes
// @Description Edit server
// @Tags auth required
// @Produce json
// @Success 200 {object} model.CommonResponse[any]
// @Router /server/{id} [patch]
func editServer(c *gin.Context) {
idStr := c.Param("id")
id, err := strconv.ParseUint(idStr, 10, 64)
if err != nil {
c.JSON(http.StatusOK, model.CommonResponse[interface{}]{
Success: false,
Error: err.Error(),
})
return
}
var sf model.EditServer
var s model.Server
if err := c.ShouldBindJSON(&sf); err != nil {
c.JSON(http.StatusOK, model.CommonResponse[interface{}]{
Success: false,
Error: err.Error(),
})
return
}
s.Name = sf.Name
s.DisplayIndex = sf.DisplayIndex
s.ID = id
s.Note = sf.Note
s.PublicNote = sf.PublicNote
s.HideForGuest = sf.HideForGuest
s.EnableDDNS = sf.EnableDDNS
s.DDNSProfiles = sf.DDNSProfiles
ddnsProfilesRaw, err := utils.Json.Marshal(s.DDNSProfiles)
if err != nil {
c.JSON(http.StatusOK, model.CommonResponse[interface{}]{
Success: false,
Error: err.Error(),
})
return
}
s.DDNSProfilesRaw = string(ddnsProfilesRaw)
if err := singleton.DB.Save(&s).Error; err != nil {
c.JSON(http.StatusOK, model.CommonResponse[interface{}]{
Success: false,
Error: err.Error(),
})
return
}
singleton.ServerLock.Lock()
s.CopyFromRunningServer(singleton.ServerList[s.ID])
singleton.ServerList[s.ID] = &s
singleton.ServerLock.Unlock()
singleton.ReSortServer()
c.JSON(http.StatusOK, model.Response{
Code: http.StatusOK,
})
}
// Batch delete server
// @Summary Batch delete server
// @Security BearerAuth
// @Schemes
// @Description Batch delete server
// @Tags auth required
// @Accept json
// @param request body []uint64 true "id list"
// @Produce json
// @Success 200 {object} model.CommonResponse[any]
// @Router /batch-delete/server [post]
func batchDeleteServer(c *gin.Context) {
var servers []uint64
if err := c.ShouldBindJSON(&servers); err != nil {
c.JSON(http.StatusOK, model.CommonResponse[interface{}]{
Success: false,
Error: err.Error(),
})
return
}
if err := singleton.DB.Unscoped().Delete(&model.Server{}, "id in (?)", servers).Error; err != nil {
c.JSON(http.StatusOK, model.CommonResponse[interface{}]{
Success: false,
Error: err.Error(),
})
return
}
singleton.ServerLock.Lock()
for i := 0; i < len(servers); i++ {
id := servers[i]
delete(singleton.ServerList, id)
singleton.AlertsLock.Lock()
for i := 0; i < len(singleton.Alerts); i++ {
if singleton.AlertsCycleTransferStatsStore[singleton.Alerts[i].ID] != nil {
delete(singleton.AlertsCycleTransferStatsStore[singleton.Alerts[i].ID].ServerName, id)
delete(singleton.AlertsCycleTransferStatsStore[singleton.Alerts[i].ID].Transfer, id)
delete(singleton.AlertsCycleTransferStatsStore[singleton.Alerts[i].ID].NextUpdate, id)
}
}
singleton.AlertsLock.Unlock()
singleton.DB.Unscoped().Delete(&model.Transfer{}, "server_id = ?", id)
}
singleton.ServerLock.Unlock()
singleton.ReSortServer()
c.JSON(http.StatusOK, model.CommonResponse[interface{}]{
Success: true,
})
}

View File

@ -0,0 +1,36 @@
package controller
import (
"log"
"net/http"
"github.com/gin-gonic/gin"
"github.com/naiba/nezha/model"
"github.com/naiba/nezha/pkg/utils"
"github.com/naiba/nezha/service/singleton"
)
// List server group
// @Summary List server group
// @Schemes
// @Description List server group
// @Security BearerAuth
// @Tags common
// @Produce json
// @Success 200 {object} model.CommonResponse[[]model.ServerGroup]
// @Router /server-group [get]
func listServerGroup(c *gin.Context) {
authorizedUser, has := c.Get(model.CtxKeyAuthorizedUser)
log.Println("bingo test", authorizedUser, has)
var sg []model.ServerGroup
err := singleton.DB.Find(&sg).Error
c.JSON(http.StatusOK, model.CommonResponse[[]model.ServerGroup]{
Success: err == nil,
Data: sg,
Error: utils.IfOrFn[string](err == nil, func() string {
return err.Error()
}, func() string {
return ""
}),
})
}

View File

@ -0,0 +1,96 @@
package controller
import (
"fmt"
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/naiba/nezha/model"
"github.com/naiba/nezha/pkg/utils"
"github.com/naiba/nezha/service/singleton"
"golang.org/x/sync/singleflight"
)
var upgrader = websocket.Upgrader{
ReadBufferSize: 32768,
WriteBufferSize: 32768,
}
// Websocket server stream
// @Summary Websocket server stream
// @tags common
// @Schemes
// @Description Websocket server stream
// @security BearerAuth
// @Produce json
// @Success 200 {object} model.StreamServerData
// @Router /ws/server [get]
func serverStream(c *gin.Context) {
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
c.JSON(http.StatusOK, model.CommonResponse[interface{}]{
Success: false,
Error: err.Error(),
})
return
}
defer conn.Close()
count := 0
for {
stat, err := getServerStat(c, count == 0)
if err != nil {
continue
}
if err := conn.WriteMessage(websocket.TextMessage, stat); err != nil {
break
}
count += 1
if count%4 == 0 {
err = conn.WriteMessage(websocket.PingMessage, []byte{})
if err != nil {
break
}
}
time.Sleep(time.Second * 2)
}
}
var requestGroup singleflight.Group
func getServerStat(c *gin.Context, withPublicNote bool) ([]byte, error) {
_, isMember := c.Get(model.CtxKeyAuthorizedUser)
authorized := isMember // TODO || isViewPasswordVerfied
v, err, _ := requestGroup.Do(fmt.Sprintf("serverStats::%t", authorized), func() (interface{}, error) {
singleton.SortedServerLock.RLock()
defer singleton.SortedServerLock.RUnlock()
var serverList []*model.Server
if authorized {
serverList = singleton.SortedServerList
} else {
serverList = singleton.SortedServerListForGuest
}
var servers []model.StreamServer
for i := 0; i < len(serverList); i++ {
server := serverList[i]
servers = append(servers, model.StreamServer{
ID: server.ID,
Name: server.Name,
PublicNote: utils.IfOr(withPublicNote, server.PublicNote, ""),
DisplayIndex: server.DisplayIndex,
Host: server.Host,
State: server.State,
LastActive: server.LastActive,
})
}
return utils.Json.Marshal(model.StreamServerData{
Now: time.Now().Unix() * 1000,
Servers: servers,
})
})
return v.([]byte), err
}

View File

@ -80,13 +80,14 @@ func (c *AgentConfig) Save() error {
type Config struct {
Debug bool // debug模式开关
Language string // 系统语言,默认 zh-CN
SiteName string
SecretKey string
ListenPort uint
InstallHost string
TLS bool
Location string // 时区,默认为 Asia/Shanghai
Language string // 系统语言,默认 zh-CN
SiteName string
JWTSecretKey string
AgentSecretKey string
ListenPort uint
InstallHost string
TLS bool
Location string // 时区,默认为 Asia/Shanghai
EnablePlainIPInNotification bool // 通知信息IP不打码
@ -132,8 +133,18 @@ func (c *Config) Read(path string) error {
if c.AvgPingCount == 0 {
c.AvgPingCount = 2
}
if c.SecretKey == "" {
c.SecretKey, err = utils.GenerateRandomString(1024)
if c.JWTSecretKey == "" {
c.JWTSecretKey, err = utils.GenerateRandomString(1024)
if err != nil {
return err
}
if err = c.Save(); err != nil {
return err
}
}
if c.AgentSecretKey == "" {
c.AgentSecretKey, err = utils.GenerateRandomString(32)
if err != nil {
return err
}

View File

@ -37,8 +37,6 @@ func execCase(t *testing.T, item testSt) {
server := Server{
Common: Common{},
Name: "ServerName",
Tag: "",
Secret: "",
Note: "",
DisplayIndex: 0,
Host: &Host{

View File

@ -1,8 +1,6 @@
package model
import (
"fmt"
"html/template"
"log"
"sync"
"time"
@ -15,21 +13,21 @@ import (
type Server struct {
Common
Name string
Tag string // 分组名
Secret string `gorm:"uniqueIndex" json:"-"`
Note string `json:"-"` // 管理员可见备注
PublicNote string `json:"PublicNote,omitempty"` // 公开备注
DisplayIndex int // 展示排序,越大越靠前
HideForGuest bool // 对游客隐藏
EnableDDNS bool // 启用DDNS
DDNSProfiles []uint64 `gorm:"-" json:"-"` // DDNS配置
Name string `json:"name,omitempty"`
UUID string `json:"uuid,omitempty" gorm:"unique"`
Note string `json:"note,omitempty"` // 管理员可见备注
PublicNote string `json:"public_note,omitempty"` // 公开备注
DisplayIndex int `json:"display_index,omitempty"` // 展示排序,越大越靠前
HideForGuest bool `json:"hide_for_guest,omitempty"` // 对游客隐藏
EnableDDNS bool `json:"enable_ddns,omitempty"` // 启用DDNS
DDNSProfilesRaw string `gorm:"default:'[]';column:ddns_profiles_raw" json:"-"`
Host *Host `gorm:"-"`
State *HostState `gorm:"-"`
LastActive time.Time `gorm:"-"`
DDNSProfiles []uint64 `gorm:"-" json:"ddns_profiles,omitempty"` // DDNS配置
Host *Host `gorm:"-" json:"host,omitempty"`
State *HostState `gorm:"-" json:"state,omitempty"`
LastActive time.Time `gorm:"-" json:"last_active,omitempty"`
TaskClose chan error `gorm:"-" json:"-"`
TaskCloseLock *sync.Mutex `gorm:"-" json:"-"`
@ -59,20 +57,3 @@ func (s *Server) AfterFind(tx *gorm.DB) error {
}
return nil
}
func boolToString(b bool) string {
if b {
return "true"
}
return "false"
}
func (s Server) MarshalForDashboard() template.JS {
name, _ := utils.Json.Marshal(s.Name)
tag, _ := utils.Json.Marshal(s.Tag)
note, _ := utils.Json.Marshal(s.Note)
secret, _ := utils.Json.Marshal(s.Secret)
ddnsProfilesRaw, _ := utils.Json.Marshal(s.DDNSProfilesRaw)
publicNote, _ := utils.Json.Marshal(s.PublicNote)
return template.JS(fmt.Sprintf(`{"ID":%d,"Name":%s,"Secret":%s,"DisplayIndex":%d,"Tag":%s,"Note":%s,"HideForGuest": %s,"EnableDDNS": %s,"DDNSProfilesRaw": %s,"PublicNote": %s}`, s.ID, name, secret, s.DisplayIndex, tag, note, boolToString(s.HideForGuest), boolToString(s.EnableDDNS), ddnsProfilesRaw, publicNote))
}

29
model/server_api.go Normal file
View File

@ -0,0 +1,29 @@
package model
import "time"
type StreamServer struct {
ID uint64 `json:"id,omitempty"`
Name string `json:"name,omitempty"`
PublicNote string `json:"public_note,omitempty"` // 公开备注,只第一个数据包有值
DisplayIndex int `json:"display_index,omitempty"` // 展示排序,越大越靠前
Host *Host `json:"host,omitempty"`
State *HostState `json:"state,omitempty"`
LastActive time.Time `json:"last_active,omitempty"`
}
type StreamServerData struct {
Now int64 `json:"now,omitempty"`
Servers []StreamServer `json:"servers,omitempty"`
}
type EditServer struct {
Name string `json:"name,omitempty"`
Note string `json:"note,omitempty"` // 管理员可见备注
PublicNote string `json:"public_note,omitempty"` // 公开备注
DisplayIndex int `json:"display_index,omitempty"` // 展示排序,越大越靠前
HideForGuest bool `json:"hide_for_guest,omitempty"` // 对游客隐藏
EnableDDNS bool `json:"enable_ddns,omitempty"` // 启用DDNS
DDNSProfiles []uint64 `gorm:"-" json:"ddns_profiles,omitempty"` // DDNS配置
}

View File

@ -2,5 +2,6 @@ package model
type ServerGroup struct {
Common
Name string `json:"name"`
}

View File

@ -1,30 +0,0 @@
package model
import (
"testing"
"github.com/naiba/nezha/pkg/utils"
)
func TestServerMarshal(t *testing.T) {
patterns := []string{
"asd > asd",
"asd \" asd",
"asd } asd",
}
for i := 0; i < len(patterns); i++ {
server := Server{
Name: patterns[i],
Tag: patterns[i],
}
serverStr := string(server.MarshalForDashboard())
var serverRestore Server
if utils.Json.Unmarshal([]byte(serverStr), &serverRestore) != nil {
t.Fatalf("Error: %s", serverStr)
}
if server.Name != serverRestore.Name {
t.Fatalf("Expected %s, but got %s", server.Name, serverRestore.Name)
}
}
}

View File

@ -90,3 +90,17 @@ func Uint64SubInt64(a uint64, b int64) uint64 {
}
return a - uint64(b)
}
func IfOr[T any](a bool, x, y T) T {
if a {
return x
}
return y
}
func IfOrFn[T any](a bool, x, y func() T) T {
if a {
return x()
}
return y()
}

View File

@ -2,20 +2,23 @@ package rpc
import (
"context"
"sync"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"github.com/naiba/nezha/model"
"github.com/naiba/nezha/service/singleton"
)
type authHandler struct {
ClientSecret string
ClientUUID string
}
func (a *authHandler) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
return map[string]string{"client_secret": a.ClientSecret}, nil
return map[string]string{"client_secret": a.ClientSecret, "client_uuid": a.ClientUUID}, nil
}
func (a *authHandler) RequireTransportSecurity() bool {
@ -33,15 +36,29 @@ func (a *authHandler) Check(ctx context.Context) (uint64, error) {
clientSecret = value[0]
}
if clientSecret != singleton.Conf.AgentSecretKey {
return 0, status.Errorf(codes.Unauthenticated, "客户端认证失败")
}
var clientUUID string
if value, ok := md["client_uuid"]; ok {
clientUUID = value[0]
}
singleton.ServerLock.RLock()
defer singleton.ServerLock.RUnlock()
clientID, hasID := singleton.SecretToID[clientSecret]
clientID, hasID := singleton.ServerUUIDToID[clientUUID]
if !hasID {
return 0, status.Errorf(codes.Unauthenticated, "客户端认证失败")
}
_, hasServer := singleton.ServerList[clientID]
if !hasServer {
return 0, status.Errorf(codes.Unauthenticated, "客户端认证失败")
s := model.Server{UUID: clientUUID}
if err := singleton.DB.Create(&s).Error; err != nil {
return 0, status.Errorf(codes.Unauthenticated, err.Error())
}
s.Host = &model.Host{}
s.State = &model.HostState{}
s.TaskCloseLock = new(sync.Mutex)
singleton.ServerList[s.ID] = &s
singleton.ServerUUIDToID[clientUUID] = s.ID
}
return clientID, nil
}

View File

@ -103,9 +103,9 @@ func (s *ServerAPIService) GetStatusByIDList(idList []uint64) *ServerStatusRespo
}
ipv4, ipv6, validIP := utils.SplitIPAddr(server.Host.IP)
info := CommonServerInfo{
ID: server.ID,
Name: server.Name,
Tag: server.Tag,
ID: server.ID,
Name: server.Name,
// Tag: server.Tag,
LastActive: server.LastActive.Unix(),
IPV4: ipv4,
IPV6: ipv6,
@ -125,9 +125,9 @@ func (s *ServerAPIService) GetStatusByIDList(idList []uint64) *ServerStatusRespo
}
// GetStatusByTag 获取传入分组的所有服务器状态信息
func (s *ServerAPIService) GetStatusByTag(tag string) *ServerStatusResponse {
return s.GetStatusByIDList(ServerTagToIDList[tag])
}
// func (s *ServerAPIService) GetStatusByTag(tag string) *ServerStatusResponse {
// return s.GetStatusByIDList(ServerTagToIDList[tag])
// }
// GetAllStatus 获取所有服务器状态信息
func (s *ServerAPIService) GetAllStatus() *ServerStatusResponse {
@ -143,9 +143,9 @@ func (s *ServerAPIService) GetAllStatus() *ServerStatusResponse {
}
ipv4, ipv6, validIP := utils.SplitIPAddr(host.IP)
info := CommonServerInfo{
ID: v.ID,
Name: v.Name,
Tag: v.Tag,
ID: v.ID,
Name: v.Name,
// Tag: v.Tag,
LastActive: v.LastActive.Unix(),
IPV4: ipv4,
IPV6: ipv6,
@ -173,23 +173,23 @@ func (s *ServerAPIService) GetListByTag(tag string) *ServerInfoResponse {
ServerLock.RLock()
defer ServerLock.RUnlock()
for _, v := range ServerTagToIDList[tag] {
host := ServerList[v].Host
if host == nil {
continue
}
ipv4, ipv6, validIP := utils.SplitIPAddr(host.IP)
info := &CommonServerInfo{
ID: v,
Name: ServerList[v].Name,
Tag: ServerList[v].Tag,
LastActive: ServerList[v].LastActive.Unix(),
IPV4: ipv4,
IPV6: ipv6,
ValidIP: validIP,
}
res.Result = append(res.Result, info)
}
// for _, v := range ServerTagToIDList[tag] {
// host := ServerList[v].Host
// if host == nil {
// continue
// }
// ipv4, ipv6, validIP := utils.SplitIPAddr(host.IP)
// info := &CommonServerInfo{
// ID: v,
// Name: ServerList[v].Name,
// Tag: ServerList[v].Tag,
// LastActive: ServerList[v].LastActive.Unix(),
// IPV4: ipv4,
// IPV6: ipv6,
// ValidIP: validIP,
// }
// res.Result = append(res.Result, info)
// }
res.CommonResponse = CommonResponse{
Code: 0,
Message: "success",
@ -211,9 +211,9 @@ func (s *ServerAPIService) GetAllList() *ServerInfoResponse {
}
ipv4, ipv6, validIP := utils.SplitIPAddr(host.IP)
info := &CommonServerInfo{
ID: v.ID,
Name: v.Name,
Tag: v.Tag,
ID: v.ID,
Name: v.Name,
// Tag: v.Tag,
LastActive: v.LastActive.Unix(),
IPV4: ipv4,
IPV6: ipv6,

View File

@ -8,21 +8,18 @@ import (
)
var (
ServerList map[uint64]*model.Server // [ServerID] -> model.Server
SecretToID map[string]uint64 // [ServerSecret] -> ServerID
ServerTagToIDList map[string][]uint64 // [ServerTag] -> ServerID
ServerLock sync.RWMutex
ServerList map[uint64]*model.Server // [ServerID] -> model.Server
ServerUUIDToID map[string]uint64 // [ServerUUID] -> ServerID
ServerLock sync.RWMutex
SortedServerList []*model.Server // 用于存储服务器列表的 slice按照服务器 ID 排序
SortedServerListForGuest []*model.Server
SortedServerLock sync.RWMutex
)
// InitServer 初始化 ServerID <-> Secret 的映射
func InitServer() {
ServerList = make(map[uint64]*model.Server)
SecretToID = make(map[string]uint64)
ServerTagToIDList = make(map[string][]uint64)
ServerUUIDToID = make(map[string]uint64)
}
// loadServers 加载服务器列表并根据ID排序
@ -36,8 +33,7 @@ func loadServers() {
innerS.State = &model.HostState{}
innerS.TaskCloseLock = new(sync.Mutex)
ServerList[innerS.ID] = &innerS
SecretToID[innerS.Secret] = innerS.ID
ServerTagToIDList[innerS.Tag] = append(ServerTagToIDList[innerS.Tag], innerS.ID)
ServerUUIDToID[innerS.UUID] = innerS.ID
}
ReSortServer()
}