nezha/cmd/dashboard/controller/ws.go
naiba 5128bfff61
Some checks are pending
CodeQL / Analyze (go) (push) Waiting to run
CodeQL / Analyze (javascript) (push) Waiting to run
Contributors / contributors (push) Waiting to run
Sync / sync-to-jihulab (push) Waiting to run
Run Tests / tests (macos) (push) Waiting to run
Run Tests / tests (ubuntu) (push) Waiting to run
Run Tests / tests (windows) (push) Waiting to run
feat: show online user id
2024-12-24 23:28:37 +08:00

198 lines
4.3 KiB
Go

package controller
import (
"fmt"
"net"
"net/http"
"net/url"
"time"
"unicode/utf8"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/hashicorp/go-uuid"
"golang.org/x/sync/singleflight"
"github.com/nezhahq/nezha/model"
"github.com/nezhahq/nezha/pkg/utils"
"github.com/nezhahq/nezha/service/singleton"
)
var upgrader *websocket.Upgrader
func InitUpgrader() {
var checkOrigin func(r *http.Request) bool
// Allow CORS from loopback addresses in debug mode
if singleton.Conf.Debug {
checkOrigin = func(r *http.Request) bool {
if checkSameOrigin(r) {
return true
}
hostAddr := r.Host
host, _, err := net.SplitHostPort(hostAddr)
if err != nil {
return false
}
if ip := net.ParseIP(host); ip != nil {
if ip.IsLoopback() {
return true
}
} else {
// Handle domains like "localhost"
ip, err := net.LookupHost(host)
if err != nil || len(ip) == 0 {
return false
}
if netIP := net.ParseIP(ip[0]); netIP != nil && netIP.IsLoopback() {
return true
}
}
return false
}
}
upgrader = &websocket.Upgrader{
ReadBufferSize: 32768,
WriteBufferSize: 32768,
CheckOrigin: checkOrigin,
}
}
func equalASCIIFold(s, t string) bool {
for s != "" && t != "" {
sr, size := utf8.DecodeRuneInString(s)
s = s[size:]
tr, size := utf8.DecodeRuneInString(t)
t = t[size:]
if sr == tr {
continue
}
if 'A' <= sr && sr <= 'Z' {
sr = sr + 'a' - 'A'
}
if 'A' <= tr && tr <= 'Z' {
tr = tr + 'a' - 'A'
}
if sr != tr {
return false
}
}
return s == t
}
func checkSameOrigin(r *http.Request) bool {
origin := r.Header["Origin"]
if len(origin) == 0 {
return true
}
u, err := url.Parse(origin[0])
if err != nil {
return false
}
return equalASCIIFold(u.Host, r.Host)
}
// Websocket server stream
// @Summary Websocket server stream
// @tags common
// @Schemes
// @Description Websocket server stream
// @security BearerAuth
// @Produce json
// @Success 200 {object} model.StreamServerData
// @Router /ws/server [get]
func serverStream(c *gin.Context) (any, error) {
connId, err := uuid.GenerateUUID()
if err != nil {
return nil, newWsError("%v", err)
}
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
return nil, newWsError("%v", err)
}
defer conn.Close()
userIp := c.GetString(model.CtxKeyRealIPStr)
if userIp == "" {
userIp = c.RemoteIP()
}
u, isMember := c.Get(model.CtxKeyAuthorizedUser)
var userId uint64
if isMember {
userId = u.(*model.User).ID
}
singleton.AddOnlineUser(connId, &model.OnlineUser{
UserID: userId,
IP: userIp,
ConnectedAt: time.Now(),
Conn: conn,
})
defer singleton.RemoveOnlineUser(connId)
count := 0
for {
stat, err := getServerStat(count == 0, isMember)
if err != nil {
continue
}
if err := conn.WriteMessage(websocket.TextMessage, stat); err != nil {
break
}
count += 1
if count%4 == 0 {
err = conn.WriteMessage(websocket.PingMessage, []byte{})
if err != nil {
break
}
}
time.Sleep(time.Second * 2)
}
return nil, newWsError("")
}
var requestGroup singleflight.Group
func getServerStat(withPublicNote, authorized bool) ([]byte, error) {
v, err, _ := requestGroup.Do(fmt.Sprintf("serverStats::%t", authorized), func() (interface{}, error) {
singleton.SortedServerLock.RLock()
defer singleton.SortedServerLock.RUnlock()
var serverList []*model.Server
if authorized {
serverList = singleton.SortedServerList
} else {
serverList = singleton.SortedServerListForGuest
}
servers := make([]model.StreamServer, 0, len(serverList))
for _, server := range serverList {
var countryCode string
if server.GeoIP != nil {
countryCode = server.GeoIP.CountryCode
}
servers = append(servers, model.StreamServer{
ID: server.ID,
Name: server.Name,
PublicNote: utils.IfOr(withPublicNote, server.PublicNote, ""),
DisplayIndex: server.DisplayIndex,
Host: utils.IfOr(authorized, server.Host, server.Host.Filter()),
State: server.State,
CountryCode: countryCode,
LastActive: server.LastActive,
})
}
return utils.Json.Marshal(model.StreamServerData{
Now: time.Now().Unix() * 1000,
Online: singleton.GetOnlineUserCount(),
Servers: servers,
})
})
return v.([]byte), err
}