Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Per-subject limits for MQTT retained messages #4199

Merged
merged 4 commits into from Jun 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
134 changes: 126 additions & 8 deletions server/mqtt.go
Expand Up @@ -112,7 +112,7 @@ const (

// Stream name for MQTT retained messages on a given account
mqttRetainedMsgsStreamName = mqttStreamNamePrefix + "rmsgs"
mqttRetainedMsgsStreamSubject = "$MQTT.rmsgs"
mqttRetainedMsgsStreamSubject = "$MQTT.rmsgs."

// Stream name for MQTT sessions on a given account
mqttSessStreamName = mqttStreamNamePrefix + "sess"
Expand Down Expand Up @@ -145,6 +145,7 @@ const (
mqttJSAIdTokenPos = 3
mqttJSATokenPos = 4
mqttJSAStreamCreate = "SC"
mqttJSAStreamUpdate = "SU"
mqttJSAStreamLookup = "SL"
mqttJSAStreamDel = "SD"
mqttJSAConsumerCreate = "CC"
Expand Down Expand Up @@ -1150,11 +1151,12 @@ func (s *Server) mqttCreateAccountSessionManager(acc *Account, quitCh chan struc
} else if si == nil {
// Create the stream for retained messages.
cfg := &StreamConfig{
Name: mqttRetainedMsgsStreamName,
Subjects: []string{mqttRetainedMsgsStreamSubject},
Storage: FileStorage,
Retention: LimitsPolicy,
Replicas: replicas,
Name: mqttRetainedMsgsStreamName,
Subjects: []string{mqttRetainedMsgsStreamSubject + as.domainTk + ">"},
Storage: FileStorage,
Retention: LimitsPolicy,
Replicas: replicas,
MaxMsgsPer: 1,
}
// We will need "si" outside of this block.
si, _, err = jsa.createStream(cfg)
Expand All @@ -1170,6 +1172,39 @@ func (s *Server) mqttCreateAccountSessionManager(acc *Account, quitCh chan struc
}
}
}
// Doing this check outside of above if/else due to possible race when
// creating the stream.
wantedSubj := mqttRetainedMsgsStreamSubject + as.domainTk + ">"
if len(si.Config.Subjects) != 1 || si.Config.Subjects[0] != wantedSubj {
// Update only the Subjects at this stage, not MaxMsgsPer yet.
si.Config.Subjects = []string{wantedSubj}
if si, err = jsa.updateStream(&si.Config); err != nil {
return nil, fmt.Errorf("failed to update stream config: %w", err)
}
}
// Try to transfer regardless if we have already updated the stream or not
// in case not all messages were transferred and the server was restarted.
if as.transferRetainedToPerKeySubjectStream(s) {
// We need another lookup to have up-to-date si.State values in order
// to load all retained messages.
si, err = lookupStream(mqttRetainedMsgsStreamName, "retained messages")
if err != nil {
return nil, err
}
}
// Now, if the stream does not have MaxMsgsPer set to 1, and there are no
// more messages on the single $MQTT.rmsgs subject, update the stream again.
if si.Config.MaxMsgsPer != 1 {
_, err := jsa.loadNextMsgFor(mqttRetainedMsgsStreamName, "$MQTT.rmsgs")
// Looking for an error indicated that there is no such message.
if err != nil && IsNatsErr(err, JSNoMessageFoundErr) {
si.Config.MaxMsgsPer = 1
// We will need an up-to-date si, so don't use local variable here.
if si, err = jsa.updateStream(&si.Config); err != nil {
return nil, fmt.Errorf("failed to update stream config: %w", err)
}
}
}

var lastSeq uint64
var rmDoneCh chan struct{}
Expand Down Expand Up @@ -1199,7 +1234,7 @@ func (s *Server) mqttCreateAccountSessionManager(acc *Account, quitCh chan struc
Stream: mqttRetainedMsgsStreamName,
Config: ConsumerConfig{
Durable: rmDurName,
FilterSubject: mqttRetainedMsgsStreamSubject,
FilterSubject: mqttRetainedMsgsStreamSubject + as.domainTk + ">",
DeliverSubject: rmsubj,
ReplayPolicy: ReplayInstant,
AckPolicy: AckNone,
Expand Down Expand Up @@ -1353,6 +1388,19 @@ func (jsa *mqttJSA) createStream(cfg *StreamConfig) (*StreamInfo, bool, error) {
return scr.StreamInfo, scr.DidCreate, scr.ToError()
}

func (jsa *mqttJSA) updateStream(cfg *StreamConfig) (*StreamInfo, error) {
cfgb, err := json.Marshal(cfg)
if err != nil {
return nil, err
}
scri, err := jsa.newRequest(mqttJSAStreamUpdate, fmt.Sprintf(JSApiStreamUpdateT, cfg.Name), 0, cfgb)
if err != nil {
return nil, err
}
scr := scri.(*JSApiStreamUpdateResponse)
return scr.StreamInfo, scr.ToError()
}

func (jsa *mqttJSA) lookupStream(name string) (*StreamInfo, error) {
slri, err := jsa.newRequest(mqttJSAStreamLookup, fmt.Sprintf(JSApiStreamInfoT, name), 0, nil)
if err != nil {
Expand Down Expand Up @@ -1385,6 +1433,20 @@ func (jsa *mqttJSA) loadLastMsgFor(streamName string, subject string) (*StoredMs
return lmr.Message, lmr.ToError()
}

func (jsa *mqttJSA) loadNextMsgFor(streamName string, subject string) (*StoredMsg, error) {
mreq := &JSApiMsgGetRequest{NextFor: subject}
req, err := json.Marshal(mreq)
if err != nil {
return nil, err
}
lmri, err := jsa.newRequest(mqttJSAMsgLoad, fmt.Sprintf(JSApiMsgGetT, streamName), 0, req)
if err != nil {
return nil, err
}
lmr := lmri.(*JSApiMsgGetResponse)
return lmr.Message, lmr.ToError()
}

func (jsa *mqttJSA) loadMsg(streamName string, seq uint64) (*StoredMsg, error) {
mreq := &JSApiMsgGetRequest{Seq: seq}
req, err := json.Marshal(mreq)
Expand Down Expand Up @@ -1464,6 +1526,12 @@ func (as *mqttAccountSessionManager) processJSAPIReplies(_ *subscription, pc *cl
resp.Error = NewJSInvalidJSONError()
}
ch <- resp
case mqttJSAStreamUpdate:
var resp = &JSApiStreamUpdateResponse{}
if err := json.Unmarshal(msg, resp); err != nil {
resp.Error = NewJSInvalidJSONError()
}
ch <- resp
case mqttJSAStreamLookup:
var resp = &JSApiStreamInfoResponse{}
if err := json.Unmarshal(msg, &resp); err != nil {
Expand Down Expand Up @@ -2260,6 +2328,56 @@ func (as *mqttAccountSessionManager) transferUniqueSessStreamsToMuxed(log *Serve
retry = false
}

func (as *mqttAccountSessionManager) transferRetainedToPerKeySubjectStream(log *Server) bool {
jsa := &as.jsa
var count, errors int

for {
// Try and look up messages on the original undivided "$MQTT.rmsgs" subject.
// If nothing is returned here, we assume to have migrated all old messages.
smsg, err := jsa.loadNextMsgFor(mqttRetainedMsgsStreamName, "$MQTT.rmsgs")
if err != nil {
if IsNatsErr(err, JSNoMessageFoundErr) {
// We've ran out of messages to transfer so give up.
break
}
log.Warnf(" Unable to load retained message with sequence %d: %s", smsg.Sequence, err)
errors++
break
}
// Unmarshal the message so that we can obtain the subject name.
var rmsg mqttRetainedMsg
if err := json.Unmarshal(smsg.Data, &rmsg); err != nil {
log.Warnf(" Unable to unmarshal retained message with sequence %d, skipping", smsg.Sequence)
errors++
continue
}
// Store the message again, this time with the new per-key subject.
subject := mqttRetainedMsgsStreamSubject + as.domainTk + rmsg.Subject
if _, err := jsa.storeMsgWithKind(mqttJSASessPersist, subject, 0, smsg.Data); err != nil {
log.Errorf(" Unable to transfer the retained message with sequence %d: %v", smsg.Sequence, err)
errors++
continue
}
// Delete the original message.
if err := jsa.deleteMsg(mqttRetainedMsgsStreamName, smsg.Sequence, true); err != nil {
log.Errorf(" Unable to clean up the retained message with sequence %d: %v", smsg.Sequence, err)
errors++
continue
}
count++
}
if errors > 0 {
next := mqttDefaultTransferRetry
log.Warnf("Failed to transfer %d MQTT retained messages, will try again in %v", errors, next)
time.AfterFunc(next, func() { as.transferRetainedToPerKeySubjectStream(log) })
} else if count > 0 {
log.Noticef("Transfer of %d MQTT retained messages done!", count)
}
// Signal if there was any activity (either some transferred or some errors)
return errors > 0 || count > 0
}

//////////////////////////////////////////////////////////////////////////////
//
// MQTT session related functions
Expand Down Expand Up @@ -3092,7 +3210,7 @@ func (c *client) mqttHandlePubRetain() {
Source: c.opts.Username,
}
rmBytes, _ := json.Marshal(rm)
smr, err := asm.jsa.storeMsg(mqttRetainedMsgsStreamSubject, -1, rmBytes)
smr, err := asm.jsa.storeMsg(mqttRetainedMsgsStreamSubject+asm.domainTk+key, -1, rmBytes)
if err == nil {
// Update the new sequence
rm.sseq = smr.Sequence
Expand Down
91 changes: 89 additions & 2 deletions server/mqtt_test.go
Expand Up @@ -1979,6 +1979,11 @@ func testMQTTCheckPubMsgNoAck(t testing.TB, c net.Conn, r *mqttReader, topic str
}

func testMQTTGetPubMsg(t testing.TB, c net.Conn, r *mqttReader, topic string, payload []byte) (byte, uint16) {
flags, pi, _ := testMQTTGetPubMsgEx(t, c, r, topic, payload)
return flags, pi
}

func testMQTTGetPubMsgEx(t testing.TB, c net.Conn, r *mqttReader, topic string, payload []byte) (byte, uint16, string) {
t.Helper()
b, pl := testMQTTReadPacket(t, r)
if pt := b & mqttPacketMask; pt != mqttPacketPub {
Expand All @@ -1991,7 +1996,7 @@ func testMQTTGetPubMsg(t testing.TB, c net.Conn, r *mqttReader, topic string, pa
if err != nil {
t.Fatal(err)
}
if ptopic != topic {
if topic != _EMPTY_ && ptopic != topic {
t.Fatalf("Expected topic %q, got %q", topic, ptopic)
}
var pi uint16
Expand All @@ -2011,7 +2016,7 @@ func testMQTTGetPubMsg(t testing.TB, c net.Conn, r *mqttReader, topic string, pa
t.Fatalf("Expected payload %q, got %q", payload, ppayload)
}
r.pos += msgLen
return pflags, pi
return pflags, pi, ptopic
}

func testMQTTSendPubAck(t testing.TB, c net.Conn, pi uint16) {
Expand Down Expand Up @@ -2993,6 +2998,88 @@ func TestMQTTRetainedMsgNetworkUpdates(t *testing.T) {
}
}

func TestMQTTRetainedMsgMigration(t *testing.T) {
o := testMQTTDefaultOptions()
s := testMQTTRunServer(t, o)
defer testMQTTShutdownServer(s)

nc, js := jsClientConnect(t, s)
defer nc.Close()

// Create the retained messages stream to listen on the old subject first.
// The server will correct this when the migration takes place.
_, err := js.AddStream(&nats.StreamConfig{
Name: mqttRetainedMsgsStreamName,
Subjects: []string{`$MQTT.rmsgs`},
Storage: nats.FileStorage,
Retention: nats.LimitsPolicy,
Replicas: 1,
})
require_NoError(t, err)

// Publish some retained messages on the old "$MQTT.rmsgs" subject.
for i := 0; i < 100; i++ {
msg := fmt.Sprintf(
`{"origin":"b5IQZNtG","subject":"test%d","topic":"test%d","msg":"YmFy","flags":1}`, i, i,
)
_, err := js.Publish(`$MQTT.rmsgs`, []byte(msg))
require_NoError(t, err)
}

// Check that the old subject looks right.
si, err := js.StreamInfo(mqttRetainedMsgsStreamName, &nats.StreamInfoRequest{
SubjectsFilter: `$MQTT.>`,
})
require_NoError(t, err)
if si.State.NumSubjects != 1 {
t.Fatalf("expected 1 subject, got %d", si.State.NumSubjects)
}
if n := si.State.Subjects[`$MQTT.rmsgs`]; n != 100 {
t.Fatalf("expected to find 100 messages on the original subject but found %d", n)
}

// Create an MQTT client, this will cause a migration to take place.
mc, rc := testMQTTConnect(t, &mqttConnInfo{clientID: "sub", cleanSess: true}, o.MQTT.Host, o.MQTT.Port)
defer mc.Close()
testMQTTCheckConnAck(t, rc, mqttConnAckRCConnectionAccepted, false)

testMQTTSub(t, 1, mc, rc, []*mqttFilter{{filter: "+", qos: 0}}, []byte{0})
topics := map[string]struct{}{}
for i := 0; i < 100; i++ {
_, _, topic := testMQTTGetPubMsgEx(t, mc, rc, _EMPTY_, []byte("bar"))
topics[topic] = struct{}{}
}
if len(topics) != 100 {
t.Fatalf("Unexpected topics: %v", topics)
}

// Now look at the stream, there should be 100 messages on the new
// divided subjects and none on the old undivided subject.
si, err = js.StreamInfo(mqttRetainedMsgsStreamName, &nats.StreamInfoRequest{
SubjectsFilter: `$MQTT.>`,
})
require_NoError(t, err)
if si.State.NumSubjects != 100 {
t.Fatalf("expected 100 subjects, got %d", si.State.NumSubjects)
}
if n := si.State.Subjects[`$MQTT.rmsgs`]; n > 0 {
t.Fatalf("expected to find no messages on the original subject but found %d", n)
}

// Check that the message counts look right. There should be one
// retained message per key.
for i := 0; i < 100; i++ {
expected := fmt.Sprintf(`$MQTT.rmsgs.test%d`, i)
n, ok := si.State.Subjects[expected]
if !ok {
t.Fatalf("expected to find %q but didn't", expected)
}
if n != 1 {
t.Fatalf("expected %q to have 1 message but had %d", expected, n)
}
}
}

func TestMQTTClusterReplicasCount(t *testing.T) {
for _, test := range []struct {
size int
Expand Down