Skip to content

Commit

Permalink
Post review updates
Browse files Browse the repository at this point in the history
- no longer rely in response/request callbacks for updating the
  replication state store
- make setting up the store simpler by adding a configuration directive
- address other comments.
- revert go version changes go.mod
  • Loading branch information
benashz committed Oct 13, 2021
1 parent 46fe88f commit 8a6dc86
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 121 deletions.
146 changes: 62 additions & 84 deletions api/client.go
Expand Up @@ -138,9 +138,10 @@ type Config struct {
// CloneHeaders ensures that the source client's headers are copied to its clone.
CloneHeaders bool

// CloneReplicationStateStore ensures that the source client's ReplicationStateStore
// is registered in the clone.
CloneReplicationStateStore bool
// PreventStaleReads enables the Client to require discovered cluster replication states
// in every request.
// The shared state is automatically propagated to all Client clones.
PreventStaleReads bool
}

// TLSConfig contains the parameters needed to configure TLS on the HTTP client
Expand Down Expand Up @@ -431,7 +432,7 @@ type Client struct {
policyOverride bool
requestCallbacks []RequestCallback
responseCallbacks []ResponseCallback
replicationStateStore ReplicationStateStore
replicationStateStore *replicationStateStore
}

// NewClient returns a new client for the given configuration.
Expand Down Expand Up @@ -505,6 +506,10 @@ func NewClient(c *Config) (*Client, error) {
headers: make(http.Header),
}

if c.PreventStaleReads {
client.replicationStateStore = &replicationStateStore{}
}

// Add the VaultRequest SSRF protection header
client.headers[consts.RequestHeaderName] = []string{"true"}

Expand All @@ -519,31 +524,6 @@ func NewClient(c *Config) (*Client, error) {
return client, nil
}

// RegisterReplicationStateStore for tracking replication states across all requests and responses.
// The ReplicationStateStore will be registered in both the request and response callback registries.
func (c *Client) RegisterReplicationStateStore(store ReplicationStateStore) error {
if c.replicationStateStore != nil {
return fmt.Errorf("replication state store already registered")
}

c.modifyLock.Lock()
defer c.modifyLock.Unlock()

c.replicationStateStore = store

if len(c.requestCallbacks) == 0 {
c.requestCallbacks = []RequestCallback{}
}
c.requestCallbacks = append(c.requestCallbacks, c.replicationStateStore.HandleRequest)

if len(c.responseCallbacks) == 0 {
c.responseCallbacks = []ResponseCallback{}
}
c.responseCallbacks = append(c.responseCallbacks, c.replicationStateStore.HandleResponse)

return nil
}

func (c *Client) CloneConfig() *Config {
c.modifyLock.RLock()
defer c.modifyLock.RUnlock()
Expand All @@ -562,7 +542,7 @@ func (c *Client) CloneConfig() *Config {
newConfig.OutputCurlString = c.config.OutputCurlString
newConfig.SRVLookup = c.config.SRVLookup
newConfig.CloneHeaders = c.config.CloneHeaders
newConfig.CloneReplicationStateStore = c.config.CloneReplicationStateStore
newConfig.PreventStaleReads = c.config.PreventStaleReads

// we specifically want a _copy_ of the client here, not a pointer to the original one
newClient := *c.config.HttpClient
Expand Down Expand Up @@ -888,24 +868,30 @@ func (c *Client) CloneHeaders() bool {
return c.config.CloneHeaders
}

// SetCloneReplicationStateStore to clone the client's ReplicationStateStore
func (c *Client) SetCloneReplicationStateStore(val bool) {
// SetPreventStaleReads to prevent reading stale cluster replication state.
func (c *Client) SetPreventStaleReads(preventStaleReads bool) {
c.modifyLock.Lock()
defer c.modifyLock.Unlock()
c.config.modifyLock.Lock()
defer c.config.modifyLock.Unlock()

c.config.CloneReplicationStateStore = val
if preventStaleReads && c.replicationStateStore == nil {
c.replicationStateStore = &replicationStateStore{}
} else {
c.replicationStateStore = nil
}

c.config.PreventStaleReads = preventStaleReads
}

// CloneReplicationStateStore gets the configured value.
func (c *Client) CloneReplicationStateStore() bool {
// PreventStaleReads gets the configured value of PreventStaleReads
func (c *Client) PreventStaleReads() bool {
c.modifyLock.RLock()
defer c.modifyLock.RUnlock()
c.config.modifyLock.RLock()
defer c.config.modifyLock.RUnlock()

return c.config.CloneReplicationStateStore
return c.config.PreventStaleReads
}

// Clone creates a new client with the same configuration. Note that the same
Expand All @@ -925,21 +911,21 @@ func (c *Client) Clone() (*Client, error) {
defer config.modifyLock.RUnlock()

newConfig := &Config{
Address: config.Address,
HttpClient: config.HttpClient,
MinRetryWait: config.MinRetryWait,
MaxRetryWait: config.MaxRetryWait,
MaxRetries: config.MaxRetries,
Timeout: config.Timeout,
Backoff: config.Backoff,
CheckRetry: config.CheckRetry,
Logger: config.Logger,
Limiter: config.Limiter,
OutputCurlString: config.OutputCurlString,
AgentAddress: config.AgentAddress,
SRVLookup: config.SRVLookup,
CloneHeaders: config.CloneHeaders,
CloneReplicationStateStore: config.CloneReplicationStateStore,
Address: config.Address,
HttpClient: config.HttpClient,
MinRetryWait: config.MinRetryWait,
MaxRetryWait: config.MaxRetryWait,
MaxRetries: config.MaxRetries,
Timeout: config.Timeout,
Backoff: config.Backoff,
CheckRetry: config.CheckRetry,
Logger: config.Logger,
Limiter: config.Limiter,
OutputCurlString: config.OutputCurlString,
AgentAddress: config.AgentAddress,
SRVLookup: config.SRVLookup,
CloneHeaders: config.CloneHeaders,
PreventStaleReads: config.PreventStaleReads,
}
client, err := NewClient(newConfig)
if err != nil {
Expand All @@ -950,15 +936,7 @@ func (c *Client) Clone() (*Client, error) {
client.SetHeaders(c.Headers().Clone())
}

if config.CloneReplicationStateStore {
if c.replicationStateStore != nil {
if err := client.RegisterReplicationStateStore(c.replicationStateStore); err != nil {
return nil, err
}
} else {
c.config.Logger.Warn("Parent has no ReplicationStateStore and CloneReplicationStateStore is specified")
}
}
client.SetPreventStaleReads(config.PreventStaleReads)

return client, nil
}
Expand Down Expand Up @@ -1065,6 +1043,10 @@ func (c *Client) RawRequestWithContext(ctx context.Context, r *Request) (*Respon
cb(r)
}

if c.config.PreventStaleReads {
c.replicationStateStore.requireStates(r)
}

if limiter != nil {
limiter.Wait(ctx)
}
Expand Down Expand Up @@ -1175,6 +1157,10 @@ START:
for _, cb := range c.responseCallbacks {
cb(result)
}

if c.config.PreventStaleReads {
c.replicationStateStore.recordStates(result)
}
}
if err := result.Error(); err != nil {
return result, err
Expand Down Expand Up @@ -1365,45 +1351,37 @@ func DefaultRetryPolicy(ctx context.Context, resp *http.Response, err error) (bo
return false, nil
}

type ReplicationStateStore interface {
HandleResponse(resp *Response)
HandleRequest(resp *Request)
States() []string
}

// SharedReplicationStateStore stores replication states by providing
// ResponseCallback and RequestCallback methods.
// It can be used when Client Controlled Consistency (VLT-146) is required.
// These methods should be registered in the Client's corresponding callback chains.
type SharedReplicationStateStore struct {
m sync.RWMutex
states []string
// replicationStateStore is used to track cluster replication states
// in order to prevent stale reads.
type replicationStateStore struct {
m sync.RWMutex
store []string
}

// HandleResponse updates the store's replication states with the merger of all states.
// It should be registered in a Client's requestCallback chain.
func (w *SharedReplicationStateStore) HandleResponse(resp *Response) {
// recordStates updates the store's replication states with the merger of all states.
func (w *replicationStateStore) recordStates(resp *Response) {
w.m.Lock()
defer w.m.Unlock()
newState := resp.Header.Get(HeaderIndex)
if newState != "" {
w.states = MergeReplicationStates(w.states, newState)
w.store = MergeReplicationStates(w.store, newState)
}
}

// HandleRequest updates the request with the store's replication states.
// It should be registered in a Client's responseCallback chain.
func (w *SharedReplicationStateStore) HandleRequest(req *Request) {
// requireStates updates the Request with the store's current replication states.
func (w *replicationStateStore) requireStates(req *Request) {
w.m.RLock()
defer w.m.RUnlock()
for _, s := range w.states {
for _, s := range w.store {
req.Headers.Add(HeaderIndex, s)
}
}

// States currently known to the store.
func (w *SharedReplicationStateStore) States() []string {
// states currently stored.
func (w *replicationStateStore) states() []string {
w.m.Lock()
defer w.m.Unlock()
return w.states
c := make([]string, len(w.store))
copy(c, w.store)
return c
}
50 changes: 14 additions & 36 deletions api/client_test.go
Expand Up @@ -649,7 +649,7 @@ func TestMergeReplicationStates(t *testing.T) {
}
}

func TestSharedReplicationStateStore_HandleResponse(t *testing.T) {
func TestReplicationStateStore_recordState(t *testing.T) {
b64enc := func(s string) string {
return base64.StdEncoding.EncodeToString([]byte(s))
}
Expand Down Expand Up @@ -744,26 +744,26 @@ func TestSharedReplicationStateStore_HandleResponse(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := &SharedReplicationStateStore{}
w := &replicationStateStore{}

var wg sync.WaitGroup
for _, r := range tt.resp {
wg.Add(1)
go func(r *Response) {
defer wg.Done()
w.HandleResponse(r)
w.recordStates(r)
}(r)
}
wg.Wait()

if !reflect.DeepEqual(tt.expected, w.states) {
t.Errorf("HandleResponse(): expected states %v, actual %v", tt.expected, w.states)
if !reflect.DeepEqual(tt.expected, w.store) {
t.Errorf("recordStates(): expected states %v, actual %v", tt.expected, w.store)
}
})
}
}

func TestSharedReplicationStateStore_HandleRequest(t *testing.T) {
func TestReplicationStateStore_requireStates(t *testing.T) {
tests := []struct {
name string
states []string
Expand Down Expand Up @@ -799,8 +799,8 @@ func TestSharedReplicationStateStore_HandleRequest(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
store := &SharedReplicationStateStore{
states: tt.states,
store := &replicationStateStore{
store: tt.states,
}

start := make(chan interface{})
Expand All @@ -809,7 +809,7 @@ func TestSharedReplicationStateStore_HandleRequest(t *testing.T) {
for _, r := range tt.req {
go func(r *Request) {
<-start
store.HandleRequest(r)
store.requireStates(r)
done <- true
}(r)
}
Expand All @@ -828,13 +828,13 @@ func TestSharedReplicationStateStore_HandleRequest(t *testing.T) {
}
sort.Strings(actual)
if !reflect.DeepEqual(tt.expected, actual) {
t.Errorf("HandleRequest(): expected states %v, actual %v", tt.expected, actual)
t.Errorf("requireStates(): expected states %v, actual %v", tt.expected, actual)
}
})
}
}

func TestClient_RegisterReplicationStateStore(t *testing.T) {
func TestClient_PreventDirtyReads(t *testing.T) {
b64enc := func(s string) string {
return base64.StdEncoding.EncodeToString([]byte(s))
}
Expand Down Expand Up @@ -920,24 +920,6 @@ func TestClient_RegisterReplicationStateStore(t *testing.T) {
},
},
},
{
name: "multiple_duplicates",
clone: false,
handler: handler,
wantStates: []string{
b64enc("v1:cid:0:4:"),
},
values: [][]string{
{
b64enc("v1:cid:0:4:"),
b64enc("v1:cid:0:2:"),
},
{
b64enc("v1:cid:0:4:"),
b64enc("v1:cid:0:2:"),
},
},
},
}

for _, tt := range tests {
Expand All @@ -960,17 +942,13 @@ func TestClient_RegisterReplicationStateStore(t *testing.T) {
config, ln := testHTTPServer(t, handler)
defer ln.Close()

config.PreventStaleReads = true
config.Address = fmt.Sprintf("http://%s", ln.Addr())
parent, err := NewClient(config)
if err != nil {
t.Fatal(err)
}

parent.SetCloneReplicationStateStore(true)
if err := parent.RegisterReplicationStateStore(&SharedReplicationStateStore{}); err != nil {
t.Fatal(err)
}

start := make(chan interface{})
done := make(chan interface{})

Expand Down Expand Up @@ -1001,8 +979,8 @@ func TestClient_RegisterReplicationStateStore(t *testing.T) {
<-done
}

if !reflect.DeepEqual(tt.wantStates, parent.replicationStateStore.States()) {
t.Errorf("expected states %v, actual %v", tt.wantStates, parent.replicationStateStore.States())
if !reflect.DeepEqual(tt.wantStates, parent.replicationStateStore.states()) {
t.Errorf("expected states %v, actual %v", tt.wantStates, parent.replicationStateStore.states())
}
})
}
Expand Down
2 changes: 1 addition & 1 deletion api/go.mod
@@ -1,6 +1,6 @@
module github.com/hashicorp/vault/api

go 1.16
go 1.13

replace github.com/hashicorp/vault/sdk => ../sdk

Expand Down
1 change: 1 addition & 0 deletions api/go.sum
Expand Up @@ -280,6 +280,7 @@ google.golang.org/protobuf v1.25.0 h1:Ejskq+SyPohKW+1uil0JJMtmHCgJPJ/qWTxr8qp+R4
google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c=
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/square/go-jose.v2 v2.5.1 h1:7odma5RETjNHWJnR32wx8t+Io4djHE1PqxCFx3iiZ2w=
gopkg.in/square/go-jose.v2 v2.5.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
Expand Down
Empty file added changelog/12814.txt
Empty file.

0 comments on commit 8a6dc86

Please sign in to comment.