diff --git a/context.go b/context.go index 037668fb7..300b6ebbd 100644 --- a/context.go +++ b/context.go @@ -21,20 +21,13 @@ import ( // RequestMsgWithContext takes a context, a subject and payload // in bytes and request expecting a single response. func (nc *Conn) RequestMsgWithContext(ctx context.Context, msg *Msg) (*Msg, error) { - var hdr []byte - var err error - - if len(msg.Header) > 0 { - if !nc.info.Headers { - return nil, ErrHeadersNotSupported - } - - hdr, err = msg.headerBytes() - if err != nil { - return nil, err - } + if msg == nil { + return nil, ErrInvalidMsg + } + hdr, err := msg.headerBytes() + if err != nil { + return nil, err } - return nc.requestWithContext(ctx, msg.Subject, hdr, msg.Data) } diff --git a/nats.go b/nats.go index 566fced52..050084b53 100644 --- a/nats.go +++ b/nats.go @@ -3383,21 +3383,10 @@ func (nc *Conn) PublishMsg(m *Msg) error { if m == nil { return ErrInvalidMsg } - - var hdr []byte - var err error - - if len(m.Header) > 0 { - if !nc.info.Headers { - return ErrHeadersNotSupported - } - - hdr, err = m.headerBytes() - if err != nil { - return err - } + hdr, err := m.headerBytes() + if err != nil { + return err } - return nc.publish(m.Subject, m.Reply, hdr, m.Data) } @@ -3423,6 +3412,12 @@ func (nc *Conn) publish(subj, reply string, hdr, data []byte) error { } nc.mu.Lock() + // Check if headers attempted to be sent to server that does not support them. + if len(hdr) > 0 && !nc.info.Headers { + nc.mu.Unlock() + return ErrHeadersNotSupported + } + if nc.isClosed() { nc.mu.Unlock() return ErrConnectionClosed @@ -3593,17 +3588,12 @@ func (nc *Conn) createNewRequestAndSend(subj string, hdr, data []byte) (chan *Ms // RequestMsg will send a request payload including optional headers and deliver // the response message, or an error, including a timeout if no message was received properly. func (nc *Conn) RequestMsg(msg *Msg, timeout time.Duration) (*Msg, error) { - var hdr []byte - var err error - - if len(msg.Header) > 0 { - if !nc.info.Headers { - return nil, ErrHeadersNotSupported - } - hdr, err = msg.headerBytes() - if err != nil { - return nil, err - } + if msg == nil { + return nil, ErrInvalidMsg + } + hdr, err := msg.headerBytes() + if err != nil { + return nil, err } return nc.request(msg.Subject, hdr, msg.Data, timeout) diff --git a/test/headers_test.go b/test/headers_test.go index 24b962326..d2d0f1bcf 100644 --- a/test/headers_test.go +++ b/test/headers_test.go @@ -14,15 +14,18 @@ package test import ( + "context" "fmt" "net/http" "reflect" "sort" + "sync" "testing" "time" "net/http/httptest" + "github.com/nats-io/nats-server/v2/server" natsserver "github.com/nats-io/nats-server/v2/test" "github.com/nats-io/nats.go" ) @@ -100,6 +103,95 @@ func TestRequestMsg(t *testing.T) { if resp.Header.Get("Hdr-Test") != "1" { t.Fatalf("Did not receive header in response") } + + if err = nc.PublishMsg(nil); err != nats.ErrInvalidMsg { + t.Errorf("Unexpected error: %v", err) + } + if _, err = nc.RequestMsg(nil, time.Second); err != nats.ErrInvalidMsg { + t.Errorf("Unexpected error: %v", err) + } + ctx, cancel := context.WithTimeout(context.Background(), 250*time.Millisecond) + defer cancel() + if _, err = nc.RequestMsgWithContext(ctx, nil); err != nats.ErrInvalidMsg { + t.Errorf("Unexpected error: %v", err) + } +} + +func TestRequestMsgRaceAsyncInfo(t *testing.T) { + s1Opts := natsserver.DefaultTestOptions + s1Opts.Host = "127.0.0.1" + s1Opts.Port = -1 + s1Opts.Cluster.Name = "CLUSTER" + s1Opts.Cluster.Host = "127.0.0.1" + s1Opts.Cluster.Port = -1 + s := natsserver.RunServer(&s1Opts) + defer s.Shutdown() + + nc, err := nats.Connect(s.ClientURL()) + if err != nil { + t.Fatalf("Error connecting to server: %v", err) + } + defer nc.Close() + + // Extra client with old request. + nc2, err := nats.Connect(s.ClientURL(), nats.UseOldRequestStyle()) + if err != nil { + t.Fatalf("Error connecting to server: %v", err) + } + defer nc2.Close() + + subject := "headers.test" + if _, err := nc.Subscribe(subject, func(m *nats.Msg) { + r := nats.NewMsg(m.Reply) + r.Header["Hdr-Test"] = []string{"bar"} + r.Data = []byte("+OK") + m.RespondMsg(r) + }); err != nil { + t.Fatalf("subscribe failed: %v", err) + } + nc.Flush() + + wg := sync.WaitGroup{} + wg.Add(1) + ch := make(chan struct{}) + go func() { + defer wg.Done() + + s2Opts := natsserver.DefaultTestOptions + s2Opts.Host = "127.0.0.1" + s2Opts.Port = -1 + s2Opts.Cluster.Name = "CLUSTER" + s2Opts.Cluster.Host = "127.0.0.1" + s2Opts.Cluster.Port = -1 + s2Opts.Routes = server.RoutesFromStr(fmt.Sprintf("nats://127.0.0.1:%d", s.ClusterAddr().Port)) + for { + s := natsserver.RunServer(&s2Opts) + s.Shutdown() + select { + case <-ch: + return + default: + } + } + }() + + msg := nats.NewMsg(subject) + msg.Header["Hdr-Test"] = []string{"quux"} + for i := 0; i < 100; i++ { + nc.RequestMsg(msg, time.Second) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + nc.RequestMsgWithContext(ctx, msg) + cancel() + + // Check with old style requests as well. + nc2.RequestMsg(msg, time.Second) + ctx2, cancel2 := context.WithTimeout(context.Background(), time.Second) + nc2.RequestMsgWithContext(ctx2, msg) + cancel2() + } + + close(ch) + wg.Wait() } func TestNoHeaderSupport(t *testing.T) {