diff --git a/nats.go b/nats.go index d451bf529..f91445842 100644 --- a/nats.go +++ b/nats.go @@ -168,6 +168,7 @@ var ( ErrStreamNameAlreadyInUse = errors.New("nats: stream name already in use") ErrMaxConnectionsExceeded = errors.New("nats: server maximum connections exceeded") ErrBadRequest = errors.New("nats: bad request") + ErrConnectionNotTLS = errors.New("nats: connection is not tls") ) func init() { @@ -1807,6 +1808,24 @@ func (nc *Conn) makeTLSConn() error { return nil } +// TLSConnectionState retrieves the state of the TLS connection to the server +func (nc *Conn) TLSConnectionState() (tls.ConnectionState, error) { + if !nc.isConnected() { + return tls.ConnectionState{}, ErrDisconnected + } + + nc.mu.RLock() + conn := nc.conn + nc.mu.RUnlock() + + tc, ok := conn.(*tls.Conn) + if !ok { + return tls.ConnectionState{}, ErrConnectionNotTLS + } + + return tc.ConnectionState(), nil +} + // waitForExits will wait for all socket watcher Go routines to // be shutdown before proceeding. func (nc *Conn) waitForExits() { diff --git a/test/basic_test.go b/test/basic_test.go index f39c8f150..d6eaaa36f 100644 --- a/test/basic_test.go +++ b/test/basic_test.go @@ -110,6 +110,19 @@ func TestLeakingGoRoutinesOnFailedConnect(t *testing.T) { checkNoGoroutineLeak(t, base, "failed connect") } +func TestTLSConnectionStateNonTLS(t *testing.T) { + s := RunDefaultServer() + defer s.Shutdown() + + nc := NewDefaultConnection(t) + defer nc.Close() + + _, err := nc.TLSConnectionState() + if err != nats.ErrConnectionNotTLS { + t.Fatalf("Expected a not tls error, got: %v", err) + } +} + func TestConnectedServer(t *testing.T) { s := RunDefaultServer() defer s.Shutdown() diff --git a/test/conn_test.go b/test/conn_test.go index 4818be234..3d426b4d2 100644 --- a/test/conn_test.go +++ b/test/conn_test.go @@ -161,6 +161,14 @@ func TestServerSecureConnections(t *testing.T) { } nc.Flush() + state, err := nc.TLSConnectionState() + if err != nil { + t.Fatalf("Expected connection state: %v", err) + } + if !state.HandshakeComplete { + t.Fatalf("Expected valid connection state") + } + if err := Wait(checkRecv); err != nil { t.Fatal("Failed receiving message") }