Skip to content

Commit

Permalink
api.Client: support isolated read-after-write (#12814)
Browse files Browse the repository at this point in the history
- add new configuration option, ReadYourWrites, which enables a Client
  to provide cluster replication states to every request. A curated set
  of cluster replication states are stored in the replicationStateStore,
  and is shared across clones.
  • Loading branch information
benashz committed Oct 14, 2021
1 parent 6b42f5d commit e24037f
Show file tree
Hide file tree
Showing 3 changed files with 451 additions and 17 deletions.
121 changes: 106 additions & 15 deletions api/client.go
Expand Up @@ -24,11 +24,12 @@ import (
retryablehttp "github.com/hashicorp/go-retryablehttp"
rootcerts "github.com/hashicorp/go-rootcerts"
"github.com/hashicorp/go-secure-stdlib/parseutil"
"golang.org/x/net/http2"
"golang.org/x/time/rate"

"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/strutil"
"github.com/hashicorp/vault/sdk/logical"
"golang.org/x/net/http2"
"golang.org/x/time/rate"
)

const (
Expand All @@ -49,6 +50,7 @@ const (
EnvVaultMFA = "VAULT_MFA"
EnvRateLimit = "VAULT_RATE_LIMIT"
EnvHTTPProxy = "VAULT_HTTP_PROXY"
HeaderIndex = "X-Vault-Index"
)

// Deprecated values
Expand Down Expand Up @@ -133,8 +135,18 @@ type Config struct {
// SRVLookup enables the client to lookup the host through DNS SRV lookup
SRVLookup bool

// CloneHeaders ensures that the source client's headers are copied to its clone.
// CloneHeaders ensures that the source client's headers are copied to
// its clone.
CloneHeaders bool

// ReadYourWrites ensures isolated read-after-write semantics by
// providing discovered cluster replication states in each request.
// The shared state is automatically propagated to all Client clones.
//
// Note: Careful consideration should be made prior to enabling this setting
// since there will be a performance penalty paid upon each request.
// This feature requires Enterprise server-side.
ReadYourWrites bool
}

// TLSConfig contains the parameters needed to configure TLS on the HTTP client
Expand Down Expand Up @@ -415,16 +427,17 @@ func parseRateLimit(val string) (rate float64, burst int, err error) {

// Client is the client to the Vault API. Create a client with NewClient.
type Client struct {
modifyLock sync.RWMutex
addr *url.URL
config *Config
token string
headers http.Header
wrappingLookupFunc WrappingLookupFunc
mfaCreds []string
policyOverride bool
requestCallbacks []RequestCallback
responseCallbacks []ResponseCallback
modifyLock sync.RWMutex
addr *url.URL
config *Config
token string
headers http.Header
wrappingLookupFunc WrappingLookupFunc
mfaCreds []string
policyOverride bool
requestCallbacks []RequestCallback
responseCallbacks []ResponseCallback
replicationStateStore *replicationStateStore
}

// NewClient returns a new client for the given configuration.
Expand Down Expand Up @@ -498,6 +511,10 @@ func NewClient(c *Config) (*Client, error) {
headers: make(http.Header),
}

if c.ReadYourWrites {
client.replicationStateStore = &replicationStateStore{}
}

// Add the VaultRequest SSRF protection header
client.headers[consts.RequestHeaderName] = []string{"true"}

Expand Down Expand Up @@ -530,6 +547,7 @@ func (c *Client) CloneConfig() *Config {
newConfig.OutputCurlString = c.config.OutputCurlString
newConfig.SRVLookup = c.config.SRVLookup
newConfig.CloneHeaders = c.config.CloneHeaders
newConfig.ReadYourWrites = c.config.ReadYourWrites

// we specifically want a _copy_ of the client here, not a pointer to the original one
newClient := *c.config.HttpClient
Expand Down Expand Up @@ -855,6 +873,32 @@ func (c *Client) CloneHeaders() bool {
return c.config.CloneHeaders
}

// SetReadYourWrites to prevent reading stale cluster replication state.
func (c *Client) SetReadYourWrites(preventStaleReads bool) {
c.modifyLock.Lock()
defer c.modifyLock.Unlock()
c.config.modifyLock.Lock()
defer c.config.modifyLock.Unlock()

if preventStaleReads && c.replicationStateStore == nil {
c.replicationStateStore = &replicationStateStore{}
} else {
c.replicationStateStore = nil
}

c.config.ReadYourWrites = preventStaleReads
}

// ReadYourWrites gets the configured value of ReadYourWrites
func (c *Client) ReadYourWrites() bool {
c.modifyLock.RLock()
defer c.modifyLock.RUnlock()
c.config.modifyLock.RLock()
defer c.config.modifyLock.RUnlock()

return c.config.ReadYourWrites
}

// Clone creates a new client with the same configuration. Note that the same
// underlying http.Client is used; modifying the client from more than one
// goroutine at once may not be safe, so modify the client as needed and then
Expand Down Expand Up @@ -886,6 +930,7 @@ func (c *Client) Clone() (*Client, error) {
AgentAddress: config.AgentAddress,
SRVLookup: config.SRVLookup,
CloneHeaders: config.CloneHeaders,
ReadYourWrites: config.ReadYourWrites,
}
client, err := NewClient(newConfig)
if err != nil {
Expand All @@ -896,6 +941,8 @@ func (c *Client) Clone() (*Client, error) {
client.SetHeaders(c.Headers().Clone())
}

client.replicationStateStore = c.replicationStateStore

return client, nil
}

Expand Down Expand Up @@ -1001,6 +1048,10 @@ func (c *Client) RawRequestWithContext(ctx context.Context, r *Request) (*Respon
cb(r)
}

if c.config.ReadYourWrites {
c.replicationStateStore.requireState(r)
}

if limiter != nil {
limiter.Wait(ctx)
}
Expand Down Expand Up @@ -1111,6 +1162,10 @@ START:
for _, cb := range c.responseCallbacks {
cb(result)
}

if c.config.ReadYourWrites {
c.replicationStateStore.recordState(result)
}
}
if err := result.Error(); err != nil {
return result, err
Expand Down Expand Up @@ -1152,7 +1207,7 @@ func (c *Client) WithResponseCallbacks(callbacks ...ResponseCallback) *Client {
// by Vault in a response header.
func RecordState(state *string) ResponseCallback {
return func(resp *Response) {
*state = resp.Header.Get("X-Vault-Index")
*state = resp.Header.Get(HeaderIndex)
}
}

Expand All @@ -1162,7 +1217,7 @@ func RecordState(state *string) ResponseCallback {
func RequireState(states ...string) RequestCallback {
return func(req *Request) {
for _, s := range states {
req.Headers.Add("X-Vault-Index", s)
req.Headers.Add(HeaderIndex, s)
}
}
}
Expand Down Expand Up @@ -1300,3 +1355,39 @@ func DefaultRetryPolicy(ctx context.Context, resp *http.Response, err error) (bo
}
return false, nil
}

// replicationStateStore is used to track cluster replication states
// in order to ensure proper read-after-write semantics for a Client.
type replicationStateStore struct {
m sync.RWMutex
store []string
}

// recordState updates the store's replication states with the merger of all
// states.
func (w *replicationStateStore) recordState(resp *Response) {
w.m.Lock()
defer w.m.Unlock()
newState := resp.Header.Get(HeaderIndex)
if newState != "" {
w.store = MergeReplicationStates(w.store, newState)
}
}

// requireState updates the Request with the store's current replication states.
func (w *replicationStateStore) requireState(req *Request) {
w.m.RLock()
defer w.m.RUnlock()
for _, s := range w.store {
req.Headers.Add(HeaderIndex, s)
}
}

// states currently stored.
func (w *replicationStateStore) states() []string {
w.m.RLock()
defer w.m.RUnlock()
c := make([]string, len(w.store))
copy(c, w.store)
return c
}

0 comments on commit e24037f

Please sign in to comment.