diff --git a/js.go b/js.go index e34b297d9..d9df981d8 100644 --- a/js.go +++ b/js.go @@ -2574,9 +2574,9 @@ func (sub *Subscription) Fetch(batch int, opts ...PullOpt) ([]*Msg, error) { // Check if context not done already before making the request. select { case <-ctx.Done(): - if ctx.Err() == context.Canceled { + if o.ctx != nil { // Timeout or Cancel triggered by context object option err = ctx.Err() - } else { + } else { // Timeout triggered by timeout option err = ErrTimeout } default: diff --git a/test/js_test.go b/test/js_test.go index 205a21139..de2efbf1a 100644 --- a/test/js_test.go +++ b/test/js_test.go @@ -6867,6 +6867,22 @@ func testJetStreamFetchContext(t *testing.T, srvs ...*jsServer) { t.Errorf("Expected %d pending messages, got: %d", pending, total) } }) + + t.Run("MaxWait timeout should return nats error", func(t *testing.T) { + _, err := sub.Fetch(1, nats.MaxWait(1*time.Nanosecond)) + if !errors.Is(err, nats.ErrTimeout) { + t.Fatalf("Expect ErrTimeout, got err=%#v", err) + } + }) + + t.Run("Context timeout should return context error", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond) + defer cancel() + _, err := sub.Fetch(1, nats.Context(ctx)) + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("Expect context.DeadlineExceeded, got err=%#v", err) + } + }) } func TestJetStreamSubscribeContextCancel(t *testing.T) {