From bcef8431c98087addcb2f0ab484ff295abe41a74 Mon Sep 17 00:00:00 2001 From: Gary Burd Date: Sat, 1 Jan 2022 08:43:22 -0800 Subject: [PATCH 1/7] Use context.Context in TLS handshake (#751) Continued work on #730. --- .circleci/config.yml | 2 +- client.go | 23 ++++++----------------- tls_handshake.go | 21 +++++++++++++++++++++ tls_handshake_116.go | 21 +++++++++++++++++++++ trace.go | 20 -------------------- trace_17.go | 13 ------------- 6 files changed, 49 insertions(+), 51 deletions(-) create mode 100644 tls_handshake.go create mode 100644 tls_handshake_116.go delete mode 100644 trace.go delete mode 100644 trace_17.go diff --git a/.circleci/config.yml b/.circleci/config.yml index 554a446..a0eb0ed 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -67,4 +67,4 @@ workflows: - test: matrix: parameters: - version: ["latest", "1.15", "1.14", "1.13", "1.12", "1.11"] + version: ["latest", "1.17", "1.16", "1.15", "1.14", "1.13", "1.12", "1.11"] diff --git a/client.go b/client.go index 196a659..a24c3ce 100644 --- a/client.go +++ b/client.go @@ -314,11 +314,12 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h tlsConn := tls.Client(netConn, cfg) netConn = tlsConn - var err error - if trace != nil { - err = doHandshakeWithTrace(trace, tlsConn, cfg) - } else { - err = doHandshake(tlsConn, cfg) + if trace != nil && trace.TLSHandshakeStart != nil { + trace.TLSHandshakeStart() + } + err := doHandshake(ctx, tlsConn, cfg) + if trace != nil && trace.TLSHandshakeDone != nil { + trace.TLSHandshakeDone(tlsConn.ConnectionState(), err) } if err != nil { @@ -383,15 +384,3 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h netConn = nil // to avoid close in defer. return conn, resp, nil } - -func doHandshake(tlsConn *tls.Conn, cfg *tls.Config) error { - if err := tlsConn.Handshake(); err != nil { - return err - } - if !cfg.InsecureSkipVerify { - if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil { - return err - } - } - return nil -} diff --git a/tls_handshake.go b/tls_handshake.go new file mode 100644 index 0000000..a62b68c --- /dev/null +++ b/tls_handshake.go @@ -0,0 +1,21 @@ +//go:build go1.17 +// +build go1.17 + +package websocket + +import ( + "context" + "crypto/tls" +) + +func doHandshake(ctx context.Context, tlsConn *tls.Conn, cfg *tls.Config) error { + if err := tlsConn.HandshakeContext(ctx); err != nil { + return err + } + if !cfg.InsecureSkipVerify { + if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil { + return err + } + } + return nil +} diff --git a/tls_handshake_116.go b/tls_handshake_116.go new file mode 100644 index 0000000..e1b2b44 --- /dev/null +++ b/tls_handshake_116.go @@ -0,0 +1,21 @@ +//go:build !go1.17 +// +build !go1.17 + +package websocket + +import ( + "context" + "crypto/tls" +) + +func doHandshake(ctx context.Context, tlsConn *tls.Conn, cfg *tls.Config) error { + if err := tlsConn.Handshake(); err != nil { + return err + } + if !cfg.InsecureSkipVerify { + if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil { + return err + } + } + return nil +} diff --git a/trace.go b/trace.go deleted file mode 100644 index 246a5d3..0000000 --- a/trace.go +++ /dev/null @@ -1,20 +0,0 @@ -//go:build go1.8 -// +build go1.8 - -package websocket - -import ( - "crypto/tls" - "net/http/httptrace" -) - -func doHandshakeWithTrace(trace *httptrace.ClientTrace, tlsConn *tls.Conn, cfg *tls.Config) error { - if trace.TLSHandshakeStart != nil { - trace.TLSHandshakeStart() - } - err := doHandshake(tlsConn, cfg) - if trace.TLSHandshakeDone != nil { - trace.TLSHandshakeDone(tlsConn.ConnectionState(), err) - } - return err -} diff --git a/trace_17.go b/trace_17.go deleted file mode 100644 index f4be940..0000000 --- a/trace_17.go +++ /dev/null @@ -1,13 +0,0 @@ -//go:build !go1.8 -// +build !go1.8 - -package websocket - -import ( - "crypto/tls" - "net/http/httptrace" -) - -func doHandshakeWithTrace(trace *httptrace.ClientTrace, tlsConn *tls.Conn, cfg *tls.Config) error { - return doHandshake(tlsConn, cfg) -} From beca1d39409212eff6678719a8ecf7761184adc8 Mon Sep 17 00:00:00 2001 From: Alexander Emelin Date: Sun, 2 Jan 2022 18:35:34 +0300 Subject: [PATCH 2/7] Fix broadcast benchmarks (#542) * do not use cached PreparedMessage in broadcast benchmarks * pick better name for benchmark method --- conn_broadcast_test.go | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/conn_broadcast_test.go b/conn_broadcast_test.go index cb88cbb..6e744fc 100644 --- a/conn_broadcast_test.go +++ b/conn_broadcast_test.go @@ -18,7 +18,6 @@ import ( // scenarios with many subscribers in one channel. type broadcastBench struct { w io.Writer - message *broadcastMessage closeCh chan struct{} doneCh chan struct{} count int32 @@ -52,14 +51,6 @@ func newBroadcastBench(usePrepared, compression bool) *broadcastBench { usePrepared: usePrepared, compression: compression, } - msg := &broadcastMessage{ - payload: textMessages(1)[0], - } - if usePrepared { - pm, _ := NewPreparedMessage(TextMessage, msg.payload) - msg.prepared = pm - } - bench.message = msg bench.makeConns(10000) return bench } @@ -78,7 +69,7 @@ func (b *broadcastBench) makeConns(numConns int) { for { select { case msg := <-c.msgCh: - if b.usePrepared { + if msg.prepared != nil { c.conn.WritePreparedMessage(msg.prepared) } else { c.conn.WriteMessage(TextMessage, msg.payload) @@ -100,9 +91,9 @@ func (b *broadcastBench) close() { close(b.closeCh) } -func (b *broadcastBench) runOnce() { +func (b *broadcastBench) broadcastOnce(msg *broadcastMessage) { for _, c := range b.conns { - c.msgCh <- b.message + c.msgCh <- msg } <-b.doneCh } @@ -114,17 +105,25 @@ func BenchmarkBroadcast(b *testing.B) { compression bool }{ {"NoCompression", false, false}, - {"WithCompression", false, true}, + {"Compression", false, true}, {"NoCompressionPrepared", true, false}, - {"WithCompressionPrepared", true, true}, + {"CompressionPrepared", true, true}, } + payload := textMessages(1)[0] for _, bm := range benchmarks { b.Run(bm.name, func(b *testing.B) { bench := newBroadcastBench(bm.usePrepared, bm.compression) defer bench.close() b.ResetTimer() for i := 0; i < b.N; i++ { - bench.runOnce() + message := &broadcastMessage{ + payload: payload, + } + if bench.usePrepared { + pm, _ := NewPreparedMessage(TextMessage, message.payload) + message.prepared = pm + } + bench.broadcastOnce(message) } b.ReportAllocs() }) From 2d6ee4c55cc9e9dc0eb5929f32a999213e25256f Mon Sep 17 00:00:00 2001 From: Gary Burd Date: Sun, 2 Jan 2022 11:21:21 -0800 Subject: [PATCH 3/7] Update autobahn example - Update instructions to use docker. - Cleanup config file. --- examples/autobahn/README.md | 9 +++++-- examples/autobahn/config/fuzzingclient.json | 29 +++++++++++++++++++++ examples/autobahn/fuzzingclient.json | 15 ----------- 3 files changed, 36 insertions(+), 17 deletions(-) create mode 100644 examples/autobahn/config/fuzzingclient.json delete mode 100644 examples/autobahn/fuzzingclient.json diff --git a/examples/autobahn/README.md b/examples/autobahn/README.md index dde8525..cc954fe 100644 --- a/examples/autobahn/README.md +++ b/examples/autobahn/README.md @@ -8,6 +8,11 @@ To test the server, run and start the client test driver - wstest -m fuzzingclient -s fuzzingclient.json + mkdir -p reports + docker run -it --rm \ + -v ${PWD}/config:/config \ + -v ${PWD}/reports:/reports \ + crossbario/autobahn-testsuite \ + wstest -m fuzzingclient -s /config/fuzzingclient.json -When the client completes, it writes a report to reports/clients/index.html. +When the client completes, it writes a report to reports/index.html. diff --git a/examples/autobahn/config/fuzzingclient.json b/examples/autobahn/config/fuzzingclient.json new file mode 100644 index 0000000..eda4e66 --- /dev/null +++ b/examples/autobahn/config/fuzzingclient.json @@ -0,0 +1,29 @@ +{ + "cases": ["*"], + "exclude-cases": [], + "exclude-agent-cases": {}, + "outdir": "/reports", + "options": {"failByDrop": false}, + "servers": [ + { + "agent": "ReadAllWriteMessage", + "url": "ws://host.docker.internal:9000/m" + }, + { + "agent": "ReadAllWritePreparedMessage", + "url": "ws://host.docker.internal:9000/p" + }, + { + "agent": "CopyFull", + "url": "ws://host.docker.internal:9000/f" + }, + { + "agent": "ReadAllWrite", + "url": "ws://host.docker.internal:9000/r" + }, + { + "agent": "CopyWriterOnly", + "url": "ws://host.docker.internal:9000/c" + } + ] +} diff --git a/examples/autobahn/fuzzingclient.json b/examples/autobahn/fuzzingclient.json deleted file mode 100644 index aa3a0bc..0000000 --- a/examples/autobahn/fuzzingclient.json +++ /dev/null @@ -1,15 +0,0 @@ - -{ - "options": {"failByDrop": false}, - "outdir": "./reports/clients", - "servers": [ - {"agent": "ReadAllWriteMessage", "url": "ws://localhost:9000/m", "options": {"version": 18}}, - {"agent": "ReadAllWritePreparedMessage", "url": "ws://localhost:9000/p", "options": {"version": 18}}, - {"agent": "ReadAllWrite", "url": "ws://localhost:9000/r", "options": {"version": 18}}, - {"agent": "CopyFull", "url": "ws://localhost:9000/f", "options": {"version": 18}}, - {"agent": "CopyWriterOnly", "url": "ws://localhost:9000/c", "options": {"version": 18}} - ], - "cases": ["*"], - "exclude-cases": [], - "exclude-agent-cases": {} -} From f0643a3a18bd24604a6131076f6419c7c518a956 Mon Sep 17 00:00:00 2001 From: Gary Burd Date: Sun, 2 Jan 2022 12:16:08 -0800 Subject: [PATCH 4/7] Improve protocol error messages To aid protocol error debugging, report all errors found in the first two bytes of a message header. --- conn.go | 57 ++++++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 40 insertions(+), 17 deletions(-) diff --git a/conn.go b/conn.go index ca46d2f..2296bf3 100644 --- a/conn.go +++ b/conn.go @@ -13,6 +13,7 @@ import ( "math/rand" "net" "strconv" + "strings" "sync" "time" "unicode/utf8" @@ -794,47 +795,69 @@ func (c *Conn) advanceFrame() (int, error) { } // 2. Read and parse first two bytes of frame header. + // To aid debugging, collect and report all errors in the first two bytes + // of the header. + + var errors []string p, err := c.read(2) if err != nil { return noFrame, err } - final := p[0]&finalBit != 0 frameType := int(p[0] & 0xf) + final := p[0]&finalBit != 0 + rsv1 := p[0]&rsv1Bit != 0 + rsv2 := p[0]&rsv2Bit != 0 + rsv3 := p[0]&rsv3Bit != 0 mask := p[1]&maskBit != 0 c.setReadRemaining(int64(p[1] & 0x7f)) c.readDecompress = false - if c.newDecompressionReader != nil && (p[0]&rsv1Bit) != 0 { - c.readDecompress = true - p[0] &^= rsv1Bit + if rsv1 { + if c.newDecompressionReader != nil { + c.readDecompress = true + } else { + errors = append(errors, "RSV1 set") + } } - if rsv := p[0] & (rsv1Bit | rsv2Bit | rsv3Bit); rsv != 0 { - return noFrame, c.handleProtocolError("unexpected reserved bits 0x" + strconv.FormatInt(int64(rsv), 16)) + if rsv2 { + errors = append(errors, "RSV2 set") + } + + if rsv3 { + errors = append(errors, "RSV3 set") } switch frameType { case CloseMessage, PingMessage, PongMessage: if c.readRemaining > maxControlFramePayloadSize { - return noFrame, c.handleProtocolError("control frame length > 125") + errors = append(errors, "len > 125 for control") } if !final { - return noFrame, c.handleProtocolError("control frame not final") + errors = append(errors, "FIN not set on control") } case TextMessage, BinaryMessage: if !c.readFinal { - return noFrame, c.handleProtocolError("message start before final message frame") + errors = append(errors, "data before FIN") } c.readFinal = final case continuationFrame: if c.readFinal { - return noFrame, c.handleProtocolError("continuation after final message frame") + errors = append(errors, "continuation after FIN") } c.readFinal = final default: - return noFrame, c.handleProtocolError("unknown opcode " + strconv.Itoa(frameType)) + errors = append(errors, "bad opcode "+strconv.Itoa(frameType)) + } + + if mask != c.isServer { + errors = append(errors, "bad MASK") + } + + if len(errors) > 0 { + return noFrame, c.handleProtocolError(strings.Join(errors, ", ")) } // 3. Read and parse frame length as per @@ -872,10 +895,6 @@ func (c *Conn) advanceFrame() (int, error) { // 4. Handle frame masking. - if mask != c.isServer { - return noFrame, c.handleProtocolError("incorrect mask flag") - } - if mask { c.readMaskPos = 0 p, err := c.read(len(c.readMaskKey)) @@ -935,7 +954,7 @@ func (c *Conn) advanceFrame() (int, error) { if len(payload) >= 2 { closeCode = int(binary.BigEndian.Uint16(payload)) if !isValidReceivedCloseCode(closeCode) { - return noFrame, c.handleProtocolError("invalid close code") + return noFrame, c.handleProtocolError("bad close code " + strconv.Itoa(closeCode)) } closeText = string(payload[2:]) if !utf8.ValidString(closeText) { @@ -952,7 +971,11 @@ func (c *Conn) advanceFrame() (int, error) { } func (c *Conn) handleProtocolError(message string) error { - c.WriteControl(CloseMessage, FormatCloseMessage(CloseProtocolError, message), time.Now().Add(writeWait)) + data := FormatCloseMessage(CloseProtocolError, message) + if len(data) > maxControlFramePayloadSize { + data = data[:maxControlFramePayloadSize] + } + c.WriteControl(CloseMessage, data, time.Now().Add(writeWait)) return errors.New("websocket: " + message) } From 4fad4036191b2a16b723b60a13a1690c5b8c27a9 Mon Sep 17 00:00:00 2001 From: Gary Burd Date: Sun, 2 Jan 2022 15:53:55 -0800 Subject: [PATCH 5/7] Remove support for Go 1.8 --- client.go | 7 +++++++ client_clone.go | 17 ----------------- client_clone_legacy.go | 39 --------------------------------------- conn.go | 6 ++++++ conn_write.go | 16 ---------------- conn_write_legacy.go | 19 ------------------- 6 files changed, 13 insertions(+), 91 deletions(-) delete mode 100644 client_clone.go delete mode 100644 client_clone_legacy.go delete mode 100644 conn_write.go delete mode 100644 conn_write_legacy.go diff --git a/client.go b/client.go index a24c3ce..ecbe584 100644 --- a/client.go +++ b/client.go @@ -384,3 +384,10 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h netConn = nil // to avoid close in defer. return conn, resp, nil } + +func cloneTLSConfig(cfg *tls.Config) *tls.Config { + if cfg == nil { + return &tls.Config{} + } + return cfg.Clone() +} diff --git a/client_clone.go b/client_clone.go deleted file mode 100644 index 4179c7a..0000000 --- a/client_clone.go +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -//go:build go1.8 -// +build go1.8 - -package websocket - -import "crypto/tls" - -func cloneTLSConfig(cfg *tls.Config) *tls.Config { - if cfg == nil { - return &tls.Config{} - } - return cfg.Clone() -} diff --git a/client_clone_legacy.go b/client_clone_legacy.go deleted file mode 100644 index 7e241a8..0000000 --- a/client_clone_legacy.go +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -//go:build !go1.8 -// +build !go1.8 - -package websocket - -import "crypto/tls" - -// cloneTLSConfig clones all public fields except the fields -// SessionTicketsDisabled and SessionTicketKey. This avoids copying the -// sync.Mutex in the sync.Once and makes it safe to call cloneTLSConfig on a -// config in active use. -func cloneTLSConfig(cfg *tls.Config) *tls.Config { - if cfg == nil { - return &tls.Config{} - } - return &tls.Config{ - Rand: cfg.Rand, - Time: cfg.Time, - Certificates: cfg.Certificates, - NameToCertificate: cfg.NameToCertificate, - GetCertificate: cfg.GetCertificate, - RootCAs: cfg.RootCAs, - NextProtos: cfg.NextProtos, - ServerName: cfg.ServerName, - ClientAuth: cfg.ClientAuth, - ClientCAs: cfg.ClientCAs, - InsecureSkipVerify: cfg.InsecureSkipVerify, - CipherSuites: cfg.CipherSuites, - PreferServerCipherSuites: cfg.PreferServerCipherSuites, - ClientSessionCache: cfg.ClientSessionCache, - MinVersion: cfg.MinVersion, - MaxVersion: cfg.MaxVersion, - CurvePreferences: cfg.CurvePreferences, - } -} diff --git a/conn.go b/conn.go index 2296bf3..331eebc 100644 --- a/conn.go +++ b/conn.go @@ -402,6 +402,12 @@ func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error return nil } +func (c *Conn) writeBufs(bufs ...[]byte) error { + b := net.Buffers(bufs) + _, err := b.WriteTo(c.conn) + return err +} + // WriteControl writes a control message with the given deadline. The allowed // message types are CloseMessage, PingMessage and PongMessage. func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) error { diff --git a/conn_write.go b/conn_write.go deleted file mode 100644 index 497467a..0000000 --- a/conn_write.go +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -//go:build go1.8 -// +build go1.8 - -package websocket - -import "net" - -func (c *Conn) writeBufs(bufs ...[]byte) error { - b := net.Buffers(bufs) - _, err := b.WriteTo(c.conn) - return err -} diff --git a/conn_write_legacy.go b/conn_write_legacy.go deleted file mode 100644 index 8501a23..0000000 --- a/conn_write_legacy.go +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -//go:build !go1.8 -// +build !go1.8 - -package websocket - -func (c *Conn) writeBufs(bufs ...[]byte) error { - for _, buf := range bufs { - if len(buf) > 0 { - if _, err := c.conn.Write(buf); err != nil { - return err - } - } - } - return nil -} From 2f25f7843d3d0e4889e5e008dcbdd77fec378deb Mon Sep 17 00:00:00 2001 From: Gary Burd Date: Mon, 3 Jan 2022 17:49:10 -0800 Subject: [PATCH 6/7] Update README (#757) - Note that a new maintainer is needed. - Remove comparison with x/net/websocket. There's no need to describe the issues with that package now that the package's documentation points people here and elsewhere. --- README.md | 39 +++++++-------------------------------- 1 file changed, 7 insertions(+), 32 deletions(-) diff --git a/README.md b/README.md index 19aa2e7..2517a28 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,13 @@ Gorilla WebSocket is a [Go](http://golang.org/) implementation of the [WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol. + +--- + +⚠️ **[The Gorilla WebSocket Package is looking for a new maintainer](https://github.com/gorilla/websocket/issues/370)** + +--- + ### Documentation * [API Reference](https://pkg.go.dev/github.com/gorilla/websocket?tab=doc) @@ -30,35 +37,3 @@ The Gorilla WebSocket package passes the server tests in the [Autobahn Test Suite](https://github.com/crossbario/autobahn-testsuite) using the application in the [examples/autobahn subdirectory](https://github.com/gorilla/websocket/tree/master/examples/autobahn). -### Gorilla WebSocket compared with other packages - - - - - - - - - - - - - - - - - - -
github.com/gorillagolang.org/x/net
RFC 6455 Features
Passes Autobahn Test SuiteYesNo
Receive fragmented messageYesNo, see note 1
Send close messageYesNo
Send pings and receive pongsYesNo
Get the type of a received data messageYesYes, see note 2
Other Features
Compression ExtensionsExperimentalNo
Read message using io.ReaderYesNo, see note 3
Write message using io.WriteCloserYesNo, see note 3
- -Notes: - -1. Large messages are fragmented in [Chrome's new WebSocket implementation](http://www.ietf.org/mail-archive/web/hybi/current/msg10503.html). -2. The application can get the type of a received data message by implementing - a [Codec marshal](http://godoc.org/golang.org/x/net/websocket#Codec.Marshal) - function. -3. The go.net io.Reader and io.Writer operate across WebSocket frame boundaries. - Read returns when the input buffer is full or a frame boundary is - encountered. Each call to Write sends a single frame message. The Gorilla - io.Reader and io.WriteCloser operate on a single WebSocket message. - From 9111bb834a68b893cebbbaed5060bdbc1d9ab7d2 Mon Sep 17 00:00:00 2001 From: Lluis Campos Date: Tue, 4 Jan 2022 02:59:52 +0100 Subject: [PATCH 7/7] Dialer: add optional method NetDialTLSContext (#746) Fixes issue: https://github.com/gorilla/websocket/issues/745 With the previous interface, NetDial and NetDialContext were used for both TLS and non-TLS TCP connections, and afterwards TLSClientConfig was used to do the TLS handshake. While this API works for most cases, it prevents from using more advance authentication methods during the TLS handshake, as this is out of the control of the user. This commits introduces another a new dial method, NetDialTLSContext, which is used when dialing for TLS/TCP. The code then assumes that the handshake is done there and TLSClientConfig is not used. This API change is fully backwards compatible and it better aligns with net/http.Transport API, which has these two dial flavors. See: https://pkg.go.dev/net/http#Transport Signed-off-by: Lluis Campos --- client.go | 45 +++++++++-- client_server_test.go | 178 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 215 insertions(+), 8 deletions(-) diff --git a/client.go b/client.go index ecbe584..2efd835 100644 --- a/client.go +++ b/client.go @@ -56,9 +56,15 @@ type Dialer struct { NetDial func(network, addr string) (net.Conn, error) // NetDialContext specifies the dial function for creating TCP connections. If - // NetDialContext is nil, net.DialContext is used. + // NetDialContext is nil, NetDial is used. NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error) + // NetDialTLSContext specifies the dial function for creating TLS/TCP connections. If + // NetDialTLSContext is nil, NetDialContext is used. + // If NetDialTLSContext is set, Dial assumes the TLS handshake is done there and + // TLSClientConfig is ignored. + NetDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error) + // Proxy specifies a function to return a proxy for a given // Request. If the function returns a non-nil error, the // request is aborted with the provided error. @@ -67,6 +73,8 @@ type Dialer struct { // TLSClientConfig specifies the TLS configuration to use with tls.Client. // If nil, the default configuration is used. + // If either NetDialTLS or NetDialTLSContext are set, Dial assumes the TLS handshake + // is done there and TLSClientConfig is ignored. TLSClientConfig *tls.Config // HandshakeTimeout specifies the duration for the handshake to complete. @@ -239,13 +247,32 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h // Get network dial function. var netDial func(network, add string) (net.Conn, error) - if d.NetDialContext != nil { - netDial = func(network, addr string) (net.Conn, error) { - return d.NetDialContext(ctx, network, addr) + switch u.Scheme { + case "http": + if d.NetDialContext != nil { + netDial = func(network, addr string) (net.Conn, error) { + return d.NetDialContext(ctx, network, addr) + } + } else if d.NetDial != nil { + netDial = d.NetDial } - } else if d.NetDial != nil { - netDial = d.NetDial - } else { + case "https": + if d.NetDialTLSContext != nil { + netDial = func(network, addr string) (net.Conn, error) { + return d.NetDialTLSContext(ctx, network, addr) + } + } else if d.NetDialContext != nil { + netDial = func(network, addr string) (net.Conn, error) { + return d.NetDialContext(ctx, network, addr) + } + } else if d.NetDial != nil { + netDial = d.NetDial + } + default: + return nil, nil, errMalformedURL + } + + if netDial == nil { netDialer := &net.Dialer{} netDial = func(network, addr string) (net.Conn, error) { return netDialer.DialContext(ctx, network, addr) @@ -306,7 +333,9 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h } }() - if u.Scheme == "https" { + if u.Scheme == "https" && d.NetDialTLSContext == nil { + // If NetDialTLSContext is set, assume that the TLS handshake has already been done + cfg := cloneTLSConfig(d.TLSClientConfig) if cfg.ServerName == "" { cfg.ServerName = hostNoPort diff --git a/client_server_test.go b/client_server_test.go index c03f2a9..e975e51 100644 --- a/client_server_test.go +++ b/client_server_test.go @@ -11,6 +11,7 @@ import ( "crypto/x509" "encoding/base64" "encoding/binary" + "errors" "fmt" "io" "io/ioutil" @@ -920,3 +921,180 @@ func TestEmptyTracingDialWithContext(t *testing.T) { defer ws.Close() sendRecv(t, ws) } + +// TestNetDialConnect tests selection of dial method between NetDial, NetDialContext, NetDialTLS or NetDialTLSContext +func TestNetDialConnect(t *testing.T) { + + upgrader := Upgrader{} + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if IsWebSocketUpgrade(r) { + c, err := upgrader.Upgrade(w, r, http.Header{"X-Test-Host": {r.Host}}) + if err != nil { + t.Fatal(err) + } + c.Close() + } else { + w.Header().Set("X-Test-Host", r.Host) + } + }) + + server := httptest.NewServer(handler) + defer server.Close() + + tlsServer := httptest.NewTLSServer(handler) + defer tlsServer.Close() + + testUrls := map[*httptest.Server]string{ + server: "ws://" + server.Listener.Addr().String() + "/", + tlsServer: "wss://" + tlsServer.Listener.Addr().String() + "/", + } + + cas := rootCAs(t, tlsServer) + tlsConfig := &tls.Config{ + RootCAs: cas, + ServerName: "example.com", + InsecureSkipVerify: false, + } + + tests := []struct { + name string + server *httptest.Server // server to use + netDial func(network, addr string) (net.Conn, error) + netDialContext func(ctx context.Context, network, addr string) (net.Conn, error) + netDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error) + tlsClientConfig *tls.Config + }{ + + { + name: "HTTP server, all NetDial* defined, shall use NetDialContext", + server: server, + netDial: func(network, addr string) (net.Conn, error) { + return nil, errors.New("NetDial should not be called") + }, + netDialContext: func(_ context.Context, network, addr string) (net.Conn, error) { + return net.Dial(network, addr) + }, + netDialTLSContext: func(_ context.Context, network, addr string) (net.Conn, error) { + return nil, errors.New("NetDialTLSContext should not be called") + }, + tlsClientConfig: nil, + }, + { + name: "HTTP server, all NetDial* undefined", + server: server, + netDial: nil, + netDialContext: nil, + netDialTLSContext: nil, + tlsClientConfig: nil, + }, + { + name: "HTTP server, NetDialContext undefined, shall fallback to NetDial", + server: server, + netDial: func(network, addr string) (net.Conn, error) { + return net.Dial(network, addr) + }, + netDialContext: nil, + netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return nil, errors.New("NetDialTLSContext should not be called") + }, + tlsClientConfig: nil, + }, + { + name: "HTTPS server, all NetDial* defined, shall use NetDialTLSContext", + server: tlsServer, + netDial: func(network, addr string) (net.Conn, error) { + return nil, errors.New("NetDial should not be called") + }, + netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return nil, errors.New("NetDialContext should not be called") + }, + netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + netConn, err := net.Dial(network, addr) + if err != nil { + return nil, err + } + tlsConn := tls.Client(netConn, tlsConfig) + err = tlsConn.Handshake() + if err != nil { + return nil, err + } + return tlsConn, nil + }, + tlsClientConfig: nil, + }, + { + name: "HTTPS server, NetDialTLSContext undefined, shall fallback to NetDialContext and do handshake", + server: tlsServer, + netDial: func(network, addr string) (net.Conn, error) { + return nil, errors.New("NetDial should not be called") + }, + netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return net.Dial(network, addr) + }, + netDialTLSContext: nil, + tlsClientConfig: tlsConfig, + }, + { + name: "HTTPS server, NetDialTLSContext and NetDialContext undefined, shall fallback to NetDial and do handshake", + server: tlsServer, + netDial: func(network, addr string) (net.Conn, error) { + return net.Dial(network, addr) + }, + netDialContext: nil, + netDialTLSContext: nil, + tlsClientConfig: tlsConfig, + }, + { + name: "HTTPS server, all NetDial* undefined", + server: tlsServer, + netDial: nil, + netDialContext: nil, + netDialTLSContext: nil, + tlsClientConfig: tlsConfig, + }, + { + name: "HTTPS server, all NetDialTLSContext defined, dummy TlsClientConfig defined, shall not do handshake", + server: tlsServer, + netDial: func(network, addr string) (net.Conn, error) { + return nil, errors.New("NetDial should not be called") + }, + netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return nil, errors.New("NetDialContext should not be called") + }, + netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + netConn, err := net.Dial(network, addr) + if err != nil { + return nil, err + } + tlsConn := tls.Client(netConn, tlsConfig) + err = tlsConn.Handshake() + if err != nil { + return nil, err + } + return tlsConn, nil + }, + tlsClientConfig: &tls.Config{ + RootCAs: nil, + ServerName: "badserver.com", + InsecureSkipVerify: false, + }, + }, + } + + for _, tc := range tests { + dialer := Dialer{ + NetDial: tc.netDial, + NetDialContext: tc.netDialContext, + NetDialTLSContext: tc.netDialTLSContext, + TLSClientConfig: tc.tlsClientConfig, + } + + // Test websocket dial + c, _, err := dialer.Dial(testUrls[tc.server], nil) + if err != nil { + t.Errorf("FAILED %s, err: %s", tc.name, err.Error()) + } else { + c.Close() + } + } +}