Skip to content

Commit

Permalink
websocket: fix message cache concurrent release problem
Browse files Browse the repository at this point in the history
  • Loading branch information
lesismal committed Apr 17, 2024
1 parent c3c7638 commit 84bf8b8
Showing 1 changed file with 25 additions and 16 deletions.
41 changes: 25 additions & 16 deletions nbhttp/websocket/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,14 +140,18 @@ func (c *Conn) CompressionEnabled() bool {
return c.compress
}

func (c *Conn) handleDataFrame(opcode MessageType, fin bool, data []byte) {
func (c *Conn) handleDataFrame(opcode MessageType, fin bool, body []byte) {
h := c.dataFrameHandler
if c.isBlockingMod {
h(c, opcode, fin, data)
h(c, opcode, fin, body)
} else {
c.Execute(func() {
h(c, opcode, fin, data)
})
if !c.Execute(func() {
h(c, opcode, fin, body)
}) {
if len(body) > 0 {
c.Engine.BodyAllocator.Free(body)
}
}
}
}

Expand Down Expand Up @@ -343,6 +347,7 @@ func (c *Conn) Parse(data []byte) error {
var err error
var body []byte
var frame []byte
var message []byte
var protocolMessage []byte
var opcode MessageType
var ok, fin, compress bool
Expand All @@ -367,7 +372,7 @@ func (c *Conn) Parse(data []byte) error {
bl := len(body)
if c.dataFrameHandler != nil {
if bl > 0 {
frame = c.Engine.BodyAllocator.Malloc(bl)
frame = allocator.Malloc(bl)
copy(frame, body)
}
if c.msgType == TextMessage && len(frame) > 0 && !c.Engine.CheckUtf8(frame) {
Expand All @@ -379,16 +384,20 @@ func (c *Conn) Parse(data []byte) error {
if c.messageHandler != nil {
if bl > 0 {
if c.message == nil {
c.message = c.Engine.BodyAllocator.Malloc(len(body))
c.message = allocator.Malloc(len(body))
copy(c.message, body)
} else {
c.message = c.Engine.BodyAllocator.Append(c.message, body...)
c.message = allocator.Append(c.message, body...)
}
}
if fin {
message = c.message
c.message = nil
}
}
case PingMessage, PongMessage, CloseMessage:
if len(body) > 0 {
protocolMessage = c.Engine.BodyAllocator.Malloc(len(body))
protocolMessage = allocator.Malloc(len(body))
copy(protocolMessage, body)
}
default:
Expand All @@ -413,23 +422,23 @@ func (c *Conn) Parse(data []byte) error {
var b []byte
var rc io.ReadCloser
if c.Engine.WebsocketDecompressor != nil {
rc = c.Engine.WebsocketDecompressor(io.MultiReader(bytes.NewBuffer(c.message), strings.NewReader(flateReaderTail)))
rc = c.Engine.WebsocketDecompressor(io.MultiReader(bytes.NewBuffer(message), strings.NewReader(flateReaderTail)))
} else {
rc = decompressReader(io.MultiReader(bytes.NewBuffer(c.message), strings.NewReader(flateReaderTail)))
rc = decompressReader(io.MultiReader(bytes.NewBuffer(message), strings.NewReader(flateReaderTail)))
}
b, err = c.readAll(rc, len(c.message)*2)
c.Engine.BodyAllocator.Free(c.message)
c.message = b
b, err = c.readAll(rc, len(message)*2)
allocator.Free(message)
message = b
rc.Close()
if err != nil {
return err
}
}
c.handleMessage(c.msgType, c.message)
c.handleMessage(c.msgType, message)
}
c.compress = false
c.expectingFragments = false
c.message = nil
message = nil
c.msgType = 0
} else {
c.expectingFragments = true
Expand Down

0 comments on commit 84bf8b8

Please sign in to comment.