From 76690683f2402cac2c132034611debd660ead77e Mon Sep 17 00:00:00 2001 From: Waldemar Quevedo Date: Tue, 14 Sep 2021 14:46:16 -0700 Subject: [PATCH] js: unblock batch requests on 408 with at least a message Signed-off-by: Waldemar Quevedo --- js.go | 10 +- test/js_test.go | 281 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 288 insertions(+), 3 deletions(-) diff --git a/js.go b/js.go index ffe8b5e1e..4d4adbdac 100644 --- a/js.go +++ b/js.go @@ -2105,9 +2105,9 @@ func checkMsg(msg *Msg, checkSts bool) (usrMsg bool, err error) { case reqTimeoutSts: // Older servers may send a 408 when a request in the server was expired // and interest is still found, which will be the case for our - // implementation. Regardless, ignore 408 errors, the caller will - // go back to wait for the next message. - err = nil + // implementation. Regardless, ignore 408 errors until receiving at least + // one message. + err = ErrTimeout default: err = fmt.Errorf("nats: %s", msg.Header.Get(descrHdr)) } @@ -2252,6 +2252,10 @@ func (sub *Subscription) Fetch(batch int, opts ...PullOpt) ([]*Msg, error) { // wait this time. noWait = false err = sendReq() + } else if err == ErrTimeout && len(msgs) == 0 { + // If we get a 408, we will bail if we already collected some + // messages, otherwise ignore and go back calling nextMsg. + err = nil } } } diff --git a/test/js_test.go b/test/js_test.go index 71af4d97b..b38535460 100644 --- a/test/js_test.go +++ b/test/js_test.go @@ -4498,6 +4498,287 @@ func testJetStreamMirror_Source(t *testing.T, nodes ...*jsServer) { }) } +func TestJetStream_PullSubscribeMaxWaiting(t *testing.T) { + nodes := []int{1, 3} + replicas := []int{1} + + for _, n := range nodes { + for _, r := range replicas { + if r > 1 && n == 1 { + continue + } + t.Run(fmt.Sprintf("psub n=%d r=%d", n, r), func(t *testing.T) { + name := fmt.Sprintf("PSUBMAX%d%d", n, r) + stream := &nats.StreamConfig{ + Name: name, + Replicas: n, + } + withJSClusterAndStream(t, name, n, stream, testJetStream_PullSubscribeMaxWaiting) + }) + } + } +} + +func testJetStream_PullSubscribeMaxWaiting(t *testing.T, subject string, srvs ...*jsServer) { + srv := srvs[0] + nc, err := nats.Connect(srv.ClientURL()) + if err != nil { + t.Fatal(err) + } + defer nc.Close() + + js, err := nc.JetStream() + if err != nil { + t.Fatal(err) + } + + // Create pull subscriber with a lower max waiting limit. + sub, err := js.PullSubscribe(subject, "durable", nats.PullMaxWaiting(5)) + if err != nil { + t.Fatal(err) + } + + // Delay for a bit the first message being received. + time.AfterFunc(200*time.Millisecond, func() { + js.Publish(subject, []byte("hello")) + }) + + // Only one message available so the request will linger and return a single message. + msgs, err := sub.Fetch(2, nats.MaxWait(500*time.Millisecond)) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if len(msgs) != 1 { + t.Fatalf("Expected one message to be delivered, got: %v", len(msgs)) + } + msg := msgs[0] + expected := "hello" + got := string(msg.Data) + if got != expected { + t.Fatalf("Expected: %v, got: %v", expected, got) + } + + info, err := sub.ConsumerInfo() + if err != nil { + t.Fatal(err) + } + if info.NumWaiting != 1 { + t.Fatalf("Expected 1 pending requests, got: %v", info.NumWaiting) + } + + // Make a few requests to start getting 408 Request Timeout errors. + for i := 0; i < 4; i++ { + msgs, err = sub.Fetch(2, nats.MaxWait(200*time.Millisecond)) + if err == nil { + t.Fatal("Unexpected success") + } + if len(msgs) != 0 { + t.Fatal("Expected no messages") + } + + // There should be a maximum number of waiting requests now. + info, err = sub.ConsumerInfo() + if err != nil { + t.Fatal(err) + } + if info.NumWaiting != i+2 { + t.Fatalf("Expected %v pending requests, got: %v", i+2, info.NumWaiting) + } + } + + // There should be a maximum number of waiting requests now. + info, err = sub.ConsumerInfo() + if err != nil { + t.Fatal(err) + } + if info.NumWaiting != 5 { + t.Fatalf("Expected 5 pending requests, got: %v", info.NumWaiting) + } + + // Making an extra request that fails which will expire some of the old requests. + msgs, err = sub.Fetch(2, nats.MaxWait(200*time.Millisecond)) + if err == nil { + t.Fatal("Unexpected success") + } + if len(msgs) != 0 { + t.Fatal("Expected no messages") + } + info, err = sub.ConsumerInfo() + if err != nil { + t.Fatal(err) + } + if info.NumWaiting != 1 { + t.Fatalf("Expected 1 pending requests, got: %v", info.NumWaiting) + } + + // Send another message with a delay... + time.AfterFunc(200*time.Millisecond, func() { + js.Publish(subject, []byte("bar")) + }) + + // Fetch and wait... + msgs, err = sub.Fetch(1, nats.MaxWait(500*time.Millisecond)) + if err != nil { + t.Fatal(err) + } + msg = msgs[0] + expected = "bar" + got = string(msg.Data) + if got != expected { + t.Fatalf("Expected: %v, got: %v", expected, got) + } + if len(msgs) != 1 { + t.Fatalf("Expected one message to be delivered, got: %v", len(msgs)) + } + + // There should be no waiting pull requests since they got expired after fetch succeeded. + info, err = sub.ConsumerInfo() + if err != nil { + t.Fatal(err) + } + if info.NumWaiting != 0 { + t.Fatalf("Expected no pending requests, got: %v", info.NumWaiting) + } + + t.Run("blocking fetch", func(t *testing.T) { + // Create requests that take a longer time and will exhaust + // the number of waiting requests so that the rest will be blocked. + var ( + max = 5 + msgCh = make(chan *nats.Msg, max) + errCh = make(chan error, max) + expectedMap = make(map[string]bool) + ) + for i := 0; i < 5; i++ { + expectedMap[fmt.Sprintf("quux:%d", i)] = false + go func() { + // Can only have at most 5 inflight fetch requests so + // these will block until they receive at least a message + // and a 408 response status. + msgs, err := sub.Fetch(100, nats.MaxWait(30*time.Second)) + if err != nil { + errCh <- err + return + } + for _, msg := range msgs { + msgCh <- msg + } + }() + } + + // Give some time to the fetch requests to linger. + timeout := time.Now().Add(2 * time.Second) + for time.Now().Before(timeout) { + info, err = sub.ConsumerInfo() + if err != nil { + t.Fatal(err) + } + if info.NumWaiting == max { + break + } + } + if info.NumWaiting != max { + t.Fatalf("Expected %v pull requests, got: %v", max, info.NumWaiting) + } + + // Send max number of messages that will be received by the first batch. + for i := 0; i < max; i++ { + js.Publish(subject, []byte(fmt.Sprintf("quux:%v", i))) + } + + ctx, done := context.WithTimeout(context.Background(), 4*time.Second) + defer done() + + // Delay sending message so that the fetch request starts waiting. + time.AfterFunc(500*time.Millisecond, func() { + js.Publish(subject, []byte("quux:5")) + }) + + var ( + msgs = make([]*nats.Msg, 0) + errs = make([]error, 0) + ) + + Loop: + for { + select { + case <-ctx.Done(): + break Loop + default: + } + + // These will timeout until all the original blocking fetch requests are done, + // or they receive at least a message after receiving a 408. + m, err := sub.Fetch(1, nats.MaxWait(100*time.Millisecond)) + if err != nil { + errs = append(errs, err) + } + if len(m) > 0 { + info, _ = sub.ConsumerInfo() + if info.NumWaiting != 5 { + t.Fatalf("Expected: %v, got: %v", 0, info.NumWaiting) + } + } + msgs = append(msgs, m...) + if len(msgs) > 0 { + break Loop + } + } + if len(errs) == 0 { + t.Fatalf("Expected at least an error, got: %v", len(errs)) + } + if len(msgs) != 1 { + t.Fatalf("Expected one message to be delivered to recent fetch, got: %v", len(msgs)) + } + for _, e := range errs { + if e != nats.ErrTimeout { + t.Fatalf("Expected nats timeout, got: %v", e) + } + } + info, err = sub.ConsumerInfo() + if err != nil { + t.Fatal(err) + } + if info.NumWaiting != 5 { + t.Fatalf("Expected no pull requests (%v), got: %v", 0, info.NumWaiting) + } + if len(msgCh) != max { + t.Fatalf("Expected %v messages to be delivered on first set of fetch requests, got: %v", max, len(msgCh)) + } + var expected, got string + for i := 0; i < max; i++ { + select { + case msg := <-msgCh: + if _, ok := expectedMap[string(msg.Data)]; ok { + expectedMap[string(msg.Data)] = true + } + default: + t.Fatal("Unexpected blocking channel") + } + } + for k, v := range expectedMap { + if !v { + t.Fatalf("Expected message %v", k) + } + } + + // Message received after the first set of goroutines have timed out, + // so that following fetch requests are unblocked. + msg = msgs[0] + expected = "quux:5" + got = string(msg.Data) + if got != expected { + t.Fatalf("Expected: %v, got: %v", expected, got) + } + + select { + case err := <-errCh: + t.Fatalf("Unexpected error: %v", err) + default: + } + }) +} + func TestJetStream_ClusterMultipleSubscribe(t *testing.T) { nodes := []int{1, 3} replicas := []int{1}