diff --git a/HCPV_badge.png b/HCPV_badge.png new file mode 100644 index 0000000000000..243dc737dcd43 Binary files /dev/null and b/HCPV_badge.png differ diff --git a/api/api_test.go b/api/api_test.go index b2b851df6e89f..e4ba3153203eb 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -9,8 +9,7 @@ import ( // testHTTPServer creates a test HTTP server that handles requests until // the listener returned is closed. -func testHTTPServer( - t *testing.T, handler http.Handler) (*Config, net.Listener) { +func testHTTPServer(t *testing.T, handler http.Handler) (*Config, net.Listener) { ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("err: %s", err) diff --git a/api/logical.go b/api/logical.go index cd950a2b78968..f8f8bc5376612 100644 --- a/api/logical.go +++ b/api/logical.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "io" + "net/http" "net/url" "os" @@ -130,24 +131,37 @@ func (c *Logical) List(path string) (*Secret, error) { } func (c *Logical) Write(path string, data map[string]interface{}) (*Secret, error) { + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + r := c.c.NewRequest("PUT", "/v1/"+path) if err := r.SetJSONBody(data); err != nil { return nil, err } - return c.write(path, r) + return c.write(ctx, path, r) +} + +func (c *Logical) JSONMergePatch(ctx context.Context, path string, data map[string]interface{}) (*Secret, error) { + r := c.c.NewRequest("PATCH", "/v1/"+path) + r.Headers = http.Header{ + "Content-Type": []string{"application/merge-patch+json"}, + } + if err := r.SetJSONBody(data); err != nil { + return nil, err + } + + return c.write(ctx, path, r) } func (c *Logical) WriteBytes(path string, data []byte) (*Secret, error) { r := c.c.NewRequest("PUT", "/v1/"+path) r.BodyBytes = data - return c.write(path, r) + return c.write(context.Background(), path, r) } -func (c *Logical) write(path string, request *Request) (*Secret, error) { - ctx, cancelFunc := context.WithCancel(context.Background()) - defer cancelFunc() +func (c *Logical) write(ctx context.Context, path string, request *Request) (*Secret, error) { resp, err := c.c.RawRequestWithContext(ctx, request) if resp != nil { defer resp.Body.Close() diff --git a/changelog/12485.txt b/changelog/12485.txt new file mode 100644 index 0000000000000..6c8a87cd21cac --- /dev/null +++ b/changelog/12485.txt @@ -0,0 +1,3 @@ +```release-note:feature +**Customizable HTTP Headers**: Add support to define custom HTTP headers for root path (`/`) and also on API endpoints (`/v1/*`) +``` diff --git a/changelog/12668.txt b/changelog/12668.txt new file mode 100644 index 0000000000000..7006da572c44a --- /dev/null +++ b/changelog/12668.txt @@ -0,0 +1,3 @@ +```release-note:improvement +sdk/framework: The '+' wildcard is now supported for parameterizing unauthenticated paths. +``` diff --git a/changelog/12687.txt b/changelog/12687.txt new file mode 100644 index 0000000000000..f5998deeda706 --- /dev/null +++ b/changelog/12687.txt @@ -0,0 +1,5 @@ +```release-note:feature +**KV patch**: Add partial update support the for the `//data/:path` kv-v2 +endpoint through HTTP `PATCH`. A new `patch` ACL capability has been added and +is required to make such requests. +``` diff --git a/changelog/12800.txt b/changelog/12800.txt new file mode 100644 index 0000000000000..38aadc0170803 --- /dev/null +++ b/changelog/12800.txt @@ -0,0 +1,3 @@ +```release-note:feature +**OIDC Authorization Code Flow**: The Vault UI now supports OIDC Authorization Code Flow +``` diff --git a/command/agent/config/config.go b/command/agent/config/config.go index 9438bd327444d..502d512d15a24 100644 --- a/command/agent/config/config.go +++ b/command/agent/config/config.go @@ -35,6 +35,7 @@ func (c *Config) Prune() { l.RawConfig = nil l.Profiling.UnusedKeys = nil l.Telemetry.UnusedKeys = nil + l.CustomResponseHeaders = nil } c.FoundKeys = nil c.UnusedKeys = nil @@ -172,6 +173,12 @@ func LoadConfig(path string) (*Config, error) { if err != nil { return nil, err } + + // Pruning custom headers for Agent for now + for _, ln := range sharedConfig.Listeners { + ln.CustomResponseHeaders = nil + } + result.SharedConfig = sharedConfig list, ok := obj.Node.(*ast.ObjectList) diff --git a/command/agent/config/config_test.go b/command/agent/config/config_test.go index 0db8cf91954f7..252461236c89b 100644 --- a/command/agent/config/config_test.go +++ b/command/agent/config/config_test.go @@ -536,7 +536,6 @@ func TestLoadConfigFile_AgentCache_PersistMissingType(t *testing.T) { } func TestLoadConfigFile_TemplateConfig(t *testing.T) { - testCases := map[string]struct { fixturePath string expectedTemplateConfig TemplateConfig @@ -586,7 +585,6 @@ func TestLoadConfigFile_TemplateConfig(t *testing.T) { } }) } - } // TestLoadConfigFile_Template tests template definitions in Vault Agent diff --git a/command/kv_patch.go b/command/kv_patch.go index 05e759793ab4f..d05ff5eed89cd 100644 --- a/command/kv_patch.go +++ b/command/kv_patch.go @@ -1,11 +1,13 @@ package command import ( + "context" "fmt" "io" "os" "strings" + "github.com/hashicorp/vault/api" "github.com/mitchellh/cli" "github.com/posener/complete" ) @@ -18,7 +20,9 @@ var ( type KVPatchCommand struct { *BaseCommand - testStdin io.Reader // for tests + flagCAS int + flagMethod string + testStdin io.Reader // for tests } func (c *KVPatchCommand) Synopsis() string { @@ -45,6 +49,25 @@ Usage: vault kv patch [options] KEY [DATA] $ echo "abcd1234" | vault kv patch secret/foo bar=- + To perform a Check-And-Set operation, specify the -cas flag with the + appropriate version number corresponding to the key you want to perform + the CAS operation on: + + $ vault kv patch -cas=1 secret/foo bar=baz + + By default, this operation will attempt an HTTP PATCH operation. If your + policy does not allow that, it will fall back to a read/local update/write approach. + If you wish to specify which method this command should use, you may do so + with the -method flag. When -method=patch is specified, only an HTTP PATCH + operation will be tried. If it fails, the entire command will fail. + + $ vault kv patch -method=patch secret/foo bar=baz + + When -method=rw is specified, only a read/local update/write approach will be tried. + This was the default behavior previous to Vault 1.9. + + $ vault kv patch -method=rw secret/foo bar=baz + Additional flags and more advanced use cases are detailed below. ` + c.Flags().Help() @@ -54,6 +77,27 @@ Usage: vault kv patch [options] KEY [DATA] func (c *KVPatchCommand) Flags() *FlagSets { set := c.flagSet(FlagSetHTTP | FlagSetOutputField | FlagSetOutputFormat) + // Patch specific options + f := set.NewFlagSet("Common Options") + + f.IntVar(&IntVar{ + Name: "cas", + Target: &c.flagCAS, + Default: 0, + Usage: `Specifies to use a Check-And-Set operation. If set to 0 or not + set, the patch will be allowed. If the index is non-zero the patch will + only be allowed if the key’s current version matches the version + specified in the cas parameter.`, + }) + + f.StringVar(&StringVar{ + Name: "method", + Target: &c.flagMethod, + Usage: `Specifies which method of patching to use. If set to "patch", then + an HTTP PATCH request will be issued. If set to "rw", then a read will be + performed, then a local update, followed by a remote update.`, + }) + return set } @@ -121,6 +165,30 @@ func (c *KVPatchCommand) Run(args []string) int { return 2 } + // Check the method and behave accordingly + var secret *api.Secret + var code int + + switch c.flagMethod { + case "rw": + secret, code = c.readThenWrite(client, path, newData) + case "patch": + secret, code = c.mergePatch(client, path, newData, false) + case "": + secret, code = c.mergePatch(client, path, newData, true) + default: + c.UI.Error(fmt.Sprintf("Unsupported method provided to -method flag: %s", c.flagMethod)) + return 2 + } + + if code != 0 { + return code + } + + return OutputSecret(c.UI, secret) +} + +func (c *KVPatchCommand) readThenWrite(client *api.Client, path string, newData map[string]interface{}) (*api.Secret, int) { // First, do a read. // Note that we don't want to see curl output for the read request. curOutputCurl := client.OutputCurlString() @@ -129,45 +197,45 @@ func (c *KVPatchCommand) Run(args []string) int { client.SetOutputCurlString(curOutputCurl) if err != nil { c.UI.Error(fmt.Sprintf("Error doing pre-read at %s: %s", path, err)) - return 2 + return nil, 2 } // Make sure a value already exists if secret == nil || secret.Data == nil { c.UI.Error(fmt.Sprintf("No value found at %s", path)) - return 2 + return nil, 2 } // Verify metadata found rawMeta, ok := secret.Data["metadata"] if !ok || rawMeta == nil { c.UI.Error(fmt.Sprintf("No metadata found at %s; patch only works on existing data", path)) - return 2 + return nil, 2 } meta, ok := rawMeta.(map[string]interface{}) if !ok { c.UI.Error(fmt.Sprintf("Metadata found at %s is not the expected type (JSON object)", path)) - return 2 + return nil, 2 } if meta == nil { c.UI.Error(fmt.Sprintf("No metadata found at %s; patch only works on existing data", path)) - return 2 + return nil, 2 } // Verify old data found rawData, ok := secret.Data["data"] if !ok || rawData == nil { c.UI.Error(fmt.Sprintf("No data found at %s; patch only works on existing data", path)) - return 2 + return nil, 2 } data, ok := rawData.(map[string]interface{}) if !ok { c.UI.Error(fmt.Sprintf("Data found at %s is not the expected type (JSON object)", path)) - return 2 + return nil, 2 } if data == nil { c.UI.Error(fmt.Sprintf("No data found at %s; patch only works on existing data", path)) - return 2 + return nil, 2 } // Copy new data over @@ -183,19 +251,58 @@ func (c *KVPatchCommand) Run(args []string) int { }) if err != nil { c.UI.Error(fmt.Sprintf("Error writing data to %s: %s", path, err)) - return 2 + return nil, 2 } + if secret == nil { // Don't output anything unless using the "table" format if Format(c.UI) == "table" { c.UI.Info(fmt.Sprintf("Success! Data written to: %s", path)) } - return 0 + return nil, 0 } if c.flagField != "" { - return PrintRawField(c.UI, secret, c.flagField) + return nil, PrintRawField(c.UI, secret, c.flagField) } - return OutputSecret(c.UI, secret) + return secret, 0 +} + +func (c *KVPatchCommand) mergePatch(client *api.Client, path string, newData map[string]interface{}, rwFallback bool) (*api.Secret, int) { + data := map[string]interface{}{ + "data": newData, + "options": map[string]interface{}{}, + } + + if c.flagCAS > 0 { + data["options"].(map[string]interface{})["cas"] = c.flagCAS + } + + secret, err := client.Logical().JSONMergePatch(context.Background(), path, data) + if err != nil { + // If it's a 403, that probably means they don't have the patch capability in their policy. Fall back to + // the old way of doing it if the user didn't specify a -method. If they did, and it was "patch", then just error. + if re, ok := err.(*api.ResponseError); ok && re.StatusCode == 403 && rwFallback { + c.UI.Warn(fmt.Sprintf("Data was written to %s but we recommend that you add the \"patch\" capability to your ACL policy in order to use HTTP PATCH in the future.", path)) + return c.readThenWrite(client, path, newData) + } + + c.UI.Error(fmt.Sprintf("Error writing data to %s: %s", path, err)) + return nil, 2 + } + + if secret == nil { + // Don't output anything unless using the "table" format + if Format(c.UI) == "table" { + c.UI.Info(fmt.Sprintf("Success! Data written to: %s", path)) + } + return nil, 0 + } + + if c.flagField != "" { + return nil, PrintRawField(c.UI, secret, c.flagField) + } + + return secret, 0 } diff --git a/command/kv_test.go b/command/kv_test.go index 7e93a4cb8beff..a73fc01356664 100644 --- a/command/kv_test.go +++ b/command/kv_test.go @@ -1,6 +1,7 @@ package command import ( + "fmt" "io" "strings" "testing" @@ -552,3 +553,606 @@ func TestKVMetadataGetCommand(t *testing.T) { assertNoTabs(t, cmd) }) } + +func testKVPatchCommand(tb testing.TB) (*cli.MockUi, *KVPatchCommand) { + tb.Helper() + + ui := cli.NewMockUi() + return ui, &KVPatchCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + } +} + +func TestKVPatchCommand_ArgValidation(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + args []string + out string + code int + }{ + { + "not_enough_args", + []string{}, + "Not enough arguments", + 1, + }, + { + "empty_kvs", + []string{"kv/patch/foo"}, + "Must supply data", + 1, + }, + { + "kvs_no_value", + []string{"kv/patch/foo", "foo"}, + "Failed to parse K=V data", + 1, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + if err := client.Sys().Mount("kv/", &api.MountInput{ + Type: "kv-v2", + }); err != nil { + t.Fatalf("kv-v2 mount attempt failed - err: %#v\n", err) + } + + ui, cmd := testKVPatchCommand(t) + cmd.client = client + + code := cmd.Run(tc.args) + + if code != tc.code { + t.Fatalf("expected code to be %d but was %d for cmd %#v with args %#v\n", tc.code, code, cmd, tc.args) + } + + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + + if !strings.Contains(combined, tc.out) { + t.Fatalf("expected output to be %q but was %q for cmd %#v with args %#v\n", tc.out, combined, cmd, tc.args) + } + }) + } +} + +func TestKvPatchCommand_StdinFull(t *testing.T) { + client, closer := testVaultServer(t) + defer closer() + + if err := client.Sys().Mount("kv/", &api.MountInput{ + Type: "kv-v2", + }); err != nil { + t.Fatalf("kv-v2 mount attempt failed - err: %#v\n", err) + } + + if _, err := client.Logical().Write("kv/data/patch/foo", map[string]interface{}{ + "data": map[string]interface{}{ + "foo": "a", + }, + }); err != nil { + t.Fatalf("write failed, err: %#v\n", err) + } + + stdinR, stdinW := io.Pipe() + go func() { + stdinW.Write([]byte(`{"foo":"bar"}`)) + stdinW.Close() + }() + + _, cmd := testKVPatchCommand(t) + cmd.client = client + + cmd.testStdin = stdinR + + args := []string{"kv/patch/foo", "-"} + code := cmd.Run(args) + if code != 0 { + t.Fatalf("expected code to be 0 but was %d for cmd %#v with args %#v\n", code, cmd, args) + } + + secret, err := client.Logical().Read("kv/data/patch/foo") + if err != nil { + t.Fatalf("read failed, err: %#v\n", err) + } + + if secret == nil || secret.Data == nil { + t.Fatal("expected secret to have data") + } + + secretDataRaw, ok := secret.Data["data"] + + if !ok { + t.Fatalf("expected secret to have nested data key, data: %#v", secret.Data) + } + + secretData := secretDataRaw.(map[string]interface{}) + foo, ok := secretData["foo"].(string) + if !ok { + t.Fatal("expected foo to be a string but it wasn't") + } + + if exp, act := "bar", foo; exp != act { + t.Fatalf("expected %q to be %q, data: %#v\n", act, exp, secret.Data) + } +} + +func TestKvPatchCommand_StdinValue(t *testing.T) { + client, closer := testVaultServer(t) + defer closer() + + if err := client.Sys().Mount("kv/", &api.MountInput{ + Type: "kv-v2", + }); err != nil { + t.Fatalf("kv-v2 mount attempt failed - err: %#v\n", err) + } + + if _, err := client.Logical().Write("kv/data/patch/foo", map[string]interface{}{ + "data": map[string]interface{}{ + "foo": "a", + }, + }); err != nil { + t.Fatalf("write failed, err: %#v\n", err) + } + + stdinR, stdinW := io.Pipe() + go func() { + stdinW.Write([]byte("bar")) + stdinW.Close() + }() + + _, cmd := testKVPatchCommand(t) + cmd.client = client + + cmd.testStdin = stdinR + + args := []string{"kv/patch/foo", "foo=-"} + code := cmd.Run(args) + if code != 0 { + t.Fatalf("expected code to be 0 but was %d for cmd %#v with args %#v\n", code, cmd, args) + } + + secret, err := client.Logical().Read("kv/data/patch/foo") + if err != nil { + t.Fatalf("read failed, err: %#v\n", err) + } + + if secret == nil || secret.Data == nil { + t.Fatal("expected secret to have data") + } + + secretDataRaw, ok := secret.Data["data"] + + if !ok { + t.Fatalf("expected secret to have nested data key, data: %#v\n", secret.Data) + } + + secretData := secretDataRaw.(map[string]interface{}) + + if exp, act := "bar", secretData["foo"].(string); exp != act { + t.Fatalf("expected %q to be %q, data: %#v\n", act, exp, secret.Data) + } +} + +func TestKVPatchCommand_RWMethodNotExists(t *testing.T) { + client, closer := testVaultServer(t) + defer closer() + + if err := client.Sys().Mount("kv/", &api.MountInput{ + Type: "kv-v2", + }); err != nil { + t.Fatalf("kv-v2 mount attempt failed - err: %#v\n", err) + } + + ui, cmd := testKVPatchCommand(t) + cmd.client = client + + args := []string{"-method", "rw", "kv/patch/foo", "foo=a"} + code := cmd.Run(args) + if code != 2 { + t.Fatalf("expected code to be 2 but was %d for cmd %#v with args %#v\n", code, cmd, args) + } + + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + + expectedOutputSubstr := "No value found" + if !strings.Contains(combined, expectedOutputSubstr) { + t.Fatalf("expected output %q to contain %q for cmd %#v with args %#v\n", combined, expectedOutputSubstr, cmd, args) + } +} + +func TestKVPatchCommand_RWMethodSucceeds(t *testing.T) { + client, closer := testVaultServer(t) + defer closer() + + if err := client.Sys().Mount("kv/", &api.MountInput{ + Type: "kv-v2", + }); err != nil { + t.Fatalf("kv-v2 mount attempt failed - err: %#v\n", err) + } + + if _, err := client.Logical().Write("kv/data/patch/foo", map[string]interface{}{ + "data": map[string]interface{}{ + "foo": "a", + "bar": "b", + }, + }); err != nil { + t.Fatalf("write failed, err: %#v\n", err) + } + + ui, cmd := testKVPatchCommand(t) + cmd.client = client + + // Test single value + args := []string{"-method", "rw", "kv/patch/foo", "foo=aa"} + code := cmd.Run(args) + if code != 0 { + t.Fatalf("expected code to be 0 but was %d for cmd %#v with args %#v\n", code, cmd, args) + } + + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + + expectedOutputSubstr := "created_time" + if !strings.Contains(combined, expectedOutputSubstr) { + t.Fatalf("expected output %q to contain %q for cmd %#v with args %#v\n", combined, expectedOutputSubstr, cmd, args) + } + + // Test multi value + ui, cmd = testKVPatchCommand(t) + cmd.client = client + + args = []string{"-method", "rw", "kv/patch/foo", "foo=aaa", "bar=bbb"} + code = cmd.Run(args) + if code != 0 { + t.Fatalf("expected code to be 0 but was %d for cmd %#v with args %#v\n", code, cmd, args) + } + + combined = ui.OutputWriter.String() + ui.ErrorWriter.String() + + if !strings.Contains(combined, expectedOutputSubstr) { + t.Fatalf("expected output %q to contain %q for cmd %#v with args %#v\n", combined, expectedOutputSubstr, cmd, args) + } +} + +func TestKVPatchCommand_CAS(t *testing.T) { + cases := []struct { + name string + args []string + expected string + out string + code int + }{ + { + "right version", + []string{"-cas", "1", "kv/foo", "bar=quux"}, + "quux", + "", + 0, + }, + { + "wrong version", + []string{"-cas", "2", "kv/foo", "bar=wibble"}, + "baz", + "check-and-set parameter did not match the current version", + 2, + }, + } + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + if err := client.Sys().Mount("kv/", &api.MountInput{ + Type: "kv-v2", + }); err != nil { + t.Fatalf("kv-v2 mount attempt failed - err: %#v\n", err) + } + + // create a policy with patch capability + policy := `path "kv/*" { capabilities = ["create", "update", "read", "patch"] }` + secretAuth, err := createTokenForPolicy(t, client, policy) + if err != nil { + t.Fatalf("policy/token creation failed for policy %s, err: %#v\n", policy, err) + } + + kvClient, err := client.Clone() + if err != nil { + t.Fatal(err) + } + + kvClient.SetToken(secretAuth.ClientToken) + + _, err = kvClient.Logical().Write("kv/data/foo", map[string]interface{}{"data": map[string]interface{}{"bar": "baz"}}) + if err != nil { + t.Fatal(err) + } + + ui, cmd := testKVPatchCommand(t) + cmd.client = kvClient + + code := cmd.Run(tc.args) + + if code != tc.code { + t.Fatalf("expected code to be %d but was %d", tc.code, code) + } + + if tc.out != "" { + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, tc.out) { + t.Errorf("expected %q to contain %q", combined, tc.out) + } + } + + secret, err := kvClient.Logical().Read("kv/data/foo") + bar := secret.Data["data"].(map[string]interface{})["bar"] + if bar != tc.expected { + t.Fatalf("expected bar to be %q but it was %q", tc.expected, bar) + } + }) + } +} + +func TestKVPatchCommand_Methods(t *testing.T) { + cases := []struct { + name string + args []string + expected string + code int + }{ + { + "rw", + []string{"-method", "rw", "kv/foo", "bar=quux"}, + "quux", + 0, + }, + { + "patch", + []string{"-method", "patch", "kv/foo", "bar=wibble"}, + "wibble", + 0, + }, + } + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + if err := client.Sys().Mount("kv/", &api.MountInput{ + Type: "kv-v2", + }); err != nil { + t.Fatalf("kv-v2 mount attempt failed - err: %#v\n", err) + } + + // create a policy with patch capability + policy := `path "kv/*" { capabilities = ["create", "update", "read", "patch"] }` + secretAuth, err := createTokenForPolicy(t, client, policy) + if err != nil { + t.Fatalf("policy/token creation failed for policy %s, err: %#v\n", policy, err) + } + + kvClient, err := client.Clone() + if err != nil { + t.Fatal(err) + } + + kvClient.SetToken(secretAuth.ClientToken) + + _, err = kvClient.Logical().Write("kv/data/foo", map[string]interface{}{"data": map[string]interface{}{"bar": "baz"}}) + if err != nil { + t.Fatal(err) + } + + _, cmd := testKVPatchCommand(t) + cmd.client = kvClient + + code := cmd.Run(tc.args) + + if code != tc.code { + t.Fatalf("expected code to be %d but was %d", tc.code, code) + } + + secret, err := kvClient.Logical().Read("kv/data/foo") + bar := secret.Data["data"].(map[string]interface{})["bar"] + if bar != tc.expected { + t.Fatalf("expected bar to be %q but it was %q", tc.expected, bar) + } + }) + } +} + +func TestKVPatchCommand_403Fallback(t *testing.T) { + cases := []struct { + name string + args []string + expected string + code int + }{ + // if no -method is specified, and patch fails, it should fall back to rw and succeed + { + "unspecified", + []string{"kv/foo", "bar=quux"}, + `add the "patch" capability to your ACL policy`, + 0, + }, + // if -method=patch is specified, and patch fails, it should not fall back, and just error + { + "specifying patch", + []string{"-method", "patch", "kv/foo", "bar=quux"}, + "permission denied", + 2, + }, + } + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + client, closer := testVaultServer(t) + defer closer() + + if err := client.Sys().Mount("kv/", &api.MountInput{ + Type: "kv-v2", + }); err != nil { + t.Fatalf("kv-v2 mount attempt failed - err: %#v\n", err) + } + + // create a policy without patch capability + policy := `path "kv/*" { capabilities = ["create", "update", "read"] }` + secretAuth, err := createTokenForPolicy(t, client, policy) + if err != nil { + t.Fatalf("policy/token creation failed for policy %s, err: %#v\n", policy, err) + } + + kvClient, err := client.Clone() + if err != nil { + t.Fatal(err) + } + + kvClient.SetToken(secretAuth.ClientToken) + + // Write a value then attempt to patch it + _, err = kvClient.Logical().Write("kv/data/foo", map[string]interface{}{"data": map[string]interface{}{"bar": "baz"}}) + if err != nil { + t.Fatal(err) + } + + ui, cmd := testKVPatchCommand(t) + cmd.client = kvClient + code := cmd.Run(tc.args) + + if code != tc.code { + t.Fatalf("expected code to be %d but was %d", tc.code, code) + } + + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, tc.expected) { + t.Errorf("expected %q to contain %q", combined, tc.expected) + } + }) + } +} + +func createTokenForPolicy(t *testing.T, client *api.Client, policy string) (*api.SecretAuth, error) { + t.Helper() + + if err := client.Sys().PutPolicy("policy", policy); err != nil { + return nil, err + } + + secret, err := client.Auth().Token().Create(&api.TokenCreateRequest{ + Policies: []string{"policy"}, + TTL: "30m", + }) + if err != nil { + return nil, err + } + + if secret == nil || secret.Auth == nil || secret.Auth.ClientToken == "" { + return nil, fmt.Errorf("missing auth data: %#v", secret) + } + + return secret.Auth, err +} + +func TestKVPatchCommand_RWMethodPolicyVariations(t *testing.T) { + cases := []struct { + name string + args []string + policy string + expected string + code int + }{ + // if the policy doesn't have read capability and -method=rw is specified, it fails + { + "no read", + []string{"-method", "rw", "kv/foo", "bar=quux"}, + `path "kv/*" { capabilities = ["create", "update"] }`, + "permission denied", + 2, + }, + // if the policy doesn't have update capability and -method=rw is specified, it fails + { + "no update", + []string{"-method", "rw", "kv/foo", "bar=quux"}, + `path "kv/*" { capabilities = ["create", "read"] }`, + "permission denied", + 2, + }, + // if the policy has both read and update and -method=rw is specified, it succeeds + { + "read and update", + []string{"-method", "rw", "kv/foo", "bar=quux"}, + `path "kv/*" { capabilities = ["create", "read", "update"] }`, + "", + 0, + }, + } + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + client, closer := testVaultServer(t) + defer closer() + + if err := client.Sys().Mount("kv/", &api.MountInput{ + Type: "kv-v2", + }); err != nil { + t.Fatalf("kv-v2 mount attempt failed - err: %#v\n", err) + } + + secretAuth, err := createTokenForPolicy(t, client, tc.policy) + if err != nil { + t.Fatalf("policy/token creation failed for policy %s, err: %#v\n", tc.policy, err) + } + + client.SetToken(secretAuth.ClientToken) + + if _, err := client.Logical().Write("kv/data/foo", map[string]interface{}{ + "data": map[string]interface{}{ + "foo": "bar", + "bar": "baz", + }, + }); err != nil { + t.Fatalf("write failed, err: %#v\n", err) + } + + ui, cmd := testKVPatchCommand(t) + cmd.client = client + + code := cmd.Run(tc.args) + + if code != tc.code { + t.Fatalf("expected code to be %d but was %d for cmd %#v with args %#v\n", tc.code, code, cmd, tc.args) + } + + if code != 0 { + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, tc.expected) { + t.Fatalf("expected output %q to contain %q for cmd %#v with args %#v\n", combined, tc.expected, cmd, tc.args) + } + } + }) + } +} diff --git a/command/server.go b/command/server.go index 84814c4102562..718009b8cf419 100644 --- a/command/server.go +++ b/command/server.go @@ -1541,6 +1541,12 @@ func (c *ServerCommand) Run(args []string) int { core.SetConfig(config) + // reloading custom response headers to make sure we have + // the most up to date headers after reloading the config file + if err = core.ReloadCustomResponseHeaders(); err != nil { + c.logger.Error(err.Error()) + } + if config.LogLevel != "" { configLogLevel := strings.ToLower(strings.TrimSpace(config.LogLevel)) switch configLogLevel { diff --git a/command/server/config_custom_response_headers_test.go b/command/server/config_custom_response_headers_test.go new file mode 100644 index 0000000000000..5380568c25107 --- /dev/null +++ b/command/server/config_custom_response_headers_test.go @@ -0,0 +1,109 @@ +package server + +import ( + "fmt" + "testing" + + "github.com/go-test/deep" +) + +var defaultCustomHeaders = map[string]string{ + "Strict-Transport-Security": "max-age=1; domains", + "Content-Security-Policy": "default-src 'others'", + "X-Vault-Ignored": "ignored", + "X-Custom-Header": "Custom header value default", +} + +var customHeaders307 = map[string]string{ + "X-Custom-Header": "Custom header value 307", +} + +var customHeader3xx = map[string]string{ + "X-Vault-Ignored-3xx": "Ignored 3xx", + "X-Custom-Header": "Custom header value 3xx", +} + +var customHeaders200 = map[string]string{ + "Someheader-200": "200", + "X-Custom-Header": "Custom header value 200", +} + +var customHeader2xx = map[string]string{ + "X-Custom-Header": "Custom header value 2xx", +} + +var customHeader400 = map[string]string{ + "Someheader-400": "400", +} + +var defaultCustomHeadersMultiListener = map[string]string{ + "Strict-Transport-Security": "max-age=31536000; includeSubDomains", + "Content-Security-Policy": "default-src 'others'", + "X-Vault-Ignored": "ignored", + "X-Custom-Header": "Custom header value default", +} + +var defaultSTS = map[string]string{ + "Strict-Transport-Security": "max-age=31536000; includeSubDomains", +} + +func TestCustomResponseHeadersConfigs(t *testing.T) { + expectedCustomResponseHeader := map[string]map[string]string{ + "default": defaultCustomHeaders, + "307": customHeaders307, + "3xx": customHeader3xx, + "200": customHeaders200, + "2xx": customHeader2xx, + "400": customHeader400, + } + + config, err := LoadConfigFile("./test-fixtures/config_custom_response_headers_1.hcl") + if err != nil { + t.Fatalf("Error encountered when loading config %+v", err) + } + if diff := deep.Equal(expectedCustomResponseHeader, config.Listeners[0].CustomResponseHeaders); diff != nil { + t.Fatalf(fmt.Sprintf("parsed custom headers do not match the expected ones, difference: %v", diff)) + } +} + +func TestCustomResponseHeadersConfigsMultipleListeners(t *testing.T) { + expectedCustomResponseHeader := map[string]map[string]string{ + "default": defaultCustomHeadersMultiListener, + "307": customHeaders307, + "3xx": customHeader3xx, + "200": customHeaders200, + "2xx": customHeader2xx, + "400": customHeader400, + } + + config, err := LoadConfigFile("./test-fixtures/config_custom_response_headers_multiple_listeners.hcl") + if err != nil { + t.Fatalf("Error encountered when loading config %+v", err) + } + if diff := deep.Equal(expectedCustomResponseHeader, config.Listeners[0].CustomResponseHeaders); diff != nil { + t.Fatalf(fmt.Sprintf("parsed custom headers do not match the expected ones, difference: %v", diff)) + } + + if diff := deep.Equal(expectedCustomResponseHeader, config.Listeners[1].CustomResponseHeaders); diff == nil { + t.Fatalf(fmt.Sprintf("parsed custom headers do not match the expected ones, difference: %v", diff)) + } + if diff := deep.Equal(expectedCustomResponseHeader["default"], config.Listeners[1].CustomResponseHeaders["default"]); diff != nil { + t.Fatalf(fmt.Sprintf("parsed custom headers do not match the expected ones, difference: %v", diff)) + } + + if diff := deep.Equal(expectedCustomResponseHeader, config.Listeners[2].CustomResponseHeaders); diff == nil { + t.Fatalf(fmt.Sprintf("parsed custom headers do not match the expected ones, difference: %v", diff)) + } + + if diff := deep.Equal(defaultSTS, config.Listeners[2].CustomResponseHeaders["default"]); diff != nil { + t.Fatalf(fmt.Sprintf("parsed custom headers do not match the expected ones, difference: %v", diff)) + } + + if diff := deep.Equal(expectedCustomResponseHeader, config.Listeners[3].CustomResponseHeaders); diff == nil { + t.Fatalf(fmt.Sprintf("parsed custom headers do not match the expected ones, difference: %v", diff)) + } + + if diff := deep.Equal(defaultSTS, config.Listeners[3].CustomResponseHeaders["default"]); diff != nil { + t.Fatalf(fmt.Sprintf("parsed custom headers do not match the expected ones, difference: %v", diff)) + } +} diff --git a/command/server/config_test_helpers.go b/command/server/config_test_helpers.go index e40f6c8368a6e..8936a0244090c 100644 --- a/command/server/config_test_helpers.go +++ b/command/server/config_test_helpers.go @@ -16,6 +16,12 @@ import ( "github.com/hashicorp/vault/internalshared/configutil" ) +var DefaultCustomHeaders = map[string]map[string]string { + "default": { + "Strict-Transport-Security": configutil.StrictTransportSecurity, + }, +} + func boolPointer(x bool) *bool { return &x } @@ -32,6 +38,7 @@ func testConfigRaftRetryJoin(t *testing.T) { { Type: "tcp", Address: "127.0.0.1:8200", + CustomResponseHeaders: DefaultCustomHeaders, }, }, DisableMlock: true, @@ -64,6 +71,7 @@ func testLoadConfigFile_topLevel(t *testing.T, entropy *configutil.Entropy) { { Type: "tcp", Address: "127.0.0.1:443", + CustomResponseHeaders: DefaultCustomHeaders, }, }, @@ -174,10 +182,12 @@ func testLoadConfigFile_json2(t *testing.T, entropy *configutil.Entropy) { { Type: "tcp", Address: "127.0.0.1:443", + CustomResponseHeaders: DefaultCustomHeaders, }, { Type: "tcp", Address: "127.0.0.1:444", + CustomResponseHeaders: DefaultCustomHeaders, }, }, @@ -336,6 +346,7 @@ func testLoadConfigFileIntegerAndBooleanValuesCommon(t *testing.T, path string) { Type: "tcp", Address: "127.0.0.1:8200", + CustomResponseHeaders: DefaultCustomHeaders, }, }, DisableMlock: true, @@ -379,6 +390,7 @@ func testLoadConfigFile(t *testing.T) { { Type: "tcp", Address: "127.0.0.1:443", + CustomResponseHeaders: DefaultCustomHeaders, }, }, @@ -486,7 +498,7 @@ func testUnknownFieldValidation(t *testing.T) { for _, er1 := range errors { found := false if strings.Contains(er1.String(), "sentinel") { - //This happens on OSS, and is fine + // This happens on OSS, and is fine continue } for _, ex := range expected { @@ -525,6 +537,7 @@ func testLoadConfigFile_json(t *testing.T) { { Type: "tcp", Address: "127.0.0.1:443", + CustomResponseHeaders: DefaultCustomHeaders, }, }, @@ -610,6 +623,7 @@ func testLoadConfigDir(t *testing.T) { { Type: "tcp", Address: "127.0.0.1:443", + CustomResponseHeaders: DefaultCustomHeaders, }, }, @@ -818,6 +832,7 @@ listener "tcp" { Profiling: configutil.ListenerProfiling{ UnauthenticatedPProfAccess: true, }, + CustomResponseHeaders: DefaultCustomHeaders, }, }, }, @@ -845,6 +860,7 @@ func testParseSeals(t *testing.T) { { Type: "tcp", Address: "127.0.0.1:443", + CustomResponseHeaders: DefaultCustomHeaders, }, }, Seals: []*configutil.KMS{ @@ -898,6 +914,7 @@ func testLoadConfigFileLeaseMetrics(t *testing.T) { { Type: "tcp", Address: "127.0.0.1:443", + CustomResponseHeaders: DefaultCustomHeaders, }, }, diff --git a/command/server/test-fixtures/config_custom_response_headers_1.hcl b/command/server/test-fixtures/config_custom_response_headers_1.hcl new file mode 100644 index 0000000000000..c2f868c2f146c --- /dev/null +++ b/command/server/test-fixtures/config_custom_response_headers_1.hcl @@ -0,0 +1,31 @@ +storage "inmem" {} +listener "tcp" { + address = "127.0.0.1:8200" + tls_disable = true + custom_response_headers { + "default" = { + "Strict-Transport-Security" = ["max-age=1","domains"], + "Content-Security-Policy" = ["default-src 'others'"], + "X-Vault-Ignored" = ["ignored"], + "X-Custom-Header" = ["Custom header value default"], + } + "307" = { + "X-Custom-Header" = ["Custom header value 307"], + } + "3xx" = { + "X-Vault-Ignored-3xx" = ["Ignored 3xx"], + "X-Custom-Header" = ["Custom header value 3xx"] + } + "200" = { + "someheader-200" = ["200"], + "X-Custom-Header" = ["Custom header value 200"] + } + "2xx" = { + "X-Custom-Header" = ["Custom header value 2xx"] + } + "400" = { + "someheader-400" = ["400"] + } + } +} +disable_mlock = true diff --git a/command/server/test-fixtures/config_custom_response_headers_multiple_listeners.hcl b/command/server/test-fixtures/config_custom_response_headers_multiple_listeners.hcl new file mode 100644 index 0000000000000..11aa099232f91 --- /dev/null +++ b/command/server/test-fixtures/config_custom_response_headers_multiple_listeners.hcl @@ -0,0 +1,56 @@ +storage "inmem" {} +listener "tcp" { + address = "127.0.0.1:8200" + tls_disable = true + custom_response_headers { + "default" = { + "Content-Security-Policy" = ["default-src 'others'"], + "X-Vault-Ignored" = ["ignored"], + "X-Custom-Header" = ["Custom header value default"], + } + "307" = { + "X-Custom-Header" = ["Custom header value 307"], + } + "3xx" = { + "X-Vault-Ignored-3xx" = ["Ignored 3xx"], + "X-Custom-Header" = ["Custom header value 3xx"] + } + "200" = { + "someheader-200" = ["200"], + "X-Custom-Header" = ["Custom header value 200"] + } + "2xx" = { + "X-Custom-Header" = ["Custom header value 2xx"] + } + "400" = { + "someheader-400" = ["400"] + } + } +} +listener "tcp" { + address = "127.0.0.2:8200" + tls_disable = true + custom_response_headers { + "default" = { + "Content-Security-Policy" = ["default-src 'others'"], + "X-Vault-Ignored" = ["ignored"], + "X-Custom-Header" = ["Custom header value default"], + } + } +} +listener "tcp" { + address = "127.0.0.3:8200" + tls_disable = true + custom_response_headers { + "2xx" = { + "X-Custom-Header" = ["Custom header value 2xx"] + } + } +} +listener "tcp" { + address = "127.0.0.4:8200" + tls_disable = true +} + + +disable_mlock = true diff --git a/go.mod b/go.mod index affcd85ccf092..b5ab8ff8b72e2 100644 --- a/go.mod +++ b/go.mod @@ -113,13 +113,13 @@ require ( github.com/hashicorp/vault-plugin-secrets-azure v0.6.3-0.20210924190759-58a034528e35 github.com/hashicorp/vault-plugin-secrets-gcp v0.10.2 github.com/hashicorp/vault-plugin-secrets-gcpkms v0.9.0 - github.com/hashicorp/vault-plugin-secrets-kv v0.5.7-0.20210811133805-e060c2307b24 + github.com/hashicorp/vault-plugin-secrets-kv v0.5.7-0.20211013154503-eec8a1c892fb github.com/hashicorp/vault-plugin-secrets-mongodbatlas v0.4.0 github.com/hashicorp/vault-plugin-secrets-openldap v0.4.1-0.20210921171411-e86105e4986d github.com/hashicorp/vault-plugin-secrets-terraform v0.1.1-0.20210715043003-e02ca8f6408e github.com/hashicorp/vault-testing-stepwise v0.1.1 github.com/hashicorp/vault/api v1.1.1 - github.com/hashicorp/vault/sdk v0.2.1 + github.com/hashicorp/vault/sdk v0.2.2-0.20211004171540-a8c7e135dd6a github.com/influxdata/influxdb v0.0.0-20190411212539-d24b7ba8c4c4 github.com/jcmturner/gokrb5/v8 v8.0.0 github.com/jefferai/isbadcipher v0.0.0-20190226160619-51d2077c035f @@ -199,6 +199,7 @@ require ( google.golang.org/grpc v1.41.0 google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.1.0 google.golang.org/protobuf v1.27.1 + gopkg.in/evanphx/json-patch.v4 v4.11.0 // indirect gopkg.in/mgo.v2 v2.0.0-20180705113604-9856a29383ce gopkg.in/ory-am/dockertest.v3 v3.3.4 gopkg.in/square/go-jose.v2 v2.5.1 diff --git a/go.sum b/go.sum index 38b8e21486377..b2490092b2ec9 100644 --- a/go.sum +++ b/go.sum @@ -350,6 +350,8 @@ github.com/envoyproxy/go-control-plane v0.9.9-0.20210512163311-63b5d3c536b0/go.m github.com/envoyproxy/go-control-plane v0.9.10-0.20210907150352-cf90f659a021/go.mod h1:AFq3mo9L8Lqqiid3OhADV3RfLJnjiw63cSpi+fDTRC0= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/evanphx/json-patch v0.0.0-20190203023257-5858425f7550/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk= +github.com/evanphx/json-patch v0.5.2/go.mod h1:ZWS5hhDbVDyob71nXKNL0+PWn6ToqBHMikGIFbs31qQ= +github.com/evanphx/json-patch v4.2.0+incompatible h1:fUDGZCv/7iAN7u0puUVhvKCcsR6vRfwrJatElLBEf0I= github.com/evanphx/json-patch v4.2.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk= github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= github.com/fatih/color v1.9.0/go.mod h1:eQcE1qtQxscV5RaZvpXrrb8Drkc3/DdQ+uUYCNjL+zU= @@ -748,8 +750,10 @@ github.com/hashicorp/vault-plugin-secrets-gcp v0.10.2 h1:+DtlYJTsrFRInQpAo09KkYN github.com/hashicorp/vault-plugin-secrets-gcp v0.10.2/go.mod h1:psRQ/dm5XatoUKLDUeWrpP9icMJNtu/jmscUr37YGK4= github.com/hashicorp/vault-plugin-secrets-gcpkms v0.9.0 h1:7a0iWuFA/YNinQ1xXogyZHStolxMVtLV+sy1LpEHaZs= github.com/hashicorp/vault-plugin-secrets-gcpkms v0.9.0/go.mod h1:hhwps56f2ATeC4Smgghrc5JH9dXR31b4ehSf1HblP5Q= -github.com/hashicorp/vault-plugin-secrets-kv v0.5.7-0.20210811133805-e060c2307b24 h1:uqPKQzkmO5vybOqk2aOdviXXi5088bcl2MrE0D1MhjM= -github.com/hashicorp/vault-plugin-secrets-kv v0.5.7-0.20210811133805-e060c2307b24/go.mod h1:4j2pZrSynPuUAAYrZQVgSSHD0A9xj7GK9Ji1sWtnO4s= +github.com/hashicorp/vault-plugin-secrets-kv v0.5.7-0.20211007143158-2d15a6fec12b h1:1GJj7AjgI0Td95haW8EK5on3Usuox78wmzLj+J9vcm4= +github.com/hashicorp/vault-plugin-secrets-kv v0.5.7-0.20211007143158-2d15a6fec12b/go.mod h1:iEKCVaKBQzzYxzb778O6VGLdd+8gA40ZI14bo+8tQjs= +github.com/hashicorp/vault-plugin-secrets-kv v0.5.7-0.20211013154503-eec8a1c892fb h1:nZ2a4a1G0ALLAzKOWQbLzD5oljKo+pjMarbq3BwU0pM= +github.com/hashicorp/vault-plugin-secrets-kv v0.5.7-0.20211013154503-eec8a1c892fb/go.mod h1:D/FQJ7zU5pD6FNJVUwaVtxr75ZsxIIqaG/Nh6RHt/xo= github.com/hashicorp/vault-plugin-secrets-mongodbatlas v0.4.0 h1:6ve+7hZmGn7OpML81iZUxYj2AaJptwys323S5XsvVas= github.com/hashicorp/vault-plugin-secrets-mongodbatlas v0.4.0/go.mod h1:4mdgPqlkO+vfFX1cFAWcxkeqz6JAtZgKxL/67q/58Oo= github.com/hashicorp/vault-plugin-secrets-openldap v0.4.1-0.20210921171411-e86105e4986d h1:o5Z9B1FztTYSnTQNzFr+iZJHPM8ZD23uV5A8gMxm2g0= @@ -803,6 +807,7 @@ github.com/jefferai/isbadcipher v0.0.0-20190226160619-51d2077c035f h1:E87tDTVS5W github.com/jefferai/isbadcipher v0.0.0-20190226160619-51d2077c035f/go.mod h1:3J2qVK16Lq8V+wfiL2lPeDZ7UWMxk5LemerHa1p6N00= github.com/jefferai/jsonx v1.0.0 h1:Xoz0ZbmkpBvED5W9W1B5B/zc3Oiq7oXqiW7iRV3B6EI= github.com/jefferai/jsonx v1.0.0/go.mod h1:OGmqmi2tTeI/PS+qQfBDToLHHJIy/RMp24fPo8vFvoQ= +github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= github.com/jmespath/go-jmespath v0.3.0/go.mod h1:9QtRXoHjLGCJ5IBSaohpXITPlowMeeYCZ7fLUTSywik= github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= @@ -1672,6 +1677,8 @@ gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8X gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/errgo.v2 v2.1.0 h1:0vLT13EuvQ0hNvakwLuFZ/jYrLp5F3kcWHXdRggjCE8= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/evanphx/json-patch.v4 v4.11.0 h1:+kbwxm5IBGIiNYVhss+hM3Nv4ck+HnPSNscCNbD1cT0= +gopkg.in/evanphx/json-patch.v4 v4.11.0/go.mod h1:p8EYWUEYMpynmqDbY58zCKCFZw8pRWMG4EsWvDvM72M= gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4= gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= gopkg.in/gemnasium/logrus-airbrake-hook.v2 v2.1.2/go.mod h1:Xk6kEKp8OKb+X14hQBKWaSkCsqBpgog8nAV2xsGOxlo= diff --git a/http/custom_header_test.go b/http/custom_header_test.go new file mode 100644 index 0000000000000..5125050ad570d --- /dev/null +++ b/http/custom_header_test.go @@ -0,0 +1,128 @@ +package http + +import ( + "testing" + + "github.com/hashicorp/vault/vault" +) + +var defaultCustomHeaders = map[string]string { + "Strict-Transport-Security": "max-age=1; domains", + "Content-Security-Policy": "default-src 'others'", + "X-Custom-Header": "Custom header value default", + "X-Frame-Options": "Deny", + "X-Content-Type-Options": "nosniff", + "Content-Type": "application/json", + "X-XSS-Protection": "1; mode=block", +} + +var customHeader2xx = map[string]string { + "X-Custom-Header": "Custom header value 2xx", +} + +var customHeader200 = map[string]string { + "Someheader-200": "200", + "X-Custom-Header": "Custom header value 200", +} + +var customHeader4xx = map[string]string { + "Someheader-4xx": "4xx", +} + +var customHeader400 = map[string]string { + "Someheader-400": "400", +} + +var customHeader405 = map[string]string { + "Someheader-405": "405", +} + +var CustomResponseHeaders = map[string]map[string]string{ + "default": defaultCustomHeaders, + "307": {"X-Custom-Header": "Custom header value 307"}, + "3xx": { + "X-Custom-Header": "Custom header value 3xx", + "X-Vault-Ignored-3xx": "Ignored 3xx", + }, + "200": customHeader200, + "2xx": customHeader2xx, + "400": customHeader400, + "405": customHeader405, + "4xx": customHeader4xx, +} + +func TestCustomResponseHeaders(t *testing.T) { + core, _, token := vault.TestCoreWithCustomResponseHeaderAndUI(t, CustomResponseHeaders, true) + ln, addr := TestServer(t, core) + defer ln.Close() + TestServerAuth(t, addr, token) + + resp := testHttpGet(t, token, addr+"/v1/sys/raw/") + testResponseStatus(t, resp, 404) + testResponseHeader(t, resp, defaultCustomHeaders) + testResponseHeader(t, resp, customHeader4xx) + + resp = testHttpGet(t, token, addr+"/v1/sys/seal") + testResponseStatus(t, resp, 405) + testResponseHeader(t, resp, defaultCustomHeaders) + testResponseHeader(t, resp, customHeader4xx) + testResponseHeader(t, resp, customHeader405) + + resp = testHttpGet(t, token, addr+"/v1/sys/leader") + testResponseStatus(t, resp, 200) + testResponseHeader(t, resp, customHeader200) + + resp = testHttpGet(t, token, addr+"/v1/sys/health") + testResponseStatus(t, resp, 200) + testResponseHeader(t, resp, customHeader200) + + resp = testHttpGet(t, token, addr+"/v1/sys/generate-root/attempt") + testResponseStatus(t, resp, 200) + testResponseHeader(t, resp, customHeader200) + + resp = testHttpGet(t, token, addr+"/v1/sys/generate-root/update") + testResponseStatus(t, resp, 400) + testResponseHeader(t, resp, defaultCustomHeaders) + testResponseHeader(t, resp, customHeader4xx) + testResponseHeader(t, resp, customHeader400) + + resp = testHttpGet(t, token, addr+"/v1/sys/") + testResponseStatus(t, resp, 404) + testResponseHeader(t, resp, defaultCustomHeaders) + testResponseHeader(t, resp, customHeader4xx) + + resp = testHttpGet(t, token, addr+"/v1/sys") + testResponseStatus(t, resp, 404) + testResponseHeader(t, resp, defaultCustomHeaders) + testResponseHeader(t, resp, customHeader4xx) + + resp = testHttpGet(t, token, addr+"/v1/") + testResponseStatus(t, resp, 404) + testResponseHeader(t, resp, defaultCustomHeaders) + testResponseHeader(t, resp, customHeader4xx) + + resp = testHttpGet(t, token, addr+"/v1") + testResponseStatus(t, resp, 404) + testResponseHeader(t, resp, defaultCustomHeaders) + testResponseHeader(t, resp, customHeader4xx) + + resp = testHttpGet(t, token, addr+"/") + testResponseStatus(t, resp, 200) + testResponseHeader(t, resp, customHeader200) + + resp = testHttpGet(t, token, addr+"/ui") + testResponseStatus(t, resp, 200) + testResponseHeader(t, resp, customHeader200) + + resp = testHttpGet(t, token, addr+"/ui/") + testResponseStatus(t, resp, 200) + testResponseHeader(t, resp, customHeader200) + + resp = testHttpPost(t, token, addr+"/v1/sys/auth/foo", map[string]interface{}{ + "type": "noop", + "description": "foo", + }) + testResponseStatus(t, resp, 204) + testResponseHeader(t, resp, customHeader2xx) + +} \ No newline at end of file diff --git a/http/handler.go b/http/handler.go index 7d48f97aee8fe..22aab6ccbd8b4 100644 --- a/http/handler.go +++ b/http/handler.go @@ -16,12 +16,14 @@ import ( "net/textproto" "net/url" "os" + "strconv" "strings" "time" "github.com/NYTimes/gziphandler" "github.com/hashicorp/errwrap" "github.com/hashicorp/go-cleanhttp" + log "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-secure-stdlib/parseutil" "github.com/hashicorp/go-sockaddr" "github.com/hashicorp/vault/helper/namespace" @@ -210,6 +212,90 @@ func Handler(props *vault.HandlerProperties) http.Handler { return printablePathCheckHandler } +type WrappingResponseWriter interface { + http.ResponseWriter + Wrapped() http.ResponseWriter +} + +type statusHeaderResponseWriter struct { + wrapped http.ResponseWriter + logger log.Logger + wroteHeader bool + statusCode int + headers map[string][]*vault.CustomHeader +} + +func (w *statusHeaderResponseWriter) Wrapped() http.ResponseWriter { + return w.wrapped +} + +func (w *statusHeaderResponseWriter) Header() http.Header { + return w.wrapped.Header() +} + +func (w *statusHeaderResponseWriter) Write(buf []byte) (int, error) { + // It is allowed to only call ResponseWriter.Write and skip + // ResponseWriter.WriteHeader. An example of such a situation is + // "handleUIStub". The Write function will internally set the status code + // 200 for the response for which that call might invoke other + // implementations of the WriteHeader function. So, we still need to set + // the custom headers. In cases where both WriteHeader and Write of + // statusHeaderResponseWriter struct are called the internal call to the + // WriterHeader invoked from inside Write method won't change the headers. + if !w.wroteHeader { + w.setCustomResponseHeaders(w.statusCode) + } + + return w.wrapped.Write(buf) +} + +func (w *statusHeaderResponseWriter) WriteHeader(statusCode int) { + w.setCustomResponseHeaders(statusCode) + w.wrapped.WriteHeader(statusCode) + w.statusCode = statusCode + // in cases where Write is called after WriteHeader, let's prevent setting + // ResponseWriter headers twice + w.wroteHeader = true +} + +func (w *statusHeaderResponseWriter) setCustomResponseHeaders(status int) { + sch := w.headers + if sch == nil { + w.logger.Warn("status code header map not configured") + return + } + + // Checking the validity of the status code + if status >= 600 || status < 100 { + return + } + + // setter function to set the headers + setter := func(hvl []*vault.CustomHeader) { + for _, hv := range hvl { + w.Header().Set(hv.Name, hv.Value) + } + } + + // Setting the default headers first + setter(sch["default"]) + + // setting the Xyy pattern first + d := fmt.Sprintf("%vxx", status/100) + if val, ok := sch[d]; ok { + setter(val) + } + + // Setting the specific headers + if val, ok := sch[strconv.Itoa(status)]; ok { + setter(val) + } + + return +} + +var _ WrappingResponseWriter = &statusHeaderResponseWriter{} + type copyResponseWriter struct { wrapped http.ResponseWriter statusCode int @@ -300,6 +386,22 @@ func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerPr hostname, _ := os.Hostname() return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // This block needs to be here so that upon sending SIGHUP, custom response + // headers are also reloaded into the handlers. + if props.ListenerConfig != nil { + la := props.ListenerConfig.Address + listenerCustomHeaders := core.GetListenerCustomResponseHeaders(la) + if listenerCustomHeaders != nil { + w = &statusHeaderResponseWriter{ + wrapped: w, + logger: core.Logger(), + wroteHeader: false, + statusCode: 200, + headers: listenerCustomHeaders.StatusCodeHeaderMap, + } + } + } + // Set the Cache-Control header for all the responses returned // by Vault w.Header().Set("Cache-Control", "no-store") @@ -632,7 +734,15 @@ func parseJSONRequest(perfStandby bool, r *http.Request, w http.ResponseWriter, return nil, errors.New("could not parse max_request_size from request context") } if max > 0 { - reader = http.MaxBytesReader(w, r.Body, max) + // MaxBytesReader won't do all the internal stuff it must unless it's + // given a ResponseWriter that implements the internal http interface + // requestTooLarger. So we let it have access to the underlying + // ResponseWriter. + inw := w + if myw, ok := inw.(WrappingResponseWriter); ok { + inw = myw.Wrapped() + } + reader = http.MaxBytesReader(inw, r.Body, max) } } var origBody io.ReadWriter diff --git a/http/handler_test.go b/http/handler_test.go index c228629ea8dce..87a79d3143394 100644 --- a/http/handler_test.go +++ b/http/handler_test.go @@ -16,6 +16,10 @@ import ( "github.com/go-test/deep" "github.com/hashicorp/go-cleanhttp" + kv "github.com/hashicorp/vault-plugin-secrets-kv" + "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/audit" + auditFile "github.com/hashicorp/vault/builtin/audit/file" "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/logical" @@ -818,3 +822,198 @@ func TestHandler_Parse_Form(t *testing.T) { t.Fatal(diff) } } + +func TestHandler_Patch_BadContentTypeHeader(t *testing.T) { + coreConfig := &vault.CoreConfig{ + LogicalBackends: map[string]logical.Factory{ + "kv": kv.VersionedKVFactory, + }, + } + + cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ + HandlerFunc: Handler, + }) + + cluster.Start() + defer cluster.Cleanup() + + cores := cluster.Cores + + core := cores[0].Core + c := cluster.Cores[0].Client + vault.TestWaitActive(t, core) + + // Mount a KVv2 backend + err := c.Sys().Mount("kv", &api.MountInput{ + Type: "kv-v2", + }) + if err != nil { + t.Fatal(err) + } + + kvData := map[string]interface{}{ + "data": map[string]interface{}{ + "bar": "a", + }, + } + + resp, err := c.Logical().Write("kv/data/foo", kvData) + if err != nil { + t.Fatalf("write failed - err :%#v, resp: %#v\n", err, resp) + } + + resp, err = c.Logical().Read("kv/data/foo") + if err != nil { + t.Fatalf("read failed - err :%#v, resp: %#v\n", err, resp) + } + + req := c.NewRequest("PATCH", "/v1/kv/data/foo") + req.Headers = http.Header{ + "Content-Type": []string{"application/json"}, + } + + if err := req.SetJSONBody(kvData); err != nil { + t.Fatal(err) + } + + apiResp, err := c.RawRequestWithContext(context.Background(), req) + if err == nil || apiResp.StatusCode != http.StatusUnsupportedMediaType { + t.Fatalf("expected PATCH request to fail with %d status code - err :%#v, resp: %#v\n", http.StatusUnsupportedMediaType, err, apiResp) + } +} + +func TestHandler_Patch_NotFound(t *testing.T) { + coreConfig := &vault.CoreConfig{ + LogicalBackends: map[string]logical.Factory{ + "kv": kv.VersionedKVFactory, + }, + } + + cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ + HandlerFunc: Handler, + }) + + cluster.Start() + defer cluster.Cleanup() + + cores := cluster.Cores + + core := cores[0].Core + c := cluster.Cores[0].Client + vault.TestWaitActive(t, core) + + // Mount a KVv2 backend + err := c.Sys().Mount("kv", &api.MountInput{ + Type: "kv-v2", + }) + if err != nil { + t.Fatal(err) + } + + kvData := map[string]interface{}{ + "data": map[string]interface{}{ + "bar": "a", + }, + } + + resp, err := c.Logical().JSONMergePatch(context.Background(), "kv/data/foo", kvData) + if err == nil { + t.Fatalf("expected PATCH request to fail, resp: %#v\n", resp) + } + + responseError := err.(*api.ResponseError) + if responseError.StatusCode != http.StatusNotFound { + t.Fatalf("expected PATCH request to fail with %d status code - err: %#v, resp: %#v\n", http.StatusNotFound, responseError, resp) + } +} + +func TestHandler_Patch_Audit(t *testing.T) { + coreConfig := &vault.CoreConfig{ + LogicalBackends: map[string]logical.Factory{ + "kv": kv.VersionedKVFactory, + }, + AuditBackends: map[string]audit.Factory{ + "file": auditFile.Factory, + }, + } + + cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ + HandlerFunc: Handler, + }) + + cluster.Start() + defer cluster.Cleanup() + + cores := cluster.Cores + + core := cores[0].Core + c := cluster.Cores[0].Client + vault.TestWaitActive(t, core) + + if err := c.Sys().Mount("kv/", &api.MountInput{ + Type: "kv-v2", + }); err != nil { + t.Fatalf("kv-v2 mount attempt failed - err: %#v\n", err) + } + + auditLogFile, err := ioutil.TempFile("", "httppatch") + if err != nil { + t.Fatal(err) + } + + err = c.Sys().EnableAuditWithOptions("file", &api.EnableAuditOptions{ + Type: "file", + Options: map[string]string{ + "file_path": auditLogFile.Name(), + }, + }) + + writeData := map[string]interface{}{ + "data": map[string]interface{}{ + "bar": "a", + }, + } + + resp, err := c.Logical().Write("kv/data/foo", writeData) + if err != nil { + t.Fatalf("write request failed, err: %#v, resp: %#v\n", err, resp) + } + + patchData := map[string]interface{}{ + "data": map[string]interface{}{ + "baz": "b", + }, + } + + resp, err = c.Logical().JSONMergePatch(context.Background(), "kv/data/foo", patchData) + if err != nil { + t.Fatalf("patch request failed, err: %#v, resp: %#v\n", err, resp) + } + + patchRequestLogCount := 0 + patchResponseLogCount := 0 + decoder := json.NewDecoder(auditLogFile) + + var auditRecord map[string]interface{} + for decoder.Decode(&auditRecord) == nil { + auditRequest := map[string]interface{}{} + + if req, ok := auditRecord["request"]; ok { + auditRequest = req.(map[string]interface{}) + } + + if auditRequest["operation"] == "patch" && auditRecord["type"] == "request" { + patchRequestLogCount += 1 + } else if auditRequest["operation"] == "patch" && auditRecord["type"] == "response" { + patchResponseLogCount += 1 + } + } + + if patchRequestLogCount != 1 { + t.Fatalf("expected 1 patch request audit log record, saw %d\n", patchRequestLogCount) + } + + if patchResponseLogCount != 1 { + t.Fatalf("expected 1 patch response audit log record, saw %d\n", patchResponseLogCount) + } +} diff --git a/http/http_test.go b/http/http_test.go index e37b9c3d7693e..692aef0d82877 100644 --- a/http/http_test.go +++ b/http/http_test.go @@ -125,6 +125,16 @@ func testResponseStatus(t *testing.T, resp *http.Response, code int) { } } +func testResponseHeader(t *testing.T, resp *http.Response, expectedHeaders map[string]string) { + t.Helper() + for k, v := range expectedHeaders { + hv := resp.Header.Get(k) + if v != hv { + t.Fatalf("expected header value %v=%v, got %v=%v", k, v, k, hv) + } + } +} + func testResponseBody(t *testing.T, resp *http.Response, out interface{}) { defer resp.Body.Close() diff --git a/http/logical.go b/http/logical.go index dd9abce34dfdb..7984f8ac06ba3 100644 --- a/http/logical.go +++ b/http/logical.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "io" + "mime" "net" "net/http" "strconv" @@ -38,6 +39,8 @@ func (b *bufferedReader) Close() error { return b.rOrig.Close() } +const MergePatchContentTypeHeader = "application/merge-patch+json" + func buildLogicalRequestNoAuth(perfStandby bool, w http.ResponseWriter, r *http.Request) (*logical.Request, io.ReadCloser, int, error) { ns, err := namespace.FromContext(r.Context()) if err != nil { @@ -139,6 +142,34 @@ func buildLogicalRequestNoAuth(perfStandby bool, w http.ResponseWriter, r *http. } } + case "PATCH": + op = logical.PatchOperation + + contentTypeHeader := r.Header.Get("Content-Type") + contentType, _, err := mime.ParseMediaType(contentTypeHeader) + if err != nil { + status := http.StatusBadRequest + logical.AdjustErrorStatusCode(&status, err) + return nil, nil, status, err + } + + if contentType != MergePatchContentTypeHeader { + return nil, nil, http.StatusUnsupportedMediaType, fmt.Errorf("PATCH requires Content-Type of %s, provided %s", MergePatchContentTypeHeader, contentType) + } + + origBody, err = parseJSONRequest(perfStandby, r, w, &data) + + if err == io.EOF { + data = nil + err = nil + } + + if err != nil { + status := http.StatusBadRequest + logical.AdjustErrorStatusCode(&status, err) + return nil, nil, status, fmt.Errorf("error parsing JSON") + } + case "LIST": op = logical.ListOperation if !strings.HasSuffix(path, "/") { diff --git a/http/sys_metrics.go b/http/sys_metrics.go index 0e58be3ea262d..012417282e5f5 100644 --- a/http/sys_metrics.go +++ b/http/sys_metrics.go @@ -35,12 +35,14 @@ func handleMetricsUnauthenticated(core *vault.Core) http.Handler { resp := core.MetricsHelper().ResponseForFormat(format) // Manually extract the logical response and send back the information - w.WriteHeader(resp.Data[logical.HTTPStatusCode].(int)) + status := resp.Data[logical.HTTPStatusCode].(int) w.Header().Set("Content-Type", resp.Data[logical.HTTPContentType].(string)) switch v := resp.Data[logical.HTTPRawBody].(type) { case string: + w.WriteHeader(status) w.Write([]byte(v)) case []byte: + w.WriteHeader(status) w.Write(v) default: respondError(w, http.StatusInternalServerError, fmt.Errorf("wrong response returned")) diff --git a/http/testing.go b/http/testing.go index be9569dc9684c..84ab73fc08006 100644 --- a/http/testing.go +++ b/http/testing.go @@ -6,6 +6,7 @@ import ( "net/http" "testing" + "github.com/hashicorp/vault/internalshared/configutil" "github.com/hashicorp/vault/vault" ) @@ -41,10 +42,16 @@ func TestServerWithListenerAndProperties(tb testing.TB, ln net.Listener, addr st } func TestServerWithListener(tb testing.TB, ln net.Listener, addr string, core *vault.Core) { + ip, _, _ := net.SplitHostPort(ln.Addr().String()) + // Create a muxer to handle our requests so that we can authenticate // for tests. props := &vault.HandlerProperties{ Core: core, + // This is needed for testing custom response headers + ListenerConfig: &configutil.Listener { + Address: ip, + }, } TestServerWithListenerAndProperties(tb, ln, addr, core, props) } diff --git a/internalshared/configutil/http_response_headers.go b/internalshared/configutil/http_response_headers.go new file mode 100644 index 0000000000000..2db3034e588b6 --- /dev/null +++ b/internalshared/configutil/http_response_headers.go @@ -0,0 +1,129 @@ +package configutil + +import ( + "fmt" + "net/textproto" + "strconv" + "strings" + + "github.com/hashicorp/go-secure-stdlib/strutil" +) + +var ValidCustomStatusCodeCollection = []string{ + "default", + "1xx", + "2xx", + "3xx", + "4xx", + "5xx", +} + +const StrictTransportSecurity = "max-age=31536000; includeSubDomains" + +// ParseCustomResponseHeaders takes a raw config values for the +// "custom_response_headers". It makes sure the config entry is passed in +// as a map of status code to a map of header name and header values. It +// verifies the validity of the status codes, and header values. It also +// adds the default headers values. +func ParseCustomResponseHeaders(responseHeaders interface{}) (map[string]map[string]string, error) { + h := make(map[string]map[string]string) + // if r is nil, we still should set the default custom headers + if responseHeaders == nil { + h["default"] = map[string]string{"Strict-Transport-Security": StrictTransportSecurity} + return h, nil + } + + customResponseHeader, ok := responseHeaders.([]map[string]interface{}) + if !ok { + return nil, fmt.Errorf("response headers were not configured correctly. please make sure they're in a slice of maps") + } + + for _, crh := range customResponseHeader { + for statusCode, responseHeader := range crh { + headerValList, ok := responseHeader.([]map[string]interface{}) + if !ok { + return nil, fmt.Errorf("response headers were not configured correctly. please make sure they're in a slice of maps") + } + + if !IsValidStatusCode(statusCode) { + return nil, fmt.Errorf("invalid status code found in the server configuration: %v", statusCode) + } + + if len(headerValList) != 1 { + return nil, fmt.Errorf("invalid number of response headers exist") + } + headerValMap := headerValList[0] + headerVal, err := parseHeaders(headerValMap) + if err != nil { + return nil, err + } + + h[statusCode] = headerVal + } + } + + // setting Strict-Transport-Security as a default header + if h["default"] == nil { + h["default"] = make(map[string]string) + } + if _, ok := h["default"]["Strict-Transport-Security"]; !ok { + h["default"]["Strict-Transport-Security"] = StrictTransportSecurity + } + + return h, nil +} + +// IsValidStatusCode checking for status codes outside the boundary +func IsValidStatusCode(sc string) bool { + if strutil.StrListContains(ValidCustomStatusCodeCollection, sc) { + return true + } + + i, err := strconv.Atoi(sc) + if err != nil { + return false + } + + if i >= 600 || i < 100 { + return false + } + + return true +} + +func parseHeaders(in map[string]interface{}) (map[string]string, error) { + hvMap := make(map[string]string) + for k, v := range in { + // parsing header name + headerName := textproto.CanonicalMIMEHeaderKey(k) + // parsing header values + s, err := parseHeaderValues(v) + if err != nil { + return nil, err + } + hvMap[headerName] = s + } + return hvMap, nil +} + +func parseHeaderValues(header interface{}) (string, error) { + var sl []string + if _, ok := header.([]interface{}); !ok { + return "", fmt.Errorf("headers must be given in a list of strings") + } + headerValList := header.([]interface{}) + for _, vh := range headerValList { + if _, ok := vh.(string); !ok { + return "", fmt.Errorf("found a non-string header value: %v", vh) + } + headerVal := strings.TrimSpace(vh.(string)) + if headerVal == "" { + continue + } + sl = append(sl, headerVal) + + } + s := strings.Join(sl, "; ") + + return s, nil +} diff --git a/internalshared/configutil/listener.go b/internalshared/configutil/listener.go index 98199082895a4..677dbf9df8b3a 100644 --- a/internalshared/configutil/listener.go +++ b/internalshared/configutil/listener.go @@ -99,6 +99,10 @@ type Listener struct { CorsAllowedOrigins []string `hcl:"cors_allowed_origins"` CorsAllowedHeaders []string `hcl:"-"` CorsAllowedHeadersRaw []string `hcl:"cors_allowed_headers,alias:cors_allowed_headers"` + + // Custom Http response headers + CustomResponseHeaders map[string]map[string]string `hcl:"-"` + CustomResponseHeadersRaw interface{} `hcl:"custom_response_headers"` } func (l *Listener) GoString() string { @@ -361,6 +365,17 @@ func ParseListeners(result *SharedConfig, list *ast.ObjectList) error { } } + // HTTP Headers + { + // if CustomResponseHeadersRaw is nil, we still need to set the default headers + customHeadersMap, err := ParseCustomResponseHeaders(l.CustomResponseHeadersRaw) + if err != nil { + return multierror.Prefix(fmt.Errorf("failed to parse custom_response_headers: %w", err), fmt.Sprintf("listeners.%d", i)) + } + l.CustomResponseHeaders = customHeadersMap + l.CustomResponseHeadersRaw = nil + } + result.Listeners = append(result.Listeners, &l) } diff --git a/sdk/framework/backend.go b/sdk/framework/backend.go index c2c3f1810008b..673351a47e374 100644 --- a/sdk/framework/backend.go +++ b/sdk/framework/backend.go @@ -3,6 +3,7 @@ package framework import ( "context" "crypto/rand" + "encoding/json" "fmt" "io" "io/ioutil" @@ -13,6 +14,7 @@ import ( "sync" "time" + jsonpatch "github.com/evanphx/json-patch" "github.com/hashicorp/errwrap" log "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-kms-wrapping/entropy" @@ -41,10 +43,8 @@ type Backend struct { // paths, including adding or removing, is not allowed once the // backend is in use). // - // PathsSpecial is the list of path patterns that denote the - // paths above that require special privileges. These can't be - // regular expressions, it is either exact match or prefix match. - // For prefix match, append '*' as a suffix. + // PathsSpecial is the list of path patterns that denote the paths above + // that require special privileges. Paths []*Path PathsSpecial *logical.Paths @@ -120,6 +120,10 @@ type InvalidateFunc func(context.Context, string) // Initialize() just after a plugin has been mounted. type InitializeFunc func(context.Context, *logical.InitializationRequest) error +// PatchPreprocessorFunc is used by HandlePatchOperation in order to shape +// the input as defined by request handler prior to JSON marshaling +type PatchPreprocessorFunc func(map[string]interface{}) (map[string]interface{}, error) + // Initialize is the logical.Backend implementation. func (b *Backend) Initialize(ctx context.Context, req *logical.InitializationRequest) error { if b.InitializeFunc != nil { @@ -274,6 +278,57 @@ func (b *Backend) HandleRequest(ctx context.Context, req *logical.Request) (*log return callback(ctx, req, &fd) } +// HandlePatchOperation acts as an abstraction for performing JSON merge patch +// operations (see https://datatracker.ietf.org/doc/html/rfc7396) for HTTP +// PATCH requests. It is responsible for properly processing and marshalling +// the input and existing resource prior to performing the JSON merge operation +// using the MergePatch function from the json-patch library. The preprocessor +// is an arbitrary func that can be provided to further process the input. The +// MergePatch function accepts and returns byte arrays. +func HandlePatchOperation(input *FieldData, resource map[string]interface{}, preprocessor PatchPreprocessorFunc) ([]byte, error) { + var err error + + if resource == nil { + return nil, fmt.Errorf("resource does not exist") + } + + inputMap := map[string]interface{}{} + + // Parse all fields to ensure data types are handled properly according to the FieldSchema + for key := range input.Raw { + val, ok := input.GetOk(key) + + // Only accept fields in the schema + if ok { + inputMap[key] = val + } + } + + if preprocessor != nil { + inputMap, err = preprocessor(inputMap) + if err != nil { + return nil, err + } + } + + marshaledResource, err := json.Marshal(resource) + if err != nil { + return nil, err + } + + marshaledInput, err := json.Marshal(inputMap) + if err != nil { + return nil, err + } + + modified, err := jsonpatch.MergePatch(marshaledResource, marshaledInput) + if err != nil { + return nil, err + } + + return modified, nil +} + // SpecialPaths is the logical.Backend implementation. func (b *Backend) SpecialPaths() *logical.Paths { return b.PathsSpecial diff --git a/sdk/logical/logical.go b/sdk/logical/logical.go index db8831535a8e8..cec2d19c0e6ed 100644 --- a/sdk/logical/logical.go +++ b/sdk/logical/logical.go @@ -117,6 +117,10 @@ type Paths struct { Root []string // Unauthenticated are the paths that can be accessed without any auth. + // These can't be regular expressions, it is either exact match, a prefix + // match and/or a wildcard match. For prefix match, append '*' as a suffix. + // For a wildcard match, use '+' in the segment to match any identifier + // (e.g. 'foo/+/bar'). Note that '+' can't be adjacent to a non-slash. Unauthenticated []string // LocalStorage are paths (prefixes) that are local to this instance; this diff --git a/sdk/logical/request.go b/sdk/logical/request.go index b88aabce2df37..e683217a6efc5 100644 --- a/sdk/logical/request.go +++ b/sdk/logical/request.go @@ -350,6 +350,7 @@ const ( CreateOperation Operation = "create" ReadOperation = "read" UpdateOperation = "update" + PatchOperation = "patch" DeleteOperation = "delete" ListOperation = "list" HelpOperation = "help" diff --git a/sdk/logical/response_util.go b/sdk/logical/response_util.go index 6ae3005b735f1..5244069379864 100644 --- a/sdk/logical/response_util.go +++ b/sdk/logical/response_util.go @@ -17,7 +17,7 @@ import ( func RespondErrorCommon(req *Request, resp *Response, err error) (int, error) { if err == nil && (resp == nil || !resp.IsError()) { switch { - case req.Operation == ReadOperation: + case req.Operation == ReadOperation, req.Operation == PatchOperation: if resp == nil { return http.StatusNotFound, nil } diff --git a/ui/.gitignore b/ui/.gitignore index 944ab0c5121f9..1eba99c9116da 100644 --- a/ui/.gitignore +++ b/ui/.gitignore @@ -18,6 +18,7 @@ /testem.log /yarn-error.log /.storybook/preview-head.html +package-lock.json # ember-try /.node_modules.ember-try/ diff --git a/ui/app/components/nav-header.js b/ui/app/components/nav-header.js index 509b88ed4b6d8..ce5f5638f62d8 100644 --- a/ui/app/components/nav-header.js +++ b/ui/app/components/nav-header.js @@ -1,10 +1,21 @@ import Component from '@ember/component'; +import { inject as service } from '@ember/service'; +import { computed } from '@ember/object'; + export default Component.extend({ + router: service(), 'data-test-navheader': true, classNameBindings: 'consoleFullscreen:panel-fullscreen', tagName: 'header', navDrawerOpen: false, consoleFullscreen: false, + hideLinks: computed('router.currentRouteName', function() { + let currentRoute = this.router.currentRouteName; + if ('vault.cluster.identity.oidc-provider' === currentRoute) { + return true; + } + return false; + }), actions: { toggleNavDrawer(isOpen) { if (isOpen !== undefined) { diff --git a/ui/app/components/oidc-consent-block.js b/ui/app/components/oidc-consent-block.js new file mode 100644 index 0000000000000..b5a9e6fbddbb1 --- /dev/null +++ b/ui/app/components/oidc-consent-block.js @@ -0,0 +1,59 @@ +/** + * @module OidcConsentBlock + * OidcConsentBlock components are used to show the consent form for the OIDC Authorization Code Flow + * + * @example + * ```js + * + * ``` + * @param {string} redirect - redirect is the URL where successful consent will redirect to + * @param {string} code - code is the string required to pass back to redirect on successful OIDC auth + * @param {string} [state] - state is a string which is required to return on redirect if provided, but optional generally + */ + +import Ember from 'ember'; +import Component from '@glimmer/component'; +import { action } from '@ember/object'; +import { tracked } from '@glimmer/tracking'; + +const validParameters = ['code', 'state']; +export default class OidcConsentBlockComponent extends Component { + @tracked didCancel = false; + + get win() { + return this.window || window; + } + + buildUrl(urlString, params) { + try { + let url = new URL(urlString); + Object.keys(params).forEach(key => { + if (params[key] && validParameters.includes(key)) { + url.searchParams.append(key, params[key]); + } + }); + return url; + } catch (e) { + console.debug('DEBUG: parsing url failed for', urlString); + throw new Error('Invalid URL'); + } + } + + @action + handleSubmit(evt) { + evt.preventDefault(); + let { redirect, ...params } = this.args; + let redirectUrl = this.buildUrl(redirect, params); + if (Ember.testing) { + this.args.testRedirect(redirectUrl.toString()); + } else { + this.win.location.replace(redirectUrl); + } + } + + @action + handleCancel(evt) { + evt.preventDefault(); + this.didCancel = true; + } +} diff --git a/ui/app/components/token-expire-warning.js b/ui/app/components/token-expire-warning.js index 4798652642ba8..eb2884b5043c1 100644 --- a/ui/app/components/token-expire-warning.js +++ b/ui/app/components/token-expire-warning.js @@ -1,5 +1,14 @@ -import Component from '@ember/component'; +import Component from '@glimmer/component'; +import { inject as service } from '@ember/service'; -export default Component.extend({ - tagName: '', -}); +export default class TokenExpireWarning extends Component { + @service router; + + get showWarning() { + let currentRoute = this.router.currentRouteName; + if ('vault.cluster.identity.oidc-provider' === currentRoute) { + return false; + } + return !!this.args.expirationDate; + } +} diff --git a/ui/app/controllers/vault/cluster/identity/oidc-provider.js b/ui/app/controllers/vault/cluster/identity/oidc-provider.js new file mode 100644 index 0000000000000..8b656eca92c95 --- /dev/null +++ b/ui/app/controllers/vault/cluster/identity/oidc-provider.js @@ -0,0 +1,24 @@ +import Controller from '@ember/controller'; + +export default class VaultClusterIdentityOidcProviderController extends Controller { + queryParams = [ + 'scope', // * + 'response_type', // * + 'client_id', // * + 'redirect_uri', // * + 'state', // * + 'nonce', // * + 'display', + 'prompt', + 'max_age', + ]; + scope = null; + response_type = null; + client_id = null; + redirect_uri = null; + state = null; + nonce = null; + display = null; + prompt = null; + max_age = null; +} diff --git a/ui/app/mixins/cluster-route.js b/ui/app/mixins/cluster-route.js index 67dae10860cfc..66127b89b17e3 100644 --- a/ui/app/mixins/cluster-route.js +++ b/ui/app/mixins/cluster-route.js @@ -7,6 +7,7 @@ const AUTH = 'vault.cluster.auth'; const CLUSTER = 'vault.cluster'; const CLUSTER_INDEX = 'vault.cluster.index'; const OIDC_CALLBACK = 'vault.cluster.oidc-callback'; +const OIDC_PROVIDER = 'vault.cluster.identity.oidc-provider'; const DR_REPLICATION_SECONDARY = 'vault.cluster.replication-dr-promote'; const DR_REPLICATION_SECONDARY_DETAILS = 'vault.cluster.replication-dr-promote.details'; const EXCLUDED_REDIRECT_URLS = ['/vault/logout']; @@ -20,7 +21,9 @@ export default Mixin.create({ transitionToTargetRoute(transition = {}) { const targetRoute = this.targetRouteName(transition); - + if (OIDC_PROVIDER === this.router.currentRouteName || OIDC_PROVIDER === transition?.to?.name) { + return RSVP.resolve(); + } if ( targetRoute && targetRoute !== this.routeName && diff --git a/ui/app/router.js b/ui/app/router.js index 090def1fa2acc..95a6e83e5f8e9 100644 --- a/ui/app/router.js +++ b/ui/app/router.js @@ -139,6 +139,10 @@ Router.map(function() { } this.route('not-found', { path: '/*path' }); + + this.route('identity', function() { + this.route('oidc-provider', { path: '/oidc/provider/:oidc_name/authorize' }); + }); }); this.route('not-found', { path: '/*path' }); }); diff --git a/ui/app/routes/vault/cluster/identity/oidc-provider.js b/ui/app/routes/vault/cluster/identity/oidc-provider.js new file mode 100644 index 0000000000000..7f5a4c66ac4d4 --- /dev/null +++ b/ui/app/routes/vault/cluster/identity/oidc-provider.js @@ -0,0 +1,115 @@ +import Route from '@ember/routing/route'; +import { inject as service } from '@ember/service'; + +const AUTH = 'vault.cluster.auth'; +const PROVIDER = 'vault.cluster.identity.oidc-provider'; + +export default class VaultClusterIdentityOidcProviderRoute extends Route { + @service auth; + @service router; + + get win() { + return this.window || window; + } + + _redirect(url, params) { + let redir = this._buildUrl(url, params); + this.win.location.replace(redir); + } + + beforeModel(transition) { + const currentToken = this.auth.get('currentTokenName'); + let { redirect_to, ...qp } = transition.to.queryParams; + console.debug('DEBUG: removing redirect_to', redirect_to); + if (!currentToken && 'none' === qp.prompt?.toLowerCase()) { + this._redirect(qp.redirect_uri, { + state: qp.state, + error: 'login_required', + }); + } else if (!currentToken || 'login' === qp.prompt?.toLowerCase()) { + if ('login' === qp.prompt?.toLowerCase()) { + this.auth.deleteCurrentToken(); + qp.prompt = null; + } + let { cluster_name } = this.paramsFor('vault.cluster'); + let url = this.router.urlFor(transition.to.name, transition.to.params, { queryParams: qp }); + return this.transitionTo(AUTH, cluster_name, { queryParams: { redirect_to: url } }); + } + } + + _redirectToAuth(oidcName, queryParams, logout = false) { + let { cluster_name } = this.paramsFor('vault.cluster'); + let currentRoute = this.router.urlFor(PROVIDER, oidcName, { queryParams }); + if (logout) { + this.auth.deleteCurrentToken(); + } + return this.transitionTo(AUTH, cluster_name, { queryParams: { redirect_to: currentRoute } }); + } + + _buildUrl(urlString, params) { + try { + let url = new URL(urlString); + Object.keys(params).forEach(key => { + if (params[key]) { + url.searchParams.append(key, params[key]); + } + }); + return url; + } catch (e) { + console.debug('DEBUG: parsing url failed for', urlString); + throw new Error('Invalid URL'); + } + } + + _handleSuccess(response, baseUrl, state) { + const { code } = response; + let redirectUrl = this._buildUrl(baseUrl, { code, state }); + this.win.location.replace(redirectUrl); + } + _handleError(errorResp, baseUrl) { + let redirectUrl = this._buildUrl(baseUrl, { ...errorResp }); + this.win.location.replace(redirectUrl); + } + + async model(params) { + let { oidc_name, ...qp } = params; + let decodedRedirect = decodeURI(qp.redirect_uri); + let url = this._buildUrl(`${this.win.origin}/v1/identity/oidc/provider/${oidc_name}/authorize`, qp); + try { + const response = await this.auth.ajax(url, 'GET', {}); + if ('consent' === qp.prompt?.toLowerCase()) { + return { + consent: { + code: response.code, + redirect: decodedRedirect, + state: qp.state, + }, + }; + } + this._handleSuccess(response, decodedRedirect, qp.state); + } catch (errorRes) { + let resp = await errorRes.json(); + let code = resp.error; + if (code === 'max_age_violation') { + this._redirectToAuth(oidc_name, qp, true); + } else if (code === 'invalid_redirect_uri') { + return { + error: { + title: 'Redirect URI mismatch', + message: + 'The provided redirect_uri is not in the list of allowed redirect URIs. Please make sure you are sending a valid redirect URI from your application.', + }, + }; + } else if (code === 'invalid_client_id') { + return { + error: { + title: 'Invalid client ID', + message: 'Your client ID is invalid. Please update your configuration and try again.', + }, + }; + } else { + this._handleError(resp, decodedRedirect); + } + } + } +} diff --git a/ui/app/services/auth.js b/ui/app/services/auth.js index 9b90965d38556..d8ceb145516a2 100644 --- a/ui/app/services/auth.js +++ b/ui/app/services/auth.js @@ -97,7 +97,7 @@ export default Service.extend({ } else if (response.status >= 200 && response.status < 300) { return resolve(response.json()); } else { - return reject(); + return reject(response); } }); }, diff --git a/ui/app/templates/components/nav-header.hbs b/ui/app/templates/components/nav-header.hbs index c6ad8738cf3cf..8a60fecdce2f0 100644 --- a/ui/app/templates/components/nav-header.hbs +++ b/ui/app/templates/components/nav-header.hbs @@ -9,30 +9,32 @@ {{/unless}} -