Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update api.Client for isolated read-after-write support #1188

Merged
merged 12 commits into from
Oct 14, 2021
112 changes: 102 additions & 10 deletions vault/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,19 @@ func setupServer(t *testing.T, handler http.Handler) (*ClientFactory, *http.Serv
return w, server
}

func validateResponseHeader(client *api.Client, headerValue string) error {
resp, err := sendRequest(client, headerValue)
vinay-gopalan marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return err
}
// validate that the server provided a valid header value in its response
actual := resp.Header.Get(indexHeaderName)
if actual != headerValue {
return fmt.Errorf("expected header value %v, actual %v", headerValue, actual)
}
return nil
}

func TestClientFactory_Client(t *testing.T) {
b64enc := func(s string) string {
return base64.StdEncoding.EncodeToString([]byte(s))
Expand Down Expand Up @@ -289,16 +302,7 @@ func TestClientFactory_Client(t *testing.T) {
wg.Add(1)
go func(expected string) {
defer wg.Done()

resp, err := sendRequest(client, expected)
if err != nil {
t.Fatal(err)
}
// validate that the server provided a valid header value in its response
actual := resp.Header.Get(indexHeaderName)
if actual != expected {
t.Errorf("expected header value %v, actual %v", expected, actual)
}
validateResponseHeader(client, expected)
vinay-gopalan marked this conversation as resolved.
Show resolved Hide resolved
vinay-gopalan marked this conversation as resolved.
Show resolved Hide resolved
}(expected)
}
wg.Wait()
Expand All @@ -309,3 +313,91 @@ func TestClientFactory_Client(t *testing.T) {
})
}
}

func TestClientFactory_Clone(t *testing.T) {
b64enc := func(s string) string {
return base64.StdEncoding.EncodeToString([]byte(s))
}

handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Header().Set(indexHeaderName, strings.TrimLeft(req.URL.Path, "/"))
})

tests := []struct {
name string
handler http.Handler
wantStates []string
states []struct {
h1 string
h2 string
}
}{
{
name: "basic",
handler: handler,
wantStates: []string{
b64enc("v1:cid:0:1:"),
b64enc("v1:cid:1:0:"),
},
states: []struct {
h1 string
h2 string
}{
{
h1: b64enc("v1:cid:0:1:"),
h2: b64enc("v1:cid:1:0:"),
},
},
},
{
name: "multiple",
handler: handler,
wantStates: []string{
b64enc("v1:cid:0:2:"),
b64enc("v1:cid:2:0:"),
},
states: []struct {
h1 string
h2 string
}{
{
h1: b64enc("v1:cid:0:1:"),
h2: b64enc("v1:cid:1:0:"),
},
{
h1: b64enc("v1:cid:0:2:"),
h2: b64enc("v1:cid:2:0:"),
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

w, server := setupServer(t, tt.handler)
defer server.Close()

c1 := w.Client()
c2, err := w.Clone()
if err != nil {
t.Fatal(err)
}

var wg sync.WaitGroup
for _, headers := range tt.states {
wg.Add(1)
go func(headerVal string) {
defer wg.Done()
validateResponseHeader(c1, headerVal)
vinay-gopalan marked this conversation as resolved.
Show resolved Hide resolved

}(headers.h1)
validateResponseHeader(c2, headers.h2)
}
wg.Wait()

if !reflect.DeepEqual(tt.wantStates, w.states) {
t.Errorf("RawRequestWithContext(): expected states %v, actual %v", tt.wantStates, w.states)
}
})
}
}