From 18e6a1b415d919d4f3ec56537188381a854ba41c Mon Sep 17 00:00:00 2001 From: "R.I.Pienaar" Date: Fri, 10 Jun 2022 11:23:33 +0200 Subject: [PATCH] Expose TLS connection state whne possible I would like to be able to report things like ciphers, tls or not, verified or not etc in nats account info where its hows other connection properties but had no way to get at this information Signed-off-by: R.I.Pienaar --- nats.go | 19 +++++++++++++++++++ test/basic_test.go | 13 +++++++++++++ test/conn_test.go | 8 ++++++++ 3 files changed, 40 insertions(+) diff --git a/nats.go b/nats.go index d451bf529..f91445842 100644 --- a/nats.go +++ b/nats.go @@ -168,6 +168,7 @@ var ( ErrStreamNameAlreadyInUse = errors.New("nats: stream name already in use") ErrMaxConnectionsExceeded = errors.New("nats: server maximum connections exceeded") ErrBadRequest = errors.New("nats: bad request") + ErrConnectionNotTLS = errors.New("nats: connection is not tls") ) func init() { @@ -1807,6 +1808,24 @@ func (nc *Conn) makeTLSConn() error { return nil } +// TLSConnectionState retrieves the state of the TLS connection to the server +func (nc *Conn) TLSConnectionState() (tls.ConnectionState, error) { + if !nc.isConnected() { + return tls.ConnectionState{}, ErrDisconnected + } + + nc.mu.RLock() + conn := nc.conn + nc.mu.RUnlock() + + tc, ok := conn.(*tls.Conn) + if !ok { + return tls.ConnectionState{}, ErrConnectionNotTLS + } + + return tc.ConnectionState(), nil +} + // waitForExits will wait for all socket watcher Go routines to // be shutdown before proceeding. func (nc *Conn) waitForExits() { diff --git a/test/basic_test.go b/test/basic_test.go index f39c8f150..e96d289fc 100644 --- a/test/basic_test.go +++ b/test/basic_test.go @@ -110,6 +110,19 @@ func TestLeakingGoRoutinesOnFailedConnect(t *testing.T) { checkNoGoroutineLeak(t, base, "failed connect") } +func TestTLSConnectionStateNonTLS(t *testing.T) { + s := RunDefaultServer() + defer s.Shutdown() + + nc := NewDefaultConnection(t) + defer nc.Close() + + _, err := nc.TLSConnectionState() + if err == nil { + t.Fatalf("Expected an error, got none") + } +} + func TestConnectedServer(t *testing.T) { s := RunDefaultServer() defer s.Shutdown() diff --git a/test/conn_test.go b/test/conn_test.go index 4818be234..3d426b4d2 100644 --- a/test/conn_test.go +++ b/test/conn_test.go @@ -161,6 +161,14 @@ func TestServerSecureConnections(t *testing.T) { } nc.Flush() + state, err := nc.TLSConnectionState() + if err != nil { + t.Fatalf("Expected connection state: %v", err) + } + if !state.HandshakeComplete { + t.Fatalf("Expected valid connection state") + } + if err := Wait(checkRecv); err != nil { t.Fatal("Failed receiving message") }