From 2edcf95f57104c5c43bb44531b0563c476c7e1e7 Mon Sep 17 00:00:00 2001 From: Thomas Date: Fri, 19 Aug 2022 08:19:22 +0200 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20update:=20tls.ClientHelloInfo=20in?= =?UTF-8?q?=20Ctx=20(#2011)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Update go.mod * wip * wip * wip * wip * wip * Move tlsHandler from Config to App * Use NewError instead of panic * Add a test with ServerName * Add some tests on ClientHelloInfo * fix missing import * remove unnecessary ctx field. Co-authored-by: RW Co-authored-by: Muhammed Efe Çetin --- app.go | 2 ++ ctx.go | 21 ++++++++++++++++++ ctx_test.go | 62 +++++++++++++++++++++++++++++++++++++++++++++++++++++ listen.go | 37 ++++++++++++++++++++++---------- 4 files changed, 111 insertions(+), 11 deletions(-) diff --git a/app.go b/app.go index fc2b68e0bd..f311617e3a 100644 --- a/app.go +++ b/app.go @@ -114,6 +114,8 @@ type App struct { // Latest route & group latestRoute *Route latestGroup *Group + // TLS handler + tlsHandler *tlsHandler } // Config is a struct holding the server settings. diff --git a/ctx.go b/ctx.go index 30aed27682..d8818ff855 100644 --- a/ctx.go +++ b/ctx.go @@ -7,6 +7,7 @@ package fiber import ( "bytes" "context" + "crypto/tls" "encoding/json" "encoding/xml" "errors" @@ -67,6 +68,17 @@ type Ctx struct { viewBindMap *dictpool.Dict // Default view map to bind template engine } +// tlsHandle object +type tlsHandler struct { + clientHelloInfo *tls.ClientHelloInfo +} + +// GetClientInfo Callback function to set CHI +func (t *tlsHandler) GetClientInfo(info *tls.ClientHelloInfo) (*tls.Certificate, error) { + t.clientHelloInfo = info + return nil, nil +} + // Range data for c.Range type Range struct { Type string @@ -797,6 +809,15 @@ func (c *Ctx) MultipartForm() (*multipart.Form, error) { return c.fasthttp.MultipartForm() } +// ClientHelloInfo return CHI from context +func (c *Ctx) ClientHelloInfo() *tls.ClientHelloInfo { + if c.app.tlsHandler != nil { + return c.app.tlsHandler.clientHelloInfo + } + + return nil +} + // Next executes the next method in the stack that matches the current route. func (c *Ctx) Next() (err error) { // Increment handler index diff --git a/ctx_test.go b/ctx_test.go index 3af6e806b0..825cbb8966 100644 --- a/ctx_test.go +++ b/ctx_test.go @@ -12,6 +12,7 @@ import ( "bytes" "compress/gzip" "context" + "crypto/tls" "encoding/xml" "errors" "fmt" @@ -1248,6 +1249,67 @@ func Test_Ctx_Method(t *testing.T) { utils.AssertEqual(t, MethodPost, c.Method()) } +// go test -run Test_Ctx_ClientHelloInfo +func Test_Ctx_ClientHelloInfo(t *testing.T) { + t.Parallel() + app := New() + app.Get("/ServerName", func(c *Ctx) error { + result := c.ClientHelloInfo() + if result != nil { + return c.SendString(result.ServerName) + } + + return c.SendString("ClientHelloInfo is nil") + }) + app.Get("/SignatureSchemes", func(c *Ctx) error { + result := c.ClientHelloInfo() + if result != nil { + return c.JSON(result.SignatureSchemes) + } + + return c.SendString("ClientHelloInfo is nil") + }) + app.Get("/SupportedVersions", func(c *Ctx) error { + result := c.ClientHelloInfo() + if result != nil { + return c.JSON(result.SupportedVersions) + } + + return c.SendString("ClientHelloInfo is nil") + }) + + // Test without TLS handler + resp, _ := app.Test(httptest.NewRequest(MethodGet, "/ServerName", nil)) + body, _ := ioutil.ReadAll(resp.Body) + utils.AssertEqual(t, []byte("ClientHelloInfo is nil"), body) + + // Test with TLS Handler + const ( + PSSWithSHA256 = 0x0804 + VersionTLS13 = 0x0304 + ) + app.tlsHandler = &tlsHandler{clientHelloInfo: &tls.ClientHelloInfo{ + ServerName: "example.golang", + SignatureSchemes: []tls.SignatureScheme{PSSWithSHA256}, + SupportedVersions: []uint16{VersionTLS13}, + }} + + // Test ServerName + resp, _ = app.Test(httptest.NewRequest(MethodGet, "/ServerName", nil)) + body, _ = ioutil.ReadAll(resp.Body) + utils.AssertEqual(t, []byte("example.golang"), body) + + // Test SignatureSchemes + resp, _ = app.Test(httptest.NewRequest(MethodGet, "/SignatureSchemes", nil)) + body, _ = ioutil.ReadAll(resp.Body) + utils.AssertEqual(t, "["+strconv.Itoa(PSSWithSHA256)+"]", string(body)) + + // Test SupportedVersions + resp, _ = app.Test(httptest.NewRequest(MethodGet, "/SupportedVersions", nil)) + body, _ = ioutil.ReadAll(resp.Body) + utils.AssertEqual(t, "["+strconv.Itoa(VersionTLS13)+"]", string(body)) +} + // go test -run Test_Ctx_InvalidMethod func Test_Ctx_InvalidMethod(t *testing.T) { t.Parallel() diff --git a/listen.go b/listen.go index d2ab361ef4..4d8d25b72d 100644 --- a/listen.go +++ b/listen.go @@ -82,22 +82,28 @@ func (app *App) ListenTLS(addr, certFile, keyFile string) error { if len(certFile) == 0 || len(keyFile) == 0 { return errors.New("tls: provide a valid cert or key path") } + // Set TLS config with handler + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return fmt.Errorf("tls: cannot load TLS key pair from certFile=%q and keyFile=%q: %s", certFile, keyFile, err) + } + tlsHandler := &tlsHandler{} + config := &tls.Config{ + MinVersion: tls.VersionTLS12, + Certificates: []tls.Certificate{ + cert, + }, + GetCertificate: tlsHandler.GetClientInfo, + } // Prefork is supported if app.config.Prefork { - cert, err := tls.LoadX509KeyPair(certFile, keyFile) - if err != nil { - return fmt.Errorf("tls: cannot load TLS key pair from certFile=%q and keyFile=%q: %s", certFile, keyFile, err) - } - config := &tls.Config{ - MinVersion: tls.VersionTLS12, - Certificates: []tls.Certificate{ - cert, - }, - } return app.prefork(app.config.Network, addr, config) } + // Setup listener ln, err := net.Listen(app.config.Network, addr) + ln = tls.NewListener(ln, config) + if err != nil { return err } @@ -111,8 +117,12 @@ func (app *App) ListenTLS(addr, certFile, keyFile string) error { if app.config.EnablePrintRoutes { app.printRoutesMessage() } + + // Attach the tlsHandler to the config + app.tlsHandler = tlsHandler + // Start listening - return app.server.ServeTLS(ln, certFile, keyFile) + return app.server.Serve(ln) } // ListenMutualTLS serves HTTPS requests from the given addr. @@ -137,6 +147,7 @@ func (app *App) ListenMutualTLS(addr, certFile, keyFile, clientCertFile string) clientCertPool := x509.NewCertPool() clientCertPool.AppendCertsFromPEM(clientCACert) + tlsHandler := &tlsHandler{} config := &tls.Config{ MinVersion: tls.VersionTLS12, ClientAuth: tls.RequireAndVerifyClientCert, @@ -144,6 +155,7 @@ func (app *App) ListenMutualTLS(addr, certFile, keyFile, clientCertFile string) Certificates: []tls.Certificate{ cert, }, + GetCertificate: tlsHandler.GetClientInfo, } // Prefork is supported @@ -170,6 +182,9 @@ func (app *App) ListenMutualTLS(addr, certFile, keyFile, clientCertFile string) app.printRoutesMessage() } + // Attach the tlsHandler to the config + app.tlsHandler = tlsHandler + // Start listening return app.server.Serve(ln) }