Skip to content

Commit 5809dec

Browse files
committedJul 15, 2024·
optimise: use byteBuffer pool in decompression
1 parent a5f2b71 commit 5809dec

File tree

5 files changed

+125
-58
lines changed

5 files changed

+125
-58
lines changed
 

‎pkg/kgo/compression.go

+48-26
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,7 @@ import (
1414
"github.com/pierrec/lz4/v4"
1515
)
1616

17-
// sliceWriter a reusable slice as an io.Writer
18-
type sliceWriter struct{ inner []byte }
19-
20-
func (s *sliceWriter) Write(p []byte) (int, error) {
21-
s.inner = append(s.inner, p...)
22-
return len(p), nil
23-
}
24-
25-
var sliceWriters = sync.Pool{New: func() any { r := make([]byte, 8<<10); return &sliceWriter{inner: r} }}
17+
var byteBuffers = sync.Pool{New: func() any { return bytes.NewBuffer(make([]byte, 8<<10)) }}
2618

2719
type codecType int8
2820

@@ -175,9 +167,7 @@ type zstdEncoder struct {
175167
//
176168
// The writer should be put back to its pool after the returned slice is done
177169
// being used.
178-
func (c *compressor) compress(dst *sliceWriter, src []byte, produceRequestVersion int16) ([]byte, codecType) {
179-
dst.inner = dst.inner[:0]
180-
170+
func (c *compressor) compress(dst *bytes.Buffer, src []byte, produceRequestVersion int16) ([]byte, codecType) {
181171
var use codecType
182172
for _, option := range c.options {
183173
if option == codecZstd && produceRequestVersion < 7 {
@@ -187,6 +177,7 @@ func (c *compressor) compress(dst *sliceWriter, src []byte, produceRequestVersio
187177
break
188178
}
189179

180+
var out []byte
190181
switch use {
191182
case codecNone:
192183
return src, 0
@@ -200,10 +191,7 @@ func (c *compressor) compress(dst *sliceWriter, src []byte, produceRequestVersio
200191
if err := gz.Close(); err != nil {
201192
return nil, -1
202193
}
203-
204-
case codecSnappy:
205-
dst.inner = s2.EncodeSnappy(dst.inner[:cap(dst.inner)], src)
206-
194+
out = dst.Bytes()
207195
case codecLZ4:
208196
lz := c.lz4Pool.Get().(*lz4.Writer)
209197
defer c.lz4Pool.Put(lz)
@@ -214,13 +202,34 @@ func (c *compressor) compress(dst *sliceWriter, src []byte, produceRequestVersio
214202
if err := lz.Close(); err != nil {
215203
return nil, -1
216204
}
205+
out = dst.Bytes()
206+
case codecSnappy:
207+
// Because the Snappy and Zstd codecs do not accept an io.Writer interface
208+
// and directly take a []byte slice, here, the underlying []byte slice (`dst`)
209+
// obtained from the bytes.Buffer{} from the pool is passed.
210+
// As the `Write()` method on the buffer isn't used, its internal
211+
// book-keeping goes out of sync, making the buffer unusable for further
212+
// reading and writing via it's (eg: accessing via `Byte()`). For subsequent
213+
// reads, the underlying slice has to be used directly.
214+
//
215+
// In this particular context, it is acceptable as there there are no subsequent
216+
// operations performed on the buffer and it is immediately returned to the
217+
// pool and `Reset()` the next time it is obtained and used where `compress()`
218+
// is called.
219+
if l := s2.MaxEncodedLen(len(src)); l > dst.Cap() {
220+
dst.Grow(l)
221+
}
222+
out = s2.EncodeSnappy(dst.Bytes(), src)
217223
case codecZstd:
218224
zstdEnc := c.zstdPool.Get().(*zstdEncoder)
219225
defer c.zstdPool.Put(zstdEnc)
220-
dst.inner = zstdEnc.inner.EncodeAll(src, dst.inner)
226+
if l := zstdEnc.inner.MaxEncodedSize(len(src)); l > dst.Cap() {
227+
dst.Grow(l)
228+
}
229+
out = zstdEnc.inner.EncodeAll(src, dst.Bytes())
221230
}
222231

223-
return dst.inner, use
232+
return out, use
224233
}
225234

226235
type decompressor struct {
@@ -259,38 +268,51 @@ type zstdDecoder struct {
259268
}
260269

261270
func (d *decompressor) decompress(src []byte, codec byte) ([]byte, error) {
262-
switch codecType(codec) {
263-
case codecNone:
271+
// Early return in case there is no compression
272+
compCodec := codecType(codec)
273+
if compCodec == codecNone {
264274
return src, nil
275+
}
276+
out := byteBuffers.Get().(*bytes.Buffer)
277+
out.Reset()
278+
defer byteBuffers.Put(out)
279+
280+
switch compCodec {
265281
case codecGzip:
266282
ungz := d.ungzPool.Get().(*gzip.Reader)
267283
defer d.ungzPool.Put(ungz)
268284
if err := ungz.Reset(bytes.NewReader(src)); err != nil {
269285
return nil, err
270286
}
271-
out := new(bytes.Buffer)
272287
if _, err := io.Copy(out, ungz); err != nil {
273288
return nil, err
274289
}
275-
return out.Bytes(), nil
290+
return append([]byte(nil), out.Bytes()...), nil
276291
case codecSnappy:
277292
if len(src) > 16 && bytes.HasPrefix(src, xerialPfx) {
278293
return xerialDecode(src)
279294
}
280-
return s2.Decode(nil, src)
295+
decoded, err := s2.Decode(out.Bytes(), src)
296+
if err != nil {
297+
return nil, err
298+
}
299+
return append([]byte(nil), decoded...), nil
281300
case codecLZ4:
282301
unlz4 := d.unlz4Pool.Get().(*lz4.Reader)
283302
defer d.unlz4Pool.Put(unlz4)
284303
unlz4.Reset(bytes.NewReader(src))
285-
out := new(bytes.Buffer)
286304
if _, err := io.Copy(out, unlz4); err != nil {
287305
return nil, err
288306
}
289-
return out.Bytes(), nil
307+
return append([]byte(nil), out.Bytes()...), nil
290308
case codecZstd:
291309
unzstd := d.unzstdPool.Get().(*zstdDecoder)
292310
defer d.unzstdPool.Put(unzstd)
293-
return unzstd.inner.DecodeAll(src, nil)
311+
decoded, err := unzstd.inner.DecodeAll(src, out.Bytes())
312+
if err != nil {
313+
return nil, err
314+
}
315+
return append([]byte(nil), decoded...), nil
294316
default:
295317
return nil, errors.New("unknown compression codec")
296318
}

‎pkg/kgo/compression_test.go

+50-13
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"bytes"
55
"encoding/base64"
66
"fmt"
7+
"math/rand"
78
"reflect"
89
"sync"
910
"testing"
@@ -46,9 +47,23 @@ func TestNewCompressor(t *testing.T) {
4647
}
4748

4849
func TestCompressDecompress(t *testing.T) {
50+
randStr := func(length int) []byte {
51+
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
52+
b := make([]byte, length)
53+
for i := range b {
54+
b[i] = charset[rand.Intn(len(charset))]
55+
}
56+
return b
57+
}
58+
4959
t.Parallel()
5060
d := newDecompressor()
51-
in := []byte("foo")
61+
inputs := [][]byte{
62+
randStr(1 << 2),
63+
randStr(1 << 5),
64+
randStr(1 << 8),
65+
}
66+
5267
var wg sync.WaitGroup
5368
for _, produceVersion := range []int16{
5469
0, 7,
@@ -74,18 +89,21 @@ func TestCompressDecompress(t *testing.T) {
7489
for i := 0; i < 3; i++ {
7590
wg.Add(1)
7691
go func() {
92+
w := byteBuffers.Get().(*bytes.Buffer)
7793
defer wg.Done()
78-
w := sliceWriters.Get().(*sliceWriter)
79-
defer sliceWriters.Put(w)
80-
got, used := c.compress(w, in, produceVersion)
94+
defer byteBuffers.Put(w)
95+
for _, in := range inputs {
96+
w.Reset()
8197

82-
got, err := d.decompress(got, byte(used))
83-
if err != nil {
84-
t.Errorf("unexpected decompress err: %v", err)
85-
return
86-
}
87-
if !bytes.Equal(got, in) {
88-
t.Errorf("got decompress %s != exp compress in %s", got, in)
98+
got, used := c.compress(w, in, produceVersion)
99+
got, err := d.decompress(got, byte(used))
100+
if err != nil {
101+
t.Errorf("unexpected decompress err: %v", err)
102+
return
103+
}
104+
if !bytes.Equal(got, in) {
105+
t.Errorf("got decompress %s != exp compress in %s", got, in)
106+
}
89107
}
90108
}()
91109
}
@@ -102,16 +120,35 @@ func BenchmarkCompress(b *testing.B) {
102120
b.Run(fmt.Sprint(codec), func(b *testing.B) {
103121
var afterSize int
104122
for i := 0; i < b.N; i++ {
105-
w := sliceWriters.Get().(*sliceWriter)
123+
w := byteBuffers.Get().(*bytes.Buffer)
124+
w.Reset()
106125
after, _ := c.compress(w, in, 99)
107126
afterSize = len(after)
108-
sliceWriters.Put(w)
127+
byteBuffers.Put(w)
109128
}
110129
b.Logf("%d => %d", len(in), afterSize)
111130
})
112131
}
113132
}
114133

134+
func BenchmarkDecompress(b *testing.B) {
135+
in := bytes.Repeat([]byte("abcdefghijklmno pqrs tuvwxy z"), 100)
136+
for _, codec := range []codecType{codecGzip, codecSnappy, codecLZ4, codecZstd} {
137+
c, _ := newCompressor(CompressionCodec{codec: codec})
138+
w := byteBuffers.Get().(*bytes.Buffer)
139+
w.Reset()
140+
c.compress(w, in, 99)
141+
142+
b.Run(fmt.Sprint(codec), func(b *testing.B) {
143+
for i := 0; i < b.N; i++ {
144+
d := newDecompressor()
145+
d.decompress(w.Bytes(), byte(codec))
146+
}
147+
})
148+
byteBuffers.Put(w)
149+
}
150+
}
151+
115152
func Test_xerialDecode(t *testing.T) {
116153
tests := []struct {
117154
name string

‎pkg/kgo/logger.go

+12-12
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package kgo
22

33
import (
4+
"bytes"
45
"fmt"
56
"io"
67
"strings"
@@ -73,28 +74,27 @@ type basicLogger struct {
7374

7475
func (b *basicLogger) Level() LogLevel { return b.level }
7576
func (b *basicLogger) Log(level LogLevel, msg string, keyvals ...any) {
76-
buf := sliceWriters.Get().(*sliceWriter)
77-
defer sliceWriters.Put(buf)
77+
buf := byteBuffers.Get().(*bytes.Buffer)
78+
defer byteBuffers.Put(buf)
7879

79-
buf.inner = buf.inner[:0]
80+
buf.Reset()
8081
if b.pfxFn != nil {
81-
buf.inner = append(buf.inner, b.pfxFn()...)
82+
buf.WriteString(b.pfxFn())
8283
}
83-
buf.inner = append(buf.inner, '[')
84-
buf.inner = append(buf.inner, level.String()...)
85-
buf.inner = append(buf.inner, "] "...)
86-
buf.inner = append(buf.inner, msg...)
84+
buf.WriteByte('[')
85+
buf.WriteString(level.String())
86+
buf.WriteString("] ")
87+
buf.WriteString(msg)
8788

8889
if len(keyvals) > 0 {
89-
buf.inner = append(buf.inner, "; "...)
90+
buf.WriteString("; ")
9091
format := strings.Repeat("%v: %v, ", len(keyvals)/2)
9192
format = format[:len(format)-2] // trim trailing comma and space
9293
fmt.Fprintf(buf, format, keyvals...)
9394
}
9495

95-
buf.inner = append(buf.inner, '\n')
96-
97-
b.dst.Write(buf.inner)
96+
buf.WriteByte('\n')
97+
b.dst.Write(buf.Bytes())
9898
}
9999

100100
// nopLogger, the default logger, drops everything.

‎pkg/kgo/produce_request_test.go

+9-3
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,9 @@ func TestRecBatchAppendTo(t *testing.T) {
163163
compressor, _ = newCompressor(CompressionCodec{codec: 2}) // snappy
164164
{
165165
kbatch.Attributes |= 0x0002 // snappy
166-
kbatch.Records, _ = compressor.compress(sliceWriters.Get().(*sliceWriter), kbatch.Records, version)
166+
w := byteBuffers.Get().(*bytes.Buffer)
167+
w.Reset()
168+
kbatch.Records, _ = compressor.compress(w, kbatch.Records, version)
167169
}
168170

169171
fixFields()
@@ -254,7 +256,9 @@ func TestMessageSetAppendTo(t *testing.T) {
254256
Offset: 1,
255257
Attributes: 0x02,
256258
}
257-
kset0c.Value, _ = compressor.compress(sliceWriters.Get().(*sliceWriter), kset0raw, 1) // version 0, 1 use message set 0
259+
w := byteBuffers.Get().(*bytes.Buffer)
260+
w.Reset()
261+
kset0c.Value, _ = compressor.compress(w, kset0raw, 1) // version 0, 1 use message set 0
258262
kset0c.CRC = int32(crc32.ChecksumIEEE(kset0c.AppendTo(nil)[16:]))
259263
kset0c.MessageSize = int32(len(kset0c.AppendTo(nil)[12:]))
260264

@@ -265,7 +269,9 @@ func TestMessageSetAppendTo(t *testing.T) {
265269
Attributes: 0x02,
266270
Timestamp: kset11.Timestamp,
267271
}
268-
kset1c.Value, _ = compressor.compress(sliceWriters.Get().(*sliceWriter), kset1raw, 2) // version 2 use message set 1
272+
wbuf := byteBuffers.Get().(*bytes.Buffer)
273+
wbuf.Reset()
274+
kset1c.Value, _ = compressor.compress(wbuf, kset1raw, 2) // version 2 use message set 1
269275
kset1c.CRC = int32(crc32.ChecksumIEEE(kset1c.AppendTo(nil)[16:]))
270276
kset1c.MessageSize = int32(len(kset1c.AppendTo(nil)[12:]))
271277

‎pkg/kgo/sink.go

+6-4
Original file line numberDiff line numberDiff line change
@@ -2094,8 +2094,9 @@ func (b seqRecBatch) appendTo(
20942094
m.CompressedBytes = m.UncompressedBytes
20952095

20962096
if compressor != nil {
2097-
w := sliceWriters.Get().(*sliceWriter)
2098-
defer sliceWriters.Put(w)
2097+
w := byteBuffers.Get().(*bytes.Buffer)
2098+
defer byteBuffers.Put(w)
2099+
w.Reset()
20992100

21002101
compressed, codec := compressor.compress(w, toCompress, version)
21012102
if compressed != nil && // nil would be from an error
@@ -2175,8 +2176,9 @@ func (b seqRecBatch) appendToAsMessageSet(dst []byte, version uint8, compressor
21752176
m.CompressedBytes = m.UncompressedBytes
21762177

21772178
if compressor != nil {
2178-
w := sliceWriters.Get().(*sliceWriter)
2179-
defer sliceWriters.Put(w)
2179+
w := byteBuffers.Get().(*bytes.Buffer)
2180+
defer byteBuffers.Put(w)
2181+
w.Reset()
21802182

21812183
compressed, codec := compressor.compress(w, toCompress, int16(version))
21822184
inner := &Record{Value: compressed}

0 commit comments

Comments
 (0)
Please sign in to comment.