Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIXED] Websocket: deadlock on authentication failure #926

Merged
merged 1 commit into from Mar 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 15 additions & 0 deletions nats.go
Expand Up @@ -1649,6 +1649,17 @@ func (w *natsWriter) doneWithPending() {
w.pending = nil
}

// Notify the reader that we are done with the connect, where "read" operations
// happen synchronously and under the connection lock. After this point, "read"
// will be happening from the read loop, without the connection lock.
//
// Note: this runs under the connection lock.
func (r *natsReader) doneWithConnect() {
if wsr, ok := r.r.(*websocketReader); ok {
wsr.doneWithConnect()
}
}

func (r *natsReader) Read() ([]byte, error) {
if r.off >= 0 {
off := r.off
Expand Down Expand Up @@ -1977,6 +1988,10 @@ func (nc *Conn) processConnectInit() error {
go nc.readLoop()
go nc.flusher()

// Notify the reader that we are done with the connect handshake, where
// reads were done synchronously and under the connection lock.
nc.br.doneWithConnect()

return nil
}

Expand Down
33 changes: 23 additions & 10 deletions ws.go
Expand Up @@ -81,6 +81,7 @@ type websocketReader struct {
ib []byte
ff bool
fc bool
nl bool
dc *wsDecompressor
nc *Conn
}
Expand Down Expand Up @@ -180,6 +181,15 @@ func wsNewReader(r io.Reader) *websocketReader {
return &websocketReader{r: r, ff: true}
}

// From now on, reads will be from the readLoop and we will need to
// acquire the connection lock should we have to send/write a control
// message from handleControlFrame.
//
// Note: this runs under the connection lock.
func (r *websocketReader) doneWithConnect() {
r.nl = true
}

func (r *websocketReader) Read(p []byte) (int, error) {
var err error
var buf []byte
Expand Down Expand Up @@ -402,12 +412,12 @@ func (r *websocketReader) handleControlFrame(frameType wsOpCode, buf []byte, pos
}
}
}
r.nc.wsEnqueueCloseMsg(status, body)
r.nc.wsEnqueueCloseMsg(r.nl, status, body)
// Return io.EOF so that readLoop will close the connection as client closed
// after processing pending buffers.
return pos, io.EOF
case wsPingMessage:
r.nc.wsEnqueueControlMsg(wsPongMessage, payload)
r.nc.wsEnqueueControlMsg(r.nl, wsPongMessage, payload)
case wsPongMessage:
// Nothing to do..
}
Expand Down Expand Up @@ -644,14 +654,16 @@ func (nc *Conn) wsClose() {
nc.wsEnqueueCloseMsgLocked(wsCloseStatusNormalClosure, _EMPTY_)
}

func (nc *Conn) wsEnqueueCloseMsg(status int, payload string) {
func (nc *Conn) wsEnqueueCloseMsg(needsLock bool, status int, payload string) {
// In some low-level unit tests it will happen...
if nc == nil {
return
}
nc.mu.Lock()
if needsLock {
nc.mu.Lock()
defer nc.mu.Unlock()
}
nc.wsEnqueueCloseMsgLocked(status, payload)
nc.mu.Unlock()
}

func (nc *Conn) wsEnqueueCloseMsgLocked(status int, payload string) {
Expand All @@ -675,25 +687,26 @@ func (nc *Conn) wsEnqueueCloseMsgLocked(status int, payload string) {
nc.bw.flush()
}

func (nc *Conn) wsEnqueueControlMsg(frameType wsOpCode, payload []byte) {
func (nc *Conn) wsEnqueueControlMsg(needsLock bool, frameType wsOpCode, payload []byte) {
// In some low-level unit tests it will happen...
if nc == nil {
return
}
fh, key := wsCreateFrameHeader(false, frameType, len(payload))
nc.mu.Lock()
if needsLock {
nc.mu.Lock()
defer nc.mu.Unlock()
}
wr, ok := nc.bw.w.(*websocketWriter)
if !ok {
nc.mu.Unlock()
return
}
fh, key := wsCreateFrameHeader(false, frameType, len(payload))
wr.ctrlFrames = append(wr.ctrlFrames, fh)
if len(payload) > 0 {
wsMaskBuf(key, payload)
wr.ctrlFrames = append(wr.ctrlFrames, payload)
}
nc.bw.flush()
nc.mu.Unlock()
}

func wsPMCExtensionSupport(header http.Header) (bool, bool) {
Expand Down
23 changes: 22 additions & 1 deletion ws_test.go
Expand Up @@ -22,6 +22,7 @@ import (
"io"
"math/rand"
"reflect"
"runtime"
"strings"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -599,7 +600,7 @@ func TestWSControlFrames(t *testing.T) {
defer nc.Close()

// Enqueue a PING and make sure that we don't break
nc.wsEnqueueControlMsg(wsPingMessage, []byte("this is a ping payload"))
nc.wsEnqueueControlMsg(true, wsPingMessage, []byte("this is a ping payload"))
select {
case e := <-dch:
t.Fatal(e)
Expand Down Expand Up @@ -1087,3 +1088,23 @@ func TestWSStress(t *testing.T) {
})
}
}

func TestWSNoDeadlockOnAuthFailure(t *testing.T) {
o := testWSGetDefaultOptions(t, false)
o.Username = "user"
o.Password = "pwd"
s := RunServerWithOptions(o)
defer s.Shutdown()

tm := time.AfterFunc(3*time.Second, func() {
buf := make([]byte, 1000000)
n := runtime.Stack(buf, true)
panic(fmt.Sprintf("Test has probably deadlocked!\n%s\n", buf[:n]))
})

if _, err := Connect(fmt.Sprintf("ws://127.0.0.1:%d", o.Websocket.Port)); err == nil {
t.Fatal("Expected auth error, did not get any error")
}

tm.Stop()
}