Skip to content

Commit a04cfcc

Browse files
jkralikzllovesuki
andcommittedAug 15, 2022
Implement GetCertificate and GetClientCertificate
The goal is to close the feature parity gap with stdlib's tls package. Co-authored-by: Rachel Chen <rachel@chens.email>
1 parent 43968a2 commit a04cfcc

15 files changed

+519
-173
lines changed
 

‎certificate.go

+108-21
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,102 @@
11
package dtls
22

33
import (
4+
"bytes"
45
"crypto/tls"
56
"crypto/x509"
7+
"fmt"
68
"strings"
79
)
810

9-
func (c *handshakeConfig) getCertificate(serverName string) (*tls.Certificate, error) {
10-
c.mu.Lock()
11-
defer c.mu.Unlock()
11+
// ClientHelloInfo contains information from a ClientHello message in order to
12+
// guide application logic in the GetCertificate.
13+
type ClientHelloInfo struct {
14+
// ServerName indicates the name of the server requested by the client
15+
// in order to support virtual hosting. ServerName is only set if the
16+
// client is using SNI (see RFC 4366, Section 3.1).
17+
ServerName string
1218

13-
if c.nameToCertificate == nil {
14-
nameToCertificate := make(map[string]*tls.Certificate)
15-
for i := range c.localCertificates {
16-
cert := &c.localCertificates[i]
17-
x509Cert := cert.Leaf
18-
if x509Cert == nil {
19-
var parseErr error
20-
x509Cert, parseErr = x509.ParseCertificate(cert.Certificate[0])
21-
if parseErr != nil {
22-
continue
23-
}
19+
// CipherSuites lists the CipherSuites supported by the client (e.g.
20+
// TLS_AES_128_GCM_SHA256, TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256).
21+
CipherSuites []CipherSuiteID
22+
}
23+
24+
// CertificateRequestInfo contains information from a server's
25+
// CertificateRequest message, which is used to demand a certificate and proof
26+
// of control from a client.
27+
type CertificateRequestInfo struct {
28+
// AcceptableCAs contains zero or more, DER-encoded, X.501
29+
// Distinguished Names. These are the names of root or intermediate CAs
30+
// that the server wishes the returned certificate to be signed by. An
31+
// empty slice indicates that the server has no preference.
32+
AcceptableCAs [][]byte
33+
}
34+
35+
// SupportsCertificate returns nil if the provided certificate is supported by
36+
// the server that sent the CertificateRequest. Otherwise, it returns an error
37+
// describing the reason for the incompatibility.
38+
// NOTE: original src: https://github.com/golang/go/blob/29b9a328d268d53833d2cc063d1d8b4bf6852675/src/crypto/tls/common.go#L1273
39+
func (cri *CertificateRequestInfo) SupportsCertificate(c *tls.Certificate) error {
40+
if len(cri.AcceptableCAs) == 0 {
41+
return nil
42+
}
43+
44+
for j, cert := range c.Certificate {
45+
x509Cert := c.Leaf
46+
// Parse the certificate if this isn't the leaf node, or if
47+
// chain.Leaf was nil.
48+
if j != 0 || x509Cert == nil {
49+
var err error
50+
if x509Cert, err = x509.ParseCertificate(cert); err != nil {
51+
return fmt.Errorf("failed to parse certificate #%d in the chain: %w", j, err)
2452
}
25-
if len(x509Cert.Subject.CommonName) > 0 {
26-
nameToCertificate[strings.ToLower(x509Cert.Subject.CommonName)] = cert
53+
}
54+
55+
for _, ca := range cri.AcceptableCAs {
56+
if bytes.Equal(x509Cert.RawIssuer, ca) {
57+
return nil
2758
}
28-
for _, san := range x509Cert.DNSNames {
29-
nameToCertificate[strings.ToLower(san)] = cert
59+
}
60+
}
61+
return errNotAcceptableCertificateChain
62+
}
63+
64+
func (c *handshakeConfig) setNameToCertificateLocked() {
65+
nameToCertificate := make(map[string]*tls.Certificate)
66+
for i := range c.localCertificates {
67+
cert := &c.localCertificates[i]
68+
x509Cert := cert.Leaf
69+
if x509Cert == nil {
70+
var parseErr error
71+
x509Cert, parseErr = x509.ParseCertificate(cert.Certificate[0])
72+
if parseErr != nil {
73+
continue
3074
}
3175
}
32-
c.nameToCertificate = nameToCertificate
76+
if len(x509Cert.Subject.CommonName) > 0 {
77+
nameToCertificate[strings.ToLower(x509Cert.Subject.CommonName)] = cert
78+
}
79+
for _, san := range x509Cert.DNSNames {
80+
nameToCertificate[strings.ToLower(san)] = cert
81+
}
82+
}
83+
c.nameToCertificate = nameToCertificate
84+
}
85+
86+
func (c *handshakeConfig) getCertificate(clientHelloInfo *ClientHelloInfo) (*tls.Certificate, error) {
87+
c.mu.Lock()
88+
defer c.mu.Unlock()
89+
90+
if c.localGetCertificate != nil &&
91+
(len(c.localCertificates) == 0 || len(clientHelloInfo.ServerName) > 0) {
92+
cert, err := c.localGetCertificate(clientHelloInfo)
93+
if cert != nil || err != nil {
94+
return cert, err
95+
}
96+
}
97+
98+
if c.nameToCertificate == nil {
99+
c.setNameToCertificateLocked()
33100
}
34101

35102
if len(c.localCertificates) == 0 {
@@ -41,11 +108,11 @@ func (c *handshakeConfig) getCertificate(serverName string) (*tls.Certificate, e
41108
return &c.localCertificates[0], nil
42109
}
43110

44-
if len(serverName) == 0 {
111+
if len(clientHelloInfo.ServerName) == 0 {
45112
return &c.localCertificates[0], nil
46113
}
47114

48-
name := strings.TrimRight(strings.ToLower(serverName), ".")
115+
name := strings.TrimRight(strings.ToLower(clientHelloInfo.ServerName), ".")
49116

50117
if cert, ok := c.nameToCertificate[name]; ok {
51118
return cert, nil
@@ -65,3 +132,23 @@ func (c *handshakeConfig) getCertificate(serverName string) (*tls.Certificate, e
65132
// If nothing matches, return the first certificate.
66133
return &c.localCertificates[0], nil
67134
}
135+
136+
// NOTE: original src: https://github.com/golang/go/blob/29b9a328d268d53833d2cc063d1d8b4bf6852675/src/crypto/tls/handshake_client.go#L974
137+
func (c *handshakeConfig) getClientCertificate(cri *CertificateRequestInfo) (*tls.Certificate, error) {
138+
c.mu.Lock()
139+
defer c.mu.Unlock()
140+
if c.localGetClientCertificate != nil {
141+
return c.localGetClientCertificate(cri)
142+
}
143+
144+
for i := range c.localCertificates {
145+
chain := c.localCertificates[i]
146+
if err := cri.SupportsCertificate(&chain); err != nil {
147+
continue
148+
}
149+
return &chain, nil
150+
}
151+
152+
// No acceptable certificate found. Don't send a certificate.
153+
return new(tls.Certificate), nil
154+
}

‎certificate_test.go

+38-14
Original file line numberDiff line numberDiff line change
@@ -24,47 +24,71 @@ func TestGetCertificate(t *testing.T) {
2424
t.Fatal(err)
2525
}
2626

27-
cfg := &handshakeConfig{
28-
localCertificates: []tls.Certificate{
29-
certificateRandom,
30-
certificateTest,
31-
certificateWildcard,
32-
},
33-
}
34-
3527
testCases := []struct {
28+
localCertificates []tls.Certificate
3629
desc string
3730
serverName string
3831
expectedCertificate tls.Certificate
32+
getCertificate func(info *ClientHelloInfo) (*tls.Certificate, error)
3933
}{
4034
{
41-
desc: "Simple match in CN",
35+
desc: "Simple match in CN",
36+
localCertificates: []tls.Certificate{
37+
certificateRandom,
38+
certificateTest,
39+
certificateWildcard,
40+
},
4241
serverName: "test.test",
4342
expectedCertificate: certificateTest,
4443
},
4544
{
46-
desc: "Simple match in SANs",
45+
desc: "Simple match in SANs",
46+
localCertificates: []tls.Certificate{
47+
certificateRandom,
48+
certificateTest,
49+
certificateWildcard,
50+
},
4751
serverName: "www.test.test",
4852
expectedCertificate: certificateTest,
4953
},
5054

5155
{
52-
desc: "Wildcard match",
56+
desc: "Wildcard match",
57+
localCertificates: []tls.Certificate{
58+
certificateRandom,
59+
certificateTest,
60+
certificateWildcard,
61+
},
5362
serverName: "foo.test.test",
5463
expectedCertificate: certificateWildcard,
5564
},
5665
{
57-
desc: "No match return first",
66+
desc: "No match return first",
67+
localCertificates: []tls.Certificate{
68+
certificateRandom,
69+
certificateTest,
70+
certificateWildcard,
71+
},
5872
serverName: "foo.bar",
5973
expectedCertificate: certificateRandom,
6074
},
75+
{
76+
desc: "Get certificate from callback",
77+
getCertificate: func(info *ClientHelloInfo) (*tls.Certificate, error) {
78+
return &certificateTest, nil
79+
},
80+
expectedCertificate: certificateTest,
81+
},
6182
}
6283

6384
for _, test := range testCases {
6485
test := test
65-
6686
t.Run(test.desc, func(t *testing.T) {
67-
cert, err := cfg.getCertificate(test.serverName)
87+
cfg := &handshakeConfig{
88+
localCertificates: test.localCertificates,
89+
localGetCertificate: test.getCertificate,
90+
}
91+
cert, err := cfg.getCertificate(&ClientHelloInfo{ServerName: test.serverName})
6892
if err != nil {
6993
t.Fatal(err)
7094
}

‎config.go

+26-1
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,27 @@ type Config struct {
147147
// If an ECC ciphersuite is configured and EllipticCurves is empty
148148
// it will default to X25519, P-256, P-384 in this specific order.
149149
EllipticCurves []elliptic.Curve
150+
151+
// GetCertificate returns a Certificate based on the given
152+
// ClientHelloInfo. It will only be called if the client supplies SNI
153+
// information or if Certificates is empty.
154+
//
155+
// If GetCertificate is nil or returns nil, then the certificate is
156+
// retrieved from NameToCertificate. If NameToCertificate is nil, the
157+
// best element of Certificates will be used.
158+
GetCertificate func(*ClientHelloInfo) (*tls.Certificate, error)
159+
160+
// GetClientCertificate, if not nil, is called when a server requests a
161+
// certificate from a client. If set, the contents of Certificates will
162+
// be ignored.
163+
//
164+
// If GetClientCertificate returns an error, the handshake will be
165+
// aborted and that error will be returned. Otherwise
166+
// GetClientCertificate must return a non-nil Certificate. If
167+
// Certificate.Certificate is empty then no certificate will be sent to
168+
// the server. If this is unacceptable to the server then it may abort
169+
// the handshake.
170+
GetClientCertificate func(*CertificateRequestInfo) (*tls.Certificate, error)
150171
}
151172

152173
func defaultConnectContextMaker() (context.Context, func()) {
@@ -160,6 +181,10 @@ func (c *Config) connectContextMaker() (context.Context, func()) {
160181
return c.ConnectContextMaker()
161182
}
162183

184+
func (c *Config) includeCertificateSuites() bool {
185+
return c.PSK == nil || len(c.Certificates) > 0 || c.GetCertificate != nil || c.GetClientCertificate != nil
186+
}
187+
163188
const defaultMTU = 1200 // bytes
164189

165190
var defaultCurves = []elliptic.Curve{elliptic.X25519, elliptic.P256, elliptic.P384} //nolint:gochecknoglobals
@@ -215,6 +240,6 @@ func validateConfig(config *Config) error {
215240
}
216241
}
217242

218-
_, err := parseCipherSuites(config.CipherSuites, config.CustomCipherSuites, config.PSK == nil || len(config.Certificates) > 0, config.PSK != nil)
243+
_, err := parseCipherSuites(config.CipherSuites, config.CustomCipherSuites, config.includeCertificateSuites(), config.PSK != nil)
219244
return err
220245
}

0 commit comments

Comments
 (0)
Please sign in to comment.