Skip to content

Commit

Permalink
fix logging tests
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyang-hu committed Sep 6, 2023
1 parent 2c7ed05 commit eb7d5d3
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 45 deletions.
9 changes: 9 additions & 0 deletions x/mongo/driver/topology/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -801,10 +801,19 @@ func (s *Server) check() (description.Server, error) {
if s.conn == nil || s.conn.closed() || s.checkWasCancelled() {
// Create a new connection and add it's handshake RTT as a sample.
err = s.setupHeartbeatConnection()
duration = time.Since(start)
if err == nil {
// Use the description from the connection handshake as the value for this check.
s.rttMonitor.addSample(s.conn.helloRTT)
descPtr = &s.conn.desc
if s.conn != nil {
s.publishServerHeartbeatSucceededEvent(s.conn.ID(), duration, s.conn.desc, false)
}
} else {
err = unwrapConnectionError(err)
if s.conn != nil {
s.publishServerHeartbeatFailedEvent(s.conn.ID(), duration, err, false)
}
}
} else {
// An existing connection is being used. Use the server description properties to execute the right heartbeat.
Expand Down
67 changes: 22 additions & 45 deletions x/mongo/driver/topology/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,47 +56,40 @@ type errorQueue struct {
mutex sync.Mutex
}

func (eq *errorQueue) head() (int, error) {
func (eq *errorQueue) head() error {
eq.mutex.Lock()
defer eq.mutex.Unlock()
if l := len(eq.errors); l > 0 {
return l, eq.errors[0]
if len(eq.errors) > 0 {
return eq.errors[0]
}
return 0, nil
return nil
}

func (eq *errorQueue) dequeue() {
func (eq *errorQueue) dequeue() bool {
eq.mutex.Lock()
defer eq.mutex.Unlock()
if len(eq.errors) > 0 {
eq.errors = eq.errors[1:]
return true
}
return false
}

type timeoutConn struct {
net.Conn
errors *errorQueue
ch chan int
}

func (c *timeoutConn) Read(b []byte) (int, error) {
var n int
l, err := c.errors.head()
defer func(l int) {
c.ch <- l
}(l)
n, err := 0, c.errors.head()
if err == nil {
n, err = c.Conn.Read(b)
}
return n, err
}

func (c *timeoutConn) Write(b []byte) (int, error) {
var n int
l, err := c.errors.head()
defer func(l int) {
c.ch <- l
}(l)
n, err := 0, c.errors.head()
if err == nil {
n, err = c.Conn.Write(b)
}
Expand All @@ -106,7 +99,6 @@ func (c *timeoutConn) Write(b []byte) (int, error) {
type timeoutDialer struct {
Dialer
errors *errorQueue
ch chan int
}

func (d *timeoutDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
Expand All @@ -129,7 +121,7 @@ func (d *timeoutDialer) DialContext(ctx context.Context, network, address string
}
c = tls.Client(c, config)
}
return &timeoutConn{c, d.errors, d.ch}, e
return &timeoutConn{c, d.errors}, e
}

// TestServerHeartbeatTimeout tests timeout retry for GODRIVER-2577.
Expand All @@ -145,19 +137,16 @@ func TestServerHeartbeatTimeout(t *testing.T) {
testCases := []struct {
desc string
ioErrors []error
len int
expectPoolCleared bool
}{
{
desc: "one single timeout should not clear the pool",
ioErrors: []error{nil, networkTimeoutError, nil, networkTimeoutError, nil},
len: 0,
expectPoolCleared: false,
},
{
desc: "continuous timeouts should clear the pool",
ioErrors: []error{nil, networkTimeoutError, networkTimeoutError},
len: 1,
ioErrors: []error{nil, networkTimeoutError, networkTimeoutError, nil},
expectPoolCleared: true,
},
}
Expand All @@ -166,9 +155,8 @@ func TestServerHeartbeatTimeout(t *testing.T) {
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()

const heartbeatInterval = 200 * time.Millisecond

c := make(chan int)
var wg sync.WaitGroup
wg.Add(1)

errors := &errorQueue{errors: tc.ioErrors}
tpm := eventtest.NewTestPoolMonitor()
Expand All @@ -182,40 +170,29 @@ func TestServerHeartbeatTimeout(t *testing.T) {
return append(opts,
WithDialer(func(d Dialer) Dialer {
var dialer net.Dialer
return &timeoutDialer{&dialer, errors, c}
return &timeoutDialer{&dialer, errors}
}))
}),
WithServerMonitor(func(*event.ServerMonitor) *event.ServerMonitor {
return &event.ServerMonitor{
ServerHeartbeatSucceeded: func(e *event.ServerHeartbeatSucceededEvent) {
errors.dequeue()
if !errors.dequeue() {
wg.Done()
}
},
ServerHeartbeatFailed: func(e *event.ServerHeartbeatFailedEvent) {
errors.dequeue()
if !errors.dequeue() {
wg.Done()
}
},
}
}),
WithHeartbeatInterval(func(time.Duration) time.Duration {
return heartbeatInterval
return 200 * time.Millisecond
}),
)
require.NoError(t, server.Connect(nil))

timeout := time.After(50 * heartbeatInterval)
var l int
loop:
for {
select {
case l = <-c:
if l == 0 || tpm.IsPoolCleared() {
break loop
}
case <-timeout:
assert.Fail(t, "timeout")
break loop
}
}
assert.Equal(t, tc.len, l, "pool has been cleared unexpectedly")
wg.Wait()
assert.Equal(t, tc.expectPoolCleared, tpm.IsPoolCleared(), "expected pool cleared to be %v but was %v", tc.expectPoolCleared, tpm.IsPoolCleared())
})
}
Expand Down

0 comments on commit eb7d5d3

Please sign in to comment.