1
1
package dtls
2
2
3
3
import (
4
+ "bytes"
4
5
"crypto/tls"
5
6
"crypto/x509"
7
+ "fmt"
6
8
"strings"
7
9
)
8
10
9
- func (c * handshakeConfig ) getCertificate (serverName string ) (* tls.Certificate , error ) {
10
- c .mu .Lock ()
11
- defer c .mu .Unlock ()
11
+ // ClientHelloInfo contains information from a ClientHello message in order to
12
+ // guide application logic in the GetCertificate.
13
+ type ClientHelloInfo struct {
14
+ // ServerName indicates the name of the server requested by the client
15
+ // in order to support virtual hosting. ServerName is only set if the
16
+ // client is using SNI (see RFC 4366, Section 3.1).
17
+ ServerName string
12
18
13
- if c .nameToCertificate == nil {
14
- nameToCertificate := make (map [string ]* tls.Certificate )
15
- for i := range c .localCertificates {
16
- cert := & c .localCertificates [i ]
17
- x509Cert := cert .Leaf
18
- if x509Cert == nil {
19
- var parseErr error
20
- x509Cert , parseErr = x509 .ParseCertificate (cert .Certificate [0 ])
21
- if parseErr != nil {
22
- continue
23
- }
19
+ // CipherSuites lists the CipherSuites supported by the client (e.g.
20
+ // TLS_AES_128_GCM_SHA256, TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256).
21
+ CipherSuites []CipherSuiteID
22
+ }
23
+
24
+ // CertificateRequestInfo contains information from a server's
25
+ // CertificateRequest message, which is used to demand a certificate and proof
26
+ // of control from a client.
27
+ type CertificateRequestInfo struct {
28
+ // AcceptableCAs contains zero or more, DER-encoded, X.501
29
+ // Distinguished Names. These are the names of root or intermediate CAs
30
+ // that the server wishes the returned certificate to be signed by. An
31
+ // empty slice indicates that the server has no preference.
32
+ AcceptableCAs [][]byte
33
+ }
34
+
35
+ // SupportsCertificate returns nil if the provided certificate is supported by
36
+ // the server that sent the CertificateRequest. Otherwise, it returns an error
37
+ // describing the reason for the incompatibility.
38
+ // NOTE: original src: https://github.com/golang/go/blob/29b9a328d268d53833d2cc063d1d8b4bf6852675/src/crypto/tls/common.go#L1273
39
+ func (cri * CertificateRequestInfo ) SupportsCertificate (c * tls.Certificate ) error {
40
+ if len (cri .AcceptableCAs ) == 0 {
41
+ return nil
42
+ }
43
+
44
+ for j , cert := range c .Certificate {
45
+ x509Cert := c .Leaf
46
+ // Parse the certificate if this isn't the leaf node, or if
47
+ // chain.Leaf was nil.
48
+ if j != 0 || x509Cert == nil {
49
+ var err error
50
+ if x509Cert , err = x509 .ParseCertificate (cert ); err != nil {
51
+ return fmt .Errorf ("failed to parse certificate #%d in the chain: %w" , j , err )
24
52
}
25
- if len (x509Cert .Subject .CommonName ) > 0 {
26
- nameToCertificate [strings .ToLower (x509Cert .Subject .CommonName )] = cert
53
+ }
54
+
55
+ for _ , ca := range cri .AcceptableCAs {
56
+ if bytes .Equal (x509Cert .RawIssuer , ca ) {
57
+ return nil
27
58
}
28
- for _ , san := range x509Cert .DNSNames {
29
- nameToCertificate [strings .ToLower (san )] = cert
59
+ }
60
+ }
61
+ return errNotAcceptableCertificateChain
62
+ }
63
+
64
+ func (c * handshakeConfig ) setNameToCertificateLocked () {
65
+ nameToCertificate := make (map [string ]* tls.Certificate )
66
+ for i := range c .localCertificates {
67
+ cert := & c .localCertificates [i ]
68
+ x509Cert := cert .Leaf
69
+ if x509Cert == nil {
70
+ var parseErr error
71
+ x509Cert , parseErr = x509 .ParseCertificate (cert .Certificate [0 ])
72
+ if parseErr != nil {
73
+ continue
30
74
}
31
75
}
32
- c .nameToCertificate = nameToCertificate
76
+ if len (x509Cert .Subject .CommonName ) > 0 {
77
+ nameToCertificate [strings .ToLower (x509Cert .Subject .CommonName )] = cert
78
+ }
79
+ for _ , san := range x509Cert .DNSNames {
80
+ nameToCertificate [strings .ToLower (san )] = cert
81
+ }
82
+ }
83
+ c .nameToCertificate = nameToCertificate
84
+ }
85
+
86
+ func (c * handshakeConfig ) getCertificate (clientHelloInfo * ClientHelloInfo ) (* tls.Certificate , error ) {
87
+ c .mu .Lock ()
88
+ defer c .mu .Unlock ()
89
+
90
+ if c .localGetCertificate != nil &&
91
+ (len (c .localCertificates ) == 0 || len (clientHelloInfo .ServerName ) > 0 ) {
92
+ cert , err := c .localGetCertificate (clientHelloInfo )
93
+ if cert != nil || err != nil {
94
+ return cert , err
95
+ }
96
+ }
97
+
98
+ if c .nameToCertificate == nil {
99
+ c .setNameToCertificateLocked ()
33
100
}
34
101
35
102
if len (c .localCertificates ) == 0 {
@@ -41,11 +108,11 @@ func (c *handshakeConfig) getCertificate(serverName string) (*tls.Certificate, e
41
108
return & c .localCertificates [0 ], nil
42
109
}
43
110
44
- if len (serverName ) == 0 {
111
+ if len (clientHelloInfo . ServerName ) == 0 {
45
112
return & c .localCertificates [0 ], nil
46
113
}
47
114
48
- name := strings .TrimRight (strings .ToLower (serverName ), "." )
115
+ name := strings .TrimRight (strings .ToLower (clientHelloInfo . ServerName ), "." )
49
116
50
117
if cert , ok := c .nameToCertificate [name ]; ok {
51
118
return cert , nil
@@ -65,3 +132,23 @@ func (c *handshakeConfig) getCertificate(serverName string) (*tls.Certificate, e
65
132
// If nothing matches, return the first certificate.
66
133
return & c .localCertificates [0 ], nil
67
134
}
135
+
136
+ // NOTE: original src: https://github.com/golang/go/blob/29b9a328d268d53833d2cc063d1d8b4bf6852675/src/crypto/tls/handshake_client.go#L974
137
+ func (c * handshakeConfig ) getClientCertificate (cri * CertificateRequestInfo ) (* tls.Certificate , error ) {
138
+ c .mu .Lock ()
139
+ defer c .mu .Unlock ()
140
+ if c .localGetClientCertificate != nil {
141
+ return c .localGetClientCertificate (cri )
142
+ }
143
+
144
+ for i := range c .localCertificates {
145
+ chain := c .localCertificates [i ]
146
+ if err := cri .SupportsCertificate (& chain ); err != nil {
147
+ continue
148
+ }
149
+ return & chain , nil
150
+ }
151
+
152
+ // No acceptable certificate found. Don't send a certificate.
153
+ return new (tls.Certificate ), nil
154
+ }
0 commit comments