Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

js: make js.Subscribe context aware #872

Merged
merged 1 commit into from Dec 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
26 changes: 26 additions & 0 deletions js.go
Expand Up @@ -859,6 +859,11 @@ func (ctx ContextOpt) configurePublish(opts *pubOpts) error {
return nil
}

func (ctx ContextOpt) configureSubscribe(opts *subOpts) error {
opts.ctx = ctx
return nil
}

func (ctx ContextOpt) configurePull(opts *pullOpts) error {
opts.ctx = ctx
return nil
Expand Down Expand Up @@ -965,6 +970,9 @@ type jsSub struct {
fcd uint64
fciseq uint64
csfct *time.Timer

// Cancellation function to cancel context on drain/unsubscribe.
cancel func()
}

// Deletes the JS Consumer.
Expand Down Expand Up @@ -1243,6 +1251,7 @@ func (js *js) subscribe(subj, queue string, cb MsgHandler, ch chan *Msg, isSync,
consumer = o.consumer
isDurable = o.cfg.Durable != _EMPTY_
consumerBound = o.bound
ctx = o.ctx
notFoundErr bool
lookupErr bool
nc = js.nc
Expand Down Expand Up @@ -1389,6 +1398,13 @@ func (js *js) subscribe(subj, queue string, cb MsgHandler, ch chan *Msg, isSync,
deliver = nc.newInbox()
}

// In case this has a context, then create a child context that
// is possible to cancel via unsubscribe / drain.
var cancel func()
if ctx != nil {
ctx, cancel = context.WithCancel(ctx)
}

jsi := &jsSub{
js: js,
stream: stream,
Expand All @@ -1401,6 +1417,7 @@ func (js *js) subscribe(subj, queue string, cb MsgHandler, ch chan *Msg, isSync,
pull: isPullMode,
nms: nms,
psubj: subj,
cancel: cancel,
}

// Check if we are manual ack.
Expand Down Expand Up @@ -1539,6 +1556,14 @@ func (js *js) subscribe(subj, queue string, cb MsgHandler, ch chan *Msg, isSync,
sub.chanSubcheckForFlowControlResponse()
}

// Wait for context to get canceled if there is one.
if ctx != nil {
go func() {
<-ctx.Done()
sub.Unsubscribe()
}()
}

return sub, nil
}

Expand Down Expand Up @@ -1953,6 +1978,7 @@ type subOpts struct {
mack bool
// For an ordered consumer.
ordered bool
ctx context.Context
}

// OrderedConsumer will create a fifo direct/ephemeral consumer for in order delivery of messages.
Expand Down
3 changes: 3 additions & 0 deletions kv.go
Expand Up @@ -689,6 +689,9 @@ func (kv *kvs) Watch(keys string, opts ...WatchOpt) (KeyWatcher, error) {
if o.metaOnly {
subOpts = append(subOpts, HeadersOnly())
}
if o.ctx != nil {
subOpts = append(subOpts, Context(o.ctx))
}
sub, err := kv.js.Subscribe(keys, update, subOpts...)
if err != nil {
return nil, err
Expand Down
14 changes: 14 additions & 0 deletions nats.go
Expand Up @@ -4144,6 +4144,20 @@ func (nc *Conn) unsubscribe(sub *Subscription, max int, drainMode bool) error {
nc.bw.appendString(fmt.Sprintf(unsubProto, s.sid, maxStr))
nc.kickFlusher()
}

// For JetStream subscriptions cancel the attached context if there is any.
var cancel func()
sub.mu.Lock()
jsi := sub.jsi
if jsi != nil {
cancel = jsi.cancel
jsi.cancel = nil
}
sub.mu.Unlock()
if cancel != nil {
cancel()
}

return nil
}

Expand Down
99 changes: 98 additions & 1 deletion test/js_test.go
Expand Up @@ -5956,7 +5956,7 @@ func TestJetStreamBindConsumer(t *testing.T) {
if ci != nil && !ci.PushBound {
return nil
}
return fmt.Errorf("Conusmer %q still active", "push")
return fmt.Errorf("Consumer %q still active", "push")
})
}
checkConsInactive()
Expand Down Expand Up @@ -6724,3 +6724,100 @@ func testJetStreamFetchContext(t *testing.T, srvs ...*jsServer) {
}
})
}

func TestJetStreamSubscribeContextCancel(t *testing.T) {
s := RunBasicJetStreamServer()
defer s.Shutdown()

if config := s.JetStreamConfig(); config != nil {
defer os.RemoveAll(config.StoreDir)
}

nc, err := nats.Connect(s.ClientURL())
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
defer nc.Close()

js, err := nc.JetStream()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

// Create the stream using our client API.
_, err = js.AddStream(&nats.StreamConfig{
Name: "TEST",
Subjects: []string{"foo", "bar", "baz", "foo.*"},
})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

toSend := 100
for i := 0; i < toSend; i++ {
js.Publish("bar", []byte("foo"))
}

t.Run("cancel unsubscribes and deletes ephemeral", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

ch := make(chan *nats.Msg, 100)
sub, err := js.Subscribe("bar", func(msg *nats.Msg) {
ch <- msg

// Cancel will unsubscribe and remove the subscription
// of the consumer.
if len(ch) >= 50 {
cancel()
}
}, nats.Context(ctx))
if err != nil {
t.Fatal(err)
}

select {
case <-ctx.Done():
case <-time.After(3 * time.Second):
t.Fatal("Timed out waiting for context to be canceled")
}

// Consumer should not be present since unsubscribe already called.
checkFor(t, 2*time.Second, 15*time.Millisecond, func() error {
info, err := sub.ConsumerInfo()
if err != nil && err == nats.ErrConsumerNotFound {
return nil
}
return fmt.Errorf("Consumer still active, got: %v (info=%+v)", err, info)
})

got := len(ch)
expected := 50
if got < expected {
t.Errorf("Expected to receive at least %d messages, got: %d", expected, got)
}
})

t.Run("unsubscribe cancels child context", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

sub, err := js.Subscribe("bar", func(msg *nats.Msg) {}, nats.Context(ctx))
if err != nil {
t.Fatal(err)
}
err = sub.Unsubscribe()
if err != nil {
t.Fatal(err)
}

// Consumer should not be present since unsubscribe already called.
checkFor(t, 2*time.Second, 15*time.Millisecond, func() error {
info, err := sub.ConsumerInfo()
if err != nil && err == nats.ErrConsumerNotFound {
return nil
}
return fmt.Errorf("Consumer still active, got: %v (info=%+v)", err, info)
})
})
}
31 changes: 31 additions & 0 deletions test/kv_test.go
Expand Up @@ -14,6 +14,7 @@
package test

import (
"context"
"fmt"
"os"
"reflect"
Expand Down Expand Up @@ -228,6 +229,36 @@ func TestKeyValueWatch(t *testing.T) {
expectInitDone()
}

func TestKeyValueWatchContext(t *testing.T) {
s := RunBasicJetStreamServer()
defer shutdown(s)

nc, js := jsClient(t, s)
defer nc.Close()

kv, err := js.CreateKeyValue(&nats.KeyValueConfig{Bucket: "WATCHCTX"})
expectOk(t, err)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

watcher, err := kv.WatchAll(nats.Context(ctx))
expectOk(t, err)
defer watcher.Stop()

// Trigger unsubscribe internally.
cancel()

// Wait for a bit for unsubscribe to be done.
time.Sleep(500 * time.Millisecond)

// Stopping watch that is already stopped via cancellation propagation is an error.
err = watcher.Stop()
if err == nil || err != nats.ErrBadSubscription {
t.Errorf("Expected invalid subscription, got: %v", err)
}
}

func TestKeyValueBindStore(t *testing.T) {
s := RunBasicJetStreamServer()
defer shutdown(s)
Expand Down