From 8f86c1d37a4402ab7d8169409f035eebbb04b137 Mon Sep 17 00:00:00 2001 From: Waldemar Quevedo Date: Mon, 4 Oct 2021 18:12:50 -0700 Subject: [PATCH] js: Fix context usage with sub.Fetch and msg.Ack The deadline of a context is now used to calculate the time used for `expires` instead of the default `ttl` of the JetStream context which was 5s. This was preventing library users from passing a context with a custom timeout. This also disallows the usage of `context.Background` to make it explicit that `sub.Fetch` has to be used with a context that has a timeout since each fetch request has to include an expire time anyway. In case `context.WithCancel` is used, then a child context with the same duration as the JetStream context default timeout will be created. Also in case msg.Ack it was possible to pass both timeout and a context which would have been ambiguous and only context option being used. Signed-off-by: Waldemar Quevedo --- js.go | 44 ++++-- test/js_test.go | 354 +++++++++++++++++++++++++++++++++++++++--------- 2 files changed, 328 insertions(+), 70 deletions(-) diff --git a/js.go b/js.go index 4d4adbdac..b1213a4dd 100644 --- a/js.go +++ b/js.go @@ -2148,6 +2148,8 @@ func (sub *Subscription) Fetch(batch int, opts ...PullOpt) ([]*Msg, error) { js := sub.jsi.js pmc := len(sub.mch) > 0 + // All fetch requests have an expiration, in case of no explicit expiration + // then the default timeout of the JetStream context is used. ttl := o.ttl if ttl == 0 { ttl = js.opts.wait @@ -2161,9 +2163,20 @@ func (sub *Subscription) Fetch(batch int, opts ...PullOpt) ([]*Msg, error) { err error cancel context.CancelFunc ) - if o.ctx == nil { + if ctx == nil { ctx, cancel = context.WithTimeout(context.Background(), ttl) defer cancel() + } else if _, hasDeadline := ctx.Deadline(); !hasDeadline { + // Prevent from passing the background context which will just block + // and cannot be canceled either. + if octx, ok := ctx.(ContextOpt); ok && octx.Context == context.Background() { + return nil, ErrNoDeadlineContext + } + + // If the context did not have a deadline, then create a new child context + // that will use the default timeout from the JS context. + ctx, cancel = context.WithTimeout(ctx, ttl) + defer cancel() } // Check if context not done already before making the request. @@ -2180,6 +2193,9 @@ func (sub *Subscription) Fetch(batch int, opts ...PullOpt) ([]*Msg, error) { return nil, err } + // Use the deadline of the context to base the expire times. + deadline, _ := ctx.Deadline() + ttl = time.Until(deadline) checkCtxErr := func(err error) error { if o.ctx == nil && err == context.DeadlineExceeded { return ErrTimeout @@ -2188,9 +2204,8 @@ func (sub *Subscription) Fetch(batch int, opts ...PullOpt) ([]*Msg, error) { } var ( - msgs = make([]*Msg, 0, batch) - msg *Msg - start = time.Now() + msgs = make([]*Msg, 0, batch) + msg *Msg ) for pmc && len(msgs) < batch { // Check next msg with booleans that say that this is an internal call @@ -2218,11 +2233,18 @@ func (sub *Subscription) Fetch(batch int, opts ...PullOpt) ([]*Msg, error) { var nr nextRequest sendReq := func() error { - ttl -= time.Since(start) - if ttl < 0 { - // At this point consider that we have timed-out - return context.DeadlineExceeded + // The current deadline for the context will be used + // to set the expires TTL for a fetch request. + deadline, _ = ctx.Deadline() + ttl = time.Until(deadline) + + // Check if context has already been canceled or expired. + select { + case <-ctx.Done(): + return ctx.Err() + default: } + // Make our request expiration a bit shorter than the current timeout. expires := ttl if ttl >= 20*time.Millisecond { @@ -2343,6 +2365,12 @@ func (m *Msg) ackReply(ackType []byte, sync bool, opts ...AckOpt) error { usesCtx := o.ctx != nil usesWait := o.ttl > 0 + + // Only allow either AckWait or Context option to set the timeout. + if usesWait && usesCtx { + return ErrContextAndTimeout + } + sync = sync || usesCtx || usesWait ctx := o.ctx wait := defaultRequestWait diff --git a/test/js_test.go b/test/js_test.go index 1c96656b9..b74cd93fc 100644 --- a/test/js_test.go +++ b/test/js_test.go @@ -2552,6 +2552,12 @@ func TestJetStreamSubscribe_AckPolicy(t *testing.T) { ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second) defer cancel() + // Prevent double context and ack wait options. + err = msg.AckSync(nats.Context(ctx), nats.AckWait(1*time.Second)) + if err != nats.ErrContextAndTimeout { + t.Errorf("Unexpected error: %v", err) + } + err = msg.AckSync(nats.Context(ctx)) if err != nil { t.Errorf("Unexpected error: %v", err) @@ -2587,6 +2593,13 @@ func TestJetStreamSubscribe_AckPolicy(t *testing.T) { if got != expected { t.Errorf("Expected %v, got %v", expected, got) } + + // Prevent double context and ack wait options. + err = msg.Nak(nats.Context(ctx), nats.AckWait(1*time.Second)) + if err != nats.ErrContextAndTimeout { + t.Errorf("Unexpected error: %v", err) + } + // Skip the message. err = msg.Nak() if err != nil { @@ -2616,6 +2629,13 @@ func TestJetStreamSubscribe_AckPolicy(t *testing.T) { if got != expected { t.Errorf("Expected %v, got %v", expected, got) } + + // Prevent double context and ack wait options. + err = msg.Term(nats.Context(ctx), nats.AckWait(1*time.Second)) + if err != nats.ErrContextAndTimeout { + t.Errorf("Unexpected error: %v", err) + } + err = msg.Term() if err != nil { t.Errorf("Unexpected error: %v", err) @@ -2649,6 +2669,13 @@ func TestJetStreamSubscribe_AckPolicy(t *testing.T) { if got != expected { t.Errorf("Expected %v, got %v", expected, got) } + + // Prevent double context and ack wait options. + err = msg.InProgress(nats.Context(ctx), nats.AckWait(1*time.Second)) + if err != nats.ErrContextAndTimeout { + t.Errorf("Unexpected error: %v", err) + } + err = msg.InProgress(nctx) if err != nil { t.Errorf("Unexpected error: %v", err) @@ -5524,68 +5551,6 @@ func testJetStreamFetchOptions(t *testing.T, srvs ...*jsServer) { } }) - t.Run("pull with context", func(t *testing.T) { - defer js.PurgeStream(subject) - - expected := 10 - sendMsgs(t, expected) - sub, err := js.PullSubscribe(subject, "batch-ctx") - if err != nil { - t.Fatal(err) - } - defer sub.Unsubscribe() - - ctx, cancel := context.WithCancel(context.Background()) - cancel() - - // Should fail with expired context. - _, err = sub.Fetch(expected, nats.Context(ctx)) - if err == nil { - t.Fatal("Unexpected success") - } - if err != context.Canceled { - t.Errorf("Expected context deadline exceeded error, got: %v", err) - } - - ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - - msgs, err := sub.Fetch(expected, nats.Context(ctx)) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - - got := len(msgs) - if got != expected { - t.Fatalf("Got %v messages, expected at least: %v", got, expected) - } - - for _, msg := range msgs { - msg.AckSync() - } - - // Next fetch will timeout since no more messages. - _, err = sub.Fetch(1, nats.MaxWait(250*time.Millisecond)) - if err != nats.ErrTimeout { - t.Errorf("Expected timeout fetching next message, got: %v", err) - } - - expected = 5 - sendMsgs(t, expected) - msgs, err = sub.Fetch(expected, nats.MaxWait(1*time.Second)) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - got = len(msgs) - if got != expected { - t.Fatalf("Got %v messages, expected at least: %v", got, expected) - } - - for _, msg := range msgs { - msg.Ack() - } - }) - t.Run("fetch after unsubscribe", func(t *testing.T) { defer js.PurgeStream(subject) @@ -6487,3 +6452,268 @@ func TestJetStreamMsgSubjectRewrite(t *testing.T) { t.Fatalf("Unexepcted data: %q", msg.Data) } } + +func TestJetStreamPullSubscribeFetchContext(t *testing.T) { + withJSCluster(t, "PULLCTX", 3, testJetStreamFetchContext) +} + +func testJetStreamFetchContext(t *testing.T, srvs ...*jsServer) { + srv := srvs[0] + nc, err := nats.Connect(srv.ClientURL()) + if err != nil { + t.Error(err) + } + defer nc.Close() + + js, err := nc.JetStream() + if err != nil { + t.Fatal(err) + } + + subject := "WQ" + _, err = js.AddStream(&nats.StreamConfig{ + Name: subject, + Replicas: 3, + }) + if err != nil { + t.Fatal(err) + } + + sendMsgs := func(t *testing.T, totalMsgs int) { + t.Helper() + for i := 0; i < totalMsgs; i++ { + payload := fmt.Sprintf("i:%d", i) + _, err := js.Publish(subject, []byte(payload)) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + } + } + expected := 10 + sendMsgs(t, expected) + + sub, err := js.PullSubscribe(subject, "batch-ctx") + if err != nil { + t.Fatal(err) + } + defer sub.Unsubscribe() + + t.Run("ctx background", func(t *testing.T) { + _, err = sub.Fetch(expected, nats.Context(context.Background())) + if err == nil { + t.Fatal("Unexpected success") + } + if err != nats.ErrNoDeadlineContext { + t.Errorf("Expected context deadline error, got: %v", err) + } + }) + + t.Run("ctx canceled", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + cancel() + + _, err = sub.Fetch(expected, nats.Context(ctx)) + if err == nil { + t.Fatal("Unexpected success") + } + if err != context.Canceled { + t.Errorf("Expected context deadline error, got: %v", err) + } + + ctx, cancel = context.WithCancel(context.Background()) + cancel() + + _, err = sub.Fetch(expected, nats.Context(ctx)) + if err == nil { + t.Fatal("Unexpected success") + } + if err != context.Canceled { + t.Errorf("Expected context deadline error, got: %v", err) + } + }) + + t.Run("ctx timeout", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + msgs, err := sub.Fetch(expected, nats.Context(ctx)) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + got := len(msgs) + if got != expected { + t.Fatalf("Got %v messages, expected at least: %v", got, expected) + } + info, err := sub.ConsumerInfo() + if err != nil { + t.Error(err) + } + if info.NumAckPending != expected { + t.Errorf("Expected %d pending acks, got: %d", expected, info.NumAckPending) + } + + for _, msg := range msgs { + msg.AckSync() + } + + info, err = sub.ConsumerInfo() + if err != nil { + t.Error(err) + } + if info.NumAckPending > 0 { + t.Errorf("Expected no pending acks, got: %d", info.NumAckPending) + } + + // No messages at this point. + ctx, cancel = context.WithTimeout(ctx, 250*time.Millisecond) + defer cancel() + + _, err = sub.Fetch(1, nats.Context(ctx)) + if err != context.DeadlineExceeded { + t.Errorf("Expected deadline exceeded fetching next message, got: %v", err) + } + + // Send more messages then pull them with a new context + expected = 5 + sendMsgs(t, expected) + + ctx, cancel = context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + // Single message fetch. + msgs, err = sub.Fetch(1, nats.Context(ctx)) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if len(msgs) != 1 { + t.Fatalf("Expected to receive a single message, got: %d", len(msgs)) + } + for _, msg := range msgs { + msg.Ack() + } + + // Fetch multiple messages. + expected = 4 + msgs, err = sub.Fetch(expected, nats.Context(ctx)) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + got = len(msgs) + if got != expected { + t.Fatalf("Got %v messages, expected at least: %v", got, expected) + } + for _, msg := range msgs { + msg.AckSync() + } + + info, err = sub.ConsumerInfo() + if err != nil { + t.Error(err) + } + if info.NumAckPending > 0 { + t.Errorf("Expected no pending acks, got: %d", info.NumAckPending) + } + }) + + t.Run("ctx with cancel", func(t *testing.T) { + // New JS context with slightly shorter timeout than default. + js, err = nc.JetStream(nats.MaxWait(2 * time.Second)) + if err != nil { + t.Fatal(err) + } + + sub, err := js.PullSubscribe(subject, "batch-cancel-ctx") + if err != nil { + t.Fatal(err) + } + defer sub.Unsubscribe() + + // Parent context + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Fetch all the messages as needed. + info, err := sub.ConsumerInfo() + if err != nil { + t.Fatal(err) + } + total := info.NumPending + + // Child context with timeout with the same duration as JS context timeout + // will be created to fetch next message. + msgs, err := sub.Fetch(1, nats.Context(ctx)) + if err != nil { + t.Fatal(err) + } + if len(msgs) != 1 { + t.Fatalf("Expected a message, got: %d", len(msgs)) + } + for _, msg := range msgs { + msg.AckSync() + } + + // Fetch the rest using same cancellation context. + expected := int(total - 1) + msgs, err = sub.Fetch(expected, nats.Context(ctx)) + if err != nil { + t.Fatal(err) + } + if len(msgs) != expected { + t.Fatalf("Expected %d messages, got: %d", expected, len(msgs)) + } + for _, msg := range msgs { + msg.AckSync() + } + + // Fetch more messages and wait for timeout since there are none. + _, err = sub.Fetch(expected, nats.Context(ctx)) + if err == nil { + t.Fatal("Unexpected success") + } + if err != context.DeadlineExceeded { + t.Fatalf("Expected deadline exceeded fetching next message, got: %v", err) + } + + // Original cancellation context is not yet canceled, it should still work. + if ctx.Err() != nil { + t.Fatalf("Expected no errors in original cancellation context, got: %v", ctx.Err()) + } + + // Should be possible to use the same context again. + sendMsgs(t, 5) + + // Get the next message to leave 4 pending. + var pending uint64 = 4 + msgs, err = sub.Fetch(1, nats.Context(ctx)) + if err != nil { + t.Fatal(err) + } + if len(msgs) != 1 { + t.Fatalf("Expected a message, got: %d", len(msgs)) + } + for _, msg := range msgs { + msg.AckSync() + } + + // Cancel finally. + cancel() + + _, err = sub.Fetch(1, nats.Context(ctx)) + if err == nil { + t.Fatal("Unexpected success") + } + if err != context.Canceled { + t.Fatalf("Expected deadline exceeded fetching next message, got: %v", err) + } + + info, err = sub.ConsumerInfo() + if err != nil { + t.Fatal(err) + } + total = info.NumPending + if total != pending { + t.Errorf("Expected %d pending messages, got: %d", pending, total) + } + }) +}