Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Propagate MQTT connect error (fixes #356) #452

Merged
merged 2 commits into from Sep 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion client.go
Expand Up @@ -365,7 +365,7 @@ func (c *client) attemptConnection() (net.Conn, byte, bool, error) {
DEBUG.Println(CLI, "socket connected to broker")

// Now we send the perform the MQTT connection handshake
rc, sessionPresent = ConnectMQTT(conn, cm, protocolVersion)
rc, sessionPresent, err = connectMQTT(conn, cm, protocolVersion)
if rc == packets.Accepted {
break // successfully connected
}
Expand Down
146 changes: 81 additions & 65 deletions net.go
Expand Up @@ -15,6 +15,8 @@
package mqtt

import (
"errors"
"io"
"net"
"reflect"
"strings"
Expand All @@ -26,11 +28,18 @@ import (

const closedNetConnErrorText = "use of closed network connection" // error string for closed conn (https://golang.org/src/net/error_test.go)

// ConnectMQTT takes a connected net.Conn and performs the initial MQTT handshake. Paramaters are:
// ConnectMQTT takes a connected net.Conn and performs the initial MQTT handshake. Parameters are:
// conn - Connected net.Conn
// cm - Connect Packet with everything other than the protocolname/version populated (historical reasons)
// cm - Connect Packet with everything other than the protocol name/version populated (historical reasons)
// protocolVersion - The protocol version to attempt to connect with
//
// Note that, for backward compatibility, ConnectMQTT() suppresses the actual connection error (compare to connectMQTT()).
func ConnectMQTT(conn net.Conn, cm *packets.ConnectPacket, protocolVersion uint) (byte, bool) {
rc, sessionPresent, _ := connectMQTT(conn, cm, protocolVersion)
return rc, sessionPresent
}

func connectMQTT(conn io.ReadWriter, cm *packets.ConnectPacket, protocolVersion uint) (byte, bool, error) {
switch protocolVersion {
case 3:
DEBUG.Println(CLI, "Using MQTT 3.1 protocol")
Expand All @@ -49,43 +58,46 @@ func ConnectMQTT(conn net.Conn, cm *packets.ConnectPacket, protocolVersion uint)
cm.ProtocolName = "MQTT"
cm.ProtocolVersion = 4
}

if err := cm.Write(conn); err != nil {
ERROR.Println(CLI, err)
return packets.ErrNetworkError, false, err
}

rc, sessionPresent := verifyCONNACK(conn)
return rc, sessionPresent
rc, sessionPresent, err := verifyCONNACK(conn)
return rc, sessionPresent, err
}

// This function is only used for receiving a connack
// when the connection is first started.
// This prevents receiving incoming data while resume
// is in progress if clean session is false.
func verifyCONNACK(conn net.Conn) (byte, bool) {
func verifyCONNACK(conn io.Reader) (byte, bool, error) {
DEBUG.Println(NET, "connect started")

ca, err := packets.ReadPacket(conn)
if err != nil {
ERROR.Println(NET, "connect got error", err)
return packets.ErrNetworkError, false
return packets.ErrNetworkError, false, err
}

if ca == nil {
ERROR.Println(NET, "received nil packet")
return packets.ErrNetworkError, false
return packets.ErrNetworkError, false, errors.New("nil CONNACK packet")
}

msg, ok := ca.(*packets.ConnackPacket)
if !ok {
ERROR.Println(NET, "received msg that was not CONNACK")
return packets.ErrNetworkError, false
return packets.ErrNetworkError, false, errors.New("non-CONNACK first packet received")
}

DEBUG.Println(NET, "received connack")
return msg.ReturnCode, msg.SessionPresent
return msg.ReturnCode, msg.SessionPresent, nil
}

// inbound encapuslates the output from startIncoming.
// err - If != nil then an error has occured
// err - If != nil then an error has occurred
// cp - A control packet received over the network link
type inbound struct {
err error
Expand All @@ -95,12 +107,13 @@ type inbound struct {
// startIncoming initiates a goroutine that reads incoming messages off the wire and sends them to the channel (returned).
// If there are any issues with the network connection then the returned cahnnel will be closed and the goroutine will exit
// (so closing the connection will terminate the goroutine)
func startIncoming(conn net.Conn) <-chan inbound {
func startIncoming(conn io.Reader) <-chan inbound {
var err error
var cp packets.ControlPacket
ibound := make(chan inbound)

DEBUG.Println(NET, "incoming started")

go func() {
for {
if cp, err = packets.ReadPacket(conn); err != nil {
Expand All @@ -118,35 +131,36 @@ func startIncoming(conn net.Conn) <-chan inbound {
ibound <- inbound{cp: cp}
}
}()

return ibound
}

// incommingComms encapuslates the possible output of the incommingComms routine. If err != nil then an error has occured and
// incomingComms encapuslates the possible output of the incomingComms routine. If err != nil then an error has occurred and
// the routine will have terminated; otherwise one of the other members should be non-nil
type incommingComms struct {
err error // If non-nil then there has been an error (ignore everything else)
outbound *PacketAndToken // Packet (with token) than needs to be sent out (e.g. an acknowledgement)
incommingPub *packets.PublishPacket // A new publish has been received; this will need to be passed on to our user
type incomingComms struct {
err error // If non-nil then there has been an error (ignore everything else)
outbound *PacketAndToken // Packet (with token) than needs to be sent out (e.g. an acknowledgement)
incomingPub *packets.PublishPacket // A new publish has been received; this will need to be passed on to our user
}

// startIncommingComms initiates incomming communications; this includes starting a goroutine to process incomming
// startIncomingComms initiates incoming communications; this includes starting a goroutine to process incoming
// messages.
// Accepts a channel of inbound messages from the store (persistanced messages); note this must be closed as soon as the
// everything in the store has been sent.
// Returns a channel that will be passed any received packets; this will be closed on a network error (and inboundFromStore closed)
func startIncommingComms(conn net.Conn,
func startIncomingComms(conn io.Reader,
c commsFns,
inboundFromStore <-chan packets.ControlPacket,
) <-chan incommingComms {
) <-chan incomingComms {
ibound := startIncoming(conn) // Start goroutine that reads from network connection
output := make(chan incommingComms)
output := make(chan incomingComms)

DEBUG.Println(NET, "startIncommingComms started")
DEBUG.Println(NET, "startIncomingComms started")
go func() {
for {
if inboundFromStore == nil && ibound == nil {
close(output)
DEBUG.Println(NET, "startIncommingComms goroutine complete")
DEBUG.Println(NET, "startIncomingComms goroutine complete")
return // As soon as ibound is closed we can exit (should have already processed an error)
}
DEBUG.Println(NET, "logic waiting for msg on ibound")
Expand All @@ -156,22 +170,22 @@ func startIncommingComms(conn net.Conn,
select {
case msg, ok = <-inboundFromStore:
if !ok {
DEBUG.Println(NET, "startIncommingComms: inboundFromStore complete")
DEBUG.Println(NET, "startIncomingComms: inboundFromStore complete")
inboundFromStore = nil // should happen quickly as this is only for persisted messages
continue
}
DEBUG.Println(NET, "startIncommingComms: got msg from store")
DEBUG.Println(NET, "startIncomingComms: got msg from store")
case ibMsg, ok := <-ibound:
if !ok {
DEBUG.Println(NET, "startIncommingComms: ibound complete")
DEBUG.Println(NET, "startIncomingComms: ibound complete")
ibound = nil
continue
}
DEBUG.Println(NET, "startIncommingComms: got msg on ibound")
DEBUG.Println(NET, "startIncomingComms: got msg on ibound")
// If the inbound comms routine encounters any issues it will send us an error.
if ibMsg.err != nil {
output <- incommingComms{err: ibMsg.err}
continue // Usually the channel will be closed immediatly after sending an error but safer that we do not assume this
output <- incomingComms{err: ibMsg.err}
continue // Usually the channel will be closed immediately after sending an error but safer that we do not assume this
}
msg = ibMsg.cp

Expand All @@ -181,44 +195,45 @@ func startIncommingComms(conn net.Conn,

switch m := msg.(type) {
case *packets.PingrespPacket:
DEBUG.Println(NET, "startIncommingComms: received pingresp")
DEBUG.Println(NET, "startIncomingComms: received pingresp")
c.pingRespReceived()
case *packets.SubackPacket:
DEBUG.Println(NET, "startIncommingComms: received suback, id:", m.MessageID)
DEBUG.Println(NET, "startIncomingComms: received suback, id:", m.MessageID)
token := c.getToken(m.MessageID)
switch t := token.(type) {
case *SubscribeToken:
DEBUG.Println(NET, "startIncommingComms: granted qoss", m.ReturnCodes)

if t, ok := token.(*SubscribeToken); ok {
DEBUG.Println(NET, "startIncomingComms: granted qoss", m.ReturnCodes)
for i, qos := range m.ReturnCodes {
t.subResult[t.subs[i]] = qos
}
}

token.flowComplete()
c.freeID(m.MessageID)
case *packets.UnsubackPacket:
DEBUG.Println(NET, "startIncommingComms: received unsuback, id:", m.MessageID)
DEBUG.Println(NET, "startIncomingComms: received unsuback, id:", m.MessageID)
c.getToken(m.MessageID).flowComplete()
c.freeID(m.MessageID)
case *packets.PublishPacket:
DEBUG.Println(NET, "startIncommingComms: received publish, msgId:", m.MessageID)
output <- incommingComms{incommingPub: m}
DEBUG.Println(NET, "startIncomingComms: received publish, msgId:", m.MessageID)
output <- incomingComms{incomingPub: m}
case *packets.PubackPacket:
DEBUG.Println(NET, "startIncommingComms: received puback, id:", m.MessageID)
DEBUG.Println(NET, "startIncomingComms: received puback, id:", m.MessageID)
c.getToken(m.MessageID).flowComplete()
c.freeID(m.MessageID)
case *packets.PubrecPacket:
DEBUG.Println(NET, "startIncommingComms: received pubrec, id:", m.MessageID)
DEBUG.Println(NET, "startIncomingComms: received pubrec, id:", m.MessageID)
prel := packets.NewControlPacket(packets.Pubrel).(*packets.PubrelPacket)
prel.MessageID = m.MessageID
output <- incommingComms{outbound: &PacketAndToken{p: prel, t: nil}}
output <- incomingComms{outbound: &PacketAndToken{p: prel, t: nil}}
case *packets.PubrelPacket:
DEBUG.Println(NET, "startIncommingComms: received pubrel, id:", m.MessageID)
DEBUG.Println(NET, "startIncomingComms: received pubrel, id:", m.MessageID)
pc := packets.NewControlPacket(packets.Pubcomp).(*packets.PubcompPacket)
pc.MessageID = m.MessageID
c.persistOutbound(pc)
output <- incommingComms{outbound: &PacketAndToken{p: pc, t: nil}}
output <- incomingComms{outbound: &PacketAndToken{p: pc, t: nil}}
case *packets.PubcompPacket:
DEBUG.Println(NET, "startIncommingComms: received pubcomp, id:", m.MessageID)
DEBUG.Println(NET, "startIncomingComms: received pubcomp, id:", m.MessageID)
c.getToken(m.MessageID).flowComplete()
c.freeID(m.MessageID)
}
Expand All @@ -229,14 +244,14 @@ func startIncommingComms(conn net.Conn,

// startOutgoingComms initiates a go routint to transmit outgoing packets.
// Pass in an open network connection and channels for outbound messages (including those triggered
// directly from incomming comms).
// directly from incoming comms).
// Returns a channel that will receive details of any errors (closed when the goroutine exits)
// This function wil only terminate when all input channels are closed
func startOutgoingComms(conn net.Conn,
c commsFns,
oboundp <-chan *PacketAndToken,
obound <-chan *PacketAndToken,
oboundFromIncomming <-chan *PacketAndToken,
oboundFromIncoming <-chan *PacketAndToken,
) <-chan error {
errChan := make(chan error)
DEBUG.Println(NET, "outgoing started")
Expand All @@ -248,7 +263,7 @@ func startOutgoingComms(conn net.Conn,
// This goroutine will only exits when all of the input channels we receive on have been closed. This approach is taken to avoid any
// deadlocks (if the connection goes down there are limited options as to what we can do with anything waiting on us and
// throwing away the packets seems the best option)
if oboundp == nil && obound == nil && oboundFromIncomming == nil {
if oboundp == nil && obound == nil && oboundFromIncoming == nil {
DEBUG.Println(NET, "outgoing comms stopping")
close(errChan)
return
Expand Down Expand Up @@ -306,22 +321,22 @@ func startOutgoingComms(conn net.Conn,
errChan <- err
continue
}
switch msg.p.(type) {
case *packets.DisconnectPacket:

if _, ok := msg.p.(*packets.DisconnectPacket); ok {
msg.t.(*DisconnectToken).flowComplete()
DEBUG.Println(NET, "outbound wrote disconnect, closing connection")
// As per the MQTT spec "After sending a DISCONNECT Packet the Client MUST close the Network Connection"
// Closing the connection will cause the goroutines to end in sequence (starting with incomming comms)
// Closing the connection will cause the goroutines to end in sequence (starting with incoming comms)
conn.Close()
}
case msg, ok := <-oboundFromIncomming: // message triggered by an inbound message (PubrecPacket or PubrelPacket)
case msg, ok := <-oboundFromIncoming: // message triggered by an inbound message (PubrecPacket or PubrelPacket)
if !ok {
oboundFromIncomming = nil
oboundFromIncoming = nil
continue
}
DEBUG.Println(NET, "obound from incomming msg to write, type", reflect.TypeOf(msg.p), " ID ", msg.p.Details().MessageID)
DEBUG.Println(NET, "obound from incoming msg to write, type", reflect.TypeOf(msg.p), " ID ", msg.p.Details().MessageID)
if err := msg.p.Write(conn); err != nil {
ERROR.Println(NET, "outgoing oboundFromIncomming reporting error", err)
ERROR.Println(NET, "outgoing oboundFromIncoming reporting error", err)
if msg.t != nil {
msg.t.setError(err)
}
Expand All @@ -338,7 +353,7 @@ func startOutgoingComms(conn net.Conn,
// commsFns provide access to the client state (messageids, requesting disconnection and updating timing)
type commsFns interface {
getToken(id uint16) tokenCompletor // Retrieve the token for the specified messageid (if none then a dummy token must be returned)
freeID(id uint16) // Release the specified messageid (clearing out of any persistant store)
freeID(id uint16) // Release the specified messageid (clearing out of any persistent store)
UpdateLastReceived() // Must be called whenever a packet is received
UpdateLastSent() // Must be called whenever a packet is successfully sent
getWriteTimeOut() time.Duration // Return the writetimeout (or 0 if none)
Expand All @@ -348,28 +363,29 @@ type commsFns interface {
}

// startComms initiates goroutines that handles communications over the network connection
// Messages will be stored (via commsFns) and deleted from the store as neccessary
// Messages will be stored (via commsFns) and deleted from the store as necessary
// It returns two channels:
// packets.PublishPacket - Will receive publish packets received over the network. Closed when incomming comms routines exit (on shutdown or if network link closed)
// packets.PublishPacket - Will receive publish packets received over the network.
// Closed when incoming comms routines exit (on shutdown or if network link closed)
// error - Any errors will be sent on this channel. The channel is closed when all comms routines have shut down
//
// Note: The comms routines monitoring oboundp and obound will not shutdown until those channels are both closed. Any messages received between the
// connection being closed and those channels being closed will generate errors (and nothing will be sent). That way the chance of a deadlock is
// minimised.
func startComms(conn net.Conn, // Network connection (must be active)
c commsFns, // getters and setters to enable us to cleanly interact with client
inboundFromStore <-chan packets.ControlPacket, // Inbound packets from the persistance store (should be closed relatively soon after startup)
inboundFromStore <-chan packets.ControlPacket, // Inbound packets from the persistence store (should be closed relatively soon after startup)
oboundp <-chan *PacketAndToken,
obound <-chan *PacketAndToken) (
<-chan *packets.PublishPacket, // Publishpackages received over the network
<-chan error, // Any errors (should generally trigger a disconnect)
) {
// Start inbound comms handler; this needs to be able to transmit messages so we start a go routine to add these to the priority outbound channel
ibound := startIncommingComms(conn, c, inboundFromStore)
outboundFromIncomming := make(chan *PacketAndToken) // Will accept outgoing messages triggered by startIncommingComms (e.g. acknowledgements)
ibound := startIncomingComms(conn, c, inboundFromStore)
outboundFromIncoming := make(chan *PacketAndToken) // Will accept outgoing messages triggered by startIncomingComms (e.g. acknowledgements)

// Start the outgoing handler. It is important to note that output from startIncommingComms is fed into startOutgoingComms (for ACK's)
oboundErr := startOutgoingComms(conn, c, oboundp, obound, outboundFromIncomming)
// Start the outgoing handler. It is important to note that output from startIncomingComms is fed into startOutgoingComms (for ACK's)
oboundErr := startOutgoingComms(conn, c, oboundp, obound, outboundFromIncoming)
DEBUG.Println(NET, "startComms started")

// Run up go routines to handle the output from the above comms functions - these are handled in seperate
Expand All @@ -388,17 +404,17 @@ func startComms(conn net.Conn, // Network connection (must be active)
continue
}
if ic.outbound != nil {
outboundFromIncomming <- ic.outbound
outboundFromIncoming <- ic.outbound
continue
}
if ic.incommingPub != nil {
outPublish <- ic.incommingPub
if ic.incomingPub != nil {
outPublish <- ic.incomingPub
continue
}
ERROR.Println(STR, "startComms received empty incommingComms msg")
ERROR.Println(STR, "startComms received empty incomingComms msg")
}
// Close channels that will not be written to again (allowing other routines to exit)
close(outboundFromIncomming)
close(outboundFromIncoming)
close(outPublish)
wg.Done()
}()
Expand Down