From 581d92d46d1bf279dfbd160188bc3cf6498c11fe Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Wed, 27 Jul 2022 12:17:50 +0200 Subject: [PATCH 1/3] zstd: Add DecodeAllCapLimit WithDecodeAllCapLimit will limit DecodeAll to decoding cap(dst)-len(dst) bytes, or any size set in WithDecoderMaxMemory. This can be used to limit decoding to a specific maximum output size. Disabled by default. Fixes #647 --- zstd/decoder.go | 31 ++++++++++++++++++++++++++++++- zstd/decoder_options.go | 12 ++++++++++++ zstd/decoder_test.go | 32 ++++++++++++++++++++++++++++++++ zstd/framedec.go | 19 +++++++++++++++++-- 4 files changed, 91 insertions(+), 3 deletions(-) diff --git a/zstd/decoder.go b/zstd/decoder.go index d212f4737f..22419e1500 100644 --- a/zstd/decoder.go +++ b/zstd/decoder.go @@ -355,6 +355,15 @@ func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) { } if frame.FrameContentSize != fcsUnknown { if frame.FrameContentSize > d.o.maxDecodedSize-uint64(len(dst)) { + if debugDecoder { + println("decoder size exceeded:", frame.FrameContentSize, ">", d.o.maxDecodedSize-uint64(len(dst))) + } + return dst, ErrDecoderSizeExceeded + } + if d.o.limitToCap && frame.FrameContentSize > uint64(cap(dst)-len(dst)) { + if debugDecoder { + println("decoder size exceeded:", frame.FrameContentSize, ">", cap(dst)-len(dst)) + } return dst, ErrDecoderSizeExceeded } if cap(dst)-len(dst) < int(frame.FrameContentSize) { @@ -364,7 +373,7 @@ func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) { } } - if cap(dst) == 0 { + if cap(dst) == 0 && !d.o.limitToCap { // Allocate len(input) * 2 by default if nothing is provided // and we didn't get frame content size. size := len(input) * 2 @@ -493,6 +502,22 @@ func (d *Decoder) nextBlockSync() (ok bool) { d.current.err = ErrDecoderSizeExceeded return false } + if d.frame.FrameContentSize != fcsUnknown { + if !d.o.limitToCap && d.frame.FrameContentSize > d.o.maxDecodedSize { + if debugDecoder { + println("decoder size exceeded, fcs:", d.frame.FrameContentSize, "> mds", d.o.maxDecodedSize) + } + d.current.err = ErrDecoderSizeExceeded + return false + } + if d.o.limitToCap && d.frame.FrameContentSize > uint64(cap(d.frame.history.b)-len(d.frame.history.b)) { + if debugDecoder { + println("decoder size exceeded, fcs:", d.frame.FrameContentSize, "> cap", cap(d.frame.history.b)-len(d.frame.history.b)) + } + d.current.err = ErrDecoderSizeExceeded + return false + } + } d.syncStream.decodedFrame = 0 d.syncStream.inFrame = true @@ -852,6 +877,10 @@ decodeStream: } } if err == nil && d.frame.WindowSize > d.o.maxWindowSize { + if debugDecoder { + println("decoder size exceeded, fws:", d.frame.WindowSize, "> mws:", d.o.maxWindowSize) + } + err = ErrDecoderSizeExceeded } if err != nil { diff --git a/zstd/decoder_options.go b/zstd/decoder_options.go index c70e6fa0f7..666c2715fe 100644 --- a/zstd/decoder_options.go +++ b/zstd/decoder_options.go @@ -20,6 +20,7 @@ type decoderOptions struct { maxWindowSize uint64 dicts []dict ignoreChecksum bool + limitToCap bool } func (o *decoderOptions) setDefault() { @@ -114,6 +115,17 @@ func WithDecoderMaxWindow(size uint64) DOption { } } +// WithDecodeAllCapLimit will limit DecodeAll to decoding cap(dst)-len(dst) bytes, +// or any size set in WithDecoderMaxMemory. +// This can be used to limit decoding to a specific maximum output size. +// Disabled by default. +func WithDecodeAllCapLimit(b bool) DOption { + return func(o *decoderOptions) error { + o.limitToCap = b + return nil + } +} + // IgnoreChecksum allows to forcibly ignore checksum checking. func IgnoreChecksum(b bool) DOption { return func(o *decoderOptions) error { diff --git a/zstd/decoder_test.go b/zstd/decoder_test.go index 30ec9bfee4..0c7e3a779e 100644 --- a/zstd/decoder_test.go +++ b/zstd/decoder_test.go @@ -19,6 +19,7 @@ import ( "path/filepath" "reflect" "runtime" + "strconv" "strings" "sync" "testing" @@ -1900,3 +1901,34 @@ func timeout(after time.Duration) (cancel func()) { close(cc) } } + +func TestWithDecodeAllCapLimit(t *testing.T) { + enc, _ := NewWriter(nil, WithZeroFrames(true), WithWindowSize(4<<10)) + dec, _ := NewReader(nil, WithDecodeAllCapLimit(true)) + for sz := 0; sz < 1<<20; sz = (sz + 1) * 2 { + sz := sz + t.Run(strconv.Itoa(sz), func(t *testing.T) { + encoded := enc.EncodeAll(make([]byte, sz), nil) + for i := sz - 1; i < sz+1; i++ { + if i < 0 { + continue + } + const existinglen = 5 + got, err := dec.DecodeAll(encoded, make([]byte, existinglen, i+existinglen)) + if i < sz { + if err != ErrDecoderSizeExceeded { + t.Errorf("cap: %d, want %v, got %v", i, ErrDecoderSizeExceeded, err) + } + } else { + if err != nil { + t.Errorf("cap: %d, want %v, got %v", i, nil, err) + continue + } + if len(got) != existinglen+i { + t.Errorf("cap: %d, want output size %d, got %d", i, existinglen+i, len(got)) + } + } + } + }) + } +} diff --git a/zstd/framedec.go b/zstd/framedec.go index 9568a4ba31..0a93a2065a 100644 --- a/zstd/framedec.go +++ b/zstd/framedec.go @@ -353,12 +353,23 @@ func (d *frameDec) runDecoder(dst []byte, dec *blockDec) ([]byte, error) { // Store input length, so we only check new data. crcStart := len(dst) d.history.decoders.maxSyncLen = 0 + if d.o.limitToCap { + d.history.decoders.maxSyncLen = uint64(cap(dst) - len(dst)) + } if d.FrameContentSize != fcsUnknown { - d.history.decoders.maxSyncLen = d.FrameContentSize + uint64(len(dst)) + if !d.o.limitToCap || d.FrameContentSize+uint64(len(dst)) < d.history.decoders.maxSyncLen { + d.history.decoders.maxSyncLen = d.FrameContentSize + uint64(len(dst)) + } if d.history.decoders.maxSyncLen > d.o.maxDecodedSize { + if debugDecoder { + println("maxSyncLen:", d.history.decoders.maxSyncLen, "> maxDecodedSize:", d.o.maxDecodedSize) + } return dst, ErrDecoderSizeExceeded } - if uint64(cap(dst)) < d.history.decoders.maxSyncLen { + if debugDecoder { + println("maxSyncLen:", d.history.decoders.maxSyncLen) + } + if !d.o.limitToCap && uint64(cap(dst)-len(dst)) < d.history.decoders.maxSyncLen { // Alloc for output dst2 := make([]byte, len(dst), d.history.decoders.maxSyncLen+compressedBlockOverAlloc) copy(dst2, dst) @@ -382,6 +393,10 @@ func (d *frameDec) runDecoder(dst []byte, dec *blockDec) ([]byte, error) { err = ErrDecoderSizeExceeded break } + if d.o.limitToCap && len(d.history.b) > cap(dst) { + err = ErrDecoderSizeExceeded + break + } if uint64(len(d.history.b)-crcStart) > d.FrameContentSize { println("runDecoder: FrameContentSize exceeded", uint64(len(d.history.b)-crcStart), ">", d.FrameContentSize) err = ErrFrameSizeExceeded From 19b37ba85efa9bc10174aee7b41bc7fc2f19ddb1 Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Wed, 27 Jul 2022 12:28:42 +0200 Subject: [PATCH 2/3] Remove changes from nextBlockSync. --- zstd/decoder.go | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/zstd/decoder.go b/zstd/decoder.go index 22419e1500..219c5c19b0 100644 --- a/zstd/decoder.go +++ b/zstd/decoder.go @@ -502,22 +502,6 @@ func (d *Decoder) nextBlockSync() (ok bool) { d.current.err = ErrDecoderSizeExceeded return false } - if d.frame.FrameContentSize != fcsUnknown { - if !d.o.limitToCap && d.frame.FrameContentSize > d.o.maxDecodedSize { - if debugDecoder { - println("decoder size exceeded, fcs:", d.frame.FrameContentSize, "> mds", d.o.maxDecodedSize) - } - d.current.err = ErrDecoderSizeExceeded - return false - } - if d.o.limitToCap && d.frame.FrameContentSize > uint64(cap(d.frame.history.b)-len(d.frame.history.b)) { - if debugDecoder { - println("decoder size exceeded, fcs:", d.frame.FrameContentSize, "> cap", cap(d.frame.history.b)-len(d.frame.history.b)) - } - d.current.err = ErrDecoderSizeExceeded - return false - } - } d.syncStream.decodedFrame = 0 d.syncStream.inFrame = true From da421eda687941dcf085d235731936e8ae677a25 Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Wed, 27 Jul 2022 13:54:04 +0200 Subject: [PATCH 3/3] ADd more tests. Correct better for initial size. --- zstd/decoder.go | 10 ++++--- zstd/decoder_test.go | 64 +++++++++++++++++++++++++++++--------------- zstd/framedec.go | 4 ++- 3 files changed, 53 insertions(+), 25 deletions(-) diff --git a/zstd/decoder.go b/zstd/decoder.go index 219c5c19b0..6104eb7936 100644 --- a/zstd/decoder.go +++ b/zstd/decoder.go @@ -312,6 +312,7 @@ func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) { // Grab a block decoder and frame decoder. block := <-d.decoders frame := block.localFrame + initialSize := len(dst) defer func() { if debugDecoder { printf("re-adding decoder: %p", block) @@ -354,15 +355,15 @@ func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) { return dst, ErrWindowSizeExceeded } if frame.FrameContentSize != fcsUnknown { - if frame.FrameContentSize > d.o.maxDecodedSize-uint64(len(dst)) { + if frame.FrameContentSize > d.o.maxDecodedSize-uint64(len(dst)-initialSize) { if debugDecoder { - println("decoder size exceeded:", frame.FrameContentSize, ">", d.o.maxDecodedSize-uint64(len(dst))) + println("decoder size exceeded; fcs:", frame.FrameContentSize, "> mcs:", d.o.maxDecodedSize-uint64(len(dst)-initialSize), "len:", len(dst)) } return dst, ErrDecoderSizeExceeded } if d.o.limitToCap && frame.FrameContentSize > uint64(cap(dst)-len(dst)) { if debugDecoder { - println("decoder size exceeded:", frame.FrameContentSize, ">", cap(dst)-len(dst)) + println("decoder size exceeded; fcs:", frame.FrameContentSize, "> (cap-len)", cap(dst)-len(dst)) } return dst, ErrDecoderSizeExceeded } @@ -391,6 +392,9 @@ func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) { if err != nil { return dst, err } + if uint64(len(dst)-initialSize) > d.o.maxDecodedSize { + return dst, ErrDecoderSizeExceeded + } if len(frame.bBuf) == 0 { if debugDecoder { println("frame dbuf empty") diff --git a/zstd/decoder_test.go b/zstd/decoder_test.go index 0c7e3a779e..971b5cdb42 100644 --- a/zstd/decoder_test.go +++ b/zstd/decoder_test.go @@ -1903,30 +1903,52 @@ func timeout(after time.Duration) (cancel func()) { } func TestWithDecodeAllCapLimit(t *testing.T) { - enc, _ := NewWriter(nil, WithZeroFrames(true), WithWindowSize(4<<10)) - dec, _ := NewReader(nil, WithDecodeAllCapLimit(true)) + var encs []*Encoder + var decs []*Decoder + addEnc := func(e *Encoder, _ error) { + encs = append(encs, e) + } + addDec := func(d *Decoder, _ error) { + decs = append(decs, d) + } + addEnc(NewWriter(nil, WithZeroFrames(true), WithWindowSize(4<<10))) + addEnc(NewWriter(nil, WithEncoderConcurrency(1), WithWindowSize(4<<10))) + addEnc(NewWriter(nil, WithZeroFrames(false), WithWindowSize(4<<10))) + addEnc(NewWriter(nil, WithWindowSize(128<<10))) + addDec(NewReader(nil, WithDecodeAllCapLimit(true))) + addDec(NewReader(nil, WithDecodeAllCapLimit(true), WithDecoderConcurrency(1))) + addDec(NewReader(nil, WithDecodeAllCapLimit(true), WithDecoderLowmem(true))) + addDec(NewReader(nil, WithDecodeAllCapLimit(true), WithDecoderMaxWindow(128<<10))) + addDec(NewReader(nil, WithDecodeAllCapLimit(true), WithDecoderMaxMemory(1<<20))) for sz := 0; sz < 1<<20; sz = (sz + 1) * 2 { sz := sz t.Run(strconv.Itoa(sz), func(t *testing.T) { - encoded := enc.EncodeAll(make([]byte, sz), nil) - for i := sz - 1; i < sz+1; i++ { - if i < 0 { - continue - } - const existinglen = 5 - got, err := dec.DecodeAll(encoded, make([]byte, existinglen, i+existinglen)) - if i < sz { - if err != ErrDecoderSizeExceeded { - t.Errorf("cap: %d, want %v, got %v", i, ErrDecoderSizeExceeded, err) - } - } else { - if err != nil { - t.Errorf("cap: %d, want %v, got %v", i, nil, err) - continue - } - if len(got) != existinglen+i { - t.Errorf("cap: %d, want output size %d, got %d", i, existinglen+i, len(got)) - } + t.Parallel() + for ei, enc := range encs { + for di, dec := range decs { + t.Run(fmt.Sprintf("e%d:d%d", ei, di), func(t *testing.T) { + encoded := enc.EncodeAll(make([]byte, sz), nil) + for i := sz - 1; i < sz+1; i++ { + if i < 0 { + continue + } + const existinglen = 5 + got, err := dec.DecodeAll(encoded, make([]byte, existinglen, i+existinglen)) + if i < sz { + if err != ErrDecoderSizeExceeded { + t.Errorf("cap: %d, want %v, got %v", i, ErrDecoderSizeExceeded, err) + } + } else { + if err != nil { + t.Errorf("cap: %d, want %v, got %v", i, nil, err) + continue + } + if len(got) != existinglen+i { + t.Errorf("cap: %d, want output size %d, got %d", i, existinglen+i, len(got)) + } + } + } + }) } } }) diff --git a/zstd/framedec.go b/zstd/framedec.go index 0a93a2065a..1559a20386 100644 --- a/zstd/framedec.go +++ b/zstd/framedec.go @@ -389,11 +389,13 @@ func (d *frameDec) runDecoder(dst []byte, dec *blockDec) ([]byte, error) { if err != nil { break } - if uint64(len(d.history.b)) > d.o.maxDecodedSize { + if uint64(len(d.history.b)-crcStart) > d.o.maxDecodedSize { + println("runDecoder: maxDecodedSize exceeded", uint64(len(d.history.b)-crcStart), ">", d.o.maxDecodedSize) err = ErrDecoderSizeExceeded break } if d.o.limitToCap && len(d.history.b) > cap(dst) { + println("runDecoder: cap exceeded", uint64(len(d.history.b)), ">", cap(dst)) err = ErrDecoderSizeExceeded break }