feat: terminal api

This commit is contained in:
naiba 2024-10-22 22:01:01 +08:00
parent 387da11f1b
commit f99edfd7bd
7 changed files with 121 additions and 478 deletions

View File

@ -1,111 +0,0 @@
package controller
import (
"strconv"
"strings"
"github.com/gin-gonic/gin"
"github.com/naiba/nezha/model"
"github.com/naiba/nezha/service/singleton"
)
type apiV1 struct {
r gin.IRouter
}
func (v *apiV1) serve() {
r := v.r.Group("")
// 强制认证的 API
// r.Use(mygin.Authorize(mygin.AuthorizeOption{
// MemberOnly: true,
// AllowAPI: true,
// IsPage: false,
// Msg: "访问此接口需要认证",
// Btn: "点此登录",
// Redirect: "/login",
// }))
r.GET("/server/list", v.serverList)
r.GET("/server/details", v.serverDetails)
// 不强制认证的 API
mr := v.r.Group("monitor")
// mr.Use(mygin.Authorize(mygin.AuthorizeOption{
// MemberOnly: false,
// IsPage: false,
// AllowAPI: true,
// Msg: "访问此接口需要认证",
// Btn: "点此登录",
// Redirect: "/login",
// }))
// mr.Use(mygin.ValidateViewPassword(mygin.ValidateViewPasswordOption{
// IsPage: false,
// AbortWhenFail: true,
// }))
mr.GET("/:id", v.monitorHistoriesById)
}
// serverList 获取服务器列表 不传入Query参数则获取全部
// header: Authorization: Token
// query: tag (服务器分组)
func (v *apiV1) serverList(c *gin.Context) {
tag := c.Query("tag")
if tag != "" {
c.JSON(200, singleton.ServerAPI.GetListByTag(tag))
return
}
c.JSON(200, singleton.ServerAPI.GetAllList())
}
// serverDetails 获取服务器信息 不传入Query参数则获取全部
// header: Authorization: Token
// query: id (服务器ID逗号分隔优先级高于tag查询)
// query: tag (服务器分组)
func (v *apiV1) serverDetails(c *gin.Context) {
var idList []uint64
idListStr := strings.Split(c.Query("id"), ",")
if c.Query("id") != "" {
idList = make([]uint64, len(idListStr))
for i, v := range idListStr {
id, _ := strconv.ParseUint(v, 10, 64)
idList[i] = id
}
}
tag := c.Query("tag")
if tag != "" {
// c.JSON(200, singleton.ServerAPI.GetStatusByTag(tag))
return
}
if len(idList) != 0 {
c.JSON(200, singleton.ServerAPI.GetStatusByIDList(idList))
return
}
c.JSON(200, singleton.ServerAPI.GetAllStatus())
}
func (v *apiV1) monitorHistoriesById(c *gin.Context) {
idStr := c.Param("id")
id, err := strconv.ParseUint(idStr, 10, 64)
if err != nil {
c.AbortWithStatusJSON(400, gin.H{"code": 400, "message": "id参数错误"})
return
}
server, ok := singleton.ServerList[id]
if !ok {
c.AbortWithStatusJSON(404, gin.H{
"code": 404,
"message": "id不存在",
})
return
}
_, isMember := c.Get(model.CtxKeyAuthorizedUser)
var isViewPasswordVerfied bool
authorized := isMember || isViewPasswordVerfied
if server.HideForGuest && !authorized {
c.AbortWithStatusJSON(403, gin.H{"code": 403, "message": "需要认证"})
return
}
c.JSON(200, singleton.MonitorAPI.GetMonitorHistories(map[string]any{"server_id": server.ID}))
}

View File

@ -1,7 +1,6 @@
package controller
import (
"fmt"
"net/http"
"strconv"
"time"
@ -25,17 +24,10 @@ type commonPage struct {
func (cp *commonPage) serve() {
cr := cp.r.Group("")
cr.GET("/terminal/:id", cp.terminal)
// cr.Use(mygin.ValidateViewPassword(mygin.ValidateViewPasswordOption{
// IsPage: true,
// AbortWhenFail: true,
// }))
cr.GET("/", cp.home)
cr.GET("/service", cp.service)
// TODO: 界面直接跳转使用该接口
cr.GET("/network/:id", cp.network)
cr.GET("/network", cp.network)
cr.POST("/terminal", cp.createTerminal)
cr.GET("/file", cp.createFM)
cr.GET("/file/:id", cp.fm)
}
@ -169,187 +161,6 @@ func (cp *commonPage) network(c *gin.Context) {
})
}
func (cp *commonPage) getServerStat(c *gin.Context, withPublicNote bool) ([]byte, error) {
_, isMember := c.Get(model.CtxKeyAuthorizedUser)
var isViewPasswordVerfied bool
authorized := isMember || isViewPasswordVerfied
v, err, _ := requestGroup.Do(fmt.Sprintf("serverStats::%t", authorized), func() (interface{}, error) {
singleton.SortedServerLock.RLock()
defer singleton.SortedServerLock.RUnlock()
var serverList []*model.Server
if authorized {
serverList = singleton.SortedServerList
} else {
serverList = singleton.SortedServerListForGuest
}
var servers []model.Server
for _, server := range serverList {
item := *server
if !withPublicNote {
item.PublicNote = ""
}
servers = append(servers, item)
}
return utils.Json.Marshal(model.StreamServerData{
Now: time.Now().Unix() * 1000,
// Servers: servers,
})
})
return v.([]byte), err
}
func (cp *commonPage) home(c *gin.Context) {
stat, err := cp.getServerStat(c, true)
if err != nil {
// mygin.ShowErrorPage(c, mygin.ErrInfo{
// Code: http.StatusInternalServerError,
// // Title: singleton.Localizer.MustLocalize(&i18n.LocalizeConfig{
// // MessageID: "SystemError",
// // }),
// Msg: "服务器状态获取失败",
// Link: "/",
// Btn: "返回首页",
// }, true)
return
}
c.HTML(http.StatusOK, "", gin.H{
"Servers": string(stat),
})
}
func (cp *commonPage) terminal(c *gin.Context) {
streamId := c.Param("id")
if _, err := rpc.NezhaHandlerSingleton.GetStream(streamId); err != nil {
// mygin.ShowErrorPage(c, mygin.ErrInfo{
// Code: http.StatusForbidden,
// Title: "无权访问",
// Msg: "终端会话不存在",
// Link: "/",
// Btn: "返回首页",
// }, true)
return
}
defer rpc.NezhaHandlerSingleton.CloseStream(streamId)
wsConn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
// mygin.ShowErrorPage(c, mygin.ErrInfo{
// Code: http.StatusInternalServerError,
// // Title: singleton.Localizer.MustLocalize(&i18n.LocalizeConfig{
// // MessageID: "NetworkError",
// // }),
// Msg: "Websocket协议切换失败",
// Link: "/",
// Btn: "返回首页",
// }, true)
return
}
defer wsConn.Close()
conn := websocketx.NewConn(wsConn)
go func() {
// PING 保活
for {
if err = conn.WriteMessage(websocket.PingMessage, []byte{}); err != nil {
return
}
time.Sleep(time.Second * 10)
}
}()
if err = rpc.NezhaHandlerSingleton.UserConnected(streamId, conn); err != nil {
return
}
rpc.NezhaHandlerSingleton.StartStream(streamId, time.Second*10)
}
type createTerminalRequest struct {
Host string
Protocol string
ID uint64
}
func (cp *commonPage) createTerminal(c *gin.Context) {
if _, authorized := c.Get(model.CtxKeyAuthorizedUser); !authorized {
// mygin.ShowErrorPage(c, mygin.ErrInfo{
// Code: http.StatusForbidden,
// Title: "无权访问",
// Msg: "用户未登录",
// Link: "/login",
// Btn: "去登录",
// }, true)
return
}
var createTerminalReq createTerminalRequest
if err := c.ShouldBind(&createTerminalReq); err != nil {
// mygin.ShowErrorPage(c, mygin.ErrInfo{
// Code: http.StatusForbidden,
// Title: "请求失败",
// Msg: "请求参数有误:" + err.Error(),
// Link: "/server",
// Btn: "返回重试",
// }, true)
return
}
streamId, err := uuid.GenerateUUID()
if err != nil {
// mygin.ShowErrorPage(c, mygin.ErrInfo{
// Code: http.StatusInternalServerError,
// // Title: singleton.Localizer.MustLocalize(&i18n.LocalizeConfig{
// // MessageID: "SystemError",
// // }),
// Msg: "生成会话ID失败",
// Link: "/server",
// Btn: "返回重试",
// }, true)
return
}
rpc.NezhaHandlerSingleton.CreateStream(streamId)
singleton.ServerLock.RLock()
server := singleton.ServerList[createTerminalReq.ID]
singleton.ServerLock.RUnlock()
if server == nil || server.TaskStream == nil {
// mygin.ShowErrorPage(c, mygin.ErrInfo{
// Code: http.StatusForbidden,
// Title: "请求失败",
// Msg: "服务器不存在或处于离线状态",
// Link: "/server",
// Btn: "返回重试",
// }, true)
return
}
terminalData, _ := utils.Json.Marshal(&model.TerminalTask{
StreamID: streamId,
})
if err := server.TaskStream.Send(&proto.Task{
Type: model.TaskTypeTerminalGRPC,
Data: string(terminalData),
}); err != nil {
// mygin.ShowErrorPage(c, mygin.ErrInfo{
// Code: http.StatusForbidden,
// Title: "请求失败",
// Msg: "Agent信令下发失败",
// Link: "/server",
// Btn: "返回重试",
// }, true)
return
}
c.HTML(http.StatusOK, "", gin.H{
"SessionID": streamId,
"ServerName": server.Name,
"ServerID": server.ID,
})
}
func (cp *commonPage) fm(c *gin.Context) {
streamId := c.Param("id")
if _, err := rpc.NezhaHandlerSingleton.GetStream(streamId); err != nil {

View File

@ -61,8 +61,12 @@ func routers(r *gin.Engine) {
optionalAuth.GET("/server-group", commonHandler(listServerGroup))
auth := api.Group("", authMiddleware.MiddlewareFunc())
auth.GET("/refresh_token", authMiddleware.RefreshHandler)
auth.POST("/terminal", commonHandler(createTerminal))
auth.GET("/ws/terminal/:id", commonHandler(terminalStream))
auth.GET("/user", commonHandler(listUser))
auth.POST("/user", commonHandler(createUser))
auth.POST("/batch-delete/user", commonHandler(batchDeleteUser))

View File

@ -7,7 +7,6 @@ import (
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
@ -33,11 +32,6 @@ func (ma *memberAPI) serve() {
// Btn: "点此登录",
// Redirect: "/login",
// }))
mr.GET("/search-server", ma.searchServer)
mr.GET("/search-tasks", ma.searchTask)
mr.GET("/search-ddns", ma.searchDDNS)
mr.POST("/server", ma.addOrEditServer)
mr.POST("/monitor", ma.addOrEditMonitor)
mr.POST("/cron", ma.addOrEditCron)
mr.GET("/cron/:id/manual", ma.manualTrigger)
@ -53,13 +47,6 @@ func (ma *memberAPI) serve() {
mr.GET("/token", ma.getToken)
mr.POST("/token", ma.issueNewToken)
mr.DELETE("/token/:token", ma.deleteToken)
// API
v1 := ma.r.Group("v1")
{
apiv1 := &apiV1{v1}
apiv1.serve()
}
}
type apiResult struct {
@ -236,141 +223,6 @@ func (ma *memberAPI) delete(c *gin.Context) {
})
}
type searchResult struct {
Name string `json:"name,omitempty"`
Value uint64 `json:"value,omitempty"`
Text string `json:"text,omitempty"`
}
func (ma *memberAPI) searchServer(c *gin.Context) {
var servers []model.Server
likeWord := "%" + c.Query("word") + "%"
singleton.DB.Select("id,name").Where("id = ? OR name LIKE ? OR tag LIKE ? OR note LIKE ?",
c.Query("word"), likeWord, likeWord, likeWord).Find(&servers)
var resp []searchResult
for i := 0; i < len(servers); i++ {
resp = append(resp, searchResult{
Value: servers[i].ID,
Name: servers[i].Name,
Text: servers[i].Name,
})
}
c.JSON(http.StatusOK, map[string]interface{}{
"success": true,
"results": resp,
})
}
func (ma *memberAPI) searchTask(c *gin.Context) {
var tasks []model.Cron
likeWord := "%" + c.Query("word") + "%"
singleton.DB.Select("id,name").Where("id = ? OR name LIKE ?",
c.Query("word"), likeWord).Find(&tasks)
var resp []searchResult
for i := 0; i < len(tasks); i++ {
resp = append(resp, searchResult{
Value: tasks[i].ID,
Name: tasks[i].Name,
Text: tasks[i].Name,
})
}
c.JSON(http.StatusOK, map[string]interface{}{
"success": true,
"results": resp,
})
}
func (ma *memberAPI) searchDDNS(c *gin.Context) {
var ddns []model.DDNSProfile
likeWord := "%" + c.Query("word") + "%"
singleton.DB.Select("id,name").Where("id = ? OR name LIKE ?",
c.Query("word"), likeWord).Find(&ddns)
var resp []searchResult
for i := 0; i < len(ddns); i++ {
resp = append(resp, searchResult{
Value: ddns[i].ID,
Name: ddns[i].Name,
Text: ddns[i].Name,
})
}
c.JSON(http.StatusOK, map[string]interface{}{
"success": true,
"results": resp,
})
}
type serverForm struct {
ID uint64
Name string `binding:"required"`
DisplayIndex int
Secret string
Tag string
Note string
PublicNote string
HideForGuest string
EnableDDNS string
DDNSProfilesRaw string
}
func (ma *memberAPI) addOrEditServer(c *gin.Context) {
var sf serverForm
var s model.Server
var isEdit bool
err := c.ShouldBindJSON(&sf)
if err == nil {
s.Name = sf.Name
s.DisplayIndex = sf.DisplayIndex
s.ID = sf.ID
s.Note = sf.Note
s.PublicNote = sf.PublicNote
s.HideForGuest = sf.HideForGuest == "on"
s.EnableDDNS = sf.EnableDDNS == "on"
s.DDNSProfilesRaw = sf.DDNSProfilesRaw
err = utils.Json.Unmarshal([]byte(sf.DDNSProfilesRaw), &s.DDNSProfiles)
if err == nil {
if s.ID == 0 {
_, err = utils.GenerateRandomString(18)
if err == nil {
err = singleton.DB.Create(&s).Error
}
} else {
isEdit = true
err = singleton.DB.Save(&s).Error
}
}
}
if err != nil {
c.JSON(http.StatusOK, model.Response{
Code: http.StatusBadRequest,
Message: fmt.Sprintf("请求错误:%s", err),
})
return
}
if isEdit {
singleton.ServerLock.Lock()
s.CopyFromRunningServer(singleton.ServerList[s.ID])
singleton.ServerList[s.ID] = &s
singleton.ServerLock.Unlock()
} else {
s.Host = &model.Host{}
s.State = &model.HostState{}
s.TaskCloseLock = new(sync.Mutex)
singleton.ServerLock.Lock()
singleton.ServerList[s.ID] = &s
singleton.ServerLock.Unlock()
}
singleton.ReSortServer()
c.JSON(http.StatusOK, model.Response{
Code: http.StatusOK,
})
}
type monitorForm struct {
ID uint64
Name string

View File

@ -0,0 +1,105 @@
package controller
import (
"errors"
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/hashicorp/go-uuid"
"github.com/naiba/nezha/model"
"github.com/naiba/nezha/pkg/utils"
"github.com/naiba/nezha/pkg/websocketx"
"github.com/naiba/nezha/proto"
"github.com/naiba/nezha/service/rpc"
"github.com/naiba/nezha/service/singleton"
)
// Create web ssh terminal
// @Summary Create web ssh terminal
// @Description Create web ssh terminal
// @Tags auth required
// @Accept json
// @Param terminal body model.TerminalForm true "TerminalForm"
// @Produce json
// @Success 200 {object} model.CreateTerminalResponse
// @Router /terminal [post]
func createTerminal(c *gin.Context) error {
var createTerminalReq model.TerminalForm
if err := c.ShouldBind(&createTerminalReq); err != nil {
return err
}
streamId, err := uuid.GenerateUUID()
if err != nil {
return err
}
rpc.NezhaHandlerSingleton.CreateStream(streamId)
singleton.ServerLock.RLock()
server := singleton.ServerList[createTerminalReq.ServerID]
singleton.ServerLock.RUnlock()
if server == nil || server.TaskStream == nil {
return errors.New("server not found or not connected")
}
terminalData, _ := utils.Json.Marshal(&model.TerminalTask{
StreamID: streamId,
})
if err := server.TaskStream.Send(&proto.Task{
Type: model.TaskTypeTerminalGRPC,
Data: string(terminalData),
}); err != nil {
return err
}
c.JSON(http.StatusOK, model.CommonResponse[model.CreateTerminalResponse]{
Success: true,
Data: model.CreateTerminalResponse{
SessionID: streamId,
ServerID: server.ID,
ServerName: server.Name,
},
})
return nil
}
// TerminalStream web ssh terminal stream
// @Summary Terminal stream
// @Description Terminal stream
// @Tags auth required
// @Param id path string true "Stream ID"
// @Router /terminal/{id} [get]
func terminalStream(c *gin.Context) error {
streamId := c.Param("id")
if _, err := rpc.NezhaHandlerSingleton.GetStream(streamId); err != nil {
return err
}
defer rpc.NezhaHandlerSingleton.CloseStream(streamId)
wsConn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
return err
}
defer wsConn.Close()
conn := websocketx.NewConn(wsConn)
go func() {
// PING 保活
for {
if err = conn.WriteMessage(websocket.PingMessage, []byte{}); err != nil {
return
}
time.Sleep(time.Second * 10)
}
}()
if err = rpc.NezhaHandlerSingleton.UserConnected(streamId, conn); err != nil {
return err
}
return rpc.NezhaHandlerSingleton.StartStream(streamId, time.Second*10)
}

View File

@ -47,36 +47,6 @@ const (
ConfigCoverIgnoreAll
)
type AgentConfig struct {
HardDrivePartitionAllowlist []string
NICAllowlist map[string]bool
v *viper.Viper
}
// Read 从给定的文件目录加载配置文件
func (c *AgentConfig) Read(path string) error {
c.v = viper.New()
c.v.SetConfigFile(path)
err := c.v.ReadInConfig()
if err != nil {
return err
}
err = c.v.Unmarshal(c)
if err != nil {
return err
}
return nil
}
func (c *AgentConfig) Save() error {
data, err := yaml.Marshal(c)
if err != nil {
return err
}
return os.WriteFile(c.v.ConfigFileUsed(), data, 0600)
}
// Config 站点配置
type Config struct {
Debug bool // debug模式开关

12
model/terminal_api.go Normal file
View File

@ -0,0 +1,12 @@
package model
type TerminalForm struct {
Protocol string `json:"protocol,omitempty"`
ServerID uint64 `json:"server_id,omitempty"`
}
type CreateTerminalResponse struct {
SessionID string `json:"session_id,omitempty"`
ServerID uint64 `json:"server_id,omitempty"`
ServerName string `json:"server_name,omitempty"`
}