From be79b11e582cbe5b9f3d15e9a0e3abda47e01ce2 Mon Sep 17 00:00:00 2001 From: UUBulb <35923940+uubulb@users.noreply.github.com> Date: Mon, 18 Nov 2024 13:26:41 +0800 Subject: [PATCH] allow cors from loopback addresses in debug mode (#9) --- cmd/dashboard/controller/fm.go | 2 +- cmd/dashboard/controller/terminal.go | 2 +- cmd/dashboard/controller/ws.go | 44 +++++++++++++++++++++++++--- cmd/dashboard/main.go | 1 + 4 files changed, 43 insertions(+), 6 deletions(-) diff --git a/cmd/dashboard/controller/fm.go b/cmd/dashboard/controller/fm.go index d03a41e..bdcd7a2 100644 --- a/cmd/dashboard/controller/fm.go +++ b/cmd/dashboard/controller/fm.go @@ -76,7 +76,7 @@ func fmStream(c *gin.Context) (any, error) { wsConn, err := upgrader.Upgrade(c.Writer, c.Request, nil) if err != nil { - return nil, err + return nil, newWsError("%v", err) } defer wsConn.Close() conn := websocketx.NewConn(wsConn) diff --git a/cmd/dashboard/controller/terminal.go b/cmd/dashboard/controller/terminal.go index d1c577e..67466ba 100644 --- a/cmd/dashboard/controller/terminal.go +++ b/cmd/dashboard/controller/terminal.go @@ -76,7 +76,7 @@ func terminalStream(c *gin.Context) (any, error) { wsConn, err := upgrader.Upgrade(c.Writer, c.Request, nil) if err != nil { - return nil, err + return nil, newWsError("%v", err) } defer wsConn.Close() conn := websocketx.NewConn(wsConn) diff --git a/cmd/dashboard/controller/ws.go b/cmd/dashboard/controller/ws.go index 62b8164..fc96bb0 100644 --- a/cmd/dashboard/controller/ws.go +++ b/cmd/dashboard/controller/ws.go @@ -2,6 +2,8 @@ package controller import ( "fmt" + "net" + "net/http" "time" "github.com/gin-gonic/gin" @@ -13,9 +15,43 @@ import ( "github.com/naiba/nezha/service/singleton" ) -var upgrader = websocket.Upgrader{ - ReadBufferSize: 32768, - WriteBufferSize: 32768, +var upgrader *websocket.Upgrader + +func InitUpgrader() { + var checkOrigin func(r *http.Request) bool + + // Allow CORS from loopback addresses in debug mode + if singleton.Conf.Debug { + checkOrigin = func(r *http.Request) bool { + hostAddr := r.Host + host, _, err := net.SplitHostPort(hostAddr) + if err != nil { + return false + } + if ip := net.ParseIP(host); ip != nil { + if ip.IsLoopback() { + return true + } + } else { + // Handle domains like "localhost" + ip, err := net.LookupHost(host) + if err != nil || len(ip) == 0 { + return false + } + if netIP := net.ParseIP(ip[0]); netIP != nil && netIP.IsLoopback() { + return true + } + } + + return false + } + } + + upgrader = &websocket.Upgrader{ + ReadBufferSize: 32768, + WriteBufferSize: 32768, + CheckOrigin: checkOrigin, + } } // Websocket server stream @@ -30,7 +66,7 @@ var upgrader = websocket.Upgrader{ func serverStream(c *gin.Context) (any, error) { conn, err := upgrader.Upgrade(c.Writer, c.Request, nil) if err != nil { - return nil, err + return nil, newWsError("%v", err) } defer conn.Close() count := 0 diff --git a/cmd/dashboard/main.go b/cmd/dashboard/main.go index c9b247a..b97b70d 100644 --- a/cmd/dashboard/main.go +++ b/cmd/dashboard/main.go @@ -121,6 +121,7 @@ func main() { grpcHandler := rpc.ServeRPC() httpHandler := controller.ServeWeb() + controller.InitUpgrader() muxHandler := newHTTPandGRPCMux(httpHandler, grpcHandler) http2Server := &http2.Server{}