mirror of
https://github.com/nezhahq/nezha.git
synced 2025-02-02 01:28:13 -05:00
prevent writing response to websocket connections (#457)
This commit is contained in:
parent
d086e98711
commit
380973a200
@ -136,6 +136,22 @@ func (ge *gormError) Error() string {
|
|||||||
return fmt.Sprintf(ge.msg, ge.a...)
|
return fmt.Sprintf(ge.msg, ge.a...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type wsError struct {
|
||||||
|
msg string
|
||||||
|
a []interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newWsError(format string, args ...interface{}) error {
|
||||||
|
return &wsError{
|
||||||
|
msg: format,
|
||||||
|
a: args,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (we *wsError) Error() string {
|
||||||
|
return fmt.Sprintf(we.msg, we.a...)
|
||||||
|
}
|
||||||
|
|
||||||
func commonHandler[T any](handler handlerFunc[T]) func(*gin.Context) {
|
func commonHandler[T any](handler handlerFunc[T]) func(*gin.Context) {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
data, err := handler(c)
|
data, err := handler(c)
|
||||||
@ -143,11 +159,18 @@ func commonHandler[T any](handler handlerFunc[T]) func(*gin.Context) {
|
|||||||
c.JSON(http.StatusOK, model.CommonResponse[T]{Success: true, Data: data})
|
c.JSON(http.StatusOK, model.CommonResponse[T]{Success: true, Data: data})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if _, ok := err.(*gormError); ok {
|
switch err.(type) {
|
||||||
|
case *gormError:
|
||||||
log.Printf("NEZHA>> gorm error: %v", err)
|
log.Printf("NEZHA>> gorm error: %v", err)
|
||||||
c.JSON(http.StatusOK, newErrorResponse(errors.New("database error")))
|
c.JSON(http.StatusOK, newErrorResponse(errors.New("database error")))
|
||||||
return
|
return
|
||||||
} else {
|
case *wsError:
|
||||||
|
// Connection is upgraded to WebSocket, so c.Writer is no longer usable
|
||||||
|
if msg := err.Error(); msg != "" {
|
||||||
|
log.Printf("NEZHA>> websocket error: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
default:
|
||||||
c.JSON(http.StatusOK, newErrorResponse(err))
|
c.JSON(http.StatusOK, newErrorResponse(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -21,7 +21,7 @@ import (
|
|||||||
// @Description Create an "attached" FM. It is advised to only call this within a terminal session.
|
// @Description Create an "attached" FM. It is advised to only call this within a terminal session.
|
||||||
// @Tags auth required
|
// @Tags auth required
|
||||||
// @Accept json
|
// @Accept json
|
||||||
// @Param id path uint true "Server ID"
|
// @Param id query uint true "Server ID"
|
||||||
// @Produce json
|
// @Produce json
|
||||||
// @Success 200 {object} model.CreateFMResponse
|
// @Success 200 {object} model.CreateFMResponse
|
||||||
// @Router /file [get]
|
// @Router /file [get]
|
||||||
@ -66,6 +66,7 @@ func createFM(c *gin.Context) (*model.CreateFMResponse, error) {
|
|||||||
// @Description Start FM stream
|
// @Description Start FM stream
|
||||||
// @Tags auth required
|
// @Tags auth required
|
||||||
// @Param id path string true "Stream UUID"
|
// @Param id path string true "Stream UUID"
|
||||||
|
// @Success 200 {object} model.CommonResponse[any]
|
||||||
// @Router /ws/file/{id} [get]
|
// @Router /ws/file/{id} [get]
|
||||||
func fmStream(c *gin.Context) (any, error) {
|
func fmStream(c *gin.Context) (any, error) {
|
||||||
streamId := c.Param("id")
|
streamId := c.Param("id")
|
||||||
@ -92,8 +93,8 @@ func fmStream(c *gin.Context) (any, error) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
if err = rpc.NezhaHandlerSingleton.UserConnected(streamId, conn); err != nil {
|
if err = rpc.NezhaHandlerSingleton.UserConnected(streamId, conn); err != nil {
|
||||||
return nil, err
|
return nil, newWsError("%v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, rpc.NezhaHandlerSingleton.StartStream(streamId, time.Second*10)
|
return nil, newWsError("%v", rpc.NezhaHandlerSingleton.StartStream(streamId, time.Second*10))
|
||||||
}
|
}
|
||||||
|
@ -66,6 +66,7 @@ func createTerminal(c *gin.Context) (*model.CreateTerminalResponse, error) {
|
|||||||
// @Description Terminal stream
|
// @Description Terminal stream
|
||||||
// @Tags auth required
|
// @Tags auth required
|
||||||
// @Param id path string true "Stream UUID"
|
// @Param id path string true "Stream UUID"
|
||||||
|
// @Success 200 {object} model.CommonResponse[any]
|
||||||
// @Router /ws/terminal/{id} [get]
|
// @Router /ws/terminal/{id} [get]
|
||||||
func terminalStream(c *gin.Context) (any, error) {
|
func terminalStream(c *gin.Context) (any, error) {
|
||||||
streamId := c.Param("id")
|
streamId := c.Param("id")
|
||||||
@ -92,8 +93,8 @@ func terminalStream(c *gin.Context) (any, error) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
if err = rpc.NezhaHandlerSingleton.UserConnected(streamId, conn); err != nil {
|
if err = rpc.NezhaHandlerSingleton.UserConnected(streamId, conn); err != nil {
|
||||||
return nil, err
|
return nil, newWsError("%v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, rpc.NezhaHandlerSingleton.StartStream(streamId, time.Second*10)
|
return nil, newWsError("%v", rpc.NezhaHandlerSingleton.StartStream(streamId, time.Second*10))
|
||||||
}
|
}
|
||||||
|
@ -51,7 +51,7 @@ func serverStream(c *gin.Context) (any, error) {
|
|||||||
}
|
}
|
||||||
time.Sleep(time.Second * 2)
|
time.Sleep(time.Second * 2)
|
||||||
}
|
}
|
||||||
return nil, nil
|
return nil, newWsError("")
|
||||||
}
|
}
|
||||||
|
|
||||||
var requestGroup singleflight.Group
|
var requestGroup singleflight.Group
|
||||||
|
Loading…
Reference in New Issue
Block a user