Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't hold entire MQTT retained messages in memory #4228

Merged
merged 5 commits into from Jun 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
122 changes: 79 additions & 43 deletions server/mqtt.go
Expand Up @@ -23,6 +23,7 @@ import (
"io"
"net"
"net/http"
"sort"
"strconv"
"strings"
"sync"
Expand Down Expand Up @@ -221,13 +222,13 @@ type mqttSessionManager struct {

type mqttAccountSessionManager struct {
mu sync.RWMutex
sessions map[string]*mqttSession // key is MQTT client ID
sessByHash map[string]*mqttSession // key is MQTT client ID hash
sessLocked map[string]struct{} // key is MQTT client ID and indicate that a session can not be taken by a new client at this time
flappers map[string]int64 // When connection connects with client ID already in use
flapTimer *time.Timer // Timer to perform some cleanup of the flappers map
sl *Sublist // sublist allowing to find retained messages for given subscription
retmsgs map[string]*mqttRetainedMsg // retained messages
sessions map[string]*mqttSession // key is MQTT client ID
sessByHash map[string]*mqttSession // key is MQTT client ID hash
sessLocked map[string]struct{} // key is MQTT client ID and indicate that a session can not be taken by a new client at this time
flappers map[string]int64 // When connection connects with client ID already in use
flapTimer *time.Timer // Timer to perform some cleanup of the flappers map
sl *Sublist // sublist allowing to find retained messages for given subscription
retmsgs map[string]*mqttRetainedMsgRef // retained messages
jsa mqttJSA
rrmLastSeq uint64 // Restore retained messages expected last sequence
rrmDoneCh chan struct{} // To notify the caller that all retained messages have been loaded
Expand Down Expand Up @@ -293,8 +294,9 @@ type mqttRetainedMsg struct {
Msg []byte `json:"msg,omitempty"`
Flags byte `json:"flags,omitempty"`
Source string `json:"source,omitempty"`
}

// non exported
type mqttRetainedMsgRef struct {
sseq uint64
floor uint64
sub *subscription
Expand Down Expand Up @@ -1604,8 +1606,9 @@ func (as *mqttAccountSessionManager) processRetainedMsg(_ *subscription, c *clie
seq, _, _ := ackReplyInfo(reply)

// Handle this retained message
rm.sseq = seq
as.handleRetainedMsg(rm.Subject, rm)
rf := &mqttRetainedMsgRef{}
rf.sseq = seq
as.handleRetainedMsg(rm.Subject, rf)

// If we were recovering (lastSeq > 0), then check if we are done.
if as.rrmLastSeq > 0 && seq >= as.rrmLastSeq {
Expand Down Expand Up @@ -1873,27 +1876,21 @@ func (as *mqttAccountSessionManager) sendJSAPIrequests(s *Server, c *client, acc
// or 0 if the record was added instead of updated.
//
// Lock not held on entry.
func (as *mqttAccountSessionManager) handleRetainedMsg(key string, rm *mqttRetainedMsg) uint64 {
func (as *mqttAccountSessionManager) handleRetainedMsg(key string, rm *mqttRetainedMsgRef) {
as.mu.Lock()
defer as.mu.Unlock()
if as.retmsgs == nil {
as.retmsgs = make(map[string]*mqttRetainedMsg)
as.retmsgs = make(map[string]*mqttRetainedMsgRef)
as.sl = NewSublistWithCache()
} else {
// Check if we already had one. If so, update the existing one.
if erm, exists := as.retmsgs[key]; exists {
// If the new sequence is below the floor or the existing one,
// then ignore the new one.
if rm.sseq <= erm.sseq || rm.sseq <= erm.floor {
return 0
return
}
// Update the existing retained message record with the new rm record.
erm.Origin = rm.Origin
erm.Msg = rm.Msg
erm.Flags = rm.Flags
erm.Source = rm.Source
// Capture existing sequence number so we can return it as the old sequence.
oldSeq := erm.sseq
erm.sseq = rm.sseq
// Clear the floor
erm.floor = 0
Expand All @@ -1903,13 +1900,11 @@ func (as *mqttAccountSessionManager) handleRetainedMsg(key string, rm *mqttRetai
erm.sub = &subscription{subject: []byte(key)}
as.sl.Insert(erm.sub)
}
return oldSeq
}
}
rm.sub = &subscription{subject: []byte(key)}
as.retmsgs[key] = rm
as.sl.Insert(rm.sub)
return 0
}

// Removes the retained message for the given `subject` if present, and returns the
Expand All @@ -1922,7 +1917,7 @@ func (as *mqttAccountSessionManager) handleRetainedMsgDel(subject string, seq ui
var seqToRemove uint64
as.mu.Lock()
if as.retmsgs == nil {
as.retmsgs = make(map[string]*mqttRetainedMsg)
as.retmsgs = make(map[string]*mqttRetainedMsgRef)
as.sl = NewSublistWithCache()
}
if erm, ok := as.retmsgs[subject]; ok {
Expand All @@ -1941,8 +1936,8 @@ func (as *mqttAccountSessionManager) handleRetainedMsgDel(subject string, seq ui
seqToRemove = erm.sseq
}
} else if seq != 0 {
rm := &mqttRetainedMsg{Subject: subject, floor: seq}
as.retmsgs[subject] = rm
rf := &mqttRetainedMsgRef{floor: seq}
as.retmsgs[subject] = rf
}
as.mu.Unlock()
return seqToRemove
Expand Down Expand Up @@ -2193,11 +2188,16 @@ func (as *mqttAccountSessionManager) getRetainedPublishMsgs(subject string, rms
return
}
for _, sub := range result.psubs {
// Since this is a reverse match, the subscription objects here
// contain literals corresponding to the published subjects.
if rm, ok := as.retmsgs[string(sub.subject)]; ok {
*rms = append(*rms, rm)
subj := mqttRetainedMsgsStreamSubject + as.domainTk + string(sub.subject)
jsm, err := as.jsa.loadLastMsgFor(mqttRetainedMsgsStreamName, subj)
if err != nil || jsm == nil {
continue
}
var rm mqttRetainedMsg
if err := json.Unmarshal(jsm.Data, &rm); err != nil {
continue
}
*rms = append(*rms, &rm)
}
}

Expand Down Expand Up @@ -3213,13 +3213,11 @@ func (c *client) mqttHandlePubRetain() {
smr, err := asm.jsa.storeMsg(mqttRetainedMsgsStreamSubject+asm.domainTk+key, -1, rmBytes)
if err == nil {
// Update the new sequence
rm.sseq = smr.Sequence
// Add/update the map
oldSeq := asm.handleRetainedMsg(key, rm)
// If this is a new message on the same subject, delete the old one.
if oldSeq != 0 {
asm.deleteRetainedMsg(oldSeq)
rf := &mqttRetainedMsgRef{
sseq: smr.Sequence,
}
// Add/update the map
asm.handleRetainedMsg(key, rf)
} else {
c.mu.Lock()
acc := c.acc
Expand Down Expand Up @@ -3256,14 +3254,50 @@ func (s *Server) mqttCheckPubRetainedPerms() {
}
s.mu.Unlock()

// First get a list of all of the sessions.
sm.mu.RLock()
defer sm.mu.RUnlock()

asms := make([]*mqttAccountSessionManager, 0, len(sm.sessions))
for _, asm := range sm.sessions {
asms = append(asms, asm)
}
sm.mu.RUnlock()

type retainedMsg struct {
subj string
rmsg *mqttRetainedMsgRef
}

// For each session we will obtain a list of retained messages.
var _rms [128]retainedMsg
rms := _rms[:0]
for _, asm := range asms {
// Get all of the retained messages. Then we will sort them so
// that they are in sequence order, which should help the file
// store to not have to load out-of-order blocks so often.
asm.mu.RLock()
rms = rms[:0] // reuse slice
for subj, rf := range asm.retmsgs {
rms = append(rms, retainedMsg{
subj: subj,
rmsg: rf,
})
}
asm.mu.RUnlock()
sort.Slice(rms, func(i, j int) bool {
return rms[i].rmsg.sseq < rms[j].rmsg.sseq
})

perms := map[string]*perm{}
deletes := map[string]uint64{}
asm.mu.Lock()
for subject, rm := range asm.retmsgs {
for _, rf := range rms {
jsm, err := asm.jsa.loadMsg(mqttRetainedMsgsStreamName, rf.rmsg.sseq)
if err != nil || jsm == nil {
continue
}
var rm mqttRetainedMsg
if err := json.Unmarshal(jsm.Data, &rm); err != nil {
continue
}
if rm.Source == _EMPTY_ {
continue
}
Expand All @@ -3277,20 +3311,22 @@ func (s *Server) mqttCheckPubRetainedPerms() {
}
// If there is permission and no longer allowed to publish in
// the subject, remove the publish retained message from the map.
if p != nil && !pubAllowed(p, subject) {
if p != nil && !pubAllowed(p, rf.subj) {
u = nil
}
}

// Not present or permissions have changed such that the source can't
// publish on that subject anymore: remove it from the map.
if u == nil {
delete(asm.retmsgs, subject)
asm.sl.Remove(rm.sub)
deletes[subject] = rm.sseq
asm.mu.Lock()
delete(asm.retmsgs, rf.subj)
asm.sl.Remove(rf.rmsg.sub)
neilalexander marked this conversation as resolved.
Show resolved Hide resolved
asm.mu.Unlock()
deletes[rf.subj] = rf.rmsg.sseq
}
}
asm.mu.Unlock()

for subject, seq := range deletes {
asm.deleteRetainedMsg(seq)
asm.notifyRetainedMsgDeleted(subject, seq)
Expand Down
8 changes: 4 additions & 4 deletions server/mqtt_test.go
Expand Up @@ -2975,8 +2975,8 @@ func TestMQTTRetainedMsgNetworkUpdates(t *testing.T) {
t.Run(test.subject, func(t *testing.T) {
for _, a := range test.order {
if a.add {
rm := &mqttRetainedMsg{sseq: a.seq}
asm.handleRetainedMsg(test.subject, rm)
rf := &mqttRetainedMsgRef{sseq: a.seq}
asm.handleRetainedMsg(test.subject, rf)
} else {
asm.handleRetainedMsgDel(test.subject, a.seq)
}
Expand All @@ -2988,8 +2988,8 @@ func TestMQTTRetainedMsgNetworkUpdates(t *testing.T) {
for _, subject := range []string{"foo.5", "foo.6"} {
t.Run("clear_"+subject, func(t *testing.T) {
// Now add a new message, which should clear the floor.
rm := &mqttRetainedMsg{sseq: 3}
asm.handleRetainedMsg(subject, rm)
rf := &mqttRetainedMsgRef{sseq: 3}
asm.handleRetainedMsg(subject, rf)
check(t, subject, true, 3, 0)
// Now do a non network delete and make sure it is gone.
asm.handleRetainedMsgDel(subject, 0)
Expand Down