Skip to content

Commit

Permalink
[FIXED] Server reload with highly active accounts with service import…
Browse files Browse the repository at this point in the history
…s could cause panic or dataloss (#4327)

When service imports were reloaded on active accounts with lots of
traffic the server could panic or lose data.

Signed-off-by: Derek Collison <derek@nats.io>
  • Loading branch information
derekcollison committed Jul 20, 2023
2 parents 0347f27 + 7477ce8 commit 6c9fb6a
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 40 deletions.
7 changes: 1 addition & 6 deletions .travis.yml
@@ -1,17 +1,12 @@
os: linux
dist: focal

branches:
only:
- main
- dev

vm:
size: 2x-large

language: go
go:
- 1.19.11
- '1.19.11'
go_import_path: github.com/nats-io/nats-server

addons:
Expand Down
74 changes: 73 additions & 1 deletion server/accounts_test.go
@@ -1,4 +1,4 @@
// Copyright 2018-2022 The NATS Authors
// Copyright 2018-2023 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
Expand Down Expand Up @@ -3683,3 +3683,75 @@ func TestAccountImportDuplicateResponseDeliveryWithLeafnodes(t *testing.T) {
t.Fatalf("Expected only 1 response, got %d", n)
}
}

func TestAccountReloadServiceImportPanic(t *testing.T) {
conf := createConfFile(t, []byte(`
listen: 127.0.0.1:-1
accounts {
A {
users = [ { user: "a", pass: "p" } ]
exports [ { service: "HELP" } ]
}
B {
users = [ { user: "b", pass: "p" } ]
imports [ { service: { account: A, subject: "HELP"} } ]
}
$SYS { users = [ { user: "admin", pass: "s3cr3t!" } ] }
}
`))

s, _ := RunServerWithConfig(conf)
defer s.Shutdown()

// Now connect up the subscriber for HELP. No-op for this test.
nc, _ := jsClientConnect(t, s, nats.UserInfo("a", "p"))
_, err := nc.Subscribe("HELP", func(m *nats.Msg) { m.Respond([]byte("OK")) })
require_NoError(t, err)
defer nc.Close()

// Now create connection to account b where we will publish to HELP.
nc, _ = jsClientConnect(t, s, nats.UserInfo("b", "p"))
require_NoError(t, err)
defer nc.Close()

// We want to continually be publishing messages that will trigger the service import while calling reload.
done := make(chan bool)
var wg sync.WaitGroup
wg.Add(1)

var requests, responses atomic.Uint64
reply := nats.NewInbox()
_, err = nc.Subscribe(reply, func(m *nats.Msg) { responses.Add(1) })
require_NoError(t, err)

go func() {
defer wg.Done()
for {
select {
case <-done:
return
default:
nc.PublishRequest("HELP", reply, []byte("HELP"))
requests.Add(1)
}
}
}()

// Perform a bunch of reloads.
for i := 0; i < 1000; i++ {
err := s.Reload()
require_NoError(t, err)
}

close(done)
wg.Wait()

totalRequests := requests.Load()
checkFor(t, 10*time.Second, 250*time.Millisecond, func() error {
resp := responses.Load()
if resp == totalRequests {
return nil
}
return fmt.Errorf("Have not received all responses, want %d got %d", totalRequests, resp)
})
}
54 changes: 31 additions & 23 deletions server/client.go
Expand Up @@ -789,15 +789,16 @@ func (c *client) subsAtLimit() bool {
}

func minLimit(value *int32, limit int32) bool {
if *value != jwt.NoLimit {
v := atomic.LoadInt32(value)
if v != jwt.NoLimit {
if limit != jwt.NoLimit {
if limit < *value {
*value = limit
if limit < v {
atomic.StoreInt32(value, limit)
return true
}
}
} else if limit != jwt.NoLimit {
*value = limit
atomic.StoreInt32(value, limit)
return true
}
return false
Expand All @@ -810,7 +811,7 @@ func (c *client) applyAccountLimits() {
if c.acc == nil || (c.kind != CLIENT && c.kind != LEAF) {
return
}
c.mpay = jwt.NoLimit
atomic.StoreInt32(&c.mpay, jwt.NoLimit)
c.msubs = jwt.NoLimit
if c.opts.JWT != _EMPTY_ { // user jwt implies account
if uc, _ := jwt.DecodeUserClaims(c.opts.JWT); uc != nil {
Expand Down Expand Up @@ -3576,15 +3577,21 @@ func (c *client) processInboundClientMsg(msg []byte) (bool, bool) {
}

// Mostly under testing scenarios.
c.mu.Lock()
if c.srv == nil || c.acc == nil {
c.mu.Unlock()
return false, false
}
acc := c.acc
genidAddr := &acc.sl.genid

// Check pub permissions
if c.perms != nil && (c.perms.pub.allow != nil || c.perms.pub.deny != nil) && !c.pubAllowed(string(c.pa.subject)) {
if c.perms != nil && (c.perms.pub.allow != nil || c.perms.pub.deny != nil) && !c.pubAllowedFullCheck(string(c.pa.subject), true, true) {
c.mu.Unlock()
c.pubPermissionViolation(c.pa.subject)
return false, true
}
c.mu.Unlock()

// Now check for reserved replies. These are used for service imports.
if c.kind == CLIENT && len(c.pa.reply) > 0 && isReservedReply(c.pa.reply) {
Expand All @@ -3605,10 +3612,10 @@ func (c *client) processInboundClientMsg(msg []byte) (bool, bool) {
// performance impact reported in our bench)
var isGWRouted bool
if c.kind != CLIENT {
if atomic.LoadInt32(&c.acc.gwReplyMapping.check) > 0 {
c.acc.mu.RLock()
c.pa.subject, isGWRouted = c.acc.gwReplyMapping.get(c.pa.subject)
c.acc.mu.RUnlock()
if atomic.LoadInt32(&acc.gwReplyMapping.check) > 0 {
acc.mu.RLock()
c.pa.subject, isGWRouted = acc.gwReplyMapping.get(c.pa.subject)
acc.mu.RUnlock()
}
} else if atomic.LoadInt32(&c.gwReplyMapping.check) > 0 {
c.mu.Lock()
Expand Down Expand Up @@ -3651,7 +3658,7 @@ func (c *client) processInboundClientMsg(msg []byte) (bool, bool) {
var r *SublistResult
var ok bool

genid := atomic.LoadUint64(&c.acc.sl.genid)
genid := atomic.LoadUint64(genidAddr)
if genid == c.in.genid && c.in.results != nil {
r, ok = c.in.results[string(c.pa.subject)]
} else {
Expand All @@ -3662,7 +3669,7 @@ func (c *client) processInboundClientMsg(msg []byte) (bool, bool) {

// Go back to the sublist data structure.
if !ok {
r = c.acc.sl.Match(string(c.pa.subject))
r = acc.sl.Match(string(c.pa.subject))
c.in.results[string(c.pa.subject)] = r
// Prune the results cache. Keeps us from unbounded growth. Random delete.
if len(c.in.results) > maxResultCacheSize {
Expand Down Expand Up @@ -3693,7 +3700,7 @@ func (c *client) processInboundClientMsg(msg []byte) (bool, bool) {
atomic.LoadInt64(&c.srv.gateway.totalQSubs) > 0 {
flag |= pmrCollectQueueNames
}
didDeliver, qnames = c.processMsgResults(c.acc, r, msg, c.pa.deliver, c.pa.subject, c.pa.reply, flag)
didDeliver, qnames = c.processMsgResults(acc, r, msg, c.pa.deliver, c.pa.subject, c.pa.reply, flag)
}

// Now deal with gateways
Expand All @@ -3703,7 +3710,7 @@ func (c *client) processInboundClientMsg(msg []byte) (bool, bool) {
reply = append(reply, '@')
reply = append(reply, c.pa.deliver...)
}
didDeliver = c.sendMsgToGateways(c.acc, msg, c.pa.subject, reply, qnames) || didDeliver
didDeliver = c.sendMsgToGateways(acc, msg, c.pa.subject, reply, qnames) || didDeliver
}

// Check to see if we did not deliver to anyone and the client has a reply subject set
Expand Down Expand Up @@ -3909,6 +3916,7 @@ func (c *client) processServiceImport(si *serviceImport, acc *Account, msg []byt
checkJS = true
}
}
siAcc := si.acc
acc.mu.RUnlock()

// We have a special case where JetStream pulls in all service imports through one export.
Expand Down Expand Up @@ -3939,7 +3947,7 @@ func (c *client) processServiceImport(si *serviceImport, acc *Account, msg []byt
}
} else if !isResponse && si.latency != nil && tracking {
// Check to see if this was a bad request with no reply and we were supposed to be tracking.
si.acc.sendBadRequestTrackingLatency(si, c, headers)
siAcc.sendBadRequestTrackingLatency(si, c, headers)
}

// Send tracking info here if we are tracking this response.
Expand Down Expand Up @@ -3967,7 +3975,7 @@ func (c *client) processServiceImport(si *serviceImport, acc *Account, msg []byt
// Now check to see if this account has mappings that could affect the service import.
// Can't use non-locked trick like in processInboundClientMsg, so just call into selectMappedSubject
// so we only lock once.
nsubj, changed := si.acc.selectMappedSubject(to)
nsubj, changed := siAcc.selectMappedSubject(to)
if changed {
c.pa.mapped = []byte(to)
to = nsubj
Expand All @@ -3984,7 +3992,7 @@ func (c *client) processServiceImport(si *serviceImport, acc *Account, msg []byt
// Place our client info for the request in the original message.
// This will survive going across routes, etc.
if !isResponse {
isSysImport := si.acc == c.srv.SystemAccount()
isSysImport := siAcc == c.srv.SystemAccount()
var ci *ClientInfo
if hadPrevSi && c.pa.hdr >= 0 {
var cis ClientInfo
Expand Down Expand Up @@ -4025,11 +4033,11 @@ func (c *client) processServiceImport(si *serviceImport, acc *Account, msg []byt
c.pa.reply = nrr

if changed && c.isMqtt() && c.pa.hdr > 0 {
c.srv.mqttStoreQoS1MsgForAccountOnNewSubject(c.pa.hdr, msg, si.acc.GetName(), to)
c.srv.mqttStoreQoS1MsgForAccountOnNewSubject(c.pa.hdr, msg, siAcc.GetName(), to)
}

// FIXME(dlc) - Do L1 cache trick like normal client?
rr := si.acc.sl.Match(to)
rr := siAcc.sl.Match(to)

// If we are a route or gateway or leafnode and this message is flipped to a queue subscriber we
// need to handle that since the processMsgResults will want a queue filter.
Expand All @@ -4054,10 +4062,10 @@ func (c *client) processServiceImport(si *serviceImport, acc *Account, msg []byt
if c.srv.gateway.enabled {
flags |= pmrCollectQueueNames
var queues [][]byte
didDeliver, queues = c.processMsgResults(si.acc, rr, msg, c.pa.deliver, []byte(to), nrr, flags)
didDeliver = c.sendMsgToGateways(si.acc, msg, []byte(to), nrr, queues) || didDeliver
didDeliver, queues = c.processMsgResults(siAcc, rr, msg, c.pa.deliver, []byte(to), nrr, flags)
didDeliver = c.sendMsgToGateways(siAcc, msg, []byte(to), nrr, queues) || didDeliver
} else {
didDeliver, _ = c.processMsgResults(si.acc, rr, msg, c.pa.deliver, []byte(to), nrr, flags)
didDeliver, _ = c.processMsgResults(siAcc, rr, msg, c.pa.deliver, []byte(to), nrr, flags)
}

// Restore to original values.
Expand Down Expand Up @@ -4090,7 +4098,7 @@ func (c *client) processServiceImport(si *serviceImport, acc *Account, msg []byt
} else {
// This is a main import and since we could not even deliver to the exporting account
// go ahead and remove the respServiceImport we created above.
si.acc.removeRespServiceImport(rsi, reason)
siAcc.removeRespServiceImport(rsi, reason)
}
}
}
Expand Down
33 changes: 23 additions & 10 deletions server/server.go
Expand Up @@ -753,6 +753,12 @@ func (s *Server) configureAccounts(reloading bool) (map[string]struct{}, error)

opts := s.getOpts()

// We need to track service imports since we can not swap them out (unsub and re-sub)
// until the proper server struct accounts have been swapped in properly. Doing it in
// place could lead to data loss or server panic since account under new si has no real
// account and hence no sublist, so will panic on inbound message.
siMap := make(map[*Account][][]byte)

// Check opts and walk through them. We need to copy them here
// so that we do not keep a real one sitting in the options.
for _, acc := range opts.Accounts {
Expand All @@ -773,12 +779,16 @@ func (s *Server) configureAccounts(reloading bool) (map[string]struct{}, error)
// Collect the sids for the service imports since we are going to
// replace with new ones.
var sids [][]byte
c := a.ic
for _, si := range a.imports.services {
if c != nil && si.sid != nil {
if si.sid != nil {
sids = append(sids, si.sid)
}
}
// Setup to process later if needed.
if len(sids) > 0 || len(acc.imports.services) > 0 {
siMap[a] = sids
}

// Now reset all export/imports fields since they are going to be
// filled in shallowCopy()
a.imports.streams, a.imports.services = nil, nil
Expand All @@ -787,14 +797,6 @@ func (s *Server) configureAccounts(reloading bool) (map[string]struct{}, error)
// and pass `a` (our existing account) to get it updated.
acc.shallowCopy(a)
a.mu.Unlock()
// Need to release the lock for this.
s.mu.Unlock()
for _, sid := range sids {
c.processUnsub(sid)
}
// Add subscriptions for existing service imports.
a.addAllServiceImportSubs()
s.mu.Lock()
create = false
}
}
Expand Down Expand Up @@ -862,6 +864,7 @@ func (s *Server) configureAccounts(reloading bool) (map[string]struct{}, error)
for _, si := range acc.imports.services {
if v, ok := s.accounts.Load(si.acc.Name); ok {
si.acc = v.(*Account)

// It is possible to allow for latency tracking inside your
// own account, so lock only when not the same account.
if si.acc == acc {
Expand Down Expand Up @@ -889,6 +892,16 @@ func (s *Server) configureAccounts(reloading bool) (map[string]struct{}, error)
return true
})

// Check if we need to process service imports pending from above.
// This processing needs to be after we swap in the real accounts above.
for acc, sids := range siMap {
c := acc.ic
for _, sid := range sids {
c.processUnsub(sid)
}
acc.addAllServiceImportSubs()
}

// Set the system account if it was configured.
// Otherwise create a default one.
if opts.SystemAccount != _EMPTY_ {
Expand Down

0 comments on commit 6c9fb6a

Please sign in to comment.