Skip to content

Commit

Permalink
js: unblock batch requests on 408 with at least a message
Browse files Browse the repository at this point in the history
Signed-off-by: Waldemar Quevedo <wally@synadia.com>
  • Loading branch information
wallyqs committed Sep 15, 2021
1 parent 1714547 commit 673391b
Show file tree
Hide file tree
Showing 2 changed files with 296 additions and 6 deletions.
21 changes: 15 additions & 6 deletions js.go
Expand Up @@ -2074,7 +2074,7 @@ var errNoMessages = errors.New("nats: no messages")
// Returns if the given message is a user message or not, and if
// `checkSts` is true, returns appropriate error based on the
// content of the status (404, etc..)
func checkMsg(msg *Msg, checkSts bool) (usrMsg bool, err error) {
func checkMsg(msg *Msg, checkSts, cancelWhen408 bool) (usrMsg bool, err error) {
// Assume user message
usrMsg = true

Expand Down Expand Up @@ -2105,9 +2105,13 @@ func checkMsg(msg *Msg, checkSts bool) (usrMsg bool, err error) {
case reqTimeoutSts:
// Older servers may send a 408 when a request in the server was expired
// and interest is still found, which will be the case for our
// implementation. Regardless, ignore 408 errors, the caller will
// go back to wait for the next message.
err = nil
// implementation. Regardless, ignore 408 errors until receiving at least
// one message.
if !cancelWhen408 {
err = nil
return
}
fallthrough
default:
err = fmt.Errorf("nats: %s", msg.Header.Get(descrHdr))
}
Expand Down Expand Up @@ -2207,7 +2211,7 @@ func (sub *Subscription) Fetch(batch int, opts ...PullOpt) ([]*Msg, error) {
// or status message, however, we don't care about values of status
// messages at this point in the Fetch() call, so checkMsg can't
// return an error.
if usrMsg, _ := checkMsg(msg, false); usrMsg {
if usrMsg, _ := checkMsg(msg, false, false); usrMsg {
msgs = append(msgs, msg)
}
}
Expand Down Expand Up @@ -2243,7 +2247,12 @@ func (sub *Subscription) Fetch(batch int, opts ...PullOpt) ([]*Msg, error) {
if err == nil {
var usrMsg bool

usrMsg, err = checkMsg(msg, true)
// If we already got at least a message but we are prompted
// with a 408 status, then unblock the request and return
// the messages that have arrived already.
cancelWhen408 := len(msgs) >= 1

usrMsg, err = checkMsg(msg, true, cancelWhen408)
if err == nil && usrMsg {
msgs = append(msgs, msg)
} else if noWait && (err == errNoMessages) && len(msgs) == 0 {
Expand Down
281 changes: 281 additions & 0 deletions test/js_test.go
Expand Up @@ -4498,6 +4498,287 @@ func testJetStreamMirror_Source(t *testing.T, nodes ...*jsServer) {
})
}

func TestJetStream_PullSubscribeMaxWaiting(t *testing.T) {
nodes := []int{1, 3}
replicas := []int{1}

for _, n := range nodes {
for _, r := range replicas {
if r > 1 && n == 1 {
continue
}
t.Run(fmt.Sprintf("psub n=%d r=%d", n, r), func(t *testing.T) {
name := fmt.Sprintf("PSUBMAX%d%d", n, r)
stream := &nats.StreamConfig{
Name: name,
Replicas: n,
}
withJSClusterAndStream(t, name, n, stream, testJetStream_PullSubscribeMaxWaiting)
})
}
}
}

func testJetStream_PullSubscribeMaxWaiting(t *testing.T, subject string, srvs ...*jsServer) {
srv := srvs[0]
nc, err := nats.Connect(srv.ClientURL())
if err != nil {
t.Fatal(err)
}
defer nc.Close()

js, err := nc.JetStream()
if err != nil {
t.Fatal(err)
}

// Create pull subscriber with a lower max waiting limit.
sub, err := js.PullSubscribe(subject, "durable", nats.PullMaxWaiting(5))
if err != nil {
t.Fatal(err)
}

// Delay for a bit the first message being received.
time.AfterFunc(200*time.Millisecond, func() {
js.Publish(subject, []byte("hello"))
})

// Only one message available so the request will linger and return a single message.
msgs, err := sub.Fetch(2, nats.MaxWait(500*time.Millisecond))
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if len(msgs) != 1 {
t.Fatalf("Expected one message to be delivered, got: %v", len(msgs))
}
msg := msgs[0]
expected := "hello"
got := string(msg.Data)
if got != expected {
t.Fatalf("Expected: %v, got: %v", expected, got)
}

info, err := sub.ConsumerInfo()
if err != nil {
t.Fatal(err)
}
if info.NumWaiting != 1 {
t.Fatalf("Expected 1 pending requests, got: %v", info.NumWaiting)
}

// Make a few requests to start getting 408 Request Timeout errors.
for i := 0; i < 4; i++ {
msgs, err = sub.Fetch(2, nats.MaxWait(200*time.Millisecond))
if err == nil {
t.Fatal("Unexpected success")
}
if len(msgs) != 0 {
t.Fatal("Expected no messages")
}

// There should be a maximum number of waiting requests now.
info, err = sub.ConsumerInfo()
if err != nil {
t.Fatal(err)
}
if info.NumWaiting != i+2 {
t.Fatalf("Expected %v pending requests, got: %v", i+2, info.NumWaiting)
}
}

// There should be a maximum number of waiting requests now.
info, err = sub.ConsumerInfo()
if err != nil {
t.Fatal(err)
}
if info.NumWaiting != 5 {
t.Fatalf("Expected 5 pending requests, got: %v", info.NumWaiting)
}

// Making an extra request that fails which will expire some of the old requests.
msgs, err = sub.Fetch(2, nats.MaxWait(200*time.Millisecond))
if err == nil {
t.Fatal("Unexpected success")
}
if len(msgs) != 0 {
t.Fatal("Expected no messages")
}
info, err = sub.ConsumerInfo()
if err != nil {
t.Fatal(err)
}
if info.NumWaiting != 1 {
t.Fatalf("Expected 1 pending requests, got: %v", info.NumWaiting)
}

// Send another message with a delay...
time.AfterFunc(200*time.Millisecond, func() {
js.Publish(subject, []byte("bar"))
})

// Fetch and wait...
msgs, err = sub.Fetch(1, nats.MaxWait(500*time.Millisecond))
if err != nil {
t.Fatal(err)
}
msg = msgs[0]
expected = "bar"
got = string(msg.Data)
if got != expected {
t.Fatalf("Expected: %v, got: %v", expected, got)
}
if len(msgs) != 1 {
t.Fatalf("Expected one message to be delivered, got: %v", len(msgs))
}

// There should be no waiting pull requests since they got expired after fetch succeeded.
info, err = sub.ConsumerInfo()
if err != nil {
t.Fatal(err)
}
if info.NumWaiting != 0 {
t.Fatalf("Expected no pending requests, got: %v", info.NumWaiting)
}

t.Run("blocking fetch", func(t *testing.T) {
// Create requests that take a longer time and will exhaust
// the number of waiting requests so that the rest will be blocked.
var (
max = 5
msgCh = make(chan *nats.Msg, max)
errCh = make(chan error, max)
expectedMap = make(map[string]bool)
)
for i := 0; i < 5; i++ {
expectedMap[fmt.Sprintf("quux:%d", i)] = false
go func() {
// Can only have at most 5 inflight fetch requests so
// these will block until they receive at least a message
// and a 408 response status.
msgs, err := sub.Fetch(100, nats.MaxWait(30*time.Second))
if err != nil {
errCh <- err
return
}
for _, msg := range msgs {
msgCh <- msg
}
}()
}

// Give some time to the fetch requests to linger.
timeout := time.Now().Add(2 * time.Second)
for time.Now().Before(timeout) {
info, err = sub.ConsumerInfo()
if err != nil {
t.Fatal(err)
}
if info.NumWaiting == max {
break
}
}
if info.NumWaiting != max {
t.Fatalf("Expected %v pull requests, got: %v", max, info.NumWaiting)
}

// Send max number of messages that will be received by the first batch.
for i := 0; i < max; i++ {
js.Publish(subject, []byte(fmt.Sprintf("quux:%v", i)))
}

ctx, done := context.WithTimeout(context.Background(), 4*time.Second)
defer done()

// Delay sending message so that the fetch request starts waiting.
time.AfterFunc(500*time.Millisecond, func() {
js.Publish(subject, []byte("quux:5"))
})

var (
msgs = make([]*nats.Msg, 0)
errs = make([]error, 0)
)

Loop:
for {
select {
case <-ctx.Done():
break Loop
default:
}

// These will timeout until all the original blocking fetch requests are done,
// or they receive at least a message after receiving a 408.
m, err := sub.Fetch(1, nats.MaxWait(100*time.Millisecond))
if err != nil {
errs = append(errs, err)
}
if len(m) > 0 {
info, _ = sub.ConsumerInfo()
if info.NumWaiting != 5 {
t.Fatalf("Expected: %v, got: %v", 0, info.NumWaiting)
}
}
msgs = append(msgs, m...)
if len(msgs) > 0 {
break Loop
}
}
if len(errs) == 0 {
t.Fatalf("Expected at least an error, got: %v", len(errs))
}
if len(msgs) != 1 {
t.Fatalf("Expected one message to be delivered to recent fetch, got: %v", len(msgs))
}
for _, e := range errs {
if e != nats.ErrTimeout {
t.Fatalf("Expected nats timeout, got: %v", e)
}
}
info, err = sub.ConsumerInfo()
if err != nil {
t.Fatal(err)
}
if info.NumWaiting != 5 {
t.Fatalf("Expected no pull requests (%v), got: %v", 0, info.NumWaiting)
}
if len(msgCh) != max {
t.Fatalf("Expected %v messages to be delivered on first set of fetch requests, got: %v", max, len(msgCh))
}
var expected, got string
for i := 0; i < max; i++ {
select {
case msg := <-msgCh:
if _, ok := expectedMap[string(msg.Data)]; ok {
expectedMap[string(msg.Data)] = true
}
default:
t.Fatal("Unexpected blocking channel")
}
}
for k, v := range expectedMap {
if !v {
t.Fatalf("Expected message %v", k)
}
}

// Message received after the first set of goroutines have timed out,
// so that following fetch requests are unblocked.
msg = msgs[0]
expected = "quux:5"
got = string(msg.Data)
if got != expected {
t.Fatalf("Expected: %v, got: %v", expected, got)
}

select {
case err := <-errCh:
t.Fatalf("Unexpected error: %v", err)
default:
}
})
}

func TestJetStream_ClusterMultipleSubscribe(t *testing.T) {
nodes := []int{1, 3}
replicas := []int{1}
Expand Down

0 comments on commit 673391b

Please sign in to comment.