Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIXED] Seqset encode bug that could cause bad stream state snapshots #4348

Merged
merged 6 commits into from Jul 30, 2023
74 changes: 50 additions & 24 deletions server/avl/seqset.go
Expand Up @@ -140,7 +140,7 @@ func (ss *SequenceSet) Heights() (l, r int) {

// Returns min, max and number of set items.
func (ss *SequenceSet) State() (min, max, num uint64) {
if ss.root == nil {
if ss == nil || ss.root == nil {
return 0, 0, 0
}
min, max = ss.MinMax()
Expand Down Expand Up @@ -446,20 +446,16 @@ func (n *node) insert(seq uint64, inserted *bool, nodes *int) *node {
// Don't make a function, impacts performance.
if bf := balanceF(n); bf > 1 {
// Left unbalanced.
if n.l.base+numEntries > seq {
return n.rotateR()
} else {
if balanceF(n.l) < 0 {
n.l = n.l.rotateL()
return n.rotateR()
}
return n.rotateR()
} else if bf < -1 {
// right unbalanced.
if n.r.base+numEntries > seq {
// Right unbalanced.
if balanceF(n.r) > 0 {
n.r = n.r.rotateR()
return n.rotateL()
} else {
return n.rotateL()
}
return n.rotateL()
}
return n
}
Expand Down Expand Up @@ -507,6 +503,9 @@ func balanceF(n *node) int {
}

func maxH(n *node) int {
if n == nil {
return 0
}
var lh, rh int
if n.l != nil {
lh = n.l.h
Expand Down Expand Up @@ -550,36 +549,63 @@ func (n *node) delete(seq uint64, deleted *bool, nodes *int) *node {
n.r = n.r.delete(seq, deleted, nodes)
} else if empty := n.clear(seq, deleted); empty {
*nodes--
if nn := n.l; nn == nil {
if n.l == nil {
n = n.r
} else if nn.r == nil {
nn.r = n.r
n = nn
} else if n.r == nil {
n = n.l
} else {
nn.r.r = n.r
n = nn
// We have both children.
n.r = n.r.insertNodePrev(n.l)
n = n.r
}
}

if n != nil {
n.h = maxH(n) + 1
}

// Check balance.
if bf := balanceF(n); bf > 1 {
// Left unbalanced.
if n.l.base+numEntries > seq {
return n.rotateR()
} else {
if balanceF(n.l) < 0 {
n.l = n.l.rotateL()
return n.rotateR()
}
return n.rotateR()
} else if bf < -1 {
// right unbalanced.
if n.r.base+numEntries > seq {
if balanceF(n.r) > 0 {
n.r = n.r.rotateR()
return n.rotateL()
} else {
return n.rotateL()
}
return n.rotateL()
}

return n
}

// Will insert nn into the node assuming it is less than all other nodes in n.
// Will re-calculate height and balance.
func (n *node) insertNodePrev(nn *node) *node {
if n.l == nil {
n.l = nn
} else {
n.l = n.l.insertNodePrev(nn)
}
n.h = maxH(n) + 1

// Check balance.
if bf := balanceF(n); bf > 1 {
// Left unbalanced.
if balanceF(n.l) < 0 {
n.l = n.l.rotateL()
}
return n.rotateR()
} else if bf < -1 {
// right unbalanced.
if balanceF(n.r) > 0 {
n.r = n.r.rotateR()
}
return n.rotateL()
}
return n
}

Expand Down
27 changes: 26 additions & 1 deletion server/avl/seqset_test.go
Expand Up @@ -134,21 +134,46 @@ func TestSeqSetDelete(t *testing.T) {
require_True(t, !ss.Exists(seq))
}
require_True(t, ss.root == nil)
}

num := 22*numEntries + 22
func TestSeqSetInsertAndDeletePedantic(t *testing.T) {
var ss SequenceSet

num := 50*numEntries + 22
nums := make([]uint64, 0, num)
for i := 0; i < num; i++ {
nums = append(nums, uint64(i))
}
rand.Shuffle(len(nums), func(i, j int) { nums[i], nums[j] = nums[j], nums[i] })

// Make sure always balanced.
testBalanced := func() {
t.Helper()
// Check heights.
ss.root.nodeIter(func(n *node) {
if n != nil && n.h != maxH(n)+1 {
t.Fatalf("Node height is wrong: %+v", n)
}
})
// Check balance factor.
if bf := balanceF(ss.root); bf > 1 || bf < -1 {
t.Fatalf("Unbalanced tree")
}
}

for _, n := range nums {
ss.Insert(n)
testBalanced()
}
require_True(t, ss.root != nil)

for _, n := range nums {
ss.Delete(n)
testBalanced()
require_True(t, !ss.Exists(n))
if ss.Size() > 0 {
require_True(t, ss.root != nil)
}
}
require_True(t, ss.root == nil)
}
Expand Down
18 changes: 9 additions & 9 deletions server/filestore.go
Expand Up @@ -6927,6 +6927,8 @@ func (fs *fileStore) EncodedStreamState(failed uint64) ([]byte, error) {
return nil, err
}
b = append(b, buf...)
default:
return nil, errors.New("no impl")
}
}
}
Expand All @@ -6946,13 +6948,11 @@ func (fs *fileStore) deleteBlocks() DeleteBlocks {
// Detect if we have a gap between these blocks.
if prevLast > 0 && prevLast+1 != mb.first.seq {
// Detect if we need to encode a run length encoding here.
gap := mb.first.seq - prevLast - 1
if gap > rlThresh {
if gap := mb.first.seq - prevLast - 1; gap > rlThresh {
// Check if we have a running adm, if so write that out first, or if contigous update rle params.
if adm != nil && adm.Size() > 0 {
min, max := adm.MinMax()
if min, max, num := adm.State(); num > 0 {
// Check if we are all contingous.
if uint64(adm.Size()) == max-min+1 {
if num == max-min+1 {
prevLast, gap = min-1, mb.first.seq-min
} else {
dbs = append(dbs, adm)
Expand All @@ -6972,13 +6972,13 @@ func (fs *fileStore) deleteBlocks() DeleteBlocks {
}
}
}
if sz := mb.dmap.Size(); sz > 0 {
// Check in case the mb's dmap is contiguous.
min, max := mb.dmap.MinMax()
if uint64(sz) == max-min+1 {
if min, max, num := mb.dmap.State(); num > 0 {
// Check in case the mb's dmap is contiguous and over our threshold.
if num == max-min+1 && num > rlThresh {
// Need to write out adm if it exists.
if adm != nil && adm.Size() > 0 {
dbs = append(dbs, adm)
adm = nil
}
dbs = append(dbs, &DeleteRange{First: min, Num: max - min + 1})
} else {
Expand Down
38 changes: 38 additions & 0 deletions server/filestore_test.go
Expand Up @@ -30,6 +30,7 @@ import (
"os"
"path/filepath"
"reflect"
"strconv"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -5677,6 +5678,7 @@ func TestFileStoreInitialFirstSeq(t *testing.T) {
func TestFileStoreRecaluclateFirstForSubjBug(t *testing.T) {
fs, err := newFileStore(FileStoreConfig{StoreDir: t.TempDir()}, StreamConfig{Name: "zzz", Subjects: []string{"*"}, Storage: FileStorage})
require_NoError(t, err)
defer fs.Stop()

fs.StoreMsg("foo", nil, nil) // 1
fs.StoreMsg("bar", nil, nil) // 2
Expand Down Expand Up @@ -5705,3 +5707,39 @@ func TestFileStoreRecaluclateFirstForSubjBug(t *testing.T) {
// Make sure it was update properly.
require_True(t, *ss == SimpleState{Msgs: 1, First: 3, Last: 3, firstNeedsUpdate: false})
}

func TestFileStoreStreamEncoderDecoder(t *testing.T) {
fs, err := newFileStore(
FileStoreConfig{StoreDir: t.TempDir()},
StreamConfig{Name: "zzz", Subjects: []string{"*"}, MaxMsgsPer: 2, Storage: FileStorage},
)
require_NoError(t, err)
defer fs.Stop()

const seed = 2222222
prand := rand.New(rand.NewSource(seed))

tick := time.NewTicker(time.Second)
defer tick.Stop()
done := time.NewTimer(10 * time.Second)

msg := bytes.Repeat([]byte("ABC"), 33) // ~100bytes

for running := true; running; {
select {
case <-tick.C:
var state StreamState
fs.FastState(&state)
snap, err := fs.EncodedStreamState(0)
require_NoError(t, err)
ss, err := DecodeStreamState(snap)
require_True(t, len(ss.Deleted) > 0)
require_NoError(t, err)
case <-done.C:
running = false
default:
key := strconv.Itoa(prand.Intn(256_000))
fs.StoreMsg(key, nil, msg)
}
}
}
12 changes: 11 additions & 1 deletion server/jetstream_cluster.go
Expand Up @@ -2604,7 +2604,7 @@ func (mset *stream) isMigrating() bool {
return true
}

// resetClusteredState is called when a clustered stream had a sequence mismatch and needs to be reset.
// resetClusteredState is called when a clustered stream had an error (e.g sequence mismatch, bad snapshot) and needs to be reset.
func (mset *stream) resetClusteredState(err error) bool {
mset.mu.RLock()
s, js, jsa, sa, acc, node := mset.srv, mset.js, mset.jsa, mset.sa, mset.acc, mset.node
Expand Down Expand Up @@ -2858,16 +2858,26 @@ func (js *jetStream) applyStreamEntries(mset *stream, ce *CommittedEntry, isReco
// Everything operates on new replicated state. Will convert legacy snapshots to this for processing.
var ss *StreamReplicatedState

onBadState := func(err error) {
// If we are the leader or recovering, meaning we own the snapshot,
// we should stepdown and clear our raft state since our snapshot is bad.
if isRecovering || mset.IsLeader() {
mset.resetClusteredState(err)
}
}

// Check if we are the new binary encoding.
if IsEncodedStreamState(e.Data) {
var err error
ss, err = DecodeStreamState(e.Data)
if err != nil {
onBadState(err)
return err
}
} else {
var snap streamSnapshot
if err := json.Unmarshal(e.Data, &snap); err != nil {
onBadState(err)
return err
}
// Convert over to StreamReplicatedState
Expand Down
4 changes: 2 additions & 2 deletions server/norace_test.go
Expand Up @@ -8730,7 +8730,7 @@ func TestNoRaceBinaryStreamSnapshotEncodingBasic(t *testing.T) {
require_True(t, ss.LastSeq == 3000)
require_True(t, ss.Msgs == 1000)
// We should have collapsed all these into 2 delete blocks.
require_True(t, len(ss.Deleted) == 2)
require_True(t, len(ss.Deleted) <= 2)
require_True(t, ss.Deleted.NumDeleted() == 2000)
}

Expand Down Expand Up @@ -8765,7 +8765,7 @@ func TestNoRaceFilestoreBinaryStreamSnapshotEncodingLargeGaps(t *testing.T) {
require_True(t, ss.FirstSeq == 1)
require_True(t, ss.LastSeq == 20_000)
require_True(t, ss.Msgs == 2)
require_True(t, len(ss.Deleted) == 2)
require_True(t, len(ss.Deleted) <= 2)
require_True(t, ss.Deleted.NumDeleted() == 19_998)
}

Expand Down
17 changes: 16 additions & 1 deletion server/store.go
Expand Up @@ -63,6 +63,8 @@ var (
ErrInvalidSequence = errors.New("invalid sequence")
// ErrSequenceMismatch is returned when storing a raw message and the expected sequence is wrong.
ErrSequenceMismatch = errors.New("expected sequence does not match store")
// ErrCorruptStreamState
ErrCorruptStreamState = errors.New("stream state snapshot is corrupt")
)

// StoreMsg is the stored message format for messages that are retained by the Store layer.
Expand Down Expand Up @@ -237,20 +239,28 @@ func DecodeStreamState(buf []byte) (*StreamReplicatedState, error) {
return num
}

parserFailed := func() bool {
return bi < 0
}

ss.Msgs = readU64()
ss.Bytes = readU64()
ss.FirstSeq = readU64()
ss.LastSeq = readU64()
ss.Failed = readU64()

if parserFailed() {
return nil, ErrCorruptStreamState
}

if numDeleted := readU64(); numDeleted > 0 {
// If we have some deleted blocks.
for l := len(buf); l > bi; {
switch buf[bi] {
case seqSetMagic:
dmap, n, err := avl.Decode(buf[bi:])
if err != nil {
return nil, err
return nil, ErrCorruptStreamState
}
bi += n
ss.Deleted = append(ss.Deleted, dmap)
Expand All @@ -259,7 +269,12 @@ func DecodeStreamState(buf []byte) (*StreamReplicatedState, error) {
var rl DeleteRange
rl.First = readU64()
rl.Num = readU64()
if parserFailed() {
return nil, ErrCorruptStreamState
}
ss.Deleted = append(ss.Deleted, &rl)
default:
return nil, ErrCorruptStreamState
}
}
}
Expand Down