allow cors from loopback addresses in debug mode (#9)

This commit is contained in:
UUBulb 2024-11-18 13:26:41 +08:00 committed by GitHub
parent 8eec79d54f
commit be79b11e58
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 43 additions and 6 deletions

View File

@ -76,7 +76,7 @@ func fmStream(c *gin.Context) (any, error) {
wsConn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
return nil, err
return nil, newWsError("%v", err)
}
defer wsConn.Close()
conn := websocketx.NewConn(wsConn)

View File

@ -76,7 +76,7 @@ func terminalStream(c *gin.Context) (any, error) {
wsConn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
return nil, err
return nil, newWsError("%v", err)
}
defer wsConn.Close()
conn := websocketx.NewConn(wsConn)

View File

@ -2,6 +2,8 @@ package controller
import (
"fmt"
"net"
"net/http"
"time"
"github.com/gin-gonic/gin"
@ -13,9 +15,43 @@ import (
"github.com/naiba/nezha/service/singleton"
)
var upgrader = websocket.Upgrader{
ReadBufferSize: 32768,
WriteBufferSize: 32768,
var upgrader *websocket.Upgrader
func InitUpgrader() {
var checkOrigin func(r *http.Request) bool
// Allow CORS from loopback addresses in debug mode
if singleton.Conf.Debug {
checkOrigin = func(r *http.Request) bool {
hostAddr := r.Host
host, _, err := net.SplitHostPort(hostAddr)
if err != nil {
return false
}
if ip := net.ParseIP(host); ip != nil {
if ip.IsLoopback() {
return true
}
} else {
// Handle domains like "localhost"
ip, err := net.LookupHost(host)
if err != nil || len(ip) == 0 {
return false
}
if netIP := net.ParseIP(ip[0]); netIP != nil && netIP.IsLoopback() {
return true
}
}
return false
}
}
upgrader = &websocket.Upgrader{
ReadBufferSize: 32768,
WriteBufferSize: 32768,
CheckOrigin: checkOrigin,
}
}
// Websocket server stream
@ -30,7 +66,7 @@ var upgrader = websocket.Upgrader{
func serverStream(c *gin.Context) (any, error) {
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
return nil, err
return nil, newWsError("%v", err)
}
defer conn.Close()
count := 0

View File

@ -121,6 +121,7 @@ func main() {
grpcHandler := rpc.ServeRPC()
httpHandler := controller.ServeWeb()
controller.InitUpgrader()
muxHandler := newHTTPandGRPCMux(httpHandler, grpcHandler)
http2Server := &http2.Server{}