Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ update: tls.ClientHelloInfo in Ctx #2011

Merged
merged 14 commits into from Aug 19, 2022
2 changes: 2 additions & 0 deletions app.go
Expand Up @@ -114,6 +114,8 @@ type App struct {
// Latest route & group
latestRoute *Route
latestGroup *Group
// TLS handler
tlsHandler *tlsHandler
}

// Config is a struct holding the server settings.
Expand Down
24 changes: 24 additions & 0 deletions ctx.go
Expand Up @@ -7,6 +7,7 @@ package fiber
import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"encoding/xml"
"errors"
Expand Down Expand Up @@ -65,6 +66,18 @@ type Ctx struct {
fasthttp *fasthttp.RequestCtx // Reference to *fasthttp.RequestCtx
matched bool // Non use route matched
viewBindMap *dictpool.Dict // Default view map to bind template engine
tlsHandler *tlsHandler // Contains information from a ClientHello message in order to guide application logic
}

// tlsHandle object
type tlsHandler struct {
clientHelloInfo *tls.ClientHelloInfo
}

// GetClientInfo Callback function to set CHI
func (t *tlsHandler) GetClientInfo(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
t.clientHelloInfo = info
return nil, nil
}

// Range data for c.Range
Expand Down Expand Up @@ -130,6 +143,8 @@ func (app *App) AcquireCtx(fctx *fasthttp.RequestCtx) *Ctx {
c.fasthttp = fctx
// reset base uri
c.baseURI = ""
// Attach tlsHandler object to context
c.tlsHandler = app.tlsHandler
efectn marked this conversation as resolved.
Show resolved Hide resolved
// Prettify path
c.configDependentPaths()
return c
Expand Down Expand Up @@ -797,6 +812,15 @@ func (c *Ctx) MultipartForm() (*multipart.Form, error) {
return c.fasthttp.MultipartForm()
}

// ClientHelloInfo return CHI from context
func (c *Ctx) ClientHelloInfo() *tls.ClientHelloInfo {
if c.tlsHandler != nil {
return c.tlsHandler.clientHelloInfo
}

return nil
}

// Next executes the next method in the stack that matches the current route.
func (c *Ctx) Next() (err error) {
// Increment handler index
Expand Down
62 changes: 62 additions & 0 deletions ctx_test.go
Expand Up @@ -12,6 +12,7 @@ import (
"bytes"
"compress/gzip"
"context"
"crypto/tls"
"encoding/xml"
"errors"
"fmt"
Expand Down Expand Up @@ -1248,6 +1249,67 @@ func Test_Ctx_Method(t *testing.T) {
utils.AssertEqual(t, MethodPost, c.Method())
}

// go test -run Test_Ctx_ClientHelloInfo
func Test_Ctx_ClientHelloInfo(t *testing.T) {
t.Parallel()
app := New()
app.Get("/ServerName", func(c *Ctx) error {
result := c.ClientHelloInfo()
if result != nil {
return c.SendString(result.ServerName)
}

return c.SendString("ClientHelloInfo is nil")
})
app.Get("/SignatureSchemes", func(c *Ctx) error {
result := c.ClientHelloInfo()
if result != nil {
return c.JSON(result.SignatureSchemes)
}

return c.SendString("ClientHelloInfo is nil")
})
app.Get("/SupportedVersions", func(c *Ctx) error {
result := c.ClientHelloInfo()
if result != nil {
return c.JSON(result.SupportedVersions)
}

return c.SendString("ClientHelloInfo is nil")
})

// Test without TLS handler
resp, _ := app.Test(httptest.NewRequest(MethodGet, "/ServerName", nil))
body, _ := ioutil.ReadAll(resp.Body)
utils.AssertEqual(t, []byte("ClientHelloInfo is nil"), body)

// Test with TLS Handler
const (
PSSWithSHA256 = 0x0804
VersionTLS13 = 0x0304
)
app.tlsHandler = &tlsHandler{clientHelloInfo: &tls.ClientHelloInfo{
ServerName: "example.golang",
SignatureSchemes: []tls.SignatureScheme{PSSWithSHA256},
SupportedVersions: []uint16{VersionTLS13},
}}

// Test ServerName
resp, _ = app.Test(httptest.NewRequest(MethodGet, "/ServerName", nil))
body, _ = ioutil.ReadAll(resp.Body)
utils.AssertEqual(t, []byte("example.golang"), body)

// Test SignatureSchemes
resp, _ = app.Test(httptest.NewRequest(MethodGet, "/SignatureSchemes", nil))
body, _ = ioutil.ReadAll(resp.Body)
utils.AssertEqual(t, "["+strconv.Itoa(PSSWithSHA256)+"]", string(body))

// Test SupportedVersions
resp, _ = app.Test(httptest.NewRequest(MethodGet, "/SupportedVersions", nil))
body, _ = ioutil.ReadAll(resp.Body)
utils.AssertEqual(t, "["+strconv.Itoa(VersionTLS13)+"]", string(body))
}

// go test -run Test_Ctx_InvalidMethod
func Test_Ctx_InvalidMethod(t *testing.T) {
t.Parallel()
Expand Down
37 changes: 26 additions & 11 deletions listen.go
Expand Up @@ -82,22 +82,28 @@ func (app *App) ListenTLS(addr, certFile, keyFile string) error {
if len(certFile) == 0 || len(keyFile) == 0 {
return errors.New("tls: provide a valid cert or key path")
}
// Set TLS config with handler
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return fmt.Errorf("tls: cannot load TLS key pair from certFile=%q and keyFile=%q: %s", certFile, keyFile, err)
}
tlsHandler := &tlsHandler{}
config := &tls.Config{
MinVersion: tls.VersionTLS12,
Certificates: []tls.Certificate{
cert,
},
GetCertificate: tlsHandler.GetClientInfo,
}
// Prefork is supported
if app.config.Prefork {
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return fmt.Errorf("tls: cannot load TLS key pair from certFile=%q and keyFile=%q: %s", certFile, keyFile, err)
}
config := &tls.Config{
MinVersion: tls.VersionTLS12,
Certificates: []tls.Certificate{
cert,
},
}
return app.prefork(app.config.Network, addr, config)
}

// Setup listener
ln, err := net.Listen(app.config.Network, addr)
ln = tls.NewListener(ln, config)

if err != nil {
return err
}
Expand All @@ -111,8 +117,12 @@ func (app *App) ListenTLS(addr, certFile, keyFile string) error {
if app.config.EnablePrintRoutes {
app.printRoutesMessage()
}

// Attach the tlsHandler to the config
app.tlsHandler = tlsHandler

// Start listening
return app.server.ServeTLS(ln, certFile, keyFile)
return app.server.Serve(ln)
}

// ListenMutualTLS serves HTTPS requests from the given addr.
Expand All @@ -137,13 +147,15 @@ func (app *App) ListenMutualTLS(addr, certFile, keyFile, clientCertFile string)
clientCertPool := x509.NewCertPool()
clientCertPool.AppendCertsFromPEM(clientCACert)

tlsHandler := &tlsHandler{}
config := &tls.Config{
MinVersion: tls.VersionTLS12,
ClientAuth: tls.RequireAndVerifyClientCert,
ClientCAs: clientCertPool,
Certificates: []tls.Certificate{
cert,
},
GetCertificate: tlsHandler.GetClientInfo,
}

// Prefork is supported
Expand All @@ -170,6 +182,9 @@ func (app *App) ListenMutualTLS(addr, certFile, keyFile, clientCertFile string)
app.printRoutesMessage()
}

// Attach the tlsHandler to the config
app.tlsHandler = tlsHandler

// Start listening
return app.server.Serve(ln)
}
Expand Down