Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIXED] Websocket decompression of continuation frames #755

Merged
merged 1 commit into from Jun 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 7 additions & 0 deletions nats.go
Expand Up @@ -2482,6 +2482,13 @@ func (nc *Conn) readLoop() {
for {
buf, err := br.Read()
if err == nil {
// With websocket, it is possible that there is no error but
// also no buffer returned (either WS control message or read of a
// partial compressed message). We could call parse(buf) which
// would ignore an empty buffer, but simply go back to top of the loop.
if len(buf) == 0 {
continue
}
err = nc.parse(buf)
}
if err != nil {
Expand Down
157 changes: 101 additions & 56 deletions ws.go
Expand Up @@ -72,9 +72,6 @@ const (
// From https://tools.ietf.org/html/rfc6455#section-1.3
var wsGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")

// As per https://tools.ietf.org/html/rfc7692#section-7.2.2
// add 0x00, 0x00, 0xff, 0xff and then a final block so that flate reader
// does not report unexpected EOF.
var compressFinalBlock = []byte{0x00, 0x00, 0xff, 0xff, 0x01, 0x00, 0x00, 0xff, 0xff}

type websocketReader struct {
Expand All @@ -83,10 +80,16 @@ type websocketReader struct {
ib []byte
ff bool
fc bool
dc io.ReadCloser
dc *wsDecompressor
nc *Conn
}

type wsDecompressor struct {
flate io.ReadCloser
bufs [][]byte
off int
}

type websocketWriter struct {
w io.Writer
compress bool
Expand All @@ -97,57 +100,81 @@ type websocketWriter struct {
noMoreSend bool // if true, even if there is a Write() call, we should not send anything
}

type decompressorBuffer struct {
buf []byte
rem int
off int
final bool
}

func newDecompressorBuffer(buf []byte) *decompressorBuffer {
return &decompressorBuffer{buf: buf, rem: len(buf)}
}

func (d *decompressorBuffer) Read(p []byte) (int, error) {
if d.buf == nil {
func (d *wsDecompressor) Read(dst []byte) (int, error) {
if len(dst) == 0 {
return 0, nil
}
if len(d.bufs) == 0 {
return 0, io.EOF
}
lim := d.rem
if len(p) < lim {
lim = len(p)
copied := 0
rem := len(dst)
for buf := d.bufs[0]; buf != nil && rem > 0; {
n := len(buf[d.off:])
if n > rem {
n = rem
}
copy(dst[copied:], buf[d.off:d.off+n])
copied += n
rem -= n
d.off += n
buf = d.nextBuf()
}
n := copy(p, d.buf[d.off:d.off+lim])
d.off += n
d.rem -= n
d.checkRem()
return n, nil
return copied, nil
}

func (d *decompressorBuffer) checkRem() {
if d.rem != 0 {
return
}
if !d.final {
d.buf = compressFinalBlock
d.off = 0
d.rem = len(d.buf)
d.final = true
} else {
d.buf = nil
}
func (d *wsDecompressor) nextBuf() []byte {
// We still have remaining data in the first buffer
if d.off != len(d.bufs[0]) {
return d.bufs[0]
}
// We read the full first buffer. Reset offset.
d.off = 0
// We were at the last buffer, so we are done.
if len(d.bufs) == 1 {
d.bufs = nil
return nil
}
// Here we move to the next buffer.
d.bufs = d.bufs[1:]
return d.bufs[0]
}

func (d *decompressorBuffer) ReadByte() (byte, error) {
if d.buf == nil {
func (d *wsDecompressor) ReadByte() (byte, error) {
if len(d.bufs) == 0 {
return 0, io.EOF
}
b := d.buf[d.off]
b := d.bufs[0][d.off]
d.off++
d.rem--
d.checkRem()
d.nextBuf()
return b, nil
}

func (d *wsDecompressor) addBuf(b []byte) {
d.bufs = append(d.bufs, b)
}

func (d *wsDecompressor) decompress() ([]byte, error) {
d.off = 0
// As per https://tools.ietf.org/html/rfc7692#section-7.2.2
// add 0x00, 0x00, 0xff, 0xff and then a final block so that flate reader
// does not report unexpected EOF.
d.bufs = append(d.bufs, compressFinalBlock)
// Create or reset the decompressor with his object (wsDecompressor)
// that provides Read() and ReadByte() APIs that will consume from
// the compressed buffers (d.bufs).
if d.flate == nil {
d.flate = flate.NewReader(d)
} else {
d.flate.(flate.Resetter).Reset(d, nil)
}
// TODO: When Go 1.15 support is dropped, replace with io.ReadAll()
b, err := ioutil.ReadAll(d.flate)
// Now reset the compressed buffers list
d.bufs = nil
return b, err
}

func wsNewReader(r io.Reader) *websocketReader {
return &websocketReader{r: r, ff: true}
}
Expand Down Expand Up @@ -254,29 +281,47 @@ func (r *websocketReader) Read(p []byte) (int, error) {
}

var b []byte
// This ensures that we get the full payload for this frame.
b, pos, err = wsGet(r.r, buf, pos, rem)
if err != nil {
return 0, err
}
// We read the full frame.
rem = 0
addToPending := true
if r.fc {
br := newDecompressorBuffer(b)
if r.dc == nil {
r.dc = flate.NewReader(br)
} else {
r.dc.(flate.Resetter).Reset(br, nil)
}
// TODO: When Go 1.15 support is dropped, replace with io.ReadAll()
b, err = ioutil.ReadAll(r.dc)
if err != nil {
return 0, err
// Don't add to pending if we are not dealing with the final frame.
addToPending = r.ff
// Add the compressed payload buffer to the list.
r.addCBuf(b)
// Decompress only when this is the final frame.
if r.ff {
b, err = r.dc.decompress()
if err != nil {
return 0, err
}
r.fc = false
}
r.fc = false
}
r.pending = append(r.pending, b)
// Add to the pending list if dealing with uncompressed frames or
// after we have received the full compressed message and decompressed it.
if addToPending {
r.pending = append(r.pending, b)
}
}
// In case of compression, there may be nothing to drain
if len(r.pending) > 0 {
return r.drainPending(p), nil
}
return 0, nil
}

func (r *websocketReader) addCBuf(b []byte) {
if r.dc == nil {
r.dc = &wsDecompressor{}
}
// At this point we should have pending slices.
return r.drainPending(p), nil
// Add a copy of the incoming buffer to the list of compressed buffers.
r.dc.addBuf(append([]byte(nil), b...))
}

func (r *websocketReader) drainPending(p []byte) int {
Expand Down
88 changes: 64 additions & 24 deletions ws_test.go
Expand Up @@ -15,11 +15,13 @@ package nats

import (
"bytes"
"compress/flate"
"crypto/tls"
"encoding/binary"
"fmt"
"io"
"math/rand"
"reflect"
"strings"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -346,8 +348,8 @@ func TestWSControlFrameBetweenDataFrames(t *testing.T) {
}
}

func TestWSDecompressorBuffer(t *testing.T) {
br := newDecompressorBuffer([]byte("ABCDE"))
func TestWSDecompressor(t *testing.T) {
var br *wsDecompressor

p := make([]byte, 100)
checkRead := func(limit int, expected []byte) {
Expand Down Expand Up @@ -385,52 +387,48 @@ func TestWSDecompressorBuffer(t *testing.T) {
}
}

newDecompressor := func(str string) *wsDecompressor {
d := &wsDecompressor{}
d.addBuf([]byte(str))
return d
}

// Read with enough room
br = newDecompressor("ABCDE")
checkRead(100, []byte("ABCDE"))
checkRead(100, compressFinalBlock)
checkEOF()
checkEOFWithReadByte()

// Read with a partial from our buffer
br = newDecompressorBuffer([]byte("FGHIJ"))
br = newDecompressor("FGHIJ")
checkRead(2, []byte("FG"))
// Call with more than the end of our buffer. We will have to
// call again to start with the final block
// Call with more than the end of our buffer.
checkRead(10, []byte("HIJ"))
checkRead(10, compressFinalBlock)
checkEOF()
checkEOFWithReadByte()

// Read with a partial from our buffer
br = newDecompressorBuffer([]byte("KLMNO"))
br = newDecompressor("KLMNO")
checkRead(2, []byte("KL"))
// Call with exact number of bytes left for our buffer.
checkRead(3, []byte("MNO"))
checkRead(10, compressFinalBlock)
checkEOF()
checkEOFWithReadByte()

// Now check partial of the final block
br = newDecompressorBuffer([]byte("PQRST"))
checkRead(10, []byte("PQRST"))
checkRead(2, compressFinalBlock[:2])
checkRead(4, compressFinalBlock[2:6])
checkRead(3, compressFinalBlock[6:9])
checkEOF()
checkEOFWithReadByte()

// Finally, check ReadByte.
br = newDecompressorBuffer([]byte("UVWXYZ"))
br = newDecompressor("UVWXYZ")
checkRead(4, []byte("UVWX"))
checkReadByte('Y')
checkReadByte('Z')
checkReadByte(compressFinalBlock[0])
checkReadByte(compressFinalBlock[1])
checkRead(5, compressFinalBlock[2:7])
checkReadByte(compressFinalBlock[7])
checkReadByte(compressFinalBlock[8])
checkEOFWithReadByte()
checkEOF()

br = newDecompressor("ABC")
buf := make([]byte, 0)
n, err := br.Read(buf)
if n != 0 || err != nil {
t.Fatalf("Unexpected n=%v err=%v", n, err)
}
}

func TestWSNoMixingScheme(t *testing.T) {
Expand Down Expand Up @@ -734,6 +732,48 @@ func TestWSCompression(t *testing.T) {
}
}

func TestWSCompressionWithContinuationFrames(t *testing.T) {
uncompressed := []byte("this is an uncompressed message with AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")
buf := &bytes.Buffer{}
compressor, _ := flate.NewWriter(buf, flate.BestSpeed)
compressor.Write(uncompressed)
compressor.Close()
b := buf.Bytes()
if len(b) < 30 {
panic("revisit test so that compressed buffer is more than 30 bytes long")
}

srbuf := &bytes.Buffer{}
// We are going to split this in several frames.
fh := []byte{66, 10}
srbuf.Write(fh)
srbuf.Write(b[:10])
fh = []byte{0, 10}
srbuf.Write(fh)
srbuf.Write(b[10:20])
fh = []byte{wsFinalBit, 0}
fh[1] = byte(len(b) - 20)
srbuf.Write(fh)
srbuf.Write(b[20:])

r := wsNewReader(srbuf)
rbuf := make([]byte, 100)
n, err := r.Read(rbuf[:15])
// Since we have a partial of compressed message, the library keeps track
// of buffer, but it can't return anything at this point, so n==0 err==nil
// is the expected result.
if n != 0 || err != nil {
t.Fatalf("Error reading: n=%v err=%v", n, err)
}
n, err = r.Read(rbuf)
if n != len(uncompressed) || err != nil {
t.Fatalf("Error reading: n=%v err=%v", n, err)
}
if !reflect.DeepEqual(uncompressed, rbuf[:n]) {
t.Fatalf("Unexpected uncompressed data: %v", rbuf[:n])
}
}

func TestWSWithTLS(t *testing.T) {
for _, test := range []struct {
name string
Expand Down