Skip to content

Commit

Permalink
Merge pull request #502 from ChIoT-Tech/master
Browse files Browse the repository at this point in the history
Handle connection loss during call to Disconnect() (including test)
  • Loading branch information
MattBrittan committed Apr 29, 2021
2 parents 4d373b3 + 222d3c1 commit 8e87e5f
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 12 deletions.
18 changes: 14 additions & 4 deletions client.go
Expand Up @@ -439,12 +439,22 @@ func (c *client) Disconnect(quiesce uint) {

dm := packets.NewControlPacket(packets.Disconnect).(*packets.DisconnectPacket)
dt := newToken(packets.Disconnect)
c.oboundP <- &PacketAndToken{p: dm, t: dt}
disconnectSent := false
select {
case c.oboundP <- &PacketAndToken{p: dm, t: dt}:
disconnectSent = true
case <-c.commsStopped:
WARN.Println("Disconnect packet could not be sent because comms stopped")
case <-time.After(time.Duration(quiesce) * time.Millisecond):
WARN.Println("Disconnect packet not sent due to timeout")
}

// wait for work to finish, or quiesce time consumed
DEBUG.Println(CLI, "calling WaitTimeout")
dt.WaitTimeout(time.Duration(quiesce) * time.Millisecond)
DEBUG.Println(CLI, "WaitTimeout done")
if disconnectSent {
DEBUG.Println(CLI, "calling WaitTimeout")
dt.WaitTimeout(time.Duration(quiesce) * time.Millisecond)
DEBUG.Println(CLI, "WaitTimeout done")
}
} else {
WARN.Println(CLI, "Disconnect() called but not connected (disconnected/reconnecting)")
c.setConnected(disconnected)
Expand Down
42 changes: 41 additions & 1 deletion fvt_client_test.go
Expand Up @@ -31,7 +31,19 @@ func Test_Start(t *testing.T) {
t.Fatalf("Error on Client.Connect(): %v", token.Error())
}

c.Disconnect(250)
// Disconnect should return within 250ms and calling a second time should not block
disconnectC := make(chan struct{}, 1)
go func() {
c.Disconnect(250)
c.Disconnect(5)
close(disconnectC)
}()

select {
case <-time.After(time.Millisecond * 300):
t.Errorf("disconnect did not finnish within 300ms")
case <-disconnectC:
}
}

/* uncomment this if you have connection policy disallowing FailClientID
Expand Down Expand Up @@ -90,6 +102,34 @@ func Test_Start(t *testing.T) {
}
*/

// Disconnect should not block under any circumstance
// This is triggered by issue #501; there is a very slight chance that Disconnect could get through the
// `status == connected` check and then the connection drops...
func Test_Disconnect(t *testing.T) {
ops := NewClientOptions().SetClientID("Disconnect").AddBroker(FVTTCP)
c := NewClient(ops)

if token := c.Connect(); token.Wait() && token.Error() != nil {
t.Fatalf("Error on Client.Connect(): %v", token.Error())
}

// Attempt to disconnect twice simultaneously and ensure this does not block
disconnectC := make(chan struct{}, 1)
go func() {
c.Disconnect(250)
cli := c.(*client)
cli.status = connected
c.Disconnect(250)
close(disconnectC)
}()

select {
case <-time.After(time.Millisecond * 300):
t.Errorf("disconnect did not finnish within 300ms")
case <-disconnectC:
}
}

func Test_Publish_1(t *testing.T) {
ops := NewClientOptions()
ops.AddBroker(FVTTCP)
Expand Down
10 changes: 5 additions & 5 deletions unit_client_test.go
Expand Up @@ -18,15 +18,15 @@ import (
"log"
"net/http"
_ "net/http/pprof"
"os"
"testing"
)

func init() {
DEBUG = log.New(os.Stderr, "DEBUG ", log.Ltime)
WARN = log.New(os.Stderr, "WARNING ", log.Ltime)
CRITICAL = log.New(os.Stderr, "CRITICAL ", log.Ltime)
ERROR = log.New(os.Stderr, "ERROR ", log.Ltime)
// Logging is off by default as this makes things simpler when you just want to confirm that tests pass
// DEBUG = log.New(os.Stderr, "DEBUG ", log.Ltime)
// WARN = log.New(os.Stderr, "WARNING ", log.Ltime)
// CRITICAL = log.New(os.Stderr, "CRITICAL ", log.Ltime)
// ERROR = log.New(os.Stderr, "ERROR ", log.Ltime)

go func() {
log.Println(http.ListenAndServe("localhost:6060", nil))
Expand Down
3 changes: 1 addition & 2 deletions unit_messageids_test.go
Expand Up @@ -16,7 +16,6 @@ package mqtt

import (
"fmt"
"log"
"testing"
)

Expand Down Expand Up @@ -63,7 +62,7 @@ func Test_noFreeID(t *testing.T) {
mids := &messageIds{index: make(map[uint16]tokenCompletor)}

for i := midMin; i != 0; i++ {
log.Println(i)
// Uncomment to see all message IDS log.Println(i)
mids.index[i] = &d
}

Expand Down

0 comments on commit 8e87e5f

Please sign in to comment.