Skip to content

Commit 305d8dc

Browse files
committedJul 29, 2024·
kgo: allow record ctx cancelation to propagate a bit more
If a record's context is canceled, we now allow it to be failed in two more locations: * while the producer ID is loading -- we can actually now cancel the producer ID loading request (which may also benefit people using transactions that want to force quit the client) * while a sink is backing off due to request failures For people using transactions, canceling a context now allows you to force quit in more areas, but the same caveat applies: your client will likely end up in an invalid transactional state and be unable to continue. For #769.
1 parent d4982d7 commit 305d8dc

File tree

7 files changed

+191
-30
lines changed

7 files changed

+191
-30
lines changed
 

‎pkg/kgo/errors.go

+14-3
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ func isRetryableBrokerErr(err error) bool {
5353
}
5454
// We could have a retryable producer ID failure, which then bubbled up
5555
// as errProducerIDLoadFail so as to be retried later.
56-
if errors.Is(err, errProducerIDLoadFail) {
56+
if pe := (*errProducerIDLoadFail)(nil); errors.As(err, &pe) {
5757
return true
5858
}
5959
// We could have chosen a broker, and then a concurrent metadata update
@@ -139,8 +139,6 @@ var (
139139
// restart a new connection ourselves.
140140
errSaslReauthLoop = errors.New("the broker is repeatedly giving us sasl lifetimes that are too short to write a request")
141141

142-
errProducerIDLoadFail = errors.New("unable to initialize a producer ID due to request failures")
143-
144142
// A temporary error returned when Kafka replies with a different
145143
// correlation ID than we were expecting for the request the client
146144
// issued.
@@ -224,6 +222,19 @@ type ErrFirstReadEOF struct {
224222
err error
225223
}
226224

225+
type errProducerIDLoadFail struct {
226+
err error
227+
}
228+
229+
func (e *errProducerIDLoadFail) Error() string {
230+
if e.err == nil {
231+
return "unable to initialize a producer ID due to request failures"
232+
}
233+
return fmt.Sprintf("unable to initialize a producer ID due to request failures: %v", e.err)
234+
}
235+
236+
func (e *errProducerIDLoadFail) Unwrap() error { return e.err }
237+
227238
const (
228239
firstReadSASL uint8 = iota
229240
firstReadTLS

‎pkg/kgo/helpers_test.go

+26
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,32 @@ var (
5555
npartitionsAt int64
5656
)
5757

58+
type slowConn struct {
59+
net.Conn
60+
}
61+
62+
func (s *slowConn) Write(p []byte) (int, error) {
63+
time.Sleep(100 * time.Millisecond)
64+
return s.Conn.Write(p)
65+
}
66+
67+
func (s *slowConn) Read(p []byte) (int, error) {
68+
time.Sleep(100 * time.Millisecond)
69+
return s.Conn.Read(p)
70+
}
71+
72+
type slowDialer struct {
73+
d net.Dialer
74+
}
75+
76+
func (s *slowDialer) DialContext(ctx context.Context, network, host string) (net.Conn, error) {
77+
c, err := s.d.DialContext(ctx, network, host)
78+
if err != nil {
79+
return nil, err
80+
}
81+
return &slowConn{c}, nil
82+
}
83+
5884
func init() {
5985
var err error
6086
if n, _ := strconv.Atoi(os.Getenv("KGO_TEST_RF")); n > 0 {

‎pkg/kgo/produce_request_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ func TestIssue769(t *testing.T) {
150150
case <-timer.C:
151151
t.Fatal("expected record to fail within 3s")
152152
}
153-
if pe := (*errProducerIDLoadFail)(nil); !errors.As(rerr, &pe) || !errors.Is(pe.err, context.Canceled) {
153+
if pe := (*errProducerIDLoadFail)(nil); !errors.As(rerr, &pe) || !(errors.Is(pe.err, context.Canceled) || strings.Contains(pe.err.Error(), "canceled")) {
154154
t.Errorf("got %v != exp errProducerIDLoadFail{context.Canceled}", rerr)
155155
}
156156
}

‎pkg/kgo/producer.go

+17-9
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,11 @@ func (cl *Client) TryProduce(
362362
// retries. If any of these conditions are hit and it is currently safe to fail
363363
// records, all buffered records for the relevant partition are failed. Only
364364
// the first record's context in a batch is considered when determining whether
365-
// the batch should be canceled.
365+
// the batch should be canceled. A record is not safe to fail if the client
366+
// is idempotently producing and a request has been sent; in this case, the
367+
// client cannot know if the broker actually processed the request (if so, then
368+
// removing the records from the client will create errors the next time you
369+
// produce).
366370
//
367371
// If the client is transactional and a transaction has not been begun, the
368372
// promise is immediately called with an error corresponding to not being in a
@@ -679,7 +683,7 @@ func (cl *Client) ProducerID(ctx context.Context) (int64, int16, error) {
679683

680684
go func() {
681685
defer close(done)
682-
id, epoch, err = cl.producerID()
686+
id, epoch, err = cl.producerID(ctx2fn(ctx))
683687
}()
684688

685689
select {
@@ -701,7 +705,7 @@ var errReloadProducerID = errors.New("producer id needs reloading")
701705
// initProducerID initializes the client's producer ID for idempotent
702706
// producing only (no transactions, which are more special). After the first
703707
// load, this clears all buffered unknown topics.
704-
func (cl *Client) producerID() (int64, int16, error) {
708+
func (cl *Client) producerID(ctxFn func() context.Context) (int64, int16, error) {
705709
p := &cl.producer
706710

707711
id := p.id.Load().(*producerID)
@@ -730,7 +734,7 @@ func (cl *Client) producerID() (int64, int16, error) {
730734
}
731735
p.id.Store(id)
732736
} else {
733-
newID, keep := cl.doInitProducerID(id.id, id.epoch)
737+
newID, keep := cl.doInitProducerID(ctxFn, id.id, id.epoch)
734738
if keep {
735739
id = newID
736740
// Whenever we have a new producer ID, we need
@@ -748,7 +752,7 @@ func (cl *Client) producerID() (int64, int16, error) {
748752
id = &producerID{
749753
id: id.id,
750754
epoch: id.epoch,
751-
err: errProducerIDLoadFail,
755+
err: &errProducerIDLoadFail{newID.err},
752756
}
753757
}
754758
}
@@ -825,7 +829,7 @@ func (cl *Client) failProducerID(id int64, epoch int16, err error) {
825829

826830
// doInitProducerID inits the idempotent ID and potentially the transactional
827831
// producer epoch, returning whether to keep the result.
828-
func (cl *Client) doInitProducerID(lastID int64, lastEpoch int16) (*producerID, bool) {
832+
func (cl *Client) doInitProducerID(ctxFn func() context.Context, lastID int64, lastEpoch int16) (*producerID, bool) {
829833
cl.cfg.logger.Log(LogLevelInfo, "initializing producer id")
830834
req := kmsg.NewPtrInitProducerIDRequest()
831835
req.TransactionalID = cl.cfg.txnID
@@ -835,7 +839,8 @@ func (cl *Client) doInitProducerID(lastID int64, lastEpoch int16) (*producerID,
835839
req.TransactionTimeoutMillis = int32(cl.cfg.txnTimeout.Milliseconds())
836840
}
837841

838-
resp, err := req.RequestWith(cl.ctx, cl)
842+
ctx := ctxFn()
843+
resp, err := req.RequestWith(ctx, cl)
839844
if err != nil {
840845
if errors.Is(err, errUnknownRequestKey) || errors.Is(err, errBrokerTooOld) {
841846
cl.cfg.logger.Log(LogLevelInfo, "unable to initialize a producer id because the broker is too old or the client is pinned to an old version, continuing without a producer id")
@@ -940,13 +945,14 @@ func (cl *Client) addUnknownTopicRecord(pr promisedRec) {
940945
}
941946
unknown.buffered = append(unknown.buffered, pr)
942947
if len(unknown.buffered) == 1 {
943-
go cl.waitUnknownTopic(pr.ctx, pr.Topic, unknown)
948+
go cl.waitUnknownTopic(pr.ctx, pr.Record.Context, pr.Topic, unknown)
944949
}
945950
}
946951

947952
// waitUnknownTopic waits for a notification
948953
func (cl *Client) waitUnknownTopic(
949-
rctx context.Context,
954+
pctx context.Context, // context passed to Produce
955+
rctx context.Context, // context on the record itself
950956
topic string,
951957
unknown *unknownTopicProduces,
952958
) {
@@ -974,6 +980,8 @@ func (cl *Client) waitUnknownTopic(
974980

975981
for err == nil {
976982
select {
983+
case <-pctx.Done():
984+
err = pctx.Err()
977985
case <-rctx.Done():
978986
err = rctx.Err()
979987
case <-cl.ctx.Done():

‎pkg/kgo/sink.go

+116-5
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ func (s *sink) maybeBackoff() {
208208
select {
209209
case <-after.C:
210210
case <-s.cl.ctx.Done():
211+
case <-s.anyCtx().Done():
211212
}
212213
}
213214

@@ -247,6 +248,34 @@ func (s *sink) drain() {
247248
}
248249
}
249250

251+
// Returns the first context encountered ranging across all records.
252+
// This does not use defers to make it clear at the return that all
253+
// unlocks are called in proper order. Ideally, do not call this func
254+
// due to lock intensity.
255+
func (s *sink) anyCtx() context.Context {
256+
s.recBufsMu.Lock()
257+
for _, recBuf := range s.recBufs {
258+
recBuf.mu.Lock()
259+
if len(recBuf.batches) > 0 {
260+
batch0 := recBuf.batches[0]
261+
batch0.mu.Lock()
262+
if batch0.canFailFromLoadErrs && len(batch0.records) > 0 {
263+
r0 := batch0.records[0]
264+
if rctx := r0.cancelingCtx(); rctx != nil {
265+
batch0.mu.Unlock()
266+
recBuf.mu.Unlock()
267+
s.recBufsMu.Unlock()
268+
return rctx
269+
}
270+
}
271+
batch0.mu.Unlock()
272+
}
273+
recBuf.mu.Unlock()
274+
}
275+
s.recBufsMu.Unlock()
276+
return context.Background()
277+
}
278+
250279
func (s *sink) produce(sem <-chan struct{}) bool {
251280
var produced bool
252281
defer func() {
@@ -267,6 +296,7 @@ func (s *sink) produce(sem <-chan struct{}) bool {
267296
// - auth failure
268297
// - transactional: a produce failure that failed the producer ID
269298
// - AddPartitionsToTxn failure (see just below)
299+
// - some head-of-line context failure
270300
//
271301
// All but the first error is fatal. Recovery may be possible with
272302
// EndTransaction in specific cases, but regardless, all buffered
@@ -275,10 +305,71 @@ func (s *sink) produce(sem <-chan struct{}) bool {
275305
// NOTE: we init the producer ID before creating a request to ensure we
276306
// are always using the latest id/epoch with the proper sequence
277307
// numbers. (i.e., resetAllSequenceNumbers && producerID logic combo).
278-
id, epoch, err := s.cl.producerID()
308+
//
309+
// For the first-discovered-record-head-of-line context, we want to
310+
// avoid looking it up if possible (which is why producerID takes a
311+
// ctxFn). If we do use one, we want to be sure that the
312+
// context.Canceled error is from *that* context rather than the client
313+
// context or something else. So, we go through some special care to
314+
// track setting the ctx / looking up if it is canceled.
315+
var holCtxMu sync.Mutex
316+
var holCtx context.Context
317+
ctxFn := func() context.Context {
318+
holCtxMu.Lock()
319+
defer holCtxMu.Unlock()
320+
holCtx = s.anyCtx()
321+
return holCtx
322+
}
323+
isHolCtxDone := func() bool {
324+
holCtxMu.Lock()
325+
defer holCtxMu.Unlock()
326+
if holCtx == nil {
327+
return false
328+
}
329+
select {
330+
case <-holCtx.Done():
331+
return true
332+
default:
333+
}
334+
return false
335+
}
336+
337+
id, epoch, err := s.cl.producerID(ctxFn)
279338
if err != nil {
339+
var pe *errProducerIDLoadFail
280340
switch {
281-
case errors.Is(err, errProducerIDLoadFail):
341+
case errors.As(err, &pe):
342+
if errors.Is(pe.err, context.Canceled) && isHolCtxDone() {
343+
// Some head-of-line record in a partition had a context cancelation.
344+
// We look for any partition with HOL cancelations and fail them all.
345+
s.cl.cfg.logger.Log(LogLevelInfo, "the first record in some partition(s) had a context cancelation; failing all relevant partitions", "broker", logID(s.nodeID))
346+
s.recBufsMu.Lock()
347+
defer s.recBufsMu.Unlock()
348+
for _, recBuf := range s.recBufs {
349+
recBuf.mu.Lock()
350+
var failAll bool
351+
if len(recBuf.batches) > 0 {
352+
batch0 := recBuf.batches[0]
353+
batch0.mu.Lock()
354+
if batch0.canFailFromLoadErrs && len(batch0.records) > 0 {
355+
r0 := batch0.records[0]
356+
if rctx := r0.cancelingCtx(); rctx != nil {
357+
select {
358+
case <-rctx.Done():
359+
failAll = true // we must not call failAllRecords here, because failAllRecords locks batches!
360+
default:
361+
}
362+
}
363+
}
364+
batch0.mu.Unlock()
365+
}
366+
if failAll {
367+
recBuf.failAllRecords(err)
368+
}
369+
recBuf.mu.Unlock()
370+
}
371+
return true
372+
}
282373
s.cl.bumpRepeatedLoadErr(err)
283374
s.cl.cfg.logger.Log(LogLevelWarn, "unable to load producer ID, bumping client's buffered record load errors by 1 and retrying")
284375
return true // whatever caused our produce, we did nothing, so keep going
@@ -385,6 +476,9 @@ func (s *sink) doSequenced(
385476
promise: promise,
386477
}
387478

479+
// We can NOT use any record context. If we do, we force the request to
480+
// fail while also force the batch to be unfailable (due to no
481+
// response),
388482
br, err := s.cl.brokerOrErr(s.cl.ctx, s.nodeID, errUnknownBroker)
389483
if err != nil {
390484
wait.err = err
@@ -432,6 +526,11 @@ func (s *sink) doTxnReq(
432526
req.batches.eachOwnerLocked(seqRecBatch.removeFromTxn)
433527
}
434528
}()
529+
// We do NOT let record context cancelations fail this request: doing
530+
// so would put the transactional ID in an unknown state. This is
531+
// similar to the warning we give in the txn.go file, but the
532+
// difference there is the user knows explicitly at the function call
533+
// that canceling the context will opt them into invalid state.
435534
err = s.cl.doWithConcurrentTransactions(s.cl.ctx, "AddPartitionsToTxn", func() error {
436535
stripped, err = s.issueTxnReq(req, txnReq)
437536
return err
@@ -1422,6 +1521,16 @@ type promisedRec struct {
14221521
*Record
14231522
}
14241523

1524+
func (pr promisedRec) cancelingCtx() context.Context {
1525+
if pr.ctx.Done() != nil {
1526+
return pr.ctx
1527+
}
1528+
if pr.Context.Done() != nil {
1529+
return pr.Context
1530+
}
1531+
return nil
1532+
}
1533+
14251534
// recBatch is the type used for buffering records before they are written.
14261535
type recBatch struct {
14271536
owner *recBuf // who owns us
@@ -1454,10 +1563,12 @@ type recBatch struct {
14541563
// Returns an error if the batch should fail.
14551564
func (b *recBatch) maybeFailErr(cfg *cfg) error {
14561565
if len(b.records) > 0 {
1457-
ctx := b.records[0].ctx
1566+
r0 := &b.records[0]
14581567
select {
1459-
case <-ctx.Done():
1460-
return ctx.Err()
1568+
case <-r0.ctx.Done():
1569+
return r0.ctx.Err()
1570+
case <-r0.Context.Done():
1571+
return r0.Context.Err()
14611572
default:
14621573
}
14631574
}

‎pkg/kgo/source.go

+3
Original file line numberDiff line numberDiff line change
@@ -956,6 +956,9 @@ func (s *source) fetch(consumerSession *consumerSession, doneFetch chan<- struct
956956
// reload offsets *always* triggers a metadata update.
957957
if updateWhy != nil {
958958
why := updateWhy.reason(fmt.Sprintf("fetch had inner topic errors from broker %d", s.nodeID))
959+
// loadWithSessionNow triggers a metadata update IF there are
960+
// offsets to reload. If there are no offsets to reload, we
961+
// trigger one here.
959962
if !reloadOffsets.loadWithSessionNow(consumerSession, why) {
960963
if updateWhy.isOnly(kerr.UnknownTopicOrPartition) || updateWhy.isOnly(kerr.UnknownTopicID) {
961964
s.cl.triggerUpdateMetadata(false, why)

‎pkg/kgo/txn.go

+14-12
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ import (
1313
"github.com/twmb/franz-go/pkg/kerr"
1414
)
1515

16+
func ctx2fn(ctx context.Context) func() context.Context { return func() context.Context { return ctx } }
17+
1618
// TransactionEndTry is simply a named bool.
1719
type TransactionEndTry bool
1820

@@ -468,7 +470,7 @@ func (cl *Client) BeginTransaction() error {
468470
return errors.New("invalid attempt to begin a transaction while already in a transaction")
469471
}
470472

471-
needRecover, didRecover, err := cl.maybeRecoverProducerID()
473+
needRecover, didRecover, err := cl.maybeRecoverProducerID(context.Background())
472474
if needRecover && !didRecover {
473475
cl.cfg.logger.Log(LogLevelInfo, "unable to begin transaction due to unrecoverable producer id error", "err", err)
474476
return fmt.Errorf("producer ID has a fatal, unrecoverable error, err: %w", err)
@@ -557,7 +559,7 @@ func (cl *Client) EndAndBeginTransaction(
557559
// expect to be in one.
558560
defer func() {
559561
if rerr == nil {
560-
needRecover, didRecover, err := cl.maybeRecoverProducerID()
562+
needRecover, didRecover, err := cl.maybeRecoverProducerID(ctx)
561563
if needRecover && !didRecover {
562564
cl.cfg.logger.Log(LogLevelInfo, "unable to begin transaction due to unrecoverable producer id error", "err", err)
563565
rerr = fmt.Errorf("producer ID has a fatal, unrecoverable error, err: %w", err)
@@ -620,12 +622,12 @@ func (cl *Client) EndAndBeginTransaction(
620622
}
621623

622624
// From EndTransaction: if the pid has an error, we may try to recover.
623-
id, epoch, err := cl.producerID()
625+
id, epoch, err := cl.producerID(ctx2fn(ctx))
624626
if err != nil {
625627
if commit {
626628
return kerr.OperationNotAttempted
627629
}
628-
if _, didRecover, _ := cl.maybeRecoverProducerID(); didRecover {
630+
if _, didRecover, _ := cl.maybeRecoverProducerID(ctx); didRecover {
629631
return nil
630632
}
631633
}
@@ -882,7 +884,7 @@ func (cl *Client) EndTransaction(ctx context.Context, commit TransactionEndTry)
882884
return nil
883885
}
884886

885-
id, epoch, err := cl.producerID()
887+
id, epoch, err := cl.producerID(ctx2fn(ctx))
886888
if err != nil {
887889
if commit {
888890
return kerr.OperationNotAttempted
@@ -892,7 +894,7 @@ func (cl *Client) EndTransaction(ctx context.Context, commit TransactionEndTry)
892894
// there is no reason to issue an abort now that the id is
893895
// different. Otherwise, we issue our EndTxn which will likely
894896
// fail, but that is ok, we will just return error.
895-
_, didRecover, _ := cl.maybeRecoverProducerID()
897+
_, didRecover, _ := cl.maybeRecoverProducerID(ctx)
896898
if didRecover {
897899
return nil
898900
}
@@ -939,11 +941,11 @@ func (cl *Client) EndTransaction(ctx context.Context, commit TransactionEndTry)
939941
// error), whether it is possible to recover, and, if not, the error.
940942
//
941943
// We call this when beginning a transaction or when ending with an abort.
942-
func (cl *Client) maybeRecoverProducerID() (necessary, did bool, err error) {
944+
func (cl *Client) maybeRecoverProducerID(ctx context.Context) (necessary, did bool, err error) {
943945
cl.producer.mu.Lock()
944946
defer cl.producer.mu.Unlock()
945947

946-
id, epoch, err := cl.producerID()
948+
id, epoch, err := cl.producerID(ctx2fn(ctx))
947949
if err == nil {
948950
return false, false, nil
949951
}
@@ -1009,7 +1011,7 @@ start:
10091011
select {
10101012
case <-time.After(backoff):
10111013
case <-ctx.Done():
1012-
cl.cfg.logger.Log(LogLevelError, fmt.Sprintf("abandoning %s retry due to client ctx quitting", name))
1014+
cl.cfg.logger.Log(LogLevelError, fmt.Sprintf("abandoning %s retry due to request ctx quitting", name))
10131015
return err
10141016
case <-cl.ctx.Done():
10151017
cl.cfg.logger.Log(LogLevelError, fmt.Sprintf("abandoning %s retry due to client ctx quitting", name))
@@ -1081,7 +1083,7 @@ func (cl *Client) commitTransactionOffsets(
10811083
}
10821084

10831085
if !g.offsetsAddedToTxn {
1084-
if err := cl.addOffsetsToTxn(g.ctx, g.cfg.group); err != nil {
1086+
if err := cl.addOffsetsToTxn(ctx, g.cfg.group); err != nil {
10851087
if onDone != nil {
10861088
onDone(nil, nil, err)
10871089
}
@@ -1111,7 +1113,7 @@ func (cl *Client) commitTransactionOffsets(
11111113
// this initializes one if it is not yet initialized. This would only be the
11121114
// case if trying to commit before any records have been sent.
11131115
func (cl *Client) addOffsetsToTxn(ctx context.Context, group string) error {
1114-
id, epoch, err := cl.producerID()
1116+
id, epoch, err := cl.producerID(ctx2fn(ctx))
11151117
if err != nil {
11161118
return err
11171119
}
@@ -1218,7 +1220,7 @@ func (g *groupConsumer) prepareTxnOffsetCommit(ctx context.Context, uncommitted
12181220

12191221
// We're now generating the producerID before addOffsetsToTxn.
12201222
// We will not make this request until after addOffsetsToTxn, but it's possible to fail here due to a failed producerID.
1221-
id, epoch, err := g.cl.producerID()
1223+
id, epoch, err := g.cl.producerID(ctx2fn(ctx))
12221224
if err != nil {
12231225
return req, err
12241226
}

0 commit comments

Comments
 (0)
Please sign in to comment.