Skip to content

Commit

Permalink
advancedtls: add revocation support to client/server options (#4781)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhenLian committed Sep 27, 2021
1 parent 4555155 commit 710419d
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 29 deletions.
71 changes: 45 additions & 26 deletions security/advancedtls/advancedtls.go
Expand Up @@ -181,6 +181,9 @@ type ClientOptions struct {
RootOptions RootCertificateOptions
// VType is the verification type on the client side.
VType VerificationType
// RevocationConfig is the configurations for certificate revocation checks.
// It could be nil if such checks are not needed.
RevocationConfig *RevocationConfig
}

// ServerOptions contains the fields needed to be filled by the server.
Expand All @@ -199,6 +202,9 @@ type ServerOptions struct {
RequireClientCert bool
// VType is the verification type on the server side.
VType VerificationType
// RevocationConfig is the configurations for certificate revocation checks.
// It could be nil if such checks are not needed.
RevocationConfig *RevocationConfig
}

func (o *ClientOptions) config() (*tls.Config, error) {
Expand Down Expand Up @@ -356,11 +362,12 @@ func (o *ServerOptions) config() (*tls.Config, error) {
// advancedTLSCreds is the credentials required for authenticating a connection
// using TLS.
type advancedTLSCreds struct {
config *tls.Config
verifyFunc CustomVerificationFunc
getRootCAs func(params *GetRootCAsParams) (*GetRootCAsResults, error)
isClient bool
vType VerificationType
config *tls.Config
verifyFunc CustomVerificationFunc
getRootCAs func(params *GetRootCAsParams) (*GetRootCAsResults, error)
isClient bool
vType VerificationType
revocationConfig *RevocationConfig
}

func (c advancedTLSCreds) Info() credentials.ProtocolInfo {
Expand Down Expand Up @@ -451,6 +458,14 @@ func buildVerifyFunc(c *advancedTLSCreds,
return func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
chains := verifiedChains
var leafCert *x509.Certificate
rawCertList := make([]*x509.Certificate, len(rawCerts))
for i, asn1Data := range rawCerts {
cert, err := x509.ParseCertificate(asn1Data)
if err != nil {
return err
}
rawCertList[i] = cert
}
if c.vType == CertAndHostVerification || c.vType == CertVerification {
// perform possible trust credential reloading and certificate check
rootCAs := c.config.RootCAs
Expand All @@ -469,14 +484,6 @@ func buildVerifyFunc(c *advancedTLSCreds,
rootCAs = results.TrustCerts
}
// Verify peers' certificates against RootCAs and get verifiedChains.
certs := make([]*x509.Certificate, len(rawCerts))
for i, asn1Data := range rawCerts {
cert, err := x509.ParseCertificate(asn1Data)
if err != nil {
return err
}
certs[i] = cert
}
keyUsages := []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}
if !c.isClient {
keyUsages = []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}
Expand All @@ -487,7 +494,7 @@ func buildVerifyFunc(c *advancedTLSCreds,
Intermediates: x509.NewCertPool(),
KeyUsages: keyUsages,
}
for _, cert := range certs[1:] {
for _, cert := range rawCertList[1:] {
opts.Intermediates.AddCert(cert)
}
// Perform default hostname check if specified.
Expand All @@ -501,11 +508,21 @@ func buildVerifyFunc(c *advancedTLSCreds,
opts.DNSName = parsedName
}
var err error
chains, err = certs[0].Verify(opts)
chains, err = rawCertList[0].Verify(opts)
if err != nil {
return err
}
leafCert = certs[0]
leafCert = rawCertList[0]
}
// Perform certificate revocation check if specified.
if c.revocationConfig != nil {
verifiedChains := chains
if verifiedChains == nil {
verifiedChains = [][]*x509.Certificate{rawCertList}
}
if err := CheckChainRevocation(verifiedChains, *c.revocationConfig); err != nil {
return err
}
}
// Perform custom verification check if specified.
if c.verifyFunc != nil {
Expand All @@ -529,11 +546,12 @@ func NewClientCreds(o *ClientOptions) (credentials.TransportCredentials, error)
return nil, err
}
tc := &advancedTLSCreds{
config: conf,
isClient: true,
getRootCAs: o.RootOptions.GetRootCertificates,
verifyFunc: o.VerifyPeer,
vType: o.VType,
config: conf,
isClient: true,
getRootCAs: o.RootOptions.GetRootCertificates,
verifyFunc: o.VerifyPeer,
vType: o.VType,
revocationConfig: o.RevocationConfig,
}
tc.config.NextProtos = credinternal.AppendH2ToNextProtos(tc.config.NextProtos)
return tc, nil
Expand All @@ -547,11 +565,12 @@ func NewServerCreds(o *ServerOptions) (credentials.TransportCredentials, error)
return nil, err
}
tc := &advancedTLSCreds{
config: conf,
isClient: false,
getRootCAs: o.RootOptions.GetRootCertificates,
verifyFunc: o.VerifyPeer,
vType: o.VType,
config: conf,
isClient: false,
getRootCAs: o.RootOptions.GetRootCertificates,
verifyFunc: o.VerifyPeer,
vType: o.VType,
revocationConfig: o.RevocationConfig,
}
tc.config.NextProtos = credinternal.AppendH2ToNextProtos(tc.config.NextProtos)
return tc, nil
Expand Down
4 changes: 2 additions & 2 deletions security/advancedtls/advancedtls_integration_test.go
Expand Up @@ -380,7 +380,7 @@ func (s) TestEnd2End(t *testing.T) {
}
clientTLSCreds, err := NewClientCreds(clientOptions)
if err != nil {
t.Fatalf("clientTLSCreds failed to create")
t.Fatalf("clientTLSCreds failed to create: %v", err)
}
// ------------------------Scenario 1------------------------------------
// stage = 0, initial connection should succeed
Expand Down Expand Up @@ -796,7 +796,7 @@ func (s) TestDefaultHostNameCheck(t *testing.T) {
}
clientTLSCreds, err := NewClientCreds(clientOptions)
if err != nil {
t.Fatalf("clientTLSCreds failed to create")
t.Fatalf("clientTLSCreds failed to create: %v", err)
}
shouldFail := false
if test.expectError {
Expand Down
36 changes: 35 additions & 1 deletion security/advancedtls/advancedtls_test.go
Expand Up @@ -27,10 +27,12 @@ import (
"net"
"testing"

lru "github.com/hashicorp/golang-lru"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/tls/certprovider"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/security/advancedtls/internal/testutils"
"google.golang.org/grpc/security/advancedtls/testdata"
)

type s struct {
Expand Down Expand Up @@ -339,6 +341,10 @@ func (s) TestClientServerHandshake(t *testing.T) {
getRootCAsForServerBad := func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
return nil, fmt.Errorf("bad root certificate reloading")
}
cache, err := lru.New(5)
if err != nil {
t.Fatalf("lru.New: err = %v", err)
}
for _, test := range []struct {
desc string
clientCert []tls.Certificate
Expand All @@ -349,6 +355,7 @@ func (s) TestClientServerHandshake(t *testing.T) {
clientVType VerificationType
clientRootProvider certprovider.Provider
clientIdentityProvider certprovider.Provider
clientRevocationConfig *RevocationConfig
clientExpectHandshakeError bool
serverMutualTLS bool
serverCert []tls.Certificate
Expand All @@ -359,6 +366,7 @@ func (s) TestClientServerHandshake(t *testing.T) {
serverVType VerificationType
serverRootProvider certprovider.Provider
serverIdentityProvider certprovider.Provider
serverRevocationConfig *RevocationConfig
serverExpectError bool
}{
// Client: nil setting except verifyFuncGood
Expand Down Expand Up @@ -642,6 +650,30 @@ func (s) TestClientServerHandshake(t *testing.T) {
serverRootProvider: fakeProvider{isClient: false},
serverVType: CertVerification,
},
// Client: set valid credentials with the revocation config
// Server: set valid credentials with the revocation config
// Expected Behavior: success, because non of the certificate chains sent in the connection are revoked
{
desc: "Client sets peer cert, reload root function with verifyFuncGood; Server sets peer cert, reload root function; mutualTLS",
clientCert: []tls.Certificate{cs.ClientCert1},
clientGetRoot: getRootCAsForClient,
clientVerifyFunc: clientVerifyFuncGood,
clientVType: CertVerification,
clientRevocationConfig: &RevocationConfig{
RootDir: testdata.Path("crl"),
AllowUndetermined: true,
Cache: cache,
},
serverMutualTLS: true,
serverCert: []tls.Certificate{cs.ServerCert1},
serverGetRoot: getRootCAsForServer,
serverVType: CertVerification,
serverRevocationConfig: &RevocationConfig{
RootDir: testdata.Path("crl"),
AllowUndetermined: true,
Cache: cache,
},
},
} {
test := test
t.Run(test.desc, func(t *testing.T) {
Expand All @@ -665,6 +697,7 @@ func (s) TestClientServerHandshake(t *testing.T) {
RequireClientCert: test.serverMutualTLS,
VerifyPeer: test.serverVerifyFunc,
VType: test.serverVType,
RevocationConfig: test.serverRevocationConfig,
}
go func(done chan credentials.AuthInfo, lis net.Listener, serverOptions *ServerOptions) {
serverRawConn, err := lis.Accept()
Expand Down Expand Up @@ -706,7 +739,8 @@ func (s) TestClientServerHandshake(t *testing.T) {
GetRootCertificates: test.clientGetRoot,
RootProvider: test.clientRootProvider,
},
VType: test.clientVType,
VType: test.clientVType,
RevocationConfig: test.clientRevocationConfig,
}
clientTLS, err := NewClientCreds(clientOptions)
if err != nil {
Expand Down

0 comments on commit 710419d

Please sign in to comment.