diff --git a/nats.go b/nats.go index 83b39bafb..606673f35 100644 --- a/nats.go +++ b/nats.go @@ -1649,6 +1649,17 @@ func (w *natsWriter) doneWithPending() { w.pending = nil } +// Notify the reader that we are done with the connect, where "read" operations +// happen synchronously and under the connection lock. After this point, "read" +// will be happening from the read loop, without the connection lock. +// +// Note: this runs under the connection lock. +func (r *natsReader) doneWithConnect() { + if wsr, ok := r.r.(*websocketReader); ok { + wsr.doneWithConnect() + } +} + func (r *natsReader) Read() ([]byte, error) { if r.off >= 0 { off := r.off @@ -1977,6 +1988,10 @@ func (nc *Conn) processConnectInit() error { go nc.readLoop() go nc.flusher() + // Notify the reader that we are done with the connect handshake, where + // reads were done synchronously and under the connection lock. + nc.br.doneWithConnect() + return nil } diff --git a/ws.go b/ws.go index 2ef3f7f46..4d2445455 100644 --- a/ws.go +++ b/ws.go @@ -81,6 +81,7 @@ type websocketReader struct { ib []byte ff bool fc bool + nl bool dc *wsDecompressor nc *Conn } @@ -180,6 +181,15 @@ func wsNewReader(r io.Reader) *websocketReader { return &websocketReader{r: r, ff: true} } +// From now on, reads will be from the readLoop and we will need to +// acquire the connection lock should we have to send/write a control +// message from handleControlFrame. +// +// Note: this runs under the connection lock. +func (r *websocketReader) doneWithConnect() { + r.nl = true +} + func (r *websocketReader) Read(p []byte) (int, error) { var err error var buf []byte @@ -402,12 +412,12 @@ func (r *websocketReader) handleControlFrame(frameType wsOpCode, buf []byte, pos } } } - r.nc.wsEnqueueCloseMsg(status, body) + r.nc.wsEnqueueCloseMsg(r.nl, status, body) // Return io.EOF so that readLoop will close the connection as client closed // after processing pending buffers. return pos, io.EOF case wsPingMessage: - r.nc.wsEnqueueControlMsg(wsPongMessage, payload) + r.nc.wsEnqueueControlMsg(r.nl, wsPongMessage, payload) case wsPongMessage: // Nothing to do.. } @@ -644,14 +654,16 @@ func (nc *Conn) wsClose() { nc.wsEnqueueCloseMsgLocked(wsCloseStatusNormalClosure, _EMPTY_) } -func (nc *Conn) wsEnqueueCloseMsg(status int, payload string) { +func (nc *Conn) wsEnqueueCloseMsg(needsLock bool, status int, payload string) { // In some low-level unit tests it will happen... if nc == nil { return } - nc.mu.Lock() + if needsLock { + nc.mu.Lock() + defer nc.mu.Unlock() + } nc.wsEnqueueCloseMsgLocked(status, payload) - nc.mu.Unlock() } func (nc *Conn) wsEnqueueCloseMsgLocked(status int, payload string) { @@ -675,25 +687,26 @@ func (nc *Conn) wsEnqueueCloseMsgLocked(status int, payload string) { nc.bw.flush() } -func (nc *Conn) wsEnqueueControlMsg(frameType wsOpCode, payload []byte) { +func (nc *Conn) wsEnqueueControlMsg(needsLock bool, frameType wsOpCode, payload []byte) { // In some low-level unit tests it will happen... if nc == nil { return } - fh, key := wsCreateFrameHeader(false, frameType, len(payload)) - nc.mu.Lock() + if needsLock { + nc.mu.Lock() + defer nc.mu.Unlock() + } wr, ok := nc.bw.w.(*websocketWriter) if !ok { - nc.mu.Unlock() return } + fh, key := wsCreateFrameHeader(false, frameType, len(payload)) wr.ctrlFrames = append(wr.ctrlFrames, fh) if len(payload) > 0 { wsMaskBuf(key, payload) wr.ctrlFrames = append(wr.ctrlFrames, payload) } nc.bw.flush() - nc.mu.Unlock() } func wsPMCExtensionSupport(header http.Header) (bool, bool) { diff --git a/ws_test.go b/ws_test.go index 551ce3f05..84229b7ba 100644 --- a/ws_test.go +++ b/ws_test.go @@ -22,6 +22,7 @@ import ( "io" "math/rand" "reflect" + "runtime" "strings" "sync" "sync/atomic" @@ -599,7 +600,7 @@ func TestWSControlFrames(t *testing.T) { defer nc.Close() // Enqueue a PING and make sure that we don't break - nc.wsEnqueueControlMsg(wsPingMessage, []byte("this is a ping payload")) + nc.wsEnqueueControlMsg(true, wsPingMessage, []byte("this is a ping payload")) select { case e := <-dch: t.Fatal(e) @@ -1087,3 +1088,23 @@ func TestWSStress(t *testing.T) { }) } } + +func TestWSNoDeadlockOnAuthFailure(t *testing.T) { + o := testWSGetDefaultOptions(t, false) + o.Username = "user" + o.Password = "pwd" + s := RunServerWithOptions(o) + defer s.Shutdown() + + tm := time.AfterFunc(3*time.Second, func() { + buf := make([]byte, 1000000) + n := runtime.Stack(buf, true) + panic(fmt.Sprintf("Test has probably deadlocked!\n%s\n", buf[:n])) + }) + + if _, err := Connect(fmt.Sprintf("ws://127.0.0.1:%d", o.Websocket.Port)); err == nil { + t.Fatal("Expected auth error, did not get any error") + } + + tm.Stop() +}