From 36b74daf8e76d000cae9a5490e0a77e894363efe Mon Sep 17 00:00:00 2001 From: "R.I.Pienaar" Date: Fri, 10 Jun 2022 11:23:33 +0200 Subject: [PATCH] Expose TLS connection state whne possible I would like to be able to report things like ciphers, tls or not, verified or not etc in nats account info where its hows other connection properties but had no way to get at this information Signed-off-by: R.I.Pienaar --- nats.go | 19 +++++++++++++++++++ test/basic_test.go | 13 +++++++++++++ test/conn_test.go | 8 ++++++++ 3 files changed, 40 insertions(+) diff --git a/nats.go b/nats.go index d451bf529..f91445842 100644 --- a/nats.go +++ b/nats.go @@ -168,6 +168,7 @@ var ( ErrStreamNameAlreadyInUse = errors.New("nats: stream name already in use") ErrMaxConnectionsExceeded = errors.New("nats: server maximum connections exceeded") ErrBadRequest = errors.New("nats: bad request") + ErrConnectionNotTLS = errors.New("nats: connection is not tls") ) func init() { @@ -1807,6 +1808,24 @@ func (nc *Conn) makeTLSConn() error { return nil } +// TLSConnectionState retrieves the state of the TLS connection to the server +func (nc *Conn) TLSConnectionState() (tls.ConnectionState, error) { + if !nc.isConnected() { + return tls.ConnectionState{}, ErrDisconnected + } + + nc.mu.RLock() + conn := nc.conn + nc.mu.RUnlock() + + tc, ok := conn.(*tls.Conn) + if !ok { + return tls.ConnectionState{}, ErrConnectionNotTLS + } + + return tc.ConnectionState(), nil +} + // waitForExits will wait for all socket watcher Go routines to // be shutdown before proceeding. func (nc *Conn) waitForExits() { diff --git a/test/basic_test.go b/test/basic_test.go index f39c8f150..d6eaaa36f 100644 --- a/test/basic_test.go +++ b/test/basic_test.go @@ -110,6 +110,19 @@ func TestLeakingGoRoutinesOnFailedConnect(t *testing.T) { checkNoGoroutineLeak(t, base, "failed connect") } +func TestTLSConnectionStateNonTLS(t *testing.T) { + s := RunDefaultServer() + defer s.Shutdown() + + nc := NewDefaultConnection(t) + defer nc.Close() + + _, err := nc.TLSConnectionState() + if err != nats.ErrConnectionNotTLS { + t.Fatalf("Expected a not tls error, got: %v", err) + } +} + func TestConnectedServer(t *testing.T) { s := RunDefaultServer() defer s.Shutdown() diff --git a/test/conn_test.go b/test/conn_test.go index 4818be234..3d426b4d2 100644 --- a/test/conn_test.go +++ b/test/conn_test.go @@ -161,6 +161,14 @@ func TestServerSecureConnections(t *testing.T) { } nc.Flush() + state, err := nc.TLSConnectionState() + if err != nil { + t.Fatalf("Expected connection state: %v", err) + } + if !state.HandshakeComplete { + t.Fatalf("Expected valid connection state") + } + if err := Wait(checkRecv); err != nil { t.Fatal("Failed receiving message") }