Skip to content

Commit

Permalink
✨ update: tls.ClientHelloInfo in Ctx (#2011)
Browse files Browse the repository at this point in the history
* Update go.mod

* wip

* wip

* wip

* wip

* wip

* Move tlsHandler from Config to App

* Use NewError instead of panic

* Add a test with ServerName

* Add some tests on ClientHelloInfo

* fix missing import

* remove unnecessary ctx field.

Co-authored-by: RW <rene@gofiber.io>
Co-authored-by: Muhammed Efe Çetin <efectn@protonmail.com>
  • Loading branch information
3 people committed Aug 19, 2022
1 parent f031e08 commit 2edcf95
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 11 deletions.
2 changes: 2 additions & 0 deletions app.go
Expand Up @@ -114,6 +114,8 @@ type App struct {
// Latest route & group
latestRoute *Route
latestGroup *Group
// TLS handler
tlsHandler *tlsHandler
}

// Config is a struct holding the server settings.
Expand Down
21 changes: 21 additions & 0 deletions ctx.go
Expand Up @@ -7,6 +7,7 @@ package fiber
import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"encoding/xml"
"errors"
Expand Down Expand Up @@ -67,6 +68,17 @@ type Ctx struct {
viewBindMap *dictpool.Dict // Default view map to bind template engine
}

// tlsHandle object
type tlsHandler struct {
clientHelloInfo *tls.ClientHelloInfo
}

// GetClientInfo Callback function to set CHI
func (t *tlsHandler) GetClientInfo(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
t.clientHelloInfo = info
return nil, nil
}

// Range data for c.Range
type Range struct {
Type string
Expand Down Expand Up @@ -797,6 +809,15 @@ func (c *Ctx) MultipartForm() (*multipart.Form, error) {
return c.fasthttp.MultipartForm()
}

// ClientHelloInfo return CHI from context
func (c *Ctx) ClientHelloInfo() *tls.ClientHelloInfo {
if c.app.tlsHandler != nil {
return c.app.tlsHandler.clientHelloInfo
}

return nil
}

// Next executes the next method in the stack that matches the current route.
func (c *Ctx) Next() (err error) {
// Increment handler index
Expand Down
62 changes: 62 additions & 0 deletions ctx_test.go
Expand Up @@ -12,6 +12,7 @@ import (
"bytes"
"compress/gzip"
"context"
"crypto/tls"
"encoding/xml"
"errors"
"fmt"
Expand Down Expand Up @@ -1248,6 +1249,67 @@ func Test_Ctx_Method(t *testing.T) {
utils.AssertEqual(t, MethodPost, c.Method())
}

// go test -run Test_Ctx_ClientHelloInfo
func Test_Ctx_ClientHelloInfo(t *testing.T) {
t.Parallel()
app := New()
app.Get("/ServerName", func(c *Ctx) error {
result := c.ClientHelloInfo()
if result != nil {
return c.SendString(result.ServerName)
}

return c.SendString("ClientHelloInfo is nil")
})
app.Get("/SignatureSchemes", func(c *Ctx) error {
result := c.ClientHelloInfo()
if result != nil {
return c.JSON(result.SignatureSchemes)
}

return c.SendString("ClientHelloInfo is nil")
})
app.Get("/SupportedVersions", func(c *Ctx) error {
result := c.ClientHelloInfo()
if result != nil {
return c.JSON(result.SupportedVersions)
}

return c.SendString("ClientHelloInfo is nil")
})

// Test without TLS handler
resp, _ := app.Test(httptest.NewRequest(MethodGet, "/ServerName", nil))
body, _ := ioutil.ReadAll(resp.Body)
utils.AssertEqual(t, []byte("ClientHelloInfo is nil"), body)

// Test with TLS Handler
const (
PSSWithSHA256 = 0x0804
VersionTLS13 = 0x0304
)
app.tlsHandler = &tlsHandler{clientHelloInfo: &tls.ClientHelloInfo{
ServerName: "example.golang",
SignatureSchemes: []tls.SignatureScheme{PSSWithSHA256},
SupportedVersions: []uint16{VersionTLS13},
}}

// Test ServerName
resp, _ = app.Test(httptest.NewRequest(MethodGet, "/ServerName", nil))
body, _ = ioutil.ReadAll(resp.Body)
utils.AssertEqual(t, []byte("example.golang"), body)

// Test SignatureSchemes
resp, _ = app.Test(httptest.NewRequest(MethodGet, "/SignatureSchemes", nil))
body, _ = ioutil.ReadAll(resp.Body)
utils.AssertEqual(t, "["+strconv.Itoa(PSSWithSHA256)+"]", string(body))

// Test SupportedVersions
resp, _ = app.Test(httptest.NewRequest(MethodGet, "/SupportedVersions", nil))
body, _ = ioutil.ReadAll(resp.Body)
utils.AssertEqual(t, "["+strconv.Itoa(VersionTLS13)+"]", string(body))
}

// go test -run Test_Ctx_InvalidMethod
func Test_Ctx_InvalidMethod(t *testing.T) {
t.Parallel()
Expand Down
37 changes: 26 additions & 11 deletions listen.go
Expand Up @@ -82,22 +82,28 @@ func (app *App) ListenTLS(addr, certFile, keyFile string) error {
if len(certFile) == 0 || len(keyFile) == 0 {
return errors.New("tls: provide a valid cert or key path")
}
// Set TLS config with handler
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return fmt.Errorf("tls: cannot load TLS key pair from certFile=%q and keyFile=%q: %s", certFile, keyFile, err)
}
tlsHandler := &tlsHandler{}
config := &tls.Config{
MinVersion: tls.VersionTLS12,
Certificates: []tls.Certificate{
cert,
},
GetCertificate: tlsHandler.GetClientInfo,
}
// Prefork is supported
if app.config.Prefork {
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return fmt.Errorf("tls: cannot load TLS key pair from certFile=%q and keyFile=%q: %s", certFile, keyFile, err)
}
config := &tls.Config{
MinVersion: tls.VersionTLS12,
Certificates: []tls.Certificate{
cert,
},
}
return app.prefork(app.config.Network, addr, config)
}

// Setup listener
ln, err := net.Listen(app.config.Network, addr)
ln = tls.NewListener(ln, config)

if err != nil {
return err
}
Expand All @@ -111,8 +117,12 @@ func (app *App) ListenTLS(addr, certFile, keyFile string) error {
if app.config.EnablePrintRoutes {
app.printRoutesMessage()
}

// Attach the tlsHandler to the config
app.tlsHandler = tlsHandler

// Start listening
return app.server.ServeTLS(ln, certFile, keyFile)
return app.server.Serve(ln)
}

// ListenMutualTLS serves HTTPS requests from the given addr.
Expand All @@ -137,13 +147,15 @@ func (app *App) ListenMutualTLS(addr, certFile, keyFile, clientCertFile string)
clientCertPool := x509.NewCertPool()
clientCertPool.AppendCertsFromPEM(clientCACert)

tlsHandler := &tlsHandler{}
config := &tls.Config{
MinVersion: tls.VersionTLS12,
ClientAuth: tls.RequireAndVerifyClientCert,
ClientCAs: clientCertPool,
Certificates: []tls.Certificate{
cert,
},
GetCertificate: tlsHandler.GetClientInfo,
}

// Prefork is supported
Expand All @@ -170,6 +182,9 @@ func (app *App) ListenMutualTLS(addr, certFile, keyFile, clientCertFile string)
app.printRoutesMessage()
}

// Attach the tlsHandler to the config
app.tlsHandler = tlsHandler

// Start listening
return app.server.Serve(ln)
}
Expand Down

1 comment on commit 2edcf95

@ReneWerner87
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 2edcf95 Previous: 4adda50 Ratio
Benchmark_Ctx_Protocol 15.62 ns/op 0 B/op 0 allocs/op 2.814 ns/op 0 B/op 0 allocs/op 5.55

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.