From 6404d83082d683a5c75ccefdb5774fad7e5dad38 Mon Sep 17 00:00:00 2001 From: Angel Garbarino Date: Tue, 12 Oct 2021 18:05:32 -0600 Subject: [PATCH 01/12] fix copy issue (#12810) --- ui/app/templates/components/secret-edit-toolbar.hbs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ui/app/templates/components/secret-edit-toolbar.hbs b/ui/app/templates/components/secret-edit-toolbar.hbs index b281e945ac3e9..914603903f593 100644 --- a/ui/app/templates/components/secret-edit-toolbar.hbs +++ b/ui/app/templates/components/secret-edit-toolbar.hbs @@ -48,7 +48,7 @@
  • From 7640d6a84042d75b40a8d947ac929d4b75e41ec5 Mon Sep 17 00:00:00 2001 From: dr-db <25711615+dr-db@users.noreply.github.com> Date: Tue, 12 Oct 2021 20:50:20 -0500 Subject: [PATCH 02/12] Update index.mdx (#12395) Typo fix. --- website/content/docs/commands/index.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/website/content/docs/commands/index.mdx b/website/content/docs/commands/index.mdx index cefdeeed82997..9af192f048fdd 100644 --- a/website/content/docs/commands/index.mdx +++ b/website/content/docs/commands/index.mdx @@ -333,7 +333,7 @@ overrides any other proxies found in the environment. Format should be There are different CLI flags that are available depending on subcommands. Some flags, such as those used for setting HTTP and output options, are available -globally, while others are specific to a particular subcommand. For a completely +globally, while others are specific to a particular subcommand. For a complete list of available flags, run: ```shell-session From c0abf719cc5380f70c14c42e2a73cfb40f67cb2f Mon Sep 17 00:00:00 2001 From: Victor Rodriguez Date: Wed, 13 Oct 2021 08:58:02 -0400 Subject: [PATCH 03/12] Wait for expiration manager to be out of restore mode while testing. (#12779) --- vault/expiration_test.go | 37 +++++++++++++++++++++++++------------ 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/vault/expiration_test.go b/vault/expiration_test.go index ccac2945dcf40..6fe12f6d6dffd 100644 --- a/vault/expiration_test.go +++ b/vault/expiration_test.go @@ -32,6 +32,17 @@ var testImagePull sync.Once // mockExpiration returns a mock expiration manager func mockExpiration(t testing.TB) *ExpirationManager { c, _, _ := TestCoreUnsealed(t) + + // Wait until the expiration manager is out of restore mode. + // This was added to prevent sporadic failures of TestExpiration_unrecoverableErrorMakesIrrevocable. + timeout := time.Now().Add(time.Second * 10) + for c.expiration.inRestoreMode() { + if time.Now().After(timeout) { + t.Fatal("ExpirationManager is still in restore mode after 10 seconds") + } + time.Sleep(50*time.Millisecond) + } + return c.expiration } @@ -3077,20 +3088,22 @@ func TestExpiration_unrecoverableErrorMakesIrrevocable(t *testing.T) { } for _, tc := range testCases { - tc.job.OnFailure(tc.err) + t.Run(tc.err.Error(), func(t *testing.T) { + tc.job.OnFailure(tc.err) - le, err := exp.loadEntry(ctx, tc.job.leaseID) - if err != nil { - t.Fatalf("could not load leaseID %q: %v", tc.job.leaseID, err) - } - if le == nil { - t.Fatalf("nil lease for leaseID: %q", tc.job.leaseID) - } + le, err := exp.loadEntry(ctx, tc.job.leaseID) + if err != nil { + t.Fatalf("could not load leaseID %q: %v", tc.job.leaseID, err) + } + if le == nil { + t.Fatalf("nil lease for leaseID: %q", tc.job.leaseID) + } - isIrrevocable := le.isIrrevocable() - if isIrrevocable != tc.shouldBeIrrevocable { - t.Errorf("expected irrevocable: %t, got irrevocable: %t", tc.shouldBeIrrevocable, isIrrevocable) - } + isIrrevocable := le.isIrrevocable() + if isIrrevocable != tc.shouldBeIrrevocable { + t.Errorf("expected irrevocable: %t, got irrevocable: %t", tc.shouldBeIrrevocable, isIrrevocable) + } + }) } } From 7d2fa4323ecc89f71fe7bf5a594c8d6166e32169 Mon Sep 17 00:00:00 2001 From: DJCrabhat Date: Wed, 13 Oct 2021 07:45:34 -0700 Subject: [PATCH 04/12] Add `nonce` configuration parameter to agent AWS auto-auth documentation (#10926) * Update aws.mdx Was looking how to give the vault agent with AWS auth-auth the same nonce, but saw it wasn't documented. Dove through the code, found https://github.com/hashicorp/vault/blob/master/command/agent/auth/aws/aws.go#L139 and https://github.com/hashicorp/vault/blob/master/command/agent/auth/aws/aws.go#L215 (tried to call out the importance and point to docs, know setting `nonce` poorly could be very bad!) * add line breaks * Apply suggestions from code review Co-authored-by: Loann Le <84412881+taoism4504@users.noreply.github.com> Co-authored-by: hghaf099 <83242695+hghaf099@users.noreply.github.com> Co-authored-by: Loann Le <84412881+taoism4504@users.noreply.github.com> --- website/content/docs/agent/autoauth/methods/aws.mdx | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/website/content/docs/agent/autoauth/methods/aws.mdx b/website/content/docs/agent/autoauth/methods/aws.mdx index ea0a50f0060d5..95be69a6164f4 100644 --- a/website/content/docs/agent/autoauth/methods/aws.mdx +++ b/website/content/docs/agent/autoauth/methods/aws.mdx @@ -56,6 +56,10 @@ parameters unset in your configuration. - `header_value` `(string: optional)` - If configured in Vault, the value to use for [`iam_server_id_header_value`](/api/auth/aws#iam_server_id_header_value). +- `nonce` `(string: optional)` - If not provided, Vault will generate a new UUID every time `vault agent` runs. + If set, make sure you understand the importance of generating a good, unique `nonce` and protecting it. + See [Client Nonce](/docs/auth/aws#client-nonce) for more information. + ## Learn Refer to the [Vault Agent with From 3aafbd0e8ad71785cb12c65529ce7ff37bfccce8 Mon Sep 17 00:00:00 2001 From: Loann Le <84412881+taoism4504@users.noreply.github.com> Date: Wed, 13 Oct 2021 07:48:00 -0700 Subject: [PATCH 05/12] Vault Documentation: Modified What is Vault description (#12783) * modified vault description * modified paragraph based on feedback * Update what-is-vault.mdx Removed characters that were arbitrarily added. * Update what-is-vault.mdx changed markdown syntax for 'secret's --- website/content/docs/what-is-vault.mdx | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/website/content/docs/what-is-vault.mdx b/website/content/docs/what-is-vault.mdx index a7c05e0056110..112673c066eca 100644 --- a/website/content/docs/what-is-vault.mdx +++ b/website/content/docs/what-is-vault.mdx @@ -20,10 +20,7 @@ available features as well as internals. ## What is Vault? -Vault is a tool for securely accessing _secrets_. A secret is anything that you -want to tightly control access to, such as API keys, passwords, or certificates. -Vault provides a unified interface to any secret, while providing tight access -control and recording a detailed audit log. +Vault is an identity-based **secrets** and encryption management system. A secret is anything that you want to tightly control access to, such as API encryption keys, passwords, or certificates. Vault provides encryption services that are gated by authentication and authorization methods. Using Vault’s UI, CLI, or HTTP API, access to secrets and other sensitive data can be securely stored and managed, tightly controlled (restricted), and auditable. A modern system requires access to a multitude of secrets: database credentials, API keys for external services, credentials for service-oriented architecture From e0bfb73815e54c9ef9691ec6d6ead6f09edf3bbf Mon Sep 17 00:00:00 2001 From: hghaf099 <83242695+hghaf099@users.noreply.github.com> Date: Wed, 13 Oct 2021 11:06:33 -0400 Subject: [PATCH 06/12] Customizing HTTP headers in the config file (#12485) * Customizing HTTP headers in the config file * Add changelog, fix bad imports * fixing some bugs * fixing interaction of custom headers and /ui * Defining a member in core to set custom response headers * missing additional file * Some refactoring * Adding automated tests for the feature * Changing some error messages based on some recommendations * Incorporating custom response headers struct into the request context * removing some unused references * fixing a test * changing some error messages, removing a default header value from /ui * fixing a test * wrapping ResponseWriter to set the custom headers * adding a new test * some cleanup * removing some extra lines * Addressing comments * fixing some agent tests * skipping custom headers from agent listener config, removing two of the default headers as they cause issues with Vault in UI mode Adding X-Content-Type-Options to the ui default headers Let Content-Type be set as before * Removing default custom headers, and renaming some function varibles * some refacotring * Refactoring and addressing comments * removing a function and fixing comments --- changelog/12485.txt | 3 + command/agent/config/config.go | 7 + command/agent/config/config_test.go | 2 - command/server.go | 6 + .../config_custom_response_headers_test.go | 109 +++++++++++ command/server/config_test_helpers.go | 19 +- .../config_custom_response_headers_1.hcl | 31 ++++ ...om_response_headers_multiple_listeners.hcl | 56 ++++++ http/custom_header_test.go | 128 +++++++++++++ http/handler.go | 112 ++++++++++- http/http_test.go | 10 + http/sys_metrics.go | 4 +- http/testing.go | 7 + .../configutil/http_response_headers.go | 129 +++++++++++++ internalshared/configutil/listener.go | 15 ++ vault/core.go | 111 +++++++++-- vault/custom_response_headers.go | 90 +++++++++ vault/custom_response_headers_test.go | 174 ++++++++++++++++++ vault/logical_system.go | 3 + vault/testing.go | 23 +++ vault/ui.go | 3 +- 21 files changed, 1019 insertions(+), 23 deletions(-) create mode 100644 changelog/12485.txt create mode 100644 command/server/config_custom_response_headers_test.go create mode 100644 command/server/test-fixtures/config_custom_response_headers_1.hcl create mode 100644 command/server/test-fixtures/config_custom_response_headers_multiple_listeners.hcl create mode 100644 http/custom_header_test.go create mode 100644 internalshared/configutil/http_response_headers.go create mode 100644 vault/custom_response_headers.go create mode 100644 vault/custom_response_headers_test.go diff --git a/changelog/12485.txt b/changelog/12485.txt new file mode 100644 index 0000000000000..6c8a87cd21cac --- /dev/null +++ b/changelog/12485.txt @@ -0,0 +1,3 @@ +```release-note:feature +**Customizable HTTP Headers**: Add support to define custom HTTP headers for root path (`/`) and also on API endpoints (`/v1/*`) +``` diff --git a/command/agent/config/config.go b/command/agent/config/config.go index 9438bd327444d..502d512d15a24 100644 --- a/command/agent/config/config.go +++ b/command/agent/config/config.go @@ -35,6 +35,7 @@ func (c *Config) Prune() { l.RawConfig = nil l.Profiling.UnusedKeys = nil l.Telemetry.UnusedKeys = nil + l.CustomResponseHeaders = nil } c.FoundKeys = nil c.UnusedKeys = nil @@ -172,6 +173,12 @@ func LoadConfig(path string) (*Config, error) { if err != nil { return nil, err } + + // Pruning custom headers for Agent for now + for _, ln := range sharedConfig.Listeners { + ln.CustomResponseHeaders = nil + } + result.SharedConfig = sharedConfig list, ok := obj.Node.(*ast.ObjectList) diff --git a/command/agent/config/config_test.go b/command/agent/config/config_test.go index 0db8cf91954f7..252461236c89b 100644 --- a/command/agent/config/config_test.go +++ b/command/agent/config/config_test.go @@ -536,7 +536,6 @@ func TestLoadConfigFile_AgentCache_PersistMissingType(t *testing.T) { } func TestLoadConfigFile_TemplateConfig(t *testing.T) { - testCases := map[string]struct { fixturePath string expectedTemplateConfig TemplateConfig @@ -586,7 +585,6 @@ func TestLoadConfigFile_TemplateConfig(t *testing.T) { } }) } - } // TestLoadConfigFile_Template tests template definitions in Vault Agent diff --git a/command/server.go b/command/server.go index 84814c4102562..718009b8cf419 100644 --- a/command/server.go +++ b/command/server.go @@ -1541,6 +1541,12 @@ func (c *ServerCommand) Run(args []string) int { core.SetConfig(config) + // reloading custom response headers to make sure we have + // the most up to date headers after reloading the config file + if err = core.ReloadCustomResponseHeaders(); err != nil { + c.logger.Error(err.Error()) + } + if config.LogLevel != "" { configLogLevel := strings.ToLower(strings.TrimSpace(config.LogLevel)) switch configLogLevel { diff --git a/command/server/config_custom_response_headers_test.go b/command/server/config_custom_response_headers_test.go new file mode 100644 index 0000000000000..5380568c25107 --- /dev/null +++ b/command/server/config_custom_response_headers_test.go @@ -0,0 +1,109 @@ +package server + +import ( + "fmt" + "testing" + + "github.com/go-test/deep" +) + +var defaultCustomHeaders = map[string]string{ + "Strict-Transport-Security": "max-age=1; domains", + "Content-Security-Policy": "default-src 'others'", + "X-Vault-Ignored": "ignored", + "X-Custom-Header": "Custom header value default", +} + +var customHeaders307 = map[string]string{ + "X-Custom-Header": "Custom header value 307", +} + +var customHeader3xx = map[string]string{ + "X-Vault-Ignored-3xx": "Ignored 3xx", + "X-Custom-Header": "Custom header value 3xx", +} + +var customHeaders200 = map[string]string{ + "Someheader-200": "200", + "X-Custom-Header": "Custom header value 200", +} + +var customHeader2xx = map[string]string{ + "X-Custom-Header": "Custom header value 2xx", +} + +var customHeader400 = map[string]string{ + "Someheader-400": "400", +} + +var defaultCustomHeadersMultiListener = map[string]string{ + "Strict-Transport-Security": "max-age=31536000; includeSubDomains", + "Content-Security-Policy": "default-src 'others'", + "X-Vault-Ignored": "ignored", + "X-Custom-Header": "Custom header value default", +} + +var defaultSTS = map[string]string{ + "Strict-Transport-Security": "max-age=31536000; includeSubDomains", +} + +func TestCustomResponseHeadersConfigs(t *testing.T) { + expectedCustomResponseHeader := map[string]map[string]string{ + "default": defaultCustomHeaders, + "307": customHeaders307, + "3xx": customHeader3xx, + "200": customHeaders200, + "2xx": customHeader2xx, + "400": customHeader400, + } + + config, err := LoadConfigFile("./test-fixtures/config_custom_response_headers_1.hcl") + if err != nil { + t.Fatalf("Error encountered when loading config %+v", err) + } + if diff := deep.Equal(expectedCustomResponseHeader, config.Listeners[0].CustomResponseHeaders); diff != nil { + t.Fatalf(fmt.Sprintf("parsed custom headers do not match the expected ones, difference: %v", diff)) + } +} + +func TestCustomResponseHeadersConfigsMultipleListeners(t *testing.T) { + expectedCustomResponseHeader := map[string]map[string]string{ + "default": defaultCustomHeadersMultiListener, + "307": customHeaders307, + "3xx": customHeader3xx, + "200": customHeaders200, + "2xx": customHeader2xx, + "400": customHeader400, + } + + config, err := LoadConfigFile("./test-fixtures/config_custom_response_headers_multiple_listeners.hcl") + if err != nil { + t.Fatalf("Error encountered when loading config %+v", err) + } + if diff := deep.Equal(expectedCustomResponseHeader, config.Listeners[0].CustomResponseHeaders); diff != nil { + t.Fatalf(fmt.Sprintf("parsed custom headers do not match the expected ones, difference: %v", diff)) + } + + if diff := deep.Equal(expectedCustomResponseHeader, config.Listeners[1].CustomResponseHeaders); diff == nil { + t.Fatalf(fmt.Sprintf("parsed custom headers do not match the expected ones, difference: %v", diff)) + } + if diff := deep.Equal(expectedCustomResponseHeader["default"], config.Listeners[1].CustomResponseHeaders["default"]); diff != nil { + t.Fatalf(fmt.Sprintf("parsed custom headers do not match the expected ones, difference: %v", diff)) + } + + if diff := deep.Equal(expectedCustomResponseHeader, config.Listeners[2].CustomResponseHeaders); diff == nil { + t.Fatalf(fmt.Sprintf("parsed custom headers do not match the expected ones, difference: %v", diff)) + } + + if diff := deep.Equal(defaultSTS, config.Listeners[2].CustomResponseHeaders["default"]); diff != nil { + t.Fatalf(fmt.Sprintf("parsed custom headers do not match the expected ones, difference: %v", diff)) + } + + if diff := deep.Equal(expectedCustomResponseHeader, config.Listeners[3].CustomResponseHeaders); diff == nil { + t.Fatalf(fmt.Sprintf("parsed custom headers do not match the expected ones, difference: %v", diff)) + } + + if diff := deep.Equal(defaultSTS, config.Listeners[3].CustomResponseHeaders["default"]); diff != nil { + t.Fatalf(fmt.Sprintf("parsed custom headers do not match the expected ones, difference: %v", diff)) + } +} diff --git a/command/server/config_test_helpers.go b/command/server/config_test_helpers.go index e40f6c8368a6e..8936a0244090c 100644 --- a/command/server/config_test_helpers.go +++ b/command/server/config_test_helpers.go @@ -16,6 +16,12 @@ import ( "github.com/hashicorp/vault/internalshared/configutil" ) +var DefaultCustomHeaders = map[string]map[string]string { + "default": { + "Strict-Transport-Security": configutil.StrictTransportSecurity, + }, +} + func boolPointer(x bool) *bool { return &x } @@ -32,6 +38,7 @@ func testConfigRaftRetryJoin(t *testing.T) { { Type: "tcp", Address: "127.0.0.1:8200", + CustomResponseHeaders: DefaultCustomHeaders, }, }, DisableMlock: true, @@ -64,6 +71,7 @@ func testLoadConfigFile_topLevel(t *testing.T, entropy *configutil.Entropy) { { Type: "tcp", Address: "127.0.0.1:443", + CustomResponseHeaders: DefaultCustomHeaders, }, }, @@ -174,10 +182,12 @@ func testLoadConfigFile_json2(t *testing.T, entropy *configutil.Entropy) { { Type: "tcp", Address: "127.0.0.1:443", + CustomResponseHeaders: DefaultCustomHeaders, }, { Type: "tcp", Address: "127.0.0.1:444", + CustomResponseHeaders: DefaultCustomHeaders, }, }, @@ -336,6 +346,7 @@ func testLoadConfigFileIntegerAndBooleanValuesCommon(t *testing.T, path string) { Type: "tcp", Address: "127.0.0.1:8200", + CustomResponseHeaders: DefaultCustomHeaders, }, }, DisableMlock: true, @@ -379,6 +390,7 @@ func testLoadConfigFile(t *testing.T) { { Type: "tcp", Address: "127.0.0.1:443", + CustomResponseHeaders: DefaultCustomHeaders, }, }, @@ -486,7 +498,7 @@ func testUnknownFieldValidation(t *testing.T) { for _, er1 := range errors { found := false if strings.Contains(er1.String(), "sentinel") { - //This happens on OSS, and is fine + // This happens on OSS, and is fine continue } for _, ex := range expected { @@ -525,6 +537,7 @@ func testLoadConfigFile_json(t *testing.T) { { Type: "tcp", Address: "127.0.0.1:443", + CustomResponseHeaders: DefaultCustomHeaders, }, }, @@ -610,6 +623,7 @@ func testLoadConfigDir(t *testing.T) { { Type: "tcp", Address: "127.0.0.1:443", + CustomResponseHeaders: DefaultCustomHeaders, }, }, @@ -818,6 +832,7 @@ listener "tcp" { Profiling: configutil.ListenerProfiling{ UnauthenticatedPProfAccess: true, }, + CustomResponseHeaders: DefaultCustomHeaders, }, }, }, @@ -845,6 +860,7 @@ func testParseSeals(t *testing.T) { { Type: "tcp", Address: "127.0.0.1:443", + CustomResponseHeaders: DefaultCustomHeaders, }, }, Seals: []*configutil.KMS{ @@ -898,6 +914,7 @@ func testLoadConfigFileLeaseMetrics(t *testing.T) { { Type: "tcp", Address: "127.0.0.1:443", + CustomResponseHeaders: DefaultCustomHeaders, }, }, diff --git a/command/server/test-fixtures/config_custom_response_headers_1.hcl b/command/server/test-fixtures/config_custom_response_headers_1.hcl new file mode 100644 index 0000000000000..c2f868c2f146c --- /dev/null +++ b/command/server/test-fixtures/config_custom_response_headers_1.hcl @@ -0,0 +1,31 @@ +storage "inmem" {} +listener "tcp" { + address = "127.0.0.1:8200" + tls_disable = true + custom_response_headers { + "default" = { + "Strict-Transport-Security" = ["max-age=1","domains"], + "Content-Security-Policy" = ["default-src 'others'"], + "X-Vault-Ignored" = ["ignored"], + "X-Custom-Header" = ["Custom header value default"], + } + "307" = { + "X-Custom-Header" = ["Custom header value 307"], + } + "3xx" = { + "X-Vault-Ignored-3xx" = ["Ignored 3xx"], + "X-Custom-Header" = ["Custom header value 3xx"] + } + "200" = { + "someheader-200" = ["200"], + "X-Custom-Header" = ["Custom header value 200"] + } + "2xx" = { + "X-Custom-Header" = ["Custom header value 2xx"] + } + "400" = { + "someheader-400" = ["400"] + } + } +} +disable_mlock = true diff --git a/command/server/test-fixtures/config_custom_response_headers_multiple_listeners.hcl b/command/server/test-fixtures/config_custom_response_headers_multiple_listeners.hcl new file mode 100644 index 0000000000000..11aa099232f91 --- /dev/null +++ b/command/server/test-fixtures/config_custom_response_headers_multiple_listeners.hcl @@ -0,0 +1,56 @@ +storage "inmem" {} +listener "tcp" { + address = "127.0.0.1:8200" + tls_disable = true + custom_response_headers { + "default" = { + "Content-Security-Policy" = ["default-src 'others'"], + "X-Vault-Ignored" = ["ignored"], + "X-Custom-Header" = ["Custom header value default"], + } + "307" = { + "X-Custom-Header" = ["Custom header value 307"], + } + "3xx" = { + "X-Vault-Ignored-3xx" = ["Ignored 3xx"], + "X-Custom-Header" = ["Custom header value 3xx"] + } + "200" = { + "someheader-200" = ["200"], + "X-Custom-Header" = ["Custom header value 200"] + } + "2xx" = { + "X-Custom-Header" = ["Custom header value 2xx"] + } + "400" = { + "someheader-400" = ["400"] + } + } +} +listener "tcp" { + address = "127.0.0.2:8200" + tls_disable = true + custom_response_headers { + "default" = { + "Content-Security-Policy" = ["default-src 'others'"], + "X-Vault-Ignored" = ["ignored"], + "X-Custom-Header" = ["Custom header value default"], + } + } +} +listener "tcp" { + address = "127.0.0.3:8200" + tls_disable = true + custom_response_headers { + "2xx" = { + "X-Custom-Header" = ["Custom header value 2xx"] + } + } +} +listener "tcp" { + address = "127.0.0.4:8200" + tls_disable = true +} + + +disable_mlock = true diff --git a/http/custom_header_test.go b/http/custom_header_test.go new file mode 100644 index 0000000000000..5125050ad570d --- /dev/null +++ b/http/custom_header_test.go @@ -0,0 +1,128 @@ +package http + +import ( + "testing" + + "github.com/hashicorp/vault/vault" +) + +var defaultCustomHeaders = map[string]string { + "Strict-Transport-Security": "max-age=1; domains", + "Content-Security-Policy": "default-src 'others'", + "X-Custom-Header": "Custom header value default", + "X-Frame-Options": "Deny", + "X-Content-Type-Options": "nosniff", + "Content-Type": "application/json", + "X-XSS-Protection": "1; mode=block", +} + +var customHeader2xx = map[string]string { + "X-Custom-Header": "Custom header value 2xx", +} + +var customHeader200 = map[string]string { + "Someheader-200": "200", + "X-Custom-Header": "Custom header value 200", +} + +var customHeader4xx = map[string]string { + "Someheader-4xx": "4xx", +} + +var customHeader400 = map[string]string { + "Someheader-400": "400", +} + +var customHeader405 = map[string]string { + "Someheader-405": "405", +} + +var CustomResponseHeaders = map[string]map[string]string{ + "default": defaultCustomHeaders, + "307": {"X-Custom-Header": "Custom header value 307"}, + "3xx": { + "X-Custom-Header": "Custom header value 3xx", + "X-Vault-Ignored-3xx": "Ignored 3xx", + }, + "200": customHeader200, + "2xx": customHeader2xx, + "400": customHeader400, + "405": customHeader405, + "4xx": customHeader4xx, +} + +func TestCustomResponseHeaders(t *testing.T) { + core, _, token := vault.TestCoreWithCustomResponseHeaderAndUI(t, CustomResponseHeaders, true) + ln, addr := TestServer(t, core) + defer ln.Close() + TestServerAuth(t, addr, token) + + resp := testHttpGet(t, token, addr+"/v1/sys/raw/") + testResponseStatus(t, resp, 404) + testResponseHeader(t, resp, defaultCustomHeaders) + testResponseHeader(t, resp, customHeader4xx) + + resp = testHttpGet(t, token, addr+"/v1/sys/seal") + testResponseStatus(t, resp, 405) + testResponseHeader(t, resp, defaultCustomHeaders) + testResponseHeader(t, resp, customHeader4xx) + testResponseHeader(t, resp, customHeader405) + + resp = testHttpGet(t, token, addr+"/v1/sys/leader") + testResponseStatus(t, resp, 200) + testResponseHeader(t, resp, customHeader200) + + resp = testHttpGet(t, token, addr+"/v1/sys/health") + testResponseStatus(t, resp, 200) + testResponseHeader(t, resp, customHeader200) + + resp = testHttpGet(t, token, addr+"/v1/sys/generate-root/attempt") + testResponseStatus(t, resp, 200) + testResponseHeader(t, resp, customHeader200) + + resp = testHttpGet(t, token, addr+"/v1/sys/generate-root/update") + testResponseStatus(t, resp, 400) + testResponseHeader(t, resp, defaultCustomHeaders) + testResponseHeader(t, resp, customHeader4xx) + testResponseHeader(t, resp, customHeader400) + + resp = testHttpGet(t, token, addr+"/v1/sys/") + testResponseStatus(t, resp, 404) + testResponseHeader(t, resp, defaultCustomHeaders) + testResponseHeader(t, resp, customHeader4xx) + + resp = testHttpGet(t, token, addr+"/v1/sys") + testResponseStatus(t, resp, 404) + testResponseHeader(t, resp, defaultCustomHeaders) + testResponseHeader(t, resp, customHeader4xx) + + resp = testHttpGet(t, token, addr+"/v1/") + testResponseStatus(t, resp, 404) + testResponseHeader(t, resp, defaultCustomHeaders) + testResponseHeader(t, resp, customHeader4xx) + + resp = testHttpGet(t, token, addr+"/v1") + testResponseStatus(t, resp, 404) + testResponseHeader(t, resp, defaultCustomHeaders) + testResponseHeader(t, resp, customHeader4xx) + + resp = testHttpGet(t, token, addr+"/") + testResponseStatus(t, resp, 200) + testResponseHeader(t, resp, customHeader200) + + resp = testHttpGet(t, token, addr+"/ui") + testResponseStatus(t, resp, 200) + testResponseHeader(t, resp, customHeader200) + + resp = testHttpGet(t, token, addr+"/ui/") + testResponseStatus(t, resp, 200) + testResponseHeader(t, resp, customHeader200) + + resp = testHttpPost(t, token, addr+"/v1/sys/auth/foo", map[string]interface{}{ + "type": "noop", + "description": "foo", + }) + testResponseStatus(t, resp, 204) + testResponseHeader(t, resp, customHeader2xx) + +} \ No newline at end of file diff --git a/http/handler.go b/http/handler.go index 7d48f97aee8fe..22aab6ccbd8b4 100644 --- a/http/handler.go +++ b/http/handler.go @@ -16,12 +16,14 @@ import ( "net/textproto" "net/url" "os" + "strconv" "strings" "time" "github.com/NYTimes/gziphandler" "github.com/hashicorp/errwrap" "github.com/hashicorp/go-cleanhttp" + log "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-secure-stdlib/parseutil" "github.com/hashicorp/go-sockaddr" "github.com/hashicorp/vault/helper/namespace" @@ -210,6 +212,90 @@ func Handler(props *vault.HandlerProperties) http.Handler { return printablePathCheckHandler } +type WrappingResponseWriter interface { + http.ResponseWriter + Wrapped() http.ResponseWriter +} + +type statusHeaderResponseWriter struct { + wrapped http.ResponseWriter + logger log.Logger + wroteHeader bool + statusCode int + headers map[string][]*vault.CustomHeader +} + +func (w *statusHeaderResponseWriter) Wrapped() http.ResponseWriter { + return w.wrapped +} + +func (w *statusHeaderResponseWriter) Header() http.Header { + return w.wrapped.Header() +} + +func (w *statusHeaderResponseWriter) Write(buf []byte) (int, error) { + // It is allowed to only call ResponseWriter.Write and skip + // ResponseWriter.WriteHeader. An example of such a situation is + // "handleUIStub". The Write function will internally set the status code + // 200 for the response for which that call might invoke other + // implementations of the WriteHeader function. So, we still need to set + // the custom headers. In cases where both WriteHeader and Write of + // statusHeaderResponseWriter struct are called the internal call to the + // WriterHeader invoked from inside Write method won't change the headers. + if !w.wroteHeader { + w.setCustomResponseHeaders(w.statusCode) + } + + return w.wrapped.Write(buf) +} + +func (w *statusHeaderResponseWriter) WriteHeader(statusCode int) { + w.setCustomResponseHeaders(statusCode) + w.wrapped.WriteHeader(statusCode) + w.statusCode = statusCode + // in cases where Write is called after WriteHeader, let's prevent setting + // ResponseWriter headers twice + w.wroteHeader = true +} + +func (w *statusHeaderResponseWriter) setCustomResponseHeaders(status int) { + sch := w.headers + if sch == nil { + w.logger.Warn("status code header map not configured") + return + } + + // Checking the validity of the status code + if status >= 600 || status < 100 { + return + } + + // setter function to set the headers + setter := func(hvl []*vault.CustomHeader) { + for _, hv := range hvl { + w.Header().Set(hv.Name, hv.Value) + } + } + + // Setting the default headers first + setter(sch["default"]) + + // setting the Xyy pattern first + d := fmt.Sprintf("%vxx", status/100) + if val, ok := sch[d]; ok { + setter(val) + } + + // Setting the specific headers + if val, ok := sch[strconv.Itoa(status)]; ok { + setter(val) + } + + return +} + +var _ WrappingResponseWriter = &statusHeaderResponseWriter{} + type copyResponseWriter struct { wrapped http.ResponseWriter statusCode int @@ -300,6 +386,22 @@ func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerPr hostname, _ := os.Hostname() return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // This block needs to be here so that upon sending SIGHUP, custom response + // headers are also reloaded into the handlers. + if props.ListenerConfig != nil { + la := props.ListenerConfig.Address + listenerCustomHeaders := core.GetListenerCustomResponseHeaders(la) + if listenerCustomHeaders != nil { + w = &statusHeaderResponseWriter{ + wrapped: w, + logger: core.Logger(), + wroteHeader: false, + statusCode: 200, + headers: listenerCustomHeaders.StatusCodeHeaderMap, + } + } + } + // Set the Cache-Control header for all the responses returned // by Vault w.Header().Set("Cache-Control", "no-store") @@ -632,7 +734,15 @@ func parseJSONRequest(perfStandby bool, r *http.Request, w http.ResponseWriter, return nil, errors.New("could not parse max_request_size from request context") } if max > 0 { - reader = http.MaxBytesReader(w, r.Body, max) + // MaxBytesReader won't do all the internal stuff it must unless it's + // given a ResponseWriter that implements the internal http interface + // requestTooLarger. So we let it have access to the underlying + // ResponseWriter. + inw := w + if myw, ok := inw.(WrappingResponseWriter); ok { + inw = myw.Wrapped() + } + reader = http.MaxBytesReader(inw, r.Body, max) } } var origBody io.ReadWriter diff --git a/http/http_test.go b/http/http_test.go index e37b9c3d7693e..692aef0d82877 100644 --- a/http/http_test.go +++ b/http/http_test.go @@ -125,6 +125,16 @@ func testResponseStatus(t *testing.T, resp *http.Response, code int) { } } +func testResponseHeader(t *testing.T, resp *http.Response, expectedHeaders map[string]string) { + t.Helper() + for k, v := range expectedHeaders { + hv := resp.Header.Get(k) + if v != hv { + t.Fatalf("expected header value %v=%v, got %v=%v", k, v, k, hv) + } + } +} + func testResponseBody(t *testing.T, resp *http.Response, out interface{}) { defer resp.Body.Close() diff --git a/http/sys_metrics.go b/http/sys_metrics.go index 0e58be3ea262d..012417282e5f5 100644 --- a/http/sys_metrics.go +++ b/http/sys_metrics.go @@ -35,12 +35,14 @@ func handleMetricsUnauthenticated(core *vault.Core) http.Handler { resp := core.MetricsHelper().ResponseForFormat(format) // Manually extract the logical response and send back the information - w.WriteHeader(resp.Data[logical.HTTPStatusCode].(int)) + status := resp.Data[logical.HTTPStatusCode].(int) w.Header().Set("Content-Type", resp.Data[logical.HTTPContentType].(string)) switch v := resp.Data[logical.HTTPRawBody].(type) { case string: + w.WriteHeader(status) w.Write([]byte(v)) case []byte: + w.WriteHeader(status) w.Write(v) default: respondError(w, http.StatusInternalServerError, fmt.Errorf("wrong response returned")) diff --git a/http/testing.go b/http/testing.go index be9569dc9684c..84ab73fc08006 100644 --- a/http/testing.go +++ b/http/testing.go @@ -6,6 +6,7 @@ import ( "net/http" "testing" + "github.com/hashicorp/vault/internalshared/configutil" "github.com/hashicorp/vault/vault" ) @@ -41,10 +42,16 @@ func TestServerWithListenerAndProperties(tb testing.TB, ln net.Listener, addr st } func TestServerWithListener(tb testing.TB, ln net.Listener, addr string, core *vault.Core) { + ip, _, _ := net.SplitHostPort(ln.Addr().String()) + // Create a muxer to handle our requests so that we can authenticate // for tests. props := &vault.HandlerProperties{ Core: core, + // This is needed for testing custom response headers + ListenerConfig: &configutil.Listener { + Address: ip, + }, } TestServerWithListenerAndProperties(tb, ln, addr, core, props) } diff --git a/internalshared/configutil/http_response_headers.go b/internalshared/configutil/http_response_headers.go new file mode 100644 index 0000000000000..2db3034e588b6 --- /dev/null +++ b/internalshared/configutil/http_response_headers.go @@ -0,0 +1,129 @@ +package configutil + +import ( + "fmt" + "net/textproto" + "strconv" + "strings" + + "github.com/hashicorp/go-secure-stdlib/strutil" +) + +var ValidCustomStatusCodeCollection = []string{ + "default", + "1xx", + "2xx", + "3xx", + "4xx", + "5xx", +} + +const StrictTransportSecurity = "max-age=31536000; includeSubDomains" + +// ParseCustomResponseHeaders takes a raw config values for the +// "custom_response_headers". It makes sure the config entry is passed in +// as a map of status code to a map of header name and header values. It +// verifies the validity of the status codes, and header values. It also +// adds the default headers values. +func ParseCustomResponseHeaders(responseHeaders interface{}) (map[string]map[string]string, error) { + h := make(map[string]map[string]string) + // if r is nil, we still should set the default custom headers + if responseHeaders == nil { + h["default"] = map[string]string{"Strict-Transport-Security": StrictTransportSecurity} + return h, nil + } + + customResponseHeader, ok := responseHeaders.([]map[string]interface{}) + if !ok { + return nil, fmt.Errorf("response headers were not configured correctly. please make sure they're in a slice of maps") + } + + for _, crh := range customResponseHeader { + for statusCode, responseHeader := range crh { + headerValList, ok := responseHeader.([]map[string]interface{}) + if !ok { + return nil, fmt.Errorf("response headers were not configured correctly. please make sure they're in a slice of maps") + } + + if !IsValidStatusCode(statusCode) { + return nil, fmt.Errorf("invalid status code found in the server configuration: %v", statusCode) + } + + if len(headerValList) != 1 { + return nil, fmt.Errorf("invalid number of response headers exist") + } + headerValMap := headerValList[0] + headerVal, err := parseHeaders(headerValMap) + if err != nil { + return nil, err + } + + h[statusCode] = headerVal + } + } + + // setting Strict-Transport-Security as a default header + if h["default"] == nil { + h["default"] = make(map[string]string) + } + if _, ok := h["default"]["Strict-Transport-Security"]; !ok { + h["default"]["Strict-Transport-Security"] = StrictTransportSecurity + } + + return h, nil +} + +// IsValidStatusCode checking for status codes outside the boundary +func IsValidStatusCode(sc string) bool { + if strutil.StrListContains(ValidCustomStatusCodeCollection, sc) { + return true + } + + i, err := strconv.Atoi(sc) + if err != nil { + return false + } + + if i >= 600 || i < 100 { + return false + } + + return true +} + +func parseHeaders(in map[string]interface{}) (map[string]string, error) { + hvMap := make(map[string]string) + for k, v := range in { + // parsing header name + headerName := textproto.CanonicalMIMEHeaderKey(k) + // parsing header values + s, err := parseHeaderValues(v) + if err != nil { + return nil, err + } + hvMap[headerName] = s + } + return hvMap, nil +} + +func parseHeaderValues(header interface{}) (string, error) { + var sl []string + if _, ok := header.([]interface{}); !ok { + return "", fmt.Errorf("headers must be given in a list of strings") + } + headerValList := header.([]interface{}) + for _, vh := range headerValList { + if _, ok := vh.(string); !ok { + return "", fmt.Errorf("found a non-string header value: %v", vh) + } + headerVal := strings.TrimSpace(vh.(string)) + if headerVal == "" { + continue + } + sl = append(sl, headerVal) + + } + s := strings.Join(sl, "; ") + + return s, nil +} diff --git a/internalshared/configutil/listener.go b/internalshared/configutil/listener.go index 98199082895a4..677dbf9df8b3a 100644 --- a/internalshared/configutil/listener.go +++ b/internalshared/configutil/listener.go @@ -99,6 +99,10 @@ type Listener struct { CorsAllowedOrigins []string `hcl:"cors_allowed_origins"` CorsAllowedHeaders []string `hcl:"-"` CorsAllowedHeadersRaw []string `hcl:"cors_allowed_headers,alias:cors_allowed_headers"` + + // Custom Http response headers + CustomResponseHeaders map[string]map[string]string `hcl:"-"` + CustomResponseHeadersRaw interface{} `hcl:"custom_response_headers"` } func (l *Listener) GoString() string { @@ -361,6 +365,17 @@ func ParseListeners(result *SharedConfig, list *ast.ObjectList) error { } } + // HTTP Headers + { + // if CustomResponseHeadersRaw is nil, we still need to set the default headers + customHeadersMap, err := ParseCustomResponseHeaders(l.CustomResponseHeadersRaw) + if err != nil { + return multierror.Prefix(fmt.Errorf("failed to parse custom_response_headers: %w", err), fmt.Sprintf("listeners.%d", i)) + } + l.CustomResponseHeaders = customHeadersMap + l.CustomResponseHeadersRaw = nil + } + result.Listeners = append(result.Listeners, &l) } diff --git a/vault/core.go b/vault/core.go index ddd2a2cf41e69..39f7b6da772b5 100644 --- a/vault/core.go +++ b/vault/core.go @@ -510,6 +510,9 @@ type Core struct { // clusterListener starts up and manages connections on the cluster ports clusterListener *atomic.Value + // customListenerHeader holds custom response headers for a listener + customListenerHeader *atomic.Value + // Telemetry objects metricsHelper *metricsutil.MetricsHelper @@ -769,23 +772,24 @@ func CreateCore(conf *CoreConfig) (*Core, error) { // Setup the core c := &Core{ - entCore: entCore{}, - devToken: conf.DevToken, - physical: conf.Physical, - serviceRegistration: conf.GetServiceRegistration(), - underlyingPhysical: conf.Physical, - storageType: conf.StorageType, - redirectAddr: conf.RedirectAddr, - clusterAddr: new(atomic.Value), - clusterListener: new(atomic.Value), - seal: conf.Seal, - router: NewRouter(), - sealed: new(uint32), - sealMigrationDone: new(uint32), - standby: true, - standbyStopCh: new(atomic.Value), - baseLogger: conf.Logger, - logger: conf.Logger.Named("core"), + entCore: entCore{}, + devToken: conf.DevToken, + physical: conf.Physical, + serviceRegistration: conf.GetServiceRegistration(), + underlyingPhysical: conf.Physical, + storageType: conf.StorageType, + redirectAddr: conf.RedirectAddr, + clusterAddr: new(atomic.Value), + clusterListener: new(atomic.Value), + customListenerHeader: new(atomic.Value), + seal: conf.Seal, + router: NewRouter(), + sealed: new(uint32), + sealMigrationDone: new(uint32), + standby: true, + standbyStopCh: new(atomic.Value), + baseLogger: conf.Logger, + logger: conf.Logger.Named("core"), defaultLeaseTTL: conf.DefaultLeaseTTL, maxLeaseTTL: conf.MaxLeaseTTL, @@ -1005,6 +1009,17 @@ func NewCore(conf *CoreConfig) (*Core, error) { c.clusterListener.Store((*cluster.Listener)(nil)) + // for listeners with custom response headers, configuring customListenerHeader + if conf.RawConfig.Listeners != nil { + uiHeaders, err := c.UIHeaders() + if err != nil { + return nil, err + } + c.customListenerHeader.Store(NewListenerCustomHeader(conf.RawConfig.Listeners, c.logger, uiHeaders)) + } else { + c.customListenerHeader.Store(([]*ListenerCustomHeaders)(nil)) + } + quotasLogger := conf.Logger.Named("quotas") c.allLoggers = append(c.allLoggers, quotasLogger) c.quotaManager, err = quotas.NewManager(quotasLogger, c.quotaLeaseWalker, c.metricSink) @@ -2641,6 +2656,68 @@ func (c *Core) SetConfig(conf *server.Config) { c.logger.Debug("set config", "sanitized config", string(bz)) } +func (c *Core) GetListenerCustomResponseHeaders(listenerAdd string) *ListenerCustomHeaders { + + customHeaders := c.customListenerHeader.Load() + if customHeaders == nil { + return nil + } + + customHeadersList, ok := customHeaders.([]*ListenerCustomHeaders) + if customHeadersList == nil || !ok { + return nil + } + + for _, l := range customHeadersList { + if l.Address == listenerAdd { + return l + } + } + return nil +} + +// ExistCustomResponseHeader checks if a custom header is configured in any +// listener's stanza +func (c *Core) ExistCustomResponseHeader(header string) bool { + customHeaders := c.customListenerHeader.Load() + if customHeaders == nil { + return false + } + + customHeadersList, ok := customHeaders.([]*ListenerCustomHeaders) + if customHeadersList == nil || !ok { + return false + } + + for _, l := range customHeadersList { + exist := l.ExistCustomResponseHeader(header) + if exist { + return true + } + } + + return false +} + +func (c *Core) ReloadCustomResponseHeaders() error { + conf := c.rawConfig.Load() + if conf == nil { + return fmt.Errorf("failed to load core raw config") + } + lns := conf.(*server.Config).Listeners + if lns == nil { + return fmt.Errorf("no listener configured") + } + + uiHeaders, err := c.UIHeaders() + if err != nil { + return err + } + c.customListenerHeader.Store(NewListenerCustomHeader(lns, c.logger, uiHeaders)) + + return nil +} + // SanitizedConfig returns a sanitized version of the current config. // See server.Config.Sanitized for specific values omitted. func (c *Core) SanitizedConfig() map[string]interface{} { diff --git a/vault/custom_response_headers.go b/vault/custom_response_headers.go new file mode 100644 index 0000000000000..54df089547fc4 --- /dev/null +++ b/vault/custom_response_headers.go @@ -0,0 +1,90 @@ +package vault + +import ( + "net/http" + "net/textproto" + "strings" + + log "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/internalshared/configutil" +) + +type ListenerCustomHeaders struct { + Address string + StatusCodeHeaderMap map[string][]*CustomHeader + // ConfiguredHeadersStatusCodeMap field is introduced so that we would not need to loop through + // StatusCodeHeaderMap to see if a header exists, the key for this map is the headers names + configuredHeadersStatusCodeMap map[string][]string +} + +type CustomHeader struct { + Name string + Value string +} + +func NewListenerCustomHeader(ln []*configutil.Listener, logger log.Logger, uiHeaders http.Header) []*ListenerCustomHeaders { + var listenerCustomHeadersList []*ListenerCustomHeaders + + for _, l := range ln { + listenerCustomHeaderStruct := &ListenerCustomHeaders{ + Address: l.Address, + } + listenerCustomHeaderStruct.StatusCodeHeaderMap = make(map[string][]*CustomHeader) + listenerCustomHeaderStruct.configuredHeadersStatusCodeMap = make(map[string][]string) + for statusCode, headerValMap := range l.CustomResponseHeaders { + var customHeaderList []*CustomHeader + for headerName, headerVal := range headerValMap { + // Sanitizing custom headers + // X-Vault- prefix is reserved for Vault internal processes + if strings.HasPrefix(headerName, "X-Vault-") { + logger.Warn("custom headers starting with X-Vault are not valid", "header", headerName) + continue + } + + // Checking for UI headers, if any common header exists, we just log an error + if uiHeaders != nil { + exist := uiHeaders.Get(headerName) + if exist != "" { + logger.Warn("found a duplicate header in UI", "header:", headerName, "Headers defined in the server configuration take precedence.") + } + } + + // Checking if the header value is not an empty string + if headerVal == "" { + logger.Warn("header value is an empty string", "header", headerName, "value", headerVal) + continue + } + + ch := &CustomHeader{ + Name: headerName, + Value: headerVal, + } + + customHeaderList = append(customHeaderList, ch) + + // setting up the reverse map of header to status code for easy lookups + listenerCustomHeaderStruct.configuredHeadersStatusCodeMap[headerName] = append(listenerCustomHeaderStruct.configuredHeadersStatusCodeMap[headerName], statusCode) + } + listenerCustomHeaderStruct.StatusCodeHeaderMap[statusCode] = customHeaderList + } + listenerCustomHeadersList = append(listenerCustomHeadersList, listenerCustomHeaderStruct) + } + + return listenerCustomHeadersList +} + +func (l *ListenerCustomHeaders) ExistCustomResponseHeader(header string) bool { + if header == "" { + return false + } + + if l.StatusCodeHeaderMap == nil { + return false + } + + headerName := textproto.CanonicalMIMEHeaderKey(header) + + headerMap := l.configuredHeadersStatusCodeMap + _, ok := headerMap[headerName] + return ok +} diff --git a/vault/custom_response_headers_test.go b/vault/custom_response_headers_test.go new file mode 100644 index 0000000000000..1ea79cf7ee5eb --- /dev/null +++ b/vault/custom_response_headers_test.go @@ -0,0 +1,174 @@ +package vault + +import ( + "context" + "fmt" + "net/http/httptest" + "strings" + "testing" + + log "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/helper/namespace" + "github.com/hashicorp/vault/internalshared/configutil" + "github.com/hashicorp/vault/sdk/helper/logging" + "github.com/hashicorp/vault/sdk/logical" + "github.com/hashicorp/vault/sdk/physical/inmem" +) + +var defaultCustomHeaders = map[string]string{ + "Strict-Transport-Security": "max-age=1; domains", + "Content-Security-Policy": "default-src 'others'", + "X-Vault-Ignored": "ignored", + "X-Custom-Header": "Custom header value default", + "X-Frame-Options": "Deny", + "X-Content-Type-Options": "nosniff", + "Content-Type": "text/plain; charset=utf-8", + "X-XSS-Protection": "1; mode=block", +} + +var customHeaders307 = map[string]string{ + "X-Custom-Header": "Custom header value 307", +} + +var customHeader3xx = map[string]string{ + "X-Vault-Ignored-3xx": "Ignored 3xx", + "X-Custom-Header": "Custom header value 3xx", +} + +var customHeaders200 = map[string]string{ + "Someheader-200": "200", + "X-Custom-Header": "Custom header value 200", +} + +var customHeader2xx = map[string]string{ + "X-Custom-Header": "Custom header value 2xx", +} + +var customHeader400 = map[string]string{ + "Someheader-400": "400", +} + +func TestConfigCustomHeaders(t *testing.T) { + logger := logging.NewVaultLogger(log.Trace) + phys, err := inmem.NewTransactionalInmem(nil, logger) + if err != nil { + t.Fatal(err) + } + logl := &logical.InmemStorage{} + uiConfig := NewUIConfig(true, phys, logl) + + rawListenerConfig := []*configutil.Listener{ + { + Type: "tcp", + Address: "127.0.0.1:443", + CustomResponseHeaders: map[string]map[string]string{ + "default": defaultCustomHeaders, + "307": customHeaders307, + "3xx": customHeader3xx, + "200": customHeaders200, + "2xx": customHeader2xx, + "400": customHeader400, + }, + }, + } + + uiHeaders, err := uiConfig.Headers(context.Background()) + listenerCustomHeaders := NewListenerCustomHeader(rawListenerConfig, logger, uiHeaders) + if listenerCustomHeaders == nil || len(listenerCustomHeaders) != 1 { + t.Fatalf("failed to get custom header configuration") + } + + lch := listenerCustomHeaders[0] + + if lch.ExistCustomResponseHeader("X-Vault-Ignored-307") { + t.Fatalf("header name with X-Vault prefix is not valid") + } + if lch.ExistCustomResponseHeader("X-Vault-Ignored-3xx") { + t.Fatalf("header name with X-Vault prefix is not valid") + } + + if !lch.ExistCustomResponseHeader("X-Custom-Header") { + t.Fatalf("header name with X-Vault prefix is not valid") + } +} + +func TestCustomResponseHeadersConfigInteractUiConfig(t *testing.T) { + b := testSystemBackend(t) + _, barrier, _ := mockBarrier(t) + view := NewBarrierView(barrier, "") + b.(*SystemBackend).Core.systemBarrierView = view + + logger := logging.NewVaultLogger(log.Trace) + rawListenerConfig := []*configutil.Listener{ + { + Type: "tcp", + Address: "127.0.0.1:443", + CustomResponseHeaders: map[string]map[string]string{ + "default": defaultCustomHeaders, + "307": customHeaders307, + "3xx": customHeader3xx, + "200": customHeaders200, + "2xx": customHeader2xx, + "400": customHeader400, + }, + }, + } + uiHeaders, err := b.(*SystemBackend).Core.uiConfig.Headers(context.Background()) + if err != nil { + t.Fatalf("failed to get headers from ui config") + } + customListenerHeader := NewListenerCustomHeader(rawListenerConfig, logger, uiHeaders) + if customListenerHeader == nil { + t.Fatalf("custom header config should be configured") + } + b.(*SystemBackend).Core.customListenerHeader.Store(customListenerHeader) + clh := b.(*SystemBackend).Core.customListenerHeader + if clh == nil { + t.Fatalf("custom header config should be configured in core") + } + + w := httptest.NewRecorder() + hw := logical.NewHTTPResponseWriter(w) + + // setting a header that already exist in custom headers + req := logical.TestRequest(t, logical.UpdateOperation, "config/ui/headers/X-Custom-Header") + req.Data["values"] = []string{"UI Custom Header"} + req.ResponseWriter = hw + + resp, err := b.HandleRequest(namespace.RootContext(nil), req) + if err == nil { + t.Fatal("request did not fail on setting a header that is present in custom response headers") + } + if !strings.Contains(resp.Data["error"].(string), fmt.Sprintf("This header already exists in the server configuration and cannot be set in the UI.")) { + t.Fatalf("failed to get the expected error") + } + + // setting a header that already exist in custom headers + req = logical.TestRequest(t, logical.UpdateOperation, "config/ui/headers/Someheader-400") + req.Data["values"] = []string{"400"} + req.ResponseWriter = hw + + _, err = b.HandleRequest(namespace.RootContext(nil), req) + if err == nil { + t.Fatal("request did not fail on setting a header that is present in custom response headers") + } + h, err := b.(*SystemBackend).Core.uiConfig.Headers(context.Background()) + if h.Get("Someheader-400") == "400" { + t.Fatalf("should not be able to set a header that is in custom response headers") + } + + // setting an ui specific header + req = logical.TestRequest(t, logical.UpdateOperation, "config/ui/headers/X-CustomUiHeader") + req.Data["values"] = []string{"Ui header value"} + req.ResponseWriter = hw + + _, err = b.HandleRequest(namespace.RootContext(nil), req) + if err != nil { + t.Fatal("request failed on setting a header that is not present in custom response headers.", "error:", err) + } + + h, err = b.(*SystemBackend).Core.uiConfig.Headers(context.Background()) + if h.Get("X-CustomUiHeader") != "Ui header value" { + t.Fatalf("failed to set a header that is not in custom response headers") + } +} diff --git a/vault/logical_system.go b/vault/logical_system.go index 6179287818861..1675475a422a3 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -2623,6 +2623,9 @@ func (b *SystemBackend) handleConfigUIHeadersUpdate(ctx context.Context, req *lo // Translate the list of values to the valid header string value := http.Header{} for _, v := range values { + if b.Core.ExistCustomResponseHeader(header) { + return logical.ErrorResponse("This header already exists in the server configuration and cannot be set in the UI."), logical.ErrInvalidRequest + } value.Add(header, v) } err := b.Core.uiConfig.SetHeader(ctx, header, value.Values(header)) diff --git a/vault/testing.go b/vault/testing.go index dbe921969d493..6ba933e85d99f 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -128,6 +128,29 @@ func TestCoreWithSeal(t testing.T, testSeal Seal, enableRaw bool) *Core { return TestCoreWithSealAndUI(t, conf) } +func TestCoreWithCustomResponseHeaderAndUI(t testing.T, CustomResponseHeaders map[string]map[string]string, enableUI bool) (*Core, [][]byte, string) { + confRaw := &server.Config{ + SharedConfig: &configutil.SharedConfig{ + Listeners: []*configutil.Listener{ + { + Type: "tcp", + Address: "127.0.0.1", + CustomResponseHeaders: CustomResponseHeaders, + }, + }, + DisableMlock: true, + }, + } + conf := &CoreConfig{ + RawConfig: confRaw, + EnableUI: enableUI, + EnableRaw: true, + BuiltinRegistry: NewMockBuiltinRegistry(), + } + core := TestCoreWithSealAndUI(t, conf) + return testCoreUnsealed(t, core) +} + func TestCoreUI(t testing.T, enableUI bool) *Core { conf := &CoreConfig{ EnableUI: enableUI, diff --git a/vault/ui.go b/vault/ui.go index c36a247af304a..bd1d3c6882147 100644 --- a/vault/ui.go +++ b/vault/ui.go @@ -32,8 +32,9 @@ type UIConfig struct { // NewUIConfig creates a new UI config func NewUIConfig(enabled bool, physicalStorage physical.Backend, barrierStorage logical.Storage) *UIConfig { defaultHeaders := http.Header{} - defaultHeaders.Set("Content-Security-Policy", "default-src 'none'; connect-src 'self'; img-src 'self' data:; script-src 'self'; style-src 'unsafe-inline' 'self'; form-action 'none'; frame-ancestors 'none'; font-src 'self'") defaultHeaders.Set("Service-Worker-Allowed", "/") + defaultHeaders.Set("X-Content-Type-Options", "nosniff") + defaultHeaders.Set("Content-Security-Policy", "default-src 'none'; connect-src 'self'; img-src 'self' data:; script-src 'self'; style-src 'unsafe-inline' 'self'; form-action 'none'; frame-ancestors 'none'; font-src 'self'") return &UIConfig{ physicalStorage: physicalStorage, From 56c6f3c9d04b0eba220811a34af3720e464e8bf9 Mon Sep 17 00:00:00 2001 From: John-Michael Faircloth Date: Wed, 13 Oct 2021 11:51:20 -0500 Subject: [PATCH 07/12] Add support to parameterize unauthenticated paths (#12668) * store unauthenticated path wildcards in map * working unauthenticated paths with basic unit tests * refactor wildcard logic * add parseUnauthenticatedPaths unit tests * use parseUnauthenticatedPaths when reloading backend * add more wildcard test cases * update special paths doc; add changelog * remove buggy prefix check; add test cases * prevent false positives for prefix matches If we ever encounter a mismatched segment, break and set a flag to prevent false positives for prefix matches. If it is a match we need to do a prefix check. But we should not return unless HasPrefix also evaluates to true. Otherwise we should let the for loop continue to check other possibilities and only return false once all wildcard paths have been evaluated. * refactor switch and add more test cases * remove comment leftover from debug session * add more wildcard path validation and test cases * update changelong; feature -> improvement * simplify wildcard segment matching logic * refactor wildcard matching into func * fix glob matching, add more wildcard validation, refactor * refactor common wildcard errors to func * move doc comment to logical.Paths * optimize wildcard paths storage with pre-split slices * fix comment typo * fix test case after changing wildcard paths storage type * move prefix check to parseUnauthenticatedPaths * tweak regex, remove unneeded array copy, refactor * add test case around wildcard and glob matching --- changelog/12668.txt | 3 + sdk/framework/backend.go | 6 +- sdk/logical/logical.go | 4 ++ vault/identity_store.go | 2 + vault/plugin_reload.go | 6 +- vault/router.go | 139 +++++++++++++++++++++++++++++++++--- vault/router_test.go | 150 +++++++++++++++++++++++++++++++++++++++ 7 files changed, 295 insertions(+), 15 deletions(-) create mode 100644 changelog/12668.txt diff --git a/changelog/12668.txt b/changelog/12668.txt new file mode 100644 index 0000000000000..7006da572c44a --- /dev/null +++ b/changelog/12668.txt @@ -0,0 +1,3 @@ +```release-note:improvement +sdk/framework: The '+' wildcard is now supported for parameterizing unauthenticated paths. +``` diff --git a/sdk/framework/backend.go b/sdk/framework/backend.go index c2c3f1810008b..d498dd4ee09f0 100644 --- a/sdk/framework/backend.go +++ b/sdk/framework/backend.go @@ -41,10 +41,8 @@ type Backend struct { // paths, including adding or removing, is not allowed once the // backend is in use). // - // PathsSpecial is the list of path patterns that denote the - // paths above that require special privileges. These can't be - // regular expressions, it is either exact match or prefix match. - // For prefix match, append '*' as a suffix. + // PathsSpecial is the list of path patterns that denote the paths above + // that require special privileges. Paths []*Path PathsSpecial *logical.Paths diff --git a/sdk/logical/logical.go b/sdk/logical/logical.go index db8831535a8e8..cec2d19c0e6ed 100644 --- a/sdk/logical/logical.go +++ b/sdk/logical/logical.go @@ -117,6 +117,10 @@ type Paths struct { Root []string // Unauthenticated are the paths that can be accessed without any auth. + // These can't be regular expressions, it is either exact match, a prefix + // match and/or a wildcard match. For prefix match, append '*' as a suffix. + // For a wildcard match, use '+' in the segment to match any identifier + // (e.g. 'foo/+/bar'). Note that '+' can't be adjacent to a non-slash. Unauthenticated []string // LocalStorage are paths (prefixes) that are local to this instance; this diff --git a/vault/identity_store.go b/vault/identity_store.go index 98a5350d994a4..18e3a215ac399 100644 --- a/vault/identity_store.go +++ b/vault/identity_store.go @@ -89,6 +89,8 @@ func NewIdentityStore(ctx context.Context, core *Core, config *logical.BackendCo PathsSpecial: &logical.Paths{ Unauthenticated: []string{ "oidc/.well-known/*", + "oidc/provider/+/.well-known/*", + "oidc/provider/+/token", }, }, PeriodicFunc: func(ctx context.Context, req *logical.Request) error { diff --git a/vault/plugin_reload.go b/vault/plugin_reload.go index 732d60bfaa82e..e38b5341aaf67 100644 --- a/vault/plugin_reload.go +++ b/vault/plugin_reload.go @@ -188,7 +188,11 @@ func (c *Core) reloadBackendCommon(ctx context.Context, entry *MountEntry, isAut paths := backend.SpecialPaths() if paths != nil { re.rootPaths.Store(pathsToRadix(paths.Root)) - re.loginPaths.Store(pathsToRadix(paths.Unauthenticated)) + loginPathsEntry, err := parseUnauthenticatedPaths(paths.Unauthenticated) + if err != nil { + return err + } + re.loginPaths.Store(loginPathsEntry) } } diff --git a/vault/router.go b/vault/router.go index 6d61d7f0e4b2d..6426e4eb8f206 100644 --- a/vault/router.go +++ b/vault/router.go @@ -3,6 +3,7 @@ package vault import ( "context" "fmt" + "regexp" "strings" "sync" "sync/atomic" @@ -22,6 +23,9 @@ var deniedPassthroughRequestHeaders = []string{ consts.AuthHeaderName, } +// matches when '+' is next to a non-slash char +var wcAdjacentNonSlashRegEx = regexp.MustCompile(`\+[^/]|[^/]\+`).MatchString + // Router is used to do prefix based routing of a request to a logical backend type Router struct { l sync.RWMutex @@ -59,6 +63,19 @@ type routeEntry struct { l sync.RWMutex } +type wildcardPath struct { + // this sits in the hot path of requests so we are micro-optimizing by + // storing pre-split slices of path segments + segments []string + isPrefix bool +} + +// loginPathsEntry is used to hold the routeEntry loginPaths +type loginPathsEntry struct { + paths *radix.Tree + wildcardPaths []wildcardPath +} + type ValidateMountResponse struct { MountType string `json:"mount_type" structs:"mount_type" mapstructure:"mount_type"` MountAccessor string `json:"mount_accessor" structs:"mount_accessor" mapstructure:"mount_accessor"` @@ -137,7 +154,11 @@ func (r *Router) Mount(backend logical.Backend, prefix string, mountEntry *Mount storageView: storageView, } re.rootPaths.Store(pathsToRadix(paths.Root)) - re.loginPaths.Store(pathsToRadix(paths.Unauthenticated)) + loginPathsEntry, err := parseUnauthenticatedPaths(paths.Unauthenticated) + if err != nil { + return err + } + re.loginPaths.Store(loginPathsEntry) switch { case prefix == "": @@ -782,6 +803,10 @@ func (r *Router) RootPath(ctx context.Context, path string) bool { } // LoginPath checks if the given path is used for logins +// Matching Priority +// 1. prefix +// 2. exact +// 3. wildcard func (r *Router) LoginPath(ctx context.Context, path string) bool { ns, err := namespace.FromContext(ctx) if err != nil { @@ -802,20 +827,114 @@ func (r *Router) LoginPath(ctx context.Context, path string) bool { remain := strings.TrimPrefix(adjustedPath, mount) // Check the loginPaths of this backend - loginPaths := re.loginPaths.Load().(*radix.Tree) - match, raw, ok := loginPaths.LongestPrefix(remain) - if !ok { + pe := re.loginPaths.Load().(*loginPathsEntry) + match, raw, ok := pe.paths.LongestPrefix(remain) + if !ok && len(pe.wildcardPaths) == 0 { + // no match found return false } - prefixMatch := raw.(bool) - // Handle the prefix match case - if prefixMatch { - return strings.HasPrefix(remain, match) + if ok { + prefixMatch := raw.(bool) + if prefixMatch { + // Handle the prefix match case + return strings.HasPrefix(remain, match) + } + if match == remain { + // Handle the exact match case + return true + } } - // Handle the exact match case - return match == remain + // check Login Paths containing wildcards + reqPathParts := strings.Split(remain, "/") + for _, w := range pe.wildcardPaths { + if pathMatchesWildcardPath(reqPathParts, w.segments, w.isPrefix) { + return true + } + } + return false +} + +// pathMatchesWildcardPath returns true if the path made up of the path slice +// matches the given wildcard path slice +func pathMatchesWildcardPath(path, wcPath []string, isPrefix bool) bool { + if len(wcPath) == 0 { + return false + } + + if len(path) < len(wcPath) { + // check if the path coming in is shorter; if so it can't match + return false + } + if !isPrefix && len(wcPath) != len(path) { + // If it's not a prefix we expect the same number of segments + return false + } + + for i, wcPathPart := range wcPath { + switch { + case wcPathPart == "+": + case wcPathPart == path[i]: + case isPrefix && i == len(wcPath)-1 && strings.HasPrefix(path[i], wcPathPart): + default: + // we encountered segments that did not match + return false + } + } + return true +} + +func wildcardError(path, msg string) error { + return fmt.Errorf("path %q: invalid use of wildcards %s", path, msg) +} + +func isValidUnauthenticatedPath(path string) (bool, error) { + switch { + case strings.Count(path, "*") > 1: + return false, wildcardError(path, "(multiple '*' is forbidden)") + case strings.Contains(path, "+*"): + return false, wildcardError(path, "('+*' is forbidden)") + case strings.Contains(path, "*") && path[len(path)-1] != '*': + return false, wildcardError(path, "('*' is only allowed at the end of a path)") + case wcAdjacentNonSlashRegEx(path): + return false, wildcardError(path, "('+' is not allowed next to a non-slash)") + } + return true, nil +} + +// parseUnauthenticatedPaths converts a list of special paths to a +// loginPathsEntry +func parseUnauthenticatedPaths(paths []string) (*loginPathsEntry, error) { + var tempPaths []string + tempWildcardPaths := make([]wildcardPath, 0) + for _, path := range paths { + if ok, err := isValidUnauthenticatedPath(path); !ok { + return nil, err + } + + if strings.Contains(path, "+") { + // Paths with wildcards are not stored in the radix tree because + // the radix tree does not handle wildcards in the middle of strings. + isPrefix := false + if path[len(path)-1] == '*' { + isPrefix = true + path = path[0 : len(path)-1] + } + // We are micro-optimizing by storing pre-split slices of path segments + wcPath := wildcardPath{segments: strings.Split(path, "/"), isPrefix: isPrefix} + tempWildcardPaths = append(tempWildcardPaths, wcPath) + } else { + // accumulate paths that do not contain wildcards + // to be stored in the radix tree + tempPaths = append(tempPaths, path) + } + } + + return &loginPathsEntry{ + paths: pathsToRadix(tempPaths), + wildcardPaths: tempWildcardPaths, + }, nil } // pathsToRadix converts a list of special paths to a radix tree. diff --git a/vault/router_test.go b/vault/router_test.go index b20b69894fe4f..842d75f62f4a6 100644 --- a/vault/router_test.go +++ b/vault/router_test.go @@ -348,6 +348,15 @@ func TestRouter_LoginPath(t *testing.T) { Login: []string{ "login", "oauth/*", + "glob1*", + "+/wildcard/glob2*", + "end1/+", + "end2/+/", + "end3/+/*", + "middle1/+/bar", + "middle2/+/+/bar", + "+/begin", + "+/around/+/", }, } err = r.Mount(n, "auth/foo/", &MountEntry{UUID: meUUID, Accessor: "authfooaccessor", NamespaceID: namespace.RootNamespaceID, namespace: namespace.RootNamespace}, view) @@ -363,8 +372,70 @@ func TestRouter_LoginPath(t *testing.T) { {"random", false}, {"auth/foo/bar", false}, {"auth/foo/login", true}, + {"auth/foo/login/", false}, {"auth/foo/oauth", false}, + {"auth/foo/oauth/", true}, {"auth/foo/oauth/redirect", true}, + {"auth/foo/oauth/redirect/", true}, + {"auth/foo/oauth/redirect/bar", true}, + {"auth/foo/glob1", true}, + {"auth/foo/glob1/", true}, + {"auth/foo/glob1/redirect", true}, + + // Wildcard cases + + // "+/wildcard/glob2*" + {"auth/foo/bar/wildcard/glo", false}, + {"auth/foo/bar/wildcard/glob2", true}, + {"auth/foo/bar/wildcard/glob2222", true}, + {"auth/foo/bar/wildcard/glob2/", true}, + {"auth/foo/bar/wildcard/glob2/baz", true}, + + // "end1/+" + {"auth/foo/end1", false}, + {"auth/foo/end1/", true}, + {"auth/foo/end1/bar", true}, + {"auth/foo/end1/bar/", false}, + {"auth/foo/end1/bar/baz", false}, + // "end2/+/" + {"auth/foo/end2", false}, + {"auth/foo/end2/", false}, + {"auth/foo/end2/bar", false}, + {"auth/foo/end2/bar/", true}, + {"auth/foo/end2/bar/baz", false}, + // "end3/+/*" + {"auth/foo/end3", false}, + {"auth/foo/end3/", false}, + {"auth/foo/end3/bar", false}, + {"auth/foo/end3/bar/", true}, + {"auth/foo/end3/bar/baz", true}, + {"auth/foo/end3/bar/baz/", true}, + {"auth/foo/end3/bar/baz/qux", true}, + {"auth/foo/end3/bar/baz/qux/qoo", true}, + {"auth/foo/end3/bar/baz/qux/qoo/qaa", true}, + // "middle1/+/bar", + {"auth/foo/middle1/bar", false}, + {"auth/foo/middle1/bar/", false}, + {"auth/foo/middle1/bar/qux", false}, + {"auth/foo/middle1/bar/bar", true}, + {"auth/foo/middle1/bar/bar/", false}, + // "middle2/+/+/bar", + {"auth/foo/middle2/bar", false}, + {"auth/foo/middle2/bar/", false}, + {"auth/foo/middle2/bar/baz", false}, + {"auth/foo/middle2/bar/baz/", false}, + {"auth/foo/middle2/bar/baz/bar", true}, + {"auth/foo/middle2/bar/baz/bar/", false}, + // "+/begin" + {"auth/foo/bar/begin", true}, + {"auth/foo/bar/begin/", false}, + {"auth/foo/begin", false}, + // "+/around/+/" + {"auth/foo/bar/around", false}, + {"auth/foo/bar/around/", false}, + {"auth/foo/bar/around/baz", false}, + {"auth/foo/bar/around/baz/", true}, + {"auth/foo/bar/around/baz/qux", false}, } for _, tc := range tcases { @@ -477,3 +548,82 @@ func TestPathsToRadix(t *testing.T) { t.Fatalf("bad: %v (sub/bar)", raw) } } + +func TestParseUnauthenticatedPaths(t *testing.T) { + // inputs + paths := []string{ + "foo", + "foo/*", + "sub/bar*", + } + wildcardPaths := []string{ + "end/+", + "+/begin/*", + "middle/+/bar*", + } + allPaths := append(paths, wildcardPaths...) + + p, err := parseUnauthenticatedPaths(allPaths) + if err != nil { + t.Fatal(err) + } + + // outputs + wildcardPathsEntry := []wildcardPath{ + {segments: []string{"end", "+"}, isPrefix: false}, + {segments: []string{"+", "begin", ""}, isPrefix: true}, + {segments: []string{"middle", "+", "bar"}, isPrefix: true}, + } + expected := &loginPathsEntry{ + paths: pathsToRadix(paths), + wildcardPaths: wildcardPathsEntry, + } + + if !reflect.DeepEqual(expected, p) { + t.Fatalf("expected: %#v\n actual: %#v\n", expected, p) + } +} + +func TestParseUnauthenticatedPaths_Error(t *testing.T) { + type tcase struct { + paths []string + err string + } + tcases := []tcase{ + { + []string{"/foo/+*"}, + "path \"/foo/+*\": invalid use of wildcards ('+*' is forbidden)", + }, + { + []string{"/foo/*/*"}, + "path \"/foo/*/*\": invalid use of wildcards (multiple '*' is forbidden)", + }, + { + []string{"*/foo/*"}, + "path \"*/foo/*\": invalid use of wildcards (multiple '*' is forbidden)", + }, + { + []string{"*/foo/"}, + "path \"*/foo/\": invalid use of wildcards ('*' is only allowed at the end of a path)", + }, + { + []string{"/foo+"}, + "path \"/foo+\": invalid use of wildcards ('+' is not allowed next to a non-slash)", + }, + { + []string{"/+foo"}, + "path \"/+foo\": invalid use of wildcards ('+' is not allowed next to a non-slash)", + }, + { + []string{"/++"}, + "path \"/++\": invalid use of wildcards ('+' is not allowed next to a non-slash)", + }, + } + + for _, tc := range tcases { + _, err := parseUnauthenticatedPaths(tc.paths) + if err == nil || err != nil && !strings.Contains(err.Error(), tc.err) { + t.Fatalf("bad: path: %s expect: %v got %v", tc.paths, tc.err, err) + } + } +} From 1f1459eba2a368e24c267d74bb967cdd7c32e9ee Mon Sep 17 00:00:00 2001 From: claire bontempo <68122737+hellobontempo@users.noreply.github.com> Date: Wed, 13 Oct 2021 10:52:23 -0700 Subject: [PATCH 08/12] UI/InfoTableRow testing (#12811) * updates storybook * adds computed property valueIsEmpty * adds tests to info table row --- .../core/addon/components/info-table-row.js | 17 ++++++ .../templates/components/info-table-row.hbs | 2 +- ui/lib/core/stories/info-table-row.md | 8 +++ ui/lib/core/stories/info-table-row.stories.js | 20 +++++-- .../components/info-table-row-test.js | 53 ++++++++++++++----- 5 files changed, 82 insertions(+), 18 deletions(-) diff --git a/ui/lib/core/addon/components/info-table-row.js b/ui/lib/core/addon/components/info-table-row.js index 2ab9083e82af4..36a81e7d90c98 100644 --- a/ui/lib/core/addon/components/info-table-row.js +++ b/ui/lib/core/addon/components/info-table-row.js @@ -44,4 +44,21 @@ export default Component.extend({ valueIsBoolean: computed('value', function() { return typeOf(this.value) === 'boolean'; }), + + valueIsEmpty: computed('value', function() { + let { value } = this; + if (typeOf(value) === 'array' && value.length === 0) { + return true; + } + switch (value) { + case undefined: + return true; + case null: + return true; + case '': + return true; + default: + return false; + } + }), }); diff --git a/ui/lib/core/addon/templates/components/info-table-row.hbs b/ui/lib/core/addon/templates/components/info-table-row.hbs index 7fcf0801811ac..0f3bfb2ae4107 100644 --- a/ui/lib/core/addon/templates/components/info-table-row.hbs +++ b/ui/lib/core/addon/templates/components/info-table-row.hbs @@ -34,7 +34,7 @@ {{!-- alwaysRender is still true --}} {{else if (and (not value) defaultShown)}} {{defaultShown}} - {{else if (eq value "")}} + {{else if valueIsEmpty}} {{else}} {{#if (eq type 'array')}} diff --git a/ui/lib/core/stories/info-table-row.md b/ui/lib/core/stories/info-table-row.md index d5c0a54074579..07bfc4394eb02 100644 --- a/ui/lib/core/stories/info-table-row.md +++ b/ui/lib/core/stories/info-table-row.md @@ -12,6 +12,14 @@ that the value breaks under the label on smaller viewports. | label | string | null | The display name for the value. | | helperText | string | null | Text to describe the value displayed beneath the label. | | alwaysRender | Boolean | false | Indicates if the component content should be always be rendered. When false, the value of `value` will be used to determine if the component should render. | +| [type] | string | "array" | The type of value being passed in. This is used for when you want to trim an array. For example, if you have an array value that can equal length 15+ this will trim to show 5 and count how many more are there | +| [isLink] | Boolean | true | Passed through to InfoTableItemArray. Indicates if the item should contain a link-to component. Only setup for arrays, but this could be changed if needed. | +| [modelType] | string | null | Passed through to InfoTableItemArray. Tells what model you want data for the allOptions to be returned from. Used in conjunction with the the isLink. | +| [queryParam] | String | | Passed through to InfoTableItemArray. If you want to specific a tab for the View All XX to display to. Ex: role | +| [backend] | String | | Passed through to InfoTableItemArray. To specify secrets backend to point link to Ex: transformation | +| [viewAll] | String | | Passed through to InfoTableItemArray. Specify the word at the end of the link View all. | +| [tooltipText] | String | | Text if a tooltip should display over the value. | +| [defaultShown] | String | | Text that renders as value if alwaysRender=true. Eg. "Vault default" | **Example** diff --git a/ui/lib/core/stories/info-table-row.stories.js b/ui/lib/core/stories/info-table-row.stories.js index c16812e718ed6..b19b836a3bdde 100644 --- a/ui/lib/core/stories/info-table-row.stories.js +++ b/ui/lib/core/stories/info-table-row.stories.js @@ -11,13 +11,19 @@ storiesOf('InfoTable/InfoTableRow', module) () => ({ template: hbs`
    Info Table Row
    - + `, context: { label: text('Label', 'TTL'), - value: text('Value', '30m'), - helperText: text('helperText', 'A short description'), + value: text('Value', '30m (hover to see the tooltip!)'), + helperText: text('helperText', 'This is helperText - for a short description'), alwaysRender: boolean('Always render?', false), + tooltipText: text('tooltipText', 'This is tooltipText'), }, }), { notes } @@ -27,12 +33,16 @@ storiesOf('InfoTable/InfoTableRow', module) () => ({ template: hbs`
    Info Table Row
    - + `, context: { label: 'Local mount?', value: boolean('Value', true), - helperText: text('helperText', 'A short description'), + helperText: text('helperText', 'This is helperText - for a short description'), alwaysRender: boolean('Always render?', true), }, }), diff --git a/ui/tests/integration/components/info-table-row-test.js b/ui/tests/integration/components/info-table-row-test.js index bf8462089e6df..9093109b203c5 100644 --- a/ui/tests/integration/components/info-table-row-test.js +++ b/ui/tests/integration/components/info-table-row-test.js @@ -2,12 +2,13 @@ import { module, test } from 'qunit'; import { resolve } from 'rsvp'; import Service from '@ember/service'; import { setupRenderingTest } from 'ember-qunit'; -import { render } from '@ember/test-helpers'; +import { render, settled, triggerEvent } from '@ember/test-helpers'; import hbs from 'htmlbars-inline-precompile'; -const VALUE = 'testing'; -const LABEL = 'item'; +const VALUE = 'test value'; +const LABEL = 'test label'; const TYPE = 'array'; +const DEFAULT = 'some default value'; const routerService = Service.extend({ transitionTo() { @@ -22,13 +23,14 @@ const routerService = Service.extend({ }, }); -module('Integration | Component | InfoTableItem', function(hooks) { +module('Integration | Component | InfoTableRow', function(hooks) { setupRenderingTest(hooks); hooks.beforeEach(function() { this.set('value', VALUE); this.set('label', LABEL); this.set('type', TYPE); + this.set('default', DEFAULT); this.owner.register('service:router', routerService); this.router = this.owner.lookup('service:router'); }); @@ -38,11 +40,10 @@ module('Integration | Component | InfoTableItem', function(hooks) { }); test('it renders', async function(assert) { - this.set('alwaysRender', true); - await render(hbs``); assert.dom('[data-test-component="info-table-row"]').exists(); @@ -50,7 +51,24 @@ module('Integration | Component | InfoTableItem', function(hooks) { assert.equal(string, VALUE, 'renders value as passed through'); this.set('value', ''); - assert.dom('[data-test-label-div]').doesNotExist('does not render if no value and alwaysRender is false'); + assert + .dom('[data-test-label-div]') + .doesNotExist('does not render if no value and alwaysRender is false (even if default exists)'); + }); + + test('it renders a tooltip', async function(assert) { + this.set('tooltipText', 'Tooltip text!'); + + await render(hbs``); + + await triggerEvent('[data-test-value-div="test label"] .ember-basic-dropdown-trigger', 'mouseenter'); + await settled(); + let tooltip = document.querySelector('div.box').textContent.trim(); + assert.equal(tooltip, 'Tooltip text!', 'renders tooltip text'); }); test('it renders a string with no link if isLink is true and the item type is not an array.', async function(assert) { @@ -76,27 +94,38 @@ module('Integration | Component | InfoTableItem', function(hooks) { assert.dom('[data-test-item="array"]').hasText('valueArray', 'Confirm link with item value exist'); }); - test('it renders a dash (-) if a label and/or value do not exist', async function(assert) { + test('it renders as expected if a label and/or value do not exist', async function(assert) { this.set('value', VALUE); this.set('label', ''); + this.set('default', ''); await render(hbs``); - assert.dom('[data-test-label-div]').exists('renders label div'); - assert.dom('[data-test-value-div]').exists('renders value div'); + assert.dom('div.column span').hasClass('hs-icon-s', 'Renders a dash (-) for the label'); this.set('value', ''); this.set('label', LABEL); - assert.dom('div.column.is-flex span').hasClass('hs-icon-s', 'Renders a dash (-) for the value'); + assert.dom('div.column.is-flex span').hasClass('hs-icon-s', 'Renders a dash (-) for empty string value'); + + this.set('value', null); + assert.dom('div.column.is-flex span').hasClass('hs-icon-s', 'Renders a dash (-) for null value'); + + this.set('value', undefined); + assert.dom('div.column.is-flex span').hasClass('hs-icon-s', 'Renders a dash (-) for undefined value'); + + this.set('default', DEFAULT); + assert.dom('[data-test-value-div]').hasText(DEFAULT, 'Renders default text if value is empty'); this.set('value', ''); this.set('label', ''); + this.set('default', ''); let dashCount = document.querySelectorAll('.hs-icon-s').length; - assert.equal(dashCount, 2, 'Renders dash (-) when both label and value do not exist'); + assert.equal(dashCount, 2, 'Renders dash (-) when both label and value do not exist (and no defaults)'); }); test('block content overrides any passed in value content', async function(assert) { From 6f65a4addcee1b18578fdf8ca0d4122a3674cea9 Mon Sep 17 00:00:00 2001 From: Chris Capurso Date: Wed, 13 Oct 2021 15:24:31 -0400 Subject: [PATCH 09/12] Add HTTP PATCH support to KV (#12687) * handle HTTP PATCH requests as logical.PatchOperation * update go.mod, go.sum * a nil response for logical.PatchOperation should result in 404 * respond with 415 for incorrect MIME type in PATCH Content-Type header * add abstraction to handle PatchOperation requests * add ACLs for patch * Adding JSON Merge support to the API client * add HTTP PATCH tests to check high level response logic * add permission-based 'kv patch' tests in prep to add HTTP PATCH * adding more 'kv patch' CLI command tests * fix TestHandler_Patch_NotFound * Fix TestKvPatchCommand_StdinValue * add audit log test for HTTP PATCH * patch CLI changes * add patch CLI tests * change JSONMergePatch func to accept a ctx * fix TestKVPatchCommand_RWMethodNotExists and TestKVPatchCommand_RWMethodSucceeds to specify -method flag * go fmt * add a test to verify patching works by default with the root token * add changelog entry * get vault-plugin-secrets-kv@add-patch-support * PR feedback * reorder some imports; go fmt * add doc comment for HandlePatchOperation * add json-patch@v5.5.0 to go.mod * remove unnecessary cancelFunc for WriteBytes * remove default for -method * use stable version of json-patch; go mod tidy * more PR feedback * temp go get vault-plugin-secrets-kv@master until official release Co-authored-by: Josh Black --- api/api_test.go | 3 +- api/logical.go | 24 +- changelog/12687.txt | 5 + command/kv_patch.go | 133 +++- command/kv_test.go | 604 ++++++++++++++++++ go.mod | 5 +- go.sum | 11 +- http/handler_test.go | 199 ++++++ http/logical.go | 31 + sdk/framework/backend.go | 57 ++ sdk/logical/request.go | 1 + sdk/logical/response_util.go | 2 +- vault/acl.go | 7 +- vault/acl_test.go | 18 + vault/external_tests/kv/kv_patch_test.go | 57 ++ .../{misc => kv}/kvv2_upgrade_test.go | 2 +- vault/logical_system.go | 6 +- vault/policy.go | 5 +- vault/policy_test.go | 10 + 19 files changed, 1151 insertions(+), 29 deletions(-) create mode 100644 changelog/12687.txt create mode 100644 vault/external_tests/kv/kv_patch_test.go rename vault/external_tests/{misc => kv}/kvv2_upgrade_test.go (99%) diff --git a/api/api_test.go b/api/api_test.go index b2b851df6e89f..e4ba3153203eb 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -9,8 +9,7 @@ import ( // testHTTPServer creates a test HTTP server that handles requests until // the listener returned is closed. -func testHTTPServer( - t *testing.T, handler http.Handler) (*Config, net.Listener) { +func testHTTPServer(t *testing.T, handler http.Handler) (*Config, net.Listener) { ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("err: %s", err) diff --git a/api/logical.go b/api/logical.go index cd950a2b78968..f8f8bc5376612 100644 --- a/api/logical.go +++ b/api/logical.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "io" + "net/http" "net/url" "os" @@ -130,24 +131,37 @@ func (c *Logical) List(path string) (*Secret, error) { } func (c *Logical) Write(path string, data map[string]interface{}) (*Secret, error) { + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + r := c.c.NewRequest("PUT", "/v1/"+path) if err := r.SetJSONBody(data); err != nil { return nil, err } - return c.write(path, r) + return c.write(ctx, path, r) +} + +func (c *Logical) JSONMergePatch(ctx context.Context, path string, data map[string]interface{}) (*Secret, error) { + r := c.c.NewRequest("PATCH", "/v1/"+path) + r.Headers = http.Header{ + "Content-Type": []string{"application/merge-patch+json"}, + } + if err := r.SetJSONBody(data); err != nil { + return nil, err + } + + return c.write(ctx, path, r) } func (c *Logical) WriteBytes(path string, data []byte) (*Secret, error) { r := c.c.NewRequest("PUT", "/v1/"+path) r.BodyBytes = data - return c.write(path, r) + return c.write(context.Background(), path, r) } -func (c *Logical) write(path string, request *Request) (*Secret, error) { - ctx, cancelFunc := context.WithCancel(context.Background()) - defer cancelFunc() +func (c *Logical) write(ctx context.Context, path string, request *Request) (*Secret, error) { resp, err := c.c.RawRequestWithContext(ctx, request) if resp != nil { defer resp.Body.Close() diff --git a/changelog/12687.txt b/changelog/12687.txt new file mode 100644 index 0000000000000..f5998deeda706 --- /dev/null +++ b/changelog/12687.txt @@ -0,0 +1,5 @@ +```release-note:feature +**KV patch**: Add partial update support the for the `//data/:path` kv-v2 +endpoint through HTTP `PATCH`. A new `patch` ACL capability has been added and +is required to make such requests. +``` diff --git a/command/kv_patch.go b/command/kv_patch.go index 05e759793ab4f..d05ff5eed89cd 100644 --- a/command/kv_patch.go +++ b/command/kv_patch.go @@ -1,11 +1,13 @@ package command import ( + "context" "fmt" "io" "os" "strings" + "github.com/hashicorp/vault/api" "github.com/mitchellh/cli" "github.com/posener/complete" ) @@ -18,7 +20,9 @@ var ( type KVPatchCommand struct { *BaseCommand - testStdin io.Reader // for tests + flagCAS int + flagMethod string + testStdin io.Reader // for tests } func (c *KVPatchCommand) Synopsis() string { @@ -45,6 +49,25 @@ Usage: vault kv patch [options] KEY [DATA] $ echo "abcd1234" | vault kv patch secret/foo bar=- + To perform a Check-And-Set operation, specify the -cas flag with the + appropriate version number corresponding to the key you want to perform + the CAS operation on: + + $ vault kv patch -cas=1 secret/foo bar=baz + + By default, this operation will attempt an HTTP PATCH operation. If your + policy does not allow that, it will fall back to a read/local update/write approach. + If you wish to specify which method this command should use, you may do so + with the -method flag. When -method=patch is specified, only an HTTP PATCH + operation will be tried. If it fails, the entire command will fail. + + $ vault kv patch -method=patch secret/foo bar=baz + + When -method=rw is specified, only a read/local update/write approach will be tried. + This was the default behavior previous to Vault 1.9. + + $ vault kv patch -method=rw secret/foo bar=baz + Additional flags and more advanced use cases are detailed below. ` + c.Flags().Help() @@ -54,6 +77,27 @@ Usage: vault kv patch [options] KEY [DATA] func (c *KVPatchCommand) Flags() *FlagSets { set := c.flagSet(FlagSetHTTP | FlagSetOutputField | FlagSetOutputFormat) + // Patch specific options + f := set.NewFlagSet("Common Options") + + f.IntVar(&IntVar{ + Name: "cas", + Target: &c.flagCAS, + Default: 0, + Usage: `Specifies to use a Check-And-Set operation. If set to 0 or not + set, the patch will be allowed. If the index is non-zero the patch will + only be allowed if the key’s current version matches the version + specified in the cas parameter.`, + }) + + f.StringVar(&StringVar{ + Name: "method", + Target: &c.flagMethod, + Usage: `Specifies which method of patching to use. If set to "patch", then + an HTTP PATCH request will be issued. If set to "rw", then a read will be + performed, then a local update, followed by a remote update.`, + }) + return set } @@ -121,6 +165,30 @@ func (c *KVPatchCommand) Run(args []string) int { return 2 } + // Check the method and behave accordingly + var secret *api.Secret + var code int + + switch c.flagMethod { + case "rw": + secret, code = c.readThenWrite(client, path, newData) + case "patch": + secret, code = c.mergePatch(client, path, newData, false) + case "": + secret, code = c.mergePatch(client, path, newData, true) + default: + c.UI.Error(fmt.Sprintf("Unsupported method provided to -method flag: %s", c.flagMethod)) + return 2 + } + + if code != 0 { + return code + } + + return OutputSecret(c.UI, secret) +} + +func (c *KVPatchCommand) readThenWrite(client *api.Client, path string, newData map[string]interface{}) (*api.Secret, int) { // First, do a read. // Note that we don't want to see curl output for the read request. curOutputCurl := client.OutputCurlString() @@ -129,45 +197,45 @@ func (c *KVPatchCommand) Run(args []string) int { client.SetOutputCurlString(curOutputCurl) if err != nil { c.UI.Error(fmt.Sprintf("Error doing pre-read at %s: %s", path, err)) - return 2 + return nil, 2 } // Make sure a value already exists if secret == nil || secret.Data == nil { c.UI.Error(fmt.Sprintf("No value found at %s", path)) - return 2 + return nil, 2 } // Verify metadata found rawMeta, ok := secret.Data["metadata"] if !ok || rawMeta == nil { c.UI.Error(fmt.Sprintf("No metadata found at %s; patch only works on existing data", path)) - return 2 + return nil, 2 } meta, ok := rawMeta.(map[string]interface{}) if !ok { c.UI.Error(fmt.Sprintf("Metadata found at %s is not the expected type (JSON object)", path)) - return 2 + return nil, 2 } if meta == nil { c.UI.Error(fmt.Sprintf("No metadata found at %s; patch only works on existing data", path)) - return 2 + return nil, 2 } // Verify old data found rawData, ok := secret.Data["data"] if !ok || rawData == nil { c.UI.Error(fmt.Sprintf("No data found at %s; patch only works on existing data", path)) - return 2 + return nil, 2 } data, ok := rawData.(map[string]interface{}) if !ok { c.UI.Error(fmt.Sprintf("Data found at %s is not the expected type (JSON object)", path)) - return 2 + return nil, 2 } if data == nil { c.UI.Error(fmt.Sprintf("No data found at %s; patch only works on existing data", path)) - return 2 + return nil, 2 } // Copy new data over @@ -183,19 +251,58 @@ func (c *KVPatchCommand) Run(args []string) int { }) if err != nil { c.UI.Error(fmt.Sprintf("Error writing data to %s: %s", path, err)) - return 2 + return nil, 2 } + if secret == nil { // Don't output anything unless using the "table" format if Format(c.UI) == "table" { c.UI.Info(fmt.Sprintf("Success! Data written to: %s", path)) } - return 0 + return nil, 0 } if c.flagField != "" { - return PrintRawField(c.UI, secret, c.flagField) + return nil, PrintRawField(c.UI, secret, c.flagField) } - return OutputSecret(c.UI, secret) + return secret, 0 +} + +func (c *KVPatchCommand) mergePatch(client *api.Client, path string, newData map[string]interface{}, rwFallback bool) (*api.Secret, int) { + data := map[string]interface{}{ + "data": newData, + "options": map[string]interface{}{}, + } + + if c.flagCAS > 0 { + data["options"].(map[string]interface{})["cas"] = c.flagCAS + } + + secret, err := client.Logical().JSONMergePatch(context.Background(), path, data) + if err != nil { + // If it's a 403, that probably means they don't have the patch capability in their policy. Fall back to + // the old way of doing it if the user didn't specify a -method. If they did, and it was "patch", then just error. + if re, ok := err.(*api.ResponseError); ok && re.StatusCode == 403 && rwFallback { + c.UI.Warn(fmt.Sprintf("Data was written to %s but we recommend that you add the \"patch\" capability to your ACL policy in order to use HTTP PATCH in the future.", path)) + return c.readThenWrite(client, path, newData) + } + + c.UI.Error(fmt.Sprintf("Error writing data to %s: %s", path, err)) + return nil, 2 + } + + if secret == nil { + // Don't output anything unless using the "table" format + if Format(c.UI) == "table" { + c.UI.Info(fmt.Sprintf("Success! Data written to: %s", path)) + } + return nil, 0 + } + + if c.flagField != "" { + return nil, PrintRawField(c.UI, secret, c.flagField) + } + + return secret, 0 } diff --git a/command/kv_test.go b/command/kv_test.go index 7e93a4cb8beff..a73fc01356664 100644 --- a/command/kv_test.go +++ b/command/kv_test.go @@ -1,6 +1,7 @@ package command import ( + "fmt" "io" "strings" "testing" @@ -552,3 +553,606 @@ func TestKVMetadataGetCommand(t *testing.T) { assertNoTabs(t, cmd) }) } + +func testKVPatchCommand(tb testing.TB) (*cli.MockUi, *KVPatchCommand) { + tb.Helper() + + ui := cli.NewMockUi() + return ui, &KVPatchCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + } +} + +func TestKVPatchCommand_ArgValidation(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + args []string + out string + code int + }{ + { + "not_enough_args", + []string{}, + "Not enough arguments", + 1, + }, + { + "empty_kvs", + []string{"kv/patch/foo"}, + "Must supply data", + 1, + }, + { + "kvs_no_value", + []string{"kv/patch/foo", "foo"}, + "Failed to parse K=V data", + 1, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + if err := client.Sys().Mount("kv/", &api.MountInput{ + Type: "kv-v2", + }); err != nil { + t.Fatalf("kv-v2 mount attempt failed - err: %#v\n", err) + } + + ui, cmd := testKVPatchCommand(t) + cmd.client = client + + code := cmd.Run(tc.args) + + if code != tc.code { + t.Fatalf("expected code to be %d but was %d for cmd %#v with args %#v\n", tc.code, code, cmd, tc.args) + } + + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + + if !strings.Contains(combined, tc.out) { + t.Fatalf("expected output to be %q but was %q for cmd %#v with args %#v\n", tc.out, combined, cmd, tc.args) + } + }) + } +} + +func TestKvPatchCommand_StdinFull(t *testing.T) { + client, closer := testVaultServer(t) + defer closer() + + if err := client.Sys().Mount("kv/", &api.MountInput{ + Type: "kv-v2", + }); err != nil { + t.Fatalf("kv-v2 mount attempt failed - err: %#v\n", err) + } + + if _, err := client.Logical().Write("kv/data/patch/foo", map[string]interface{}{ + "data": map[string]interface{}{ + "foo": "a", + }, + }); err != nil { + t.Fatalf("write failed, err: %#v\n", err) + } + + stdinR, stdinW := io.Pipe() + go func() { + stdinW.Write([]byte(`{"foo":"bar"}`)) + stdinW.Close() + }() + + _, cmd := testKVPatchCommand(t) + cmd.client = client + + cmd.testStdin = stdinR + + args := []string{"kv/patch/foo", "-"} + code := cmd.Run(args) + if code != 0 { + t.Fatalf("expected code to be 0 but was %d for cmd %#v with args %#v\n", code, cmd, args) + } + + secret, err := client.Logical().Read("kv/data/patch/foo") + if err != nil { + t.Fatalf("read failed, err: %#v\n", err) + } + + if secret == nil || secret.Data == nil { + t.Fatal("expected secret to have data") + } + + secretDataRaw, ok := secret.Data["data"] + + if !ok { + t.Fatalf("expected secret to have nested data key, data: %#v", secret.Data) + } + + secretData := secretDataRaw.(map[string]interface{}) + foo, ok := secretData["foo"].(string) + if !ok { + t.Fatal("expected foo to be a string but it wasn't") + } + + if exp, act := "bar", foo; exp != act { + t.Fatalf("expected %q to be %q, data: %#v\n", act, exp, secret.Data) + } +} + +func TestKvPatchCommand_StdinValue(t *testing.T) { + client, closer := testVaultServer(t) + defer closer() + + if err := client.Sys().Mount("kv/", &api.MountInput{ + Type: "kv-v2", + }); err != nil { + t.Fatalf("kv-v2 mount attempt failed - err: %#v\n", err) + } + + if _, err := client.Logical().Write("kv/data/patch/foo", map[string]interface{}{ + "data": map[string]interface{}{ + "foo": "a", + }, + }); err != nil { + t.Fatalf("write failed, err: %#v\n", err) + } + + stdinR, stdinW := io.Pipe() + go func() { + stdinW.Write([]byte("bar")) + stdinW.Close() + }() + + _, cmd := testKVPatchCommand(t) + cmd.client = client + + cmd.testStdin = stdinR + + args := []string{"kv/patch/foo", "foo=-"} + code := cmd.Run(args) + if code != 0 { + t.Fatalf("expected code to be 0 but was %d for cmd %#v with args %#v\n", code, cmd, args) + } + + secret, err := client.Logical().Read("kv/data/patch/foo") + if err != nil { + t.Fatalf("read failed, err: %#v\n", err) + } + + if secret == nil || secret.Data == nil { + t.Fatal("expected secret to have data") + } + + secretDataRaw, ok := secret.Data["data"] + + if !ok { + t.Fatalf("expected secret to have nested data key, data: %#v\n", secret.Data) + } + + secretData := secretDataRaw.(map[string]interface{}) + + if exp, act := "bar", secretData["foo"].(string); exp != act { + t.Fatalf("expected %q to be %q, data: %#v\n", act, exp, secret.Data) + } +} + +func TestKVPatchCommand_RWMethodNotExists(t *testing.T) { + client, closer := testVaultServer(t) + defer closer() + + if err := client.Sys().Mount("kv/", &api.MountInput{ + Type: "kv-v2", + }); err != nil { + t.Fatalf("kv-v2 mount attempt failed - err: %#v\n", err) + } + + ui, cmd := testKVPatchCommand(t) + cmd.client = client + + args := []string{"-method", "rw", "kv/patch/foo", "foo=a"} + code := cmd.Run(args) + if code != 2 { + t.Fatalf("expected code to be 2 but was %d for cmd %#v with args %#v\n", code, cmd, args) + } + + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + + expectedOutputSubstr := "No value found" + if !strings.Contains(combined, expectedOutputSubstr) { + t.Fatalf("expected output %q to contain %q for cmd %#v with args %#v\n", combined, expectedOutputSubstr, cmd, args) + } +} + +func TestKVPatchCommand_RWMethodSucceeds(t *testing.T) { + client, closer := testVaultServer(t) + defer closer() + + if err := client.Sys().Mount("kv/", &api.MountInput{ + Type: "kv-v2", + }); err != nil { + t.Fatalf("kv-v2 mount attempt failed - err: %#v\n", err) + } + + if _, err := client.Logical().Write("kv/data/patch/foo", map[string]interface{}{ + "data": map[string]interface{}{ + "foo": "a", + "bar": "b", + }, + }); err != nil { + t.Fatalf("write failed, err: %#v\n", err) + } + + ui, cmd := testKVPatchCommand(t) + cmd.client = client + + // Test single value + args := []string{"-method", "rw", "kv/patch/foo", "foo=aa"} + code := cmd.Run(args) + if code != 0 { + t.Fatalf("expected code to be 0 but was %d for cmd %#v with args %#v\n", code, cmd, args) + } + + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + + expectedOutputSubstr := "created_time" + if !strings.Contains(combined, expectedOutputSubstr) { + t.Fatalf("expected output %q to contain %q for cmd %#v with args %#v\n", combined, expectedOutputSubstr, cmd, args) + } + + // Test multi value + ui, cmd = testKVPatchCommand(t) + cmd.client = client + + args = []string{"-method", "rw", "kv/patch/foo", "foo=aaa", "bar=bbb"} + code = cmd.Run(args) + if code != 0 { + t.Fatalf("expected code to be 0 but was %d for cmd %#v with args %#v\n", code, cmd, args) + } + + combined = ui.OutputWriter.String() + ui.ErrorWriter.String() + + if !strings.Contains(combined, expectedOutputSubstr) { + t.Fatalf("expected output %q to contain %q for cmd %#v with args %#v\n", combined, expectedOutputSubstr, cmd, args) + } +} + +func TestKVPatchCommand_CAS(t *testing.T) { + cases := []struct { + name string + args []string + expected string + out string + code int + }{ + { + "right version", + []string{"-cas", "1", "kv/foo", "bar=quux"}, + "quux", + "", + 0, + }, + { + "wrong version", + []string{"-cas", "2", "kv/foo", "bar=wibble"}, + "baz", + "check-and-set parameter did not match the current version", + 2, + }, + } + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + if err := client.Sys().Mount("kv/", &api.MountInput{ + Type: "kv-v2", + }); err != nil { + t.Fatalf("kv-v2 mount attempt failed - err: %#v\n", err) + } + + // create a policy with patch capability + policy := `path "kv/*" { capabilities = ["create", "update", "read", "patch"] }` + secretAuth, err := createTokenForPolicy(t, client, policy) + if err != nil { + t.Fatalf("policy/token creation failed for policy %s, err: %#v\n", policy, err) + } + + kvClient, err := client.Clone() + if err != nil { + t.Fatal(err) + } + + kvClient.SetToken(secretAuth.ClientToken) + + _, err = kvClient.Logical().Write("kv/data/foo", map[string]interface{}{"data": map[string]interface{}{"bar": "baz"}}) + if err != nil { + t.Fatal(err) + } + + ui, cmd := testKVPatchCommand(t) + cmd.client = kvClient + + code := cmd.Run(tc.args) + + if code != tc.code { + t.Fatalf("expected code to be %d but was %d", tc.code, code) + } + + if tc.out != "" { + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, tc.out) { + t.Errorf("expected %q to contain %q", combined, tc.out) + } + } + + secret, err := kvClient.Logical().Read("kv/data/foo") + bar := secret.Data["data"].(map[string]interface{})["bar"] + if bar != tc.expected { + t.Fatalf("expected bar to be %q but it was %q", tc.expected, bar) + } + }) + } +} + +func TestKVPatchCommand_Methods(t *testing.T) { + cases := []struct { + name string + args []string + expected string + code int + }{ + { + "rw", + []string{"-method", "rw", "kv/foo", "bar=quux"}, + "quux", + 0, + }, + { + "patch", + []string{"-method", "patch", "kv/foo", "bar=wibble"}, + "wibble", + 0, + }, + } + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + if err := client.Sys().Mount("kv/", &api.MountInput{ + Type: "kv-v2", + }); err != nil { + t.Fatalf("kv-v2 mount attempt failed - err: %#v\n", err) + } + + // create a policy with patch capability + policy := `path "kv/*" { capabilities = ["create", "update", "read", "patch"] }` + secretAuth, err := createTokenForPolicy(t, client, policy) + if err != nil { + t.Fatalf("policy/token creation failed for policy %s, err: %#v\n", policy, err) + } + + kvClient, err := client.Clone() + if err != nil { + t.Fatal(err) + } + + kvClient.SetToken(secretAuth.ClientToken) + + _, err = kvClient.Logical().Write("kv/data/foo", map[string]interface{}{"data": map[string]interface{}{"bar": "baz"}}) + if err != nil { + t.Fatal(err) + } + + _, cmd := testKVPatchCommand(t) + cmd.client = kvClient + + code := cmd.Run(tc.args) + + if code != tc.code { + t.Fatalf("expected code to be %d but was %d", tc.code, code) + } + + secret, err := kvClient.Logical().Read("kv/data/foo") + bar := secret.Data["data"].(map[string]interface{})["bar"] + if bar != tc.expected { + t.Fatalf("expected bar to be %q but it was %q", tc.expected, bar) + } + }) + } +} + +func TestKVPatchCommand_403Fallback(t *testing.T) { + cases := []struct { + name string + args []string + expected string + code int + }{ + // if no -method is specified, and patch fails, it should fall back to rw and succeed + { + "unspecified", + []string{"kv/foo", "bar=quux"}, + `add the "patch" capability to your ACL policy`, + 0, + }, + // if -method=patch is specified, and patch fails, it should not fall back, and just error + { + "specifying patch", + []string{"-method", "patch", "kv/foo", "bar=quux"}, + "permission denied", + 2, + }, + } + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + client, closer := testVaultServer(t) + defer closer() + + if err := client.Sys().Mount("kv/", &api.MountInput{ + Type: "kv-v2", + }); err != nil { + t.Fatalf("kv-v2 mount attempt failed - err: %#v\n", err) + } + + // create a policy without patch capability + policy := `path "kv/*" { capabilities = ["create", "update", "read"] }` + secretAuth, err := createTokenForPolicy(t, client, policy) + if err != nil { + t.Fatalf("policy/token creation failed for policy %s, err: %#v\n", policy, err) + } + + kvClient, err := client.Clone() + if err != nil { + t.Fatal(err) + } + + kvClient.SetToken(secretAuth.ClientToken) + + // Write a value then attempt to patch it + _, err = kvClient.Logical().Write("kv/data/foo", map[string]interface{}{"data": map[string]interface{}{"bar": "baz"}}) + if err != nil { + t.Fatal(err) + } + + ui, cmd := testKVPatchCommand(t) + cmd.client = kvClient + code := cmd.Run(tc.args) + + if code != tc.code { + t.Fatalf("expected code to be %d but was %d", tc.code, code) + } + + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, tc.expected) { + t.Errorf("expected %q to contain %q", combined, tc.expected) + } + }) + } +} + +func createTokenForPolicy(t *testing.T, client *api.Client, policy string) (*api.SecretAuth, error) { + t.Helper() + + if err := client.Sys().PutPolicy("policy", policy); err != nil { + return nil, err + } + + secret, err := client.Auth().Token().Create(&api.TokenCreateRequest{ + Policies: []string{"policy"}, + TTL: "30m", + }) + if err != nil { + return nil, err + } + + if secret == nil || secret.Auth == nil || secret.Auth.ClientToken == "" { + return nil, fmt.Errorf("missing auth data: %#v", secret) + } + + return secret.Auth, err +} + +func TestKVPatchCommand_RWMethodPolicyVariations(t *testing.T) { + cases := []struct { + name string + args []string + policy string + expected string + code int + }{ + // if the policy doesn't have read capability and -method=rw is specified, it fails + { + "no read", + []string{"-method", "rw", "kv/foo", "bar=quux"}, + `path "kv/*" { capabilities = ["create", "update"] }`, + "permission denied", + 2, + }, + // if the policy doesn't have update capability and -method=rw is specified, it fails + { + "no update", + []string{"-method", "rw", "kv/foo", "bar=quux"}, + `path "kv/*" { capabilities = ["create", "read"] }`, + "permission denied", + 2, + }, + // if the policy has both read and update and -method=rw is specified, it succeeds + { + "read and update", + []string{"-method", "rw", "kv/foo", "bar=quux"}, + `path "kv/*" { capabilities = ["create", "read", "update"] }`, + "", + 0, + }, + } + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + client, closer := testVaultServer(t) + defer closer() + + if err := client.Sys().Mount("kv/", &api.MountInput{ + Type: "kv-v2", + }); err != nil { + t.Fatalf("kv-v2 mount attempt failed - err: %#v\n", err) + } + + secretAuth, err := createTokenForPolicy(t, client, tc.policy) + if err != nil { + t.Fatalf("policy/token creation failed for policy %s, err: %#v\n", tc.policy, err) + } + + client.SetToken(secretAuth.ClientToken) + + if _, err := client.Logical().Write("kv/data/foo", map[string]interface{}{ + "data": map[string]interface{}{ + "foo": "bar", + "bar": "baz", + }, + }); err != nil { + t.Fatalf("write failed, err: %#v\n", err) + } + + ui, cmd := testKVPatchCommand(t) + cmd.client = client + + code := cmd.Run(tc.args) + + if code != tc.code { + t.Fatalf("expected code to be %d but was %d for cmd %#v with args %#v\n", tc.code, code, cmd, tc.args) + } + + if code != 0 { + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, tc.expected) { + t.Fatalf("expected output %q to contain %q for cmd %#v with args %#v\n", combined, tc.expected, cmd, tc.args) + } + } + }) + } +} diff --git a/go.mod b/go.mod index affcd85ccf092..b5ab8ff8b72e2 100644 --- a/go.mod +++ b/go.mod @@ -113,13 +113,13 @@ require ( github.com/hashicorp/vault-plugin-secrets-azure v0.6.3-0.20210924190759-58a034528e35 github.com/hashicorp/vault-plugin-secrets-gcp v0.10.2 github.com/hashicorp/vault-plugin-secrets-gcpkms v0.9.0 - github.com/hashicorp/vault-plugin-secrets-kv v0.5.7-0.20210811133805-e060c2307b24 + github.com/hashicorp/vault-plugin-secrets-kv v0.5.7-0.20211013154503-eec8a1c892fb github.com/hashicorp/vault-plugin-secrets-mongodbatlas v0.4.0 github.com/hashicorp/vault-plugin-secrets-openldap v0.4.1-0.20210921171411-e86105e4986d github.com/hashicorp/vault-plugin-secrets-terraform v0.1.1-0.20210715043003-e02ca8f6408e github.com/hashicorp/vault-testing-stepwise v0.1.1 github.com/hashicorp/vault/api v1.1.1 - github.com/hashicorp/vault/sdk v0.2.1 + github.com/hashicorp/vault/sdk v0.2.2-0.20211004171540-a8c7e135dd6a github.com/influxdata/influxdb v0.0.0-20190411212539-d24b7ba8c4c4 github.com/jcmturner/gokrb5/v8 v8.0.0 github.com/jefferai/isbadcipher v0.0.0-20190226160619-51d2077c035f @@ -199,6 +199,7 @@ require ( google.golang.org/grpc v1.41.0 google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.1.0 google.golang.org/protobuf v1.27.1 + gopkg.in/evanphx/json-patch.v4 v4.11.0 // indirect gopkg.in/mgo.v2 v2.0.0-20180705113604-9856a29383ce gopkg.in/ory-am/dockertest.v3 v3.3.4 gopkg.in/square/go-jose.v2 v2.5.1 diff --git a/go.sum b/go.sum index 38b8e21486377..b2490092b2ec9 100644 --- a/go.sum +++ b/go.sum @@ -350,6 +350,8 @@ github.com/envoyproxy/go-control-plane v0.9.9-0.20210512163311-63b5d3c536b0/go.m github.com/envoyproxy/go-control-plane v0.9.10-0.20210907150352-cf90f659a021/go.mod h1:AFq3mo9L8Lqqiid3OhADV3RfLJnjiw63cSpi+fDTRC0= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/evanphx/json-patch v0.0.0-20190203023257-5858425f7550/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk= +github.com/evanphx/json-patch v0.5.2/go.mod h1:ZWS5hhDbVDyob71nXKNL0+PWn6ToqBHMikGIFbs31qQ= +github.com/evanphx/json-patch v4.2.0+incompatible h1:fUDGZCv/7iAN7u0puUVhvKCcsR6vRfwrJatElLBEf0I= github.com/evanphx/json-patch v4.2.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk= github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= github.com/fatih/color v1.9.0/go.mod h1:eQcE1qtQxscV5RaZvpXrrb8Drkc3/DdQ+uUYCNjL+zU= @@ -748,8 +750,10 @@ github.com/hashicorp/vault-plugin-secrets-gcp v0.10.2 h1:+DtlYJTsrFRInQpAo09KkYN github.com/hashicorp/vault-plugin-secrets-gcp v0.10.2/go.mod h1:psRQ/dm5XatoUKLDUeWrpP9icMJNtu/jmscUr37YGK4= github.com/hashicorp/vault-plugin-secrets-gcpkms v0.9.0 h1:7a0iWuFA/YNinQ1xXogyZHStolxMVtLV+sy1LpEHaZs= github.com/hashicorp/vault-plugin-secrets-gcpkms v0.9.0/go.mod h1:hhwps56f2ATeC4Smgghrc5JH9dXR31b4ehSf1HblP5Q= -github.com/hashicorp/vault-plugin-secrets-kv v0.5.7-0.20210811133805-e060c2307b24 h1:uqPKQzkmO5vybOqk2aOdviXXi5088bcl2MrE0D1MhjM= -github.com/hashicorp/vault-plugin-secrets-kv v0.5.7-0.20210811133805-e060c2307b24/go.mod h1:4j2pZrSynPuUAAYrZQVgSSHD0A9xj7GK9Ji1sWtnO4s= +github.com/hashicorp/vault-plugin-secrets-kv v0.5.7-0.20211007143158-2d15a6fec12b h1:1GJj7AjgI0Td95haW8EK5on3Usuox78wmzLj+J9vcm4= +github.com/hashicorp/vault-plugin-secrets-kv v0.5.7-0.20211007143158-2d15a6fec12b/go.mod h1:iEKCVaKBQzzYxzb778O6VGLdd+8gA40ZI14bo+8tQjs= +github.com/hashicorp/vault-plugin-secrets-kv v0.5.7-0.20211013154503-eec8a1c892fb h1:nZ2a4a1G0ALLAzKOWQbLzD5oljKo+pjMarbq3BwU0pM= +github.com/hashicorp/vault-plugin-secrets-kv v0.5.7-0.20211013154503-eec8a1c892fb/go.mod h1:D/FQJ7zU5pD6FNJVUwaVtxr75ZsxIIqaG/Nh6RHt/xo= github.com/hashicorp/vault-plugin-secrets-mongodbatlas v0.4.0 h1:6ve+7hZmGn7OpML81iZUxYj2AaJptwys323S5XsvVas= github.com/hashicorp/vault-plugin-secrets-mongodbatlas v0.4.0/go.mod h1:4mdgPqlkO+vfFX1cFAWcxkeqz6JAtZgKxL/67q/58Oo= github.com/hashicorp/vault-plugin-secrets-openldap v0.4.1-0.20210921171411-e86105e4986d h1:o5Z9B1FztTYSnTQNzFr+iZJHPM8ZD23uV5A8gMxm2g0= @@ -803,6 +807,7 @@ github.com/jefferai/isbadcipher v0.0.0-20190226160619-51d2077c035f h1:E87tDTVS5W github.com/jefferai/isbadcipher v0.0.0-20190226160619-51d2077c035f/go.mod h1:3J2qVK16Lq8V+wfiL2lPeDZ7UWMxk5LemerHa1p6N00= github.com/jefferai/jsonx v1.0.0 h1:Xoz0ZbmkpBvED5W9W1B5B/zc3Oiq7oXqiW7iRV3B6EI= github.com/jefferai/jsonx v1.0.0/go.mod h1:OGmqmi2tTeI/PS+qQfBDToLHHJIy/RMp24fPo8vFvoQ= +github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= github.com/jmespath/go-jmespath v0.3.0/go.mod h1:9QtRXoHjLGCJ5IBSaohpXITPlowMeeYCZ7fLUTSywik= github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= @@ -1672,6 +1677,8 @@ gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8X gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/errgo.v2 v2.1.0 h1:0vLT13EuvQ0hNvakwLuFZ/jYrLp5F3kcWHXdRggjCE8= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/evanphx/json-patch.v4 v4.11.0 h1:+kbwxm5IBGIiNYVhss+hM3Nv4ck+HnPSNscCNbD1cT0= +gopkg.in/evanphx/json-patch.v4 v4.11.0/go.mod h1:p8EYWUEYMpynmqDbY58zCKCFZw8pRWMG4EsWvDvM72M= gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4= gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= gopkg.in/gemnasium/logrus-airbrake-hook.v2 v2.1.2/go.mod h1:Xk6kEKp8OKb+X14hQBKWaSkCsqBpgog8nAV2xsGOxlo= diff --git a/http/handler_test.go b/http/handler_test.go index c228629ea8dce..87a79d3143394 100644 --- a/http/handler_test.go +++ b/http/handler_test.go @@ -16,6 +16,10 @@ import ( "github.com/go-test/deep" "github.com/hashicorp/go-cleanhttp" + kv "github.com/hashicorp/vault-plugin-secrets-kv" + "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/audit" + auditFile "github.com/hashicorp/vault/builtin/audit/file" "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/logical" @@ -818,3 +822,198 @@ func TestHandler_Parse_Form(t *testing.T) { t.Fatal(diff) } } + +func TestHandler_Patch_BadContentTypeHeader(t *testing.T) { + coreConfig := &vault.CoreConfig{ + LogicalBackends: map[string]logical.Factory{ + "kv": kv.VersionedKVFactory, + }, + } + + cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ + HandlerFunc: Handler, + }) + + cluster.Start() + defer cluster.Cleanup() + + cores := cluster.Cores + + core := cores[0].Core + c := cluster.Cores[0].Client + vault.TestWaitActive(t, core) + + // Mount a KVv2 backend + err := c.Sys().Mount("kv", &api.MountInput{ + Type: "kv-v2", + }) + if err != nil { + t.Fatal(err) + } + + kvData := map[string]interface{}{ + "data": map[string]interface{}{ + "bar": "a", + }, + } + + resp, err := c.Logical().Write("kv/data/foo", kvData) + if err != nil { + t.Fatalf("write failed - err :%#v, resp: %#v\n", err, resp) + } + + resp, err = c.Logical().Read("kv/data/foo") + if err != nil { + t.Fatalf("read failed - err :%#v, resp: %#v\n", err, resp) + } + + req := c.NewRequest("PATCH", "/v1/kv/data/foo") + req.Headers = http.Header{ + "Content-Type": []string{"application/json"}, + } + + if err := req.SetJSONBody(kvData); err != nil { + t.Fatal(err) + } + + apiResp, err := c.RawRequestWithContext(context.Background(), req) + if err == nil || apiResp.StatusCode != http.StatusUnsupportedMediaType { + t.Fatalf("expected PATCH request to fail with %d status code - err :%#v, resp: %#v\n", http.StatusUnsupportedMediaType, err, apiResp) + } +} + +func TestHandler_Patch_NotFound(t *testing.T) { + coreConfig := &vault.CoreConfig{ + LogicalBackends: map[string]logical.Factory{ + "kv": kv.VersionedKVFactory, + }, + } + + cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ + HandlerFunc: Handler, + }) + + cluster.Start() + defer cluster.Cleanup() + + cores := cluster.Cores + + core := cores[0].Core + c := cluster.Cores[0].Client + vault.TestWaitActive(t, core) + + // Mount a KVv2 backend + err := c.Sys().Mount("kv", &api.MountInput{ + Type: "kv-v2", + }) + if err != nil { + t.Fatal(err) + } + + kvData := map[string]interface{}{ + "data": map[string]interface{}{ + "bar": "a", + }, + } + + resp, err := c.Logical().JSONMergePatch(context.Background(), "kv/data/foo", kvData) + if err == nil { + t.Fatalf("expected PATCH request to fail, resp: %#v\n", resp) + } + + responseError := err.(*api.ResponseError) + if responseError.StatusCode != http.StatusNotFound { + t.Fatalf("expected PATCH request to fail with %d status code - err: %#v, resp: %#v\n", http.StatusNotFound, responseError, resp) + } +} + +func TestHandler_Patch_Audit(t *testing.T) { + coreConfig := &vault.CoreConfig{ + LogicalBackends: map[string]logical.Factory{ + "kv": kv.VersionedKVFactory, + }, + AuditBackends: map[string]audit.Factory{ + "file": auditFile.Factory, + }, + } + + cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ + HandlerFunc: Handler, + }) + + cluster.Start() + defer cluster.Cleanup() + + cores := cluster.Cores + + core := cores[0].Core + c := cluster.Cores[0].Client + vault.TestWaitActive(t, core) + + if err := c.Sys().Mount("kv/", &api.MountInput{ + Type: "kv-v2", + }); err != nil { + t.Fatalf("kv-v2 mount attempt failed - err: %#v\n", err) + } + + auditLogFile, err := ioutil.TempFile("", "httppatch") + if err != nil { + t.Fatal(err) + } + + err = c.Sys().EnableAuditWithOptions("file", &api.EnableAuditOptions{ + Type: "file", + Options: map[string]string{ + "file_path": auditLogFile.Name(), + }, + }) + + writeData := map[string]interface{}{ + "data": map[string]interface{}{ + "bar": "a", + }, + } + + resp, err := c.Logical().Write("kv/data/foo", writeData) + if err != nil { + t.Fatalf("write request failed, err: %#v, resp: %#v\n", err, resp) + } + + patchData := map[string]interface{}{ + "data": map[string]interface{}{ + "baz": "b", + }, + } + + resp, err = c.Logical().JSONMergePatch(context.Background(), "kv/data/foo", patchData) + if err != nil { + t.Fatalf("patch request failed, err: %#v, resp: %#v\n", err, resp) + } + + patchRequestLogCount := 0 + patchResponseLogCount := 0 + decoder := json.NewDecoder(auditLogFile) + + var auditRecord map[string]interface{} + for decoder.Decode(&auditRecord) == nil { + auditRequest := map[string]interface{}{} + + if req, ok := auditRecord["request"]; ok { + auditRequest = req.(map[string]interface{}) + } + + if auditRequest["operation"] == "patch" && auditRecord["type"] == "request" { + patchRequestLogCount += 1 + } else if auditRequest["operation"] == "patch" && auditRecord["type"] == "response" { + patchResponseLogCount += 1 + } + } + + if patchRequestLogCount != 1 { + t.Fatalf("expected 1 patch request audit log record, saw %d\n", patchRequestLogCount) + } + + if patchResponseLogCount != 1 { + t.Fatalf("expected 1 patch response audit log record, saw %d\n", patchResponseLogCount) + } +} diff --git a/http/logical.go b/http/logical.go index dd9abce34dfdb..7984f8ac06ba3 100644 --- a/http/logical.go +++ b/http/logical.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "io" + "mime" "net" "net/http" "strconv" @@ -38,6 +39,8 @@ func (b *bufferedReader) Close() error { return b.rOrig.Close() } +const MergePatchContentTypeHeader = "application/merge-patch+json" + func buildLogicalRequestNoAuth(perfStandby bool, w http.ResponseWriter, r *http.Request) (*logical.Request, io.ReadCloser, int, error) { ns, err := namespace.FromContext(r.Context()) if err != nil { @@ -139,6 +142,34 @@ func buildLogicalRequestNoAuth(perfStandby bool, w http.ResponseWriter, r *http. } } + case "PATCH": + op = logical.PatchOperation + + contentTypeHeader := r.Header.Get("Content-Type") + contentType, _, err := mime.ParseMediaType(contentTypeHeader) + if err != nil { + status := http.StatusBadRequest + logical.AdjustErrorStatusCode(&status, err) + return nil, nil, status, err + } + + if contentType != MergePatchContentTypeHeader { + return nil, nil, http.StatusUnsupportedMediaType, fmt.Errorf("PATCH requires Content-Type of %s, provided %s", MergePatchContentTypeHeader, contentType) + } + + origBody, err = parseJSONRequest(perfStandby, r, w, &data) + + if err == io.EOF { + data = nil + err = nil + } + + if err != nil { + status := http.StatusBadRequest + logical.AdjustErrorStatusCode(&status, err) + return nil, nil, status, fmt.Errorf("error parsing JSON") + } + case "LIST": op = logical.ListOperation if !strings.HasSuffix(path, "/") { diff --git a/sdk/framework/backend.go b/sdk/framework/backend.go index d498dd4ee09f0..673351a47e374 100644 --- a/sdk/framework/backend.go +++ b/sdk/framework/backend.go @@ -3,6 +3,7 @@ package framework import ( "context" "crypto/rand" + "encoding/json" "fmt" "io" "io/ioutil" @@ -13,6 +14,7 @@ import ( "sync" "time" + jsonpatch "github.com/evanphx/json-patch" "github.com/hashicorp/errwrap" log "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-kms-wrapping/entropy" @@ -118,6 +120,10 @@ type InvalidateFunc func(context.Context, string) // Initialize() just after a plugin has been mounted. type InitializeFunc func(context.Context, *logical.InitializationRequest) error +// PatchPreprocessorFunc is used by HandlePatchOperation in order to shape +// the input as defined by request handler prior to JSON marshaling +type PatchPreprocessorFunc func(map[string]interface{}) (map[string]interface{}, error) + // Initialize is the logical.Backend implementation. func (b *Backend) Initialize(ctx context.Context, req *logical.InitializationRequest) error { if b.InitializeFunc != nil { @@ -272,6 +278,57 @@ func (b *Backend) HandleRequest(ctx context.Context, req *logical.Request) (*log return callback(ctx, req, &fd) } +// HandlePatchOperation acts as an abstraction for performing JSON merge patch +// operations (see https://datatracker.ietf.org/doc/html/rfc7396) for HTTP +// PATCH requests. It is responsible for properly processing and marshalling +// the input and existing resource prior to performing the JSON merge operation +// using the MergePatch function from the json-patch library. The preprocessor +// is an arbitrary func that can be provided to further process the input. The +// MergePatch function accepts and returns byte arrays. +func HandlePatchOperation(input *FieldData, resource map[string]interface{}, preprocessor PatchPreprocessorFunc) ([]byte, error) { + var err error + + if resource == nil { + return nil, fmt.Errorf("resource does not exist") + } + + inputMap := map[string]interface{}{} + + // Parse all fields to ensure data types are handled properly according to the FieldSchema + for key := range input.Raw { + val, ok := input.GetOk(key) + + // Only accept fields in the schema + if ok { + inputMap[key] = val + } + } + + if preprocessor != nil { + inputMap, err = preprocessor(inputMap) + if err != nil { + return nil, err + } + } + + marshaledResource, err := json.Marshal(resource) + if err != nil { + return nil, err + } + + marshaledInput, err := json.Marshal(inputMap) + if err != nil { + return nil, err + } + + modified, err := jsonpatch.MergePatch(marshaledResource, marshaledInput) + if err != nil { + return nil, err + } + + return modified, nil +} + // SpecialPaths is the logical.Backend implementation. func (b *Backend) SpecialPaths() *logical.Paths { return b.PathsSpecial diff --git a/sdk/logical/request.go b/sdk/logical/request.go index b88aabce2df37..e683217a6efc5 100644 --- a/sdk/logical/request.go +++ b/sdk/logical/request.go @@ -350,6 +350,7 @@ const ( CreateOperation Operation = "create" ReadOperation = "read" UpdateOperation = "update" + PatchOperation = "patch" DeleteOperation = "delete" ListOperation = "list" HelpOperation = "help" diff --git a/sdk/logical/response_util.go b/sdk/logical/response_util.go index 6ae3005b735f1..5244069379864 100644 --- a/sdk/logical/response_util.go +++ b/sdk/logical/response_util.go @@ -17,7 +17,7 @@ import ( func RespondErrorCommon(req *Request, resp *Response, err error) (int, error) { if err == nil && (resp == nil || !resp.IsError()) { switch { - case req.Operation == ReadOperation: + case req.Operation == ReadOperation, req.Operation == PatchOperation: if resp == nil { return http.StatusNotFound, nil } diff --git a/vault/acl.go b/vault/acl.go index 3d07c4089c486..fc9f353aa8afb 100644 --- a/vault/acl.go +++ b/vault/acl.go @@ -305,6 +305,9 @@ func (a *ACL) Capabilities(ctx context.Context, path string) (pathCapabilities [ if capabilities&CreateCapabilityInt > 0 { pathCapabilities = append(pathCapabilities, CreateCapability) } + if capabilities&PatchCapabilityInt > 0 { + pathCapabilities = append(pathCapabilities, PatchCapability) + } // If "deny" is explicitly set or if the path has no capabilities at all, // set the path capabilities to "deny" @@ -406,6 +409,8 @@ CHECK: operationAllowed = capabilities&DeleteCapabilityInt > 0 case logical.CreateOperation: operationAllowed = capabilities&CreateCapabilityInt > 0 + case logical.PatchOperation: + operationAllowed = capabilities&PatchCapabilityInt > 0 // These three re-use UpdateCapabilityInt since that's the most appropriate // capability/operation mapping @@ -440,7 +445,7 @@ CHECK: // Only check parameter permissions for operations that can modify // parameters. - if op == logical.ReadOperation || op == logical.UpdateOperation || op == logical.CreateOperation { + if op == logical.ReadOperation || op == logical.UpdateOperation || op == logical.CreateOperation || op == logical.PatchOperation { for _, parameter := range permissions.RequiredParameters { if _, ok := req.Data[strings.ToLower(parameter)]; !ok { return diff --git a/vault/acl_test.go b/vault/acl_test.go index 29c690e01a882..fe3a33e8aa20d 100644 --- a/vault/acl_test.go +++ b/vault/acl_test.go @@ -238,6 +238,12 @@ func testACLSingle(t *testing.T, ns *namespace.Namespace) { {logical.UpdateOperation, "foo/bar", false, true}, {logical.CreateOperation, "foo/bar", true, true}, + {logical.ReadOperation, "baz/quux", true, false}, + {logical.CreateOperation, "baz/quux", true, false}, + {logical.PatchOperation, "baz/quux", true, false}, + {logical.ListOperation, "baz/quux", false, false}, + {logical.UpdateOperation, "baz/quux", false, false}, + // Path segment wildcards {logical.ReadOperation, "test/foo/bar/segment", false, false}, {logical.ReadOperation, "test/foo/segment", true, false}, @@ -341,6 +347,12 @@ func testLayeredACL(t *testing.T, acl *ACL, ns *namespace.Namespace) { {logical.ListOperation, "foo/bar", false, false}, {logical.UpdateOperation, "foo/bar", false, false}, {logical.CreateOperation, "foo/bar", false, false}, + + {logical.ReadOperation, "baz/quux", false, false}, + {logical.ListOperation, "baz/quux", false, false}, + {logical.UpdateOperation, "baz/quux", false, false}, + {logical.CreateOperation, "baz/quux", false, false}, + {logical.PatchOperation, "baz/quux", false, false}, } for _, tc := range tcases { @@ -864,6 +876,9 @@ path "sys/*" { path "foo/bar" { capabilities = ["read", "create", "sudo"] } +path "baz/quux" { + capabilities = ["read", "create", "patch"] +} path "test/+/segment" { capabilities = ["read"] } @@ -912,6 +927,9 @@ path "sys/seal" { path "foo/bar" { capabilities = ["deny"] } +path "baz/quux" { + capabilities = ["deny"] +} ` // test merging diff --git a/vault/external_tests/kv/kv_patch_test.go b/vault/external_tests/kv/kv_patch_test.go new file mode 100644 index 0000000000000..33d36f4f0b10d --- /dev/null +++ b/vault/external_tests/kv/kv_patch_test.go @@ -0,0 +1,57 @@ +package kv + +import ( + "context" + "testing" + + logicalKv "github.com/hashicorp/vault-plugin-secrets-kv" + "github.com/hashicorp/vault/api" + vaulthttp "github.com/hashicorp/vault/http" + "github.com/hashicorp/vault/sdk/logical" + "github.com/hashicorp/vault/vault" +) + +// Verifies that patching works by default with the root token +func TestKV_Patch_RootToken(t *testing.T) { + coreConfig := &vault.CoreConfig{ + LogicalBackends: map[string]logical.Factory{ + "kv": logicalKv.Factory, + }, + } + cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ + HandlerFunc: vaulthttp.Handler, + }) + cluster.Start() + defer cluster.Cleanup() + + core := cluster.Cores[0] + client := core.Client + + // make sure this client is using the root token + client.SetToken(cluster.RootToken) + + // Enable KVv2 + err := client.Sys().Mount("kv", &api.MountInput{ + Type: "kv-v2", + }) + if err != nil { + t.Fatal(err) + } + + // Write a kv value and patch it + _, err = client.Logical().Write("kv/data/foo", map[string]interface{}{"data": map[string]interface{}{"bar": "baz"}}) + if err != nil { + t.Fatal(err) + } + + _, err = client.Logical().JSONMergePatch(context.Background(), "kv/data/foo", map[string]interface{}{"data": map[string]interface{}{"bar": "quux"}}) + if err != nil { + t.Fatal(err) + } + + secret, err := client.Logical().Read("kv/data/foo") + bar := secret.Data["data"].(map[string]interface{})["bar"] + if bar != "quux" { + t.Fatalf("expected bar to be quux but it was %q", bar) + } +} diff --git a/vault/external_tests/misc/kvv2_upgrade_test.go b/vault/external_tests/kv/kvv2_upgrade_test.go similarity index 99% rename from vault/external_tests/misc/kvv2_upgrade_test.go rename to vault/external_tests/kv/kvv2_upgrade_test.go index 39d747e8ab53e..3d3eb486f2071 100644 --- a/vault/external_tests/misc/kvv2_upgrade_test.go +++ b/vault/external_tests/kv/kvv2_upgrade_test.go @@ -1,4 +1,4 @@ -package misc +package kv import ( "bytes" diff --git a/vault/logical_system.go b/vault/logical_system.go index 1675475a422a3..c41d3cb82278e 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -3398,7 +3398,8 @@ func hasMountAccess(ctx context.Context, acl *ACL, path string) bool { perms.CapabilitiesBitmap&ListCapabilityInt > 0, perms.CapabilitiesBitmap&ReadCapabilityInt > 0, perms.CapabilitiesBitmap&SudoCapabilityInt > 0, - perms.CapabilitiesBitmap&UpdateCapabilityInt > 0: + perms.CapabilitiesBitmap&UpdateCapabilityInt > 0, + perms.CapabilitiesBitmap&PatchCapabilityInt > 0: aclCapabilitiesGiven = true @@ -3684,6 +3685,9 @@ func (b *SystemBackend) pathInternalUIResultantACL(ctx context.Context, req *log if perms.CapabilitiesBitmap&UpdateCapabilityInt > 0 { capabilities = append(capabilities, UpdateCapability) } + if perms.CapabilitiesBitmap&PatchCapabilityInt > 0 { + capabilities = append(capabilities, PatchCapability) + } // If "deny" is explicitly set or if the path has no capabilities at all, // set the path capabilities to "deny" diff --git a/vault/policy.go b/vault/policy.go index a4686b1d81c74..e80d1657e98dd 100644 --- a/vault/policy.go +++ b/vault/policy.go @@ -26,6 +26,7 @@ const ( ListCapability = "list" SudoCapability = "sudo" RootCapability = "root" + PatchCapability = "patch" // Backwards compatibility OldDenyPathPolicy = "deny" @@ -42,6 +43,7 @@ const ( DeleteCapabilityInt ListCapabilityInt SudoCapabilityInt + PatchCapabilityInt ) // Error constants for testing @@ -83,6 +85,7 @@ var cap2Int = map[string]uint32{ DeleteCapability: DeleteCapabilityInt, ListCapability: ListCapabilityInt, SudoCapability: SudoCapabilityInt, + PatchCapability: PatchCapabilityInt, } type egpPath struct { @@ -390,7 +393,7 @@ func parsePaths(result *Policy, list *ast.ObjectList, performTemplating bool, en pc.Capabilities = []string{DenyCapability} pc.Permissions.CapabilitiesBitmap = DenyCapabilityInt goto PathFinished - case CreateCapability, ReadCapability, UpdateCapability, DeleteCapability, ListCapability, SudoCapability: + case CreateCapability, ReadCapability, UpdateCapability, DeleteCapability, ListCapability, SudoCapability, PatchCapability: pc.Permissions.CapabilitiesBitmap |= cap2Int[cap] default: return fmt.Errorf("path %q: invalid capability %q", key, cap) diff --git a/vault/policy_test.go b/vault/policy_test.go index 045fae8655387..7f09d7cb2be75 100644 --- a/vault/policy_test.go +++ b/vault/policy_test.go @@ -83,6 +83,9 @@ path "test/req" { capabilities = ["create", "sudo"] required_parameters = ["foo"] } +path "test/patch" { + capabilities = ["patch"] +} path "test/mfa" { capabilities = ["create", "sudo"] mfa_methods = ["my_totp", "my_totp2"] @@ -244,6 +247,13 @@ func TestPolicy_Parse(t *testing.T) { RequiredParameters: []string{"foo"}, }, }, + { + Path: "test/patch", + Capabilities: []string{"patch"}, + Permissions: &ACLPermissions{ + CapabilitiesBitmap: (PatchCapabilityInt), + }, + }, { Path: "test/mfa", Capabilities: []string{ From 9c6bd51754d6b871ea6c4d5f0315468f78036c74 Mon Sep 17 00:00:00 2001 From: Chelsea Shaw <82459713+hashishaw@users.noreply.github.com> Date: Wed, 13 Oct 2021 15:04:39 -0500 Subject: [PATCH 10/12] UI/OIDC provider (#12800) * Add new route w/ controller oidc-provider * oidc-provider controller has params, template has success message (temporary), model requests correct endpoint * Move oidc-provider route to under identity * Do not redirect after poll if on oidc-provider page * WIP provider -- beforeModel handles prompt, logout, redirect * Auth service fetch method rejects with fetch response if status >= 300 * New component OidcConsentBlock * Fix redirect to/from auth with cluster name, show error and consent form if applicable * Show error and consent form on template * Add component test, update docs * Test for oidc-consent-block component * Add changelog * fix tests * Add authorize to end of router path * Remove unused tests * Update changelog with feature name * Add descriptions for OidcConsentBlock component * glimmerize token-expire-warning and don't override yield if on oidc-provider route * remove text on token-expire-warning * Fix null transition.to on cluster redirect * Hide nav links if oidc-provider route --- changelog/12800.txt | 3 + ui/app/components/nav-header.js | 11 ++ ui/app/components/oidc-consent-block.js | 59 +++++++++ ui/app/components/token-expire-warning.js | 17 ++- .../vault/cluster/identity/oidc-provider.js | 24 ++++ ui/app/mixins/cluster-route.js | 5 +- ui/app/router.js | 4 + .../vault/cluster/identity/oidc-provider.js | 115 ++++++++++++++++++ ui/app/services/auth.js | 2 +- ui/app/templates/components/nav-header.hbs | 46 +++---- .../components/oidc-consent-block.hbs | 23 ++++ .../components/token-expire-warning.hbs | 4 +- .../vault/cluster/identity/oidc-provider.hbs | 25 ++++ .../addon/components/form-save-buttons.js | 1 + .../components/oidc-consent-block-test.js | 105 ++++++++++++++++ 15 files changed, 414 insertions(+), 30 deletions(-) create mode 100644 changelog/12800.txt create mode 100644 ui/app/components/oidc-consent-block.js create mode 100644 ui/app/controllers/vault/cluster/identity/oidc-provider.js create mode 100644 ui/app/routes/vault/cluster/identity/oidc-provider.js create mode 100644 ui/app/templates/components/oidc-consent-block.hbs create mode 100644 ui/app/templates/vault/cluster/identity/oidc-provider.hbs create mode 100644 ui/tests/integration/components/oidc-consent-block-test.js diff --git a/changelog/12800.txt b/changelog/12800.txt new file mode 100644 index 0000000000000..38aadc0170803 --- /dev/null +++ b/changelog/12800.txt @@ -0,0 +1,3 @@ +```release-note:feature +**OIDC Authorization Code Flow**: The Vault UI now supports OIDC Authorization Code Flow +``` diff --git a/ui/app/components/nav-header.js b/ui/app/components/nav-header.js index 509b88ed4b6d8..ce5f5638f62d8 100644 --- a/ui/app/components/nav-header.js +++ b/ui/app/components/nav-header.js @@ -1,10 +1,21 @@ import Component from '@ember/component'; +import { inject as service } from '@ember/service'; +import { computed } from '@ember/object'; + export default Component.extend({ + router: service(), 'data-test-navheader': true, classNameBindings: 'consoleFullscreen:panel-fullscreen', tagName: 'header', navDrawerOpen: false, consoleFullscreen: false, + hideLinks: computed('router.currentRouteName', function() { + let currentRoute = this.router.currentRouteName; + if ('vault.cluster.identity.oidc-provider' === currentRoute) { + return true; + } + return false; + }), actions: { toggleNavDrawer(isOpen) { if (isOpen !== undefined) { diff --git a/ui/app/components/oidc-consent-block.js b/ui/app/components/oidc-consent-block.js new file mode 100644 index 0000000000000..b5a9e6fbddbb1 --- /dev/null +++ b/ui/app/components/oidc-consent-block.js @@ -0,0 +1,59 @@ +/** + * @module OidcConsentBlock + * OidcConsentBlock components are used to show the consent form for the OIDC Authorization Code Flow + * + * @example + * ```js + * + * ``` + * @param {string} redirect - redirect is the URL where successful consent will redirect to + * @param {string} code - code is the string required to pass back to redirect on successful OIDC auth + * @param {string} [state] - state is a string which is required to return on redirect if provided, but optional generally + */ + +import Ember from 'ember'; +import Component from '@glimmer/component'; +import { action } from '@ember/object'; +import { tracked } from '@glimmer/tracking'; + +const validParameters = ['code', 'state']; +export default class OidcConsentBlockComponent extends Component { + @tracked didCancel = false; + + get win() { + return this.window || window; + } + + buildUrl(urlString, params) { + try { + let url = new URL(urlString); + Object.keys(params).forEach(key => { + if (params[key] && validParameters.includes(key)) { + url.searchParams.append(key, params[key]); + } + }); + return url; + } catch (e) { + console.debug('DEBUG: parsing url failed for', urlString); + throw new Error('Invalid URL'); + } + } + + @action + handleSubmit(evt) { + evt.preventDefault(); + let { redirect, ...params } = this.args; + let redirectUrl = this.buildUrl(redirect, params); + if (Ember.testing) { + this.args.testRedirect(redirectUrl.toString()); + } else { + this.win.location.replace(redirectUrl); + } + } + + @action + handleCancel(evt) { + evt.preventDefault(); + this.didCancel = true; + } +} diff --git a/ui/app/components/token-expire-warning.js b/ui/app/components/token-expire-warning.js index 4798652642ba8..eb2884b5043c1 100644 --- a/ui/app/components/token-expire-warning.js +++ b/ui/app/components/token-expire-warning.js @@ -1,5 +1,14 @@ -import Component from '@ember/component'; +import Component from '@glimmer/component'; +import { inject as service } from '@ember/service'; -export default Component.extend({ - tagName: '', -}); +export default class TokenExpireWarning extends Component { + @service router; + + get showWarning() { + let currentRoute = this.router.currentRouteName; + if ('vault.cluster.identity.oidc-provider' === currentRoute) { + return false; + } + return !!this.args.expirationDate; + } +} diff --git a/ui/app/controllers/vault/cluster/identity/oidc-provider.js b/ui/app/controllers/vault/cluster/identity/oidc-provider.js new file mode 100644 index 0000000000000..8b656eca92c95 --- /dev/null +++ b/ui/app/controllers/vault/cluster/identity/oidc-provider.js @@ -0,0 +1,24 @@ +import Controller from '@ember/controller'; + +export default class VaultClusterIdentityOidcProviderController extends Controller { + queryParams = [ + 'scope', // * + 'response_type', // * + 'client_id', // * + 'redirect_uri', // * + 'state', // * + 'nonce', // * + 'display', + 'prompt', + 'max_age', + ]; + scope = null; + response_type = null; + client_id = null; + redirect_uri = null; + state = null; + nonce = null; + display = null; + prompt = null; + max_age = null; +} diff --git a/ui/app/mixins/cluster-route.js b/ui/app/mixins/cluster-route.js index 67dae10860cfc..66127b89b17e3 100644 --- a/ui/app/mixins/cluster-route.js +++ b/ui/app/mixins/cluster-route.js @@ -7,6 +7,7 @@ const AUTH = 'vault.cluster.auth'; const CLUSTER = 'vault.cluster'; const CLUSTER_INDEX = 'vault.cluster.index'; const OIDC_CALLBACK = 'vault.cluster.oidc-callback'; +const OIDC_PROVIDER = 'vault.cluster.identity.oidc-provider'; const DR_REPLICATION_SECONDARY = 'vault.cluster.replication-dr-promote'; const DR_REPLICATION_SECONDARY_DETAILS = 'vault.cluster.replication-dr-promote.details'; const EXCLUDED_REDIRECT_URLS = ['/vault/logout']; @@ -20,7 +21,9 @@ export default Mixin.create({ transitionToTargetRoute(transition = {}) { const targetRoute = this.targetRouteName(transition); - + if (OIDC_PROVIDER === this.router.currentRouteName || OIDC_PROVIDER === transition?.to?.name) { + return RSVP.resolve(); + } if ( targetRoute && targetRoute !== this.routeName && diff --git a/ui/app/router.js b/ui/app/router.js index 090def1fa2acc..95a6e83e5f8e9 100644 --- a/ui/app/router.js +++ b/ui/app/router.js @@ -139,6 +139,10 @@ Router.map(function() { } this.route('not-found', { path: '/*path' }); + + this.route('identity', function() { + this.route('oidc-provider', { path: '/oidc/provider/:oidc_name/authorize' }); + }); }); this.route('not-found', { path: '/*path' }); }); diff --git a/ui/app/routes/vault/cluster/identity/oidc-provider.js b/ui/app/routes/vault/cluster/identity/oidc-provider.js new file mode 100644 index 0000000000000..7f5a4c66ac4d4 --- /dev/null +++ b/ui/app/routes/vault/cluster/identity/oidc-provider.js @@ -0,0 +1,115 @@ +import Route from '@ember/routing/route'; +import { inject as service } from '@ember/service'; + +const AUTH = 'vault.cluster.auth'; +const PROVIDER = 'vault.cluster.identity.oidc-provider'; + +export default class VaultClusterIdentityOidcProviderRoute extends Route { + @service auth; + @service router; + + get win() { + return this.window || window; + } + + _redirect(url, params) { + let redir = this._buildUrl(url, params); + this.win.location.replace(redir); + } + + beforeModel(transition) { + const currentToken = this.auth.get('currentTokenName'); + let { redirect_to, ...qp } = transition.to.queryParams; + console.debug('DEBUG: removing redirect_to', redirect_to); + if (!currentToken && 'none' === qp.prompt?.toLowerCase()) { + this._redirect(qp.redirect_uri, { + state: qp.state, + error: 'login_required', + }); + } else if (!currentToken || 'login' === qp.prompt?.toLowerCase()) { + if ('login' === qp.prompt?.toLowerCase()) { + this.auth.deleteCurrentToken(); + qp.prompt = null; + } + let { cluster_name } = this.paramsFor('vault.cluster'); + let url = this.router.urlFor(transition.to.name, transition.to.params, { queryParams: qp }); + return this.transitionTo(AUTH, cluster_name, { queryParams: { redirect_to: url } }); + } + } + + _redirectToAuth(oidcName, queryParams, logout = false) { + let { cluster_name } = this.paramsFor('vault.cluster'); + let currentRoute = this.router.urlFor(PROVIDER, oidcName, { queryParams }); + if (logout) { + this.auth.deleteCurrentToken(); + } + return this.transitionTo(AUTH, cluster_name, { queryParams: { redirect_to: currentRoute } }); + } + + _buildUrl(urlString, params) { + try { + let url = new URL(urlString); + Object.keys(params).forEach(key => { + if (params[key]) { + url.searchParams.append(key, params[key]); + } + }); + return url; + } catch (e) { + console.debug('DEBUG: parsing url failed for', urlString); + throw new Error('Invalid URL'); + } + } + + _handleSuccess(response, baseUrl, state) { + const { code } = response; + let redirectUrl = this._buildUrl(baseUrl, { code, state }); + this.win.location.replace(redirectUrl); + } + _handleError(errorResp, baseUrl) { + let redirectUrl = this._buildUrl(baseUrl, { ...errorResp }); + this.win.location.replace(redirectUrl); + } + + async model(params) { + let { oidc_name, ...qp } = params; + let decodedRedirect = decodeURI(qp.redirect_uri); + let url = this._buildUrl(`${this.win.origin}/v1/identity/oidc/provider/${oidc_name}/authorize`, qp); + try { + const response = await this.auth.ajax(url, 'GET', {}); + if ('consent' === qp.prompt?.toLowerCase()) { + return { + consent: { + code: response.code, + redirect: decodedRedirect, + state: qp.state, + }, + }; + } + this._handleSuccess(response, decodedRedirect, qp.state); + } catch (errorRes) { + let resp = await errorRes.json(); + let code = resp.error; + if (code === 'max_age_violation') { + this._redirectToAuth(oidc_name, qp, true); + } else if (code === 'invalid_redirect_uri') { + return { + error: { + title: 'Redirect URI mismatch', + message: + 'The provided redirect_uri is not in the list of allowed redirect URIs. Please make sure you are sending a valid redirect URI from your application.', + }, + }; + } else if (code === 'invalid_client_id') { + return { + error: { + title: 'Invalid client ID', + message: 'Your client ID is invalid. Please update your configuration and try again.', + }, + }; + } else { + this._handleError(resp, decodedRedirect); + } + } + } +} diff --git a/ui/app/services/auth.js b/ui/app/services/auth.js index 9b90965d38556..d8ceb145516a2 100644 --- a/ui/app/services/auth.js +++ b/ui/app/services/auth.js @@ -97,7 +97,7 @@ export default Service.extend({ } else if (response.status >= 200 && response.status < 300) { return resolve(response.json()); } else { - return reject(); + return reject(response); } }); }, diff --git a/ui/app/templates/components/nav-header.hbs b/ui/app/templates/components/nav-header.hbs index c6ad8738cf3cf..8a60fecdce2f0 100644 --- a/ui/app/templates/components/nav-header.hbs +++ b/ui/app/templates/components/nav-header.hbs @@ -9,30 +9,32 @@ {{/unless}} -