Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add code to api.RaftSnapshot to detect incomplete snapshots #12388

Merged
merged 3 commits into from Sep 7, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
61 changes: 57 additions & 4 deletions api/sys_raft.go
@@ -1,21 +1,25 @@
package api

import (
"archive/tar"
"compress/gzip"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"sync"
"time"

"github.com/hashicorp/go-secure-stdlib/parseutil"

"github.com/mitchellh/mapstructure"

"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/mitchellh/mapstructure"
)

var ErrIncompleteSnapshot = errors.New("incomplete snapshot, unable to read SHA256SUMS.sealed file")

// RaftJoinResponse represents the response of the raft join API
type RaftJoinResponse struct {
Joined bool `json:"joined"`
Expand Down Expand Up @@ -210,11 +214,60 @@ func (c *Sys) RaftSnapshot(snapWriter io.Writer) error {
return err
}

_, err = io.Copy(snapWriter, resp.Body)
// Make sure that the last file in the archive, SHA256SUMS.sealed, is present
// and non-empty. This is to catch cases where the snapshot failed midstream,
// e.g. due to a problem with the seal that prevented encryption of that file.
var wg sync.WaitGroup
wg.Add(1)
var verified bool

rPipe, wPipe := io.Pipe()
dup := io.TeeReader(resp.Body, wPipe)
go func() {
defer func() {
io.Copy(ioutil.Discard, rPipe)
rPipe.Close()
wg.Done()
}()

uncompressed, err := gzip.NewReader(rPipe)
if err != nil {
return
ncabatoff marked this conversation as resolved.
Show resolved Hide resolved
}

t := tar.NewReader(uncompressed)
var h *tar.Header
for {
h, err = t.Next()
if err != nil {
return
}
if h.Name != "SHA256SUMS.sealed" {
continue
}
var b []byte
b, err = ioutil.ReadAll(t)
if err != nil || len(b) == 0 {
return
}
verified = true
return
}
}()

// Copy bytes from dup to snapWriter. This will have a side effect that
// everything read from dup will be written to wPipe.
_, err = io.Copy(snapWriter, dup)
wPipe.Close()
if err != nil {
rPipe.CloseWithError(err)
return err
}
wg.Wait()

if !verified {
return ErrIncompleteSnapshot
}
return nil
}

Expand Down
3 changes: 3 additions & 0 deletions changelog/12388.txt
@@ -0,0 +1,3 @@
```release-note:bug
storage/raft: Detect incomplete raft snapshots in api.RaftSnapshot(), and thereby in `vault operator raft snapshot save`.
```
61 changes: 61 additions & 0 deletions vault/external_tests/raft/raft_test.go
Expand Up @@ -4,14 +4,19 @@ import (
"bytes"
"context"
"crypto/md5"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"strings"
"sync"
"sync/atomic"
"testing"
"time"

vaultseal "github.com/hashicorp/vault/vault/seal"

"github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/api"
Expand All @@ -36,6 +41,8 @@ type RaftClusterOpts struct {
PhysicalFactoryConfig map[string]interface{}
DisablePerfStandby bool
EnableResponseHeaderRaftNodeID bool
NumCores int
Seal vault.Seal
}

func raftCluster(t testing.TB, ropts *RaftClusterOpts) *vault.TestCluster {
Expand All @@ -49,6 +56,7 @@ func raftCluster(t testing.TB, ropts *RaftClusterOpts) *vault.TestCluster {
},
DisableAutopilot: !ropts.EnableAutopilot,
EnableResponseHeaderRaftNodeID: ropts.EnableResponseHeaderRaftNodeID,
Seal: ropts.Seal,
}

opts := vault.TestClusterOptions{
Expand All @@ -57,6 +65,7 @@ func raftCluster(t testing.TB, ropts *RaftClusterOpts) *vault.TestCluster {
opts.InmemClusterLayers = ropts.InmemCluster
opts.PhysicalFactoryConfig = ropts.PhysicalFactoryConfig
conf.DisablePerformanceStandby = ropts.DisablePerfStandby
opts.NumCores = ropts.NumCores

teststorage.RaftBackendSetup(conf, &opts)

Expand Down Expand Up @@ -542,6 +551,58 @@ func TestRaft_SnapshotAPI(t *testing.T) {
}
}

func TestRaft_SnapshotAPI_MidstreamFailure(t *testing.T) {
// defer goleak.VerifyNone(t)
t.Parallel()

seal, errptr := vaultseal.NewToggleableTestSeal(nil)
cluster := raftCluster(t, &RaftClusterOpts{
NumCores: 1,
Seal: vault.NewAutoSeal(seal),
})
defer cluster.Cleanup()

leaderClient := cluster.Cores[0].Client

// Write a bunch of keys; if too few, the detection code in api.RaftSnapshot
// will never make it into the tar part, it'll fail merely when trying to
// decompress the stream.
for i := 0; i < 1000; i++ {
_, err := leaderClient.Logical().Write(fmt.Sprintf("secret/%d", i), map[string]interface{}{
"test": "data",
})
if err != nil {
t.Fatal(err)
}
}

r, w := io.Pipe()
var snap []byte
var wg sync.WaitGroup
wg.Add(1)

var readErr error
go func() {
snap, readErr = ioutil.ReadAll(r)
wg.Done()
}()

*errptr = errors.New("seal failure")
// Take a snapshot
err := leaderClient.Sys().RaftSnapshot(w)
w.Close()
if err == nil || err != api.ErrIncompleteSnapshot {
t.Fatalf("expected err=%v, got: %v", api.ErrIncompleteSnapshot, err)
}
wg.Wait()
if len(snap) == 0 && readErr == nil {
readErr = errors.New("no bytes read")
}
if readErr != nil {
t.Fatal(readErr)
}
}

func TestRaft_SnapshotAPI_RekeyRotate_Backward(t *testing.T) {
type testCase struct {
Name string
Expand Down
35 changes: 35 additions & 0 deletions vault/seal/seal_testing.go
@@ -1,6 +1,8 @@
package seal

import (
"context"

"github.com/hashicorp/go-hclog"
wrapping "github.com/hashicorp/go-kms-wrapping"
)
Expand All @@ -22,3 +24,36 @@ func NewTestSeal(opts *TestSealOpts) *Access {
OverriddenType: opts.Name,
}
}

func NewToggleableTestSeal(opts *TestSealOpts) (*Access, *error) {
if opts == nil {
opts = new(TestSealOpts)
}

w := &ToggleableWrapper{Wrapper: wrapping.NewTestWrapper(opts.Secret)}
return &Access{
Wrapper: w,
OverriddenType: opts.Name,
}, &w.Error
}

type ToggleableWrapper struct {
wrapping.Wrapper
Error error
}

func (t ToggleableWrapper) Encrypt(ctx context.Context, bytes []byte, bytes2 []byte) (*wrapping.EncryptedBlobInfo, error) {
if t.Error != nil {
return nil, t.Error
}
return t.Wrapper.Encrypt(ctx, bytes, bytes2)
}

func (t ToggleableWrapper) Decrypt(ctx context.Context, info *wrapping.EncryptedBlobInfo, bytes []byte) ([]byte, error) {
if t.Error != nil {
return nil, t.Error
}
return t.Wrapper.Decrypt(ctx, info, bytes)
}

var _ wrapping.Wrapper = &ToggleableWrapper{}