diff --git a/js.go b/js.go index 2a2a0edc5..539b55e69 100644 --- a/js.go +++ b/js.go @@ -2078,7 +2078,11 @@ func checkMsg(msg *Msg, checkSts bool) (usrMsg bool, err error) { // 404 indicates that there are no messages. err = errNoMessages case reqTimeoutSts: - err = ErrTimeout + // Older servers may send a 408 when a request in the server was expired + // and interest is still found, which will be the case for our + // implementation. Regardless, ignore 408 errors, the caller will + // go back to wait for the next message. + err = nil default: err = fmt.Errorf("nats: %s", msg.Header.Get(descrHdr)) } @@ -2090,6 +2094,9 @@ func (sub *Subscription) Fetch(batch int, opts ...PullOpt) ([]*Msg, error) { if sub == nil { return nil, ErrBadSubscription } + if batch < 1 { + return nil, ErrInvalidArg + } var o pullOpts for _, opt := range opts { @@ -2182,19 +2189,31 @@ func (sub *Subscription) Fetch(batch int, opts ...PullOpt) ([]*Msg, error) { if err == nil && len(msgs) < batch { // For batch real size of 1, it does not make sense to set no_wait in // the request. - batchSize := batch - len(msgs) - noWait := batchSize > 1 - nr := &nextRequest{Batch: batchSize, NoWait: noWait} - req, _ := json.Marshal(nr) + noWait := batch-len(msgs) > 1 + var nr nextRequest - err = nc.PublishRequest(nms, rply, req) - for err == nil && len(msgs) < batch { + sendReq := func() error { ttl -= time.Since(start) if ttl < 0 { - ttl = 0 + // At this point consider that we have timed-out + return context.DeadlineExceeded + } + // Make our request expiration a bit shorter than the current timeout. + expires := ttl + if ttl >= 20*time.Millisecond { + expires = ttl - 10*time.Millisecond } - // Ask for next message and waits if there are no messages + nr.Batch = batch - len(msgs) + nr.Expires = expires + nr.NoWait = noWait + req, _ := json.Marshal(nr) + return nc.PublishRequest(nms, rply, req) + } + + err = sendReq() + for err == nil && len(msgs) < batch { + // Ask for next message and wait if there are no messages msg, err = sub.nextMsgWithContext(ctx, true, true) if err == nil { var usrMsg bool @@ -2207,27 +2226,7 @@ func (sub *Subscription) Fetch(batch int, opts ...PullOpt) ([]*Msg, error) { // not collected any message, then resend request to // wait this time. noWait = false - - ttl -= time.Since(start) - if ttl < 0 { - // At this point consider that we have timed-out - err = context.DeadlineExceeded - break - } - - // Make our request expiration a bit shorter than the - // current timeout. - expires := ttl - if ttl >= 20*time.Millisecond { - expires = ttl - 10*time.Millisecond - } - - nr.Batch = batch - len(msgs) - nr.Expires = expires - nr.NoWait = false - req, _ = json.Marshal(nr) - - err = nc.PublishRequest(nms, rply, req) + err = sendReq() } } } diff --git a/js_test.go b/js_test.go index 9e0eba8ed..90c03fe64 100644 --- a/js_test.go +++ b/js_test.go @@ -771,3 +771,61 @@ func TestJetStreamFlowControlStalled(t *testing.T) { t.Fatal("Library did not send FC") } } + +func TestJetStreamExpiredPullRequests(t *testing.T) { + s := RunBasicJetStreamServer() + defer s.Shutdown() + + if config := s.JetStreamConfig(); config != nil { + defer os.RemoveAll(config.StoreDir) + } + + nc, err := Connect(s.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + + js, err := nc.JetStream() + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + _, err = js.AddStream(&StreamConfig{ + Name: "TEST", + Subjects: []string{"foo"}, + }) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + sub, err := js.PullSubscribe("foo", "bar", PullMaxWaiting(2)) + if err != nil { + t.Fatalf("Error on subscribe: %v", err) + } + // Make sure that we reject batch < 1 + if _, err := sub.Fetch(0); err == nil { + t.Fatal("Expected error, did not get one") + } + if _, err := sub.Fetch(-1); err == nil { + t.Fatal("Expected error, did not get one") + } + + // Send 2 fetch requests + for i := 0; i < 2; i++ { + if _, err = sub.Fetch(1, MaxWait(15*time.Millisecond)); err == nil { + t.Fatalf("Expected error, got none") + } + } + // Wait before the above expire + time.Sleep(50 * time.Millisecond) + batches := []int{1, 10} + for _, bsz := range batches { + start := time.Now() + _, err = sub.Fetch(bsz, MaxWait(250*time.Millisecond)) + dur := time.Since(start) + if err == nil || dur < 50*time.Millisecond { + t.Fatalf("Expected error and wait for 250ms, got err=%v and dur=%v", err, dur) + } + } +} diff --git a/test/js_test.go b/test/js_test.go index f7cc4ed43..71af4d97b 100644 --- a/test/js_test.go +++ b/test/js_test.go @@ -842,9 +842,9 @@ func TestJetStreamAckPending_Pull(t *testing.T) { for time.Now().Before(timeout) { ms, err := sub.Fetch(ackPendingLimit) if err != nil || (ms != nil && len(ms) == 0) { + time.Sleep(50 * time.Millisecond) continue } - msgs = append(msgs, ms...) if len(msgs) >= expected { break @@ -852,7 +852,7 @@ func TestJetStreamAckPending_Pull(t *testing.T) { time.Sleep(10 * time.Millisecond) } if len(msgs) < expected { - t.Errorf("Expected %v, got %v", expected, pending) + t.Fatalf("Expected %v, got %v", expected, pending) } info, err := sub.ConsumerInfo() @@ -863,37 +863,37 @@ func TestJetStreamAckPending_Pull(t *testing.T) { got := info.NumRedelivered expected = 3 if got < expected { - t.Errorf("Expected %v, got: %v", expected, got) + t.Fatalf("Expected %v, got: %v", expected, got) } got = info.NumAckPending expected = 3 if got < expected { - t.Errorf("Expected %v, got: %v", expected, got) + t.Fatalf("Expected %v, got: %v", expected, got) } got = info.NumWaiting expected = 0 if got != expected { - t.Errorf("Expected %v, got: %v", expected, got) + t.Fatalf("Expected %v, got: %v", expected, got) } got = int(info.NumPending) expected = 0 if got != expected { - t.Errorf("Expected %v, got: %v", expected, got) + t.Fatalf("Expected %v, got: %v", expected, got) } got = info.Config.MaxAckPending expected = 3 if got != expected { - t.Errorf("Expected %v, got %v", expected, pending) + t.Fatalf("Expected %v, got %v", expected, pending) } got = info.Config.MaxDeliver expected = 5 if got != expected { - t.Errorf("Expected %v, got %v", expected, pending) + t.Fatalf("Expected %v, got %v", expected, pending) } acks := map[int]int{} @@ -913,7 +913,7 @@ func TestJetStreamAckPending_Pull(t *testing.T) { meta, err := m.Metadata() if err != nil { - t.Errorf("Unexpected error: %v", err) + t.Fatalf("Unexpected error: %v", err) } acks[int(meta.Sequence.Stream)]++ @@ -921,26 +921,26 @@ func TestJetStreamAckPending_Pull(t *testing.T) { ackPending-- } if int(meta.NumPending) != ackPending { - t.Errorf("Expected %v, got %v", ackPending, meta.NumPending) + t.Fatalf("Expected %v, got %v", ackPending, meta.NumPending) } } got = len(acks) expected = 3 if got != expected { - t.Errorf("Expected %v, got %v", expected, got) + t.Fatalf("Expected %v, got %v", expected, got) } expected = 5 for _, got := range acks { if got != expected { - t.Errorf("Expected %v, got %v", expected, got) + t.Fatalf("Expected %v, got %v", expected, got) } } _, err = sub.Fetch(1, nats.MaxWait(100*time.Millisecond)) if err != nats.ErrTimeout { - t.Errorf("Expected timeout, got: %v", err) + t.Fatalf("Expected timeout, got: %v", err) } }