Skip to content

Commit

Permalink
Outbound queue improvements (#4084)
Browse files Browse the repository at this point in the history
This extends the previous work in #3733 with the following:

1. Remove buffer coalescing, as this could result in a race condition
during the `writev` syscall in rare circumstances
2. Add a third buffer size, to ensure that we aren't allocating more
than we need to without coalescing
3. Refactor buffer handling in the WebSocket code to reduce allocations
and ensure owned buffers aren't incorrectly being pooled resulting in
further race conditions

Fixes nats-io/nats.ws#194.

Signed-off-by: Neil Twigg <neil@nats.io>
  • Loading branch information
derekcollison committed Apr 21, 2023
2 parents e96ae0b + 5f88434 commit 01041ca
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 108 deletions.
68 changes: 33 additions & 35 deletions server/client.go
Expand Up @@ -302,7 +302,8 @@ type outbound struct {
stc chan struct{} // Stall chan we create to slow down producers on overrun, e.g. fan-in.
}

const nbPoolSizeSmall = 4096 // Underlying array size of small buffer
const nbPoolSizeSmall = 512 // Underlying array size of small buffer
const nbPoolSizeMedium = 4096 // Underlying array size of medium buffer
const nbPoolSizeLarge = 65536 // Underlying array size of large buffer

var nbPoolSmall = &sync.Pool{
Expand All @@ -312,18 +313,44 @@ var nbPoolSmall = &sync.Pool{
},
}

var nbPoolMedium = &sync.Pool{
New: func() any {
b := [nbPoolSizeMedium]byte{}
return &b
},
}

var nbPoolLarge = &sync.Pool{
New: func() any {
b := [nbPoolSizeLarge]byte{}
return &b
},
}

func nbPoolGet(sz int) []byte {
var new []byte
switch {
case sz <= nbPoolSizeSmall:
ptr := nbPoolSmall.Get().(*[nbPoolSizeSmall]byte)
new = ptr[:0]
case sz <= nbPoolSizeMedium:
ptr := nbPoolMedium.Get().(*[nbPoolSizeMedium]byte)
new = ptr[:0]
default:
ptr := nbPoolLarge.Get().(*[nbPoolSizeLarge]byte)
new = ptr[:0]
}
return new
}

func nbPoolPut(b []byte) {
switch cap(b) {
case nbPoolSizeSmall:
b := (*[nbPoolSizeSmall]byte)(b[0:nbPoolSizeSmall])
nbPoolSmall.Put(b)
case nbPoolSizeMedium:
b := (*[nbPoolSizeMedium]byte)(b[0:nbPoolSizeMedium])
nbPoolMedium.Put(b)
case nbPoolSizeLarge:
b := (*[nbPoolSizeLarge]byte)(b[0:nbPoolSizeLarge])
nbPoolLarge.Put(b)
Expand Down Expand Up @@ -1481,7 +1508,7 @@ func (c *client) flushOutbound() bool {
if err != nil && err != io.ErrShortWrite {
// Handle timeout error (slow consumer) differently
if ne, ok := err.(net.Error); ok && ne.Timeout() {
if closed := c.handleWriteTimeout(n, attempted, len(c.out.nb)); closed {
if closed := c.handleWriteTimeout(n, attempted, len(orig)); closed {
return true
}
} else {
Expand Down Expand Up @@ -2014,43 +2041,14 @@ func (c *client) queueOutbound(data []byte) {
// without affecting the original "data" slice.
toBuffer := data

// All of the queued []byte have a fixed capacity, so if there's a []byte
// at the tail of the buffer list that isn't full yet, we should top that
// up first. This helps to ensure we aren't pulling more []bytes from the
// pool than we need to.
if len(c.out.nb) > 0 {
last := &c.out.nb[len(c.out.nb)-1]
if free := cap(*last) - len(*last); free > 0 {
if l := len(toBuffer); l < free {
free = l
}
*last = append(*last, toBuffer[:free]...)
toBuffer = toBuffer[free:]
}
}

// Now we can push the rest of the data into new []bytes from the pool
// in fixed size chunks. This ensures we don't go over the capacity of any
// of the buffers and end up reallocating.
for len(toBuffer) > 0 {
var new []byte
if len(c.out.nb) == 0 && len(toBuffer) <= nbPoolSizeSmall {
// If the buffer is empty, try to allocate a small buffer if the
// message will fit in it. This will help for cases like pings.
new = nbPoolSmall.Get().(*[nbPoolSizeSmall]byte)[:0]
} else {
// If "nb" isn't empty, default to large buffers in all cases as
// this means we are always coalescing future messages into
// larger buffers. Reduces the number of buffers into writev.
new = nbPoolLarge.Get().(*[nbPoolSizeLarge]byte)[:0]
}
l := len(toBuffer)
if c := cap(new); l > c {
l = c
}
new = append(new, toBuffer[:l]...)
c.out.nb = append(c.out.nb, new)
toBuffer = toBuffer[l:]
new := nbPoolGet(len(toBuffer))
n := copy(new[:cap(new)], toBuffer)
c.out.nb = append(c.out.nb, new[:n])
toBuffer = toBuffer[n:]
}

// Check for slow consumer via pending bytes limit.
Expand Down
60 changes: 0 additions & 60 deletions server/client_test.go
Expand Up @@ -1483,66 +1483,6 @@ func TestWildcardCharsInLiteralSubjectWorks(t *testing.T) {
}
}

// This test ensures that coalescing into the fixed-size output
// queues works as expected. When bytes are queued up, they should
// not overflow a buffer until the capacity is exceeded, at which
// point a new buffer should be added.
func TestClientOutboundQueueCoalesce(t *testing.T) {
opts := DefaultOptions()
s := RunServer(opts)
defer s.Shutdown()

nc, err := nats.Connect(fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port))
if err != nil {
t.Fatalf("Error on connect: %v", err)
}
defer nc.Close()

clients := s.GlobalAccount().getClients()
if len(clients) != 1 {
t.Fatal("Expecting a client to exist")
}
client := clients[0]
client.mu.Lock()
defer client.mu.Unlock()

// First up, queue something small into the queue.
client.queueOutbound([]byte{1, 2, 3, 4, 5})

if len(client.out.nb) != 1 {
t.Fatal("Expecting a single queued buffer")
}
if l := len(client.out.nb[0]); l != 5 {
t.Fatalf("Expecting only 5 bytes in the first queued buffer, found %d instead", l)
}

// Then queue up a few more bytes, but not enough
// to overflow into the next buffer.
client.queueOutbound([]byte{6, 7, 8, 9, 10})

if len(client.out.nb) != 1 {
t.Fatal("Expecting a single queued buffer")
}
if l := len(client.out.nb[0]); l != 10 {
t.Fatalf("Expecting 10 bytes in the first queued buffer, found %d instead", l)
}

// Finally, queue up something that is guaranteed
// to overflow.
b := nbPoolSmall.Get().(*[nbPoolSizeSmall]byte)[:]
b = b[:cap(b)]
client.queueOutbound(b)
if len(client.out.nb) != 2 {
t.Fatal("Expecting buffer to have overflowed")
}
if l := len(client.out.nb[0]); l != cap(b) {
t.Fatalf("Expecting %d bytes in the first queued buffer, found %d instead", cap(b), l)
}
if l := len(client.out.nb[1]); l != 10 {
t.Fatalf("Expecting 10 bytes in the second queued buffer, found %d instead", l)
}
}

// This test ensures that outbound queues don't cause a run on
// memory when sending something to lots of clients.
func TestClientOutboundQueueMemory(t *testing.T) {
Expand Down
44 changes: 31 additions & 13 deletions server/websocket.go
Expand Up @@ -452,7 +452,9 @@ func (c *client) wsHandleControlFrame(r *wsReadInfo, frameType wsOpCode, nc io.R
}
}
}
c.wsEnqueueControlMessage(wsCloseMessage, wsCreateCloseMessage(status, body))
clm := wsCreateCloseMessage(status, body)
c.wsEnqueueControlMessage(wsCloseMessage, clm)
nbPoolPut(clm) // wsEnqueueControlMessage has taken a copy.
// Return io.EOF so that readLoop will close the connection as ClientClosed
// after processing pending buffers.
return pos, io.EOF
Expand Down Expand Up @@ -502,7 +504,7 @@ func wsIsControlFrame(frameType wsOpCode) bool {
// Create the frame header.
// Encodes the frame type and optional compression flag, and the size of the payload.
func wsCreateFrameHeader(useMasking, compressed bool, frameType wsOpCode, l int) ([]byte, []byte) {
fh := make([]byte, wsMaxFrameHeaderSize)
fh := nbPoolGet(wsMaxFrameHeaderSize)[:wsMaxFrameHeaderSize]
n, key := wsFillFrameHeader(fh, useMasking, wsFirstFrame, wsFinalFrame, compressed, frameType, l)
return fh[:n], key
}
Expand Down Expand Up @@ -596,11 +598,13 @@ func (c *client) wsEnqueueControlMessageLocked(controlMsg wsOpCode, payload []by
if useMasking {
sz += 4
}
cm := make([]byte, sz+len(payload))
cm := nbPoolGet(sz + len(payload))
cm = cm[:cap(cm)]
n, key := wsFillFrameHeader(cm, useMasking, wsFirstFrame, wsFinalFrame, wsUncompressedFrame, controlMsg, len(payload))
cm = cm[:n]
// Note that payload is optional.
if len(payload) > 0 {
copy(cm[n:], payload)
cm = append(cm, payload...)
if useMasking {
wsMaskBuf(key, cm[n:])
}
Expand Down Expand Up @@ -646,6 +650,7 @@ func (c *client) wsEnqueueCloseMessage(reason ClosedState) {
}
body := wsCreateCloseMessage(status, reason.String())
c.wsEnqueueControlMessageLocked(wsCloseMessage, body)
nbPoolPut(body) // wsEnqueueControlMessageLocked has taken a copy.
}

// Create and then enqueue a close message with a protocol error and the
Expand All @@ -655,6 +660,7 @@ func (c *client) wsEnqueueCloseMessage(reason ClosedState) {
func (c *client) wsHandleProtocolError(message string) error {
buf := wsCreateCloseMessage(wsCloseStatusProtocolError, message)
c.wsEnqueueControlMessage(wsCloseMessage, buf)
nbPoolPut(buf) // wsEnqueueControlMessage has taken a copy.
return fmt.Errorf(message)
}

Expand All @@ -671,7 +677,7 @@ func wsCreateCloseMessage(status int, body string) []byte {
body = body[:wsMaxControlPayloadSize-5]
body += "..."
}
buf := make([]byte, 2+len(body))
buf := nbPoolGet(2 + len(body))[:2+len(body)]
// We need to have a 2 byte unsigned int that represents the error status code
// https://tools.ietf.org/html/rfc6455#section-5.5.1
binary.BigEndian.PutUint16(buf[:2], uint16(status))
Expand Down Expand Up @@ -1298,6 +1304,7 @@ func (c *client) wsCollapsePtoNB() (net.Buffers, int64) {
var csz int
for _, b := range nb {
cp.Write(b)
nbPoolPut(b) // No longer needed as contents written to compressor.
}
if err := cp.Flush(); err != nil {
c.Errorf("Error during compression: %v", err)
Expand All @@ -1314,24 +1321,33 @@ func (c *client) wsCollapsePtoNB() (net.Buffers, int64) {
} else {
final = true
}
fh := make([]byte, wsMaxFrameHeaderSize)
// Only the first frame should be marked as compressed, so pass
// `first` for the compressed boolean.
fh := nbPoolGet(wsMaxFrameHeaderSize)[:wsMaxFrameHeaderSize]
n, key := wsFillFrameHeader(fh, mask, first, final, first, wsBinaryMessage, lp)
if mask {
wsMaskBuf(key, p[:lp])
}
bufs = append(bufs, fh[:n], p[:lp])
new := nbPoolGet(wsFrameSizeForBrowsers)
lp = copy(new[:wsFrameSizeForBrowsers], p[:lp])
bufs = append(bufs, fh[:n], new[:lp])
csz += n + lp
p = p[lp:]
}
} else {
h, key := wsCreateFrameHeader(mask, true, wsBinaryMessage, len(p))
ol := len(p)
h, key := wsCreateFrameHeader(mask, true, wsBinaryMessage, ol)
if mask {
wsMaskBuf(key, p)
}
bufs = append(bufs, h, p)
csz = len(h) + len(p)
bufs = append(bufs, h)
for len(p) > 0 {
new := nbPoolGet(len(p))
n := copy(new[:cap(new)], p)
bufs = append(bufs, new[:n])
p = p[n:]
}
csz = len(h) + ol
}
// Add to pb the compressed data size (including headers), but
// remove the original uncompressed data size that was added
Expand All @@ -1343,7 +1359,7 @@ func (c *client) wsCollapsePtoNB() (net.Buffers, int64) {
if mfs > 0 {
// We are limiting the frame size.
startFrame := func() int {
bufs = append(bufs, make([]byte, wsMaxFrameHeaderSize))
bufs = append(bufs, nbPoolGet(wsMaxFrameHeaderSize)[:wsMaxFrameHeaderSize])
return len(bufs) - 1
}
endFrame := func(idx, size int) {
Expand Down Expand Up @@ -1376,8 +1392,10 @@ func (c *client) wsCollapsePtoNB() (net.Buffers, int64) {
if endStart {
fhIdx = startFrame()
}
bufs = append(bufs, b[:total])
b = b[total:]
new := nbPoolGet(total)
n := copy(new[:cap(new)], b[:total])
bufs = append(bufs, new[:n])
b = b[n:]
}
}
if total > 0 {
Expand Down

0 comments on commit 01041ca

Please sign in to comment.