From 16d26f8e7b6c09e2142e8906bee021b148ca7ecf Mon Sep 17 00:00:00 2001 From: Waldemar Quevedo Date: Tue, 14 Dec 2021 02:34:56 -0800 Subject: [PATCH] js: make js.Subscribe context aware Can now attach a context to a subscription so that it is unsubscribed and/or consumer deleted via propagation of cancellation via parent context. Signed-off-by: Waldemar Quevedo --- js.go | 26 +++++++++++++ kv.go | 3 ++ nats.go | 14 +++++++ test/js_test.go | 99 ++++++++++++++++++++++++++++++++++++++++++++++++- test/kv_test.go | 31 ++++++++++++++++ 5 files changed, 172 insertions(+), 1 deletion(-) diff --git a/js.go b/js.go index 38d7be5a3..46f86888b 100644 --- a/js.go +++ b/js.go @@ -859,6 +859,11 @@ func (ctx ContextOpt) configurePublish(opts *pubOpts) error { return nil } +func (ctx ContextOpt) configureSubscribe(opts *subOpts) error { + opts.ctx = ctx + return nil +} + func (ctx ContextOpt) configurePull(opts *pullOpts) error { opts.ctx = ctx return nil @@ -965,6 +970,9 @@ type jsSub struct { fcd uint64 fciseq uint64 csfct *time.Timer + + // Cancellation function to cancel context on drain/unsubscribe. + cancel func() } // Deletes the JS Consumer. @@ -1243,6 +1251,7 @@ func (js *js) subscribe(subj, queue string, cb MsgHandler, ch chan *Msg, isSync, consumer = o.consumer isDurable = o.cfg.Durable != _EMPTY_ consumerBound = o.bound + ctx = o.ctx notFoundErr bool lookupErr bool nc = js.nc @@ -1389,6 +1398,13 @@ func (js *js) subscribe(subj, queue string, cb MsgHandler, ch chan *Msg, isSync, deliver = nc.newInbox() } + // In case this has a context, then create a child context that + // is possible to cancel via unsubscribe / drain. + var cancel func() + if ctx != nil { + ctx, cancel = context.WithCancel(ctx) + } + jsi := &jsSub{ js: js, stream: stream, @@ -1401,6 +1417,7 @@ func (js *js) subscribe(subj, queue string, cb MsgHandler, ch chan *Msg, isSync, pull: isPullMode, nms: nms, psubj: subj, + cancel: cancel, } // Check if we are manual ack. @@ -1539,6 +1556,14 @@ func (js *js) subscribe(subj, queue string, cb MsgHandler, ch chan *Msg, isSync, sub.chanSubcheckForFlowControlResponse() } + // Wait for context to get canceled if there is one. + if ctx != nil { + go func() { + <-ctx.Done() + sub.Unsubscribe() + }() + } + return sub, nil } @@ -1953,6 +1978,7 @@ type subOpts struct { mack bool // For an ordered consumer. ordered bool + ctx context.Context } // OrderedConsumer will create a fifo direct/ephemeral consumer for in order delivery of messages. diff --git a/kv.go b/kv.go index deaefde2a..a727249fd 100644 --- a/kv.go +++ b/kv.go @@ -689,6 +689,9 @@ func (kv *kvs) Watch(keys string, opts ...WatchOpt) (KeyWatcher, error) { if o.metaOnly { subOpts = append(subOpts, HeadersOnly()) } + if o.ctx != nil { + subOpts = append(subOpts, Context(o.ctx)) + } sub, err := kv.js.Subscribe(keys, update, subOpts...) if err != nil { return nil, err diff --git a/nats.go b/nats.go index 72759717e..3090191c6 100644 --- a/nats.go +++ b/nats.go @@ -4144,6 +4144,20 @@ func (nc *Conn) unsubscribe(sub *Subscription, max int, drainMode bool) error { nc.bw.appendString(fmt.Sprintf(unsubProto, s.sid, maxStr)) nc.kickFlusher() } + + // For JetStream subscriptions cancel the attached context if there is any. + var cancel func() + sub.mu.Lock() + jsi := sub.jsi + if jsi != nil { + cancel = jsi.cancel + jsi.cancel = nil + } + sub.mu.Unlock() + if cancel != nil { + cancel() + } + return nil } diff --git a/test/js_test.go b/test/js_test.go index 079ea70f2..9042da724 100644 --- a/test/js_test.go +++ b/test/js_test.go @@ -5956,7 +5956,7 @@ func TestJetStreamBindConsumer(t *testing.T) { if ci != nil && !ci.PushBound { return nil } - return fmt.Errorf("Conusmer %q still active", "push") + return fmt.Errorf("Consumer %q still active", "push") }) } checkConsInactive() @@ -6724,3 +6724,100 @@ func testJetStreamFetchContext(t *testing.T, srvs ...*jsServer) { } }) } + +func TestJetStreamSubscribeContextCancel(t *testing.T) { + s := RunBasicJetStreamServer() + defer s.Shutdown() + + if config := s.JetStreamConfig(); config != nil { + defer os.RemoveAll(config.StoreDir) + } + + nc, err := nats.Connect(s.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + + js, err := nc.JetStream() + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + // Create the stream using our client API. + _, err = js.AddStream(&nats.StreamConfig{ + Name: "TEST", + Subjects: []string{"foo", "bar", "baz", "foo.*"}, + }) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + toSend := 100 + for i := 0; i < toSend; i++ { + js.Publish("bar", []byte("foo")) + } + + t.Run("cancel unsubscribes and deletes ephemeral", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ch := make(chan *nats.Msg, 100) + sub, err := js.Subscribe("bar", func(msg *nats.Msg) { + ch <- msg + + // Cancel will unsubscribe and remove the subscription + // of the consumer. + if len(ch) >= 50 { + cancel() + } + }, nats.Context(ctx)) + if err != nil { + t.Fatal(err) + } + + select { + case <-ctx.Done(): + case <-time.After(3 * time.Second): + t.Fatal("Timed out waiting for context to be canceled") + } + + // Consumer should not be present since unsubscribe already called. + checkFor(t, 2*time.Second, 15*time.Millisecond, func() error { + info, err := sub.ConsumerInfo() + if err != nil && err == nats.ErrConsumerNotFound { + return nil + } + return fmt.Errorf("Consumer still active, got: %v (info=%+v)", err, info) + }) + + got := len(ch) + expected := 50 + if got < expected { + t.Errorf("Expected to receive at least %d messages, got: %d", expected, got) + } + }) + + t.Run("unsubscribe cancels child context", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + sub, err := js.Subscribe("bar", func(msg *nats.Msg) {}, nats.Context(ctx)) + if err != nil { + t.Fatal(err) + } + err = sub.Unsubscribe() + if err != nil { + t.Fatal(err) + } + + // Consumer should not be present since unsubscribe already called. + checkFor(t, 2*time.Second, 15*time.Millisecond, func() error { + info, err := sub.ConsumerInfo() + if err != nil && err == nats.ErrConsumerNotFound { + return nil + } + return fmt.Errorf("Consumer still active, got: %v (info=%+v)", err, info) + }) + }) +} diff --git a/test/kv_test.go b/test/kv_test.go index 6c8602603..4d035bb31 100644 --- a/test/kv_test.go +++ b/test/kv_test.go @@ -14,6 +14,7 @@ package test import ( + "context" "fmt" "os" "reflect" @@ -228,6 +229,36 @@ func TestKeyValueWatch(t *testing.T) { expectInitDone() } +func TestKeyValueWatchContext(t *testing.T) { + s := RunBasicJetStreamServer() + defer shutdown(s) + + nc, js := jsClient(t, s) + defer nc.Close() + + kv, err := js.CreateKeyValue(&nats.KeyValueConfig{Bucket: "WATCHCTX"}) + expectOk(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + watcher, err := kv.WatchAll(nats.Context(ctx)) + expectOk(t, err) + defer watcher.Stop() + + // Trigger unsubscribe internally. + cancel() + + // Wait for a bit for unsubscribe to be done. + time.Sleep(500 * time.Millisecond) + + // Stopping watch that is already stopped via cancellation propagation is an error. + err = watcher.Stop() + if err == nil || err != nats.ErrBadSubscription { + t.Errorf("Expected invalid subscription, got: %v", err) + } +} + func TestKeyValueBindStore(t *testing.T) { s := RunBasicJetStreamServer() defer shutdown(s)