Skip to content

Commit

Permalink
Add timeouts to tls Handshake
Browse files Browse the repository at this point in the history
Fixes #813 which suggested the code.
  • Loading branch information
erikdubbelboer committed May 25, 2020
1 parent 123f6a8 commit 2f92c68
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 1 deletion.
14 changes: 13 additions & 1 deletion server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1419,6 +1419,18 @@ func (s *Server) NextProto(key string, nph ServeHandler) {

func (s *Server) getNextProto(c net.Conn) (proto string, err error) {
if tlsConn, ok := c.(connTLSer); ok {
if s.ReadTimeout > 0 {
if err := c.SetReadDeadline(time.Now().Add(s.ReadTimeout)); err != nil {
panic(fmt.Sprintf("BUG: error in SetReadDeadline(%s): %s", s.ReadTimeout, err))
}
}

if s.WriteTimeout > 0 {
if err := c.SetWriteDeadline(time.Now().Add(s.WriteTimeout)); err != nil {
panic(fmt.Sprintf("BUG: error in SetWriteDeadline(%s): %s", s.WriteTimeout, err))
}
}

err = tlsConn.Handshake()
if err == nil {
proto = tlsConn.ConnectionState().NegotiatedProtocol
Expand Down Expand Up @@ -2179,7 +2191,7 @@ func (s *Server) serveConn(c net.Conn) (err error) {

if writeTimeout > 0 {
if err := c.SetWriteDeadline(time.Now().Add(writeTimeout)); err != nil {
panic(fmt.Sprintf("BUG: error in SetWriteDeadline(%s): %s", s.WriteTimeout, err))
panic(fmt.Sprintf("BUG: error in SetWriteDeadline(%s): %s", writeTimeout, err))
}
}

Expand Down
50 changes: 50 additions & 0 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -940,6 +940,56 @@ func TestServerTLS(t *testing.T) {
}
}

func TestServerTLSReadTimeout(t *testing.T) {
t.Parallel()

ln := fasthttputil.NewInmemoryListener()

certFile := "./ssl-cert-snakeoil.pem"
keyFile := "./ssl-cert-snakeoil.key"

s := &Server{
ReadTimeout: time.Millisecond * 50,
Logger: &testLogger{}, // Ignore log output.
Handler: func(ctx *RequestCtx) {
},
}

err := s.AppendCert(certFile, keyFile)
if err != nil {
t.Fatal(err)
}
go func() {
err = s.ServeTLS(ln, "", "")
if err != nil {
t.Error(err)
}
}()

c, err := ln.Dial()
if err != nil {
t.Error(err)
}

r := make(chan error)

go func() {
b := make([]byte, 1)
_, err := c.Read(b)
c.Close()
r <- err
}()

select {
case err = <-r:
case <-time.After(time.Millisecond * 100):
}

if err == nil {
t.Error("server didn't close connection after timeout")
}
}

func TestServerServeTLSEmbed(t *testing.T) {
t.Parallel()

Expand Down

0 comments on commit 2f92c68

Please sign in to comment.