Skip to content

Commit

Permalink
Renamed GetCertificate and buildGetCertificateFunc
Browse files Browse the repository at this point in the history
  • Loading branch information
Cindy Xue committed Jul 24, 2020
1 parent 3891625 commit ec83efb
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 25 deletions.
14 changes: 7 additions & 7 deletions security/advancedtls/advancedtls.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ type ClientOptions struct {
// Certificates or GetClientCertificate indicates the certificates sent from
// the server to the client to prove server's identities. The rules for setting
// these two fields are:
// Either Certificates or GetCertificate must be set; the other will be ignored.
// Either Certificates or GetCertificates must be set; the other will be ignored.
type ServerOptions struct {
// If field Certificates is set, field GetClientCertificate will be ignored.
// The server will use Certificates every time when asked for a certificate,
Expand All @@ -166,7 +166,7 @@ type ServerOptions struct {
// invoke this function every time asked to present certificates to the
// client when a new connection is established. This is known as peer
// certificate reloading.
GetCertificate func(*tls.ClientHelloInfo) ([]*tls.Certificate, error)
GetCertificates func(*tls.ClientHelloInfo) ([]*tls.Certificate, error)
// VerifyPeer is a custom verification check after certificate signature
// check.
// If this is set, we will perform this customized check after doing the
Expand Down Expand Up @@ -210,8 +210,8 @@ func (o *ClientOptions) config() (*tls.Config, error) {
}

func (o *ServerOptions) config() (*tls.Config, error) {
if o.Certificates == nil && o.GetCertificate == nil {
return nil, fmt.Errorf("either Certificates or GetCertificate must be specified")
if o.Certificates == nil && o.GetCertificates == nil {
return nil, fmt.Errorf("either Certificates or GetCertificates must be specified")
}
if o.RequireClientCert && o.VType == SkipVerification && o.VerifyPeer == nil {
return nil, fmt.Errorf(
Expand All @@ -238,10 +238,10 @@ func (o *ServerOptions) config() (*tls.Config, error) {
Certificates: o.Certificates,
}
// The function getCertificateWithSNI is only able to perform SNI logic for go1.10 and above.
// It will return the first certificate in o.GetCertificate for go1.9.
if o.GetCertificate != nil {
// It will return the first certificate in o.GetCertificates for go1.9.
if o.GetCertificates != nil {
getCertificateWithSNI := func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
return buildGetCertificateFunc(clientHello, o)
return buildGetCertificates(clientHello, o)
}
config.GetCertificate = getCertificateWithSNI
}
Expand Down
4 changes: 2 additions & 2 deletions security/advancedtls/advancedtls_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -390,8 +390,8 @@ func TestEnd2End(t *testing.T) {
t.Run(test.desc, func(t *testing.T) {
// Start a server using ServerOptions in another goroutine.
serverOptions := &ServerOptions{
Certificates: test.serverCert,
GetCertificate: test.serverGetCert,
Certificates: test.serverCert,
GetCertificates: test.serverGetCert,
RootCertificateOptions: RootCertificateOptions{
RootCACerts: test.serverRoot,
GetRootCAs: test.serverGetRoot,
Expand Down
10 changes: 5 additions & 5 deletions security/advancedtls/advancedtls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -404,8 +404,8 @@ func TestClientServerHandshake(t *testing.T) {
}
// Start a server using ServerOptions in another goroutine.
serverOptions := &ServerOptions{
Certificates: test.serverCert,
GetCertificate: test.serverGetCert,
Certificates: test.serverCert,
GetCertificates: test.serverGetCert,
RootCertificateOptions: RootCertificateOptions{
RootCACerts: test.serverRoot,
GetRootCAs: test.serverGetRoot,
Expand Down Expand Up @@ -682,7 +682,7 @@ func TestOptionsConfig(t *testing.T) {
}
}

func TestGetCertificateSNI(t *testing.T) {
func TestGetCertificatesSNI(t *testing.T) {
// Load server certificates for setting the serverGetCert callback function.
serverPeerCert1, err := tls.LoadX509KeyPair(testdata.Path("server_cert_1.pem"),
testdata.Path("server_key_1.pem"))
Expand Down Expand Up @@ -737,7 +737,7 @@ func TestGetCertificateSNI(t *testing.T) {
test := test
t.Run(test.desc, func(t *testing.T) {
serverOptions := &ServerOptions{
GetCertificate: test.serverGetCert,
GetCertificates: test.serverGetCert,
}
serverConfig, err := serverOptions.config()
if err != nil {
Expand All @@ -756,7 +756,7 @@ func TestGetCertificateSNI(t *testing.T) {
t.Fatalf("Server is unable to parse peer certificates. Error: %v", err)
}
if !reflect.DeepEqual(*gotCertificate, test.expectedCertificate) {
t.Errorf("GetCertificate() = %v, want %v", gotCertificate, test.expectedCertificate)
t.Errorf("GetCertificates() = %v, want %v", gotCertificate, test.expectedCertificate)
}
})
}
Expand Down
12 changes: 6 additions & 6 deletions security/advancedtls/sni_110.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ import (
"fmt"
)

// The function buildGetCertificateFunc returns the certificate that matches the SNI field
// for the given ClientHelloInfo, defaulting to the first element of o.GetCertificate.
func buildGetCertificateFunc(clientHello *tls.ClientHelloInfo, o *ServerOptions) (*tls.Certificate, error) {
if o.GetCertificate == nil {
return nil, fmt.Errorf("function GetCertificate must be specified")
// The function buildGetCertificates returns the certificate that matches the SNI field
// for the given ClientHelloInfo, defaulting to the first element of o.GetCertificates.
func buildGetCertificates(clientHello *tls.ClientHelloInfo, o *ServerOptions) (*tls.Certificate, error) {
if o.GetCertificates == nil {
return nil, fmt.Errorf("function GetCertificates must be specified")
}
certificates, err := o.GetCertificate(clientHello)
certificates, err := o.GetCertificates(clientHello)
if err != nil {
return nil, err
}
Expand Down
10 changes: 5 additions & 5 deletions security/advancedtls/sni_before_110.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ import (
"fmt"
)

// The function buildGetCertificateFunc returns the first element of o.GetCertificate.
func buildGetCertificateFunc(clientHello *tls.ClientHelloInfo, o *ServerOptions) (*tls.Certificate, error) {
if o.GetCertificate == nil {
return nil, fmt.Errorf("function GetCertificate must be specified")
// The function buildGetCertificates returns the first element of o.GetCertificates.
func buildGetCertificates(clientHello *tls.ClientHelloInfo, o *ServerOptions) (*tls.Certificate, error) {
if o.GetCertificates == nil {
return nil, fmt.Errorf("function GetCertificates must be specified")
}
certificates, err := o.GetCertificate(clientHello)
certificates, err := o.GetCertificates(clientHello)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit ec83efb

Please sign in to comment.