Skip to content

Commit 82c1271

Browse files
committedAug 3, 2022
Implement VerifyConnection as is in tls.Config
VerifyConnection, if not nil, is called after normal certificate verification/PSK and after VerifyPeerCertificate by either a TLS client or server. If it returns a non-nil error, the handshake is aborted and that error results. If normal verification fails then the handshake will abort before considering this callback. This callback will run for all connections regardless of InsecureSkipVerify or ClientAuth settings.
1 parent de299f5 commit 82c1271

File tree

6 files changed

+111
-10
lines changed

6 files changed

+111
-10
lines changed
 

‎config.go

+10
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,16 @@ type Config struct {
8383
// be considered but the verifiedChains will always be nil.
8484
VerifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error
8585

86+
// VerifyConnection, if not nil, is called after normal certificate
87+
// verification/PSK and after VerifyPeerCertificate by either a TLS client
88+
// or server. If it returns a non-nil error, the handshake is aborted
89+
// and that error results.
90+
//
91+
// If normal verification fails then the handshake will abort before
92+
// considering this callback. This callback will run for all connections
93+
// regardless of InsecureSkipVerify or ClientAuth settings.
94+
VerifyConnection func(*State) error
95+
8696
// RootCAs defines the set of root certificate authorities
8797
// that one peer uses when verifying the other peer's certificates.
8898
// If RootCAs is nil, TLS uses the host's root CA set.

‎conn.go

+1
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ func createConn(ctx context.Context, nextConn net.Conn, config *Config, isClient
172172
localCertificates: config.Certificates,
173173
insecureSkipVerify: config.InsecureSkipVerify,
174174
verifyPeerCertificate: config.VerifyPeerCertificate,
175+
verifyConnection: config.VerifyConnection,
175176
rootCAs: config.RootCAs,
176177
clientCAs: config.ClientCAs,
177178
customCipherSuites: config.CustomCipherSuites,

‎conn_test.go

+82-8
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"fmt"
1616
"io"
1717
"net"
18+
"strings"
1819
"sync"
1920
"sync/atomic"
2021
"testing"
@@ -466,15 +467,42 @@ func TestPSK(t *testing.T) {
466467
defer report()
467468

468469
for _, test := range []struct {
469-
Name string
470-
ServerIdentity []byte
471-
CipherSuites []CipherSuiteID
470+
Name string
471+
ServerIdentity []byte
472+
CipherSuites []CipherSuiteID
473+
ClientVerifyConnection func(*State) error
474+
ServerVerifyConnection func(*State) error
475+
WantFail bool
476+
ExpectedServerErr string
477+
ExpectedClientErr string
472478
}{
473479
{
474480
Name: "Server identity specified",
475481
ServerIdentity: []byte("Test Identity"),
476482
CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8},
477483
},
484+
{
485+
Name: "Server identity specified - Server verify connection fails",
486+
ServerIdentity: []byte("Test Identity"),
487+
CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8},
488+
ServerVerifyConnection: func(s *State) error {
489+
return errExample
490+
},
491+
WantFail: true,
492+
ExpectedServerErr: errExample.Error(),
493+
ExpectedClientErr: alert.BadCertificate.String(),
494+
},
495+
{
496+
Name: "Server identity specified - Client verify connection fails",
497+
ServerIdentity: []byte("Test Identity"),
498+
CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8},
499+
ClientVerifyConnection: func(s *State) error {
500+
return errExample
501+
},
502+
WantFail: true,
503+
ExpectedServerErr: alert.BadCertificate.String(),
504+
ExpectedClientErr: errExample.Error(),
505+
},
478506
{
479507
Name: "Server identity nil",
480508
ServerIdentity: nil,
@@ -513,8 +541,9 @@ func TestPSK(t *testing.T) {
513541

514542
return []byte{0xAB, 0xC1, 0x23}, nil
515543
},
516-
PSKIdentityHint: clientIdentity,
517-
CipherSuites: test.CipherSuites,
544+
PSKIdentityHint: clientIdentity,
545+
CipherSuites: test.CipherSuites,
546+
VerifyConnection: test.ClientVerifyConnection,
518547
}
519548

520549
c, err := testClient(ctx, ca, conf, false)
@@ -528,11 +557,22 @@ func TestPSK(t *testing.T) {
528557
}
529558
return []byte{0xAB, 0xC1, 0x23}, nil
530559
},
531-
PSKIdentityHint: test.ServerIdentity,
532-
CipherSuites: test.CipherSuites,
560+
PSKIdentityHint: test.ServerIdentity,
561+
CipherSuites: test.CipherSuites,
562+
VerifyConnection: test.ServerVerifyConnection,
533563
}
534564

535565
server, err := testServer(ctx, cb, config, false)
566+
if test.WantFail {
567+
res := <-clientRes
568+
if err == nil || !strings.Contains(err.Error(), test.ExpectedServerErr) {
569+
t.Fatalf("TestPSK: Server expected(%v) actual(%v)", test.ExpectedServerErr, err)
570+
}
571+
if res.err == nil || !strings.Contains(res.err.Error(), test.ExpectedClientErr) {
572+
t.Fatalf("TestPSK: Client expected(%v) actual(%v)", test.ExpectedClientErr, res.err)
573+
}
574+
return
575+
}
536576
if err != nil {
537577
t.Fatalf("TestPSK: Server failed(%v)", err)
538578
}
@@ -788,6 +828,29 @@ func TestClientCertificate(t *testing.T) {
788828
ClientCAs: caPool,
789829
},
790830
},
831+
"NoClientCert_ServerVerifyConnectionFails": {
832+
clientCfg: &Config{RootCAs: srvCAPool},
833+
serverCfg: &Config{
834+
Certificates: []tls.Certificate{srvCert},
835+
ClientAuth: NoClientCert,
836+
ClientCAs: caPool,
837+
VerifyConnection: func(s *State) error {
838+
return errExample
839+
},
840+
},
841+
wantErr: true,
842+
},
843+
"NoClientCert_ClientVerifyConnectionFails": {
844+
clientCfg: &Config{RootCAs: srvCAPool, VerifyConnection: func(s *State) error {
845+
return errExample
846+
}},
847+
serverCfg: &Config{
848+
Certificates: []tls.Certificate{srvCert},
849+
ClientAuth: NoClientCert,
850+
ClientCAs: caPool,
851+
},
852+
wantErr: true,
853+
},
791854
"NoClientCert_cert": {
792855
clientCfg: &Config{RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}},
793856
serverCfg: &Config{
@@ -850,11 +913,22 @@ func TestClientCertificate(t *testing.T) {
850913
wantErr: true,
851914
},
852915
"RequireAndVerifyClientCert": {
853-
clientCfg: &Config{RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}},
916+
clientCfg: &Config{RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}, VerifyConnection: func(s *State) error {
917+
if ok := bytes.Equal(s.PeerCertificates[0], srvCertificate.Raw); !ok {
918+
return errExample
919+
}
920+
return nil
921+
}},
854922
serverCfg: &Config{
855923
Certificates: []tls.Certificate{srvCert},
856924
ClientAuth: RequireAndVerifyClientCert,
857925
ClientCAs: caPool,
926+
VerifyConnection: func(s *State) error {
927+
if ok := bytes.Equal(s.PeerCertificates[0], certificate.Raw); !ok {
928+
return errExample
929+
}
930+
return nil
931+
},
858932
},
859933
},
860934
}

‎flight4handler.go

+12-2
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ func flight4Parse(ctx context.Context, c flightConn, state *State, cache *handsh
9191
state.peerCertificatesVerified = verified
9292
} else if state.PeerCertificates != nil {
9393
// A certificate was received, but we haven't seen a CertificateVerify
94-
// keep reading until we receieve one
94+
// keep reading until we receive one
9595
return 0, nil, nil
9696
}
9797

@@ -178,6 +178,11 @@ func flight4Parse(ctx context.Context, c flightConn, state *State, cache *handsh
178178
}
179179

180180
if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeAnonymous {
181+
if cfg.verifyConnection != nil {
182+
if err := cfg.verifyConnection(state.clone()); err != nil {
183+
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err
184+
}
185+
}
181186
return flight6, nil, nil
182187
}
183188

@@ -198,7 +203,12 @@ func flight4Parse(ctx context.Context, c flightConn, state *State, cache *handsh
198203
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, errClientCertificateNotVerified
199204
}
200205
case NoClientCert, RequestClientCert:
201-
return flight6, nil, nil
206+
// go to flight6
207+
}
208+
if cfg.verifyConnection != nil {
209+
if err := cfg.verifyConnection(state.clone()); err != nil {
210+
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err
211+
}
202212
}
203213

204214
return flight6, nil, nil

‎flight5handler.go

+5
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,11 @@ func initalizeCipherSuite(state *State, cache *handshakeCache, cfg *handshakeCon
328328
}
329329
}
330330
}
331+
if cfg.verifyConnection != nil {
332+
if err = cfg.verifyConnection(state.clone()); err != nil {
333+
return &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err
334+
}
335+
}
331336

332337
if err = state.cipherSuite.Init(state.masterSecret, clientRandom[:], serverRandom[:], true); err != nil {
333338
return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err

‎handshaker.go

+1
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ type handshakeConfig struct {
102102
nameToCertificate map[string]*tls.Certificate
103103
insecureSkipVerify bool
104104
verifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error
105+
verifyConnection func(*State) error
105106
sessionStore SessionStore
106107
rootCAs *x509.CertPool
107108
clientCAs *x509.CertPool

0 commit comments

Comments
 (0)
Please sign in to comment.