Skip to content

Commit

Permalink
Fix client cert handling when cert reloading is enabled (#697)
Browse files Browse the repository at this point in the history
Since cloudprober is the client, the dynamic cert loading needs to be
implemented in GetClientCertificate rather than GetCertificate.
  • Loading branch information
cbroglie committed Mar 14, 2024
1 parent 17c6b95 commit b0f610d
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 7 deletions.
2 changes: 1 addition & 1 deletion internal/tlsconfig/tlsconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func UpdateTLSConfig(tlsConfig *tls.Config, c *configpb.TLSConfig) error {
if c.GetReloadIntervalSec() > 0 {
key := [2]string{certF, keyF}

tlsConfig.GetCertificate = func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
tlsConfig.GetClientCertificate = func(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) {
global.mu.RLock()
entry, ok := global.cache[key]
global.mu.RUnlock()
Expand Down
33 changes: 27 additions & 6 deletions internal/tlsconfig/tlsconfig_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ package tlsconfig
import (
"crypto/tls"
"crypto/x509"
"net/http"
"net/http/httptest"
"os"
"testing"
"time"
Expand Down Expand Up @@ -112,27 +114,46 @@ func TestUpdateTLSConfig(t *testing.T) {
assert.Equal(t, tt.serverName, tlsConfig.ServerName, "ServerName mismatch")

if !tt.dynamic {
assert.Nil(t, tlsConfig.GetCertificate, "GetCertificate should be nil")
assert.Nil(t, tlsConfig.GetClientCertificate, "GetClientCertificate should be nil")
assert.Len(t, tlsConfig.Certificates, 1, "Certificates should have one entry")
parseAndVerifyCert(t, tlsConfig.Certificates[0], tt.wantCN)
}

if tt.dynamic {
assert.NotNil(t, tlsConfig.GetCertificate, "GetCertificate should not be nil")
assert.NotNil(t, tlsConfig.GetClientCertificate, "GetClientCertificate should not be nil")
assert.Equal(t, 0, len(tlsConfig.Certificates), "Certificates should be empty")

cert, err := tlsConfig.GetCertificate(nil)
assert.NoError(t, err, "Error getting TLS certificate")
cert, err := tlsConfig.GetClientCertificate(nil)
assert.NoError(t, err, "Error getting client TLS certificate")
parseAndVerifyCert(t, *cert, tt.wantCN)

if tt.nextCert[0] != "" {
writeTestCert(tt.nextCert)
time.Sleep(1 * time.Second)
cert, err := tlsConfig.GetCertificate(nil)
assert.NoError(t, err, "Error getting TLS certificate")
cert, err := tlsConfig.GetClientCertificate(nil)
assert.NoError(t, err, "Error getting client TLS certificate")
parseAndVerifyCert(t, *cert, tt.wantNextCN)
}
}

ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
ts.TLS = &tls.Config{ClientAuth: tls.RequireAnyClientCert}
ts.StartTLS()
defer ts.Close()

tlsConfig = tlsConfig.Clone()
tlsConfig.InsecureSkipVerify = true
client := http.Client{
Transport: &http.Transport{
TLSClientConfig: tlsConfig,
},
}
res, err := client.Get(ts.URL)
assert.NoError(t, err)
res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
})
}
}

0 comments on commit b0f610d

Please sign in to comment.