nezha/cmd/dashboard/controller/ws.go

130 lines
3.0 KiB
Go
Raw Normal View History

2024-10-20 11:23:04 -04:00
package controller
import (
"fmt"
"net"
"net/http"
2024-10-20 11:23:04 -04:00
"time"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
2024-10-21 04:22:30 -04:00
"golang.org/x/sync/singleflight"
2024-10-20 11:23:04 -04:00
"github.com/naiba/nezha/model"
"github.com/naiba/nezha/pkg/utils"
"github.com/naiba/nezha/service/singleton"
)
var upgrader *websocket.Upgrader
func InitUpgrader() {
var checkOrigin func(r *http.Request) bool
// Allow CORS from loopback addresses in debug mode
if singleton.Conf.Debug {
checkOrigin = func(r *http.Request) bool {
hostAddr := r.Host
host, _, err := net.SplitHostPort(hostAddr)
if err != nil {
return false
}
if ip := net.ParseIP(host); ip != nil {
if ip.IsLoopback() {
return true
}
} else {
// Handle domains like "localhost"
ip, err := net.LookupHost(host)
if err != nil || len(ip) == 0 {
return false
}
if netIP := net.ParseIP(ip[0]); netIP != nil && netIP.IsLoopback() {
return true
}
}
return false
}
}
upgrader = &websocket.Upgrader{
ReadBufferSize: 32768,
WriteBufferSize: 32768,
CheckOrigin: checkOrigin,
}
2024-10-20 11:23:04 -04:00
}
// Websocket server stream
// @Summary Websocket server stream
// @tags common
// @Schemes
// @Description Websocket server stream
// @security BearerAuth
// @Produce json
// @Success 200 {object} model.StreamServerData
// @Router /ws/server [get]
2024-10-23 05:56:51 -04:00
func serverStream(c *gin.Context) (any, error) {
2024-10-20 11:23:04 -04:00
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
return nil, newWsError("%v", err)
2024-10-20 11:23:04 -04:00
}
defer conn.Close()
count := 0
for {
stat, err := getServerStat(c, count == 0)
if err != nil {
continue
}
if err := conn.WriteMessage(websocket.TextMessage, stat); err != nil {
break
}
count += 1
if count%4 == 0 {
err = conn.WriteMessage(websocket.PingMessage, []byte{})
if err != nil {
break
}
}
time.Sleep(time.Second * 2)
}
return nil, newWsError("")
2024-10-20 11:23:04 -04:00
}
var requestGroup singleflight.Group
func getServerStat(c *gin.Context, withPublicNote bool) ([]byte, error) {
_, isMember := c.Get(model.CtxKeyAuthorizedUser)
authorized := isMember // TODO || isViewPasswordVerfied
v, err, _ := requestGroup.Do(fmt.Sprintf("serverStats::%t", authorized), func() (interface{}, error) {
singleton.SortedServerLock.RLock()
defer singleton.SortedServerLock.RUnlock()
var serverList []*model.Server
if authorized {
serverList = singleton.SortedServerList
} else {
serverList = singleton.SortedServerListForGuest
}
var servers []model.StreamServer
for i := 0; i < len(serverList); i++ {
server := serverList[i]
servers = append(servers, model.StreamServer{
ID: server.ID,
Name: server.Name,
PublicNote: utils.IfOr(withPublicNote, server.PublicNote, ""),
DisplayIndex: server.DisplayIndex,
Host: server.Host,
2024-10-20 11:23:04 -04:00
State: server.State,
LastActive: server.LastActive,
})
}
return utils.Json.Marshal(model.StreamServerData{
Now: time.Now().Unix() * 1000,
Servers: servers,
})
})
return v.([]byte), err
}