Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle connection loss during call to Disconnect() (including test) #502

Merged
merged 1 commit into from Apr 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
18 changes: 14 additions & 4 deletions client.go
Expand Up @@ -439,12 +439,22 @@ func (c *client) Disconnect(quiesce uint) {

dm := packets.NewControlPacket(packets.Disconnect).(*packets.DisconnectPacket)
dt := newToken(packets.Disconnect)
c.oboundP <- &PacketAndToken{p: dm, t: dt}
disconnectSent := false
select {
case c.oboundP <- &PacketAndToken{p: dm, t: dt}:
disconnectSent = true
case <-c.commsStopped:
WARN.Println("Disconnect packet could not be sent because comms stopped")
case <-time.After(time.Duration(quiesce) * time.Millisecond):
WARN.Println("Disconnect packet not sent due to timeout")
}

// wait for work to finish, or quiesce time consumed
DEBUG.Println(CLI, "calling WaitTimeout")
dt.WaitTimeout(time.Duration(quiesce) * time.Millisecond)
DEBUG.Println(CLI, "WaitTimeout done")
if disconnectSent {
DEBUG.Println(CLI, "calling WaitTimeout")
dt.WaitTimeout(time.Duration(quiesce) * time.Millisecond)
DEBUG.Println(CLI, "WaitTimeout done")
}
} else {
WARN.Println(CLI, "Disconnect() called but not connected (disconnected/reconnecting)")
c.setConnected(disconnected)
Expand Down
42 changes: 41 additions & 1 deletion fvt_client_test.go
Expand Up @@ -31,7 +31,19 @@ func Test_Start(t *testing.T) {
t.Fatalf("Error on Client.Connect(): %v", token.Error())
}

c.Disconnect(250)
// Disconnect should return within 250ms and calling a second time should not block
disconnectC := make(chan struct{}, 1)
go func() {
c.Disconnect(250)
c.Disconnect(5)
close(disconnectC)
}()

select {
case <-time.After(time.Millisecond * 300):
t.Errorf("disconnect did not finnish within 300ms")
case <-disconnectC:
}
}

/* uncomment this if you have connection policy disallowing FailClientID
Expand Down Expand Up @@ -90,6 +102,34 @@ func Test_Start(t *testing.T) {
}
*/

// Disconnect should not block under any circumstance
// This is triggered by issue #501; there is a very slight chance that Disconnect could get through the
// `status == connected` check and then the connection drops...
func Test_Disconnect(t *testing.T) {
ops := NewClientOptions().SetClientID("Disconnect").AddBroker(FVTTCP)
c := NewClient(ops)

if token := c.Connect(); token.Wait() && token.Error() != nil {
t.Fatalf("Error on Client.Connect(): %v", token.Error())
}

// Attempt to disconnect twice simultaneously and ensure this does not block
disconnectC := make(chan struct{}, 1)
go func() {
c.Disconnect(250)
cli := c.(*client)
cli.status = connected
c.Disconnect(250)
close(disconnectC)
}()

select {
case <-time.After(time.Millisecond * 300):
t.Errorf("disconnect did not finnish within 300ms")
case <-disconnectC:
}
}

func Test_Publish_1(t *testing.T) {
ops := NewClientOptions()
ops.AddBroker(FVTTCP)
Expand Down
10 changes: 5 additions & 5 deletions unit_client_test.go
Expand Up @@ -18,15 +18,15 @@ import (
"log"
"net/http"
_ "net/http/pprof"
"os"
"testing"
)

func init() {
DEBUG = log.New(os.Stderr, "DEBUG ", log.Ltime)
WARN = log.New(os.Stderr, "WARNING ", log.Ltime)
CRITICAL = log.New(os.Stderr, "CRITICAL ", log.Ltime)
ERROR = log.New(os.Stderr, "ERROR ", log.Ltime)
// Logging is off by default as this makes things simpler when you just want to confirm that tests pass
// DEBUG = log.New(os.Stderr, "DEBUG ", log.Ltime)
// WARN = log.New(os.Stderr, "WARNING ", log.Ltime)
// CRITICAL = log.New(os.Stderr, "CRITICAL ", log.Ltime)
// ERROR = log.New(os.Stderr, "ERROR ", log.Ltime)

go func() {
log.Println(http.ListenAndServe("localhost:6060", nil))
Expand Down
3 changes: 1 addition & 2 deletions unit_messageids_test.go
Expand Up @@ -16,7 +16,6 @@ package mqtt

import (
"fmt"
"log"
"testing"
)

Expand Down Expand Up @@ -63,7 +62,7 @@ func Test_noFreeID(t *testing.T) {
mids := &messageIds{index: make(map[uint16]tokenCompletor)}

for i := midMin; i != 0; i++ {
log.Println(i)
// Uncomment to see all message IDS log.Println(i)
mids.index[i] = &d
}

Expand Down