From 768ead6020a5a94d05794f52f9a33d6546ade1fc Mon Sep 17 00:00:00 2001 From: Ben Ash Date: Wed, 13 Oct 2021 12:57:33 -0400 Subject: [PATCH] Post review updates - no longer rely in response/request callbacks for updating the replication state store - make setting up the store simpler by adding a configuration directive - address other comments. - revert go version changes go.mod --- api/client.go | 134 +++++++++++++++++++-------------------------- api/client_test.go | 46 ++++------------ api/go.mod | 2 +- api/go.sum | 1 + 4 files changed, 71 insertions(+), 112 deletions(-) diff --git a/api/client.go b/api/client.go index 2a53db835135b..39a2dc7065d59 100644 --- a/api/client.go +++ b/api/client.go @@ -138,9 +138,8 @@ type Config struct { // CloneHeaders ensures that the source client's headers are copied to its clone. CloneHeaders bool - // CloneReplicationStateStore ensures that the source client's ReplicationStateStore - // is registered in the clone. - CloneReplicationStateStore bool + // PreventStaleReads .....TODO update + PreventStaleReads bool } // TLSConfig contains the parameters needed to configure TLS on the HTTP client @@ -431,7 +430,7 @@ type Client struct { policyOverride bool requestCallbacks []RequestCallback responseCallbacks []ResponseCallback - replicationStateStore ReplicationStateStore + replicationStateStore *replicationStateStore } // NewClient returns a new client for the given configuration. @@ -505,6 +504,10 @@ func NewClient(c *Config) (*Client, error) { headers: make(http.Header), } + if c.PreventStaleReads { + client.replicationStateStore = &replicationStateStore{} + } + // Add the VaultRequest SSRF protection header client.headers[consts.RequestHeaderName] = []string{"true"} @@ -519,31 +522,6 @@ func NewClient(c *Config) (*Client, error) { return client, nil } -// RegisterReplicationStateStore for tracking replication states across all requests and responses. -// The ReplicationStateStore will be registered in both the request and response callback registries. -func (c *Client) RegisterReplicationStateStore(store ReplicationStateStore) error { - if c.replicationStateStore != nil { - return fmt.Errorf("replication state store already registered") - } - - c.modifyLock.Lock() - defer c.modifyLock.Unlock() - - c.replicationStateStore = store - - if len(c.requestCallbacks) == 0 { - c.requestCallbacks = []RequestCallback{} - } - c.requestCallbacks = append(c.requestCallbacks, c.replicationStateStore.HandleRequest) - - if len(c.responseCallbacks) == 0 { - c.responseCallbacks = []ResponseCallback{} - } - c.responseCallbacks = append(c.responseCallbacks, c.replicationStateStore.HandleResponse) - - return nil -} - func (c *Client) CloneConfig() *Config { c.modifyLock.RLock() defer c.modifyLock.RUnlock() @@ -562,7 +540,7 @@ func (c *Client) CloneConfig() *Config { newConfig.OutputCurlString = c.config.OutputCurlString newConfig.SRVLookup = c.config.SRVLookup newConfig.CloneHeaders = c.config.CloneHeaders - newConfig.CloneReplicationStateStore = c.config.CloneReplicationStateStore + newConfig.PreventStaleReads = c.config.PreventStaleReads // we specifically want a _copy_ of the client here, not a pointer to the original one newClient := *c.config.HttpClient @@ -888,24 +866,28 @@ func (c *Client) CloneHeaders() bool { return c.config.CloneHeaders } -// SetCloneReplicationStateStore to clone the client's ReplicationStateStore -func (c *Client) SetCloneReplicationStateStore(val bool) { +// SetPreventStalesReads TODO: update +func (c *Client) SetPreventStalesReads(val bool) { c.modifyLock.Lock() defer c.modifyLock.Unlock() c.config.modifyLock.Lock() defer c.config.modifyLock.Unlock() - c.config.CloneReplicationStateStore = val + if c.replicationStateStore == nil { + c.replicationStateStore = &replicationStateStore{} + } + + c.config.PreventStaleReads = val } -// CloneReplicationStateStore gets the configured value. -func (c *Client) CloneReplicationStateStore() bool { +// PreventStaleReads TODO: update +func (c *Client) PreventStaleReads() bool { c.modifyLock.RLock() defer c.modifyLock.RUnlock() c.config.modifyLock.RLock() defer c.config.modifyLock.RUnlock() - return c.config.CloneReplicationStateStore + return c.config.PreventStaleReads } // Clone creates a new client with the same configuration. Note that the same @@ -925,21 +907,21 @@ func (c *Client) Clone() (*Client, error) { defer config.modifyLock.RUnlock() newConfig := &Config{ - Address: config.Address, - HttpClient: config.HttpClient, - MinRetryWait: config.MinRetryWait, - MaxRetryWait: config.MaxRetryWait, - MaxRetries: config.MaxRetries, - Timeout: config.Timeout, - Backoff: config.Backoff, - CheckRetry: config.CheckRetry, - Logger: config.Logger, - Limiter: config.Limiter, - OutputCurlString: config.OutputCurlString, - AgentAddress: config.AgentAddress, - SRVLookup: config.SRVLookup, - CloneHeaders: config.CloneHeaders, - CloneReplicationStateStore: config.CloneReplicationStateStore, + Address: config.Address, + HttpClient: config.HttpClient, + MinRetryWait: config.MinRetryWait, + MaxRetryWait: config.MaxRetryWait, + MaxRetries: config.MaxRetries, + Timeout: config.Timeout, + Backoff: config.Backoff, + CheckRetry: config.CheckRetry, + Logger: config.Logger, + Limiter: config.Limiter, + OutputCurlString: config.OutputCurlString, + AgentAddress: config.AgentAddress, + SRVLookup: config.SRVLookup, + CloneHeaders: config.CloneHeaders, + PreventStaleReads: config.PreventStaleReads, } client, err := NewClient(newConfig) if err != nil { @@ -950,14 +932,8 @@ func (c *Client) Clone() (*Client, error) { client.SetHeaders(c.Headers().Clone()) } - if config.CloneReplicationStateStore { - if c.replicationStateStore != nil { - if err := client.RegisterReplicationStateStore(c.replicationStateStore); err != nil { - return nil, err - } - } else { - c.config.Logger.Warn("Parent has no ReplicationStateStore and CloneReplicationStateStore is specified") - } + if config.PreventStaleReads { + client.replicationStateStore = c.replicationStateStore } return client, nil @@ -1065,6 +1041,10 @@ func (c *Client) RawRequestWithContext(ctx context.Context, r *Request) (*Respon cb(r) } + if c.config.PreventStaleReads { + c.replicationStateStore.requireStates(r) + } + if limiter != nil { limiter.Wait(ctx) } @@ -1175,6 +1155,10 @@ START: for _, cb := range c.responseCallbacks { cb(result) } + + if c.config.PreventStaleReads { + c.replicationStateStore.recordState(result) + } } if err := result.Error(); err != nil { return result, err @@ -1365,45 +1349,41 @@ func DefaultRetryPolicy(ctx context.Context, resp *http.Response, err error) (bo return false, nil } -type ReplicationStateStore interface { - HandleResponse(resp *Response) - HandleRequest(resp *Request) - States() []string -} - -// SharedReplicationStateStore stores replication states by providing +// replicationStateStore stores replication states by providing // ResponseCallback and RequestCallback methods. // It can be used when Client Controlled Consistency (VLT-146) is required. // These methods should be registered in the Client's corresponding callback chains. -type SharedReplicationStateStore struct { - m sync.RWMutex - states []string +type replicationStateStore struct { + m sync.RWMutex + store []string } -// HandleResponse updates the store's replication states with the merger of all states. +// recordState updates the store's replication states with the merger of all states. // It should be registered in a Client's requestCallback chain. -func (w *SharedReplicationStateStore) HandleResponse(resp *Response) { +func (w *replicationStateStore) recordState(resp *Response) { w.m.Lock() defer w.m.Unlock() newState := resp.Header.Get(HeaderIndex) if newState != "" { - w.states = MergeReplicationStates(w.states, newState) + w.store = MergeReplicationStates(w.store, newState) } } -// HandleRequest updates the request with the store's replication states. +// requireStates updates the request with the store's replication states. // It should be registered in a Client's responseCallback chain. -func (w *SharedReplicationStateStore) HandleRequest(req *Request) { +func (w *replicationStateStore) requireStates(req *Request) { w.m.RLock() defer w.m.RUnlock() - for _, s := range w.states { + for _, s := range w.store { req.Headers.Add(HeaderIndex, s) } } // States currently known to the store. -func (w *SharedReplicationStateStore) States() []string { +func (w *replicationStateStore) states() []string { w.m.Lock() defer w.m.Unlock() - return w.states + c := make([]string, len(w.store)) + copy(c, w.store) + return c } diff --git a/api/client_test.go b/api/client_test.go index 696c2f536dfe3..cf5365ae9ae3b 100644 --- a/api/client_test.go +++ b/api/client_test.go @@ -744,20 +744,20 @@ func TestSharedReplicationStateStore_HandleResponse(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - w := &SharedReplicationStateStore{} + w := &replicationStateStore{} var wg sync.WaitGroup for _, r := range tt.resp { wg.Add(1) go func(r *Response) { defer wg.Done() - w.HandleResponse(r) + w.recordState(r) }(r) } wg.Wait() - if !reflect.DeepEqual(tt.expected, w.states) { - t.Errorf("HandleResponse(): expected states %v, actual %v", tt.expected, w.states) + if !reflect.DeepEqual(tt.expected, w.store) { + t.Errorf("recordState(): expected states %v, actual %v", tt.expected, w.store) } }) } @@ -799,8 +799,8 @@ func TestSharedReplicationStateStore_HandleRequest(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - store := &SharedReplicationStateStore{ - states: tt.states, + store := &replicationStateStore{ + store: tt.states, } start := make(chan interface{}) @@ -809,7 +809,7 @@ func TestSharedReplicationStateStore_HandleRequest(t *testing.T) { for _, r := range tt.req { go func(r *Request) { <-start - store.HandleRequest(r) + store.requireStates(r) done <- true }(r) } @@ -828,13 +828,13 @@ func TestSharedReplicationStateStore_HandleRequest(t *testing.T) { } sort.Strings(actual) if !reflect.DeepEqual(tt.expected, actual) { - t.Errorf("HandleRequest(): expected states %v, actual %v", tt.expected, actual) + t.Errorf("requireStates(): expected states %v, actual %v", tt.expected, actual) } }) } } -func TestClient_RegisterReplicationStateStore(t *testing.T) { +func TestClient_PreventDirtyReads(t *testing.T) { b64enc := func(s string) string { return base64.StdEncoding.EncodeToString([]byte(s)) } @@ -920,24 +920,6 @@ func TestClient_RegisterReplicationStateStore(t *testing.T) { }, }, }, - { - name: "multiple_duplicates", - clone: false, - handler: handler, - wantStates: []string{ - b64enc("v1:cid:0:4:"), - }, - values: [][]string{ - { - b64enc("v1:cid:0:4:"), - b64enc("v1:cid:0:2:"), - }, - { - b64enc("v1:cid:0:4:"), - b64enc("v1:cid:0:2:"), - }, - }, - }, } for _, tt := range tests { @@ -960,17 +942,13 @@ func TestClient_RegisterReplicationStateStore(t *testing.T) { config, ln := testHTTPServer(t, handler) defer ln.Close() + config.PreventStaleReads = true config.Address = fmt.Sprintf("http://%s", ln.Addr()) parent, err := NewClient(config) if err != nil { t.Fatal(err) } - parent.SetCloneReplicationStateStore(true) - if err := parent.RegisterReplicationStateStore(&SharedReplicationStateStore{}); err != nil { - t.Fatal(err) - } - start := make(chan interface{}) done := make(chan interface{}) @@ -1001,8 +979,8 @@ func TestClient_RegisterReplicationStateStore(t *testing.T) { <-done } - if !reflect.DeepEqual(tt.wantStates, parent.replicationStateStore.States()) { - t.Errorf("expected states %v, actual %v", tt.wantStates, parent.replicationStateStore.States()) + if !reflect.DeepEqual(tt.wantStates, parent.replicationStateStore.states()) { + t.Errorf("expected states %v, actual %v", tt.wantStates, parent.replicationStateStore.states()) } }) } diff --git a/api/go.mod b/api/go.mod index 560543c7f9e56..e3b1824f59654 100644 --- a/api/go.mod +++ b/api/go.mod @@ -1,6 +1,6 @@ module github.com/hashicorp/vault/api -go 1.16 +go 1.13 replace github.com/hashicorp/vault/sdk => ../sdk diff --git a/api/go.sum b/api/go.sum index 0a801b892fd3a..95f434a5f44f4 100644 --- a/api/go.sum +++ b/api/go.sum @@ -280,6 +280,7 @@ google.golang.org/protobuf v1.25.0 h1:Ejskq+SyPohKW+1uil0JJMtmHCgJPJ/qWTxr8qp+R4 google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/square/go-jose.v2 v2.5.1 h1:7odma5RETjNHWJnR32wx8t+Io4djHE1PqxCFx3iiZ2w= gopkg.in/square/go-jose.v2 v2.5.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=