Skip to content

Commit

Permalink
Resolved Easwars' comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Cindy Xue committed Jul 24, 2020
1 parent ec83efb commit 789296a
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 29 deletions.
4 changes: 2 additions & 2 deletions security/advancedtls/advancedtls.go
Expand Up @@ -237,12 +237,12 @@ func (o *ServerOptions) config() (*tls.Config, error) {
ClientAuth: clientAuth,
Certificates: o.Certificates,
}
// The function getCertificateWithSNI is only able to perform SNI logic for go1.10 and above.
// It will return the first certificate in o.GetCertificates for go1.9.
if o.GetCertificates != nil {
getCertificateWithSNI := func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
return buildGetCertificates(clientHello, o)
}
// GetCertificate is only able to perform SNI logic for go1.10 and above.
// It will return the first certificate in o.GetCertificates for go1.9.
config.GetCertificate = getCertificateWithSNI
}
if clientCAs != nil {
Expand Down
52 changes: 25 additions & 27 deletions security/advancedtls/advancedtls_test.go
Expand Up @@ -683,54 +683,52 @@ func TestOptionsConfig(t *testing.T) {
}

func TestGetCertificatesSNI(t *testing.T) {
// Load server certificates for setting the serverGetCert callback function.
serverPeerCert1, err := tls.LoadX509KeyPair(testdata.Path("server_cert_1.pem"),
testdata.Path("server_key_1.pem"))
// Load server peer certificates for setting the serverGetCert callback function.
serverPeerCert1, err := tls.LoadX509KeyPair(testdata.Path("server_cert_1.pem"), testdata.Path("server_key_1.pem"))
if err != nil {
t.Fatalf("Server is unable to parse peer certificates. Error: %v", err)
t.Fatalf("tls.LoadX509KeyPair(server_cert_1.pem, server_key_1.pem) failed: %v", err)
}
serverPeerCert2, err := tls.LoadX509KeyPair(testdata.Path("server_cert_2.pem"),
testdata.Path("server_key_2.pem"))
serverPeerCert2, err := tls.LoadX509KeyPair(testdata.Path("server_cert_2.pem"), testdata.Path("server_key_2.pem"))
if err != nil {
t.Fatalf("Server is unable to parse peer certificates. Error: %v", err)
t.Fatalf("tls.LoadX509KeyPair(server_cert_2.pem, server_key_2.pem) failed: %v", err)
}
serverPeerCert3, err := tls.LoadX509KeyPair(testdata.Path("server_cert_3.pem"),
testdata.Path("server_key_3.pem"))
serverPeerCert3, err := tls.LoadX509KeyPair(testdata.Path("server_cert_3.pem"), testdata.Path("server_key_3.pem"))
if err != nil {
t.Fatalf("Server is unable to parse peer certificates. Error: %v", err)
t.Fatalf("tls.LoadX509KeyPair(server_cert_3.pem, server_key_3.pem) failed: %v", err)
}

tests := []struct {
desc string
serverGetCert func(*tls.ClientHelloInfo) ([]*tls.Certificate, error)
serverName string
expectedCertificate tls.Certificate
desc string
serverGetCert func(*tls.ClientHelloInfo) ([]*tls.Certificate, error)
serverName string
wantCert tls.Certificate
}{
{
desc: "Selected certificate by SNI should be serverPeerCert1",
desc: "Select serverPeerCert1",
serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
return []*tls.Certificate{&serverPeerCert1, &serverPeerCert2, &serverPeerCert3}, nil
},
// "foo.bar.com" is the common name on server certificate server_cert_1.pem.
serverName: "foo.bar.com",
expectedCertificate: serverPeerCert1,
serverName: "foo.bar.com",
wantCert: serverPeerCert1,
},
{
desc: "Selected certificate by SNI should be serverPeerCert2",
desc: "Select serverPeerCert2",
serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
return []*tls.Certificate{&serverPeerCert1, &serverPeerCert2, &serverPeerCert3}, nil
},
// "foo.bar.server2.com" is the common name on server certificate server_cert_2.pem.
serverName: "foo.bar.server2.com",
expectedCertificate: serverPeerCert2,
serverName: "foo.bar.server2.com",
wantCert: serverPeerCert2,
},
{
desc: "Selected certificate by SNI should be serverPeerCert3",
desc: "Select serverPeerCert3",
serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
return []*tls.Certificate{&serverPeerCert1, &serverPeerCert2, &serverPeerCert3}, nil
},
// "google.com" is one of the DNS names on server certificate server_cert_3.pem.
serverName: "google.com",
expectedCertificate: serverPeerCert3,
serverName: "google.com",
wantCert: serverPeerCert3,
},
}
for _, test := range tests {
Expand All @@ -741,7 +739,7 @@ func TestGetCertificatesSNI(t *testing.T) {
}
serverConfig, err := serverOptions.config()
if err != nil {
t.Fatalf("Unable to generate serverConfig. Error: %v", err)
t.Fatalf("serverOptions.config() failed: %v", err)
}
pointFormatUncompressed := uint8(0)
clientHello := &tls.ClientHelloInfo{
Expand All @@ -753,10 +751,10 @@ func TestGetCertificatesSNI(t *testing.T) {
}
gotCertificate, err := serverConfig.GetCertificate(clientHello)
if err != nil {
t.Fatalf("Server is unable to parse peer certificates. Error: %v", err)
t.Fatalf("serverConfig.GetCertificate(clientHello) failed: %v", err)
}
if !reflect.DeepEqual(*gotCertificate, test.expectedCertificate) {
t.Errorf("GetCertificates() = %v, want %v", gotCertificate, test.expectedCertificate)
if !reflect.DeepEqual(*gotCertificate, test.wantCert) {
t.Errorf("GetCertificates() = %v, want %v", gotCertificate, test.wantCert)
}
})
}
Expand Down

0 comments on commit 789296a

Please sign in to comment.