Skip to content

Commit

Permalink
Merge pull request #872 from nats-io/js-subscribe-ctx
Browse files Browse the repository at this point in the history
js: make js.Subscribe context aware
  • Loading branch information
wallyqs committed Dec 14, 2021
2 parents d7c1d78 + 16d26f8 commit b337a5c
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 1 deletion.
26 changes: 26 additions & 0 deletions js.go
Expand Up @@ -859,6 +859,11 @@ func (ctx ContextOpt) configurePublish(opts *pubOpts) error {
return nil
}

func (ctx ContextOpt) configureSubscribe(opts *subOpts) error {
opts.ctx = ctx
return nil
}

func (ctx ContextOpt) configurePull(opts *pullOpts) error {
opts.ctx = ctx
return nil
Expand Down Expand Up @@ -965,6 +970,9 @@ type jsSub struct {
fcd uint64
fciseq uint64
csfct *time.Timer

// Cancellation function to cancel context on drain/unsubscribe.
cancel func()
}

// Deletes the JS Consumer.
Expand Down Expand Up @@ -1243,6 +1251,7 @@ func (js *js) subscribe(subj, queue string, cb MsgHandler, ch chan *Msg, isSync,
consumer = o.consumer
isDurable = o.cfg.Durable != _EMPTY_
consumerBound = o.bound
ctx = o.ctx
notFoundErr bool
lookupErr bool
nc = js.nc
Expand Down Expand Up @@ -1389,6 +1398,13 @@ func (js *js) subscribe(subj, queue string, cb MsgHandler, ch chan *Msg, isSync,
deliver = nc.newInbox()
}

// In case this has a context, then create a child context that
// is possible to cancel via unsubscribe / drain.
var cancel func()
if ctx != nil {
ctx, cancel = context.WithCancel(ctx)
}

jsi := &jsSub{
js: js,
stream: stream,
Expand All @@ -1401,6 +1417,7 @@ func (js *js) subscribe(subj, queue string, cb MsgHandler, ch chan *Msg, isSync,
pull: isPullMode,
nms: nms,
psubj: subj,
cancel: cancel,
}

// Check if we are manual ack.
Expand Down Expand Up @@ -1539,6 +1556,14 @@ func (js *js) subscribe(subj, queue string, cb MsgHandler, ch chan *Msg, isSync,
sub.chanSubcheckForFlowControlResponse()
}

// Wait for context to get canceled if there is one.
if ctx != nil {
go func() {
<-ctx.Done()
sub.Unsubscribe()
}()
}

return sub, nil
}

Expand Down Expand Up @@ -1953,6 +1978,7 @@ type subOpts struct {
mack bool
// For an ordered consumer.
ordered bool
ctx context.Context
}

// OrderedConsumer will create a fifo direct/ephemeral consumer for in order delivery of messages.
Expand Down
3 changes: 3 additions & 0 deletions kv.go
Expand Up @@ -689,6 +689,9 @@ func (kv *kvs) Watch(keys string, opts ...WatchOpt) (KeyWatcher, error) {
if o.metaOnly {
subOpts = append(subOpts, HeadersOnly())
}
if o.ctx != nil {
subOpts = append(subOpts, Context(o.ctx))
}
sub, err := kv.js.Subscribe(keys, update, subOpts...)
if err != nil {
return nil, err
Expand Down
14 changes: 14 additions & 0 deletions nats.go
Expand Up @@ -4144,6 +4144,20 @@ func (nc *Conn) unsubscribe(sub *Subscription, max int, drainMode bool) error {
nc.bw.appendString(fmt.Sprintf(unsubProto, s.sid, maxStr))
nc.kickFlusher()
}

// For JetStream subscriptions cancel the attached context if there is any.
var cancel func()
sub.mu.Lock()
jsi := sub.jsi
if jsi != nil {
cancel = jsi.cancel
jsi.cancel = nil
}
sub.mu.Unlock()
if cancel != nil {
cancel()
}

return nil
}

Expand Down
99 changes: 98 additions & 1 deletion test/js_test.go
Expand Up @@ -5956,7 +5956,7 @@ func TestJetStreamBindConsumer(t *testing.T) {
if ci != nil && !ci.PushBound {
return nil
}
return fmt.Errorf("Conusmer %q still active", "push")
return fmt.Errorf("Consumer %q still active", "push")
})
}
checkConsInactive()
Expand Down Expand Up @@ -6724,3 +6724,100 @@ func testJetStreamFetchContext(t *testing.T, srvs ...*jsServer) {
}
})
}

func TestJetStreamSubscribeContextCancel(t *testing.T) {
s := RunBasicJetStreamServer()
defer s.Shutdown()

if config := s.JetStreamConfig(); config != nil {
defer os.RemoveAll(config.StoreDir)
}

nc, err := nats.Connect(s.ClientURL())
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
defer nc.Close()

js, err := nc.JetStream()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

// Create the stream using our client API.
_, err = js.AddStream(&nats.StreamConfig{
Name: "TEST",
Subjects: []string{"foo", "bar", "baz", "foo.*"},
})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

toSend := 100
for i := 0; i < toSend; i++ {
js.Publish("bar", []byte("foo"))
}

t.Run("cancel unsubscribes and deletes ephemeral", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

ch := make(chan *nats.Msg, 100)
sub, err := js.Subscribe("bar", func(msg *nats.Msg) {
ch <- msg

// Cancel will unsubscribe and remove the subscription
// of the consumer.
if len(ch) >= 50 {
cancel()
}
}, nats.Context(ctx))
if err != nil {
t.Fatal(err)
}

select {
case <-ctx.Done():
case <-time.After(3 * time.Second):
t.Fatal("Timed out waiting for context to be canceled")
}

// Consumer should not be present since unsubscribe already called.
checkFor(t, 2*time.Second, 15*time.Millisecond, func() error {
info, err := sub.ConsumerInfo()
if err != nil && err == nats.ErrConsumerNotFound {
return nil
}
return fmt.Errorf("Consumer still active, got: %v (info=%+v)", err, info)
})

got := len(ch)
expected := 50
if got < expected {
t.Errorf("Expected to receive at least %d messages, got: %d", expected, got)
}
})

t.Run("unsubscribe cancels child context", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

sub, err := js.Subscribe("bar", func(msg *nats.Msg) {}, nats.Context(ctx))
if err != nil {
t.Fatal(err)
}
err = sub.Unsubscribe()
if err != nil {
t.Fatal(err)
}

// Consumer should not be present since unsubscribe already called.
checkFor(t, 2*time.Second, 15*time.Millisecond, func() error {
info, err := sub.ConsumerInfo()
if err != nil && err == nats.ErrConsumerNotFound {
return nil
}
return fmt.Errorf("Consumer still active, got: %v (info=%+v)", err, info)
})
})
}
31 changes: 31 additions & 0 deletions test/kv_test.go
Expand Up @@ -14,6 +14,7 @@
package test

import (
"context"
"fmt"
"os"
"reflect"
Expand Down Expand Up @@ -228,6 +229,36 @@ func TestKeyValueWatch(t *testing.T) {
expectInitDone()
}

func TestKeyValueWatchContext(t *testing.T) {
s := RunBasicJetStreamServer()
defer shutdown(s)

nc, js := jsClient(t, s)
defer nc.Close()

kv, err := js.CreateKeyValue(&nats.KeyValueConfig{Bucket: "WATCHCTX"})
expectOk(t, err)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

watcher, err := kv.WatchAll(nats.Context(ctx))
expectOk(t, err)
defer watcher.Stop()

// Trigger unsubscribe internally.
cancel()

// Wait for a bit for unsubscribe to be done.
time.Sleep(500 * time.Millisecond)

// Stopping watch that is already stopped via cancellation propagation is an error.
err = watcher.Stop()
if err == nil || err != nats.ErrBadSubscription {
t.Errorf("Expected invalid subscription, got: %v", err)
}
}

func TestKeyValueBindStore(t *testing.T) {
s := RunBasicJetStreamServer()
defer shutdown(s)
Expand Down

0 comments on commit b337a5c

Please sign in to comment.