@@ -15,6 +15,7 @@ import (
15
15
"fmt"
16
16
"io"
17
17
"net"
18
+ "strings"
18
19
"sync"
19
20
"sync/atomic"
20
21
"testing"
@@ -466,15 +467,42 @@ func TestPSK(t *testing.T) {
466
467
defer report ()
467
468
468
469
for _ , test := range []struct {
469
- Name string
470
- ServerIdentity []byte
471
- CipherSuites []CipherSuiteID
470
+ Name string
471
+ ServerIdentity []byte
472
+ CipherSuites []CipherSuiteID
473
+ ClientVerifyConnection func (* State ) error
474
+ ServerVerifyConnection func (* State ) error
475
+ WantFail bool
476
+ ExpectedServerErr string
477
+ ExpectedClientErr string
472
478
}{
473
479
{
474
480
Name : "Server identity specified" ,
475
481
ServerIdentity : []byte ("Test Identity" ),
476
482
CipherSuites : []CipherSuiteID {TLS_PSK_WITH_AES_128_CCM_8 },
477
483
},
484
+ {
485
+ Name : "Server identity specified - Server verify connection fails" ,
486
+ ServerIdentity : []byte ("Test Identity" ),
487
+ CipherSuites : []CipherSuiteID {TLS_PSK_WITH_AES_128_CCM_8 },
488
+ ServerVerifyConnection : func (s * State ) error {
489
+ return errExample
490
+ },
491
+ WantFail : true ,
492
+ ExpectedServerErr : errExample .Error (),
493
+ ExpectedClientErr : alert .BadCertificate .String (),
494
+ },
495
+ {
496
+ Name : "Server identity specified - Client verify connection fails" ,
497
+ ServerIdentity : []byte ("Test Identity" ),
498
+ CipherSuites : []CipherSuiteID {TLS_PSK_WITH_AES_128_CCM_8 },
499
+ ClientVerifyConnection : func (s * State ) error {
500
+ return errExample
501
+ },
502
+ WantFail : true ,
503
+ ExpectedServerErr : alert .BadCertificate .String (),
504
+ ExpectedClientErr : errExample .Error (),
505
+ },
478
506
{
479
507
Name : "Server identity nil" ,
480
508
ServerIdentity : nil ,
@@ -513,8 +541,9 @@ func TestPSK(t *testing.T) {
513
541
514
542
return []byte {0xAB , 0xC1 , 0x23 }, nil
515
543
},
516
- PSKIdentityHint : clientIdentity ,
517
- CipherSuites : test .CipherSuites ,
544
+ PSKIdentityHint : clientIdentity ,
545
+ CipherSuites : test .CipherSuites ,
546
+ VerifyConnection : test .ClientVerifyConnection ,
518
547
}
519
548
520
549
c , err := testClient (ctx , ca , conf , false )
@@ -528,11 +557,22 @@ func TestPSK(t *testing.T) {
528
557
}
529
558
return []byte {0xAB , 0xC1 , 0x23 }, nil
530
559
},
531
- PSKIdentityHint : test .ServerIdentity ,
532
- CipherSuites : test .CipherSuites ,
560
+ PSKIdentityHint : test .ServerIdentity ,
561
+ CipherSuites : test .CipherSuites ,
562
+ VerifyConnection : test .ServerVerifyConnection ,
533
563
}
534
564
535
565
server , err := testServer (ctx , cb , config , false )
566
+ if test .WantFail {
567
+ res := <- clientRes
568
+ if err == nil || ! strings .Contains (err .Error (), test .ExpectedServerErr ) {
569
+ t .Fatalf ("TestPSK: Server expected(%v) actual(%v)" , test .ExpectedServerErr , err )
570
+ }
571
+ if res .err == nil || ! strings .Contains (res .err .Error (), test .ExpectedClientErr ) {
572
+ t .Fatalf ("TestPSK: Client expected(%v) actual(%v)" , test .ExpectedClientErr , res .err )
573
+ }
574
+ return
575
+ }
536
576
if err != nil {
537
577
t .Fatalf ("TestPSK: Server failed(%v)" , err )
538
578
}
@@ -788,6 +828,29 @@ func TestClientCertificate(t *testing.T) {
788
828
ClientCAs : caPool ,
789
829
},
790
830
},
831
+ "NoClientCert_ServerVerifyConnectionFails" : {
832
+ clientCfg : & Config {RootCAs : srvCAPool },
833
+ serverCfg : & Config {
834
+ Certificates : []tls.Certificate {srvCert },
835
+ ClientAuth : NoClientCert ,
836
+ ClientCAs : caPool ,
837
+ VerifyConnection : func (s * State ) error {
838
+ return errExample
839
+ },
840
+ },
841
+ wantErr : true ,
842
+ },
843
+ "NoClientCert_ClientVerifyConnectionFails" : {
844
+ clientCfg : & Config {RootCAs : srvCAPool , VerifyConnection : func (s * State ) error {
845
+ return errExample
846
+ }},
847
+ serverCfg : & Config {
848
+ Certificates : []tls.Certificate {srvCert },
849
+ ClientAuth : NoClientCert ,
850
+ ClientCAs : caPool ,
851
+ },
852
+ wantErr : true ,
853
+ },
791
854
"NoClientCert_cert" : {
792
855
clientCfg : & Config {RootCAs : srvCAPool , Certificates : []tls.Certificate {cert }},
793
856
serverCfg : & Config {
@@ -850,11 +913,22 @@ func TestClientCertificate(t *testing.T) {
850
913
wantErr : true ,
851
914
},
852
915
"RequireAndVerifyClientCert" : {
853
- clientCfg : & Config {RootCAs : srvCAPool , Certificates : []tls.Certificate {cert }},
916
+ clientCfg : & Config {RootCAs : srvCAPool , Certificates : []tls.Certificate {cert }, VerifyConnection : func (s * State ) error {
917
+ if ok := bytes .Equal (s .PeerCertificates [0 ], srvCertificate .Raw ); ! ok {
918
+ return errExample
919
+ }
920
+ return nil
921
+ }},
854
922
serverCfg : & Config {
855
923
Certificates : []tls.Certificate {srvCert },
856
924
ClientAuth : RequireAndVerifyClientCert ,
857
925
ClientCAs : caPool ,
926
+ VerifyConnection : func (s * State ) error {
927
+ if ok := bytes .Equal (s .PeerCertificates [0 ], certificate .Raw ); ! ok {
928
+ return errExample
929
+ }
930
+ return nil
931
+ },
858
932
},
859
933
},
860
934
}
0 commit comments