Skip to content

Commit

Permalink
[FIXED] Websocket decompression of continuation frames
Browse files Browse the repository at this point in the history
When compression is used, the full message payload need to be
assembled (in case there are continuation frames) before being
decompressed.

Signed-off-by: Ivan Kozlovic <ivan@synadia.com>
  • Loading branch information
kozlovic committed Jun 21, 2021
1 parent cc4f3b2 commit 9fc5222
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 80 deletions.
7 changes: 7 additions & 0 deletions nats.go
Expand Up @@ -2482,6 +2482,13 @@ func (nc *Conn) readLoop() {
for {
buf, err := br.Read()
if err == nil {
// With websocket, it is possible that there is no error but
// also no buffer returned (either WS control message or read of a
// partial compressed message). We could call parse(buf) which
// would ignore an empty buffer, but simply go back to top of the loop.
if len(buf) == 0 {
continue
}
err = nc.parse(buf)
}
if err != nil {
Expand Down
157 changes: 101 additions & 56 deletions ws.go
Expand Up @@ -72,9 +72,6 @@ const (
// From https://tools.ietf.org/html/rfc6455#section-1.3
var wsGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")

// As per https://tools.ietf.org/html/rfc7692#section-7.2.2
// add 0x00, 0x00, 0xff, 0xff and then a final block so that flate reader
// does not report unexpected EOF.
var compressFinalBlock = []byte{0x00, 0x00, 0xff, 0xff, 0x01, 0x00, 0x00, 0xff, 0xff}

type websocketReader struct {
Expand All @@ -83,10 +80,16 @@ type websocketReader struct {
ib []byte
ff bool
fc bool
dc io.ReadCloser
dc *wsDecompressor
nc *Conn
}

type wsDecompressor struct {
flate io.ReadCloser
bufs [][]byte
off int
}

type websocketWriter struct {
w io.Writer
compress bool
Expand All @@ -97,57 +100,81 @@ type websocketWriter struct {
noMoreSend bool // if true, even if there is a Write() call, we should not send anything
}

type decompressorBuffer struct {
buf []byte
rem int
off int
final bool
}

func newDecompressorBuffer(buf []byte) *decompressorBuffer {
return &decompressorBuffer{buf: buf, rem: len(buf)}
}

func (d *decompressorBuffer) Read(p []byte) (int, error) {
if d.buf == nil {
func (d *wsDecompressor) Read(dst []byte) (int, error) {
if len(dst) == 0 {
return 0, nil
}
if len(d.bufs) == 0 {
return 0, io.EOF
}
lim := d.rem
if len(p) < lim {
lim = len(p)
copied := 0
rem := len(dst)
for buf := d.bufs[0]; buf != nil && rem > 0; {
n := len(buf[d.off:])
if n > rem {
n = rem
}
copy(dst[copied:], buf[d.off:d.off+n])
copied += n
rem -= n
d.off += n
buf = d.nextBuf()
}
n := copy(p, d.buf[d.off:d.off+lim])
d.off += n
d.rem -= n
d.checkRem()
return n, nil
return copied, nil
}

func (d *decompressorBuffer) checkRem() {
if d.rem != 0 {
return
}
if !d.final {
d.buf = compressFinalBlock
d.off = 0
d.rem = len(d.buf)
d.final = true
} else {
d.buf = nil
}
func (d *wsDecompressor) nextBuf() []byte {
// We still have remaining data in the first buffer
if d.off != len(d.bufs[0]) {
return d.bufs[0]
}
// We read the full first buffer. Reset offset.
d.off = 0
// We were at the last buffer, so we are done.
if len(d.bufs) == 1 {
d.bufs = nil
return nil
}
// Here we move to the next buffer.
d.bufs = d.bufs[1:]
return d.bufs[0]
}

func (d *decompressorBuffer) ReadByte() (byte, error) {
if d.buf == nil {
func (d *wsDecompressor) ReadByte() (byte, error) {
if len(d.bufs) == 0 {
return 0, io.EOF
}
b := d.buf[d.off]
b := d.bufs[0][d.off]
d.off++
d.rem--
d.checkRem()
d.nextBuf()
return b, nil
}

func (d *wsDecompressor) addBuf(b []byte) {
d.bufs = append(d.bufs, b)
}

func (d *wsDecompressor) decompress() ([]byte, error) {
d.off = 0
// As per https://tools.ietf.org/html/rfc7692#section-7.2.2
// add 0x00, 0x00, 0xff, 0xff and then a final block so that flate reader
// does not report unexpected EOF.
d.bufs = append(d.bufs, compressFinalBlock)
// Create or reset the decompressor with his object (wsDecompressor)
// that provides Read() and ReadByte() APIs that will consume from
// the compressed buffers (d.bufs).
if d.flate == nil {
d.flate = flate.NewReader(d)
} else {
d.flate.(flate.Resetter).Reset(d, nil)
}
// TODO: When Go 1.15 support is dropped, replace with io.ReadAll()
b, err := ioutil.ReadAll(d.flate)
// Now reset the compressed buffers list
d.bufs = nil
return b, err
}

func wsNewReader(r io.Reader) *websocketReader {
return &websocketReader{r: r, ff: true}
}
Expand Down Expand Up @@ -254,29 +281,47 @@ func (r *websocketReader) Read(p []byte) (int, error) {
}

var b []byte
// This ensures that we get the full payload for this frame.
b, pos, err = wsGet(r.r, buf, pos, rem)
if err != nil {
return 0, err
}
// We read the full frame.
rem = 0
addToPending := true
if r.fc {
br := newDecompressorBuffer(b)
if r.dc == nil {
r.dc = flate.NewReader(br)
} else {
r.dc.(flate.Resetter).Reset(br, nil)
}
// TODO: When Go 1.15 support is dropped, replace with io.ReadAll()
b, err = ioutil.ReadAll(r.dc)
if err != nil {
return 0, err
// Don't add to pending if we are not dealing with the final frame.
addToPending = r.ff
// Add the compressed payload buffer to the list.
r.addCBuf(b)
// Decompress only when this is the final frame.
if r.ff {
b, err = r.dc.decompress()
if err != nil {
return 0, err
}
r.fc = false
}
r.fc = false
}
r.pending = append(r.pending, b)
// Add to the pending list if dealing with uncompressed frames or
// after we have received the full compressed message and decompressed it.
if addToPending {
r.pending = append(r.pending, b)
}
}
// In case of compression, there may be nothing to drain
if len(r.pending) > 0 {
return r.drainPending(p), nil
}
return 0, nil
}

func (r *websocketReader) addCBuf(b []byte) {
if r.dc == nil {
r.dc = &wsDecompressor{}
}
// At this point we should have pending slices.
return r.drainPending(p), nil
// Add a copy of the incoming buffer to the list of compressed buffers.
r.dc.addBuf(append([]byte(nil), b...))
}

func (r *websocketReader) drainPending(p []byte) int {
Expand Down
88 changes: 64 additions & 24 deletions ws_test.go
Expand Up @@ -15,11 +15,13 @@ package nats

import (
"bytes"
"compress/flate"
"crypto/tls"
"encoding/binary"
"fmt"
"io"
"math/rand"
"reflect"
"strings"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -346,8 +348,8 @@ func TestWSControlFrameBetweenDataFrames(t *testing.T) {
}
}

func TestWSDecompressorBuffer(t *testing.T) {
br := newDecompressorBuffer([]byte("ABCDE"))
func TestWSDecompressor(t *testing.T) {
var br *wsDecompressor

p := make([]byte, 100)
checkRead := func(limit int, expected []byte) {
Expand Down Expand Up @@ -385,52 +387,48 @@ func TestWSDecompressorBuffer(t *testing.T) {
}
}

newDecompressor := func(str string) *wsDecompressor {
d := &wsDecompressor{}
d.addBuf([]byte(str))
return d
}

// Read with enough room
br = newDecompressor("ABCDE")
checkRead(100, []byte("ABCDE"))
checkRead(100, compressFinalBlock)
checkEOF()
checkEOFWithReadByte()

// Read with a partial from our buffer
br = newDecompressorBuffer([]byte("FGHIJ"))
br = newDecompressor("FGHIJ")
checkRead(2, []byte("FG"))
// Call with more than the end of our buffer. We will have to
// call again to start with the final block
// Call with more than the end of our buffer.
checkRead(10, []byte("HIJ"))
checkRead(10, compressFinalBlock)
checkEOF()
checkEOFWithReadByte()

// Read with a partial from our buffer
br = newDecompressorBuffer([]byte("KLMNO"))
br = newDecompressor("KLMNO")
checkRead(2, []byte("KL"))
// Call with exact number of bytes left for our buffer.
checkRead(3, []byte("MNO"))
checkRead(10, compressFinalBlock)
checkEOF()
checkEOFWithReadByte()

// Now check partial of the final block
br = newDecompressorBuffer([]byte("PQRST"))
checkRead(10, []byte("PQRST"))
checkRead(2, compressFinalBlock[:2])
checkRead(4, compressFinalBlock[2:6])
checkRead(3, compressFinalBlock[6:9])
checkEOF()
checkEOFWithReadByte()

// Finally, check ReadByte.
br = newDecompressorBuffer([]byte("UVWXYZ"))
br = newDecompressor("UVWXYZ")
checkRead(4, []byte("UVWX"))
checkReadByte('Y')
checkReadByte('Z')
checkReadByte(compressFinalBlock[0])
checkReadByte(compressFinalBlock[1])
checkRead(5, compressFinalBlock[2:7])
checkReadByte(compressFinalBlock[7])
checkReadByte(compressFinalBlock[8])
checkEOFWithReadByte()
checkEOF()

br = newDecompressor("ABC")
buf := make([]byte, 0)
n, err := br.Read(buf)
if n != 0 || err != nil {
t.Fatalf("Unexpected n=%v err=%v", n, err)
}
}

func TestWSNoMixingScheme(t *testing.T) {
Expand Down Expand Up @@ -734,6 +732,48 @@ func TestWSCompression(t *testing.T) {
}
}

func TestWSCompressionWithContinuationFrames(t *testing.T) {
uncompressed := []byte("this is an uncompressed message with AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")
buf := &bytes.Buffer{}
compressor, _ := flate.NewWriter(buf, flate.BestSpeed)
compressor.Write(uncompressed)
compressor.Close()
b := buf.Bytes()
if len(b) < 30 {
panic("revisit test so that compressed buffer is more than 30 bytes long")
}

srbuf := &bytes.Buffer{}
// We are going to split this in several frames.
fh := []byte{66, 10}
srbuf.Write(fh)
srbuf.Write(b[:10])
fh = []byte{0, 10}
srbuf.Write(fh)
srbuf.Write(b[10:20])
fh = []byte{wsFinalBit, 0}
fh[1] = byte(len(b) - 20)
srbuf.Write(fh)
srbuf.Write(b[20:])

r := wsNewReader(srbuf)
rbuf := make([]byte, 100)
n, err := r.Read(rbuf[:15])
// Since we have a partial of compressed message, the library keeps track
// of buffer, but it can't return anything at this point, so n==0 err==nil
// is the expected result.
if n != 0 || err != nil {
t.Fatalf("Error reading: n=%v err=%v", n, err)
}
n, err = r.Read(rbuf)
if n != len(uncompressed) || err != nil {
t.Fatalf("Error reading: n=%v err=%v", n, err)
}
if !reflect.DeepEqual(uncompressed, rbuf[:n]) {
t.Fatalf("Unexpected uncompressed data: %v", rbuf[:n])
}
}

func TestWSWithTLS(t *testing.T) {
for _, test := range []struct {
name string
Expand Down

0 comments on commit 9fc5222

Please sign in to comment.