Skip to content

Commit

Permalink
wip: MQTT: receive QoS2 messages
Browse files Browse the repository at this point in the history
  • Loading branch information
levb committed Jul 31, 2023
1 parent b22cdf1 commit 1714e4a
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 71 deletions.
125 changes: 107 additions & 18 deletions server/mqtt.go
Expand Up @@ -277,6 +277,12 @@ type mqttSession struct {
tmaxack int
clean bool
domainTk string

// Spec [MQTT-4.3.3-2]. When we receive a QOS2 message, we acknowledge it
// with a PUBREC and ensure its PI is in this map. If it was not in the map
// before, delivery is initiated, just like for QOS1. When we receive a
// PUBREL, we remove the PI from the map, send a PUBCOMP to the sender.
pendingQOS2Received map[uint16]struct{}
}

type mqttPersistedSession struct {
Expand Down Expand Up @@ -678,23 +684,28 @@ func (c *client) mqttParse(buf []byte) error {
}
}
if err == nil {
err = s.mqttProcessPub(c, pp)
}
if err == nil && pp.pi > 0 {
c.mqttEnqueuePubAck(pp.pi)
if trace {
c.traceOutOp("PUBACK", []byte(fmt.Sprintf("pi=%v", pp.pi)))
}
err = s.mqttProcessPub(c, pp, trace)
}

case mqttPacketPubAck:
var pi uint16
pi, err = mqttParsePubAck(r, pl)
pi, err = mqttParsePIPacket(r, pl)
if trace {
c.traceInOp("PUBACK", errOrTrace(err, fmt.Sprintf("pi=%v", pi)))
}
if err == nil {
c.mqttProcessPubAck(pi)
}

case mqttPacketPubRel:
var pi uint16
pi, err = mqttParsePIPacket(r, pl)
if trace {
c.traceInOp("PUBREL", errOrTrace(err, fmt.Sprintf("pi=%v", pi)))
}
if err == nil {
c.mqttProcessPubRel(pi, trace)
}
case mqttPacketSub:
var pi uint16 // packet identifier
var filters []*mqttFilter
Expand Down Expand Up @@ -784,8 +795,6 @@ func (c *client) mqttParse(buf []byte) error {
s.mqttHandleClosedClient(c)
c.closeConnection(ClientClosed)
return nil
case mqttPacketPubRec, mqttPacketPubRel, mqttPacketPubComp:
err = fmt.Errorf("protocol %d not supported", pt>>4)
default:
err = fmt.Errorf("received unknown packet type %d", pt>>4)
}
Expand Down Expand Up @@ -2454,6 +2463,7 @@ func (sess *mqttSession) clear() error {
}
}
sess.subs, sess.pending, sess.cpending, sess.seq, sess.tmaxack = nil, nil, nil, 0, 0
sess.pendingQOS2Received = nil
for _, dur := range durs {
if _, err := sess.jsa.deleteConsumer(mqttStreamName, dur); isErrorOtherThan(err, JSConsumerNotFoundErr) {
sess.mu.Unlock()
Expand Down Expand Up @@ -3050,7 +3060,7 @@ func (s *Server) mqttHandleWill(c *client) {
pp.flags |= mqttPubFlagRetain
}
c.mu.Unlock()
s.mqttProcessPub(c, pp)
s.mqttForwardPub(c, pp)
c.flushClients(0)
}

Expand All @@ -3062,8 +3072,8 @@ func (s *Server) mqttHandleWill(c *client) {

func (c *client) mqttParsePub(r *mqttReader, pl int, pp *mqttPublish, hasMappings bool) error {
qos := mqttGetQoS(pp.flags)
if qos > 1 {
return fmt.Errorf("publish QoS=%v not supported", qos)
if qos > 2 {
return fmt.Errorf("QoS=%v is invalid in MQTT", qos)
}
// Keep track of where we are when starting to read the variable header
start := r.pos
Expand Down Expand Up @@ -3144,7 +3154,52 @@ func mqttPubTrace(pp *mqttPublish) string {
//
// Runs from the client's readLoop.
// No lock held on entry.
func (s *Server) mqttProcessPub(c *client, pp *mqttPublish) error {
func (s *Server) mqttProcessPub(c *client, pp *mqttPublish, trace bool) error {
qos := mqttGetQoS(pp.flags)

// Forward the PUBLISH packet on, unless it is a QoS2 message that is not complete yet.
drop := false
if qos == 2 {
c.mqtt.sess.mu.Lock()
if len(c.mqtt.sess.pendingQOS2Received) > 0 {
_, drop = c.mqtt.sess.pendingQOS2Received[pp.pi]
}
c.mqtt.sess.mu.Unlock()
}

if !drop {
if err := s.mqttForwardPub(c, pp); err != nil {
return err
}
}

switch mqttGetQoS(pp.flags) {
case 1:
c.mqttEnqueuePubResponse(mqttPacketPubAck, pp.pi)
if trace {
c.traceOutOp("PUBACK", []byte(fmt.Sprintf("pi=%v", pp.pi)))
}

case 2:
c.mqttEnqueuePubResponse(mqttPacketPubRec, pp.pi)
if trace {
c.traceOutOp("PUBREC", []byte(fmt.Sprintf("pi=%v", pp.pi)))
}

if !drop {
c.mqtt.sess.mu.Lock()
if c.mqtt.sess.pendingQOS2Received == nil {
c.mqtt.sess.pendingQOS2Received = make(map[uint16]struct{})
}
c.mqtt.sess.pendingQOS2Received[pp.pi] = struct{}{}
c.mqtt.sess.mu.Unlock()
}
}

return nil
}

func (s *Server) mqttForwardPub(c *client, pp *mqttPublish) error {
c.pa.subject, c.pa.mapped, c.pa.hdr, c.pa.size, c.pa.reply = pp.subject, pp.mapped, -1, pp.sz, nil

bb := bytes.Buffer{}
Expand Down Expand Up @@ -3395,16 +3450,16 @@ func mqttWritePublish(w *mqttWriter, qos byte, dup, retain bool, subject string,
w.Write([]byte(payload))
}

func (c *client) mqttEnqueuePubAck(pi uint16) {
proto := [4]byte{mqttPacketPubAck, 0x2, 0, 0}
func (c *client) mqttEnqueuePubResponse(packetType byte, pi uint16) {
proto := [4]byte{packetType, 0x2, 0, 0}
proto[2] = byte(pi >> 8)
proto[3] = byte(pi)
c.mu.Lock()
c.enqueueProto(proto[:4])
c.mu.Unlock()
}

func mqttParsePubAck(r *mqttReader, pl int) (uint16, error) {
func mqttParsePIPacket(r *mqttReader, pl int) (uint16, error) {
pi, err := r.readUint16("packet identifier")
if err != nil {
return 0, err
Expand All @@ -3415,7 +3470,7 @@ func mqttParsePubAck(r *mqttReader, pl int) (uint16, error) {
return pi, nil
}

// Process a PUBACK packet.
// Process a PUBACK packet (QoS1, acting as Sender).
// Updates the session's pending list and sends an ACK to JS.
//
// Runs from the client's readLoop.
Expand Down Expand Up @@ -3451,6 +3506,40 @@ func (c *client) mqttProcessPubAck(pi uint16) {
sess.mu.Unlock()
}

// Process a PUBREL packet (QoS2, acting as Receiver).
// Updates the session's pending list and sends an ACK to JS.
//
// Runs from the client's readLoop.
// No lock held on entry.
func (c *client) mqttProcessPubRel(pi uint16, trace bool) {
sess := c.mqtt.sess
if sess == nil {
return
}

isPending := false
sess.mu.Lock()
if sess.c != c {
sess.mu.Unlock()
return
}

if len(sess.pendingQOS2Received) > 0 {
_, isPending = sess.pendingQOS2Received[pi]
}
if isPending {
delete(sess.pendingQOS2Received, pi)
}
sess.mu.Unlock()

if isPending {
c.mqttEnqueuePubResponse(mqttPacketPubComp, pi)
if trace {
c.traceOutOp("PUBCOMP", []byte(fmt.Sprintf("pi=%v", pi)))
}
}
}

// Return the QoS from the given PUBLISH protocol's flags
func mqttGetQoS(flags byte) byte {
return flags & mqttPubFlagQoS >> 1
Expand Down

0 comments on commit 1714e4a

Please sign in to comment.