Skip to content

Commit 24fbb0f

Browse files
committedJul 29, 2024
kgo: fix deadlock in Produce when using MaxBufferedBytes
Copying from the issue, """ 1) Produce() record A (100 bytes) 2) Produce() record B (50 bytes), waiting for buffer to free 3) Produce() record C (50 bytes), waiting for buffer to free 4) Record A is produced, finishRecordPromise() gets called, detects it was over the limit so publish 1 message to waitBuffer 5) Record B is unlocked, finishRecordPromise() gets called, does not detect it was over the limit (only 50 bytes), so record C is never unblocked and will wait indefinitely on waitBuffer """ The fix requires adding a lock while producing. This reuses the existing lock on the `producer` type. This can lead to a few more spurious wakeups in other functions that use this same mutex, but that's fine. The prior algorithm counted anything to produce immediately into the buffered records and bytes fields; the fix for #777 could not really be possible unless we avoid counting the "buffered" aspect right away. Specifically, we need to have a goroutine looping with a sync.Cond that checks *IF* we add the record, will we still be blocked? This allows us to wake up all blocked goroutines always (unlike one at a time, the problem this issue points out), and each goroutine can check under a lock if they still do not fit. This also fixes an unreported bug where, if a record WOULD be blocked but fails early due to no topic / not in a transaction while in a transactional client, the serial promise finishing goroutine would deadlock. Closes #777.
1 parent 58d20a1 commit 24fbb0f

File tree

4 files changed

+187
-71
lines changed

4 files changed

+187
-71
lines changed
 

‎pkg/kgo/consumer_direct_test.go

+4-2
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,8 @@ func TestPauseIssue489(t *testing.T) {
338338
exit.Store(true)
339339
}
340340
})
341-
time.Sleep(100 * time.Microsecond)
341+
cl.Flush(ctx)
342+
time.Sleep(50 * time.Microsecond)
342343
}
343344
}()
344345
defer cancel()
@@ -416,7 +417,8 @@ func TestPauseIssueOct2023(t *testing.T) {
416417
exit.Store(true)
417418
}
418419
})
419-
time.Sleep(100 * time.Microsecond)
420+
cl.Flush(ctx)
421+
time.Sleep(50 * time.Microsecond)
420422
}
421423
}()
422424
defer cancel()

‎pkg/kgo/produce_request_test.go

+62
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,75 @@ package kgo
22

33
import (
44
"bytes"
5+
"context"
6+
"errors"
57
"hash/crc32"
8+
"math/rand"
9+
"strings"
10+
"sync"
11+
"sync/atomic"
612
"testing"
713

814
"github.com/twmb/franz-go/pkg/kbin"
915
"github.com/twmb/franz-go/pkg/kmsg"
1016
)
1117

18+
func TestClient_Produce(t *testing.T) {
19+
var (
20+
topic, cleanup = tmpTopicPartitions(t, 1)
21+
numWorkers = 50
22+
recsToWrite = int64(20_000)
23+
24+
workers sync.WaitGroup
25+
writeSuccess atomic.Int64
26+
writeFailure atomic.Int64
27+
28+
randRec = func() *Record {
29+
return &Record{
30+
Key: []byte("test"),
31+
Value: []byte(strings.Repeat("x", rand.Intn(1000))),
32+
Topic: topic,
33+
}
34+
}
35+
)
36+
defer cleanup()
37+
38+
cl, _ := newTestClient(MaxBufferedBytes(5000))
39+
defer cl.Close()
40+
41+
// Start N workers that will concurrently write to the same partition.
42+
var recsWritten atomic.Int64
43+
var fatal atomic.Bool
44+
for i := 0; i < numWorkers; i++ {
45+
workers.Add(1)
46+
47+
go func() {
48+
defer workers.Done()
49+
50+
for recsWritten.Add(1) <= recsToWrite {
51+
res := cl.ProduceSync(context.Background(), randRec())
52+
if err := res.FirstErr(); err == nil {
53+
writeSuccess.Add(1)
54+
} else {
55+
if !errors.Is(err, ErrMaxBuffered) {
56+
t.Errorf("unexpected error: %v", err)
57+
fatal.Store(true)
58+
}
59+
60+
writeFailure.Add(1)
61+
}
62+
}
63+
}()
64+
}
65+
workers.Wait()
66+
67+
t.Logf("writes succeeded: %d", writeSuccess.Load())
68+
t.Logf("writes failed: %d", writeFailure.Load())
69+
if fatal.Load() {
70+
t.Fatal("failed")
71+
}
72+
}
73+
1274
// This file contains golden tests against kmsg AppendTo's to ensure our custom
1375
// encoding is correct.
1476

‎pkg/kgo/producer.go

+120-68
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,15 @@ import (
1414
)
1515

1616
type producer struct {
17-
bufferedRecords atomicI64
18-
bufferedBytes atomicI64
19-
inflight atomicI64 // high 16: # waiters, low 48: # inflight
17+
inflight atomicI64 // high 16: # waiters, low 48: # inflight
18+
19+
// mu and c are used for flush and drain notifications; mu is used for
20+
// a few other tight locks.
21+
mu sync.Mutex
22+
c *sync.Cond
23+
24+
bufferedRecords int64
25+
bufferedBytes int64
2026

2127
cl *Client
2228

@@ -45,19 +51,14 @@ type producer struct {
4551
// We must have a producer field for flushing; we cannot just have a
4652
// field on recBufs that is toggled on flush. If we did, then a new
4753
// recBuf could be created and records sent to while we are flushing.
48-
flushing atomicI32 // >0 if flushing, can Flush many times concurrently
49-
blocked atomicI32 // >0 if over max recs or bytes
54+
flushing atomicI32 // >0 if flushing, can Flush many times concurrently
55+
blocked atomicI32 // >0 if over max recs or bytes
56+
blockedBytes int64
5057

5158
aborting atomicI32 // >0 if aborting, can abort many times concurrently
5259

53-
idMu sync.Mutex
54-
idVersion int16
55-
waitBuffer chan struct{}
56-
57-
// mu and c are used for flush and drain notifications; mu is used for
58-
// a few other tight locks.
59-
mu sync.Mutex
60-
c *sync.Cond
60+
idMu sync.Mutex
61+
idVersion int16
6162

6263
batchPromises ringBatchPromise
6364
promisesMu sync.Mutex
@@ -86,14 +87,18 @@ type producer struct {
8687
// flushing records produced by your client (which can help determine network /
8788
// cluster health).
8889
func (cl *Client) BufferedProduceRecords() int64 {
89-
return cl.producer.bufferedRecords.Load()
90+
cl.producer.mu.Lock()
91+
defer cl.producer.mu.Unlock()
92+
return cl.producer.bufferedRecords + int64(cl.producer.blocked.Load())
9093
}
9194

9295
// BufferedProduceBytes returns the number of bytes currently buffered for
9396
// producing within the client. This is the sum of all keys, values, and header
9497
// keys/values. See the related [BufferedProduceRecords] for more information.
9598
func (cl *Client) BufferedProduceBytes() int64 {
96-
return cl.producer.bufferedBytes.Load()
99+
cl.producer.mu.Lock()
100+
defer cl.producer.mu.Unlock()
101+
return cl.producer.bufferedBytes + cl.producer.blockedBytes
97102
}
98103

99104
type unknownTopicProduces struct {
@@ -106,7 +111,6 @@ func (p *producer) init(cl *Client) {
106111
p.cl = cl
107112
p.topics = newTopicsPartitions()
108113
p.unknownTopics = make(map[string]*unknownTopicProduces)
109-
p.waitBuffer = make(chan struct{}, math.MaxInt32)
110114
p.idVersion = -1
111115
p.id.Store(&producerID{
112116
id: -1,
@@ -397,58 +401,93 @@ func (cl *Client) produce(
397401
}
398402
}
399403

400-
var (
401-
userSize = r.userSize()
402-
bufRecs = p.bufferedRecords.Add(1)
403-
bufBytes = p.bufferedBytes.Add(userSize)
404-
overMaxRecs = bufRecs > cl.cfg.maxBufferedRecords
405-
overMaxBytes bool
406-
)
407-
if cl.cfg.maxBufferedBytes > 0 {
408-
if userSize > cl.cfg.maxBufferedBytes {
409-
p.promiseRecord(promisedRec{ctx, promise, r}, kerr.MessageTooLarge)
410-
return
411-
}
412-
overMaxBytes = bufBytes > cl.cfg.maxBufferedBytes
413-
}
414-
404+
// We can now fail the rec after the buffered hook.
415405
if r.Topic == "" {
416-
p.promiseRecord(promisedRec{ctx, promise, r}, errNoTopic)
406+
p.promiseRecordBeforeBuf(promisedRec{ctx, promise, r}, errNoTopic)
417407
return
418408
}
419409
if cl.cfg.txnID != nil && !p.producingTxn.Load() {
420-
p.promiseRecord(promisedRec{ctx, promise, r}, errNotInTransaction)
410+
p.promiseRecordBeforeBuf(promisedRec{ctx, promise, r}, errNotInTransaction)
421411
return
422412
}
423413

414+
userSize := r.userSize()
415+
if cl.cfg.maxBufferedBytes > 0 && userSize > cl.cfg.maxBufferedBytes {
416+
p.promiseRecordBeforeBuf(promisedRec{ctx, promise, r}, kerr.MessageTooLarge)
417+
return
418+
}
419+
420+
// We have to grab the produce lock to check if this record will exceed
421+
// configured limits. We try to keep the logic tight since this is
422+
// effectively a global lock around producing.
423+
var (
424+
nextBufRecs, nextBufBytes int64
425+
overMaxRecs, overMaxBytes bool
426+
427+
calcNums = func() {
428+
nextBufRecs = p.bufferedRecords + 1
429+
nextBufBytes = p.bufferedBytes + userSize
430+
overMaxRecs = nextBufRecs > cl.cfg.maxBufferedRecords
431+
overMaxBytes = cl.cfg.maxBufferedBytes > 0 && nextBufBytes > cl.cfg.maxBufferedBytes
432+
}
433+
)
434+
p.mu.Lock()
435+
calcNums()
424436
if overMaxRecs || overMaxBytes {
437+
if !block || cl.cfg.manualFlushing {
438+
p.mu.Unlock()
439+
p.promiseRecordBeforeBuf(promisedRec{ctx, promise, r}, ErrMaxBuffered)
440+
return
441+
}
442+
443+
// Before we potentially unlinger, add that we are blocked to
444+
// ensure we do NOT start a linger anymore. We THEN wakeup
445+
// anything that is actively lingering. Note that blocked is
446+
// also used when finishing promises to see if we need to be
447+
// notified.
448+
p.blocked.Add(1)
449+
p.blockedBytes += userSize
450+
p.mu.Unlock()
451+
425452
cl.cfg.logger.Log(LogLevelDebug, "blocking Produce because we are either over max buffered records or max buffered bytes",
426453
"over_max_records", overMaxRecs,
427454
"over_max_bytes", overMaxBytes,
428455
)
429-
// Before we potentially unlinger, add that we are blocked.
430-
// Lingering always checks blocked, so we will not start a
431-
// linger while we are blocked. We THEN wakeup anything that
432-
// is actively lingering.
433-
cl.producer.blocked.Add(1)
456+
434457
cl.unlingerDueToMaxRecsBuffered()
435-
// If the client ctx cancels or the produce ctx cancels, we
436-
// need to un-count our buffering of this record. We also need
437-
// to drain a slot from the waitBuffer chan, which could be
438-
// sent to right when we are erroring.
458+
459+
// We keep the lock when we exit. If we are flushing, we want
460+
// this blocked record to be produced before we return from
461+
// flushing. This blocked record will be accounted for in the
462+
// bufferedRecords addition below, after being removed from
463+
// blocked in the goroutine.
464+
wait := make(chan struct{})
465+
var quit bool
466+
go func() {
467+
defer close(wait)
468+
p.mu.Lock()
469+
calcNums()
470+
for !quit && (overMaxRecs || overMaxBytes) {
471+
p.c.Wait()
472+
calcNums()
473+
}
474+
p.blocked.Add(-1)
475+
p.blockedBytes -= userSize
476+
}()
477+
439478
drainBuffered := func(err error) {
440-
p.promiseRecord(promisedRec{ctx, promise, r}, err)
441-
<-p.waitBuffer
442-
cl.producer.blocked.Add(-1)
443-
}
444-
if !block || cl.cfg.manualFlushing {
445-
drainBuffered(ErrMaxBuffered)
446-
return
479+
p.mu.Lock()
480+
quit = true
481+
p.mu.Unlock()
482+
p.c.Broadcast() // wake the goroutine above
483+
<-wait
484+
p.mu.Unlock() // we wait for the goroutine to exit, then unlock again (since the goroutine leaves the mutex locked)
485+
p.promiseRecordBeforeBuf(promisedRec{ctx, promise, r}, err)
447486
}
487+
448488
select {
449-
case <-p.waitBuffer:
450-
cl.cfg.logger.Log(LogLevelDebug, "Produce block signaled, continuing to produce")
451-
cl.producer.blocked.Add(-1)
489+
case <-wait:
490+
cl.cfg.logger.Log(LogLevelDebug, "Produce block awoken, we now have space to produce, continuing to partition and produce")
452491
case <-cl.ctx.Done():
453492
drainBuffered(ErrClientClosed)
454493
cl.cfg.logger.Log(LogLevelDebug, "client ctx canceled while blocked in Produce, returning")
@@ -459,6 +498,9 @@ func (cl *Client) produce(
459498
return
460499
}
461500
}
501+
p.bufferedRecords = nextBufRecs
502+
p.bufferedBytes = nextBufBytes
503+
p.mu.Unlock()
462504

463505
cl.partitionRecord(promisedRec{ctx, promise, r})
464506
}
@@ -468,6 +510,7 @@ type batchPromise struct {
468510
pid int64
469511
epoch int16
470512
attrs RecordAttrs
513+
beforeBuf bool
471514
partition int32
472515
recs []promisedRec
473516
err error
@@ -483,6 +526,10 @@ func (p *producer) promiseRecord(pr promisedRec, err error) {
483526
p.promiseBatch(batchPromise{recs: []promisedRec{pr}, err: err})
484527
}
485528

529+
func (p *producer) promiseRecordBeforeBuf(pr promisedRec, err error) {
530+
p.promiseBatch(batchPromise{recs: []promisedRec{pr}, beforeBuf: true, err: err})
531+
}
532+
486533
func (p *producer) finishPromises(b batchPromise) {
487534
cl := p.cl
488535
var more bool
@@ -495,7 +542,7 @@ start:
495542
pr.ProducerID = b.pid
496543
pr.ProducerEpoch = b.epoch
497544
pr.Attrs = b.attrs
498-
cl.finishRecordPromise(pr, b.err)
545+
cl.finishRecordPromise(pr, b.err, b.beforeBuf)
499546
b.recs[i] = promisedRec{}
500547
}
501548
p.promisesMu.Unlock()
@@ -509,7 +556,7 @@ start:
509556
}
510557
}
511558

512-
func (cl *Client) finishRecordPromise(pr promisedRec, err error) {
559+
func (cl *Client) finishRecordPromise(pr promisedRec, err error, beforeBuffering bool) {
513560
p := &cl.producer
514561

515562
if p.hooks != nil && len(p.hooks.unbuffered) > 0 {
@@ -519,22 +566,27 @@ func (cl *Client) finishRecordPromise(pr promisedRec, err error) {
519566
}
520567

521568
// Capture user size before potential modification by the promise.
569+
//
570+
// We call the promise before finishing the flush notification,
571+
// allowing users of Flush to know all buf recs are done by the
572+
// time we notify flush below.
522573
userSize := pr.userSize()
523-
nowBufBytes := p.bufferedBytes.Add(-userSize)
524-
nowBufRecs := p.bufferedRecords.Add(-1)
525-
wasOverMaxRecs := nowBufRecs >= cl.cfg.maxBufferedRecords
526-
wasOverMaxBytes := cl.cfg.maxBufferedBytes > 0 && nowBufBytes+userSize > cl.cfg.maxBufferedBytes
527-
528-
// We call the promise before finishing the record; this allows users
529-
// of Flush to know that all buffered records are completely done
530-
// before Flush returns.
531574
pr.promise(pr.Record, err)
532575

533-
if wasOverMaxRecs || wasOverMaxBytes {
534-
p.waitBuffer <- struct{}{}
535-
} else if nowBufRecs == 0 && p.flushing.Load() > 0 {
536-
p.mu.Lock()
537-
p.mu.Unlock() //nolint:gocritic,staticcheck // We use the lock as a barrier, unlocking immediately is safe.
576+
// If this record was never buffered, it's size was never accounted
577+
// for on any p field: return early.
578+
if beforeBuffering {
579+
return
580+
}
581+
582+
// Keep the lock as tight as possible: the broadcast can come after.
583+
p.mu.Lock()
584+
p.bufferedBytes -= userSize
585+
p.bufferedRecords--
586+
broadcast := p.blocked.Load() > 0 || p.bufferedRecords == 0 && p.flushing.Load() > 0
587+
p.mu.Unlock()
588+
589+
if broadcast {
538590
p.c.Broadcast()
539591
}
540592
}
@@ -1021,7 +1073,7 @@ func (cl *Client) Flush(ctx context.Context) error {
10211073
defer p.mu.Unlock()
10221074
defer close(done)
10231075

1024-
for !quit && p.bufferedRecords.Load() > 0 {
1076+
for !quit && p.bufferedRecords+int64(p.blocked.Load()) > 0 {
10251077
p.c.Wait()
10261078
}
10271079
}()

‎pkg/kgo/sink.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ func (s *sink) produce(sem <-chan struct{}) bool {
258258
// We could have been triggered from a metadata update even though the
259259
// user is not producing at all. If we have no buffered records, let's
260260
// avoid potentially creating a producer ID.
261-
if s.cl.producer.bufferedRecords.Load() == 0 {
261+
if s.cl.BufferedProduceRecords() == 0 {
262262
return false
263263
}
264264

0 commit comments

Comments
 (0)
Please sign in to comment.