Skip to content

Commit

Permalink
Merge pull request #755 from nats-io/fix_ws_compression
Browse files Browse the repository at this point in the history
[FIXED] Websocket decompression of continuation frames
  • Loading branch information
kozlovic committed Jun 21, 2021
2 parents cc4f3b2 + 9fc5222 commit 5a4a543
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 80 deletions.
7 changes: 7 additions & 0 deletions nats.go
Expand Up @@ -2482,6 +2482,13 @@ func (nc *Conn) readLoop() {
for {
buf, err := br.Read()
if err == nil {
// With websocket, it is possible that there is no error but
// also no buffer returned (either WS control message or read of a
// partial compressed message). We could call parse(buf) which
// would ignore an empty buffer, but simply go back to top of the loop.
if len(buf) == 0 {
continue
}
err = nc.parse(buf)
}
if err != nil {
Expand Down
157 changes: 101 additions & 56 deletions ws.go
Expand Up @@ -72,9 +72,6 @@ const (
// From https://tools.ietf.org/html/rfc6455#section-1.3
var wsGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")

// As per https://tools.ietf.org/html/rfc7692#section-7.2.2
// add 0x00, 0x00, 0xff, 0xff and then a final block so that flate reader
// does not report unexpected EOF.
var compressFinalBlock = []byte{0x00, 0x00, 0xff, 0xff, 0x01, 0x00, 0x00, 0xff, 0xff}

type websocketReader struct {
Expand All @@ -83,10 +80,16 @@ type websocketReader struct {
ib []byte
ff bool
fc bool
dc io.ReadCloser
dc *wsDecompressor
nc *Conn
}

type wsDecompressor struct {
flate io.ReadCloser
bufs [][]byte
off int
}

type websocketWriter struct {
w io.Writer
compress bool
Expand All @@ -97,57 +100,81 @@ type websocketWriter struct {
noMoreSend bool // if true, even if there is a Write() call, we should not send anything
}

type decompressorBuffer struct {
buf []byte
rem int
off int
final bool
}

func newDecompressorBuffer(buf []byte) *decompressorBuffer {
return &decompressorBuffer{buf: buf, rem: len(buf)}
}

func (d *decompressorBuffer) Read(p []byte) (int, error) {
if d.buf == nil {
func (d *wsDecompressor) Read(dst []byte) (int, error) {
if len(dst) == 0 {
return 0, nil
}
if len(d.bufs) == 0 {
return 0, io.EOF
}
lim := d.rem
if len(p) < lim {
lim = len(p)
copied := 0
rem := len(dst)
for buf := d.bufs[0]; buf != nil && rem > 0; {
n := len(buf[d.off:])
if n > rem {
n = rem
}
copy(dst[copied:], buf[d.off:d.off+n])
copied += n
rem -= n
d.off += n
buf = d.nextBuf()
}
n := copy(p, d.buf[d.off:d.off+lim])
d.off += n
d.rem -= n
d.checkRem()
return n, nil
return copied, nil
}

func (d *decompressorBuffer) checkRem() {
if d.rem != 0 {
return
}
if !d.final {
d.buf = compressFinalBlock
d.off = 0
d.rem = len(d.buf)
d.final = true
} else {
d.buf = nil
}
func (d *wsDecompressor) nextBuf() []byte {
// We still have remaining data in the first buffer
if d.off != len(d.bufs[0]) {
return d.bufs[0]
}
// We read the full first buffer. Reset offset.
d.off = 0
// We were at the last buffer, so we are done.
if len(d.bufs) == 1 {
d.bufs = nil
return nil
}
// Here we move to the next buffer.
d.bufs = d.bufs[1:]
return d.bufs[0]
}

func (d *decompressorBuffer) ReadByte() (byte, error) {
if d.buf == nil {
func (d *wsDecompressor) ReadByte() (byte, error) {
if len(d.bufs) == 0 {
return 0, io.EOF
}
b := d.buf[d.off]
b := d.bufs[0][d.off]
d.off++
d.rem--
d.checkRem()
d.nextBuf()
return b, nil
}

func (d *wsDecompressor) addBuf(b []byte) {
d.bufs = append(d.bufs, b)
}

func (d *wsDecompressor) decompress() ([]byte, error) {
d.off = 0
// As per https://tools.ietf.org/html/rfc7692#section-7.2.2
// add 0x00, 0x00, 0xff, 0xff and then a final block so that flate reader
// does not report unexpected EOF.
d.bufs = append(d.bufs, compressFinalBlock)
// Create or reset the decompressor with his object (wsDecompressor)
// that provides Read() and ReadByte() APIs that will consume from
// the compressed buffers (d.bufs).
if d.flate == nil {
d.flate = flate.NewReader(d)
} else {
d.flate.(flate.Resetter).Reset(d, nil)
}
// TODO: When Go 1.15 support is dropped, replace with io.ReadAll()
b, err := ioutil.ReadAll(d.flate)
// Now reset the compressed buffers list
d.bufs = nil
return b, err
}

func wsNewReader(r io.Reader) *websocketReader {
return &websocketReader{r: r, ff: true}
}
Expand Down Expand Up @@ -254,29 +281,47 @@ func (r *websocketReader) Read(p []byte) (int, error) {
}

var b []byte
// This ensures that we get the full payload for this frame.
b, pos, err = wsGet(r.r, buf, pos, rem)
if err != nil {
return 0, err
}
// We read the full frame.
rem = 0
addToPending := true
if r.fc {
br := newDecompressorBuffer(b)
if r.dc == nil {
r.dc = flate.NewReader(br)
} else {
r.dc.(flate.Resetter).Reset(br, nil)
}
// TODO: When Go 1.15 support is dropped, replace with io.ReadAll()
b, err = ioutil.ReadAll(r.dc)
if err != nil {
return 0, err
// Don't add to pending if we are not dealing with the final frame.
addToPending = r.ff
// Add the compressed payload buffer to the list.
r.addCBuf(b)
// Decompress only when this is the final frame.
if r.ff {
b, err = r.dc.decompress()
if err != nil {
return 0, err
}
r.fc = false
}
r.fc = false
}
r.pending = append(r.pending, b)
// Add to the pending list if dealing with uncompressed frames or
// after we have received the full compressed message and decompressed it.
if addToPending {
r.pending = append(r.pending, b)
}
}
// In case of compression, there may be nothing to drain
if len(r.pending) > 0 {
return r.drainPending(p), nil
}
return 0, nil
}

func (r *websocketReader) addCBuf(b []byte) {
if r.dc == nil {
r.dc = &wsDecompressor{}
}
// At this point we should have pending slices.
return r.drainPending(p), nil
// Add a copy of the incoming buffer to the list of compressed buffers.
r.dc.addBuf(append([]byte(nil), b...))
}

func (r *websocketReader) drainPending(p []byte) int {
Expand Down
88 changes: 64 additions & 24 deletions ws_test.go
Expand Up @@ -15,11 +15,13 @@ package nats

import (
"bytes"
"compress/flate"
"crypto/tls"
"encoding/binary"
"fmt"
"io"
"math/rand"
"reflect"
"strings"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -346,8 +348,8 @@ func TestWSControlFrameBetweenDataFrames(t *testing.T) {
}
}

func TestWSDecompressorBuffer(t *testing.T) {
br := newDecompressorBuffer([]byte("ABCDE"))
func TestWSDecompressor(t *testing.T) {
var br *wsDecompressor

p := make([]byte, 100)
checkRead := func(limit int, expected []byte) {
Expand Down Expand Up @@ -385,52 +387,48 @@ func TestWSDecompressorBuffer(t *testing.T) {
}
}

newDecompressor := func(str string) *wsDecompressor {
d := &wsDecompressor{}
d.addBuf([]byte(str))
return d
}

// Read with enough room
br = newDecompressor("ABCDE")
checkRead(100, []byte("ABCDE"))
checkRead(100, compressFinalBlock)
checkEOF()
checkEOFWithReadByte()

// Read with a partial from our buffer
br = newDecompressorBuffer([]byte("FGHIJ"))
br = newDecompressor("FGHIJ")
checkRead(2, []byte("FG"))
// Call with more than the end of our buffer. We will have to
// call again to start with the final block
// Call with more than the end of our buffer.
checkRead(10, []byte("HIJ"))
checkRead(10, compressFinalBlock)
checkEOF()
checkEOFWithReadByte()

// Read with a partial from our buffer
br = newDecompressorBuffer([]byte("KLMNO"))
br = newDecompressor("KLMNO")
checkRead(2, []byte("KL"))
// Call with exact number of bytes left for our buffer.
checkRead(3, []byte("MNO"))
checkRead(10, compressFinalBlock)
checkEOF()
checkEOFWithReadByte()

// Now check partial of the final block
br = newDecompressorBuffer([]byte("PQRST"))
checkRead(10, []byte("PQRST"))
checkRead(2, compressFinalBlock[:2])
checkRead(4, compressFinalBlock[2:6])
checkRead(3, compressFinalBlock[6:9])
checkEOF()
checkEOFWithReadByte()

// Finally, check ReadByte.
br = newDecompressorBuffer([]byte("UVWXYZ"))
br = newDecompressor("UVWXYZ")
checkRead(4, []byte("UVWX"))
checkReadByte('Y')
checkReadByte('Z')
checkReadByte(compressFinalBlock[0])
checkReadByte(compressFinalBlock[1])
checkRead(5, compressFinalBlock[2:7])
checkReadByte(compressFinalBlock[7])
checkReadByte(compressFinalBlock[8])
checkEOFWithReadByte()
checkEOF()

br = newDecompressor("ABC")
buf := make([]byte, 0)
n, err := br.Read(buf)
if n != 0 || err != nil {
t.Fatalf("Unexpected n=%v err=%v", n, err)
}
}

func TestWSNoMixingScheme(t *testing.T) {
Expand Down Expand Up @@ -734,6 +732,48 @@ func TestWSCompression(t *testing.T) {
}
}

func TestWSCompressionWithContinuationFrames(t *testing.T) {
uncompressed := []byte("this is an uncompressed message with AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")
buf := &bytes.Buffer{}
compressor, _ := flate.NewWriter(buf, flate.BestSpeed)
compressor.Write(uncompressed)
compressor.Close()
b := buf.Bytes()
if len(b) < 30 {
panic("revisit test so that compressed buffer is more than 30 bytes long")
}

srbuf := &bytes.Buffer{}
// We are going to split this in several frames.
fh := []byte{66, 10}
srbuf.Write(fh)
srbuf.Write(b[:10])
fh = []byte{0, 10}
srbuf.Write(fh)
srbuf.Write(b[10:20])
fh = []byte{wsFinalBit, 0}
fh[1] = byte(len(b) - 20)
srbuf.Write(fh)
srbuf.Write(b[20:])

r := wsNewReader(srbuf)
rbuf := make([]byte, 100)
n, err := r.Read(rbuf[:15])
// Since we have a partial of compressed message, the library keeps track
// of buffer, but it can't return anything at this point, so n==0 err==nil
// is the expected result.
if n != 0 || err != nil {
t.Fatalf("Error reading: n=%v err=%v", n, err)
}
n, err = r.Read(rbuf)
if n != len(uncompressed) || err != nil {
t.Fatalf("Error reading: n=%v err=%v", n, err)
}
if !reflect.DeepEqual(uncompressed, rbuf[:n]) {
t.Fatalf("Unexpected uncompressed data: %v", rbuf[:n])
}
}

func TestWSWithTLS(t *testing.T) {
for _, test := range []struct {
name string
Expand Down

0 comments on commit 5a4a543

Please sign in to comment.