From 789296a562a8fb4efc62ec02bc5892cc3d05343c Mon Sep 17 00:00:00 2001 From: Cindy Xue Date: Fri, 24 Jul 2020 10:24:56 -0700 Subject: [PATCH] Resolved Easwars' comments --- security/advancedtls/advancedtls.go | 4 +- security/advancedtls/advancedtls_test.go | 52 ++++++++++++------------ 2 files changed, 27 insertions(+), 29 deletions(-) diff --git a/security/advancedtls/advancedtls.go b/security/advancedtls/advancedtls.go index 5ed699d0a83..183678c952b 100644 --- a/security/advancedtls/advancedtls.go +++ b/security/advancedtls/advancedtls.go @@ -237,12 +237,12 @@ func (o *ServerOptions) config() (*tls.Config, error) { ClientAuth: clientAuth, Certificates: o.Certificates, } - // The function getCertificateWithSNI is only able to perform SNI logic for go1.10 and above. - // It will return the first certificate in o.GetCertificates for go1.9. if o.GetCertificates != nil { getCertificateWithSNI := func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { return buildGetCertificates(clientHello, o) } + // GetCertificate is only able to perform SNI logic for go1.10 and above. + // It will return the first certificate in o.GetCertificates for go1.9. config.GetCertificate = getCertificateWithSNI } if clientCAs != nil { diff --git a/security/advancedtls/advancedtls_test.go b/security/advancedtls/advancedtls_test.go index af21a52f520..cc8d206ff25 100644 --- a/security/advancedtls/advancedtls_test.go +++ b/security/advancedtls/advancedtls_test.go @@ -683,54 +683,52 @@ func TestOptionsConfig(t *testing.T) { } func TestGetCertificatesSNI(t *testing.T) { - // Load server certificates for setting the serverGetCert callback function. - serverPeerCert1, err := tls.LoadX509KeyPair(testdata.Path("server_cert_1.pem"), - testdata.Path("server_key_1.pem")) + // Load server peer certificates for setting the serverGetCert callback function. + serverPeerCert1, err := tls.LoadX509KeyPair(testdata.Path("server_cert_1.pem"), testdata.Path("server_key_1.pem")) if err != nil { - t.Fatalf("Server is unable to parse peer certificates. Error: %v", err) + t.Fatalf("tls.LoadX509KeyPair(server_cert_1.pem, server_key_1.pem) failed: %v", err) } - serverPeerCert2, err := tls.LoadX509KeyPair(testdata.Path("server_cert_2.pem"), - testdata.Path("server_key_2.pem")) + serverPeerCert2, err := tls.LoadX509KeyPair(testdata.Path("server_cert_2.pem"), testdata.Path("server_key_2.pem")) if err != nil { - t.Fatalf("Server is unable to parse peer certificates. Error: %v", err) + t.Fatalf("tls.LoadX509KeyPair(server_cert_2.pem, server_key_2.pem) failed: %v", err) } - serverPeerCert3, err := tls.LoadX509KeyPair(testdata.Path("server_cert_3.pem"), - testdata.Path("server_key_3.pem")) + serverPeerCert3, err := tls.LoadX509KeyPair(testdata.Path("server_cert_3.pem"), testdata.Path("server_key_3.pem")) if err != nil { - t.Fatalf("Server is unable to parse peer certificates. Error: %v", err) + t.Fatalf("tls.LoadX509KeyPair(server_cert_3.pem, server_key_3.pem) failed: %v", err) } + tests := []struct { - desc string - serverGetCert func(*tls.ClientHelloInfo) ([]*tls.Certificate, error) - serverName string - expectedCertificate tls.Certificate + desc string + serverGetCert func(*tls.ClientHelloInfo) ([]*tls.Certificate, error) + serverName string + wantCert tls.Certificate }{ { - desc: "Selected certificate by SNI should be serverPeerCert1", + desc: "Select serverPeerCert1", serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) { return []*tls.Certificate{&serverPeerCert1, &serverPeerCert2, &serverPeerCert3}, nil }, // "foo.bar.com" is the common name on server certificate server_cert_1.pem. - serverName: "foo.bar.com", - expectedCertificate: serverPeerCert1, + serverName: "foo.bar.com", + wantCert: serverPeerCert1, }, { - desc: "Selected certificate by SNI should be serverPeerCert2", + desc: "Select serverPeerCert2", serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) { return []*tls.Certificate{&serverPeerCert1, &serverPeerCert2, &serverPeerCert3}, nil }, // "foo.bar.server2.com" is the common name on server certificate server_cert_2.pem. - serverName: "foo.bar.server2.com", - expectedCertificate: serverPeerCert2, + serverName: "foo.bar.server2.com", + wantCert: serverPeerCert2, }, { - desc: "Selected certificate by SNI should be serverPeerCert3", + desc: "Select serverPeerCert3", serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) { return []*tls.Certificate{&serverPeerCert1, &serverPeerCert2, &serverPeerCert3}, nil }, // "google.com" is one of the DNS names on server certificate server_cert_3.pem. - serverName: "google.com", - expectedCertificate: serverPeerCert3, + serverName: "google.com", + wantCert: serverPeerCert3, }, } for _, test := range tests { @@ -741,7 +739,7 @@ func TestGetCertificatesSNI(t *testing.T) { } serverConfig, err := serverOptions.config() if err != nil { - t.Fatalf("Unable to generate serverConfig. Error: %v", err) + t.Fatalf("serverOptions.config() failed: %v", err) } pointFormatUncompressed := uint8(0) clientHello := &tls.ClientHelloInfo{ @@ -753,10 +751,10 @@ func TestGetCertificatesSNI(t *testing.T) { } gotCertificate, err := serverConfig.GetCertificate(clientHello) if err != nil { - t.Fatalf("Server is unable to parse peer certificates. Error: %v", err) + t.Fatalf("serverConfig.GetCertificate(clientHello) failed: %v", err) } - if !reflect.DeepEqual(*gotCertificate, test.expectedCertificate) { - t.Errorf("GetCertificates() = %v, want %v", gotCertificate, test.expectedCertificate) + if !reflect.DeepEqual(*gotCertificate, test.wantCert) { + t.Errorf("GetCertificates() = %v, want %v", gotCertificate, test.wantCert) } }) }