Skip to content

Commit

Permalink
Merge pull request #996 from ripienaar/tls_conn_state
Browse files Browse the repository at this point in the history
Expose TLS connection state when possible
  • Loading branch information
ripienaar committed Jun 10, 2022
2 parents 1a55cb9 + 36b74da commit dcbb65a
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 0 deletions.
19 changes: 19 additions & 0 deletions nats.go
Expand Up @@ -168,6 +168,7 @@ var (
ErrStreamNameAlreadyInUse = errors.New("nats: stream name already in use")
ErrMaxConnectionsExceeded = errors.New("nats: server maximum connections exceeded")
ErrBadRequest = errors.New("nats: bad request")
ErrConnectionNotTLS = errors.New("nats: connection is not tls")
)

func init() {
Expand Down Expand Up @@ -1807,6 +1808,24 @@ func (nc *Conn) makeTLSConn() error {
return nil
}

// TLSConnectionState retrieves the state of the TLS connection to the server
func (nc *Conn) TLSConnectionState() (tls.ConnectionState, error) {
if !nc.isConnected() {
return tls.ConnectionState{}, ErrDisconnected
}

nc.mu.RLock()
conn := nc.conn
nc.mu.RUnlock()

tc, ok := conn.(*tls.Conn)
if !ok {
return tls.ConnectionState{}, ErrConnectionNotTLS
}

return tc.ConnectionState(), nil
}

// waitForExits will wait for all socket watcher Go routines to
// be shutdown before proceeding.
func (nc *Conn) waitForExits() {
Expand Down
13 changes: 13 additions & 0 deletions test/basic_test.go
Expand Up @@ -110,6 +110,19 @@ func TestLeakingGoRoutinesOnFailedConnect(t *testing.T) {
checkNoGoroutineLeak(t, base, "failed connect")
}

func TestTLSConnectionStateNonTLS(t *testing.T) {
s := RunDefaultServer()
defer s.Shutdown()

nc := NewDefaultConnection(t)
defer nc.Close()

_, err := nc.TLSConnectionState()
if err != nats.ErrConnectionNotTLS {
t.Fatalf("Expected a not tls error, got: %v", err)
}
}

func TestConnectedServer(t *testing.T) {
s := RunDefaultServer()
defer s.Shutdown()
Expand Down
8 changes: 8 additions & 0 deletions test/conn_test.go
Expand Up @@ -161,6 +161,14 @@ func TestServerSecureConnections(t *testing.T) {
}
nc.Flush()

state, err := nc.TLSConnectionState()
if err != nil {
t.Fatalf("Expected connection state: %v", err)
}
if !state.HandshakeComplete {
t.Fatalf("Expected valid connection state")
}

if err := Wait(checkRecv); err != nil {
t.Fatal("Failed receiving message")
}
Expand Down

0 comments on commit dcbb65a

Please sign in to comment.