Skip to content

Commit

Permalink
[FIXED] Protect against out of bounds access on usage updates. (#4164)
Browse files Browse the repository at this point in the history
Signed-off-by: Derek Collison <derek@nats.io>
  • Loading branch information
derekcollison committed May 15, 2023
2 parents fe71ef5 + 832df1c commit 584ea85
Showing 1 changed file with 55 additions and 33 deletions.
88 changes: 55 additions & 33 deletions server/jetstream.go
Expand Up @@ -1714,14 +1714,13 @@ func (a *Account) JetStreamEnabled() bool {
}

func (jsa *jsAccount) remoteUpdateUsage(sub *subscription, c *client, _ *Account, subject, _ string, msg []byte) {
const usageSize = 32

// jsa.js.srv is immutable and guaranteed to no be nil, so no lock needed.
s := jsa.js.srv

jsa.usageMu.Lock()
if len(msg) < usageSize {
jsa.usageMu.Unlock()
defer jsa.usageMu.Unlock()

if len(msg) < minUsageUpdateLen {
s.Warnf("Ignoring remote usage update with size too short")
return
}
Expand All @@ -1730,7 +1729,6 @@ func (jsa *jsAccount) remoteUpdateUsage(sub *subscription, c *client, _ *Account
rnode = subject[li+1:]
}
if rnode == _EMPTY_ {
jsa.usageMu.Unlock()
s.Warnf("Received remote usage update with no remote node")
return
}
Expand Down Expand Up @@ -1765,21 +1763,31 @@ func (jsa *jsAccount) remoteUpdateUsage(sub *subscription, c *client, _ *Account
apiTotal, apiErrors := le.Uint64(msg[16:]), le.Uint64(msg[24:])
memUsed, storeUsed := int64(le.Uint64(msg[0:])), int64(le.Uint64(msg[8:]))

// we later extended the data structure to support multiple tiers
excessRecordCnt := uint32(0)
tierName := _EMPTY_
if len(msg) >= 44 {
excessRecordCnt = le.Uint32(msg[32:])
length := le.Uint64(msg[36:])
tierName = string(msg[44 : 44+length])
msg = msg[44+length:]
// We later extended the data structure to support multiple tiers
var excessRecordCnt uint32
var tierName string

if len(msg) >= usageMultiTiersLen {
excessRecordCnt = le.Uint32(msg[minUsageUpdateLen:])
length := le.Uint64(msg[minUsageUpdateLen+4:])
// Need to protect past this point in case this is wrong.
if uint64(len(msg)) < usageMultiTiersLen+length {
s.Warnf("Received corrupt remote usage update")
return
}
tierName = string(msg[usageMultiTiersLen : usageMultiTiersLen+length])
msg = msg[usageMultiTiersLen+length:]
}
updateTotal(tierName, memUsed, storeUsed)
for ; excessRecordCnt > 0 && len(msg) >= 24; excessRecordCnt-- {
for ; excessRecordCnt > 0 && len(msg) >= usageRecordLen; excessRecordCnt-- {
memUsed, storeUsed := int64(le.Uint64(msg[0:])), int64(le.Uint64(msg[8:]))
length := le.Uint64(msg[16:])
tierName = string(msg[24 : 24+length])
msg = msg[24+length:]
if uint64(len(msg)) < usageRecordLen+length {
s.Warnf("Received corrupt remote usage update on excess record")
return
}
tierName = string(msg[usageRecordLen : usageRecordLen+length])
msg = msg[usageRecordLen+length:]
updateTotal(tierName, memUsed, storeUsed)
}
jsa.apiTotal -= rUsage.api
Expand All @@ -1788,7 +1796,6 @@ func (jsa *jsAccount) remoteUpdateUsage(sub *subscription, c *client, _ *Account
rUsage.err = apiErrors
jsa.apiTotal += apiTotal
jsa.apiErrors += apiErrors
jsa.usageMu.Unlock()
}

// When we detect a skew of some sort this will verify the usage reporting is correct.
Expand Down Expand Up @@ -1906,12 +1913,22 @@ func (jsa *jsAccount) sendClusterUsageUpdateTimer() {
}
}

// For usage fields.
const (
minUsageUpdateLen = 32
stackUsageUpdate = 72
usageRecordLen = 24
usageMultiTiersLen = 44
apiStatsAndNumTiers = 20
minUsageUpdateWindow = 250 * time.Millisecond
)

// Send updates to our account usage for this server.
// jsa.usageMu lock should be held.
func (jsa *jsAccount) sendClusterUsageUpdate() {
// These values are absolute so we can limit send rates.
now := time.Now()
if now.Sub(jsa.lupdate) < 250*time.Millisecond {
if now.Sub(jsa.lupdate) < minUsageUpdateWindow {
return
}
jsa.lupdate = now
Expand All @@ -1921,32 +1938,37 @@ func (jsa *jsAccount) sendClusterUsageUpdate() {
return
}
// every base record contains mem/store/len(tier) as well as the tier name
l := 24 * lenUsage
l := usageRecordLen * lenUsage
for tier := range jsa.usage {
l += len(tier)
}
if lenUsage > 0 {
// first record contains api/usage errors as well as count for extra base records
l += 20
// first record contains api/usage errors as well as count for extra base records
l += apiStatsAndNumTiers

var raw [stackUsageUpdate]byte
var b []byte
if l > stackUsageUpdate {
b = make([]byte, l)
} else {
b = raw[:l]
}
var le = binary.LittleEndian
b := make([]byte, l)
i := 0

var i int
var le = binary.LittleEndian
for tier, usage := range jsa.usage {
le.PutUint64(b[i+0:], uint64(usage.local.mem))
le.PutUint64(b[i+8:], uint64(usage.local.store))
if i == 0 {
le.PutUint64(b[i+16:], jsa.usageApi)
le.PutUint64(b[i+24:], jsa.usageErr)
le.PutUint32(b[i+32:], uint32(len(jsa.usage)-1))
le.PutUint64(b[i+36:], uint64(len(tier)))
copy(b[i+44:], tier)
i += 44 + len(tier)
le.PutUint64(b[16:], jsa.usageApi)
le.PutUint64(b[24:], jsa.usageErr)
le.PutUint32(b[32:], uint32(len(jsa.usage)-1))
le.PutUint64(b[36:], uint64(len(tier)))
copy(b[usageMultiTiersLen:], tier)
i = usageMultiTiersLen + len(tier)
} else {
le.PutUint64(b[i+16:], uint64(len(tier)))
copy(b[i+24:], tier)
i += 24 + len(tier)
copy(b[i+usageRecordLen:], tier)
i += usageRecordLen + len(tier)
}
}
jsa.sendq.push(newPubMsg(nil, jsa.updatesPub, _EMPTY_, nil, nil, b, noCompression, false, false))
Expand Down

0 comments on commit 584ea85

Please sign in to comment.