Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

api.Client: support isolated read-after-write #12814

Merged
merged 8 commits into from Oct 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
121 changes: 106 additions & 15 deletions api/client.go
Expand Up @@ -24,11 +24,12 @@ import (
retryablehttp "github.com/hashicorp/go-retryablehttp"
rootcerts "github.com/hashicorp/go-rootcerts"
"github.com/hashicorp/go-secure-stdlib/parseutil"
"golang.org/x/net/http2"
"golang.org/x/time/rate"

"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/strutil"
"github.com/hashicorp/vault/sdk/logical"
"golang.org/x/net/http2"
"golang.org/x/time/rate"
)

const (
Expand All @@ -49,6 +50,7 @@ const (
EnvVaultMFA = "VAULT_MFA"
EnvRateLimit = "VAULT_RATE_LIMIT"
EnvHTTPProxy = "VAULT_HTTP_PROXY"
HeaderIndex = "X-Vault-Index"
)

// Deprecated values
Expand Down Expand Up @@ -133,8 +135,18 @@ type Config struct {
// SRVLookup enables the client to lookup the host through DNS SRV lookup
SRVLookup bool

// CloneHeaders ensures that the source client's headers are copied to its clone.
// CloneHeaders ensures that the source client's headers are copied to
// its clone.
CloneHeaders bool

benashz marked this conversation as resolved.
Show resolved Hide resolved
// ReadYourWrites ensures isolated read-after-write semantics by
// providing discovered cluster replication states in each request.
// The shared state is automatically propagated to all Client clones.
//
// Note: Careful consideration should be made prior to enabling this setting
// since there will be a performance penalty paid upon each request.
// This feature requires Enterprise server-side.
ReadYourWrites bool
}

// TLSConfig contains the parameters needed to configure TLS on the HTTP client
Expand Down Expand Up @@ -415,16 +427,17 @@ func parseRateLimit(val string) (rate float64, burst int, err error) {

// Client is the client to the Vault API. Create a client with NewClient.
type Client struct {
modifyLock sync.RWMutex
addr *url.URL
config *Config
token string
headers http.Header
wrappingLookupFunc WrappingLookupFunc
mfaCreds []string
policyOverride bool
requestCallbacks []RequestCallback
responseCallbacks []ResponseCallback
modifyLock sync.RWMutex
addr *url.URL
config *Config
token string
headers http.Header
wrappingLookupFunc WrappingLookupFunc
mfaCreds []string
policyOverride bool
requestCallbacks []RequestCallback
responseCallbacks []ResponseCallback
replicationStateStore *replicationStateStore
}

// NewClient returns a new client for the given configuration.
Expand Down Expand Up @@ -498,6 +511,10 @@ func NewClient(c *Config) (*Client, error) {
headers: make(http.Header),
}

if c.ReadYourWrites {
client.replicationStateStore = &replicationStateStore{}
}

// Add the VaultRequest SSRF protection header
client.headers[consts.RequestHeaderName] = []string{"true"}

Expand Down Expand Up @@ -530,6 +547,7 @@ func (c *Client) CloneConfig() *Config {
newConfig.OutputCurlString = c.config.OutputCurlString
newConfig.SRVLookup = c.config.SRVLookup
newConfig.CloneHeaders = c.config.CloneHeaders
newConfig.ReadYourWrites = c.config.ReadYourWrites

// we specifically want a _copy_ of the client here, not a pointer to the original one
newClient := *c.config.HttpClient
Expand Down Expand Up @@ -855,6 +873,32 @@ func (c *Client) CloneHeaders() bool {
return c.config.CloneHeaders
}

// SetReadYourWrites to prevent reading stale cluster replication state.
func (c *Client) SetReadYourWrites(preventStaleReads bool) {
c.modifyLock.Lock()
defer c.modifyLock.Unlock()
c.config.modifyLock.Lock()
defer c.config.modifyLock.Unlock()

if preventStaleReads && c.replicationStateStore == nil {
c.replicationStateStore = &replicationStateStore{}
} else {
c.replicationStateStore = nil
}

c.config.ReadYourWrites = preventStaleReads
}

// ReadYourWrites gets the configured value of ReadYourWrites
func (c *Client) ReadYourWrites() bool {
c.modifyLock.RLock()
defer c.modifyLock.RUnlock()
c.config.modifyLock.RLock()
defer c.config.modifyLock.RUnlock()

return c.config.ReadYourWrites
}

// Clone creates a new client with the same configuration. Note that the same
// underlying http.Client is used; modifying the client from more than one
// goroutine at once may not be safe, so modify the client as needed and then
Expand Down Expand Up @@ -886,6 +930,7 @@ func (c *Client) Clone() (*Client, error) {
AgentAddress: config.AgentAddress,
SRVLookup: config.SRVLookup,
CloneHeaders: config.CloneHeaders,
ReadYourWrites: config.ReadYourWrites,
}
client, err := NewClient(newConfig)
if err != nil {
Expand All @@ -896,6 +941,8 @@ func (c *Client) Clone() (*Client, error) {
client.SetHeaders(c.Headers().Clone())
}

client.replicationStateStore = c.replicationStateStore

return client, nil
}

Expand Down Expand Up @@ -1001,6 +1048,10 @@ func (c *Client) RawRequestWithContext(ctx context.Context, r *Request) (*Respon
cb(r)
}

if c.config.ReadYourWrites {
c.replicationStateStore.requireState(r)
}

if limiter != nil {
limiter.Wait(ctx)
}
Expand Down Expand Up @@ -1111,6 +1162,10 @@ START:
for _, cb := range c.responseCallbacks {
cb(result)
}

if c.config.ReadYourWrites {
c.replicationStateStore.recordState(result)
}
}
if err := result.Error(); err != nil {
return result, err
Expand Down Expand Up @@ -1152,7 +1207,7 @@ func (c *Client) WithResponseCallbacks(callbacks ...ResponseCallback) *Client {
// by Vault in a response header.
func RecordState(state *string) ResponseCallback {
return func(resp *Response) {
*state = resp.Header.Get("X-Vault-Index")
*state = resp.Header.Get(HeaderIndex)
}
}

Expand All @@ -1162,7 +1217,7 @@ func RecordState(state *string) ResponseCallback {
func RequireState(states ...string) RequestCallback {
return func(req *Request) {
for _, s := range states {
req.Headers.Add("X-Vault-Index", s)
req.Headers.Add(HeaderIndex, s)
}
}
}
Expand Down Expand Up @@ -1300,3 +1355,39 @@ func DefaultRetryPolicy(ctx context.Context, resp *http.Response, err error) (bo
}
return false, nil
}

// replicationStateStore is used to track cluster replication states
// in order to ensure proper read-after-write semantics for a Client.
type replicationStateStore struct {
m sync.RWMutex
store []string
}

// recordState updates the store's replication states with the merger of all
// states.
func (w *replicationStateStore) recordState(resp *Response) {
w.m.Lock()
defer w.m.Unlock()
newState := resp.Header.Get(HeaderIndex)
if newState != "" {
w.store = MergeReplicationStates(w.store, newState)
}
}

// requireState updates the Request with the store's current replication states.
func (w *replicationStateStore) requireState(req *Request) {
w.m.RLock()
defer w.m.RUnlock()
for _, s := range w.store {
req.Headers.Add(HeaderIndex, s)
}
}

// states currently stored.
func (w *replicationStateStore) states() []string {
w.m.RLock()
defer w.m.RUnlock()
c := make([]string, len(w.store))
copy(c, w.store)
return c
}