Skip to content

Commit

Permalink
Merge pull request #934 from nats-io/request-msg-and-info-race
Browse files Browse the repository at this point in the history
Fix race with async INFO and header APIs
  • Loading branch information
wallyqs committed Mar 25, 2022
2 parents e0e03e3 + 0aa7f9d commit 31782c0
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 38 deletions.
19 changes: 6 additions & 13 deletions context.go
Expand Up @@ -21,20 +21,13 @@ import (
// RequestMsgWithContext takes a context, a subject and payload
// in bytes and request expecting a single response.
func (nc *Conn) RequestMsgWithContext(ctx context.Context, msg *Msg) (*Msg, error) {
var hdr []byte
var err error

if len(msg.Header) > 0 {
if !nc.info.Headers {
return nil, ErrHeadersNotSupported
}

hdr, err = msg.headerBytes()
if err != nil {
return nil, err
}
if msg == nil {
return nil, ErrInvalidMsg
}
hdr, err := msg.headerBytes()
if err != nil {
return nil, err
}

return nc.requestWithContext(ctx, msg.Subject, hdr, msg.Data)
}

Expand Down
40 changes: 15 additions & 25 deletions nats.go
Expand Up @@ -3383,21 +3383,10 @@ func (nc *Conn) PublishMsg(m *Msg) error {
if m == nil {
return ErrInvalidMsg
}

var hdr []byte
var err error

if len(m.Header) > 0 {
if !nc.info.Headers {
return ErrHeadersNotSupported
}

hdr, err = m.headerBytes()
if err != nil {
return err
}
hdr, err := m.headerBytes()
if err != nil {
return err
}

return nc.publish(m.Subject, m.Reply, hdr, m.Data)
}

Expand All @@ -3423,6 +3412,12 @@ func (nc *Conn) publish(subj, reply string, hdr, data []byte) error {
}
nc.mu.Lock()

// Check if headers attempted to be sent to server that does not support them.
if len(hdr) > 0 && !nc.info.Headers {
nc.mu.Unlock()
return ErrHeadersNotSupported
}

if nc.isClosed() {
nc.mu.Unlock()
return ErrConnectionClosed
Expand Down Expand Up @@ -3593,17 +3588,12 @@ func (nc *Conn) createNewRequestAndSend(subj string, hdr, data []byte) (chan *Ms
// RequestMsg will send a request payload including optional headers and deliver
// the response message, or an error, including a timeout if no message was received properly.
func (nc *Conn) RequestMsg(msg *Msg, timeout time.Duration) (*Msg, error) {
var hdr []byte
var err error

if len(msg.Header) > 0 {
if !nc.info.Headers {
return nil, ErrHeadersNotSupported
}
hdr, err = msg.headerBytes()
if err != nil {
return nil, err
}
if msg == nil {
return nil, ErrInvalidMsg
}
hdr, err := msg.headerBytes()
if err != nil {
return nil, err
}

return nc.request(msg.Subject, hdr, msg.Data, timeout)
Expand Down
92 changes: 92 additions & 0 deletions test/headers_test.go
Expand Up @@ -14,15 +14,18 @@
package test

import (
"context"
"fmt"
"net/http"
"reflect"
"sort"
"sync"
"testing"
"time"

"net/http/httptest"

"github.com/nats-io/nats-server/v2/server"
natsserver "github.com/nats-io/nats-server/v2/test"
"github.com/nats-io/nats.go"
)
Expand Down Expand Up @@ -100,6 +103,95 @@ func TestRequestMsg(t *testing.T) {
if resp.Header.Get("Hdr-Test") != "1" {
t.Fatalf("Did not receive header in response")
}

if err = nc.PublishMsg(nil); err != nats.ErrInvalidMsg {
t.Errorf("Unexpected error: %v", err)
}
if _, err = nc.RequestMsg(nil, time.Second); err != nats.ErrInvalidMsg {
t.Errorf("Unexpected error: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 250*time.Millisecond)
defer cancel()
if _, err = nc.RequestMsgWithContext(ctx, nil); err != nats.ErrInvalidMsg {
t.Errorf("Unexpected error: %v", err)
}
}

func TestRequestMsgRaceAsyncInfo(t *testing.T) {
s1Opts := natsserver.DefaultTestOptions
s1Opts.Host = "127.0.0.1"
s1Opts.Port = -1
s1Opts.Cluster.Name = "CLUSTER"
s1Opts.Cluster.Host = "127.0.0.1"
s1Opts.Cluster.Port = -1
s := natsserver.RunServer(&s1Opts)
defer s.Shutdown()

nc, err := nats.Connect(s.ClientURL())
if err != nil {
t.Fatalf("Error connecting to server: %v", err)
}
defer nc.Close()

// Extra client with old request.
nc2, err := nats.Connect(s.ClientURL(), nats.UseOldRequestStyle())
if err != nil {
t.Fatalf("Error connecting to server: %v", err)
}
defer nc2.Close()

subject := "headers.test"
if _, err := nc.Subscribe(subject, func(m *nats.Msg) {
r := nats.NewMsg(m.Reply)
r.Header["Hdr-Test"] = []string{"bar"}
r.Data = []byte("+OK")
m.RespondMsg(r)
}); err != nil {
t.Fatalf("subscribe failed: %v", err)
}
nc.Flush()

wg := sync.WaitGroup{}
wg.Add(1)
ch := make(chan struct{})
go func() {
defer wg.Done()

s2Opts := natsserver.DefaultTestOptions
s2Opts.Host = "127.0.0.1"
s2Opts.Port = -1
s2Opts.Cluster.Name = "CLUSTER"
s2Opts.Cluster.Host = "127.0.0.1"
s2Opts.Cluster.Port = -1
s2Opts.Routes = server.RoutesFromStr(fmt.Sprintf("nats://127.0.0.1:%d", s.ClusterAddr().Port))
for {
s := natsserver.RunServer(&s2Opts)
s.Shutdown()
select {
case <-ch:
return
default:
}
}
}()

msg := nats.NewMsg(subject)
msg.Header["Hdr-Test"] = []string{"quux"}
for i := 0; i < 100; i++ {
nc.RequestMsg(msg, time.Second)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
nc.RequestMsgWithContext(ctx, msg)
cancel()

// Check with old style requests as well.
nc2.RequestMsg(msg, time.Second)
ctx2, cancel2 := context.WithTimeout(context.Background(), time.Second)
nc2.RequestMsgWithContext(ctx2, msg)
cancel2()
}

close(ch)
wg.Wait()
}

func TestNoHeaderSupport(t *testing.T) {
Expand Down

0 comments on commit 31782c0

Please sign in to comment.