From 5bf301c17800829c854f9d811932f4fc43676acc Mon Sep 17 00:00:00 2001 From: hamid ghaf Date: Thu, 2 Sep 2021 16:37:54 -0700 Subject: [PATCH 01/25] Customizing HTTP headers in the config file --- command/agent.go | 7 +- command/agent/cache/handler.go | 17 +- command/agent/cache/lease_cache.go | 14 +- command/server.go | 12 ++ http/cors.go | 14 +- http/handler.go | 150 +++++++++++---- http/help.go | 12 +- http/logical.go | 41 ++-- http/sys_feature_flags.go | 13 +- http/sys_generate_root.go | 71 +++++-- http/sys_health.go | 27 ++- http/sys_init.go | 33 +++- http/sys_leader.go | 18 +- http/sys_metrics.go | 17 +- http/sys_raft.go | 38 +++- http/sys_rekey.go | 139 ++++++++++---- http/sys_seal.go | 72 +++++-- http/testing.go | 2 +- http/util.go | 16 +- .../configutil/http_response_headers.go | 176 ++++++++++++++++++ internalshared/configutil/listener.go | 28 +++ .../listenerutil/response_headers.go | 75 ++++++++ sdk/logical/response_util.go | 2 +- vault/core.go | 80 ++++++++ vault/external_tests/raft/raft_test.go | 80 ++++++++ vault/logical_system.go | 15 +- 26 files changed, 990 insertions(+), 179 deletions(-) create mode 100644 internalshared/configutil/http_response_headers.go create mode 100644 internalshared/listenerutil/response_headers.go diff --git a/command/agent.go b/command/agent.go index cbbcba5757b4a..9cbb6245a01c3 100644 --- a/command/agent.go +++ b/command/agent.go @@ -878,9 +878,10 @@ func (c *AgentCommand) Run(args []string) int { func verifyRequestHeader(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if val, ok := r.Header[consts.RequestHeaderName]; !ok || len(val) != 1 || val[0] != "true" { - logical.RespondError(w, - http.StatusPreconditionFailed, - errors.New(fmt.Sprintf("missing '%s' header", consts.RequestHeaderName))) + err := errors.New(fmt.Sprintf("missing '%s' header", consts.RequestHeaderName)) + status := http.StatusPreconditionFailed + logical.AdjustErrorStatusCode(&status, err) + logical.RespondError(w, status, err) return } diff --git a/command/agent/cache/handler.go b/command/agent/cache/handler.go index 73062df41fbd0..e33c9beffab86 100644 --- a/command/agent/cache/handler.go +++ b/command/agent/cache/handler.go @@ -37,8 +37,11 @@ func Handler(ctx context.Context, logger hclog.Logger, proxier Proxier, inmemSin // Parse and reset body. reqBody, err := ioutil.ReadAll(r.Body) if err != nil { - logger.Error("failed to read request body") - logical.RespondError(w, http.StatusInternalServerError, errors.New("failed to read request body")) + errRet := errors.New("failed to read request body") + logger.Error(errRet.Error()) + status := http.StatusInternalServerError + logical.AdjustErrorStatusCode(&status, errRet) + logical.RespondError(w, status, errRet) return } if r.Body != nil { @@ -59,14 +62,20 @@ func Handler(ctx context.Context, logger hclog.Logger, proxier Proxier, inmemSin w.WriteHeader(resp.Response.StatusCode) io.Copy(w, resp.Response.Body) } else { - logical.RespondError(w, http.StatusInternalServerError, fmt.Errorf("failed to get the response: %w", err)) + status := http.StatusInternalServerError + errNew := fmt.Errorf("failed to get the response: %w", err) + logical.AdjustErrorStatusCode(&status, errNew) + logical.RespondError(w, status, errNew) } return } err = processTokenLookupResponse(ctx, logger, inmemSink, req, resp) if err != nil { - logical.RespondError(w, http.StatusInternalServerError, fmt.Errorf("failed to process token lookup response: %w", err)) + status := http.StatusInternalServerError + errNew := fmt.Errorf("failed to process token lookup response: %w", err) + logical.AdjustErrorStatusCode(&status, errNew) + logical.RespondError(w, status, errNew) return } diff --git a/command/agent/cache/lease_cache.go b/command/agent/cache/lease_cache.go index a8b2d4bd88cea..a4e739378a27b 100644 --- a/command/agent/cache/lease_cache.go +++ b/command/agent/cache/lease_cache.go @@ -576,7 +576,10 @@ func (c *LeaseCache) HandleCacheClear(ctx context.Context) http.Handler { if err == io.EOF { err = errors.New("empty JSON provided") } - logical.RespondError(w, http.StatusBadRequest, fmt.Errorf("failed to parse JSON input: %w", err)) + status := http.StatusBadRequest + errNew := fmt.Errorf("failed to parse JSON input: %w", err) + logical.AdjustErrorStatusCode(&status, errNew) + logical.RespondError(w, status, errNew) return } @@ -585,7 +588,10 @@ func (c *LeaseCache) HandleCacheClear(ctx context.Context) http.Handler { in, err := parseCacheClearInput(req) if err != nil { c.logger.Error("unable to parse clear input", "error", err) - logical.RespondError(w, http.StatusBadRequest, fmt.Errorf("failed to parse clear input: %w", err)) + status := http.StatusBadRequest + errNew := fmt.Errorf("failed to parse clear input: %w", err) + logical.AdjustErrorStatusCode(&status, errNew) + logical.RespondError(w, status, errNew) return } @@ -596,7 +602,9 @@ func (c *LeaseCache) HandleCacheClear(ctx context.Context) http.Handler { if err == errInvalidType { httpStatus = http.StatusBadRequest } - logical.RespondError(w, httpStatus, fmt.Errorf("failed to clear cache: %w", err)) + errNew := fmt.Errorf("failed to clear cache: %w", err) + logical.AdjustErrorStatusCode(&httpStatus, errNew) + logical.RespondError(w, httpStatus, errNew) return } diff --git a/command/server.go b/command/server.go index dedd009de6766..182d1310fd459 100644 --- a/command/server.go +++ b/command/server.go @@ -1341,6 +1341,9 @@ func (c *ServerCommand) Run(args []string) int { } } + // Sanitizing listener config from invalid custom headers + core.SanitizedCustomResponseHeader(config) + status, lns, clusterAddrs, errMsg := c.InitListeners(config, disableClustering, &infoKeys, &info) if status != 0 { @@ -1571,6 +1574,15 @@ func (c *ServerCommand) Run(args []string) int { c.UI.Error(err.Error()) } + // Reload Custom headers + if err = core.ReloadCustomHeadersListenerConf(); err != nil { + c.UI.Error(err.Error()) + } + // Sanitizing listener config from invalid patterns + core.SanitizedCustomResponseHeader(config) + core.Logger().Info("**** Listern config after sanitization", "Li", config.Listeners) + + select { case c.licenseReloadedCh <- err: default: diff --git a/http/cors.go b/http/cors.go index 74cfeeaef072e..685108ff4b03e 100644 --- a/http/cors.go +++ b/http/cors.go @@ -2,6 +2,7 @@ package http import ( "fmt" + "github.com/hashicorp/vault/internalshared/listenerutil" "net/http" "strings" @@ -37,15 +38,22 @@ func wrapCORSHandler(h http.Handler, core *vault.Core) http.Handler { h.ServeHTTP(w, req) return } - + // Getting custom headers from listener's config + la := w.Header().Get("X-Vault-Listener-Add") + lc, err := core.GetCustomResponseHeaders(la) + if err != nil { + core.Logger().Debug("failed to get custom headers from listener config") + } // Return a 403 if the origin is not allowed to make cross-origin requests. if !corsConf.IsValidOrigin(origin) { - respondError(w, http.StatusForbidden, fmt.Errorf("origin not allowed")) + respondError(w, http.StatusForbidden, fmt.Errorf("origin not allowed"), lc) return } if req.Method == http.MethodOptions && !strutil.StrListContains(allowedMethods, requestMethod) { - w.WriteHeader(http.StatusMethodNotAllowed) + status := http.StatusMethodNotAllowed + listenerutil.SetCustomResponseHeaders(lc, w, status) + w.WriteHeader(status) return } diff --git a/http/handler.go b/http/handler.go index 7d48f97aee8fe..0a0b52661733b 100644 --- a/http/handler.go +++ b/http/handler.go @@ -6,6 +6,7 @@ import ( "encoding/json" "errors" "fmt" + "github.com/hashicorp/vault/internalshared/listenerutil" "io" "io/fs" "io/ioutil" @@ -27,6 +28,7 @@ import ( "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/internalshared/configutil" "github.com/hashicorp/vault/sdk/helper/consts" + "github.com/hashicorp/vault/sdk/helper/headerutil" "github.com/hashicorp/vault/sdk/helper/jsonutil" "github.com/hashicorp/vault/sdk/helper/pathmanager" "github.com/hashicorp/vault/sdk/logical" @@ -166,8 +168,8 @@ func Handler(props *vault.HandlerProperties) http.Handler { } else { mux.Handle("/ui/", handleUIHeaders(core, handleUIStub())) } - mux.Handle("/ui", handleUIRedirect()) - mux.Handle("/", handleUIRedirect()) + mux.Handle("/ui", handleUIRedirect(core)) + mux.Handle("/", handleUIRedirect(core)) } @@ -245,9 +247,16 @@ func handleAuditNonLogical(core *vault.Core, h http.Handler) http.Handler { origBody := new(bytes.Buffer) reader := ioutil.NopCloser(io.TeeReader(r.Body, origBody)) r.Body = reader + // Getting custom headers from listener's config + la := w.Header().Get("X-Vault-Listener-Add") + lc, err := core.GetCustomResponseHeaders(la) + if err != nil { + core.Logger().Debug("failed to get custom headers from listener config") + } + req, _, status, err := buildLogicalRequestNoAuth(core.PerfStandby(), w, r) if err != nil || status != 0 { - respondError(w, status, err) + respondError(w, status, err, lc) return } if origBody != nil { @@ -258,7 +267,7 @@ func handleAuditNonLogical(core *vault.Core, h http.Handler) http.Handler { } err = core.AuditLogger().AuditRequest(r.Context(), input) if err != nil { - respondError(w, status, err) + respondError(w, status, err, lc) return } cw := newCopyResponseWriter(w) @@ -272,7 +281,7 @@ func handleAuditNonLogical(core *vault.Core, h http.Handler) http.Handler { input.Response = logical.HTTPResponseToLogicalResponse(httpResp) err = core.AuditLogger().AuditResponse(r.Context(), input) if err != nil { - respondError(w, status, err) + respondError(w, status, err, lc) } return }) @@ -334,11 +343,20 @@ func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerPr w.Header().Set("X-Vault-Hostname", hostname) } + // Setting listener address so that we could get the config from core + w.Header().Set("X-Vault-Listener-Add", props.ListenerConfig.Address) + + // Getting custom headers from listener's config + lc, err := core.GetCustomResponseHeaders(props.ListenerConfig.Address) + if err != nil { + core.Logger().Debug("failed to get custom headers from listener config") + } + switch { case strings.HasPrefix(r.URL.Path, "/v1/"): newR, status := adjustRequest(core, r) if status != 0 { - respondError(w, status, nil) + respondError(w, status, nil, lc) cancelFunc() return } @@ -346,7 +364,7 @@ func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerPr case strings.HasPrefix(r.URL.Path, "/ui"), r.URL.Path == "/robots.txt", r.URL.Path == "/": default: - respondError(w, http.StatusNotFound, nil) + respondError(w, http.StatusNotFound, nil, lc) cancelFunc() return } @@ -376,7 +394,7 @@ func WrapForwardedForHandler(h http.Handler, l *configutil.Listener) http.Handle h.ServeHTTP(w, r) return } - respondError(w, http.StatusBadRequest, fmt.Errorf("missing x-forwarded-for header and configured to reject when not present")) + respondError(w, http.StatusBadRequest, fmt.Errorf("missing x-forwarded-for header and configured to reject when not present"), l.CustomResponseHeaders) return } @@ -389,7 +407,7 @@ func WrapForwardedForHandler(h http.Handler, l *configutil.Listener) http.Handle h.ServeHTTP(w, r) return } - respondError(w, http.StatusBadRequest, fmt.Errorf("error parsing client hostport: %w", err)) + respondError(w, http.StatusBadRequest, fmt.Errorf("error parsing client hostport: %w", err), l.CustomResponseHeaders) return } @@ -400,7 +418,7 @@ func WrapForwardedForHandler(h http.Handler, l *configutil.Listener) http.Handle h.ServeHTTP(w, r) return } - respondError(w, http.StatusBadRequest, fmt.Errorf("error parsing client address: %w", err)) + respondError(w, http.StatusBadRequest, fmt.Errorf("error parsing client address: %w", err), l.CustomResponseHeaders) return } @@ -418,7 +436,7 @@ func WrapForwardedForHandler(h http.Handler, l *configutil.Listener) http.Handle h.ServeHTTP(w, r) return } - respondError(w, http.StatusBadRequest, fmt.Errorf("client address not authorized for x-forwarded-for and configured to reject connection")) + respondError(w, http.StatusBadRequest, fmt.Errorf("client address not authorized for x-forwarded-for and configured to reject connection"), l.CustomResponseHeaders) return } @@ -446,7 +464,7 @@ func WrapForwardedForHandler(h http.Handler, l *configutil.Listener) http.Handle // authorized (or we've turned off explicit rejection) and we // should assume that what comes in should be properly // formatted. - respondError(w, http.StatusBadRequest, fmt.Errorf("malformed x-forwarded-for configuration or request, hops to skip (%d) would skip before earliest chain link (chain length %d)", hopSkips, len(headers))) + respondError(w, http.StatusBadRequest, fmt.Errorf("malformed x-forwarded-for configuration or request, hops to skip (%d) would skip before earliest chain link (chain length %d)", hopSkips, len(headers)), l.CustomResponseHeaders) return } @@ -475,9 +493,16 @@ func handleUIHeaders(core *vault.Core, h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { header := w.Header() + // Getting custom headers from listener's config + la := w.Header().Get("X-Vault-Listener-Add") + lc, err := core.GetCustomResponseHeaders(la) + if err != nil { + core.Logger().Debug("failed to get custom headers from listener config") + } + userHeaders, err := core.UIHeaders() if err != nil { - respondError(w, http.StatusInternalServerError, err) + respondError(w, http.StatusInternalServerError, err, lc) return } if userHeaders != nil { @@ -486,6 +511,12 @@ func handleUIHeaders(core *vault.Core, h http.Handler) http.Handler { header.Set(k, v) } } + + // just setting all default headers in the config file. + // This might overwrite some headers there already set + listenerutil.SetCustomResponseHeaders(lc, w, headerutil.DefaultStatus) + // TODO: Setting 200 series as well + h.ServeHTTP(w, req) }) } @@ -573,9 +604,18 @@ func handleUIStub() http.Handler { }) } -func handleUIRedirect() http.Handler { +func handleUIRedirect(core *vault.Core) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - http.Redirect(w, req, "/ui/", 307) + // Getting custom headers from listener's config + la := w.Header().Get("X-Vault-Listener-Add") + lc, err := core.GetCustomResponseHeaders(la) + if err != nil { + core.Logger().Debug("failed to get custom headers from listener config") + } + + status := 307 + listenerutil.SetCustomResponseHeaders(lc, w, status) + http.Redirect(w, req, "/ui/", status) return }) } @@ -727,10 +767,16 @@ func forwardBasedOnHeaders(core *vault.Core, r *http.Request) (bool, error) { // falling back on the older behavior of redirecting the client func handleRequestForwarding(core *vault.Core, handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Getting custom headers from listener's config + la := w.Header().Get("X-Vault-Listener-Add") + lc, err := core.GetCustomResponseHeaders(la) + if err != nil { + core.Logger().Debug("failed to get custom headers from listener config") + } // Note if the client requested forwarding shouldForward, err := forwardBasedOnHeaders(core, r) if err != nil { - respondError(w, http.StatusBadRequest, err) + respondError(w, http.StatusBadRequest, err, lc) return } @@ -738,7 +784,7 @@ func handleRequestForwarding(core *vault.Core, handler http.Handler) http.Handle if core.PerfStandby() && !shouldForward { ns, err := namespace.FromContext(r.Context()) if err != nil { - respondError(w, http.StatusBadRequest, err) + respondError(w, http.StatusBadRequest, err, lc) return } path := ns.TrimmedPath(r.URL.Path[len("/v1/"):]) @@ -766,7 +812,7 @@ func handleRequestForwarding(core *vault.Core, handler http.Handler) http.Handle return } // Some internal error occurred - respondError(w, http.StatusInternalServerError, err) + respondError(w, http.StatusInternalServerError, err, lc) return } if isLeader { @@ -775,7 +821,7 @@ func handleRequestForwarding(core *vault.Core, handler http.Handler) http.Handle return } if leaderAddr == "" { - respondError(w, http.StatusInternalServerError, fmt.Errorf("local node not active but active cluster node not found")) + respondError(w, http.StatusInternalServerError, fmt.Errorf("local node not active but active cluster node not found"), lc) return } @@ -785,26 +831,33 @@ func handleRequestForwarding(core *vault.Core, handler http.Handler) http.Handle } func forwardRequest(core *vault.Core, w http.ResponseWriter, r *http.Request) { + // Getting custom headers from listener's config + la := w.Header().Get("X-Vault-Listener-Add") + lc, err := core.GetCustomResponseHeaders(la) + if err != nil { + core.Logger().Debug("failed to get custom headers from listener config") + } + if r.Header.Get(vault.IntNoForwardingHeaderName) != "" { - respondStandby(core, w, r.URL) + respondStandby(core, w, r) return } if r.Header.Get(NoRequestForwardingHeaderName) != "" { // Forwarding explicitly disabled, fall back to previous behavior core.Logger().Debug("handleRequestForwarding: forwarding disabled by client request") - respondStandby(core, w, r.URL) + respondStandby(core, w, r) return } ns, err := namespace.FromContext(r.Context()) if err != nil { - respondError(w, http.StatusBadRequest, err) + respondError(w, http.StatusBadRequest, err, lc) return } path := ns.TrimmedPath(r.URL.Path[len("/v1/"):]) if alwaysRedirectPaths.HasPath(path) { - respondStandby(core, w, r.URL) + respondStandby(core, w, r) return } @@ -820,7 +873,7 @@ func forwardRequest(core *vault.Core, w http.ResponseWriter, r *http.Request) { } // Fall back to redirection - respondStandby(core, w, r.URL) + respondStandby(core, w, r) return } @@ -830,6 +883,7 @@ func forwardRequest(core *vault.Core, w http.ResponseWriter, r *http.Request) { } } + listenerutil.SetCustomResponseHeaders(lc, w, statusCode) w.WriteHeader(statusCode) w.Write(retBytes) } @@ -845,7 +899,7 @@ func request(core *vault.Core, w http.ResponseWriter, rawReq *http.Request, r *l resp.AddWarning("Timeout hit while waiting for local replicated cluster to apply primary's write; this client may encounter stale reads of values written during this operation.") } if errwrap.Contains(err, consts.ErrStandby.Error()) { - respondStandby(core, w, rawReq.URL) + respondStandby(core, w, rawReq) return resp, false, false } if err != nil && errwrap.Contains(err, logical.ErrPerfStandbyPleaseForward.Error()) { @@ -886,7 +940,13 @@ func request(core *vault.Core, w http.ResponseWriter, rawReq *http.Request, r *l return nil, true, false } - if respondErrorCommon(w, r, resp, err) { + // Getting custom headers from listener's config + la := w.Header().Get("X-Vault-Listener-Add") + lc, err := core.GetCustomResponseHeaders(la) + if err != nil { + core.Logger().Debug("failed to get custom headers from listener config") + } + if respondErrorCommon(w, r, resp, err, lc) { return resp, false, false } @@ -894,32 +954,40 @@ func request(core *vault.Core, w http.ResponseWriter, rawReq *http.Request, r *l } // respondStandby is used to trigger a redirect in the case that this Vault is currently a hot standby -func respondStandby(core *vault.Core, w http.ResponseWriter, reqURL *url.URL) { +func respondStandby(core *vault.Core, w http.ResponseWriter, req *http.Request) { + // Getting custom headers from listener's config + la := w.Header().Get("X-Vault-Listener-Add") + lc, err := core.GetCustomResponseHeaders(la) + if err != nil { + core.Logger().Debug("failed to get custom headers from listener config") + } + reqURL := req.URL // Request the leader address _, redirectAddr, _, err := core.Leader() if err != nil { if err == vault.ErrHANotEnabled { // Standalone node, serve 503 err = errors.New("node is not active") - respondError(w, http.StatusServiceUnavailable, err) + // TODO: set headers before all these responseError + respondError(w, http.StatusServiceUnavailable, err, lc) return } - respondError(w, http.StatusInternalServerError, err) + respondError(w, http.StatusInternalServerError, err, lc) return } // If there is no leader, generate a 503 error if redirectAddr == "" { err = errors.New("no active Vault instance found") - respondError(w, http.StatusServiceUnavailable, err) + respondError(w, http.StatusServiceUnavailable, err, lc) return } // Parse the redirect location redirectURL, err := url.Parse(redirectAddr) if err != nil { - respondError(w, http.StatusInternalServerError, err) + respondError(w, http.StatusInternalServerError, err, lc) return } @@ -940,6 +1008,7 @@ func respondStandby(core *vault.Core, w http.ResponseWriter, reqURL *url.URL) { // because we don't actually know if its permanent and // the request method should be preserved. w.Header().Set("Location", finalURL.String()) + listenerutil.SetCustomResponseHeaders(lc, w, 307) w.WriteHeader(307) } @@ -1104,27 +1173,34 @@ func isForm(head []byte, contentType string) bool { return true } -func respondError(w http.ResponseWriter, status int, err error) { +func respondError(w http.ResponseWriter, status int, err error, h map[string]map[string]string) { + logical.AdjustErrorStatusCode(&status, err) + listenerutil.SetCustomResponseHeaders(h, w, status) logical.RespondError(w, status, err) } -func respondErrorCommon(w http.ResponseWriter, req *logical.Request, resp *logical.Response, err error) bool { +func respondErrorCommon(w http.ResponseWriter, req *logical.Request, resp *logical.Response, err error, h map[string]map[string]string) bool { statusCode, newErr := logical.RespondErrorCommon(req, resp, err) if newErr == nil && statusCode == 0 { return false } - respondError(w, statusCode, newErr) + respondError(w, statusCode, newErr, h) return true } -func respondOk(w http.ResponseWriter, body interface{}) { +func respondOk(w http.ResponseWriter, body interface{}, h map[string]map[string]string) { w.Header().Set("Content-Type", "application/json") + var status int if body == nil { - w.WriteHeader(http.StatusNoContent) + status = http.StatusNoContent + listenerutil.SetCustomResponseHeaders(h, w, status) + w.WriteHeader(status) } else { - w.WriteHeader(http.StatusOK) + status = http.StatusOK + listenerutil.SetCustomResponseHeaders(h, w, status) + w.WriteHeader(status) enc := json.NewEncoder(w) enc.Encode(body) } diff --git a/http/help.go b/http/help.go index 45099bd7b67f5..c57675dcf84d2 100644 --- a/http/help.go +++ b/http/help.go @@ -26,9 +26,15 @@ func wrapHelpHandler(h http.Handler, core *vault.Core) http.Handler { } func handleHelp(core *vault.Core, w http.ResponseWriter, r *http.Request) { + // Getting custom headers from listener's config + la := w.Header().Get("X-Vault-Listener-Add") + lc, err := core.GetCustomResponseHeaders(la) + if err != nil { + core.Logger().Debug("failed to get custom headers from listener config") + } ns, err := namespace.FromContext(r.Context()) if err != nil { - respondError(w, http.StatusBadRequest, nil) + respondError(w, http.StatusBadRequest, nil, lc) return } path := ns.TrimmedPath(r.URL.Path[len("/v1/"):]) @@ -42,9 +48,9 @@ func handleHelp(core *vault.Core, w http.ResponseWriter, r *http.Request) { resp, err := core.HandleRequest(r.Context(), req) if err != nil { - respondErrorCommon(w, req, resp, err) + respondErrorCommon(w, req, resp, err, lc) return } - respondOk(w, resp.Data) + respondOk(w, resp.Data, lc) } diff --git a/http/logical.go b/http/logical.go index dd9abce34dfdb..2f2912ab904f2 100644 --- a/http/logical.go +++ b/http/logical.go @@ -5,6 +5,7 @@ import ( "encoding/base64" "encoding/json" "fmt" + "github.com/hashicorp/vault/internalshared/listenerutil" "io" "net" "net/http" @@ -268,17 +269,17 @@ func handleLogicalRecovery(raw *vault.RawBackend, token *atomic.String) http.Han return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { req, _, statusCode, err := buildLogicalRequestNoAuth(false, w, r) if err != nil || statusCode != 0 { - respondError(w, statusCode, err) + respondError(w, statusCode, err, nil) return } reqToken := r.Header.Get(consts.AuthHeaderName) if reqToken == "" || token.Load() == "" || reqToken != token.Load() { - respondError(w, http.StatusForbidden, nil) + respondError(w, http.StatusForbidden, nil, nil) return } resp, err := raw.HandleRequest(r.Context(), req) - if respondErrorCommon(w, req, resp, err) { + if respondErrorCommon(w, req, resp, err, nil) { return } @@ -287,7 +288,7 @@ func handleLogicalRecovery(raw *vault.RawBackend, token *atomic.String) http.Han httpResp = logical.LogicalResponseToHTTPResponse(resp) httpResp.RequestID = req.ID } - respondOk(w, httpResp) + respondOk(w, httpResp, nil) }) } @@ -296,9 +297,15 @@ func handleLogicalRecovery(raw *vault.RawBackend, token *atomic.String) http.Han // toggles. Refer to usage on functions for possible behaviors. func handleLogicalInternal(core *vault.Core, injectDataIntoTopLevel bool, noForward bool) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Getting custom headers from listener's config + la := w.Header().Get("X-Vault-Listener-Add") + lc, err := core.GetCustomResponseHeaders(la) + if err != nil { + core.Logger().Debug("failed to get custom headers from listener config") + } req, origBody, statusCode, err := buildLogicalRequest(core, w, r) if err != nil || statusCode != 0 { - respondError(w, statusCode, err) + respondError(w, statusCode, err, lc) return } @@ -310,7 +317,7 @@ func handleLogicalInternal(core *vault.Core, injectDataIntoTopLevel bool, noForw resp, ok, needsForward := request(core, w, r, req) switch { case needsForward && noForward: - respondError(w, http.StatusBadRequest, vault.ErrCannotForwardLocalOnly) + respondError(w, http.StatusBadRequest, vault.ErrCannotForwardLocalOnly, lc) return case needsForward && !noForward: if origBody != nil { @@ -341,17 +348,26 @@ func respondLogical(core *vault.Core, w http.ResponseWriter, r *http.Request, re return } + // Getting custom headers from listener's config + la := w.Header().Get("X-Vault-Listener-Add") + lc, err := core.GetCustomResponseHeaders(la) + if err != nil { + core.Logger().Debug("failed to get custom headers from listener config") + } if resp != nil { if resp.Redirect != "" { // If we have a redirect, redirect! We use a 307 code // because we don't actually know if its permanent. - http.Redirect(w, r, resp.Redirect, 307) + // TODO: need to set custom headers before calling the redirect + status := 307 + listenerutil.SetCustomResponseHeaders(lc, w, status) + http.Redirect(w, r, resp.Redirect, status) return } // Check if this is a raw response if _, ok := resp.Data[logical.HTTPStatusCode]; ok { - respondRaw(w, r, resp) + respondRaw(w, r, resp, lc) return } @@ -384,17 +400,19 @@ func respondLogical(core *vault.Core, w http.ResponseWriter, r *http.Request, re adjustResponse(core, w, req) // Respond - respondOk(w, ret) + respondOk(w, ret, lc) return } // respondRaw is used when the response is using HTTPContentType and HTTPRawBody // to change the default response handling. This is only used for specific things like // returning the CRL information on the PKI backends. -func respondRaw(w http.ResponseWriter, r *http.Request, resp *logical.Response) { +func respondRaw(w http.ResponseWriter, r *http.Request, resp *logical.Response, h map[string]map[string]string) { retErr := func(w http.ResponseWriter, err string) { w.Header().Set("X-Vault-Raw-Error", err) - w.WriteHeader(http.StatusInternalServerError) + code := http.StatusInternalServerError + listenerutil.SetCustomResponseHeaders(h, w, code) + w.WriteHeader(code) w.Write(nil) } @@ -483,6 +501,7 @@ WRITE_RESPONSE: w.Header().Set("Cache-Control", cacheControl) } + listenerutil.SetCustomResponseHeaders(h, w, status) w.WriteHeader(status) w.Write(body) } diff --git a/http/sys_feature_flags.go b/http/sys_feature_flags.go index 11ece32795b77..468e76e159d32 100644 --- a/http/sys_feature_flags.go +++ b/http/sys_feature_flags.go @@ -2,6 +2,7 @@ package http import ( "encoding/json" + "github.com/hashicorp/vault/internalshared/listenerutil" "net/http" "os" @@ -27,11 +28,17 @@ func featureFlagIsSet(name string) bool { func handleSysInternalFeatureFlags(core *vault.Core) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Getting custom headers from listener's config + la := w.Header().Get("X-Vault-Listener-Add") + lc, err := core.GetCustomResponseHeaders(la) + if err != nil { + core.Logger().Debug("failed to get custom headers from listener config") + } switch r.Method { case "GET": break default: - respondError(w, http.StatusMethodNotAllowed, nil) + respondError(w, http.StatusMethodNotAllowed, nil, lc) } response := &FeatureFlagsResponse{} @@ -43,7 +50,9 @@ func handleSysInternalFeatureFlags(core *vault.Core) http.Handler { } w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) + status := http.StatusOK + listenerutil.SetCustomResponseHeaders(lc, w, status) + w.WriteHeader(status) // Generate the response enc := json.NewEncoder(w) diff --git a/http/sys_generate_root.go b/http/sys_generate_root.go index 4ac3015077447..8cea12e474f1d 100644 --- a/http/sys_generate_root.go +++ b/http/sys_generate_root.go @@ -14,6 +14,12 @@ import ( func handleSysGenerateRootAttempt(core *vault.Core, generateStrategy vault.GenerateRootStrategy) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Getting custom headers from listener's config + la := w.Header().Get("X-Vault-Listener-Add") + lc, err := core.GetCustomResponseHeaders(la) + if err != nil { + core.Logger().Debug("failed to get custom headers from listener config") + } switch r.Method { case "GET": handleSysGenerateRootAttemptGet(core, w, r, "") @@ -22,7 +28,7 @@ func handleSysGenerateRootAttempt(core *vault.Core, generateStrategy vault.Gener case "DELETE": handleSysGenerateRootAttemptDelete(core, w, r) default: - respondError(w, http.StatusMethodNotAllowed, nil) + respondError(w, http.StatusMethodNotAllowed, nil, lc) } }) } @@ -31,14 +37,21 @@ func handleSysGenerateRootAttemptGet(core *vault.Core, w http.ResponseWriter, r ctx, cancel := core.GetContext() defer cancel() + // Getting custom headers from listener's config + la := w.Header().Get("X-Vault-Listener-Add") + lc, err := core.GetCustomResponseHeaders(la) + if err != nil { + core.Logger().Debug("failed to get custom headers from listener config") + } + // Get the current seal configuration barrierConfig, err := core.SealAccess().BarrierConfig(ctx) if err != nil { - respondError(w, http.StatusInternalServerError, err) + respondError(w, http.StatusInternalServerError, err, lc) return } if barrierConfig == nil { - respondError(w, http.StatusBadRequest, fmt.Errorf("server is not yet initialized")) + respondError(w, http.StatusBadRequest, fmt.Errorf("server is not yet initialized"), lc) return } @@ -46,7 +59,7 @@ func handleSysGenerateRootAttemptGet(core *vault.Core, w http.ResponseWriter, r if core.SealAccess().RecoveryKeySupported() { sealConfig, err = core.SealAccess().RecoveryConfig(ctx) if err != nil { - respondError(w, http.StatusInternalServerError, err) + respondError(w, http.StatusInternalServerError, err, lc) return } } @@ -54,14 +67,14 @@ func handleSysGenerateRootAttemptGet(core *vault.Core, w http.ResponseWriter, r // Get the generation configuration generationConfig, err := core.GenerateRootConfiguration() if err != nil { - respondError(w, http.StatusInternalServerError, err) + respondError(w, http.StatusInternalServerError, err, lc) return } // Get the progress progress, err := core.GenerateRootProgress() if err != nil { - respondError(w, http.StatusInternalServerError, err) + respondError(w, http.StatusInternalServerError, err, lc) return } @@ -80,14 +93,21 @@ func handleSysGenerateRootAttemptGet(core *vault.Core, w http.ResponseWriter, r status.PGPFingerprint = generationConfig.PGPFingerprint } - respondOk(w, status) + respondOk(w, status, lc) } func handleSysGenerateRootAttemptPut(core *vault.Core, w http.ResponseWriter, r *http.Request, generateStrategy vault.GenerateRootStrategy) { + // Getting custom headers from listener's config + la := w.Header().Get("X-Vault-Listener-Add") + lc, errNew := core.GetCustomResponseHeaders(la) + if errNew != nil { + core.Logger().Debug("failed to get custom headers from listener config") + } + // Parse the request var req GenerateRootInitRequest if _, err := parseJSONRequest(core.PerfStandby(), r, w, &req); err != nil && err != io.EOF { - respondError(w, http.StatusBadRequest, err) + respondError(w, http.StatusBadRequest, err, lc) return } @@ -100,14 +120,14 @@ func handleSysGenerateRootAttemptPut(core *vault.Core, w http.ResponseWriter, r genned = true req.OTP, err = base62.Random(vault.TokenLength + 2) if err != nil { - respondError(w, http.StatusInternalServerError, err) + respondError(w, http.StatusInternalServerError, err, lc) return } } // Attemptialize the generation if err := core.GenerateRootInit(req.OTP, req.PGPKey, generateStrategy); err != nil { - respondError(w, http.StatusBadRequest, err) + respondError(w, http.StatusBadRequest, err, lc) return } @@ -120,26 +140,40 @@ func handleSysGenerateRootAttemptPut(core *vault.Core, w http.ResponseWriter, r } func handleSysGenerateRootAttemptDelete(core *vault.Core, w http.ResponseWriter, r *http.Request) { - err := core.GenerateRootCancel() + // Getting custom headers from listener's config + la := w.Header().Get("X-Vault-Listener-Add") + lc, err := core.GetCustomResponseHeaders(la) if err != nil { - respondError(w, http.StatusInternalServerError, err) + core.Logger().Debug("failed to get custom headers from listener config") + } + + errNew := core.GenerateRootCancel() + if errNew != nil { + respondError(w, http.StatusInternalServerError, errNew, lc) return } - respondOk(w, nil) + respondOk(w, nil, lc) } func handleSysGenerateRootUpdate(core *vault.Core, generateStrategy vault.GenerateRootStrategy) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Getting custom headers from listener's config + la := w.Header().Get("X-Vault-Listener-Add") + lc, errNew := core.GetCustomResponseHeaders(la) + if errNew != nil { + core.Logger().Debug("failed to get custom headers from listener config") + } // Parse the request var req GenerateRootUpdateRequest if _, err := parseJSONRequest(core.PerfStandby(), r, w, &req); err != nil { - respondError(w, http.StatusBadRequest, err) + respondError(w, http.StatusBadRequest, err, lc) return } if req.Key == "" { respondError( w, http.StatusBadRequest, - errors.New("'key' must be specified in request body as JSON")) + errors.New("'key' must be specified in request body as JSON"), + lc) return } @@ -154,7 +188,8 @@ func handleSysGenerateRootUpdate(core *vault.Core, generateStrategy vault.Genera if err != nil { respondError( w, http.StatusBadRequest, - errors.New("'key' must be a valid hex or base64 string")) + errors.New("'key' must be a valid hex or base64 string"), + lc) return } } @@ -165,7 +200,7 @@ func handleSysGenerateRootUpdate(core *vault.Core, generateStrategy vault.Genera // Use the key to make progress on root generation result, err := core.GenerateRootUpdate(ctx, key, req.Nonce, generateStrategy) if err != nil { - respondError(w, http.StatusBadRequest, err) + respondError(w, http.StatusBadRequest, err, lc) return } @@ -183,7 +218,7 @@ func handleSysGenerateRootUpdate(core *vault.Core, generateStrategy vault.Genera resp.EncodedRootToken = result.EncodedToken } - respondOk(w, resp) + respondOk(w, resp, lc) }) } diff --git a/http/sys_health.go b/http/sys_health.go index fcaf4e1590999..9e310d17e7b0b 100644 --- a/http/sys_health.go +++ b/http/sys_health.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "github.com/hashicorp/vault/internalshared/listenerutil" "net/http" "strconv" "time" @@ -16,13 +17,19 @@ import ( func handleSysHealth(core *vault.Core) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Getting custom headers from listener's config + la := w.Header().Get("X-Vault-Listener-Add") + lc, err := core.GetCustomResponseHeaders(la) + if err != nil { + core.Logger().Debug("failed to get custom headers from listener config") + } switch r.Method { case "GET": handleSysHealthGet(core, w, r) case "HEAD": handleSysHealthHead(core, w, r) default: - respondError(w, http.StatusMethodNotAllowed, nil) + respondError(w, http.StatusMethodNotAllowed, nil, lc) } }) } @@ -41,19 +48,26 @@ func fetchStatusCode(r *http.Request, field string) (int, bool, bool) { } func handleSysHealthGet(core *vault.Core, w http.ResponseWriter, r *http.Request) { + // Getting custom headers from listener's config + la := w.Header().Get("X-Vault-Listener-Add") + lc, err := core.GetCustomResponseHeaders(la) + if err != nil { + core.Logger().Debug("failed to get custom headers from listener config") + } code, body, err := getSysHealth(core, r) if err != nil { core.Logger().Error("error checking health", "error", err) - respondError(w, code, nil) + respondError(w, code, nil, lc) return } if body == nil { - respondError(w, code, nil) + respondError(w, code, nil, lc) return } w.Header().Set("Content-Type", "application/json") + listenerutil.SetCustomResponseHeaders(lc, w, code) w.WriteHeader(code) // Generate the response @@ -67,6 +81,13 @@ func handleSysHealthHead(core *vault.Core, w http.ResponseWriter, r *http.Reques if body != nil { w.Header().Set("Content-Type", "application/json") } + // Getting custom headers from listener's config + la := w.Header().Get("X-Vault-Listener-Add") + lc, err := core.GetCustomResponseHeaders(la) + if err != nil { + core.Logger().Debug("failed to get custom headers from listener config") + } + listenerutil.SetCustomResponseHeaders(lc, w, code) w.WriteHeader(code) } diff --git a/http/sys_init.go b/http/sys_init.go index b21e5363ea020..4f9436eb4c82a 100644 --- a/http/sys_init.go +++ b/http/sys_init.go @@ -11,36 +11,55 @@ import ( func handleSysInit(core *vault.Core) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Getting custom headers from listener's config + la := w.Header().Get("X-Vault-Listener-Add") + lc, err := core.GetCustomResponseHeaders(la) + if err != nil { + core.Logger().Debug("failed to get custom headers from listener config") + } switch r.Method { case "GET": handleSysInitGet(core, w, r) case "PUT", "POST": handleSysInitPut(core, w, r) default: - respondError(w, http.StatusMethodNotAllowed, nil) + respondError(w, http.StatusMethodNotAllowed, nil, lc) } }) } func handleSysInitGet(core *vault.Core, w http.ResponseWriter, r *http.Request) { + // Getting custom headers from listener's config + la := w.Header().Get("X-Vault-Listener-Add") + lc, err := core.GetCustomResponseHeaders(la) + if err != nil { + core.Logger().Debug("failed to get custom headers from listener config") + } init, err := core.Initialized(context.Background()) if err != nil { - respondError(w, http.StatusInternalServerError, err) + respondError(w, http.StatusInternalServerError, err, lc) return } respondOk(w, &InitStatusResponse{ Initialized: init, - }) + }, lc) } func handleSysInitPut(core *vault.Core, w http.ResponseWriter, r *http.Request) { + // Getting custom headers from listener's config + la := w.Header().Get("X-Vault-Listener-Add") + lc, err := core.GetCustomResponseHeaders(la) + if err != nil { + core.Logger().Debug("failed to get custom headers from listener config") + } + ctx := context.Background() // Parse the request var req InitRequest if _, err := parseJSONRequest(core.PerfStandby(), r, w, &req); err != nil { - respondError(w, http.StatusBadRequest, err) + respondError(w, http.StatusBadRequest, err, lc) return } @@ -67,7 +86,7 @@ func handleSysInitPut(core *vault.Core, w http.ResponseWriter, r *http.Request) result, initErr := core.Initialize(ctx, initParams) if initErr != nil { if vault.IsFatalError(initErr) { - respondError(w, http.StatusBadRequest, initErr) + respondError(w, http.StatusBadRequest, initErr, lc) return } else { // Add a warnings field? The error will be logged in the vault log @@ -99,11 +118,11 @@ func handleSysInitPut(core *vault.Core, w http.ResponseWriter, r *http.Request) } if err := core.UnsealWithStoredKeys(ctx); err != nil { - respondError(w, http.StatusInternalServerError, err) + respondError(w, http.StatusInternalServerError, err, lc) return } - respondOk(w, resp) + respondOk(w, resp, lc) } type InitRequest struct { diff --git a/http/sys_leader.go b/http/sys_leader.go index 8c2ce21e5001d..4d1754946530b 100644 --- a/http/sys_leader.go +++ b/http/sys_leader.go @@ -10,20 +10,32 @@ import ( // or becomes the leader. func handleSysLeader(core *vault.Core) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Getting custom headers from listener's config + la := w.Header().Get("X-Vault-Listener-Add") + lc, err := core.GetCustomResponseHeaders(la) + if err != nil { + core.Logger().Debug("failed to get custom headers from listener config") + } switch r.Method { case "GET": handleSysLeaderGet(core, w, r) default: - respondError(w, http.StatusMethodNotAllowed, nil) + respondError(w, http.StatusMethodNotAllowed, nil, lc) } }) } func handleSysLeaderGet(core *vault.Core, w http.ResponseWriter, r *http.Request) { + // Getting custom headers from listener's config + la := w.Header().Get("X-Vault-Listener-Add") + lc, err := core.GetCustomResponseHeaders(la) + if err != nil { + core.Logger().Debug("failed to get custom headers from listener config") + } resp, err := core.GetLeaderStatus() if err != nil { - respondError(w, http.StatusInternalServerError, err) + respondError(w, http.StatusInternalServerError, err, lc) return } - respondOk(w, resp) + respondOk(w, resp, lc) } diff --git a/http/sys_metrics.go b/http/sys_metrics.go index 0e58be3ea262d..4e2ffc9bff18d 100644 --- a/http/sys_metrics.go +++ b/http/sys_metrics.go @@ -2,6 +2,7 @@ package http import ( "fmt" + "github.com/hashicorp/vault/internalshared/listenerutil" "net/http" "github.com/hashicorp/vault/helper/metricsutil" @@ -11,18 +12,24 @@ import ( func handleMetricsUnauthenticated(core *vault.Core) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Getting custom headers from listener's config + la := w.Header().Get("X-Vault-Listener-Add") + lc, err := core.GetCustomResponseHeaders(la) + if err != nil { + core.Logger().Debug("failed to get custom headers from listener config") + } req := &logical.Request{Headers: r.Header} switch r.Method { case "GET": default: - respondError(w, http.StatusMethodNotAllowed, nil) + respondError(w, http.StatusMethodNotAllowed, nil, lc) return } // Parse form if err := r.ParseForm(); err != nil { - respondError(w, http.StatusBadRequest, err) + respondError(w, http.StatusBadRequest, err, lc) return } @@ -35,7 +42,9 @@ func handleMetricsUnauthenticated(core *vault.Core) http.Handler { resp := core.MetricsHelper().ResponseForFormat(format) // Manually extract the logical response and send back the information - w.WriteHeader(resp.Data[logical.HTTPStatusCode].(int)) + status := resp.Data[logical.HTTPStatusCode].(int) + listenerutil.SetCustomResponseHeaders(lc, w, status) + w.WriteHeader(status) w.Header().Set("Content-Type", resp.Data[logical.HTTPContentType].(string)) switch v := resp.Data[logical.HTTPRawBody].(type) { case string: @@ -43,7 +52,7 @@ func handleMetricsUnauthenticated(core *vault.Core) http.Handler { case []byte: w.Write(v) default: - respondError(w, http.StatusInternalServerError, fmt.Errorf("wrong response returned")) + respondError(w, http.StatusInternalServerError, fmt.Errorf("wrong response returned"), lc) } }) } diff --git a/http/sys_raft.go b/http/sys_raft.go index 5db1a80fb78f6..a4590398660b8 100644 --- a/http/sys_raft.go +++ b/http/sys_raft.go @@ -15,19 +15,25 @@ import ( func handleSysRaftBootstrap(core *vault.Core) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Getting custom headers from listener's config + la := w.Header().Get("X-Vault-Listener-Add") + lc, err := core.GetCustomResponseHeaders(la) + if err != nil { + core.Logger().Debug("failed to get custom headers from listener config") + } switch r.Method { case "POST", "PUT": if core.Sealed() { - respondError(w, http.StatusBadRequest, errors.New("node must be unsealed to bootstrap")) + respondError(w, http.StatusBadRequest, errors.New("node must be unsealed to bootstrap"), lc) } if err := core.RaftBootstrap(context.Background(), false); err != nil { - respondError(w, http.StatusInternalServerError, err) + respondError(w, http.StatusInternalServerError, err, lc) return } default: - respondError(w, http.StatusBadRequest, nil) + respondError(w, http.StatusBadRequest, nil, lc) } }) } @@ -38,21 +44,33 @@ func handleSysRaftJoin(core *vault.Core) http.Handler { case "POST", "PUT": handleSysRaftJoinPost(core, w, r) default: - respondError(w, http.StatusMethodNotAllowed, nil) + // Getting custom headers from listener's config + la := w.Header().Get("X-Vault-Listener-Add") + lc, err := core.GetCustomResponseHeaders(la) + if err != nil { + core.Logger().Debug("failed to get custom headers from listener config") + } + respondError(w, http.StatusMethodNotAllowed, nil, lc) } }) } func handleSysRaftJoinPost(core *vault.Core, w http.ResponseWriter, r *http.Request) { + // Getting custom headers from listener's config + la := w.Header().Get("X-Vault-Listener-Add") + lc, errNew := core.GetCustomResponseHeaders(la) + if errNew != nil { + core.Logger().Debug("failed to get custom headers from listener config") + } // Parse the request var req JoinRequest if _, err := parseJSONRequest(core.PerfStandby(), r, w, &req); err != nil && err != io.EOF { - respondError(w, http.StatusBadRequest, err) + respondError(w, http.StatusBadRequest, err, lc) return } if req.NonVoter && !nonVotersAllowed { - respondError(w, http.StatusBadRequest, errors.New("non-voting nodes not allowed")) + respondError(w, http.StatusBadRequest, errors.New("non-voting nodes not allowed"), lc) return } @@ -61,14 +79,14 @@ func handleSysRaftJoinPost(core *vault.Core, w http.ResponseWriter, r *http.Requ if len(req.LeaderCACert) != 0 || len(req.LeaderClientCert) != 0 || len(req.LeaderClientKey) != 0 { tlsConfig, err = tlsutil.ClientTLSConfig([]byte(req.LeaderCACert), []byte(req.LeaderClientCert), []byte(req.LeaderClientKey)) if err != nil { - respondError(w, http.StatusBadRequest, err) + respondError(w, http.StatusBadRequest, err, lc) return } tlsConfig.ServerName = req.LeaderTLSServerName } if req.AutoJoinScheme != "" && (req.AutoJoinScheme != "http" && req.AutoJoinScheme != "https") { - respondError(w, http.StatusBadRequest, fmt.Errorf("invalid scheme '%s'; must either be http or https", req.AutoJoinScheme)) + respondError(w, http.StatusBadRequest, fmt.Errorf("invalid scheme '%s'; must either be http or https", req.AutoJoinScheme), lc) return } @@ -85,14 +103,14 @@ func handleSysRaftJoinPost(core *vault.Core, w http.ResponseWriter, r *http.Requ joined, err := core.JoinRaftCluster(context.Background(), leaderInfos, req.NonVoter) if err != nil { - respondError(w, http.StatusInternalServerError, err) + respondError(w, http.StatusInternalServerError, err, lc) return } resp := JoinResponse{ Joined: joined, } - respondOk(w, resp) + respondOk(w, resp, lc) } type JoinResponse struct { diff --git a/http/sys_rekey.go b/http/sys_rekey.go index d1cec653a6283..5679a5923528e 100644 --- a/http/sys_rekey.go +++ b/http/sys_rekey.go @@ -17,14 +17,20 @@ func handleSysRekeyInit(core *vault.Core, recovery bool) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { standby, _ := core.Standby() if standby { - respondStandby(core, w, r.URL) + respondStandby(core, w, r) return } - + // Getting custom headers from listener's config + la := w.Header().Get("X-Vault-Listener-Add") + lc, err := core.GetCustomResponseHeaders(la) + if err != nil { + core.Logger().Debug("failed to get custom headers from listener config") + } repState := core.ReplicationState() if repState.HasState(consts.ReplicationPerformanceSecondary) { respondError(w, http.StatusBadRequest, - fmt.Errorf("rekeying can only be performed on the primary cluster when replication is activated")) + fmt.Errorf("rekeying can only be performed on the primary cluster when replication is activated"), + lc) return } @@ -33,7 +39,7 @@ func handleSysRekeyInit(core *vault.Core, recovery bool) http.Handler { switch { case recovery && !core.SealAccess().RecoveryKeySupported(): - respondError(w, http.StatusBadRequest, fmt.Errorf("recovery rekeying not supported")) + respondError(w, http.StatusBadRequest, fmt.Errorf("recovery rekeying not supported"), lc) case r.Method == "GET": handleSysRekeyInitGet(ctx, core, recovery, w, r) case r.Method == "POST" || r.Method == "PUT": @@ -41,32 +47,38 @@ func handleSysRekeyInit(core *vault.Core, recovery bool) http.Handler { case r.Method == "DELETE": handleSysRekeyInitDelete(ctx, core, recovery, w, r) default: - respondError(w, http.StatusMethodNotAllowed, nil) + respondError(w, http.StatusMethodNotAllowed, nil, lc) } }) } func handleSysRekeyInitGet(ctx context.Context, core *vault.Core, recovery bool, w http.ResponseWriter, r *http.Request) { + // Getting custom headers from listener's config + la := w.Header().Get("X-Vault-Listener-Add") + lc, errNew := core.GetCustomResponseHeaders(la) + if errNew != nil { + core.Logger().Debug("failed to get custom headers from listener config") + } barrierConfig, barrierConfErr := core.SealAccess().BarrierConfig(ctx) if barrierConfErr != nil { - respondError(w, http.StatusInternalServerError, barrierConfErr) + respondError(w, http.StatusInternalServerError, barrierConfErr, lc) return } if barrierConfig == nil { - respondError(w, http.StatusBadRequest, fmt.Errorf("server is not yet initialized")) + respondError(w, http.StatusBadRequest, fmt.Errorf("server is not yet initialized"), lc) return } // Get the rekey configuration rekeyConf, err := core.RekeyConfig(recovery) if err != nil { - respondError(w, err.Code(), err) + respondError(w, err.Code(), err, lc) return } sealThreshold, err := core.RekeyThreshold(ctx, recovery) if err != nil { - respondError(w, err.Code(), err) + respondError(w, err.Code(), err, lc) return } @@ -81,7 +93,7 @@ func handleSysRekeyInitGet(ctx context.Context, core *vault.Core, recovery bool, // Get the progress started, progress, err := core.RekeyProgress(recovery, false) if err != nil { - respondError(w, err.Code(), err) + respondError(w, err.Code(), err, lc) return } @@ -95,31 +107,37 @@ func handleSysRekeyInitGet(ctx context.Context, core *vault.Core, recovery bool, if rekeyConf.PGPKeys != nil && len(rekeyConf.PGPKeys) != 0 { pgpFingerprints, err := pgpkeys.GetFingerprints(rekeyConf.PGPKeys, nil) if err != nil { - respondError(w, http.StatusInternalServerError, err) + respondError(w, http.StatusInternalServerError, err, lc) return } status.PGPFingerprints = pgpFingerprints status.Backup = rekeyConf.Backup } } - respondOk(w, status) + respondOk(w, status, lc) } func handleSysRekeyInitPut(ctx context.Context, core *vault.Core, recovery bool, w http.ResponseWriter, r *http.Request) { + // Getting custom headers from listener's config + la := w.Header().Get("X-Vault-Listener-Add") + lc, errNew := core.GetCustomResponseHeaders(la) + if errNew != nil { + core.Logger().Debug("failed to get custom headers from listener config") + } // Parse the request var req RekeyRequest if _, err := parseJSONRequest(core.PerfStandby(), r, w, &req); err != nil { - respondError(w, http.StatusBadRequest, err) + respondError(w, http.StatusBadRequest, err, lc) return } if req.Backup && len(req.PGPKeys) == 0 { - respondError(w, http.StatusBadRequest, fmt.Errorf("cannot request a backup of the new keys without providing PGP keys for encryption")) + respondError(w, http.StatusBadRequest, fmt.Errorf("cannot request a backup of the new keys without providing PGP keys for encryption"), lc) return } if len(req.PGPKeys) > 0 && len(req.PGPKeys) != req.SecretShares { - respondError(w, http.StatusBadRequest, fmt.Errorf("incorrect number of PGP keys for rekey")) + respondError(w, http.StatusBadRequest, fmt.Errorf("incorrect number of PGP keys for rekey"), lc) return } @@ -133,7 +151,7 @@ func handleSysRekeyInitPut(ctx context.Context, core *vault.Core, recovery bool, VerificationRequired: req.RequireVerification, }, recovery) if err != nil { - respondError(w, err.Code(), err) + respondError(w, err.Code(), err, lc) return } @@ -141,31 +159,44 @@ func handleSysRekeyInitPut(ctx context.Context, core *vault.Core, recovery bool, } func handleSysRekeyInitDelete(ctx context.Context, core *vault.Core, recovery bool, w http.ResponseWriter, r *http.Request) { + // Getting custom headers from listener's config + la := w.Header().Get("X-Vault-Listener-Add") + lc, err := core.GetCustomResponseHeaders(la) + if err != nil { + core.Logger().Debug("failed to get custom headers from listener config") + } if err := core.RekeyCancel(recovery); err != nil { - respondError(w, err.Code(), err) + respondError(w, err.Code(), err, lc) return } - respondOk(w, nil) + respondOk(w, nil, lc) } func handleSysRekeyUpdate(core *vault.Core, recovery bool) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Getting custom headers from listener's config + la := w.Header().Get("X-Vault-Listener-Add") + lc, errNew := core.GetCustomResponseHeaders(la) + if errNew != nil { + core.Logger().Debug("failed to get custom headers from listener config") + } standby, _ := core.Standby() if standby { - respondStandby(core, w, r.URL) + respondStandby(core, w, r) return } // Parse the request var req RekeyUpdateRequest if _, err := parseJSONRequest(core.PerfStandby(), r, w, &req); err != nil { - respondError(w, http.StatusBadRequest, err) + respondError(w, http.StatusBadRequest, err, lc) return } if req.Key == "" { respondError( w, http.StatusBadRequest, - errors.New("'key' must be specified in request body as JSON")) + errors.New("'key' must be specified in request body as JSON"), + lc) return } @@ -180,7 +211,8 @@ func handleSysRekeyUpdate(core *vault.Core, recovery bool) http.Handler { if err != nil { respondError( w, http.StatusBadRequest, - errors.New("'key' must be a valid hex or base64 string")) + errors.New("'key' must be a valid hex or base64 string"), + lc) return } } @@ -191,7 +223,7 @@ func handleSysRekeyUpdate(core *vault.Core, recovery bool) http.Handler { // Use the key to make progress on rekey result, rekeyErr := core.RekeyUpdate(ctx, key, req.Nonce, recovery) if rekeyErr != nil { - respondError(w, rekeyErr.Code(), rekeyErr) + respondError(w, rekeyErr.Code(), rekeyErr, lc) return } @@ -214,7 +246,7 @@ func handleSysRekeyUpdate(core *vault.Core, recovery bool) http.Handler { } resp.Keys = keys resp.KeysB64 = keysB64 - respondOk(w, resp) + respondOk(w, resp, lc) } else { handleSysRekeyInitGet(ctx, core, recovery, w, r) } @@ -223,16 +255,23 @@ func handleSysRekeyUpdate(core *vault.Core, recovery bool) http.Handler { func handleSysRekeyVerify(core *vault.Core, recovery bool) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Getting custom headers from listener's config + la := w.Header().Get("X-Vault-Listener-Add") + lc, err := core.GetCustomResponseHeaders(la) + if err != nil { + core.Logger().Debug("failed to get custom headers from listener config") + } standby, _ := core.Standby() if standby { - respondStandby(core, w, r.URL) + respondStandby(core, w, r) return } repState := core.ReplicationState() if repState.HasState(consts.ReplicationPerformanceSecondary) { respondError(w, http.StatusBadRequest, - fmt.Errorf("rekeying can only be performed on the primary cluster when replication is activated")) + fmt.Errorf("rekeying can only be performed on the primary cluster when replication is activated"), + lc) return } @@ -241,7 +280,7 @@ func handleSysRekeyVerify(core *vault.Core, recovery bool) http.Handler { switch { case recovery && !core.SealAccess().RecoveryKeySupported(): - respondError(w, http.StatusBadRequest, fmt.Errorf("recovery rekeying not supported")) + respondError(w, http.StatusBadRequest, fmt.Errorf("recovery rekeying not supported"), lc) case r.Method == "GET": handleSysRekeyVerifyGet(ctx, core, recovery, w, r) case r.Method == "POST" || r.Method == "PUT": @@ -249,37 +288,43 @@ func handleSysRekeyVerify(core *vault.Core, recovery bool) http.Handler { case r.Method == "DELETE": handleSysRekeyVerifyDelete(ctx, core, recovery, w, r) default: - respondError(w, http.StatusMethodNotAllowed, nil) + respondError(w, http.StatusMethodNotAllowed, nil, lc) } }) } func handleSysRekeyVerifyGet(ctx context.Context, core *vault.Core, recovery bool, w http.ResponseWriter, r *http.Request) { + // Getting custom headers from listener's config + la := w.Header().Get("X-Vault-Listener-Add") + lc, errNew := core.GetCustomResponseHeaders(la) + if errNew != nil { + core.Logger().Debug("failed to get custom headers from listener config") + } barrierConfig, barrierConfErr := core.SealAccess().BarrierConfig(ctx) if barrierConfErr != nil { - respondError(w, http.StatusInternalServerError, barrierConfErr) + respondError(w, http.StatusInternalServerError, barrierConfErr, lc) return } if barrierConfig == nil { - respondError(w, http.StatusBadRequest, fmt.Errorf("server is not yet initialized")) + respondError(w, http.StatusBadRequest, fmt.Errorf("server is not yet initialized"), lc) return } // Get the rekey configuration rekeyConf, err := core.RekeyConfig(recovery) if err != nil { - respondError(w, err.Code(), err) + respondError(w, err.Code(), err, lc) return } if rekeyConf == nil { - respondError(w, http.StatusBadRequest, errors.New("no rekey configuration found")) + respondError(w, http.StatusBadRequest, errors.New("no rekey configuration found"), lc) return } // Get the progress started, progress, err := core.RekeyProgress(recovery, true) if err != nil { - respondError(w, err.Code(), err) + respondError(w, err.Code(), err, lc) return } @@ -291,12 +336,18 @@ func handleSysRekeyVerifyGet(ctx context.Context, core *vault.Core, recovery boo N: rekeyConf.SecretShares, Progress: progress, } - respondOk(w, status) + respondOk(w, status, lc) } func handleSysRekeyVerifyDelete(ctx context.Context, core *vault.Core, recovery bool, w http.ResponseWriter, r *http.Request) { + // Getting custom headers from listener's config + la := w.Header().Get("X-Vault-Listener-Add") + lc, errNew := core.GetCustomResponseHeaders(la) + if errNew != nil { + core.Logger().Debug("failed to get custom headers from listener config") + } if err := core.RekeyVerifyRestart(recovery); err != nil { - respondError(w, err.Code(), err) + respondError(w, err.Code(), err, lc) return } @@ -304,16 +355,23 @@ func handleSysRekeyVerifyDelete(ctx context.Context, core *vault.Core, recovery } func handleSysRekeyVerifyPut(ctx context.Context, core *vault.Core, recovery bool, w http.ResponseWriter, r *http.Request) { + // Getting custom headers from listener's config + la := w.Header().Get("X-Vault-Listener-Add") + lc, err := core.GetCustomResponseHeaders(la) + if err != nil { + core.Logger().Debug("failed to get custom headers from listener config") + } // Parse the request var req RekeyVerificationUpdateRequest if _, err := parseJSONRequest(core.PerfStandby(), r, w, &req); err != nil { - respondError(w, http.StatusBadRequest, err) + respondError(w, http.StatusBadRequest, err, lc) return } if req.Key == "" { respondError( w, http.StatusBadRequest, - errors.New("'key' must be specified in request body as JSON")) + errors.New("'key' must be specified in request body as JSON"), + lc) return } @@ -328,7 +386,8 @@ func handleSysRekeyVerifyPut(ctx context.Context, core *vault.Core, recovery boo if err != nil { respondError( w, http.StatusBadRequest, - errors.New("'key' must be a valid hex or base64 string")) + errors.New("'key' must be a valid hex or base64 string"), + lc) return } } @@ -339,7 +398,7 @@ func handleSysRekeyVerifyPut(ctx context.Context, core *vault.Core, recovery boo // Use the key to make progress on rekey result, rekeyErr := core.RekeyVerify(ctx, key, req.Nonce, recovery) if rekeyErr != nil { - respondError(w, rekeyErr.Code(), rekeyErr) + respondError(w, rekeyErr.Code(), rekeyErr, lc) return } @@ -348,7 +407,7 @@ func handleSysRekeyVerifyPut(ctx context.Context, core *vault.Core, recovery boo if result != nil { resp.Complete = true resp.Nonce = result.Nonce - respondOk(w, resp) + respondOk(w, resp, lc) } else { handleSysRekeyVerifyGet(ctx, core, recovery, w, r) } diff --git a/http/sys_seal.go b/http/sys_seal.go index 24f491b65d1d6..9108c0e495e7a 100644 --- a/http/sys_seal.go +++ b/http/sys_seal.go @@ -15,16 +15,22 @@ import ( func handleSysSeal(core *vault.Core) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Getting custom headers from listener's config + la := w.Header().Get("X-Vault-Listener-Add") + lc, err := core.GetCustomResponseHeaders(la) + if err != nil { + core.Logger().Debug("failed to get custom headers from listener config") + } req, _, statusCode, err := buildLogicalRequest(core, w, r) if err != nil || statusCode != 0 { - respondError(w, statusCode, err) + respondError(w, statusCode, err, lc) return } switch req.Operation { case logical.UpdateOperation: default: - respondError(w, http.StatusMethodNotAllowed, nil) + respondError(w, http.StatusMethodNotAllowed, nil, lc) return } @@ -32,66 +38,78 @@ func handleSysSeal(core *vault.Core) http.Handler { // We use context.Background since there won't be a request context if the node isn't active if err := core.SealWithRequest(r.Context(), req); err != nil { if errwrap.Contains(err, logical.ErrPermissionDenied.Error()) { - respondError(w, http.StatusForbidden, err) + respondError(w, http.StatusForbidden, err, lc) return } - respondError(w, http.StatusInternalServerError, err) + respondError(w, http.StatusInternalServerError, err, lc) return } - respondOk(w, nil) + respondOk(w, nil, lc) }) } func handleSysStepDown(core *vault.Core) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Getting custom headers from listener's config + la := w.Header().Get("X-Vault-Listener-Add") + lc, err := core.GetCustomResponseHeaders(la) + if err != nil { + core.Logger().Debug("failed to get custom headers from listener config") + } req, _, statusCode, err := buildLogicalRequest(core, w, r) if err != nil || statusCode != 0 { - respondError(w, statusCode, err) + respondError(w, statusCode, err, lc) return } switch req.Operation { case logical.UpdateOperation: default: - respondError(w, http.StatusMethodNotAllowed, nil) + respondError(w, http.StatusMethodNotAllowed, nil, lc) return } // Seal with the token above if err := core.StepDown(r.Context(), req); err != nil { if errwrap.Contains(err, logical.ErrPermissionDenied.Error()) { - respondError(w, http.StatusForbidden, err) + respondError(w, http.StatusForbidden, err, lc) return } - respondError(w, http.StatusInternalServerError, err) + respondError(w, http.StatusInternalServerError, err, lc) return } - respondOk(w, nil) + respondOk(w, nil, lc) }) } func handleSysUnseal(core *vault.Core) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Getting custom headers from listener's config + la := w.Header().Get("X-Vault-Listener-Add") + lc, err := core.GetCustomResponseHeaders(la) + if err != nil { + core.Logger().Debug("failed to get custom headers from listener config") + } switch r.Method { case "PUT": case "POST": default: - respondError(w, http.StatusMethodNotAllowed, nil) + respondError(w, http.StatusMethodNotAllowed, nil, lc) return } // Parse the request var req UnsealRequest if _, err := parseJSONRequest(core.PerfStandby(), r, w, &req); err != nil { - respondError(w, http.StatusBadRequest, err) + respondError(w, http.StatusBadRequest, err, lc) return } if req.Reset { if !core.Sealed() { - respondError(w, http.StatusBadRequest, errors.New("vault is unsealed")) + respondError(w, http.StatusBadRequest, errors.New("vault is unsealed"), lc) return } core.ResetUnsealProcess() @@ -102,7 +120,8 @@ func handleSysUnseal(core *vault.Core) http.Handler { if req.Key == "" { respondError( w, http.StatusBadRequest, - errors.New("'key' must be specified in request body as JSON, or 'reset' set to true")) + errors.New("'key' must be specified in request body as JSON, or 'reset' set to true"), + lc) return } @@ -117,7 +136,8 @@ func handleSysUnseal(core *vault.Core) http.Handler { if err != nil { respondError( w, http.StatusBadRequest, - errors.New("'key' must be a valid hex or base64 string")) + errors.New("'key' must be a valid hex or base64 string"), + lc) return } } @@ -137,10 +157,10 @@ func handleSysUnseal(core *vault.Core) http.Handler { case errwrap.Contains(err, vault.ErrBarrierSealed.Error()): case errwrap.Contains(err, consts.ErrStandby.Error()): default: - respondError(w, http.StatusInternalServerError, err) + respondError(w, http.StatusInternalServerError, err, lc) return } - respondError(w, http.StatusBadRequest, err) + respondError(w, http.StatusBadRequest, err, lc) return } @@ -151,8 +171,14 @@ func handleSysUnseal(core *vault.Core) http.Handler { func handleSysSealStatus(core *vault.Core) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Getting custom headers from listener's config + la := w.Header().Get("X-Vault-Listener-Add") + lc, err := core.GetCustomResponseHeaders(la) + if err != nil { + core.Logger().Debug("failed to get custom headers from listener config") + } if r.Method != "GET" { - respondError(w, http.StatusMethodNotAllowed, nil) + respondError(w, http.StatusMethodNotAllowed, nil, lc) return } @@ -161,14 +187,20 @@ func handleSysSealStatus(core *vault.Core) http.Handler { } func handleSysSealStatusRaw(core *vault.Core, w http.ResponseWriter, r *http.Request) { + // Getting custom headers from listener's config + la := w.Header().Get("X-Vault-Listener-Add") + lc, err := core.GetCustomResponseHeaders(la) + if err != nil { + core.Logger().Debug("failed to get custom headers from listener config") + } ctx := context.Background() status, err := core.GetSealStatus(ctx) if err != nil { - respondError(w, http.StatusInternalServerError, err) + respondError(w, http.StatusInternalServerError, err, lc) return } - respondOk(w, status) + respondOk(w, status, lc) } // Note: because we didn't provide explicit tagging in the past we can't do it diff --git a/http/testing.go b/http/testing.go index be9569dc9684c..53f7fca04a249 100644 --- a/http/testing.go +++ b/http/testing.go @@ -62,5 +62,5 @@ func TestServerAuth(tb testing.TB, addr string, token string) { } func testHandleAuth(w http.ResponseWriter, req *http.Request) { - respondOk(w, nil) + respondOk(w, nil, nil) } diff --git a/http/util.go b/http/util.go index 0550a93c7e66e..da1c921c84b6a 100644 --- a/http/util.go +++ b/http/util.go @@ -33,9 +33,15 @@ var ( func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Getting custom headers from listener's config + la := w.Header().Get("X-Vault-Listener-Add") + lc, err := core.GetCustomResponseHeaders(la) + if err != nil { + core.Logger().Debug("failed to get custom headers from listener config") + } ns, err := namespace.FromContext(r.Context()) if err != nil { - respondError(w, http.StatusInternalServerError, err) + respondError(w, http.StatusInternalServerError, err, lc) return } @@ -44,7 +50,7 @@ func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler // again, which is not desired. path, status, err := buildLogicalPath(r) if err != nil || status != 0 { - respondError(w, status, err) + respondError(w, status, err, lc) return } @@ -57,7 +63,7 @@ func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler }) if err != nil { core.Logger().Error("failed to apply quota", "path", path, "error", err) - respondError(w, http.StatusUnprocessableEntity, err) + respondError(w, http.StatusUnprocessableEntity, err, lc) return } @@ -69,7 +75,7 @@ func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler if !quotaResp.Allowed { quotaErr := fmt.Errorf("request path %q: %w", path, quotas.ErrRateLimitQuotaExceeded) - respondError(w, http.StatusTooManyRequests, quotaErr) + respondError(w, http.StatusTooManyRequests, quotaErr, lc) if core.Logger().IsTrace() { core.Logger().Trace("request rejected due to rate limit quota violation", "request_path", path) @@ -78,7 +84,7 @@ func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler if core.RateLimitAuditLoggingEnabled() { req, _, status, err := buildLogicalRequestNoAuth(core.PerfStandby(), w, r) if err != nil || status != 0 { - respondError(w, status, err) + respondError(w, status, err, lc) return } diff --git a/internalshared/configutil/http_response_headers.go b/internalshared/configutil/http_response_headers.go new file mode 100644 index 0000000000000..779391ee87edb --- /dev/null +++ b/internalshared/configutil/http_response_headers.go @@ -0,0 +1,176 @@ +package configutil + +import ( + "fmt" + "net/textproto" + "strconv" + "strings" +) + +var defaultHeaderNames = []string { + "Content-Security-Policy", + "X-XSS-Protection", + "X-Frame-Options", + "X-Content-Type-Options", + "Strict-Transport-Security", + "Content-Type", +} + +var validStatusCodeCollection = []string { + "default", + "1xx", + "2xx", + "3xx", + "4xx", + "5xx", +} + +const ( + contentSecurityPolicy = "default-src 'none'; connect-src 'self'; img-src 'self' data:; script-src 'self'; style-src 'unsafe-inline' 'self'; form-action 'none'; frame-ancestors 'none'; font-src 'self'" + xXssProtection = "1; mode=block" + xFrameOptions = "Deny" + xContentTypeOptions = "nosniff" + strictTransportSecurity = "max-age=31536000; includeSubDomains" + contentType = "text/plain; charset=utf-8" +) + +func parseDefaultHeaders(h string) string { + switch h { + case "Content-Security-Policy": + return contentSecurityPolicy + case "X-XSS-Protection": + return xXssProtection + case "X-Frame-Options": + return xFrameOptions + case "X-Content-Type-Options": + return xContentTypeOptions + case "Strict-Transport-Security": + return strictTransportSecurity + case "Content-Type": + return contentType + default: + return "" + } +} + +func setDefaultResponseHeaders(c map[string]string) map[string]string { + defaults := make(map[string]string) + // adding all parsed default headers + for k, v := range c { + defaults[k] = v + } + + for _, hn := range defaultHeaderNames { + if _, ok := c[hn]; ok { + continue + } + hv := parseDefaultHeaders(hn) + if hv != "" { + defaults[hn] = hv + } + } + fmt.Printf("Default headers are %v", defaults) + return defaults +} + +func ParseCustomResponseHeaders(r interface{}) (map[string]map[string]string, error) { + if !isValidListDict(r) { + return nil, fmt.Errorf("invalid input type: %T", r) + } + + customResponseHeader := r.([]map[string]interface{}) + h := make(map[string]map[string]string) + + for _, crh := range customResponseHeader { + for sc, rh := range crh { + if !isValidListDict(rh){ + return nil, fmt.Errorf("invalid response header type") + } + + if !isValidStatusCode(sc) { + return nil, fmt.Errorf("invalid status code found in the config file: %v", sc) + } + + hvl := rh.([]map[string]interface{}) + if len(hvl) != 1 { + return nil, fmt.Errorf("invalid number of response headers exist") + } + hvm := hvl[0] + hv, err := parseHeaders(hvm) + if err != nil { + return nil, err + } + + h[sc] = hv + } + } + + // setting default custom headers + de := h["default"] + h["default"] = setDefaultResponseHeaders(de) + + return h, nil +} + +func isValidListDict(in interface{}) bool { + if _, ok := in.([]map[string]interface{}); ok { + return true + } + return false +} + +func isValidList(in interface{}) bool { + if _, ok := in.([]interface{}); ok { + return true + } + return false +} + +// checking for status codes outside the boundary +func isValidStatusCode(sc string) bool { + for _, v := range validStatusCodeCollection { + if sc == v { + return true + } + } + + i, err := strconv.Atoi(sc) + if err != nil { + return false + } + + if i >= 600 || i < 100 { + return false + } + + return true +} + +func parseHeaders(in map[string]interface{}) (map[string]string, error) { + hvMap := make(map[string]string) + for k, v := range in { + // parsing header name + hn := textproto.CanonicalMIMEHeaderKey(k) + // parsing header values + s, err := parseHeaderValues(v) + if err != nil { + return nil, err + } + hvMap[hn] = s + } + return hvMap, nil +} + +func parseHeaderValues(h interface{}) (string, error) { + var sl []string + if !isValidList(h) { + return "", fmt.Errorf("failed to parse custom_response_headers3") + } + vli := h.([]interface{}) + for _, vh := range vli { + sl = append(sl, vh.(string)) + } + s := strings.Join(sl, "; ") + + return s, nil +} \ No newline at end of file diff --git a/internalshared/configutil/listener.go b/internalshared/configutil/listener.go index 98199082895a4..4395bb1520eee 100644 --- a/internalshared/configutil/listener.go +++ b/internalshared/configutil/listener.go @@ -28,6 +28,17 @@ type ListenerProfiling struct { UnauthenticatedPProfAccessRaw interface{} `hcl:"unauthenticated_pprof_access,alias:UnauthenticatedPProfAccessRaw"` } +// TODO: remove this +type CH struct { + X interface{} `hcl:",key,alias:unknown"` + //UnusedKeys UnusedKeyMap `hcl:",unusedKeyPositions"` + Defaults interface{} `hcl:"-"` + DefaultsRaw interface{} `hcl:"default,alias:default"` + R307 map[string]string `hcl:"-"` + R307Raw interface{} `hcl:"307,alias:R307"` + +} + // Listener is the listener configuration for the server. type Listener struct { UnusedKeys UnusedKeyMap `hcl:",unusedKeyPositions"` @@ -99,6 +110,11 @@ type Listener struct { CorsAllowedOrigins []string `hcl:"cors_allowed_origins"` CorsAllowedHeaders []string `hcl:"-"` CorsAllowedHeadersRaw []string `hcl:"cors_allowed_headers,alias:cors_allowed_headers"` + + // Custom Http response headers + CustomResponseHeaders map[string]map[string]string `hcl:"-"` + CustomResponseHeadersRaw interface{} `hcl:"custom_response_headers,alias:custom_response_headers"` + } func (l *Listener) GoString() string { @@ -361,6 +377,18 @@ func ParseListeners(result *SharedConfig, list *ast.ObjectList) error { } } + // HTTP Headers + { + if l.CustomResponseHeadersRaw != nil { + customHeadersMap, err := ParseCustomResponseHeaders(l.CustomResponseHeadersRaw) + if err != nil { + return multierror.Prefix(fmt.Errorf("failed to parse custom_response_headers:%w", err), fmt.Sprintf("listeners.%d", i)) + } + l.CustomResponseHeaders = customHeadersMap + l.CustomResponseHeadersRaw = nil + } + } + result.Listeners = append(result.Listeners, &l) } diff --git a/internalshared/listenerutil/response_headers.go b/internalshared/listenerutil/response_headers.go new file mode 100644 index 0000000000000..8a51a4cb18b61 --- /dev/null +++ b/internalshared/listenerutil/response_headers.go @@ -0,0 +1,75 @@ +package listenerutil + +import ( + "fmt" + "net/http" + "strconv" +) + +// DefaultStatus is used to set default headers early before having a status code, +// for example, for /ui headers +const DefaultStatus = 1 + +func SetCustomResponseHeaders(hm map[string]map[string]string, w http.ResponseWriter, status int) error { + // Removing X-Vault-Listener-Add header from ResponseWriter + // This should be safe as the call to this function is right + // before w.WriteHeader for which the status code is finalized and known + w.Header().Del("X-Vault-Listener-Add") + + if hm == nil { + return nil + } + + // setter function to set the headers + setter := func(hv map[string]string) { + for h, v := range hv { + w.Header().Set(h, v) + } + } + + // Checking the validity of the status code + if status >= 600 || (status < 100 && status != DefaultStatus) { + return fmt.Errorf("invalid status code") + } + + // Setting the default headers first + setter(hm["default"]) + + // for NoStatus, we only set the default headers + if status == DefaultStatus { + return nil + } + + // setting the Xyy pattern first + d := fmt.Sprintf("%vxx", status / 100) + if val, ok := hm[d]; ok { + setter(val) + } + // Setting the specific headers + if val, ok := hm[strconv.Itoa(status)]; ok { + setter(val) + } + + return nil +} + +func FetchCustomResponseHeaderValue(hm map[string]map[string]string, th string, sc int) (string, error) { + if hm == nil { + return "", nil + } + if th == "" { + return "", fmt.Errorf("invalid target header") + } + + var h map[string]string + if sc == DefaultStatus { + h = hm["default"] + }else { + h = hm[strconv.Itoa(sc)] + } + + if v, ok := h[th]; ok { + return v, nil + } + return "", nil +} diff --git a/sdk/logical/response_util.go b/sdk/logical/response_util.go index 6ae3005b735f1..30d9af3e5cb40 100644 --- a/sdk/logical/response_util.go +++ b/sdk/logical/response_util.go @@ -158,7 +158,7 @@ func AdjustErrorStatusCode(status *int, err error) { } func RespondError(w http.ResponseWriter, status int, err error) { - AdjustErrorStatusCode(&status, err) + //AdjustErrorStatusCode(&status, err) w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) diff --git a/vault/core.go b/vault/core.go index 7d629099f70ec..7ae286d567d33 100644 --- a/vault/core.go +++ b/vault/core.go @@ -42,6 +42,7 @@ import ( "github.com/hashicorp/vault/command/server" "github.com/hashicorp/vault/helper/metricsutil" "github.com/hashicorp/vault/helper/namespace" + "github.com/hashicorp/vault/internalshared/configutil" "github.com/hashicorp/vault/physical/raft" "github.com/hashicorp/vault/sdk/helper/certutil" "github.com/hashicorp/vault/sdk/helper/consts" @@ -2632,6 +2633,85 @@ func (c *Core) SetConfig(conf *server.Config) { c.logger.Debug("set config", "sanitized config", string(bz)) } +func (c *Core) GetCustomResponseHeaders(la string) (map[string]map[string]string, error) { + + ln, err := c.GetListenersConf(la) + if err != nil { + c.Logger().Trace(err.Error()) + return nil, fmt.Errorf("listener config with address %v was not found:%w", la, err) + } + // TODO: maybe copy the ln.CustomResponseHeaders and return the copy? + return ln.CustomResponseHeaders, nil +} + +func (c *Core) GetListenersConf(address string) (*configutil.Listener, error) { + conf := c.rawConfig.Load() + if conf == nil { + return nil, errors.New("failed to load config") + } + lns := conf.(*server.Config).Listeners + for _, ln := range lns{ + if ln.Address == address { + return ln, nil + } + } + return nil, errors.New(fmt.Sprintf("no listener with the given address found: %v", address)) +} + +func (c *Core) ReloadCustomHeadersListenerConf() error { + conf := c.rawConfig.Load() + if conf == nil { + return fmt.Errorf("failed to Reload config") + } + lns := conf.(*server.Config).Listeners + for _, ln := range lns{ + if ln.CustomResponseHeadersRaw != nil { + customHeadersMap, err := configutil.ParseCustomResponseHeaders(ln.CustomResponseHeadersRaw) + if err != nil { + return fmt.Errorf("failed to parse custom_response_headers:%w", err) + } + ln.CustomResponseHeaders = customHeadersMap + ln.CustomResponseHeadersRaw = nil + } + } + return nil +} + + +// SanitizedCustomResponseHeader sanitizes listener config from invalid custom headers +func (c *Core) SanitizedCustomResponseHeader(conf *server.Config) { + hm := make(map[string]map[string]string) + userHeaders, err := c.UIHeaders() + if err != nil { + c.Logger().Trace("failed to get ui headers", "error:", err.Error()) + } + + for _, ln := range conf.Listeners { + for sc, ch := range ln.CustomResponseHeaders { + hv := make(map[string]string) + for h, v := range ch { + // X-Vault- prefix is reserved for Vault internal processes + if strings.HasPrefix(h, "X-Vault-") { + c.Logger().Error("Custom headers starting with X-Vault are not valid", "header", h) + continue + } + + // Checking for UI headers, if any common header exist, HCL headers take precedence + if userHeaders != nil { + exist := userHeaders.Get(h) + if exist != "" { + c.Logger().Error("found a duplicate header in UI, note that config file headers take precedence.", "header:", h) + } + } + hv[h] = v + } + hm[sc] = hv + } + ln.CustomResponseHeaders = hm + } + +} + // SanitizedConfig returns a sanitized version of the current config. // See server.Config.Sanitized for specific values omitted. func (c *Core) SanitizedConfig() map[string]interface{} { diff --git a/vault/external_tests/raft/raft_test.go b/vault/external_tests/raft/raft_test.go index f98b575589fd8..040d57e5542a5 100644 --- a/vault/external_tests/raft/raft_test.go +++ b/vault/external_tests/raft/raft_test.go @@ -1205,3 +1205,83 @@ func TestRaft_Join_InitStatus(t *testing.T) { verifyInitStatus(i, true) } } + + +func TestRaft_SnapshotRestoreOnStandby(t *testing.T) { + t.Parallel() + cluster := raftCluster(t, nil) + defer cluster.Cleanup() + + leaderClient := cluster.Cores[0].Client + + // Write a few keys + for i := 0; i < 10; i++ { + _, err := leaderClient.Logical().Write(fmt.Sprintf("secret/%d", i), map[string]interface{}{ + "test": "data", + }) + if err != nil { + t.Fatal(err) + } + } + + transport := cleanhttp.DefaultPooledTransport() + transport.TLSClientConfig = cluster.Cores[0].TLSConfig.Clone() + if err := http2.ConfigureTransport(transport); err != nil { + t.Fatal(err) + } + client := &http.Client{ + Transport: transport, + } + + // Take a snapshot + req := leaderClient.NewRequest("GET", "/v1/sys/storage/raft/snapshot") + httpReq, err := req.ToHTTP() + if err != nil { + t.Fatal(err) + } + resp, err := client.Do(httpReq) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + snap, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + if len(snap) == 0 { + t.Fatal("no snapshot returned") + } + + // Write a few more keys + for i := 10; i < 20; i++ { + _, err := leaderClient.Logical().Write(fmt.Sprintf("secret/%d", i), map[string]interface{}{ + "test": "data", + }) + if err != nil { + t.Fatal(err) + } + } + + // Restore snapshot + req = leaderClient.NewRequest("POST", "/v1/sys/storage/raft/snapshot") + req.Body = bytes.NewBuffer(snap) + httpReq, err = req.ToHTTP() + if err != nil { + t.Fatal(err) + } + resp, err = client.Do(httpReq) + if err != nil { + t.Fatal(err) + } + + // List kv to make sure we removed the extra keys + secret, err := leaderClient.Logical().List("secret/") + if err != nil { + t.Fatal(err) + } + + if len(secret.Data["keys"].([]interface{})) != 10 { + t.Fatal("snapshot didn't apply correctly") + } +} diff --git a/vault/logical_system.go b/vault/logical_system.go index 6179287818861..39e222f78d5b6 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -9,6 +9,8 @@ import ( "encoding/json" "errors" "fmt" + "github.com/hashicorp/vault/internalshared/listenerutil" + "github.com/hashicorp/vault/sdk/helper/headerutil" "hash" "net/http" "path" @@ -2620,12 +2622,23 @@ func (b *SystemBackend) handleConfigUIHeadersUpdate(ctx context.Context, req *lo return logical.ErrorResponse("X-Vault headers cannot be set"), logical.ErrInvalidRequest } + // Getting custom headers from listener's config + la := req.ResponseWriter.Header().Get("X-Vault-Listener-Add") + lc, err := b.Core.GetCustomResponseHeaders(la) + if err != nil { + b.Core.Logger().Debug("failed to get custom headers from listener config") + } + // Translate the list of values to the valid header string value := http.Header{} for _, v := range values { + chv, _ := listenerutil.FetchCustomResponseHeaderValue(lc, header, headerutil.DefaultStatus) + if chv != "" { + return logical.ErrorResponse("header already exist in server configuration file"), logical.ErrInvalidRequest + } value.Add(header, v) } - err := b.Core.uiConfig.SetHeader(ctx, header, value.Values(header)) + err = b.Core.uiConfig.SetHeader(ctx, header, value.Values(header)) if err != nil { return nil, err } From 9dfa68904e908a33b8ea10e818ff90111a5273f8 Mon Sep 17 00:00:00 2001 From: hamid ghaf Date: Thu, 2 Sep 2021 17:04:31 -0700 Subject: [PATCH 02/25] Add changelog, fix bad imports --- changelog/12485.txt | 3 +++ http/handler.go | 3 +-- vault/logical_system.go | 3 +-- 3 files changed, 5 insertions(+), 4 deletions(-) create mode 100644 changelog/12485.txt diff --git a/changelog/12485.txt b/changelog/12485.txt new file mode 100644 index 0000000000000..6ccb4432d67d9 --- /dev/null +++ b/changelog/12485.txt @@ -0,0 +1,3 @@ +```release-note:feature +http: Enable users to customize HTTP response headers +``` diff --git a/http/handler.go b/http/handler.go index 0a0b52661733b..96103be30b940 100644 --- a/http/handler.go +++ b/http/handler.go @@ -28,7 +28,6 @@ import ( "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/internalshared/configutil" "github.com/hashicorp/vault/sdk/helper/consts" - "github.com/hashicorp/vault/sdk/helper/headerutil" "github.com/hashicorp/vault/sdk/helper/jsonutil" "github.com/hashicorp/vault/sdk/helper/pathmanager" "github.com/hashicorp/vault/sdk/logical" @@ -514,7 +513,7 @@ func handleUIHeaders(core *vault.Core, h http.Handler) http.Handler { // just setting all default headers in the config file. // This might overwrite some headers there already set - listenerutil.SetCustomResponseHeaders(lc, w, headerutil.DefaultStatus) + listenerutil.SetCustomResponseHeaders(lc, w, listenerutil.DefaultStatus) // TODO: Setting 200 series as well h.ServeHTTP(w, req) diff --git a/vault/logical_system.go b/vault/logical_system.go index 39e222f78d5b6..29526d971c058 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -10,7 +10,6 @@ import ( "errors" "fmt" "github.com/hashicorp/vault/internalshared/listenerutil" - "github.com/hashicorp/vault/sdk/helper/headerutil" "hash" "net/http" "path" @@ -2632,7 +2631,7 @@ func (b *SystemBackend) handleConfigUIHeadersUpdate(ctx context.Context, req *lo // Translate the list of values to the valid header string value := http.Header{} for _, v := range values { - chv, _ := listenerutil.FetchCustomResponseHeaderValue(lc, header, headerutil.DefaultStatus) + chv, _ := listenerutil.FetchCustomResponseHeaderValue(lc, header, listenerutil.DefaultStatus) if chv != "" { return logical.ErrorResponse("header already exist in server configuration file"), logical.ErrInvalidRequest } From 65f3d3a641b064ff4f5368c398caecd86700c492 Mon Sep 17 00:00:00 2001 From: hamid ghaf Date: Fri, 3 Sep 2021 10:55:00 -0700 Subject: [PATCH 03/25] fixing some bugs --- command/server.go | 11 ++--------- http/handler.go | 16 +++++++++++----- http/handler_test.go | 6 +++--- vault/core.go | 33 +++++++-------------------------- 4 files changed, 23 insertions(+), 43 deletions(-) diff --git a/command/server.go b/command/server.go index 182d1310fd459..4c50a359ba0e4 100644 --- a/command/server.go +++ b/command/server.go @@ -1543,6 +1543,8 @@ func (c *ServerCommand) Run(args []string) int { } core.SetConfig(config) + // Sanitizing listener config from invalid patterns + core.SanitizedCustomResponseHeader(config) if config.LogLevel != "" { configLogLevel := strings.ToLower(strings.TrimSpace(config.LogLevel)) @@ -1574,15 +1576,6 @@ func (c *ServerCommand) Run(args []string) int { c.UI.Error(err.Error()) } - // Reload Custom headers - if err = core.ReloadCustomHeadersListenerConf(); err != nil { - c.UI.Error(err.Error()) - } - // Sanitizing listener config from invalid patterns - core.SanitizedCustomResponseHeader(config) - core.Logger().Info("**** Listern config after sanitization", "Li", config.Listeners) - - select { case c.licenseReloadedCh <- err: default: diff --git a/http/handler.go b/http/handler.go index 96103be30b940..20de3a74de367 100644 --- a/http/handler.go +++ b/http/handler.go @@ -343,14 +343,20 @@ func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerPr } // Setting listener address so that we could get the config from core - w.Header().Set("X-Vault-Listener-Add", props.ListenerConfig.Address) - + var la string + if props.ListenerConfig != nil { + la = props.ListenerConfig.Address + } + // Setting a header so that we could set customized headers + // configured in the corresponding listener stanza + w.Header().Set("X-Vault-Listener-Add", la) // Getting custom headers from listener's config - lc, err := core.GetCustomResponseHeaders(props.ListenerConfig.Address) + lc, err := core.GetCustomResponseHeaders(la) if err != nil { core.Logger().Debug("failed to get custom headers from listener config") } + switch { case strings.HasPrefix(r.URL.Path, "/v1/"): newR, status := adjustRequest(core, r) @@ -941,8 +947,8 @@ func request(core *vault.Core, w http.ResponseWriter, rawReq *http.Request, r *l // Getting custom headers from listener's config la := w.Header().Get("X-Vault-Listener-Add") - lc, err := core.GetCustomResponseHeaders(la) - if err != nil { + lc, errNew := core.GetCustomResponseHeaders(la) + if errNew != nil { core.Logger().Debug("failed to get custom headers from listener config") } if respondErrorCommon(w, r, resp, err, lc) { diff --git a/http/handler_test.go b/http/handler_test.go index c228629ea8dce..01d1866eb7b42 100644 --- a/http/handler_test.go +++ b/http/handler_test.go @@ -599,7 +599,7 @@ func TestHandler_ui_enabled(t *testing.T) { func TestHandler_error(t *testing.T) { w := httptest.NewRecorder() - respondError(w, 500, errors.New("test Error")) + respondError(w, 500, errors.New("test Error"), nil) if w.Code != 500 { t.Fatalf("expected 500, got %d", w.Code) @@ -610,7 +610,7 @@ func TestHandler_error(t *testing.T) { w2 := httptest.NewRecorder() e := logical.CodedError(403, "error text") - respondError(w2, 500, e) + respondError(w2, 500, e, nil) if w2.Code != 403 { t.Fatalf("expected 403, got %d", w2.Code) @@ -619,7 +619,7 @@ func TestHandler_error(t *testing.T) { // vault.ErrSealed is a special case w3 := httptest.NewRecorder() - respondError(w3, 400, consts.ErrSealed) + respondError(w3, 400, consts.ErrSealed, nil) if w3.Code != 503 { t.Fatalf("expected 503, got %d", w3.Code) diff --git a/vault/core.go b/vault/core.go index 7ae286d567d33..391b646e25a00 100644 --- a/vault/core.go +++ b/vault/core.go @@ -2634,11 +2634,12 @@ func (c *Core) SetConfig(conf *server.Config) { } func (c *Core) GetCustomResponseHeaders(la string) (map[string]map[string]string, error) { - + if la == "" { + return nil, nil + } ln, err := c.GetListenersConf(la) - if err != nil { - c.Logger().Trace(err.Error()) - return nil, fmt.Errorf("listener config with address %v was not found:%w", la, err) + if err != nil || ln == nil { + return nil, err } // TODO: maybe copy the ln.CustomResponseHeaders and return the copy? return ln.CustomResponseHeaders, nil @@ -2647,7 +2648,7 @@ func (c *Core) GetCustomResponseHeaders(la string) (map[string]map[string]string func (c *Core) GetListenersConf(address string) (*configutil.Listener, error) { conf := c.rawConfig.Load() if conf == nil { - return nil, errors.New("failed to load config") + return nil, fmt.Errorf("failed to load config") } lns := conf.(*server.Config).Listeners for _, ln := range lns{ @@ -2655,29 +2656,9 @@ func (c *Core) GetListenersConf(address string) (*configutil.Listener, error) { return ln, nil } } - return nil, errors.New(fmt.Sprintf("no listener with the given address found: %v", address)) -} - -func (c *Core) ReloadCustomHeadersListenerConf() error { - conf := c.rawConfig.Load() - if conf == nil { - return fmt.Errorf("failed to Reload config") - } - lns := conf.(*server.Config).Listeners - for _, ln := range lns{ - if ln.CustomResponseHeadersRaw != nil { - customHeadersMap, err := configutil.ParseCustomResponseHeaders(ln.CustomResponseHeadersRaw) - if err != nil { - return fmt.Errorf("failed to parse custom_response_headers:%w", err) - } - ln.CustomResponseHeaders = customHeadersMap - ln.CustomResponseHeadersRaw = nil - } - } - return nil + return nil, fmt.Errorf("failed to find listener config with address %v", address) } - // SanitizedCustomResponseHeader sanitizes listener config from invalid custom headers func (c *Core) SanitizedCustomResponseHeader(conf *server.Config) { hm := make(map[string]map[string]string) From 61c12ebf2090a6a705338710ba19efdd2d7ea2b8 Mon Sep 17 00:00:00 2001 From: hamid ghaf Date: Fri, 3 Sep 2021 13:59:28 -0700 Subject: [PATCH 04/25] fixing interaction of custom headers and /ui --- http/handler.go | 10 ++++----- http/logical.go | 6 +++++- .../listenerutil/response_headers.go | 21 +++++++++++++++++-- sdk/logical/response_util.go | 1 - vault/logical_system.go | 8 +++++-- 5 files changed, 35 insertions(+), 11 deletions(-) diff --git a/http/handler.go b/http/handler.go index 20de3a74de367..60a2af097508a 100644 --- a/http/handler.go +++ b/http/handler.go @@ -517,10 +517,11 @@ func handleUIHeaders(core *vault.Core, h http.Handler) http.Handler { } } - // just setting all default headers in the config file. - // This might overwrite some headers there already set - listenerutil.SetCustomResponseHeaders(lc, w, listenerutil.DefaultStatus) - // TODO: Setting 200 series as well + // This function wraps handleUI and handleUIStub which do not set the + // status code specifically, instead, a call to w.Write is called which + // internally also sets the status code to 200. + // Just setting the headers for status code 200. + listenerutil.SetCustomResponseHeaders(lc, w, 200) h.ServeHTTP(w, req) }) @@ -973,7 +974,6 @@ func respondStandby(core *vault.Core, w http.ResponseWriter, req *http.Request) if err == vault.ErrHANotEnabled { // Standalone node, serve 503 err = errors.New("node is not active") - // TODO: set headers before all these responseError respondError(w, http.StatusServiceUnavailable, err, lc) return } diff --git a/http/logical.go b/http/logical.go index 2f2912ab904f2..3b17703fa4b20 100644 --- a/http/logical.go +++ b/http/logical.go @@ -99,6 +99,11 @@ func buildLogicalRequestNoAuth(perfStandby bool, w http.ResponseWriter, r *http. bufferedBody := newBufferedReader(r.Body) r.Body = bufferedBody + // response writer is needed when updating ui headers to make sure it + // does not interfere with custom response headers set in the configuration file + if strings.HasPrefix(path,"sys/config/ui") { + responseWriter = w + } // If we are uploading a snapshot we don't want to parse it. Instead // we will simply add the HTTP request to the logical request object // for later consumption. @@ -358,7 +363,6 @@ func respondLogical(core *vault.Core, w http.ResponseWriter, r *http.Request, re if resp.Redirect != "" { // If we have a redirect, redirect! We use a 307 code // because we don't actually know if its permanent. - // TODO: need to set custom headers before calling the redirect status := 307 listenerutil.SetCustomResponseHeaders(lc, w, status) http.Redirect(w, r, resp.Redirect, status) diff --git a/internalshared/listenerutil/response_headers.go b/internalshared/listenerutil/response_headers.go index 8a51a4cb18b61..d3cff70894e1e 100644 --- a/internalshared/listenerutil/response_headers.go +++ b/internalshared/listenerutil/response_headers.go @@ -3,6 +3,7 @@ package listenerutil import ( "fmt" "net/http" + "net/textproto" "strconv" ) @@ -35,7 +36,7 @@ func SetCustomResponseHeaders(hm map[string]map[string]string, w http.ResponseWr // Setting the default headers first setter(hm["default"]) - // for NoStatus, we only set the default headers + // for DefaultStatus, we only set the default headers if status == DefaultStatus { return nil } @@ -68,8 +69,24 @@ func FetchCustomResponseHeaderValue(hm map[string]map[string]string, th string, h = hm[strconv.Itoa(sc)] } - if v, ok := h[th]; ok { + hn := textproto.CanonicalMIMEHeaderKey(th) + if v, ok := h[hn]; ok { return v, nil } return "", nil } + +func ExistHeader(hm map[string]map[string]string, th string, sl []int) bool { + if len(sl) == 0 { + return false + } + + for _, s := range sl { + chv, _ := FetchCustomResponseHeaderValue(hm, th, s) + if chv != "" { + return true + } + } + + return false +} \ No newline at end of file diff --git a/sdk/logical/response_util.go b/sdk/logical/response_util.go index 30d9af3e5cb40..a570b7d602227 100644 --- a/sdk/logical/response_util.go +++ b/sdk/logical/response_util.go @@ -158,7 +158,6 @@ func AdjustErrorStatusCode(status *int, err error) { } func RespondError(w http.ResponseWriter, status int, err error) { - //AdjustErrorStatusCode(&status, err) w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) diff --git a/vault/logical_system.go b/vault/logical_system.go index 29526d971c058..6cd479448c0d9 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -2622,6 +2622,9 @@ func (b *SystemBackend) handleConfigUIHeadersUpdate(ctx context.Context, req *lo } // Getting custom headers from listener's config + if req.ResponseWriter == nil { + return logical.ErrorResponse("no ResponseWriter in the request"), logical.ErrInvalidRequest + } la := req.ResponseWriter.Header().Get("X-Vault-Listener-Add") lc, err := b.Core.GetCustomResponseHeaders(la) if err != nil { @@ -2631,8 +2634,9 @@ func (b *SystemBackend) handleConfigUIHeadersUpdate(ctx context.Context, req *lo // Translate the list of values to the valid header string value := http.Header{} for _, v := range values { - chv, _ := listenerutil.FetchCustomResponseHeaderValue(lc, header, listenerutil.DefaultStatus) - if chv != "" { + // check if the header exist in "default" and 200 status code maps of custom response headers + sl := []int{listenerutil.DefaultStatus, 200} + if listenerutil.ExistHeader(lc, header, sl) { return logical.ErrorResponse("header already exist in server configuration file"), logical.ErrInvalidRequest } value.Add(header, v) From f4232cddd864178c9b2123326cc0fdf6221f4d47 Mon Sep 17 00:00:00 2001 From: hamid ghaf Date: Sat, 4 Sep 2021 17:31:11 -0700 Subject: [PATCH 05/25] Defining a member in core to set custom response headers --- command/server.go | 10 +- http/cors.go | 12 +- http/forwarded_for_test.go | 12 +- http/handler.go | 143 ++++++------------ http/help.go | 12 +- http/logical.go | 38 ++--- http/sys_feature_flags.go | 11 +- http/sys_generate_root.go | 65 +++----- http/sys_health.go | 30 +--- http/sys_init.go | 32 +--- http/sys_leader.go | 18 +-- http/sys_metrics.go | 19 +-- http/sys_raft.go | 38 ++--- http/sys_rekey.go | 133 +++++----------- http/sys_seal.go | 70 +++------ http/util.go | 16 +- .../configutil/http_response_headers.go | 12 +- internalshared/configutil/listener.go | 11 -- .../listenerutil/response_headers.go | 92 ----------- vault/core.go | 97 ++++++------ vault/external_tests/raft/raft_test.go | 80 ---------- vault/logical_system.go | 12 +- 22 files changed, 259 insertions(+), 704 deletions(-) delete mode 100644 internalshared/listenerutil/response_headers.go diff --git a/command/server.go b/command/server.go index 4c50a359ba0e4..c063778450cd9 100644 --- a/command/server.go +++ b/command/server.go @@ -1341,9 +1341,6 @@ func (c *ServerCommand) Run(args []string) int { } } - // Sanitizing listener config from invalid custom headers - core.SanitizedCustomResponseHeader(config) - status, lns, clusterAddrs, errMsg := c.InitListeners(config, disableClustering, &infoKeys, &info) if status != 0 { @@ -1543,8 +1540,9 @@ func (c *ServerCommand) Run(args []string) int { } core.SetConfig(config) - // Sanitizing listener config from invalid patterns - core.SanitizedCustomResponseHeader(config) + if err = core.ReloadCustomListenerHeader(); err != nil { + c.UI.Error(err.Error()) + } if config.LogLevel != "" { configLogLevel := strings.ToLower(strings.TrimSpace(config.LogLevel)) @@ -2636,7 +2634,7 @@ func startHttpServers(c *ServerCommand, core *vault.Core, config *server.Config, }) if len(ln.Config.XForwardedForAuthorizedAddrs) > 0 { - handler = vaulthttp.WrapForwardedForHandler(handler, ln.Config) + handler = vaulthttp.WrapForwardedForHandler(handler, ln.Config, core.SetCustomResponseHeaders) } // server defaults diff --git a/http/cors.go b/http/cors.go index 685108ff4b03e..9a8b57c9b9b6f 100644 --- a/http/cors.go +++ b/http/cors.go @@ -2,7 +2,6 @@ package http import ( "fmt" - "github.com/hashicorp/vault/internalshared/listenerutil" "net/http" "strings" @@ -38,21 +37,16 @@ func wrapCORSHandler(h http.Handler, core *vault.Core) http.Handler { h.ServeHTTP(w, req) return } - // Getting custom headers from listener's config - la := w.Header().Get("X-Vault-Listener-Add") - lc, err := core.GetCustomResponseHeaders(la) - if err != nil { - core.Logger().Debug("failed to get custom headers from listener config") - } + // Return a 403 if the origin is not allowed to make cross-origin requests. if !corsConf.IsValidOrigin(origin) { - respondError(w, http.StatusForbidden, fmt.Errorf("origin not allowed"), lc) + respondError(w, http.StatusForbidden, fmt.Errorf("origin not allowed"), core.SetCustomResponseHeaders) return } if req.Method == http.MethodOptions && !strutil.StrListContains(allowedMethods, requestMethod) { status := http.StatusMethodNotAllowed - listenerutil.SetCustomResponseHeaders(lc, w, status) + core.SetCustomResponseHeaders(w, status) w.WriteHeader(status) return } diff --git a/http/forwarded_for_test.go b/http/forwarded_for_test.go index 9323f5bf1c728..40ea8289c7c71 100644 --- a/http/forwarded_for_test.go +++ b/http/forwarded_for_test.go @@ -42,7 +42,7 @@ func TestHandler_XForwardedFor(t *testing.T) { }) listenerConfig := getListenerConfigForMarshalerTest(goodAddr) listenerConfig.XForwardedForRejectNotPresent = true - return WrapForwardedForHandler(origHandler, listenerConfig) + return WrapForwardedForHandler(origHandler, listenerConfig, nil) } cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{ @@ -85,7 +85,7 @@ func TestHandler_XForwardedFor(t *testing.T) { }) listenerConfig := getListenerConfigForMarshalerTest(badAddr) listenerConfig.XForwardedForRejectNotPresent = true - return WrapForwardedForHandler(origHandler, listenerConfig) + return WrapForwardedForHandler(origHandler, listenerConfig, nil) } cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{ @@ -121,7 +121,7 @@ func TestHandler_XForwardedFor(t *testing.T) { listenerConfig := getListenerConfigForMarshalerTest(badAddr) listenerConfig.XForwardedForRejectNotPresent = true listenerConfig.XForwardedForRejectNotAuthorized = true - return WrapForwardedForHandler(origHandler, listenerConfig) + return WrapForwardedForHandler(origHandler, listenerConfig, nil) } cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{ @@ -155,7 +155,7 @@ func TestHandler_XForwardedFor(t *testing.T) { listenerConfig.XForwardedForRejectNotPresent = true listenerConfig.XForwardedForRejectNotAuthorized = true listenerConfig.XForwardedForHopSkips = 4 - return WrapForwardedForHandler(origHandler, listenerConfig) + return WrapForwardedForHandler(origHandler, listenerConfig, nil) } cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{ @@ -189,7 +189,7 @@ func TestHandler_XForwardedFor(t *testing.T) { listenerConfig.XForwardedForRejectNotPresent = true listenerConfig.XForwardedForRejectNotAuthorized = true listenerConfig.XForwardedForHopSkips = 1 - return WrapForwardedForHandler(origHandler, listenerConfig) + return WrapForwardedForHandler(origHandler, listenerConfig, nil) } cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{ @@ -226,7 +226,7 @@ func TestHandler_XForwardedFor(t *testing.T) { listenerConfig.XForwardedForRejectNotPresent = true listenerConfig.XForwardedForRejectNotAuthorized = true listenerConfig.XForwardedForHopSkips = 1 - return WrapForwardedForHandler(origHandler, listenerConfig) + return WrapForwardedForHandler(origHandler, listenerConfig, nil) } cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{ diff --git a/http/handler.go b/http/handler.go index 60a2af097508a..45ab0ed425f69 100644 --- a/http/handler.go +++ b/http/handler.go @@ -6,7 +6,6 @@ import ( "encoding/json" "errors" "fmt" - "github.com/hashicorp/vault/internalshared/listenerutil" "io" "io/fs" "io/ioutil" @@ -246,16 +245,10 @@ func handleAuditNonLogical(core *vault.Core, h http.Handler) http.Handler { origBody := new(bytes.Buffer) reader := ioutil.NopCloser(io.TeeReader(r.Body, origBody)) r.Body = reader - // Getting custom headers from listener's config - la := w.Header().Get("X-Vault-Listener-Add") - lc, err := core.GetCustomResponseHeaders(la) - if err != nil { - core.Logger().Debug("failed to get custom headers from listener config") - } req, _, status, err := buildLogicalRequestNoAuth(core.PerfStandby(), w, r) if err != nil || status != 0 { - respondError(w, status, err, lc) + respondError(w, status, err, core.SetCustomResponseHeaders) return } if origBody != nil { @@ -266,7 +259,7 @@ func handleAuditNonLogical(core *vault.Core, h http.Handler) http.Handler { } err = core.AuditLogger().AuditRequest(r.Context(), input) if err != nil { - respondError(w, status, err, lc) + respondError(w, status, err, core.SetCustomResponseHeaders) return } cw := newCopyResponseWriter(w) @@ -280,7 +273,7 @@ func handleAuditNonLogical(core *vault.Core, h http.Handler) http.Handler { input.Response = logical.HTTPResponseToLogicalResponse(httpResp) err = core.AuditLogger().AuditResponse(r.Context(), input) if err != nil { - respondError(w, status, err, lc) + respondError(w, status, err, core.SetCustomResponseHeaders) } return }) @@ -350,18 +343,12 @@ func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerPr // Setting a header so that we could set customized headers // configured in the corresponding listener stanza w.Header().Set("X-Vault-Listener-Add", la) - // Getting custom headers from listener's config - lc, err := core.GetCustomResponseHeaders(la) - if err != nil { - core.Logger().Debug("failed to get custom headers from listener config") - } - switch { case strings.HasPrefix(r.URL.Path, "/v1/"): newR, status := adjustRequest(core, r) if status != 0 { - respondError(w, status, nil, lc) + respondError(w, status, nil, core.SetCustomResponseHeaders) cancelFunc() return } @@ -369,7 +356,7 @@ func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerPr case strings.HasPrefix(r.URL.Path, "/ui"), r.URL.Path == "/robots.txt", r.URL.Path == "/": default: - respondError(w, http.StatusNotFound, nil, lc) + respondError(w, http.StatusNotFound, nil, core.SetCustomResponseHeaders) cancelFunc() return } @@ -387,7 +374,7 @@ func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerPr }) } -func WrapForwardedForHandler(h http.Handler, l *configutil.Listener) http.Handler { +func WrapForwardedForHandler(h http.Handler, l *configutil.Listener, hs customResponseHeaderSetter) http.Handler { rejectNotPresent := l.XForwardedForRejectNotPresent hopSkips := l.XForwardedForHopSkips authorizedAddrs := l.XForwardedForAuthorizedAddrs @@ -399,7 +386,7 @@ func WrapForwardedForHandler(h http.Handler, l *configutil.Listener) http.Handle h.ServeHTTP(w, r) return } - respondError(w, http.StatusBadRequest, fmt.Errorf("missing x-forwarded-for header and configured to reject when not present"), l.CustomResponseHeaders) + respondError(w, http.StatusBadRequest, fmt.Errorf("missing x-forwarded-for header and configured to reject when not present"), hs) return } @@ -412,7 +399,7 @@ func WrapForwardedForHandler(h http.Handler, l *configutil.Listener) http.Handle h.ServeHTTP(w, r) return } - respondError(w, http.StatusBadRequest, fmt.Errorf("error parsing client hostport: %w", err), l.CustomResponseHeaders) + respondError(w, http.StatusBadRequest, fmt.Errorf("error parsing client hostport: %w", err), hs) return } @@ -423,7 +410,7 @@ func WrapForwardedForHandler(h http.Handler, l *configutil.Listener) http.Handle h.ServeHTTP(w, r) return } - respondError(w, http.StatusBadRequest, fmt.Errorf("error parsing client address: %w", err), l.CustomResponseHeaders) + respondError(w, http.StatusBadRequest, fmt.Errorf("error parsing client address: %w", err), hs) return } @@ -441,7 +428,7 @@ func WrapForwardedForHandler(h http.Handler, l *configutil.Listener) http.Handle h.ServeHTTP(w, r) return } - respondError(w, http.StatusBadRequest, fmt.Errorf("client address not authorized for x-forwarded-for and configured to reject connection"), l.CustomResponseHeaders) + respondError(w, http.StatusBadRequest, fmt.Errorf("client address not authorized for x-forwarded-for and configured to reject connection"), hs) return } @@ -469,7 +456,7 @@ func WrapForwardedForHandler(h http.Handler, l *configutil.Listener) http.Handle // authorized (or we've turned off explicit rejection) and we // should assume that what comes in should be properly // formatted. - respondError(w, http.StatusBadRequest, fmt.Errorf("malformed x-forwarded-for configuration or request, hops to skip (%d) would skip before earliest chain link (chain length %d)", hopSkips, len(headers)), l.CustomResponseHeaders) + respondError(w, http.StatusBadRequest, fmt.Errorf("malformed x-forwarded-for configuration or request, hops to skip (%d) would skip before earliest chain link (chain length %d)", hopSkips, len(headers)), hs) return } @@ -498,16 +485,9 @@ func handleUIHeaders(core *vault.Core, h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { header := w.Header() - // Getting custom headers from listener's config - la := w.Header().Get("X-Vault-Listener-Add") - lc, err := core.GetCustomResponseHeaders(la) - if err != nil { - core.Logger().Debug("failed to get custom headers from listener config") - } - userHeaders, err := core.UIHeaders() if err != nil { - respondError(w, http.StatusInternalServerError, err, lc) + respondError(w, http.StatusInternalServerError, err, core.SetCustomResponseHeaders) return } if userHeaders != nil { @@ -521,7 +501,7 @@ func handleUIHeaders(core *vault.Core, h http.Handler) http.Handler { // status code specifically, instead, a call to w.Write is called which // internally also sets the status code to 200. // Just setting the headers for status code 200. - listenerutil.SetCustomResponseHeaders(lc, w, 200) + core.SetCustomResponseHeaders(w, 200) h.ServeHTTP(w, req) }) @@ -612,15 +592,8 @@ func handleUIStub() http.Handler { func handleUIRedirect(core *vault.Core) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - // Getting custom headers from listener's config - la := w.Header().Get("X-Vault-Listener-Add") - lc, err := core.GetCustomResponseHeaders(la) - if err != nil { - core.Logger().Debug("failed to get custom headers from listener config") - } - status := 307 - listenerutil.SetCustomResponseHeaders(lc, w, status) + core.SetCustomResponseHeaders(w, status) http.Redirect(w, req, "/ui/", status) return }) @@ -773,16 +746,10 @@ func forwardBasedOnHeaders(core *vault.Core, r *http.Request) (bool, error) { // falling back on the older behavior of redirecting the client func handleRequestForwarding(core *vault.Core, handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Getting custom headers from listener's config - la := w.Header().Get("X-Vault-Listener-Add") - lc, err := core.GetCustomResponseHeaders(la) - if err != nil { - core.Logger().Debug("failed to get custom headers from listener config") - } // Note if the client requested forwarding shouldForward, err := forwardBasedOnHeaders(core, r) if err != nil { - respondError(w, http.StatusBadRequest, err, lc) + respondError(w, http.StatusBadRequest, err, core.SetCustomResponseHeaders) return } @@ -790,7 +757,7 @@ func handleRequestForwarding(core *vault.Core, handler http.Handler) http.Handle if core.PerfStandby() && !shouldForward { ns, err := namespace.FromContext(r.Context()) if err != nil { - respondError(w, http.StatusBadRequest, err, lc) + respondError(w, http.StatusBadRequest, err, core.SetCustomResponseHeaders) return } path := ns.TrimmedPath(r.URL.Path[len("/v1/"):]) @@ -818,7 +785,7 @@ func handleRequestForwarding(core *vault.Core, handler http.Handler) http.Handle return } // Some internal error occurred - respondError(w, http.StatusInternalServerError, err, lc) + respondError(w, http.StatusInternalServerError, err, core.SetCustomResponseHeaders) return } if isLeader { @@ -827,7 +794,7 @@ func handleRequestForwarding(core *vault.Core, handler http.Handler) http.Handle return } if leaderAddr == "" { - respondError(w, http.StatusInternalServerError, fmt.Errorf("local node not active but active cluster node not found"), lc) + respondError(w, http.StatusInternalServerError, fmt.Errorf("local node not active but active cluster node not found"), core.SetCustomResponseHeaders) return } @@ -837,33 +804,27 @@ func handleRequestForwarding(core *vault.Core, handler http.Handler) http.Handle } func forwardRequest(core *vault.Core, w http.ResponseWriter, r *http.Request) { - // Getting custom headers from listener's config - la := w.Header().Get("X-Vault-Listener-Add") - lc, err := core.GetCustomResponseHeaders(la) - if err != nil { - core.Logger().Debug("failed to get custom headers from listener config") - } if r.Header.Get(vault.IntNoForwardingHeaderName) != "" { - respondStandby(core, w, r) + respondStandby(core, w, r.URL) return } if r.Header.Get(NoRequestForwardingHeaderName) != "" { // Forwarding explicitly disabled, fall back to previous behavior core.Logger().Debug("handleRequestForwarding: forwarding disabled by client request") - respondStandby(core, w, r) + respondStandby(core, w, r.URL) return } ns, err := namespace.FromContext(r.Context()) if err != nil { - respondError(w, http.StatusBadRequest, err, lc) + respondError(w, http.StatusBadRequest, err, core.SetCustomResponseHeaders) return } path := ns.TrimmedPath(r.URL.Path[len("/v1/"):]) if alwaysRedirectPaths.HasPath(path) { - respondStandby(core, w, r) + respondStandby(core, w, r.URL) return } @@ -879,7 +840,7 @@ func forwardRequest(core *vault.Core, w http.ResponseWriter, r *http.Request) { } // Fall back to redirection - respondStandby(core, w, r) + respondStandby(core, w, r.URL) return } @@ -889,7 +850,7 @@ func forwardRequest(core *vault.Core, w http.ResponseWriter, r *http.Request) { } } - listenerutil.SetCustomResponseHeaders(lc, w, statusCode) + core.SetCustomResponseHeaders(w, statusCode) w.WriteHeader(statusCode) w.Write(retBytes) } @@ -905,7 +866,7 @@ func request(core *vault.Core, w http.ResponseWriter, rawReq *http.Request, r *l resp.AddWarning("Timeout hit while waiting for local replicated cluster to apply primary's write; this client may encounter stale reads of values written during this operation.") } if errwrap.Contains(err, consts.ErrStandby.Error()) { - respondStandby(core, w, rawReq) + respondStandby(core, w, rawReq.URL) return resp, false, false } if err != nil && errwrap.Contains(err, logical.ErrPerfStandbyPleaseForward.Error()) { @@ -946,13 +907,7 @@ func request(core *vault.Core, w http.ResponseWriter, rawReq *http.Request, r *l return nil, true, false } - // Getting custom headers from listener's config - la := w.Header().Get("X-Vault-Listener-Add") - lc, errNew := core.GetCustomResponseHeaders(la) - if errNew != nil { - core.Logger().Debug("failed to get custom headers from listener config") - } - if respondErrorCommon(w, r, resp, err, lc) { + if respondErrorCommon(w, r, resp, err, core.SetCustomResponseHeaders) { return resp, false, false } @@ -960,39 +915,33 @@ func request(core *vault.Core, w http.ResponseWriter, rawReq *http.Request, r *l } // respondStandby is used to trigger a redirect in the case that this Vault is currently a hot standby -func respondStandby(core *vault.Core, w http.ResponseWriter, req *http.Request) { - // Getting custom headers from listener's config - la := w.Header().Get("X-Vault-Listener-Add") - lc, err := core.GetCustomResponseHeaders(la) - if err != nil { - core.Logger().Debug("failed to get custom headers from listener config") - } - reqURL := req.URL +func respondStandby(core *vault.Core, w http.ResponseWriter, reqURL *url.URL) { + // Request the leader address _, redirectAddr, _, err := core.Leader() if err != nil { if err == vault.ErrHANotEnabled { // Standalone node, serve 503 err = errors.New("node is not active") - respondError(w, http.StatusServiceUnavailable, err, lc) + respondError(w, http.StatusServiceUnavailable, err, core.SetCustomResponseHeaders) return } - respondError(w, http.StatusInternalServerError, err, lc) + respondError(w, http.StatusInternalServerError, err, core.SetCustomResponseHeaders) return } // If there is no leader, generate a 503 error if redirectAddr == "" { err = errors.New("no active Vault instance found") - respondError(w, http.StatusServiceUnavailable, err, lc) + respondError(w, http.StatusServiceUnavailable, err, core.SetCustomResponseHeaders) return } // Parse the redirect location redirectURL, err := url.Parse(redirectAddr) if err != nil { - respondError(w, http.StatusInternalServerError, err, lc) + respondError(w, http.StatusInternalServerError, err, core.SetCustomResponseHeaders) return } @@ -1013,7 +962,7 @@ func respondStandby(core *vault.Core, w http.ResponseWriter, req *http.Request) // because we don't actually know if its permanent and // the request method should be preserved. w.Header().Set("Location", finalURL.String()) - listenerutil.SetCustomResponseHeaders(lc, w, 307) + core.SetCustomResponseHeaders(w, 307) w.WriteHeader(307) } @@ -1178,34 +1127,42 @@ func isForm(head []byte, contentType string) bool { return true } -func respondError(w http.ResponseWriter, status int, err error, h map[string]map[string]string) { +type customResponseHeaderSetter func(w http.ResponseWriter, status int) + +func respondError(w http.ResponseWriter, status int, err error, hs customResponseHeaderSetter) { logical.AdjustErrorStatusCode(&status, err) - listenerutil.SetCustomResponseHeaders(h, w, status) + if hs != nil { + hs(w, status) + } logical.RespondError(w, status, err) } -func respondErrorCommon(w http.ResponseWriter, req *logical.Request, resp *logical.Response, err error, h map[string]map[string]string) bool { +func respondErrorCommon(w http.ResponseWriter, req *logical.Request, resp *logical.Response, err error, hs customResponseHeaderSetter) bool { statusCode, newErr := logical.RespondErrorCommon(req, resp, err) if newErr == nil && statusCode == 0 { return false } - respondError(w, statusCode, newErr, h) + respondError(w, statusCode, newErr, hs) return true } -func respondOk(w http.ResponseWriter, body interface{}, h map[string]map[string]string) { +func respondOk(w http.ResponseWriter, body interface{}, hs customResponseHeaderSetter) { w.Header().Set("Content-Type", "application/json") var status int if body == nil { status = http.StatusNoContent - listenerutil.SetCustomResponseHeaders(h, w, status) - w.WriteHeader(status) } else { status = http.StatusOK - listenerutil.SetCustomResponseHeaders(h, w, status) - w.WriteHeader(status) + } + + if hs != nil { + hs(w, status) + } + w.WriteHeader(status) + + if body != nil { enc := json.NewEncoder(w) enc.Encode(body) } diff --git a/http/help.go b/http/help.go index c57675dcf84d2..e4c93c712d427 100644 --- a/http/help.go +++ b/http/help.go @@ -26,15 +26,9 @@ func wrapHelpHandler(h http.Handler, core *vault.Core) http.Handler { } func handleHelp(core *vault.Core, w http.ResponseWriter, r *http.Request) { - // Getting custom headers from listener's config - la := w.Header().Get("X-Vault-Listener-Add") - lc, err := core.GetCustomResponseHeaders(la) - if err != nil { - core.Logger().Debug("failed to get custom headers from listener config") - } ns, err := namespace.FromContext(r.Context()) if err != nil { - respondError(w, http.StatusBadRequest, nil, lc) + respondError(w, http.StatusBadRequest, nil, core.SetCustomResponseHeaders) return } path := ns.TrimmedPath(r.URL.Path[len("/v1/"):]) @@ -48,9 +42,9 @@ func handleHelp(core *vault.Core, w http.ResponseWriter, r *http.Request) { resp, err := core.HandleRequest(r.Context(), req) if err != nil { - respondErrorCommon(w, req, resp, err, lc) + respondErrorCommon(w, req, resp, err, core.SetCustomResponseHeaders) return } - respondOk(w, resp.Data, lc) + respondOk(w, resp.Data, core.SetCustomResponseHeaders) } diff --git a/http/logical.go b/http/logical.go index 3b17703fa4b20..79b2504a0ede0 100644 --- a/http/logical.go +++ b/http/logical.go @@ -5,7 +5,6 @@ import ( "encoding/base64" "encoding/json" "fmt" - "github.com/hashicorp/vault/internalshared/listenerutil" "io" "net" "net/http" @@ -302,15 +301,9 @@ func handleLogicalRecovery(raw *vault.RawBackend, token *atomic.String) http.Han // toggles. Refer to usage on functions for possible behaviors. func handleLogicalInternal(core *vault.Core, injectDataIntoTopLevel bool, noForward bool) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Getting custom headers from listener's config - la := w.Header().Get("X-Vault-Listener-Add") - lc, err := core.GetCustomResponseHeaders(la) - if err != nil { - core.Logger().Debug("failed to get custom headers from listener config") - } req, origBody, statusCode, err := buildLogicalRequest(core, w, r) if err != nil || statusCode != 0 { - respondError(w, statusCode, err, lc) + respondError(w, statusCode, err, core.SetCustomResponseHeaders) return } @@ -322,7 +315,7 @@ func handleLogicalInternal(core *vault.Core, injectDataIntoTopLevel bool, noForw resp, ok, needsForward := request(core, w, r, req) switch { case needsForward && noForward: - respondError(w, http.StatusBadRequest, vault.ErrCannotForwardLocalOnly, lc) + respondError(w, http.StatusBadRequest, vault.ErrCannotForwardLocalOnly, core.SetCustomResponseHeaders) return case needsForward && !noForward: if origBody != nil { @@ -353,25 +346,23 @@ func respondLogical(core *vault.Core, w http.ResponseWriter, r *http.Request, re return } - // Getting custom headers from listener's config - la := w.Header().Get("X-Vault-Listener-Add") - lc, err := core.GetCustomResponseHeaders(la) - if err != nil { - core.Logger().Debug("failed to get custom headers from listener config") - } if resp != nil { if resp.Redirect != "" { // If we have a redirect, redirect! We use a 307 code // because we don't actually know if its permanent. status := 307 - listenerutil.SetCustomResponseHeaders(lc, w, status) + core.SetCustomResponseHeaders(w, status) http.Redirect(w, r, resp.Redirect, status) return } // Check if this is a raw response if _, ok := resp.Data[logical.HTTPStatusCode]; ok { - respondRaw(w, r, resp, lc) + var hs customResponseHeaderSetter + if core != nil { + hs = core.SetCustomResponseHeaders + } + respondRaw(w, r, resp, hs) return } @@ -404,18 +395,20 @@ func respondLogical(core *vault.Core, w http.ResponseWriter, r *http.Request, re adjustResponse(core, w, req) // Respond - respondOk(w, ret, lc) + respondOk(w, ret, core.SetCustomResponseHeaders) return } // respondRaw is used when the response is using HTTPContentType and HTTPRawBody // to change the default response handling. This is only used for specific things like // returning the CRL information on the PKI backends. -func respondRaw(w http.ResponseWriter, r *http.Request, resp *logical.Response, h map[string]map[string]string) { +func respondRaw(w http.ResponseWriter, r *http.Request, resp *logical.Response, hs customResponseHeaderSetter) { retErr := func(w http.ResponseWriter, err string) { w.Header().Set("X-Vault-Raw-Error", err) code := http.StatusInternalServerError - listenerutil.SetCustomResponseHeaders(h, w, code) + if hs != nil { + hs(w, code) + } w.WriteHeader(code) w.Write(nil) } @@ -504,8 +497,9 @@ WRITE_RESPONSE: if cacheControl, ok := resp.Data[logical.HTTPRawCacheControl].(string); ok { w.Header().Set("Cache-Control", cacheControl) } - - listenerutil.SetCustomResponseHeaders(h, w, status) + if hs != nil { + hs(w, status) + } w.WriteHeader(status) w.Write(body) } diff --git a/http/sys_feature_flags.go b/http/sys_feature_flags.go index 468e76e159d32..512dfdfb834a5 100644 --- a/http/sys_feature_flags.go +++ b/http/sys_feature_flags.go @@ -2,7 +2,6 @@ package http import ( "encoding/json" - "github.com/hashicorp/vault/internalshared/listenerutil" "net/http" "os" @@ -28,17 +27,11 @@ func featureFlagIsSet(name string) bool { func handleSysInternalFeatureFlags(core *vault.Core) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Getting custom headers from listener's config - la := w.Header().Get("X-Vault-Listener-Add") - lc, err := core.GetCustomResponseHeaders(la) - if err != nil { - core.Logger().Debug("failed to get custom headers from listener config") - } switch r.Method { case "GET": break default: - respondError(w, http.StatusMethodNotAllowed, nil, lc) + respondError(w, http.StatusMethodNotAllowed, nil, core.SetCustomResponseHeaders) } response := &FeatureFlagsResponse{} @@ -51,7 +44,7 @@ func handleSysInternalFeatureFlags(core *vault.Core) http.Handler { w.Header().Set("Content-Type", "application/json") status := http.StatusOK - listenerutil.SetCustomResponseHeaders(lc, w, status) + core.SetCustomResponseHeaders(w, status) w.WriteHeader(status) // Generate the response diff --git a/http/sys_generate_root.go b/http/sys_generate_root.go index 8cea12e474f1d..2329de6a1282b 100644 --- a/http/sys_generate_root.go +++ b/http/sys_generate_root.go @@ -14,12 +14,6 @@ import ( func handleSysGenerateRootAttempt(core *vault.Core, generateStrategy vault.GenerateRootStrategy) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Getting custom headers from listener's config - la := w.Header().Get("X-Vault-Listener-Add") - lc, err := core.GetCustomResponseHeaders(la) - if err != nil { - core.Logger().Debug("failed to get custom headers from listener config") - } switch r.Method { case "GET": handleSysGenerateRootAttemptGet(core, w, r, "") @@ -28,7 +22,7 @@ func handleSysGenerateRootAttempt(core *vault.Core, generateStrategy vault.Gener case "DELETE": handleSysGenerateRootAttemptDelete(core, w, r) default: - respondError(w, http.StatusMethodNotAllowed, nil, lc) + respondError(w, http.StatusMethodNotAllowed, nil, core.SetCustomResponseHeaders) } }) } @@ -37,21 +31,14 @@ func handleSysGenerateRootAttemptGet(core *vault.Core, w http.ResponseWriter, r ctx, cancel := core.GetContext() defer cancel() - // Getting custom headers from listener's config - la := w.Header().Get("X-Vault-Listener-Add") - lc, err := core.GetCustomResponseHeaders(la) - if err != nil { - core.Logger().Debug("failed to get custom headers from listener config") - } - // Get the current seal configuration barrierConfig, err := core.SealAccess().BarrierConfig(ctx) if err != nil { - respondError(w, http.StatusInternalServerError, err, lc) + respondError(w, http.StatusInternalServerError, err, core.SetCustomResponseHeaders) return } if barrierConfig == nil { - respondError(w, http.StatusBadRequest, fmt.Errorf("server is not yet initialized"), lc) + respondError(w, http.StatusBadRequest, fmt.Errorf("server is not yet initialized"), core.SetCustomResponseHeaders) return } @@ -59,7 +46,7 @@ func handleSysGenerateRootAttemptGet(core *vault.Core, w http.ResponseWriter, r if core.SealAccess().RecoveryKeySupported() { sealConfig, err = core.SealAccess().RecoveryConfig(ctx) if err != nil { - respondError(w, http.StatusInternalServerError, err, lc) + respondError(w, http.StatusInternalServerError, err, core.SetCustomResponseHeaders) return } } @@ -67,14 +54,14 @@ func handleSysGenerateRootAttemptGet(core *vault.Core, w http.ResponseWriter, r // Get the generation configuration generationConfig, err := core.GenerateRootConfiguration() if err != nil { - respondError(w, http.StatusInternalServerError, err, lc) + respondError(w, http.StatusInternalServerError, err, core.SetCustomResponseHeaders) return } // Get the progress progress, err := core.GenerateRootProgress() if err != nil { - respondError(w, http.StatusInternalServerError, err, lc) + respondError(w, http.StatusInternalServerError, err, core.SetCustomResponseHeaders) return } @@ -93,21 +80,15 @@ func handleSysGenerateRootAttemptGet(core *vault.Core, w http.ResponseWriter, r status.PGPFingerprint = generationConfig.PGPFingerprint } - respondOk(w, status, lc) + respondOk(w, status, core.SetCustomResponseHeaders) } func handleSysGenerateRootAttemptPut(core *vault.Core, w http.ResponseWriter, r *http.Request, generateStrategy vault.GenerateRootStrategy) { - // Getting custom headers from listener's config - la := w.Header().Get("X-Vault-Listener-Add") - lc, errNew := core.GetCustomResponseHeaders(la) - if errNew != nil { - core.Logger().Debug("failed to get custom headers from listener config") - } // Parse the request var req GenerateRootInitRequest if _, err := parseJSONRequest(core.PerfStandby(), r, w, &req); err != nil && err != io.EOF { - respondError(w, http.StatusBadRequest, err, lc) + respondError(w, http.StatusBadRequest, err, core.SetCustomResponseHeaders) return } @@ -120,14 +101,14 @@ func handleSysGenerateRootAttemptPut(core *vault.Core, w http.ResponseWriter, r genned = true req.OTP, err = base62.Random(vault.TokenLength + 2) if err != nil { - respondError(w, http.StatusInternalServerError, err, lc) + respondError(w, http.StatusInternalServerError, err, core.SetCustomResponseHeaders) return } } // Attemptialize the generation if err := core.GenerateRootInit(req.OTP, req.PGPKey, generateStrategy); err != nil { - respondError(w, http.StatusBadRequest, err, lc) + respondError(w, http.StatusBadRequest, err, core.SetCustomResponseHeaders) return } @@ -140,40 +121,28 @@ func handleSysGenerateRootAttemptPut(core *vault.Core, w http.ResponseWriter, r } func handleSysGenerateRootAttemptDelete(core *vault.Core, w http.ResponseWriter, r *http.Request) { - // Getting custom headers from listener's config - la := w.Header().Get("X-Vault-Listener-Add") - lc, err := core.GetCustomResponseHeaders(la) - if err != nil { - core.Logger().Debug("failed to get custom headers from listener config") - } errNew := core.GenerateRootCancel() if errNew != nil { - respondError(w, http.StatusInternalServerError, errNew, lc) + respondError(w, http.StatusInternalServerError, errNew, core.SetCustomResponseHeaders) return } - respondOk(w, nil, lc) + respondOk(w, nil, core.SetCustomResponseHeaders) } func handleSysGenerateRootUpdate(core *vault.Core, generateStrategy vault.GenerateRootStrategy) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Getting custom headers from listener's config - la := w.Header().Get("X-Vault-Listener-Add") - lc, errNew := core.GetCustomResponseHeaders(la) - if errNew != nil { - core.Logger().Debug("failed to get custom headers from listener config") - } // Parse the request var req GenerateRootUpdateRequest if _, err := parseJSONRequest(core.PerfStandby(), r, w, &req); err != nil { - respondError(w, http.StatusBadRequest, err, lc) + respondError(w, http.StatusBadRequest, err, core.SetCustomResponseHeaders) return } if req.Key == "" { respondError( w, http.StatusBadRequest, errors.New("'key' must be specified in request body as JSON"), - lc) + core.SetCustomResponseHeaders) return } @@ -189,7 +158,7 @@ func handleSysGenerateRootUpdate(core *vault.Core, generateStrategy vault.Genera respondError( w, http.StatusBadRequest, errors.New("'key' must be a valid hex or base64 string"), - lc) + core.SetCustomResponseHeaders) return } } @@ -200,7 +169,7 @@ func handleSysGenerateRootUpdate(core *vault.Core, generateStrategy vault.Genera // Use the key to make progress on root generation result, err := core.GenerateRootUpdate(ctx, key, req.Nonce, generateStrategy) if err != nil { - respondError(w, http.StatusBadRequest, err, lc) + respondError(w, http.StatusBadRequest, err, core.SetCustomResponseHeaders) return } @@ -218,7 +187,7 @@ func handleSysGenerateRootUpdate(core *vault.Core, generateStrategy vault.Genera resp.EncodedRootToken = result.EncodedToken } - respondOk(w, resp, lc) + respondOk(w, resp, core.SetCustomResponseHeaders) }) } diff --git a/http/sys_health.go b/http/sys_health.go index 9e310d17e7b0b..219ebdf8e6ee7 100644 --- a/http/sys_health.go +++ b/http/sys_health.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "fmt" - "github.com/hashicorp/vault/internalshared/listenerutil" "net/http" "strconv" "time" @@ -17,19 +16,13 @@ import ( func handleSysHealth(core *vault.Core) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Getting custom headers from listener's config - la := w.Header().Get("X-Vault-Listener-Add") - lc, err := core.GetCustomResponseHeaders(la) - if err != nil { - core.Logger().Debug("failed to get custom headers from listener config") - } switch r.Method { case "GET": handleSysHealthGet(core, w, r) case "HEAD": handleSysHealthHead(core, w, r) default: - respondError(w, http.StatusMethodNotAllowed, nil, lc) + respondError(w, http.StatusMethodNotAllowed, nil, core.SetCustomResponseHeaders) } }) } @@ -48,26 +41,20 @@ func fetchStatusCode(r *http.Request, field string) (int, bool, bool) { } func handleSysHealthGet(core *vault.Core, w http.ResponseWriter, r *http.Request) { - // Getting custom headers from listener's config - la := w.Header().Get("X-Vault-Listener-Add") - lc, err := core.GetCustomResponseHeaders(la) - if err != nil { - core.Logger().Debug("failed to get custom headers from listener config") - } code, body, err := getSysHealth(core, r) if err != nil { core.Logger().Error("error checking health", "error", err) - respondError(w, code, nil, lc) + respondError(w, code, nil, core.SetCustomResponseHeaders) return } if body == nil { - respondError(w, code, nil, lc) + respondError(w, code, nil, core.SetCustomResponseHeaders) return } w.Header().Set("Content-Type", "application/json") - listenerutil.SetCustomResponseHeaders(lc, w, code) + core.SetCustomResponseHeaders(w, code) w.WriteHeader(code) // Generate the response @@ -81,13 +68,8 @@ func handleSysHealthHead(core *vault.Core, w http.ResponseWriter, r *http.Reques if body != nil { w.Header().Set("Content-Type", "application/json") } - // Getting custom headers from listener's config - la := w.Header().Get("X-Vault-Listener-Add") - lc, err := core.GetCustomResponseHeaders(la) - if err != nil { - core.Logger().Debug("failed to get custom headers from listener config") - } - listenerutil.SetCustomResponseHeaders(lc, w, code) + + core.SetCustomResponseHeaders(w, code) w.WriteHeader(code) } diff --git a/http/sys_init.go b/http/sys_init.go index 4f9436eb4c82a..3a224f804ae4b 100644 --- a/http/sys_init.go +++ b/http/sys_init.go @@ -11,55 +11,37 @@ import ( func handleSysInit(core *vault.Core) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Getting custom headers from listener's config - la := w.Header().Get("X-Vault-Listener-Add") - lc, err := core.GetCustomResponseHeaders(la) - if err != nil { - core.Logger().Debug("failed to get custom headers from listener config") - } switch r.Method { case "GET": handleSysInitGet(core, w, r) case "PUT", "POST": handleSysInitPut(core, w, r) default: - respondError(w, http.StatusMethodNotAllowed, nil, lc) + respondError(w, http.StatusMethodNotAllowed, nil, core.SetCustomResponseHeaders) } }) } func handleSysInitGet(core *vault.Core, w http.ResponseWriter, r *http.Request) { - // Getting custom headers from listener's config - la := w.Header().Get("X-Vault-Listener-Add") - lc, err := core.GetCustomResponseHeaders(la) - if err != nil { - core.Logger().Debug("failed to get custom headers from listener config") - } init, err := core.Initialized(context.Background()) if err != nil { - respondError(w, http.StatusInternalServerError, err, lc) + respondError(w, http.StatusInternalServerError, err, core.SetCustomResponseHeaders) return } respondOk(w, &InitStatusResponse{ Initialized: init, - }, lc) + }, core.SetCustomResponseHeaders) } func handleSysInitPut(core *vault.Core, w http.ResponseWriter, r *http.Request) { - // Getting custom headers from listener's config - la := w.Header().Get("X-Vault-Listener-Add") - lc, err := core.GetCustomResponseHeaders(la) - if err != nil { - core.Logger().Debug("failed to get custom headers from listener config") - } ctx := context.Background() // Parse the request var req InitRequest if _, err := parseJSONRequest(core.PerfStandby(), r, w, &req); err != nil { - respondError(w, http.StatusBadRequest, err, lc) + respondError(w, http.StatusBadRequest, err, core.SetCustomResponseHeaders) return } @@ -86,7 +68,7 @@ func handleSysInitPut(core *vault.Core, w http.ResponseWriter, r *http.Request) result, initErr := core.Initialize(ctx, initParams) if initErr != nil { if vault.IsFatalError(initErr) { - respondError(w, http.StatusBadRequest, initErr, lc) + respondError(w, http.StatusBadRequest, initErr, core.SetCustomResponseHeaders) return } else { // Add a warnings field? The error will be logged in the vault log @@ -118,11 +100,11 @@ func handleSysInitPut(core *vault.Core, w http.ResponseWriter, r *http.Request) } if err := core.UnsealWithStoredKeys(ctx); err != nil { - respondError(w, http.StatusInternalServerError, err, lc) + respondError(w, http.StatusInternalServerError, err, core.SetCustomResponseHeaders) return } - respondOk(w, resp, lc) + respondOk(w, resp, core.SetCustomResponseHeaders) } type InitRequest struct { diff --git a/http/sys_leader.go b/http/sys_leader.go index 4d1754946530b..b31da9ada99f1 100644 --- a/http/sys_leader.go +++ b/http/sys_leader.go @@ -10,32 +10,20 @@ import ( // or becomes the leader. func handleSysLeader(core *vault.Core) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Getting custom headers from listener's config - la := w.Header().Get("X-Vault-Listener-Add") - lc, err := core.GetCustomResponseHeaders(la) - if err != nil { - core.Logger().Debug("failed to get custom headers from listener config") - } switch r.Method { case "GET": handleSysLeaderGet(core, w, r) default: - respondError(w, http.StatusMethodNotAllowed, nil, lc) + respondError(w, http.StatusMethodNotAllowed, nil, core.SetCustomResponseHeaders) } }) } func handleSysLeaderGet(core *vault.Core, w http.ResponseWriter, r *http.Request) { - // Getting custom headers from listener's config - la := w.Header().Get("X-Vault-Listener-Add") - lc, err := core.GetCustomResponseHeaders(la) - if err != nil { - core.Logger().Debug("failed to get custom headers from listener config") - } resp, err := core.GetLeaderStatus() if err != nil { - respondError(w, http.StatusInternalServerError, err, lc) + respondError(w, http.StatusInternalServerError, err, core.SetCustomResponseHeaders) return } - respondOk(w, resp, lc) + respondOk(w, resp, core.SetCustomResponseHeaders) } diff --git a/http/sys_metrics.go b/http/sys_metrics.go index 4e2ffc9bff18d..840d75d87be5d 100644 --- a/http/sys_metrics.go +++ b/http/sys_metrics.go @@ -2,7 +2,6 @@ package http import ( "fmt" - "github.com/hashicorp/vault/internalshared/listenerutil" "net/http" "github.com/hashicorp/vault/helper/metricsutil" @@ -12,24 +11,19 @@ import ( func handleMetricsUnauthenticated(core *vault.Core) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Getting custom headers from listener's config - la := w.Header().Get("X-Vault-Listener-Add") - lc, err := core.GetCustomResponseHeaders(la) - if err != nil { - core.Logger().Debug("failed to get custom headers from listener config") - } + req := &logical.Request{Headers: r.Header} switch r.Method { case "GET": default: - respondError(w, http.StatusMethodNotAllowed, nil, lc) + respondError(w, http.StatusMethodNotAllowed, nil, core.SetCustomResponseHeaders) return } // Parse form if err := r.ParseForm(); err != nil { - respondError(w, http.StatusBadRequest, err, lc) + respondError(w, http.StatusBadRequest, err, core.SetCustomResponseHeaders) return } @@ -43,16 +37,17 @@ func handleMetricsUnauthenticated(core *vault.Core) http.Handler { // Manually extract the logical response and send back the information status := resp.Data[logical.HTTPStatusCode].(int) - listenerutil.SetCustomResponseHeaders(lc, w, status) - w.WriteHeader(status) + core.SetCustomResponseHeaders(w, status) w.Header().Set("Content-Type", resp.Data[logical.HTTPContentType].(string)) switch v := resp.Data[logical.HTTPRawBody].(type) { case string: + w.WriteHeader(status) w.Write([]byte(v)) case []byte: + w.WriteHeader(status) w.Write(v) default: - respondError(w, http.StatusInternalServerError, fmt.Errorf("wrong response returned"), lc) + respondError(w, http.StatusInternalServerError, fmt.Errorf("wrong response returned"), core.SetCustomResponseHeaders) } }) } diff --git a/http/sys_raft.go b/http/sys_raft.go index a4590398660b8..1659a8da86cd9 100644 --- a/http/sys_raft.go +++ b/http/sys_raft.go @@ -15,25 +15,19 @@ import ( func handleSysRaftBootstrap(core *vault.Core) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Getting custom headers from listener's config - la := w.Header().Get("X-Vault-Listener-Add") - lc, err := core.GetCustomResponseHeaders(la) - if err != nil { - core.Logger().Debug("failed to get custom headers from listener config") - } switch r.Method { case "POST", "PUT": if core.Sealed() { - respondError(w, http.StatusBadRequest, errors.New("node must be unsealed to bootstrap"), lc) + respondError(w, http.StatusBadRequest, errors.New("node must be unsealed to bootstrap"), core.SetCustomResponseHeaders) } if err := core.RaftBootstrap(context.Background(), false); err != nil { - respondError(w, http.StatusInternalServerError, err, lc) + respondError(w, http.StatusInternalServerError, err, core.SetCustomResponseHeaders) return } default: - respondError(w, http.StatusBadRequest, nil, lc) + respondError(w, http.StatusBadRequest, nil, core.SetCustomResponseHeaders) } }) } @@ -44,33 +38,21 @@ func handleSysRaftJoin(core *vault.Core) http.Handler { case "POST", "PUT": handleSysRaftJoinPost(core, w, r) default: - // Getting custom headers from listener's config - la := w.Header().Get("X-Vault-Listener-Add") - lc, err := core.GetCustomResponseHeaders(la) - if err != nil { - core.Logger().Debug("failed to get custom headers from listener config") - } - respondError(w, http.StatusMethodNotAllowed, nil, lc) + respondError(w, http.StatusMethodNotAllowed, nil, core.SetCustomResponseHeaders) } }) } func handleSysRaftJoinPost(core *vault.Core, w http.ResponseWriter, r *http.Request) { - // Getting custom headers from listener's config - la := w.Header().Get("X-Vault-Listener-Add") - lc, errNew := core.GetCustomResponseHeaders(la) - if errNew != nil { - core.Logger().Debug("failed to get custom headers from listener config") - } // Parse the request var req JoinRequest if _, err := parseJSONRequest(core.PerfStandby(), r, w, &req); err != nil && err != io.EOF { - respondError(w, http.StatusBadRequest, err, lc) + respondError(w, http.StatusBadRequest, err, core.SetCustomResponseHeaders) return } if req.NonVoter && !nonVotersAllowed { - respondError(w, http.StatusBadRequest, errors.New("non-voting nodes not allowed"), lc) + respondError(w, http.StatusBadRequest, errors.New("non-voting nodes not allowed"), core.SetCustomResponseHeaders) return } @@ -79,14 +61,14 @@ func handleSysRaftJoinPost(core *vault.Core, w http.ResponseWriter, r *http.Requ if len(req.LeaderCACert) != 0 || len(req.LeaderClientCert) != 0 || len(req.LeaderClientKey) != 0 { tlsConfig, err = tlsutil.ClientTLSConfig([]byte(req.LeaderCACert), []byte(req.LeaderClientCert), []byte(req.LeaderClientKey)) if err != nil { - respondError(w, http.StatusBadRequest, err, lc) + respondError(w, http.StatusBadRequest, err, core.SetCustomResponseHeaders) return } tlsConfig.ServerName = req.LeaderTLSServerName } if req.AutoJoinScheme != "" && (req.AutoJoinScheme != "http" && req.AutoJoinScheme != "https") { - respondError(w, http.StatusBadRequest, fmt.Errorf("invalid scheme '%s'; must either be http or https", req.AutoJoinScheme), lc) + respondError(w, http.StatusBadRequest, fmt.Errorf("invalid scheme '%s'; must either be http or https", req.AutoJoinScheme), core.SetCustomResponseHeaders) return } @@ -103,14 +85,14 @@ func handleSysRaftJoinPost(core *vault.Core, w http.ResponseWriter, r *http.Requ joined, err := core.JoinRaftCluster(context.Background(), leaderInfos, req.NonVoter) if err != nil { - respondError(w, http.StatusInternalServerError, err, lc) + respondError(w, http.StatusInternalServerError, err, core.SetCustomResponseHeaders) return } resp := JoinResponse{ Joined: joined, } - respondOk(w, resp, lc) + respondOk(w, resp, core.SetCustomResponseHeaders) } type JoinResponse struct { diff --git a/http/sys_rekey.go b/http/sys_rekey.go index 5679a5923528e..f93b8370789de 100644 --- a/http/sys_rekey.go +++ b/http/sys_rekey.go @@ -17,20 +17,15 @@ func handleSysRekeyInit(core *vault.Core, recovery bool) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { standby, _ := core.Standby() if standby { - respondStandby(core, w, r) + respondStandby(core, w, r.URL) return } - // Getting custom headers from listener's config - la := w.Header().Get("X-Vault-Listener-Add") - lc, err := core.GetCustomResponseHeaders(la) - if err != nil { - core.Logger().Debug("failed to get custom headers from listener config") - } + repState := core.ReplicationState() if repState.HasState(consts.ReplicationPerformanceSecondary) { respondError(w, http.StatusBadRequest, fmt.Errorf("rekeying can only be performed on the primary cluster when replication is activated"), - lc) + core.SetCustomResponseHeaders) return } @@ -39,7 +34,7 @@ func handleSysRekeyInit(core *vault.Core, recovery bool) http.Handler { switch { case recovery && !core.SealAccess().RecoveryKeySupported(): - respondError(w, http.StatusBadRequest, fmt.Errorf("recovery rekeying not supported"), lc) + respondError(w, http.StatusBadRequest, fmt.Errorf("recovery rekeying not supported"), core.SetCustomResponseHeaders) case r.Method == "GET": handleSysRekeyInitGet(ctx, core, recovery, w, r) case r.Method == "POST" || r.Method == "PUT": @@ -47,38 +42,32 @@ func handleSysRekeyInit(core *vault.Core, recovery bool) http.Handler { case r.Method == "DELETE": handleSysRekeyInitDelete(ctx, core, recovery, w, r) default: - respondError(w, http.StatusMethodNotAllowed, nil, lc) + respondError(w, http.StatusMethodNotAllowed, nil, core.SetCustomResponseHeaders) } }) } func handleSysRekeyInitGet(ctx context.Context, core *vault.Core, recovery bool, w http.ResponseWriter, r *http.Request) { - // Getting custom headers from listener's config - la := w.Header().Get("X-Vault-Listener-Add") - lc, errNew := core.GetCustomResponseHeaders(la) - if errNew != nil { - core.Logger().Debug("failed to get custom headers from listener config") - } barrierConfig, barrierConfErr := core.SealAccess().BarrierConfig(ctx) if barrierConfErr != nil { - respondError(w, http.StatusInternalServerError, barrierConfErr, lc) + respondError(w, http.StatusInternalServerError, barrierConfErr, core.SetCustomResponseHeaders) return } if barrierConfig == nil { - respondError(w, http.StatusBadRequest, fmt.Errorf("server is not yet initialized"), lc) + respondError(w, http.StatusBadRequest, fmt.Errorf("server is not yet initialized"), core.SetCustomResponseHeaders) return } // Get the rekey configuration rekeyConf, err := core.RekeyConfig(recovery) if err != nil { - respondError(w, err.Code(), err, lc) + respondError(w, err.Code(), err, core.SetCustomResponseHeaders) return } sealThreshold, err := core.RekeyThreshold(ctx, recovery) if err != nil { - respondError(w, err.Code(), err, lc) + respondError(w, err.Code(), err, core.SetCustomResponseHeaders) return } @@ -93,7 +82,7 @@ func handleSysRekeyInitGet(ctx context.Context, core *vault.Core, recovery bool, // Get the progress started, progress, err := core.RekeyProgress(recovery, false) if err != nil { - respondError(w, err.Code(), err, lc) + respondError(w, err.Code(), err, core.SetCustomResponseHeaders) return } @@ -107,37 +96,31 @@ func handleSysRekeyInitGet(ctx context.Context, core *vault.Core, recovery bool, if rekeyConf.PGPKeys != nil && len(rekeyConf.PGPKeys) != 0 { pgpFingerprints, err := pgpkeys.GetFingerprints(rekeyConf.PGPKeys, nil) if err != nil { - respondError(w, http.StatusInternalServerError, err, lc) + respondError(w, http.StatusInternalServerError, err, core.SetCustomResponseHeaders) return } status.PGPFingerprints = pgpFingerprints status.Backup = rekeyConf.Backup } } - respondOk(w, status, lc) + respondOk(w, status, core.SetCustomResponseHeaders) } func handleSysRekeyInitPut(ctx context.Context, core *vault.Core, recovery bool, w http.ResponseWriter, r *http.Request) { - // Getting custom headers from listener's config - la := w.Header().Get("X-Vault-Listener-Add") - lc, errNew := core.GetCustomResponseHeaders(la) - if errNew != nil { - core.Logger().Debug("failed to get custom headers from listener config") - } // Parse the request var req RekeyRequest if _, err := parseJSONRequest(core.PerfStandby(), r, w, &req); err != nil { - respondError(w, http.StatusBadRequest, err, lc) + respondError(w, http.StatusBadRequest, err, core.SetCustomResponseHeaders) return } if req.Backup && len(req.PGPKeys) == 0 { - respondError(w, http.StatusBadRequest, fmt.Errorf("cannot request a backup of the new keys without providing PGP keys for encryption"), lc) + respondError(w, http.StatusBadRequest, fmt.Errorf("cannot request a backup of the new keys without providing PGP keys for encryption"), core.SetCustomResponseHeaders) return } if len(req.PGPKeys) > 0 && len(req.PGPKeys) != req.SecretShares { - respondError(w, http.StatusBadRequest, fmt.Errorf("incorrect number of PGP keys for rekey"), lc) + respondError(w, http.StatusBadRequest, fmt.Errorf("incorrect number of PGP keys for rekey"), core.SetCustomResponseHeaders) return } @@ -151,7 +134,7 @@ func handleSysRekeyInitPut(ctx context.Context, core *vault.Core, recovery bool, VerificationRequired: req.RequireVerification, }, recovery) if err != nil { - respondError(w, err.Code(), err, lc) + respondError(w, err.Code(), err, core.SetCustomResponseHeaders) return } @@ -159,44 +142,32 @@ func handleSysRekeyInitPut(ctx context.Context, core *vault.Core, recovery bool, } func handleSysRekeyInitDelete(ctx context.Context, core *vault.Core, recovery bool, w http.ResponseWriter, r *http.Request) { - // Getting custom headers from listener's config - la := w.Header().Get("X-Vault-Listener-Add") - lc, err := core.GetCustomResponseHeaders(la) - if err != nil { - core.Logger().Debug("failed to get custom headers from listener config") - } if err := core.RekeyCancel(recovery); err != nil { - respondError(w, err.Code(), err, lc) + respondError(w, err.Code(), err, core.SetCustomResponseHeaders) return } - respondOk(w, nil, lc) + respondOk(w, nil, core.SetCustomResponseHeaders) } func handleSysRekeyUpdate(core *vault.Core, recovery bool) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Getting custom headers from listener's config - la := w.Header().Get("X-Vault-Listener-Add") - lc, errNew := core.GetCustomResponseHeaders(la) - if errNew != nil { - core.Logger().Debug("failed to get custom headers from listener config") - } standby, _ := core.Standby() if standby { - respondStandby(core, w, r) + respondStandby(core, w, r.URL) return } // Parse the request var req RekeyUpdateRequest if _, err := parseJSONRequest(core.PerfStandby(), r, w, &req); err != nil { - respondError(w, http.StatusBadRequest, err, lc) + respondError(w, http.StatusBadRequest, err, core.SetCustomResponseHeaders) return } if req.Key == "" { respondError( w, http.StatusBadRequest, errors.New("'key' must be specified in request body as JSON"), - lc) + core.SetCustomResponseHeaders) return } @@ -212,7 +183,7 @@ func handleSysRekeyUpdate(core *vault.Core, recovery bool) http.Handler { respondError( w, http.StatusBadRequest, errors.New("'key' must be a valid hex or base64 string"), - lc) + core.SetCustomResponseHeaders) return } } @@ -223,7 +194,7 @@ func handleSysRekeyUpdate(core *vault.Core, recovery bool) http.Handler { // Use the key to make progress on rekey result, rekeyErr := core.RekeyUpdate(ctx, key, req.Nonce, recovery) if rekeyErr != nil { - respondError(w, rekeyErr.Code(), rekeyErr, lc) + respondError(w, rekeyErr.Code(), rekeyErr, core.SetCustomResponseHeaders) return } @@ -246,7 +217,7 @@ func handleSysRekeyUpdate(core *vault.Core, recovery bool) http.Handler { } resp.Keys = keys resp.KeysB64 = keysB64 - respondOk(w, resp, lc) + respondOk(w, resp, core.SetCustomResponseHeaders) } else { handleSysRekeyInitGet(ctx, core, recovery, w, r) } @@ -255,15 +226,9 @@ func handleSysRekeyUpdate(core *vault.Core, recovery bool) http.Handler { func handleSysRekeyVerify(core *vault.Core, recovery bool) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Getting custom headers from listener's config - la := w.Header().Get("X-Vault-Listener-Add") - lc, err := core.GetCustomResponseHeaders(la) - if err != nil { - core.Logger().Debug("failed to get custom headers from listener config") - } standby, _ := core.Standby() if standby { - respondStandby(core, w, r) + respondStandby(core, w, r.URL) return } @@ -271,7 +236,7 @@ func handleSysRekeyVerify(core *vault.Core, recovery bool) http.Handler { if repState.HasState(consts.ReplicationPerformanceSecondary) { respondError(w, http.StatusBadRequest, fmt.Errorf("rekeying can only be performed on the primary cluster when replication is activated"), - lc) + core.SetCustomResponseHeaders) return } @@ -280,7 +245,7 @@ func handleSysRekeyVerify(core *vault.Core, recovery bool) http.Handler { switch { case recovery && !core.SealAccess().RecoveryKeySupported(): - respondError(w, http.StatusBadRequest, fmt.Errorf("recovery rekeying not supported"), lc) + respondError(w, http.StatusBadRequest, fmt.Errorf("recovery rekeying not supported"), core.SetCustomResponseHeaders) case r.Method == "GET": handleSysRekeyVerifyGet(ctx, core, recovery, w, r) case r.Method == "POST" || r.Method == "PUT": @@ -288,43 +253,37 @@ func handleSysRekeyVerify(core *vault.Core, recovery bool) http.Handler { case r.Method == "DELETE": handleSysRekeyVerifyDelete(ctx, core, recovery, w, r) default: - respondError(w, http.StatusMethodNotAllowed, nil, lc) + respondError(w, http.StatusMethodNotAllowed, nil, core.SetCustomResponseHeaders) } }) } func handleSysRekeyVerifyGet(ctx context.Context, core *vault.Core, recovery bool, w http.ResponseWriter, r *http.Request) { - // Getting custom headers from listener's config - la := w.Header().Get("X-Vault-Listener-Add") - lc, errNew := core.GetCustomResponseHeaders(la) - if errNew != nil { - core.Logger().Debug("failed to get custom headers from listener config") - } barrierConfig, barrierConfErr := core.SealAccess().BarrierConfig(ctx) if barrierConfErr != nil { - respondError(w, http.StatusInternalServerError, barrierConfErr, lc) + respondError(w, http.StatusInternalServerError, barrierConfErr, core.SetCustomResponseHeaders) return } if barrierConfig == nil { - respondError(w, http.StatusBadRequest, fmt.Errorf("server is not yet initialized"), lc) + respondError(w, http.StatusBadRequest, fmt.Errorf("server is not yet initialized"), core.SetCustomResponseHeaders) return } // Get the rekey configuration rekeyConf, err := core.RekeyConfig(recovery) if err != nil { - respondError(w, err.Code(), err, lc) + respondError(w, err.Code(), err, core.SetCustomResponseHeaders) return } if rekeyConf == nil { - respondError(w, http.StatusBadRequest, errors.New("no rekey configuration found"), lc) + respondError(w, http.StatusBadRequest, errors.New("no rekey configuration found"), core.SetCustomResponseHeaders) return } // Get the progress started, progress, err := core.RekeyProgress(recovery, true) if err != nil { - respondError(w, err.Code(), err, lc) + respondError(w, err.Code(), err, core.SetCustomResponseHeaders) return } @@ -336,18 +295,12 @@ func handleSysRekeyVerifyGet(ctx context.Context, core *vault.Core, recovery boo N: rekeyConf.SecretShares, Progress: progress, } - respondOk(w, status, lc) + respondOk(w, status, core.SetCustomResponseHeaders) } func handleSysRekeyVerifyDelete(ctx context.Context, core *vault.Core, recovery bool, w http.ResponseWriter, r *http.Request) { - // Getting custom headers from listener's config - la := w.Header().Get("X-Vault-Listener-Add") - lc, errNew := core.GetCustomResponseHeaders(la) - if errNew != nil { - core.Logger().Debug("failed to get custom headers from listener config") - } if err := core.RekeyVerifyRestart(recovery); err != nil { - respondError(w, err.Code(), err, lc) + respondError(w, err.Code(), err, core.SetCustomResponseHeaders) return } @@ -355,23 +308,17 @@ func handleSysRekeyVerifyDelete(ctx context.Context, core *vault.Core, recovery } func handleSysRekeyVerifyPut(ctx context.Context, core *vault.Core, recovery bool, w http.ResponseWriter, r *http.Request) { - // Getting custom headers from listener's config - la := w.Header().Get("X-Vault-Listener-Add") - lc, err := core.GetCustomResponseHeaders(la) - if err != nil { - core.Logger().Debug("failed to get custom headers from listener config") - } // Parse the request var req RekeyVerificationUpdateRequest if _, err := parseJSONRequest(core.PerfStandby(), r, w, &req); err != nil { - respondError(w, http.StatusBadRequest, err, lc) + respondError(w, http.StatusBadRequest, err, core.SetCustomResponseHeaders) return } if req.Key == "" { respondError( w, http.StatusBadRequest, errors.New("'key' must be specified in request body as JSON"), - lc) + core.SetCustomResponseHeaders) return } @@ -387,7 +334,7 @@ func handleSysRekeyVerifyPut(ctx context.Context, core *vault.Core, recovery boo respondError( w, http.StatusBadRequest, errors.New("'key' must be a valid hex or base64 string"), - lc) + core.SetCustomResponseHeaders) return } } @@ -398,7 +345,7 @@ func handleSysRekeyVerifyPut(ctx context.Context, core *vault.Core, recovery boo // Use the key to make progress on rekey result, rekeyErr := core.RekeyVerify(ctx, key, req.Nonce, recovery) if rekeyErr != nil { - respondError(w, rekeyErr.Code(), rekeyErr, lc) + respondError(w, rekeyErr.Code(), rekeyErr, core.SetCustomResponseHeaders) return } @@ -407,7 +354,7 @@ func handleSysRekeyVerifyPut(ctx context.Context, core *vault.Core, recovery boo if result != nil { resp.Complete = true resp.Nonce = result.Nonce - respondOk(w, resp, lc) + respondOk(w, resp, core.SetCustomResponseHeaders) } else { handleSysRekeyVerifyGet(ctx, core, recovery, w, r) } diff --git a/http/sys_seal.go b/http/sys_seal.go index 9108c0e495e7a..effcca7221d74 100644 --- a/http/sys_seal.go +++ b/http/sys_seal.go @@ -15,22 +15,16 @@ import ( func handleSysSeal(core *vault.Core) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Getting custom headers from listener's config - la := w.Header().Get("X-Vault-Listener-Add") - lc, err := core.GetCustomResponseHeaders(la) - if err != nil { - core.Logger().Debug("failed to get custom headers from listener config") - } req, _, statusCode, err := buildLogicalRequest(core, w, r) if err != nil || statusCode != 0 { - respondError(w, statusCode, err, lc) + respondError(w, statusCode, err, core.SetCustomResponseHeaders) return } switch req.Operation { case logical.UpdateOperation: default: - respondError(w, http.StatusMethodNotAllowed, nil, lc) + respondError(w, http.StatusMethodNotAllowed, nil, core.SetCustomResponseHeaders) return } @@ -38,78 +32,66 @@ func handleSysSeal(core *vault.Core) http.Handler { // We use context.Background since there won't be a request context if the node isn't active if err := core.SealWithRequest(r.Context(), req); err != nil { if errwrap.Contains(err, logical.ErrPermissionDenied.Error()) { - respondError(w, http.StatusForbidden, err, lc) + respondError(w, http.StatusForbidden, err, core.SetCustomResponseHeaders) return } - respondError(w, http.StatusInternalServerError, err, lc) + respondError(w, http.StatusInternalServerError, err, core.SetCustomResponseHeaders) return } - respondOk(w, nil, lc) + respondOk(w, nil, core.SetCustomResponseHeaders) }) } func handleSysStepDown(core *vault.Core) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Getting custom headers from listener's config - la := w.Header().Get("X-Vault-Listener-Add") - lc, err := core.GetCustomResponseHeaders(la) - if err != nil { - core.Logger().Debug("failed to get custom headers from listener config") - } req, _, statusCode, err := buildLogicalRequest(core, w, r) if err != nil || statusCode != 0 { - respondError(w, statusCode, err, lc) + respondError(w, statusCode, err, core.SetCustomResponseHeaders) return } switch req.Operation { case logical.UpdateOperation: default: - respondError(w, http.StatusMethodNotAllowed, nil, lc) + respondError(w, http.StatusMethodNotAllowed, nil, core.SetCustomResponseHeaders) return } // Seal with the token above if err := core.StepDown(r.Context(), req); err != nil { if errwrap.Contains(err, logical.ErrPermissionDenied.Error()) { - respondError(w, http.StatusForbidden, err, lc) + respondError(w, http.StatusForbidden, err, core.SetCustomResponseHeaders) return } - respondError(w, http.StatusInternalServerError, err, lc) + respondError(w, http.StatusInternalServerError, err, core.SetCustomResponseHeaders) return } - respondOk(w, nil, lc) + respondOk(w, nil, core.SetCustomResponseHeaders) }) } func handleSysUnseal(core *vault.Core) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Getting custom headers from listener's config - la := w.Header().Get("X-Vault-Listener-Add") - lc, err := core.GetCustomResponseHeaders(la) - if err != nil { - core.Logger().Debug("failed to get custom headers from listener config") - } switch r.Method { case "PUT": case "POST": default: - respondError(w, http.StatusMethodNotAllowed, nil, lc) + respondError(w, http.StatusMethodNotAllowed, nil, core.SetCustomResponseHeaders) return } // Parse the request var req UnsealRequest if _, err := parseJSONRequest(core.PerfStandby(), r, w, &req); err != nil { - respondError(w, http.StatusBadRequest, err, lc) + respondError(w, http.StatusBadRequest, err, core.SetCustomResponseHeaders) return } if req.Reset { if !core.Sealed() { - respondError(w, http.StatusBadRequest, errors.New("vault is unsealed"), lc) + respondError(w, http.StatusBadRequest, errors.New("vault is unsealed"), core.SetCustomResponseHeaders) return } core.ResetUnsealProcess() @@ -121,7 +103,7 @@ func handleSysUnseal(core *vault.Core) http.Handler { respondError( w, http.StatusBadRequest, errors.New("'key' must be specified in request body as JSON, or 'reset' set to true"), - lc) + core.SetCustomResponseHeaders) return } @@ -137,7 +119,7 @@ func handleSysUnseal(core *vault.Core) http.Handler { respondError( w, http.StatusBadRequest, errors.New("'key' must be a valid hex or base64 string"), - lc) + core.SetCustomResponseHeaders) return } } @@ -157,10 +139,10 @@ func handleSysUnseal(core *vault.Core) http.Handler { case errwrap.Contains(err, vault.ErrBarrierSealed.Error()): case errwrap.Contains(err, consts.ErrStandby.Error()): default: - respondError(w, http.StatusInternalServerError, err, lc) + respondError(w, http.StatusInternalServerError, err, core.SetCustomResponseHeaders) return } - respondError(w, http.StatusBadRequest, err, lc) + respondError(w, http.StatusBadRequest, err, core.SetCustomResponseHeaders) return } @@ -171,14 +153,8 @@ func handleSysUnseal(core *vault.Core) http.Handler { func handleSysSealStatus(core *vault.Core) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Getting custom headers from listener's config - la := w.Header().Get("X-Vault-Listener-Add") - lc, err := core.GetCustomResponseHeaders(la) - if err != nil { - core.Logger().Debug("failed to get custom headers from listener config") - } if r.Method != "GET" { - respondError(w, http.StatusMethodNotAllowed, nil, lc) + respondError(w, http.StatusMethodNotAllowed, nil, core.SetCustomResponseHeaders) return } @@ -187,20 +163,14 @@ func handleSysSealStatus(core *vault.Core) http.Handler { } func handleSysSealStatusRaw(core *vault.Core, w http.ResponseWriter, r *http.Request) { - // Getting custom headers from listener's config - la := w.Header().Get("X-Vault-Listener-Add") - lc, err := core.GetCustomResponseHeaders(la) - if err != nil { - core.Logger().Debug("failed to get custom headers from listener config") - } ctx := context.Background() status, err := core.GetSealStatus(ctx) if err != nil { - respondError(w, http.StatusInternalServerError, err, lc) + respondError(w, http.StatusInternalServerError, err, core.SetCustomResponseHeaders) return } - respondOk(w, status, lc) + respondOk(w, status, core.SetCustomResponseHeaders) } // Note: because we didn't provide explicit tagging in the past we can't do it diff --git a/http/util.go b/http/util.go index da1c921c84b6a..a27405d571ed1 100644 --- a/http/util.go +++ b/http/util.go @@ -33,15 +33,9 @@ var ( func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Getting custom headers from listener's config - la := w.Header().Get("X-Vault-Listener-Add") - lc, err := core.GetCustomResponseHeaders(la) - if err != nil { - core.Logger().Debug("failed to get custom headers from listener config") - } ns, err := namespace.FromContext(r.Context()) if err != nil { - respondError(w, http.StatusInternalServerError, err, lc) + respondError(w, http.StatusInternalServerError, err, core.SetCustomResponseHeaders) return } @@ -50,7 +44,7 @@ func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler // again, which is not desired. path, status, err := buildLogicalPath(r) if err != nil || status != 0 { - respondError(w, status, err, lc) + respondError(w, status, err, core.SetCustomResponseHeaders) return } @@ -63,7 +57,7 @@ func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler }) if err != nil { core.Logger().Error("failed to apply quota", "path", path, "error", err) - respondError(w, http.StatusUnprocessableEntity, err, lc) + respondError(w, http.StatusUnprocessableEntity, err, core.SetCustomResponseHeaders) return } @@ -75,7 +69,7 @@ func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler if !quotaResp.Allowed { quotaErr := fmt.Errorf("request path %q: %w", path, quotas.ErrRateLimitQuotaExceeded) - respondError(w, http.StatusTooManyRequests, quotaErr, lc) + respondError(w, http.StatusTooManyRequests, quotaErr, core.SetCustomResponseHeaders) if core.Logger().IsTrace() { core.Logger().Trace("request rejected due to rate limit quota violation", "request_path", path) @@ -84,7 +78,7 @@ func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler if core.RateLimitAuditLoggingEnabled() { req, _, status, err := buildLogicalRequestNoAuth(core.PerfStandby(), w, r) if err != nil || status != 0 { - respondError(w, status, err, lc) + respondError(w, status, err, core.SetCustomResponseHeaders) return } diff --git a/internalshared/configutil/http_response_headers.go b/internalshared/configutil/http_response_headers.go index 779391ee87edb..5ae858743bea0 100644 --- a/internalshared/configutil/http_response_headers.go +++ b/internalshared/configutil/http_response_headers.go @@ -7,7 +7,7 @@ import ( "strings" ) -var defaultHeaderNames = []string { +var DefaultHeaderNames = []string { "Content-Security-Policy", "X-XSS-Protection", "X-Frame-Options", @@ -16,7 +16,7 @@ var defaultHeaderNames = []string { "Content-Type", } -var validStatusCodeCollection = []string { +var ValidCustomStatusCodeCollection = []string { "default", "1xx", "2xx", @@ -34,7 +34,7 @@ const ( contentType = "text/plain; charset=utf-8" ) -func parseDefaultHeaders(h string) string { +func ParseDefaultHeaders(h string) string { switch h { case "Content-Security-Policy": return contentSecurityPolicy @@ -60,11 +60,11 @@ func setDefaultResponseHeaders(c map[string]string) map[string]string { defaults[k] = v } - for _, hn := range defaultHeaderNames { + for _, hn := range DefaultHeaderNames { if _, ok := c[hn]; ok { continue } - hv := parseDefaultHeaders(hn) + hv := ParseDefaultHeaders(hn) if hv != "" { defaults[hn] = hv } @@ -128,7 +128,7 @@ func isValidList(in interface{}) bool { // checking for status codes outside the boundary func isValidStatusCode(sc string) bool { - for _, v := range validStatusCodeCollection { + for _, v := range ValidCustomStatusCodeCollection { if sc == v { return true } diff --git a/internalshared/configutil/listener.go b/internalshared/configutil/listener.go index 4395bb1520eee..107045e1d91c7 100644 --- a/internalshared/configutil/listener.go +++ b/internalshared/configutil/listener.go @@ -28,17 +28,6 @@ type ListenerProfiling struct { UnauthenticatedPProfAccessRaw interface{} `hcl:"unauthenticated_pprof_access,alias:UnauthenticatedPProfAccessRaw"` } -// TODO: remove this -type CH struct { - X interface{} `hcl:",key,alias:unknown"` - //UnusedKeys UnusedKeyMap `hcl:",unusedKeyPositions"` - Defaults interface{} `hcl:"-"` - DefaultsRaw interface{} `hcl:"default,alias:default"` - R307 map[string]string `hcl:"-"` - R307Raw interface{} `hcl:"307,alias:R307"` - -} - // Listener is the listener configuration for the server. type Listener struct { UnusedKeys UnusedKeyMap `hcl:",unusedKeyPositions"` diff --git a/internalshared/listenerutil/response_headers.go b/internalshared/listenerutil/response_headers.go deleted file mode 100644 index d3cff70894e1e..0000000000000 --- a/internalshared/listenerutil/response_headers.go +++ /dev/null @@ -1,92 +0,0 @@ -package listenerutil - -import ( - "fmt" - "net/http" - "net/textproto" - "strconv" -) - -// DefaultStatus is used to set default headers early before having a status code, -// for example, for /ui headers -const DefaultStatus = 1 - -func SetCustomResponseHeaders(hm map[string]map[string]string, w http.ResponseWriter, status int) error { - // Removing X-Vault-Listener-Add header from ResponseWriter - // This should be safe as the call to this function is right - // before w.WriteHeader for which the status code is finalized and known - w.Header().Del("X-Vault-Listener-Add") - - if hm == nil { - return nil - } - - // setter function to set the headers - setter := func(hv map[string]string) { - for h, v := range hv { - w.Header().Set(h, v) - } - } - - // Checking the validity of the status code - if status >= 600 || (status < 100 && status != DefaultStatus) { - return fmt.Errorf("invalid status code") - } - - // Setting the default headers first - setter(hm["default"]) - - // for DefaultStatus, we only set the default headers - if status == DefaultStatus { - return nil - } - - // setting the Xyy pattern first - d := fmt.Sprintf("%vxx", status / 100) - if val, ok := hm[d]; ok { - setter(val) - } - // Setting the specific headers - if val, ok := hm[strconv.Itoa(status)]; ok { - setter(val) - } - - return nil -} - -func FetchCustomResponseHeaderValue(hm map[string]map[string]string, th string, sc int) (string, error) { - if hm == nil { - return "", nil - } - if th == "" { - return "", fmt.Errorf("invalid target header") - } - - var h map[string]string - if sc == DefaultStatus { - h = hm["default"] - }else { - h = hm[strconv.Itoa(sc)] - } - - hn := textproto.CanonicalMIMEHeaderKey(th) - if v, ok := h[hn]; ok { - return v, nil - } - return "", nil -} - -func ExistHeader(hm map[string]map[string]string, th string, sl []int) bool { - if len(sl) == 0 { - return false - } - - for _, s := range sl { - chv, _ := FetchCustomResponseHeaderValue(hm, th, s) - if chv != "" { - return true - } - } - - return false -} \ No newline at end of file diff --git a/vault/core.go b/vault/core.go index 391b646e25a00..957e3bddb57e2 100644 --- a/vault/core.go +++ b/vault/core.go @@ -42,7 +42,6 @@ import ( "github.com/hashicorp/vault/command/server" "github.com/hashicorp/vault/helper/metricsutil" "github.com/hashicorp/vault/helper/namespace" - "github.com/hashicorp/vault/internalshared/configutil" "github.com/hashicorp/vault/physical/raft" "github.com/hashicorp/vault/sdk/helper/certutil" "github.com/hashicorp/vault/sdk/helper/consts" @@ -518,6 +517,8 @@ type Core struct { // clusterListener starts up and manages connections on the cluster ports clusterListener *atomic.Value + customListenerHeader *ListenersCustomHeaderList + // Telemetry objects metricsHelper *metricsutil.MetricsHelper @@ -1002,6 +1003,11 @@ func NewCore(conf *CoreConfig) (*Core, error) { c.clusterListener.Store((*cluster.Listener)(nil)) + uiHeaders, _ := c.UIHeaders() + customHeaderLogger := conf.Logger.Named("customHeader") + c.allLoggers = append(c.allLoggers, customHeaderLogger) + c.customListenerHeader = NewListenerCustomHeader(conf.RawConfig.Listeners, customHeaderLogger, uiHeaders) + quotasLogger := conf.Logger.Named("quotas") c.allLoggers = append(c.allLoggers, quotasLogger) c.quotaManager, err = quotas.NewManager(quotasLogger, c.quotaLeaseWalker, c.metricSink) @@ -2621,6 +2627,18 @@ func (c *Core) SetLogLevel(level log.Level) { } } +func (c *Core) GetLogger(name string) log.Logger { + c.allLoggersLock.Lock() + defer c.allLoggersLock.Unlock() + for _, logger := range c.allLoggers { + ln := logger.Name() + if ln == name { + return logger + } + } + return nil +} + // SetConfig sets core's config object to the newly provided config. func (c *Core) SetConfig(conf *server.Config) { c.rawConfig.Store(conf) @@ -2633,64 +2651,51 @@ func (c *Core) SetConfig(conf *server.Config) { c.logger.Debug("set config", "sanitized config", string(bz)) } -func (c *Core) GetCustomResponseHeaders(la string) (map[string]map[string]string, error) { - if la == "" { - return nil, nil +func (c *Core) SetCustomResponseHeaders(w http.ResponseWriter, status int) { + if c.customListenerHeader == nil { + c.logger.Debug("custom response headers not configured") + return } - ln, err := c.GetListenersConf(la) - if err != nil || ln == nil { - return nil, err + + c.customListenerHeader.SetCustomResponseHeaders(w, status) +} + +func (c *Core) ExistCustomResponseHeader(header string, statusCodeList []int, la string) bool { + if c.customListenerHeader == nil { + c.logger.Debug("custom response headers not configured") + return false } - // TODO: maybe copy the ln.CustomResponseHeaders and return the copy? - return ln.CustomResponseHeaders, nil + + return c.customListenerHeader.ExistHeader(header, statusCodeList, la) } -func (c *Core) GetListenersConf(address string) (*configutil.Listener, error) { +func (c *Core) ReloadCustomListenerHeader() error { + conf := c.rawConfig.Load() if conf == nil { - return nil, fmt.Errorf("failed to load config") - } - lns := conf.(*server.Config).Listeners - for _, ln := range lns{ - if ln.Address == address { - return ln, nil - } + return fmt.Errorf("failed to load core raw config") } - return nil, fmt.Errorf("failed to find listener config with address %v", address) -} -// SanitizedCustomResponseHeader sanitizes listener config from invalid custom headers -func (c *Core) SanitizedCustomResponseHeader(conf *server.Config) { - hm := make(map[string]map[string]string) - userHeaders, err := c.UIHeaders() - if err != nil { - c.Logger().Trace("failed to get ui headers", "error:", err.Error()) + tempLH := c.customListenerHeader + c.customListenerHeader = nil + + uiHeaders, _ := c.UIHeaders() + + customHeaderLogger := c.GetLogger("customHeader") + if customHeaderLogger == nil { + customHeaderLogger = c.Logger().Named("customHeader") + c.AddLogger(customHeaderLogger) } - for _, ln := range conf.Listeners { - for sc, ch := range ln.CustomResponseHeaders { - hv := make(map[string]string) - for h, v := range ch { - // X-Vault- prefix is reserved for Vault internal processes - if strings.HasPrefix(h, "X-Vault-") { - c.Logger().Error("Custom headers starting with X-Vault are not valid", "header", h) - continue - } + lns := conf.(*server.Config).Listeners + c.customListenerHeader = NewListenerCustomHeader(lns, customHeaderLogger, uiHeaders) - // Checking for UI headers, if any common header exist, HCL headers take precedence - if userHeaders != nil { - exist := userHeaders.Get(h) - if exist != "" { - c.Logger().Error("found a duplicate header in UI, note that config file headers take precedence.", "header:", h) - } - } - hv[h] = v - } - hm[sc] = hv - } - ln.CustomResponseHeaders = hm + if c.customListenerHeader == nil { + c.logger.Error("failed to reload custom headers, reverting back the old configuration") + c.customListenerHeader = tempLH } + return nil } // SanitizedConfig returns a sanitized version of the current config. diff --git a/vault/external_tests/raft/raft_test.go b/vault/external_tests/raft/raft_test.go index 040d57e5542a5..f98b575589fd8 100644 --- a/vault/external_tests/raft/raft_test.go +++ b/vault/external_tests/raft/raft_test.go @@ -1205,83 +1205,3 @@ func TestRaft_Join_InitStatus(t *testing.T) { verifyInitStatus(i, true) } } - - -func TestRaft_SnapshotRestoreOnStandby(t *testing.T) { - t.Parallel() - cluster := raftCluster(t, nil) - defer cluster.Cleanup() - - leaderClient := cluster.Cores[0].Client - - // Write a few keys - for i := 0; i < 10; i++ { - _, err := leaderClient.Logical().Write(fmt.Sprintf("secret/%d", i), map[string]interface{}{ - "test": "data", - }) - if err != nil { - t.Fatal(err) - } - } - - transport := cleanhttp.DefaultPooledTransport() - transport.TLSClientConfig = cluster.Cores[0].TLSConfig.Clone() - if err := http2.ConfigureTransport(transport); err != nil { - t.Fatal(err) - } - client := &http.Client{ - Transport: transport, - } - - // Take a snapshot - req := leaderClient.NewRequest("GET", "/v1/sys/storage/raft/snapshot") - httpReq, err := req.ToHTTP() - if err != nil { - t.Fatal(err) - } - resp, err := client.Do(httpReq) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - - snap, err := ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - if len(snap) == 0 { - t.Fatal("no snapshot returned") - } - - // Write a few more keys - for i := 10; i < 20; i++ { - _, err := leaderClient.Logical().Write(fmt.Sprintf("secret/%d", i), map[string]interface{}{ - "test": "data", - }) - if err != nil { - t.Fatal(err) - } - } - - // Restore snapshot - req = leaderClient.NewRequest("POST", "/v1/sys/storage/raft/snapshot") - req.Body = bytes.NewBuffer(snap) - httpReq, err = req.ToHTTP() - if err != nil { - t.Fatal(err) - } - resp, err = client.Do(httpReq) - if err != nil { - t.Fatal(err) - } - - // List kv to make sure we removed the extra keys - secret, err := leaderClient.Logical().List("secret/") - if err != nil { - t.Fatal(err) - } - - if len(secret.Data["keys"].([]interface{})) != 10 { - t.Fatal("snapshot didn't apply correctly") - } -} diff --git a/vault/logical_system.go b/vault/logical_system.go index 6cd479448c0d9..7e4b21347c70c 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -9,7 +9,6 @@ import ( "encoding/json" "errors" "fmt" - "github.com/hashicorp/vault/internalshared/listenerutil" "hash" "net/http" "path" @@ -2621,27 +2620,22 @@ func (b *SystemBackend) handleConfigUIHeadersUpdate(ctx context.Context, req *lo return logical.ErrorResponse("X-Vault headers cannot be set"), logical.ErrInvalidRequest } - // Getting custom headers from listener's config if req.ResponseWriter == nil { return logical.ErrorResponse("no ResponseWriter in the request"), logical.ErrInvalidRequest } la := req.ResponseWriter.Header().Get("X-Vault-Listener-Add") - lc, err := b.Core.GetCustomResponseHeaders(la) - if err != nil { - b.Core.Logger().Debug("failed to get custom headers from listener config") - } // Translate the list of values to the valid header string value := http.Header{} for _, v := range values { // check if the header exist in "default" and 200 status code maps of custom response headers - sl := []int{listenerutil.DefaultStatus, 200} - if listenerutil.ExistHeader(lc, header, sl) { + sl := []int{DefaultCustomResponseStatus, 200} + if b.Core.ExistCustomResponseHeader(header, sl, la) { return logical.ErrorResponse("header already exist in server configuration file"), logical.ErrInvalidRequest } value.Add(header, v) } - err = b.Core.uiConfig.SetHeader(ctx, header, value.Values(header)) + err := b.Core.uiConfig.SetHeader(ctx, header, value.Values(header)) if err != nil { return nil, err } From 5ec251078fe7efc3686fef2be487e33519ed643d Mon Sep 17 00:00:00 2001 From: hamid ghaf Date: Sat, 4 Sep 2021 17:33:54 -0700 Subject: [PATCH 06/25] missing additional file --- vault/custom_response_headers.go | 238 +++++++++++++++++++++++++++++++ 1 file changed, 238 insertions(+) create mode 100644 vault/custom_response_headers.go diff --git a/vault/custom_response_headers.go b/vault/custom_response_headers.go new file mode 100644 index 0000000000000..d9c974e439dd8 --- /dev/null +++ b/vault/custom_response_headers.go @@ -0,0 +1,238 @@ +package vault + +import ( + "fmt" + log "github.com/hashicorp/go-hclog" + "net/http" + "net/textproto" + "strconv" + "strings" + + "github.com/hashicorp/vault/internalshared/configutil" +) + +// DefaultCustomResponseStatus is used to set default headers early before having a status code, +// for example, for /ui headers +const DefaultCustomResponseStatus = 1 + +type ListenersCustomHeaderList struct { + logger log.Logger + CustomHeadersList []*ListenerCustomHeaders +} + +type ListenerCustomHeaders struct { + Address string + StatusCodeHeaderMap map[string][]*CustomHeader +} + +type CustomHeader struct { + Name string + Value string +} + +func NewListenerCustomHeader(ln []*configutil.Listener, logger log.Logger, uiHeaders http.Header) *ListenersCustomHeaderList { + + if ln == nil { + return nil + } + + ll := &ListenersCustomHeaderList{ + logger: logger, + } + + for _, l := range ln { + lc := &ListenerCustomHeaders{ + Address: l.Address, + } + lc.StatusCodeHeaderMap = make(map[string][]*CustomHeader) + for sc, hv := range l.CustomResponseHeaders { + var chl []*CustomHeader + for h, v := range hv { + + // X-Vault- prefix is reserved for Vault internal processes + if strings.HasPrefix(h, "X-Vault-") { + logger.Error("Custom headers starting with X-Vault are not valid", "header", h) + continue + } + + // Checking for UI headers, if any common header exist, HCL headers take precedence + if uiHeaders != nil { + exist := uiHeaders.Get(h) + if exist != "" { + logger.Error("found a duplicate header in UI, note that config file headers take precedence.", "header:", h) + } + } + + ch := &CustomHeader{ + Name: h, + Value: v, + } + + chl = append(chl, ch) + } + lc.StatusCodeHeaderMap[sc] = chl + } + ll.CustomHeadersList = append(ll.CustomHeadersList, lc) + } + + return ll +} + +func (c *ListenersCustomHeaderList) SetCustomResponseHeaders(w http.ResponseWriter, status int) { + if w == nil { + c.logger.Error("No ResponseWriter provided") + } + + // Getting the listener address to set its corresponding custom headers + la := w.Header().Get("X-Vault-Listener-Add") + if la == "" { + c.logger.Error("X-Vault-Listener-Add was not set in the ResponseWriter") + return + } + + // Removing X-Vault-Listener-Add header from ResponseWriter + // This should be safe as the call to this function is right + // before w.WriteHeader for which the status code is finalized and known + w.Header().Del("X-Vault-Listener-Add") + + lch := c.getListenerMap(la) + if lch == nil { + c.logger.Warn("no listener config found") + return + } + + // setter function to set the headers + setter := func(hvl []*CustomHeader) { + for _, hv := range hvl { + w.Header().Set(hv.Name, hv.Value) + } + } + + // Checking the validity of the status code + if status >= 600 || (status < 100 && status != DefaultCustomResponseStatus) { + c.logger.Error("invalid status code") + return + } + + // Setting the default headers first + setter(lch["default"]) + + // for DefaultCustomResponseStatus, we only set the default headers + if status == DefaultCustomResponseStatus { + return + } + + // setting the Xyy pattern first + d := fmt.Sprintf("%vxx", status / 100) + if val, ok := lch[d]; ok { + setter(val) + } + // Setting the specific headers + if val, ok := lch[strconv.Itoa(status)]; ok { + setter(val) + } + + return +} + +func (c *ListenersCustomHeaderList) getListenerMap(address string) map[string][]*CustomHeader { + if c.CustomHeadersList == nil { + return nil + } + for _, l := range c.CustomHeadersList { + if l.Address == address { + return l.StatusCodeHeaderMap + } + } + return nil +} + +func (c *ListenersCustomHeaderList) findCustomHeaderMatchStatusCode(hm map[string][]*CustomHeader, sc int) ([]*CustomHeader, error) { + + if sc == DefaultCustomResponseStatus { + return hm["default"], nil + } + + if h, ok := hm[strconv.Itoa(sc)]; ok { + return h, nil + } + + d := fmt.Sprintf("%vxx", sc / 100) + for _, s := range configutil.ValidCustomStatusCodeCollection { + if s == d { + if h, ok := hm[s]; ok { + return h, nil + } + } + } + + return nil, fmt.Errorf("failed to find a match for the given status code:%v", sc) +} + +func (c *ListenersCustomHeaderList) FetchCustomResponseHeaderValue(header string, sc int, la string) ([]string, error) { + + if header == "" { + return nil, fmt.Errorf("invalid target header") + } + + getHeader := func(hm map[string][]*CustomHeader) (string, error){ + ch, err := c.findCustomHeaderMatchStatusCode(hm, sc) + if err != nil { + return "", err + } + + if ch == nil { + return "", nil + } + + hn := textproto.CanonicalMIMEHeaderKey(header) + for _, h := range ch { + if h.Name == hn { + return h.Value, nil + } + } + + return "", nil + } + + var lch []*ListenerCustomHeaders + if la == "" { + lch = c.CustomHeadersList + } else { + for _, l := range c.CustomHeadersList { + if l.Address == la { + lch = append(lch, l) + } + } + if len(lch) == 0 { + return nil, fmt.Errorf("no listener found with address:%v", la) + } + } + + var headers []string + var err error + for _, l := range lch { + h, err := getHeader(l.StatusCodeHeaderMap) + if err != nil || h == "" { + continue + } + headers = append(headers, h) + } + + return headers, err +} + +func(c *ListenersCustomHeaderList) ExistHeader(th string, sl []int, la string) bool { + if len(sl) == 0 { + return false + } + + for _, s := range sl { + chv, _ := c.FetchCustomResponseHeaderValue(th, s, la) + if chv != nil { + return true + } + } + + return false +} From 6982b8a04996d7e8b6f0cbb14afe027f172ca52a Mon Sep 17 00:00:00 2001 From: hamid ghaf Date: Tue, 7 Sep 2021 14:12:40 -0700 Subject: [PATCH 07/25] Some refactoring --- command/server.go | 5 +- http/handler.go | 5 +- .../configutil/http_response_headers.go | 20 +++- internalshared/configutil/listener.go | 2 +- vault/core.go | 8 +- vault/custom_response_headers.go | 108 ++++++++---------- vault/logical_system.go | 6 +- 7 files changed, 77 insertions(+), 77 deletions(-) diff --git a/command/server.go b/command/server.go index c063778450cd9..5e6fd8bdca807 100644 --- a/command/server.go +++ b/command/server.go @@ -1540,7 +1540,10 @@ func (c *ServerCommand) Run(args []string) int { } core.SetConfig(config) - if err = core.ReloadCustomListenerHeader(); err != nil { + + // reloading custom response headers to make sure we have + // the most up to date headers after reloading the config file + if err = core.ReloadCustomResponseHeaders(); err != nil { c.UI.Error(err.Error()) } diff --git a/http/handler.go b/http/handler.go index 45ab0ed425f69..dc9ff3841d081 100644 --- a/http/handler.go +++ b/http/handler.go @@ -335,13 +335,12 @@ func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerPr w.Header().Set("X-Vault-Hostname", hostname) } - // Setting listener address so that we could get the config from core + // Setting the listener address as a header so that we could set customized + // headers configured in the corresponding listener stanza var la string if props.ListenerConfig != nil { la = props.ListenerConfig.Address } - // Setting a header so that we could set customized headers - // configured in the corresponding listener stanza w.Header().Set("X-Vault-Listener-Add", la) switch { diff --git a/internalshared/configutil/http_response_headers.go b/internalshared/configutil/http_response_headers.go index 5ae858743bea0..430ae277e6925 100644 --- a/internalshared/configutil/http_response_headers.go +++ b/internalshared/configutil/http_response_headers.go @@ -60,6 +60,8 @@ func setDefaultResponseHeaders(c map[string]string) map[string]string { defaults[k] = v } + // setting all default headers that are not included in the config + // file under the "default" category for _, hn := range DefaultHeaderNames { if _, ok := c[hn]; ok { continue @@ -69,7 +71,7 @@ func setDefaultResponseHeaders(c map[string]string) map[string]string { defaults[hn] = hv } } - fmt.Printf("Default headers are %v", defaults) + return defaults } @@ -87,7 +89,7 @@ func ParseCustomResponseHeaders(r interface{}) (map[string]map[string]string, er return nil, fmt.Errorf("invalid response header type") } - if !isValidStatusCode(sc) { + if !IsValidStatusCode(sc) { return nil, fmt.Errorf("invalid status code found in the config file: %v", sc) } @@ -126,14 +128,22 @@ func isValidList(in interface{}) bool { return false } -// checking for status codes outside the boundary -func isValidStatusCode(sc string) bool { +func IsValidStatusCodeCollection(sc string) bool { for _, v := range ValidCustomStatusCodeCollection { if sc == v { return true } } + return false +} + +// IsValidStatusCode checking for status codes outside the boundary +func IsValidStatusCode(sc string) bool { + if IsValidStatusCodeCollection(sc) { + return true + } + i, err := strconv.Atoi(sc) if err != nil { return false @@ -164,7 +174,7 @@ func parseHeaders(in map[string]interface{}) (map[string]string, error) { func parseHeaderValues(h interface{}) (string, error) { var sl []string if !isValidList(h) { - return "", fmt.Errorf("failed to parse custom_response_headers3") + return "", fmt.Errorf("failed to parse header values") } vli := h.([]interface{}) for _, vh := range vli { diff --git a/internalshared/configutil/listener.go b/internalshared/configutil/listener.go index 107045e1d91c7..77cbbf93b10d1 100644 --- a/internalshared/configutil/listener.go +++ b/internalshared/configutil/listener.go @@ -102,7 +102,7 @@ type Listener struct { // Custom Http response headers CustomResponseHeaders map[string]map[string]string `hcl:"-"` - CustomResponseHeadersRaw interface{} `hcl:"custom_response_headers,alias:custom_response_headers"` + CustomResponseHeadersRaw interface{} `hcl:"custom_response_headers,alias:custom_response_headers"` } diff --git a/vault/core.go b/vault/core.go index 957e3bddb57e2..110488db7047b 100644 --- a/vault/core.go +++ b/vault/core.go @@ -517,7 +517,7 @@ type Core struct { // clusterListener starts up and manages connections on the cluster ports clusterListener *atomic.Value - customListenerHeader *ListenersCustomHeaderList + customListenerHeader *ListenersCustomResponseHeadersList // Telemetry objects metricsHelper *metricsutil.MetricsHelper @@ -2660,16 +2660,16 @@ func (c *Core) SetCustomResponseHeaders(w http.ResponseWriter, status int) { c.customListenerHeader.SetCustomResponseHeaders(w, status) } -func (c *Core) ExistCustomResponseHeader(header string, statusCodeList []int, la string) bool { +func (c *Core) ExistCustomResponseHeader(header string, statusCode int, la string) bool { if c.customListenerHeader == nil { c.logger.Debug("custom response headers not configured") return false } - return c.customListenerHeader.ExistHeader(header, statusCodeList, la) + return c.customListenerHeader.ExistHeader(header, statusCode, la) } -func (c *Core) ReloadCustomListenerHeader() error { +func (c *Core) ReloadCustomResponseHeaders() error { conf := c.rawConfig.Load() if conf == nil { diff --git a/vault/custom_response_headers.go b/vault/custom_response_headers.go index d9c974e439dd8..2601239e9afd0 100644 --- a/vault/custom_response_headers.go +++ b/vault/custom_response_headers.go @@ -11,11 +11,7 @@ import ( "github.com/hashicorp/vault/internalshared/configutil" ) -// DefaultCustomResponseStatus is used to set default headers early before having a status code, -// for example, for /ui headers -const DefaultCustomResponseStatus = 1 - -type ListenersCustomHeaderList struct { +type ListenersCustomResponseHeadersList struct { logger log.Logger CustomHeadersList []*ListenerCustomHeaders } @@ -30,13 +26,13 @@ type CustomHeader struct { Value string } -func NewListenerCustomHeader(ln []*configutil.Listener, logger log.Logger, uiHeaders http.Header) *ListenersCustomHeaderList { +func NewListenerCustomHeader(ln []*configutil.Listener, logger log.Logger, uiHeaders http.Header) *ListenersCustomResponseHeadersList { if ln == nil { return nil } - ll := &ListenersCustomHeaderList{ + ll := &ListenersCustomResponseHeadersList{ logger: logger, } @@ -48,14 +44,14 @@ func NewListenerCustomHeader(ln []*configutil.Listener, logger log.Logger, uiHea for sc, hv := range l.CustomResponseHeaders { var chl []*CustomHeader for h, v := range hv { - + // Sanitizing custom headers // X-Vault- prefix is reserved for Vault internal processes if strings.HasPrefix(h, "X-Vault-") { logger.Error("Custom headers starting with X-Vault are not valid", "header", h) continue } - // Checking for UI headers, if any common header exist, HCL headers take precedence + // Checking for UI headers, if any common header exists, we just log an error if uiHeaders != nil { exist := uiHeaders.Get(h) if exist != "" { @@ -78,7 +74,7 @@ func NewListenerCustomHeader(ln []*configutil.Listener, logger log.Logger, uiHea return ll } -func (c *ListenersCustomHeaderList) SetCustomResponseHeaders(w http.ResponseWriter, status int) { +func (c *ListenersCustomResponseHeadersList) SetCustomResponseHeaders(w http.ResponseWriter, status int) { if w == nil { c.logger.Error("No ResponseWriter provided") } @@ -97,7 +93,7 @@ func (c *ListenersCustomHeaderList) SetCustomResponseHeaders(w http.ResponseWrit lch := c.getListenerMap(la) if lch == nil { - c.logger.Warn("no listener config found") + c.logger.Warn("no listener config found", "address", la) return } @@ -109,7 +105,7 @@ func (c *ListenersCustomHeaderList) SetCustomResponseHeaders(w http.ResponseWrit } // Checking the validity of the status code - if status >= 600 || (status < 100 && status != DefaultCustomResponseStatus) { + if status >= 600 || status < 100 { c.logger.Error("invalid status code") return } @@ -117,11 +113,6 @@ func (c *ListenersCustomHeaderList) SetCustomResponseHeaders(w http.ResponseWrit // Setting the default headers first setter(lch["default"]) - // for DefaultCustomResponseStatus, we only set the default headers - if status == DefaultCustomResponseStatus { - return - } - // setting the Xyy pattern first d := fmt.Sprintf("%vxx", status / 100) if val, ok := lch[d]; ok { @@ -135,7 +126,7 @@ func (c *ListenersCustomHeaderList) SetCustomResponseHeaders(w http.ResponseWrit return } -func (c *ListenersCustomHeaderList) getListenerMap(address string) map[string][]*CustomHeader { +func (c *ListenersCustomResponseHeadersList) getListenerMap(address string) map[string][]*CustomHeader { if c.CustomHeadersList == nil { return nil } @@ -147,54 +138,53 @@ func (c *ListenersCustomHeaderList) getListenerMap(address string) map[string][] return nil } -func (c *ListenersCustomHeaderList) findCustomHeaderMatchStatusCode(hm map[string][]*CustomHeader, sc int) ([]*CustomHeader, error) { +func (c *ListenersCustomResponseHeadersList) findCustomHeaderMatchStatusCode(hm map[string][]*CustomHeader, sc int, hn string) string { - if sc == DefaultCustomResponseStatus { - return hm["default"], nil + getHeader := func(ch []*CustomHeader) string { + for _, h := range ch { + if h.Name == hn { + return h.Value + } + } + return "" } - if h, ok := hm[strconv.Itoa(sc)]; ok { - return h, nil + // starting with the most specific status code + if ch, ok := hm[strconv.Itoa(sc)]; ok { + h := getHeader(ch) + if h != "" { + return h + } } - d := fmt.Sprintf("%vxx", sc / 100) - for _, s := range configutil.ValidCustomStatusCodeCollection { - if s == d { - if h, ok := hm[s]; ok { - return h, nil + s := fmt.Sprintf("%vxx", sc/100) + if configutil.IsValidStatusCodeCollection(s) { + if ch, ok := hm[s]; ok { + h := getHeader(ch) + if h != "" { + return h } } } - return nil, fmt.Errorf("failed to find a match for the given status code:%v", sc) + // At this point, we could not find a match for the given status code in the config file + // so, we just return the "default" ones + h := getHeader(hm["default"]) + if h != ""{ + return h + } + + return "" } -func (c *ListenersCustomHeaderList) FetchCustomResponseHeaderValue(header string, sc int, la string) ([]string, error) { +func (c *ListenersCustomResponseHeadersList) FetchCustomResponseHeaderValue(header string, sc int, la string) ([]string, error) { if header == "" { return nil, fmt.Errorf("invalid target header") } - getHeader := func(hm map[string][]*CustomHeader) (string, error){ - ch, err := c.findCustomHeaderMatchStatusCode(hm, sc) - if err != nil { - return "", err - } - - if ch == nil { - return "", nil - } - - hn := textproto.CanonicalMIMEHeaderKey(header) - for _, h := range ch { - if h.Name == hn { - return h.Value, nil - } - } - - return "", nil - } - + // either looking for a specific listener, or if listener address isn't given, + // checking for all available listeners var lch []*ListenerCustomHeaders if la == "" { lch = c.CustomHeadersList @@ -211,9 +201,10 @@ func (c *ListenersCustomHeaderList) FetchCustomResponseHeaderValue(header string var headers []string var err error + hn := textproto.CanonicalMIMEHeaderKey(header) for _, l := range lch { - h, err := getHeader(l.StatusCodeHeaderMap) - if err != nil || h == "" { + h := c.findCustomHeaderMatchStatusCode(l.StatusCodeHeaderMap, sc, hn) + if h == "" { continue } headers = append(headers, h) @@ -222,16 +213,15 @@ func (c *ListenersCustomHeaderList) FetchCustomResponseHeaderValue(header string return headers, err } -func(c *ListenersCustomHeaderList) ExistHeader(th string, sl []int, la string) bool { - if len(sl) == 0 { +func(c *ListenersCustomResponseHeadersList) ExistHeader(th string, sc int, la string) bool { + if !configutil.IsValidStatusCode(strconv.Itoa(sc)) { + c.logger.Error("failed to check if a header exist in config file due to invalid status code") return false } - for _, s := range sl { - chv, _ := c.FetchCustomResponseHeaderValue(th, s, la) - if chv != nil { - return true - } + chv, _ := c.FetchCustomResponseHeaderValue(th, sc, la) + if chv != nil { + return true } return false diff --git a/vault/logical_system.go b/vault/logical_system.go index 7e4b21347c70c..a59cb1b5e1dd9 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -2623,14 +2623,12 @@ func (b *SystemBackend) handleConfigUIHeadersUpdate(ctx context.Context, req *lo if req.ResponseWriter == nil { return logical.ErrorResponse("no ResponseWriter in the request"), logical.ErrInvalidRequest } - la := req.ResponseWriter.Header().Get("X-Vault-Listener-Add") // Translate the list of values to the valid header string value := http.Header{} for _, v := range values { - // check if the header exist in "default" and 200 status code maps of custom response headers - sl := []int{DefaultCustomResponseStatus, 200} - if b.Core.ExistCustomResponseHeader(header, sl, la) { + la := req.ResponseWriter.Header().Get("X-Vault-Listener-Add") + if b.Core.ExistCustomResponseHeader(header, 200, la) { return logical.ErrorResponse("header already exist in server configuration file"), logical.ErrInvalidRequest } value.Add(header, v) From f804b261225105c112c0de09bf9eb06718a8fb24 Mon Sep 17 00:00:00 2001 From: hamid ghaf Date: Wed, 8 Sep 2021 16:18:06 -0700 Subject: [PATCH 08/25] Adding automated tests for the feature --- .../config_custom_response_headers_test.go | 60 +++++ .../config_custom_response_headers_1.hcl | 31 +++ vault/core.go | 4 +- vault/custom_response_headers.go | 118 +++++++--- vault/custom_response_headers_test.go | 207 ++++++++++++++++++ vault/logical_system.go | 2 +- 6 files changed, 385 insertions(+), 37 deletions(-) create mode 100644 command/server/config_custom_response_headers_test.go create mode 100644 command/server/test-fixtures/config_custom_response_headers_1.hcl create mode 100644 vault/custom_response_headers_test.go diff --git a/command/server/config_custom_response_headers_test.go b/command/server/config_custom_response_headers_test.go new file mode 100644 index 0000000000000..1b6c395cddd24 --- /dev/null +++ b/command/server/config_custom_response_headers_test.go @@ -0,0 +1,60 @@ +package server + +import ( + "github.com/go-test/deep" + "testing" +) + +var defaultCustomHeaders = map[string]string { + "Strict-Transport-Security": "max-age=1; domains", + "Content-Security-Policy": "default-src 'others'", + "X-Vault-Ignored": "ignored", + "X-Custom-Header": "Custom header value default", + "X-Frame-Options": "Deny", + "X-Content-Type-Options": "nosniff", + "Content-Type": "text/plain; charset=utf-8", + "X-XSS-Protection": "1; mode=block", +} + +var customHeaders307 = map[string]string { + "X-Custom-Header": "Custom header value 307", +} + +var customHeader3xx = map[string]string { + "X-Vault-Ignored-3xx": "Ignored 3xx", + "X-Custom-Header": "Custom header value 3xx", +} + +var customHeaders200 = map[string]string { + "Someheader-200": "200", + "X-Custom-Header": "Custom header value 200", +} + +var customHeader2xx = map[string]string { + "X-Custom-Header": "Custom header value 2xx", +} + +var customHeader400 = map[string]string { + "Someheader-400": "400", +} + +func TestCustomResponseHeadersConfigs(t *testing.T) { + expectedCustomResponseHeader := map[string]map[string]string { + "default": defaultCustomHeaders, + "307": customHeaders307, + "3xx": customHeader3xx, + "200": customHeaders200, + "2xx": customHeader2xx, + "400": customHeader400, + } + + config, err := LoadConfigFile("./test-fixtures/config_custom_response_headers_1.hcl") + + if err != nil { + t.Fatalf("Error encountered when loading config %+v", err) + } + if diff := deep.Equal(expectedCustomResponseHeader, config.Listeners[0].CustomResponseHeaders); diff != nil { + t.Fatalf("parsed custom headers do not match the expected ones") + } +} + diff --git a/command/server/test-fixtures/config_custom_response_headers_1.hcl b/command/server/test-fixtures/config_custom_response_headers_1.hcl new file mode 100644 index 0000000000000..c2f868c2f146c --- /dev/null +++ b/command/server/test-fixtures/config_custom_response_headers_1.hcl @@ -0,0 +1,31 @@ +storage "inmem" {} +listener "tcp" { + address = "127.0.0.1:8200" + tls_disable = true + custom_response_headers { + "default" = { + "Strict-Transport-Security" = ["max-age=1","domains"], + "Content-Security-Policy" = ["default-src 'others'"], + "X-Vault-Ignored" = ["ignored"], + "X-Custom-Header" = ["Custom header value default"], + } + "307" = { + "X-Custom-Header" = ["Custom header value 307"], + } + "3xx" = { + "X-Vault-Ignored-3xx" = ["Ignored 3xx"], + "X-Custom-Header" = ["Custom header value 3xx"] + } + "200" = { + "someheader-200" = ["200"], + "X-Custom-Header" = ["Custom header value 200"] + } + "2xx" = { + "X-Custom-Header" = ["Custom header value 2xx"] + } + "400" = { + "someheader-400" = ["400"] + } + } +} +disable_mlock = true diff --git a/vault/core.go b/vault/core.go index 110488db7047b..d1aa41e4e9fdd 100644 --- a/vault/core.go +++ b/vault/core.go @@ -2660,13 +2660,13 @@ func (c *Core) SetCustomResponseHeaders(w http.ResponseWriter, status int) { c.customListenerHeader.SetCustomResponseHeaders(w, status) } -func (c *Core) ExistCustomResponseHeader(header string, statusCode int, la string) bool { +func (c *Core) ExistCustomResponseHeader(header string, la string) bool { if c.customListenerHeader == nil { c.logger.Debug("custom response headers not configured") return false } - return c.customListenerHeader.ExistHeader(header, statusCode, la) + return c.customListenerHeader.ExistCustomResponseHeader(header, la) } func (c *Core) ReloadCustomResponseHeaders() error { diff --git a/vault/custom_response_headers.go b/vault/custom_response_headers.go index 2601239e9afd0..8a62b6200730f 100644 --- a/vault/custom_response_headers.go +++ b/vault/custom_response_headers.go @@ -19,6 +19,9 @@ type ListenersCustomResponseHeadersList struct { type ListenerCustomHeaders struct { Address string StatusCodeHeaderMap map[string][]*CustomHeader + // ConfiguredHeadersStatusCodeMap field is introduced so that we would not need to loop through + // StatusCodeHeaderMap to see if a header exists, the key for this map is the headers names + ConfiguredHeadersStatusCodeMap map[string][]string } type CustomHeader struct { @@ -41,6 +44,7 @@ func NewListenerCustomHeader(ln []*configutil.Listener, logger log.Logger, uiHea Address: l.Address, } lc.StatusCodeHeaderMap = make(map[string][]*CustomHeader) + lc.ConfiguredHeadersStatusCodeMap = make(map[string][]string) for sc, hv := range l.CustomResponseHeaders { var chl []*CustomHeader for h, v := range hv { @@ -65,6 +69,9 @@ func NewListenerCustomHeader(ln []*configutil.Listener, logger log.Logger, uiHea } chl = append(chl, ch) + + // setting up the reverse map of header to status code for easy lookups + lc.ConfiguredHeadersStatusCodeMap[h] = append(lc.ConfiguredHeadersStatusCodeMap[h], sc) } lc.StatusCodeHeaderMap[sc] = chl } @@ -77,6 +84,7 @@ func NewListenerCustomHeader(ln []*configutil.Listener, logger log.Logger, uiHea func (c *ListenersCustomResponseHeadersList) SetCustomResponseHeaders(w http.ResponseWriter, status int) { if w == nil { c.logger.Error("No ResponseWriter provided") + return } // Getting the listener address to set its corresponding custom headers @@ -96,6 +104,15 @@ func (c *ListenersCustomResponseHeadersList) SetCustomResponseHeaders(w http.Res c.logger.Warn("no listener config found", "address", la) return } + if len(lch) != 1 { + c.logger.Warn("multiple listeners with the same address configured") + return + } + sch := lch[0].StatusCodeHeaderMap + if sch == nil { + c.logger.Warn("status code header map not configured") + return + } // setter function to set the headers setter := func(hvl []*CustomHeader) { @@ -111,34 +128,45 @@ func (c *ListenersCustomResponseHeadersList) SetCustomResponseHeaders(w http.Res } // Setting the default headers first - setter(lch["default"]) + setter(sch["default"]) // setting the Xyy pattern first d := fmt.Sprintf("%vxx", status / 100) - if val, ok := lch[d]; ok { + if val, ok := sch[d]; ok { setter(val) } // Setting the specific headers - if val, ok := lch[strconv.Itoa(status)]; ok { + if val, ok := sch[strconv.Itoa(status)]; ok { setter(val) } return } -func (c *ListenersCustomResponseHeadersList) getListenerMap(address string) map[string][]*CustomHeader { +func (c *ListenersCustomResponseHeadersList) getListenerMap(address string) []*ListenerCustomHeaders { if c.CustomHeadersList == nil { return nil } - for _, l := range c.CustomHeadersList { - if l.Address == address { - return l.StatusCodeHeaderMap + + // either looking for a specific listener, or if listener address isn't given, + // checking for all available listeners + var lch []*ListenerCustomHeaders + if address == "" { + lch = c.CustomHeadersList + } else { + for _, l := range c.CustomHeadersList { + if l.Address == address { + lch = append(lch, l) + } + } + if len(lch) == 0 { + return nil } } - return nil + return lch } -func (c *ListenersCustomResponseHeadersList) findCustomHeaderMatchStatusCode(hm map[string][]*CustomHeader, sc int, hn string) string { +func (c *ListenersCustomResponseHeadersList) findCustomHeaderMatchStatusCode(hm map[string][]*CustomHeader, sc string, hn string) string { getHeader := func(ch []*CustomHeader) string { for _, h := range ch { @@ -150,19 +178,26 @@ func (c *ListenersCustomResponseHeadersList) findCustomHeaderMatchStatusCode(hm } // starting with the most specific status code - if ch, ok := hm[strconv.Itoa(sc)]; ok { + if ch, ok := hm[sc]; ok { h := getHeader(ch) if h != "" { return h } } - s := fmt.Sprintf("%vxx", sc/100) - if configutil.IsValidStatusCodeCollection(s) { - if ch, ok := hm[s]; ok { - h := getHeader(ch) - if h != "" { - return h + // Checking for the Yxx pattern + var firstDig string + if len(sc) == 3 { + firstDig = strings.Split(sc, "")[0] + } + if firstDig != "" { + s := fmt.Sprintf("%vxx", firstDig) + if configutil.IsValidStatusCodeCollection(s) { + if ch, ok := hm[s]; ok { + h := getHeader(ch) + if h != "" { + return h + } } } } @@ -177,26 +212,19 @@ func (c *ListenersCustomResponseHeadersList) findCustomHeaderMatchStatusCode(hm return "" } -func (c *ListenersCustomResponseHeadersList) FetchCustomResponseHeaderValue(header string, sc int, la string) ([]string, error) { +func (c *ListenersCustomResponseHeadersList) FetchCustomResponseHeaderValue(header string, sc string, la string) ([]string, error) { if header == "" { return nil, fmt.Errorf("invalid target header") } - // either looking for a specific listener, or if listener address isn't given, - // checking for all available listeners - var lch []*ListenerCustomHeaders - if la == "" { - lch = c.CustomHeadersList - } else { - for _, l := range c.CustomHeadersList { - if l.Address == la { - lch = append(lch, l) - } - } - if len(lch) == 0 { - return nil, fmt.Errorf("no listener found with address:%v", la) - } + if c.CustomHeadersList == nil { + return nil, fmt.Errorf("core custom headers not configured") + } + + lch := c.getListenerMap(la) + if lch == nil { + return nil, fmt.Errorf("no listener found with address:%v", la) } var headers []string @@ -213,16 +241,38 @@ func (c *ListenersCustomResponseHeadersList) FetchCustomResponseHeaderValue(head return headers, err } -func(c *ListenersCustomResponseHeadersList) ExistHeader(th string, sc int, la string) bool { - if !configutil.IsValidStatusCode(strconv.Itoa(sc)) { +func(c *ListenersCustomResponseHeadersList) FetchHeaderForStausCode(header, sc, la string) bool { + + if !configutil.IsValidStatusCode(sc) { c.logger.Error("failed to check if a header exist in config file due to invalid status code") return false } - chv, _ := c.FetchCustomResponseHeaderValue(th, sc, la) + chv, _ := c.FetchCustomResponseHeaderValue(header, sc, la) if chv != nil { return true } return false } + +func (c *ListenersCustomResponseHeadersList) ExistCustomResponseHeader(header, la string) bool { + + lch := c.getListenerMap(la) + if lch == nil { + return false + } + if len(lch) != 1 { + c.logger.Warn("multiple listeners with the same address configured, checking all listeners for the custom header") + } + + hn := textproto.CanonicalMIMEHeaderKey(header) + for _, chs := range lch { + hs := chs.ConfiguredHeadersStatusCodeMap + if _, ok := hs[hn]; ok { + return true + } + } + + return false +} \ No newline at end of file diff --git a/vault/custom_response_headers_test.go b/vault/custom_response_headers_test.go new file mode 100644 index 0000000000000..eda26cf355b41 --- /dev/null +++ b/vault/custom_response_headers_test.go @@ -0,0 +1,207 @@ +package vault + +import ( + "context" + "net/http/httptest" + "strings" + "testing" + + log "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/helper/namespace" + "github.com/hashicorp/vault/internalshared/configutil" + "github.com/hashicorp/vault/sdk/helper/logging" + "github.com/hashicorp/vault/sdk/logical" + "github.com/hashicorp/vault/sdk/physical/inmem" +) + + +var defaultCustomHeaders = map[string]string { + "Strict-Transport-Security": "max-age=1; domains", + "Content-Security-Policy": "default-src 'others'", + "X-Vault-Ignored": "ignored", + "X-Custom-Header": "Custom header value default", + "X-Frame-Options": "Deny", + "X-Content-Type-Options": "nosniff", + "Content-Type": "text/plain; charset=utf-8", + "X-XSS-Protection": "1; mode=block", +} + +var customHeaders307 = map[string]string { + "X-Custom-Header": "Custom header value 307", +} + +var customHeader3xx = map[string]string { + "X-Vault-Ignored-3xx": "Ignored 3xx", + "X-Custom-Header": "Custom header value 3xx", +} + +var customHeaders200 = map[string]string { + "Someheader-200": "200", + "X-Custom-Header": "Custom header value 200", +} + +var customHeader2xx = map[string]string { + "X-Custom-Header": "Custom header value 2xx", +} + +var customHeader400 = map[string]string { + "Someheader-400": "400", +} + +func TestConfigCustomHeaders(t *testing.T) { + logger := logging.NewVaultLogger(log.Trace) + phys, err := inmem.NewTransactionalInmem(nil, logger) + if err != nil { + t.Fatal(err) + } + logl := &logical.InmemStorage{} + uiConfig := NewUIConfig(true, phys, logl) + + rawListenerConfig := []*configutil.Listener { + { + Type: "tcp", + Address: "127.0.0.1:443", + CustomResponseHeaders: map[string]map[string]string{ + "default": defaultCustomHeaders, + "307": customHeaders307, + "3xx": customHeader3xx, + "200": customHeaders200, + "2xx": customHeader2xx, + "400": customHeader400, + }, + }, + } + + uiHeaders, err := uiConfig.Headers(context.Background()) + customListenerHeader := NewListenerCustomHeader(rawListenerConfig, logger, uiHeaders) + if customListenerHeader == nil { + t.Fatalf("custom header config should be configured") + } + + if customListenerHeader.ExistCustomResponseHeader("X-Vault-Ignored-307", "127.0.0.1:443") { + t.Fatalf("header name with X-Vault prefix is not valid") + } + if customListenerHeader.ExistCustomResponseHeader("X-Vault-Ignored-3xx", "127.0.0.1:443") { + t.Fatalf("header name with X-Vault prefix is not valid") + } + + if !customListenerHeader.ExistCustomResponseHeader("X-Custom-Header", "127.0.0.1:443") { + t.Fatalf("header name with X-Vault prefix is not valid") + } + + commonDefaultUiHeader := uiHeaders["Content-Security-Policy"] + commonDefaultResponseHeader, _ := customListenerHeader.FetchCustomResponseHeaderValue("Content-Security-Policy", "default", "127.0.0.1:443") + + if commonDefaultUiHeader[0] == commonDefaultResponseHeader[0] { + t.Fatalf("default haeder ") + } + + w := httptest.NewRecorder() + w.Header().Set("X-Vault-Listener-Add", rawListenerConfig[0].Address) + + customListenerHeader.SetCustomResponseHeaders(w, 200) + if w.Header().Get("Someheader-200") != "200" || w.Header().Get("X-Custom-Header") != "Custom header value 200"{ + t.Fatalf("response headers related to status code %v did not set properly", 200) + } + + w = httptest.NewRecorder() + w.Header().Set("X-Vault-Listener-Add", rawListenerConfig[0].Address) + customListenerHeader.SetCustomResponseHeaders(w, 204) + if w.Header().Get("Someheader-200") == "200" || w.Header().Get("X-Custom-Header") != "Custom header value 2xx" { + t.Fatalf("response headers related to status code %v did not set properly", "2xx") + } + + w = httptest.NewRecorder() + w.Header().Set("X-Vault-Listener-Add", rawListenerConfig[0].Address) + customListenerHeader.SetCustomResponseHeaders(w, 500) + for h, v := range defaultCustomHeaders { + if h != "X-Vault-Ignored" && w.Header().Get(h) != v { + t.Fatalf("response headers related to status code %v did not set properly", 500) + } + } + if w.Header().Get("X-Vault-Ignored") != "" { + t.Fatalf("response headers contains a header with pattern X-Vault") + } +} + + +func TestCustomResponseHeadersConfigInteractUiConfig(t *testing.T) { + b := testSystemBackend(t) + _, barrier, _ := mockBarrier(t) + view := NewBarrierView(barrier, "") + b.(*SystemBackend).Core.systemBarrierView = view + + logger := logging.NewVaultLogger(log.Trace) + rawListenerConfig := []*configutil.Listener { + { + Type: "tcp", + Address: "127.0.0.1:443", + CustomResponseHeaders: map[string]map[string]string{ + "default": defaultCustomHeaders, + "307": customHeaders307, + "3xx": customHeader3xx, + "200": customHeaders200, + "2xx": customHeader2xx, + "400": customHeader400, + }, + }, + } + uiHeaders, err := b.(*SystemBackend).Core.uiConfig.Headers(context.Background()) + if err != nil { + t.Fatalf("failed to get headers from ui config") + } + customListenerHeader := NewListenerCustomHeader(rawListenerConfig, logger, uiHeaders) + if customListenerHeader == nil { + t.Fatalf("custom header config should be configured") + } + b.(*SystemBackend).Core.customListenerHeader = customListenerHeader + clh := b.(*SystemBackend).Core.customListenerHeader + if clh == nil { + t.Fatalf("custom header config should be configured in core") + } + + w := httptest.NewRecorder() + w.Header().Set("X-Vault-Listener-Add", "127.0.0.1:443") + hw := logical.NewHTTPResponseWriter(w) + + // setting a header that already exist in custom headers + req := logical.TestRequest(t, logical.UpdateOperation, "config/ui/headers/X-Custom-Header") + req.Data["values"] = []string{"UI Custom Header"} + req.ResponseWriter = hw + + resp, err := b.HandleRequest(namespace.RootContext(nil), req) + if err == nil { + t.Fatal("request did not fail on setting a header that is present in custom response headers") + } + if !strings.Contains(resp.Data["error"].(string), "header already exist in server configuration file") { + t.Fatalf("failed to get the expected error") + } + + // setting a header that already exist in custom headers + req = logical.TestRequest(t, logical.UpdateOperation, "config/ui/headers/Someheader-400") + req.Data["values"] = []string{"400"} + req.ResponseWriter = hw + + _, err = b.HandleRequest(namespace.RootContext(nil), req) + if err == nil { + t.Fatal("request did not fail on setting a header that is present in custom response headers") + } + h, err := b.(*SystemBackend).Core.uiConfig.Headers(context.Background()) + if h.Get("Someheader-400") == "400" { + t.Fatalf("should not be able to set a header that is in custom response headers") + } + + req = logical.TestRequest(t, logical.UpdateOperation, "config/ui/headers/X-CustomUiHeader") + req.Data["values"] = []string{"Ui header value"} + req.ResponseWriter = hw + + _, err = b.HandleRequest(namespace.RootContext(nil), req) + if err != nil { + t.Fatal("request did not fail on setting a header that is present in custom response headers") + } + + h, err = b.(*SystemBackend).Core.uiConfig.Headers(context.Background()) + if h.Get("X-CustomUiHeader") != "Ui header value" { + t.Fatalf("failed to sett a header that is not in custom response headers") + } +} diff --git a/vault/logical_system.go b/vault/logical_system.go index a59cb1b5e1dd9..b21b30db78481 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -2628,7 +2628,7 @@ func (b *SystemBackend) handleConfigUIHeadersUpdate(ctx context.Context, req *lo value := http.Header{} for _, v := range values { la := req.ResponseWriter.Header().Get("X-Vault-Listener-Add") - if b.Core.ExistCustomResponseHeader(header, 200, la) { + if b.Core.ExistCustomResponseHeader(header, la) { return logical.ErrorResponse("header already exist in server configuration file"), logical.ErrInvalidRequest } value.Add(header, v) From e52db45ec51ef0b26091d0574c179737dbb72365 Mon Sep 17 00:00:00 2001 From: hamid ghaf Date: Thu, 9 Sep 2021 08:11:59 -0700 Subject: [PATCH 09/25] Changing some error messages based on some recommendations --- internalshared/configutil/http_response_headers.go | 4 ++-- vault/core.go | 4 ++-- vault/custom_response_headers.go | 2 +- vault/logical_system.go | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/internalshared/configutil/http_response_headers.go b/internalshared/configutil/http_response_headers.go index 430ae277e6925..a3c44a3c2b500 100644 --- a/internalshared/configutil/http_response_headers.go +++ b/internalshared/configutil/http_response_headers.go @@ -77,7 +77,7 @@ func setDefaultResponseHeaders(c map[string]string) map[string]string { func ParseCustomResponseHeaders(r interface{}) (map[string]map[string]string, error) { if !isValidListDict(r) { - return nil, fmt.Errorf("invalid input type: %T", r) + return nil, fmt.Errorf("response headers were not configured correctly. please make sure they're in a map") } customResponseHeader := r.([]map[string]interface{}) @@ -174,7 +174,7 @@ func parseHeaders(in map[string]interface{}) (map[string]string, error) { func parseHeaderValues(h interface{}) (string, error) { var sl []string if !isValidList(h) { - return "", fmt.Errorf("failed to parse header values") + return "", fmt.Errorf("headers must be given in a list of strings") } vli := h.([]interface{}) for _, vh := range vli { diff --git a/vault/core.go b/vault/core.go index d1aa41e4e9fdd..255a98c59a04b 100644 --- a/vault/core.go +++ b/vault/core.go @@ -2653,7 +2653,7 @@ func (c *Core) SetConfig(conf *server.Config) { func (c *Core) SetCustomResponseHeaders(w http.ResponseWriter, status int) { if c.customListenerHeader == nil { - c.logger.Debug("custom response headers not configured") + c.logger.Debug("failed to find the custom response headers configuration") return } @@ -2662,7 +2662,7 @@ func (c *Core) SetCustomResponseHeaders(w http.ResponseWriter, status int) { func (c *Core) ExistCustomResponseHeader(header string, la string) bool { if c.customListenerHeader == nil { - c.logger.Debug("custom response headers not configured") + c.logger.Debug("failed to find the custom response headers configuration") return false } diff --git a/vault/custom_response_headers.go b/vault/custom_response_headers.go index 8a62b6200730f..b57eefa979a45 100644 --- a/vault/custom_response_headers.go +++ b/vault/custom_response_headers.go @@ -59,7 +59,7 @@ func NewListenerCustomHeader(ln []*configutil.Listener, logger log.Logger, uiHea if uiHeaders != nil { exist := uiHeaders.Get(h) if exist != "" { - logger.Error("found a duplicate header in UI, note that config file headers take precedence.", "header:", h) + logger.Warn("found a duplicate header in UI, note that config file headers take precedence.", "header:", h) } } diff --git a/vault/logical_system.go b/vault/logical_system.go index b21b30db78481..b23e8cdead191 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -2629,7 +2629,7 @@ func (b *SystemBackend) handleConfigUIHeadersUpdate(ctx context.Context, req *lo for _, v := range values { la := req.ResponseWriter.Header().Get("X-Vault-Listener-Add") if b.Core.ExistCustomResponseHeader(header, la) { - return logical.ErrorResponse("header already exist in server configuration file"), logical.ErrInvalidRequest + return logical.ErrorResponse(fmt.Sprintf("header already exist in server configuration file: %v", header)), logical.ErrInvalidRequest } value.Add(header, v) } From 06ef62f0b2a9791ae1916ed45c76e976d794772a Mon Sep 17 00:00:00 2001 From: hamid ghaf Date: Tue, 14 Sep 2021 11:02:54 -0700 Subject: [PATCH 10/25] Incorporating custom response headers struct into the request context --- command/agent.go | 3 +- command/server.go | 2 +- http/cors.go | 4 +- http/forwarded_for_test.go | 12 +- http/handler.go | 138 ++++++----- http/help.go | 6 +- http/http_test.go | 10 + http/logical.go | 33 ++- http/sys_feature_flags.go | 4 +- http/sys_generate_root.go | 34 +-- http/sys_health.go | 10 +- http/sys_init.go | 14 +- http/sys_leader.go | 6 +- http/sys_metrics.go | 8 +- http/sys_raft.go | 20 +- http/sys_rekey.go | 78 +++---- http/sys_seal.go | 40 ++-- http/util.go | 10 +- .../configutil/http_response_headers.go | 41 ++-- internalshared/configutil/listener.go | 4 +- .../listenerutil/custom_response_headers.go | 160 +++++++++++++ vault/core.go | 72 +++--- vault/custom_response_headers.go | 217 ++---------------- vault/custom_response_headers_test.go | 23 +- vault/logical_system.go | 5 +- 25 files changed, 488 insertions(+), 466 deletions(-) create mode 100644 internalshared/listenerutil/custom_response_headers.go diff --git a/command/agent.go b/command/agent.go index 9cbb6245a01c3..2cd7b0e94b420 100644 --- a/command/agent.go +++ b/command/agent.go @@ -2,7 +2,6 @@ package command import ( "context" - "errors" "flag" "fmt" "io" @@ -878,7 +877,7 @@ func (c *AgentCommand) Run(args []string) int { func verifyRequestHeader(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if val, ok := r.Header[consts.RequestHeaderName]; !ok || len(val) != 1 || val[0] != "true" { - err := errors.New(fmt.Sprintf("missing '%s' header", consts.RequestHeaderName)) + err := fmt.Errorf("missing '%s' header", consts.RequestHeaderName) status := http.StatusPreconditionFailed logical.AdjustErrorStatusCode(&status, err) logical.RespondError(w, status, err) diff --git a/command/server.go b/command/server.go index 5e6fd8bdca807..768e907e2b152 100644 --- a/command/server.go +++ b/command/server.go @@ -2637,7 +2637,7 @@ func startHttpServers(c *ServerCommand, core *vault.Core, config *server.Config, }) if len(ln.Config.XForwardedForAuthorizedAddrs) > 0 { - handler = vaulthttp.WrapForwardedForHandler(handler, ln.Config, core.SetCustomResponseHeaders) + handler = vaulthttp.WrapForwardedForHandler(handler, ln.Config) } // server defaults diff --git a/http/cors.go b/http/cors.go index 9a8b57c9b9b6f..ed48b31228a14 100644 --- a/http/cors.go +++ b/http/cors.go @@ -40,13 +40,13 @@ func wrapCORSHandler(h http.Handler, core *vault.Core) http.Handler { // Return a 403 if the origin is not allowed to make cross-origin requests. if !corsConf.IsValidOrigin(origin) { - respondError(w, http.StatusForbidden, fmt.Errorf("origin not allowed"), core.SetCustomResponseHeaders) + respondError(w, http.StatusForbidden, fmt.Errorf("origin not allowed"), req) return } if req.Method == http.MethodOptions && !strutil.StrListContains(allowedMethods, requestMethod) { status := http.StatusMethodNotAllowed - core.SetCustomResponseHeaders(w, status) + SetCustomResponseHeaders(w, status, req) w.WriteHeader(status) return } diff --git a/http/forwarded_for_test.go b/http/forwarded_for_test.go index 40ea8289c7c71..9323f5bf1c728 100644 --- a/http/forwarded_for_test.go +++ b/http/forwarded_for_test.go @@ -42,7 +42,7 @@ func TestHandler_XForwardedFor(t *testing.T) { }) listenerConfig := getListenerConfigForMarshalerTest(goodAddr) listenerConfig.XForwardedForRejectNotPresent = true - return WrapForwardedForHandler(origHandler, listenerConfig, nil) + return WrapForwardedForHandler(origHandler, listenerConfig) } cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{ @@ -85,7 +85,7 @@ func TestHandler_XForwardedFor(t *testing.T) { }) listenerConfig := getListenerConfigForMarshalerTest(badAddr) listenerConfig.XForwardedForRejectNotPresent = true - return WrapForwardedForHandler(origHandler, listenerConfig, nil) + return WrapForwardedForHandler(origHandler, listenerConfig) } cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{ @@ -121,7 +121,7 @@ func TestHandler_XForwardedFor(t *testing.T) { listenerConfig := getListenerConfigForMarshalerTest(badAddr) listenerConfig.XForwardedForRejectNotPresent = true listenerConfig.XForwardedForRejectNotAuthorized = true - return WrapForwardedForHandler(origHandler, listenerConfig, nil) + return WrapForwardedForHandler(origHandler, listenerConfig) } cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{ @@ -155,7 +155,7 @@ func TestHandler_XForwardedFor(t *testing.T) { listenerConfig.XForwardedForRejectNotPresent = true listenerConfig.XForwardedForRejectNotAuthorized = true listenerConfig.XForwardedForHopSkips = 4 - return WrapForwardedForHandler(origHandler, listenerConfig, nil) + return WrapForwardedForHandler(origHandler, listenerConfig) } cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{ @@ -189,7 +189,7 @@ func TestHandler_XForwardedFor(t *testing.T) { listenerConfig.XForwardedForRejectNotPresent = true listenerConfig.XForwardedForRejectNotAuthorized = true listenerConfig.XForwardedForHopSkips = 1 - return WrapForwardedForHandler(origHandler, listenerConfig, nil) + return WrapForwardedForHandler(origHandler, listenerConfig) } cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{ @@ -226,7 +226,7 @@ func TestHandler_XForwardedFor(t *testing.T) { listenerConfig.XForwardedForRejectNotPresent = true listenerConfig.XForwardedForRejectNotAuthorized = true listenerConfig.XForwardedForHopSkips = 1 - return WrapForwardedForHandler(origHandler, listenerConfig, nil) + return WrapForwardedForHandler(origHandler, listenerConfig) } cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{ diff --git a/http/handler.go b/http/handler.go index dc9ff3841d081..b7a4ba8c8a87c 100644 --- a/http/handler.go +++ b/http/handler.go @@ -26,6 +26,7 @@ import ( "github.com/hashicorp/go-sockaddr" "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/internalshared/configutil" + "github.com/hashicorp/vault/internalshared/listenerutil" "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/jsonutil" "github.com/hashicorp/vault/sdk/helper/pathmanager" @@ -195,7 +196,8 @@ func Handler(props *vault.HandlerProperties) http.Handler { } // Wrap the handler in another handler to trigger all help paths. - helpWrappedHandler := wrapHelpHandler(mux, core) + unregisteredPathsHandler := wrapUnregisteredPathsHandler(mux, core) + helpWrappedHandler := wrapHelpHandler(unregisteredPathsHandler, core) corsWrappedHandler := wrapCORSHandler(helpWrappedHandler, core) quotaWrappedHandler := rateLimitQuotaWrapping(corsWrappedHandler, core) genericWrappedHandler := genericWrapping(core, quotaWrappedHandler, props) @@ -248,7 +250,7 @@ func handleAuditNonLogical(core *vault.Core, h http.Handler) http.Handler { req, _, status, err := buildLogicalRequestNoAuth(core.PerfStandby(), w, r) if err != nil || status != 0 { - respondError(w, status, err, core.SetCustomResponseHeaders) + respondError(w, status, err, r) return } if origBody != nil { @@ -259,7 +261,7 @@ func handleAuditNonLogical(core *vault.Core, h http.Handler) http.Handler { } err = core.AuditLogger().AuditRequest(r.Context(), input) if err != nil { - respondError(w, status, err, core.SetCustomResponseHeaders) + respondError(w, status, err, r) return } cw := newCopyResponseWriter(w) @@ -273,12 +275,30 @@ func handleAuditNonLogical(core *vault.Core, h http.Handler) http.Handler { input.Response = logical.HTTPResponseToLogicalResponse(httpResp) err = core.AuditLogger().AuditResponse(r.Context(), input) if err != nil { - respondError(w, status, err, core.SetCustomResponseHeaders) + respondError(w, status, err, r) } return }) } +// wrapUnregisteredPathsHandler is the last layer before the endpoint direct handlers +// This is to prevent response headers being overwritten unintentionally +func wrapUnregisteredPathsHandler(h http.Handler, core *vault.Core) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // According to the net/http package, if a subtree has been registered and a request + // is received naming the subtree root without its trailing slash, ServeMux + // redirects that request to the subtree root (adding the trailing slash). + // Calling "/v1/sys" endpoint will be redirected to "/v1/sys/" with status + // 301 (Moved Permanently), however, the status code is set inside the net/http + // package. So, we set the custom response headers here instead. + if r.URL.Path == "/v1/sys" { + SetCustomResponseHeaders(w, 301, r) + } + h.ServeHTTP(w, r) + return + }) +} + // wrapGenericHandler wraps the handler with an extra layer of handler where // tasks that should be commonly handled for all the requests and/or responses // are performed. @@ -296,6 +316,12 @@ func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerPr maxRequestSize = DefaultMaxRequestSize } + var listenerCustomHeaders *listenerutil.ListenerCustomHeaders + if props.ListenerConfig != nil { + la := props.ListenerConfig.Address + listenerCustomHeaders = core.GetListenerCustomResponseHeaders(la) + } + // Swallow this error since we don't want to pollute the logs and we also don't want to // return an HTTP error here. This information is best effort. hostname, _ := os.Hostname() @@ -320,6 +346,11 @@ func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerPr ctx = context.WithValue(ctx, "max_request_size", maxRequestSize) } ctx = context.WithValue(ctx, "original_request_path", r.URL.Path) + + if listenerCustomHeaders != nil { + ctx = context.WithValue(ctx, "X-Vault-Listener-Custom-Headers-Struct", listenerCustomHeaders) + } + r = r.WithContext(ctx) r = r.WithContext(namespace.ContextWithNamespace(r.Context(), namespace.RootNamespace)) @@ -335,19 +366,11 @@ func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerPr w.Header().Set("X-Vault-Hostname", hostname) } - // Setting the listener address as a header so that we could set customized - // headers configured in the corresponding listener stanza - var la string - if props.ListenerConfig != nil { - la = props.ListenerConfig.Address - } - w.Header().Set("X-Vault-Listener-Add", la) - switch { case strings.HasPrefix(r.URL.Path, "/v1/"): newR, status := adjustRequest(core, r) if status != 0 { - respondError(w, status, nil, core.SetCustomResponseHeaders) + respondError(w, status, nil, r) cancelFunc() return } @@ -355,7 +378,7 @@ func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerPr case strings.HasPrefix(r.URL.Path, "/ui"), r.URL.Path == "/robots.txt", r.URL.Path == "/": default: - respondError(w, http.StatusNotFound, nil, core.SetCustomResponseHeaders) + respondError(w, http.StatusNotFound, nil, r) cancelFunc() return } @@ -373,7 +396,7 @@ func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerPr }) } -func WrapForwardedForHandler(h http.Handler, l *configutil.Listener, hs customResponseHeaderSetter) http.Handler { +func WrapForwardedForHandler(h http.Handler, l *configutil.Listener) http.Handler { rejectNotPresent := l.XForwardedForRejectNotPresent hopSkips := l.XForwardedForHopSkips authorizedAddrs := l.XForwardedForAuthorizedAddrs @@ -385,7 +408,7 @@ func WrapForwardedForHandler(h http.Handler, l *configutil.Listener, hs customRe h.ServeHTTP(w, r) return } - respondError(w, http.StatusBadRequest, fmt.Errorf("missing x-forwarded-for header and configured to reject when not present"), hs) + respondError(w, http.StatusBadRequest, fmt.Errorf("missing x-forwarded-for header and configured to reject when not present"), r) return } @@ -398,7 +421,7 @@ func WrapForwardedForHandler(h http.Handler, l *configutil.Listener, hs customRe h.ServeHTTP(w, r) return } - respondError(w, http.StatusBadRequest, fmt.Errorf("error parsing client hostport: %w", err), hs) + respondError(w, http.StatusBadRequest, fmt.Errorf("error parsing client hostport: %w", err), r) return } @@ -409,7 +432,7 @@ func WrapForwardedForHandler(h http.Handler, l *configutil.Listener, hs customRe h.ServeHTTP(w, r) return } - respondError(w, http.StatusBadRequest, fmt.Errorf("error parsing client address: %w", err), hs) + respondError(w, http.StatusBadRequest, fmt.Errorf("error parsing client address: %w", err), r) return } @@ -427,7 +450,7 @@ func WrapForwardedForHandler(h http.Handler, l *configutil.Listener, hs customRe h.ServeHTTP(w, r) return } - respondError(w, http.StatusBadRequest, fmt.Errorf("client address not authorized for x-forwarded-for and configured to reject connection"), hs) + respondError(w, http.StatusBadRequest, fmt.Errorf("client address not authorized for x-forwarded-for and configured to reject connection"), r) return } @@ -455,7 +478,7 @@ func WrapForwardedForHandler(h http.Handler, l *configutil.Listener, hs customRe // authorized (or we've turned off explicit rejection) and we // should assume that what comes in should be properly // formatted. - respondError(w, http.StatusBadRequest, fmt.Errorf("malformed x-forwarded-for configuration or request, hops to skip (%d) would skip before earliest chain link (chain length %d)", hopSkips, len(headers)), hs) + respondError(w, http.StatusBadRequest, fmt.Errorf("malformed x-forwarded-for configuration or request, hops to skip (%d) would skip before earliest chain link (chain length %d)", hopSkips, len(headers)), r) return } @@ -486,7 +509,7 @@ func handleUIHeaders(core *vault.Core, h http.Handler) http.Handler { userHeaders, err := core.UIHeaders() if err != nil { - respondError(w, http.StatusInternalServerError, err, core.SetCustomResponseHeaders) + respondError(w, http.StatusInternalServerError, err, req) return } if userHeaders != nil { @@ -500,7 +523,7 @@ func handleUIHeaders(core *vault.Core, h http.Handler) http.Handler { // status code specifically, instead, a call to w.Write is called which // internally also sets the status code to 200. // Just setting the headers for status code 200. - core.SetCustomResponseHeaders(w, 200) + SetCustomResponseHeaders(w, 200, req) h.ServeHTTP(w, req) }) @@ -592,7 +615,7 @@ func handleUIStub() http.Handler { func handleUIRedirect(core *vault.Core) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { status := 307 - core.SetCustomResponseHeaders(w, status) + SetCustomResponseHeaders(w, status, req) http.Redirect(w, req, "/ui/", status) return }) @@ -748,7 +771,7 @@ func handleRequestForwarding(core *vault.Core, handler http.Handler) http.Handle // Note if the client requested forwarding shouldForward, err := forwardBasedOnHeaders(core, r) if err != nil { - respondError(w, http.StatusBadRequest, err, core.SetCustomResponseHeaders) + respondError(w, http.StatusBadRequest, err, r) return } @@ -756,7 +779,7 @@ func handleRequestForwarding(core *vault.Core, handler http.Handler) http.Handle if core.PerfStandby() && !shouldForward { ns, err := namespace.FromContext(r.Context()) if err != nil { - respondError(w, http.StatusBadRequest, err, core.SetCustomResponseHeaders) + respondError(w, http.StatusBadRequest, err, r) return } path := ns.TrimmedPath(r.URL.Path[len("/v1/"):]) @@ -784,7 +807,7 @@ func handleRequestForwarding(core *vault.Core, handler http.Handler) http.Handle return } // Some internal error occurred - respondError(w, http.StatusInternalServerError, err, core.SetCustomResponseHeaders) + respondError(w, http.StatusInternalServerError, err, r) return } if isLeader { @@ -793,7 +816,7 @@ func handleRequestForwarding(core *vault.Core, handler http.Handler) http.Handle return } if leaderAddr == "" { - respondError(w, http.StatusInternalServerError, fmt.Errorf("local node not active but active cluster node not found"), core.SetCustomResponseHeaders) + respondError(w, http.StatusInternalServerError, fmt.Errorf("local node not active but active cluster node not found"), r) return } @@ -805,25 +828,25 @@ func handleRequestForwarding(core *vault.Core, handler http.Handler) http.Handle func forwardRequest(core *vault.Core, w http.ResponseWriter, r *http.Request) { if r.Header.Get(vault.IntNoForwardingHeaderName) != "" { - respondStandby(core, w, r.URL) + respondStandby(core, w, r) return } if r.Header.Get(NoRequestForwardingHeaderName) != "" { // Forwarding explicitly disabled, fall back to previous behavior core.Logger().Debug("handleRequestForwarding: forwarding disabled by client request") - respondStandby(core, w, r.URL) + respondStandby(core, w, r) return } ns, err := namespace.FromContext(r.Context()) if err != nil { - respondError(w, http.StatusBadRequest, err, core.SetCustomResponseHeaders) + respondError(w, http.StatusBadRequest, err, r) return } path := ns.TrimmedPath(r.URL.Path[len("/v1/"):]) if alwaysRedirectPaths.HasPath(path) { - respondStandby(core, w, r.URL) + respondStandby(core, w, r) return } @@ -839,7 +862,7 @@ func forwardRequest(core *vault.Core, w http.ResponseWriter, r *http.Request) { } // Fall back to redirection - respondStandby(core, w, r.URL) + respondStandby(core, w, r) return } @@ -849,7 +872,7 @@ func forwardRequest(core *vault.Core, w http.ResponseWriter, r *http.Request) { } } - core.SetCustomResponseHeaders(w, statusCode) + SetCustomResponseHeaders(w, statusCode, r) w.WriteHeader(statusCode) w.Write(retBytes) } @@ -865,7 +888,7 @@ func request(core *vault.Core, w http.ResponseWriter, rawReq *http.Request, r *l resp.AddWarning("Timeout hit while waiting for local replicated cluster to apply primary's write; this client may encounter stale reads of values written during this operation.") } if errwrap.Contains(err, consts.ErrStandby.Error()) { - respondStandby(core, w, rawReq.URL) + respondStandby(core, w, rawReq) return resp, false, false } if err != nil && errwrap.Contains(err, logical.ErrPerfStandbyPleaseForward.Error()) { @@ -906,7 +929,7 @@ func request(core *vault.Core, w http.ResponseWriter, rawReq *http.Request, r *l return nil, true, false } - if respondErrorCommon(w, r, resp, err, core.SetCustomResponseHeaders) { + if respondErrorCommon(w, r, resp, err, rawReq) { return resp, false, false } @@ -914,33 +937,34 @@ func request(core *vault.Core, w http.ResponseWriter, rawReq *http.Request, r *l } // respondStandby is used to trigger a redirect in the case that this Vault is currently a hot standby -func respondStandby(core *vault.Core, w http.ResponseWriter, reqURL *url.URL) { +func respondStandby(core *vault.Core, w http.ResponseWriter, req *http.Request) { + reqURL := req.URL // Request the leader address _, redirectAddr, _, err := core.Leader() if err != nil { if err == vault.ErrHANotEnabled { // Standalone node, serve 503 err = errors.New("node is not active") - respondError(w, http.StatusServiceUnavailable, err, core.SetCustomResponseHeaders) + respondError(w, http.StatusServiceUnavailable, err, req) return } - respondError(w, http.StatusInternalServerError, err, core.SetCustomResponseHeaders) + respondError(w, http.StatusInternalServerError, err, req) return } // If there is no leader, generate a 503 error if redirectAddr == "" { err = errors.New("no active Vault instance found") - respondError(w, http.StatusServiceUnavailable, err, core.SetCustomResponseHeaders) + respondError(w, http.StatusServiceUnavailable, err, req) return } // Parse the redirect location redirectURL, err := url.Parse(redirectAddr) if err != nil { - respondError(w, http.StatusInternalServerError, err, core.SetCustomResponseHeaders) + respondError(w, http.StatusInternalServerError, err, req) return } @@ -961,7 +985,7 @@ func respondStandby(core *vault.Core, w http.ResponseWriter, reqURL *url.URL) { // because we don't actually know if its permanent and // the request method should be preserved. w.Header().Set("Location", finalURL.String()) - core.SetCustomResponseHeaders(w, 307) + SetCustomResponseHeaders(w, 307, req) w.WriteHeader(307) } @@ -1126,27 +1150,37 @@ func isForm(head []byte, contentType string) bool { return true } -type customResponseHeaderSetter func(w http.ResponseWriter, status int) - -func respondError(w http.ResponseWriter, status int, err error, hs customResponseHeaderSetter) { +func respondError(w http.ResponseWriter, status int, err error, r *http.Request) { logical.AdjustErrorStatusCode(&status, err) - if hs != nil { - hs(w, status) - } + SetCustomResponseHeaders(w, status, r) logical.RespondError(w, status, err) } -func respondErrorCommon(w http.ResponseWriter, req *logical.Request, resp *logical.Response, err error, hs customResponseHeaderSetter) bool { +func SetCustomResponseHeaders(w http.ResponseWriter, status int, r *http.Request) { + if r == nil { + return + } + ctx := r.Context() + listenerCustomHeaders := ctx.Value("X-Vault-Listener-Custom-Headers-Struct") + if listenerCustomHeaders != nil { + lc := listenerCustomHeaders.(*listenerutil.ListenerCustomHeaders) + if lc != nil { + lc.SetCustomResponseHeaders(w, status) + } + } +} + +func respondErrorCommon(w http.ResponseWriter, req *logical.Request, resp *logical.Response, err error, r *http.Request) bool { statusCode, newErr := logical.RespondErrorCommon(req, resp, err) if newErr == nil && statusCode == 0 { return false } - respondError(w, statusCode, newErr, hs) + respondError(w, statusCode, newErr, r) return true } -func respondOk(w http.ResponseWriter, body interface{}, hs customResponseHeaderSetter) { +func respondOk(w http.ResponseWriter, body interface{}, r *http.Request) { w.Header().Set("Content-Type", "application/json") var status int @@ -1156,9 +1190,7 @@ func respondOk(w http.ResponseWriter, body interface{}, hs customResponseHeaderS status = http.StatusOK } - if hs != nil { - hs(w, status) - } + SetCustomResponseHeaders(w, status, r) w.WriteHeader(status) if body != nil { diff --git a/http/help.go b/http/help.go index e4c93c712d427..6e2903d5faef1 100644 --- a/http/help.go +++ b/http/help.go @@ -28,7 +28,7 @@ func wrapHelpHandler(h http.Handler, core *vault.Core) http.Handler { func handleHelp(core *vault.Core, w http.ResponseWriter, r *http.Request) { ns, err := namespace.FromContext(r.Context()) if err != nil { - respondError(w, http.StatusBadRequest, nil, core.SetCustomResponseHeaders) + respondError(w, http.StatusBadRequest, nil, r) return } path := ns.TrimmedPath(r.URL.Path[len("/v1/"):]) @@ -42,9 +42,9 @@ func handleHelp(core *vault.Core, w http.ResponseWriter, r *http.Request) { resp, err := core.HandleRequest(r.Context(), req) if err != nil { - respondErrorCommon(w, req, resp, err, core.SetCustomResponseHeaders) + respondErrorCommon(w, req, resp, err, r) return } - respondOk(w, resp.Data, core.SetCustomResponseHeaders) + respondOk(w, resp.Data, r) } diff --git a/http/http_test.go b/http/http_test.go index e37b9c3d7693e..692aef0d82877 100644 --- a/http/http_test.go +++ b/http/http_test.go @@ -125,6 +125,16 @@ func testResponseStatus(t *testing.T, resp *http.Response, code int) { } } +func testResponseHeader(t *testing.T, resp *http.Response, expectedHeaders map[string]string) { + t.Helper() + for k, v := range expectedHeaders { + hv := resp.Header.Get(k) + if v != hv { + t.Fatalf("expected header value %v=%v, got %v=%v", k, v, k, hv) + } + } +} + func testResponseBody(t *testing.T, resp *http.Response, out interface{}) { defer resp.Body.Close() diff --git a/http/logical.go b/http/logical.go index 79b2504a0ede0..f6d744a874a72 100644 --- a/http/logical.go +++ b/http/logical.go @@ -273,17 +273,17 @@ func handleLogicalRecovery(raw *vault.RawBackend, token *atomic.String) http.Han return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { req, _, statusCode, err := buildLogicalRequestNoAuth(false, w, r) if err != nil || statusCode != 0 { - respondError(w, statusCode, err, nil) + respondError(w, statusCode, err, r) return } reqToken := r.Header.Get(consts.AuthHeaderName) if reqToken == "" || token.Load() == "" || reqToken != token.Load() { - respondError(w, http.StatusForbidden, nil, nil) + respondError(w, http.StatusForbidden, nil, r) return } resp, err := raw.HandleRequest(r.Context(), req) - if respondErrorCommon(w, req, resp, err, nil) { + if respondErrorCommon(w, req, resp, err, r) { return } @@ -292,7 +292,7 @@ func handleLogicalRecovery(raw *vault.RawBackend, token *atomic.String) http.Han httpResp = logical.LogicalResponseToHTTPResponse(resp) httpResp.RequestID = req.ID } - respondOk(w, httpResp, nil) + respondOk(w, httpResp, r) }) } @@ -303,7 +303,7 @@ func handleLogicalInternal(core *vault.Core, injectDataIntoTopLevel bool, noForw return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { req, origBody, statusCode, err := buildLogicalRequest(core, w, r) if err != nil || statusCode != 0 { - respondError(w, statusCode, err, core.SetCustomResponseHeaders) + respondError(w, statusCode, err, r) return } @@ -315,7 +315,7 @@ func handleLogicalInternal(core *vault.Core, injectDataIntoTopLevel bool, noForw resp, ok, needsForward := request(core, w, r, req) switch { case needsForward && noForward: - respondError(w, http.StatusBadRequest, vault.ErrCannotForwardLocalOnly, core.SetCustomResponseHeaders) + respondError(w, http.StatusBadRequest, vault.ErrCannotForwardLocalOnly, r) return case needsForward && !noForward: if origBody != nil { @@ -351,18 +351,14 @@ func respondLogical(core *vault.Core, w http.ResponseWriter, r *http.Request, re // If we have a redirect, redirect! We use a 307 code // because we don't actually know if its permanent. status := 307 - core.SetCustomResponseHeaders(w, status) + SetCustomResponseHeaders(w, status, r) http.Redirect(w, r, resp.Redirect, status) return } // Check if this is a raw response if _, ok := resp.Data[logical.HTTPStatusCode]; ok { - var hs customResponseHeaderSetter - if core != nil { - hs = core.SetCustomResponseHeaders - } - respondRaw(w, r, resp, hs) + respondRaw(w, r, resp) return } @@ -395,20 +391,18 @@ func respondLogical(core *vault.Core, w http.ResponseWriter, r *http.Request, re adjustResponse(core, w, req) // Respond - respondOk(w, ret, core.SetCustomResponseHeaders) + respondOk(w, ret, r) return } // respondRaw is used when the response is using HTTPContentType and HTTPRawBody // to change the default response handling. This is only used for specific things like // returning the CRL information on the PKI backends. -func respondRaw(w http.ResponseWriter, r *http.Request, resp *logical.Response, hs customResponseHeaderSetter) { +func respondRaw(w http.ResponseWriter, r *http.Request, resp *logical.Response) { retErr := func(w http.ResponseWriter, err string) { w.Header().Set("X-Vault-Raw-Error", err) code := http.StatusInternalServerError - if hs != nil { - hs(w, code) - } + SetCustomResponseHeaders(w, code, r) w.WriteHeader(code) w.Write(nil) } @@ -497,9 +491,8 @@ WRITE_RESPONSE: if cacheControl, ok := resp.Data[logical.HTTPRawCacheControl].(string); ok { w.Header().Set("Cache-Control", cacheControl) } - if hs != nil { - hs(w, status) - } + + SetCustomResponseHeaders(w, status, r) w.WriteHeader(status) w.Write(body) } diff --git a/http/sys_feature_flags.go b/http/sys_feature_flags.go index 512dfdfb834a5..9f115d6e73d7d 100644 --- a/http/sys_feature_flags.go +++ b/http/sys_feature_flags.go @@ -31,7 +31,7 @@ func handleSysInternalFeatureFlags(core *vault.Core) http.Handler { case "GET": break default: - respondError(w, http.StatusMethodNotAllowed, nil, core.SetCustomResponseHeaders) + respondError(w, http.StatusMethodNotAllowed, nil, r) } response := &FeatureFlagsResponse{} @@ -44,7 +44,7 @@ func handleSysInternalFeatureFlags(core *vault.Core) http.Handler { w.Header().Set("Content-Type", "application/json") status := http.StatusOK - core.SetCustomResponseHeaders(w, status) + SetCustomResponseHeaders(w, status, r) w.WriteHeader(status) // Generate the response diff --git a/http/sys_generate_root.go b/http/sys_generate_root.go index 2329de6a1282b..b91e94dd651c5 100644 --- a/http/sys_generate_root.go +++ b/http/sys_generate_root.go @@ -22,7 +22,7 @@ func handleSysGenerateRootAttempt(core *vault.Core, generateStrategy vault.Gener case "DELETE": handleSysGenerateRootAttemptDelete(core, w, r) default: - respondError(w, http.StatusMethodNotAllowed, nil, core.SetCustomResponseHeaders) + respondError(w, http.StatusMethodNotAllowed, nil, r) } }) } @@ -34,11 +34,11 @@ func handleSysGenerateRootAttemptGet(core *vault.Core, w http.ResponseWriter, r // Get the current seal configuration barrierConfig, err := core.SealAccess().BarrierConfig(ctx) if err != nil { - respondError(w, http.StatusInternalServerError, err, core.SetCustomResponseHeaders) + respondError(w, http.StatusInternalServerError, err, r) return } if barrierConfig == nil { - respondError(w, http.StatusBadRequest, fmt.Errorf("server is not yet initialized"), core.SetCustomResponseHeaders) + respondError(w, http.StatusBadRequest, fmt.Errorf("server is not yet initialized"), r) return } @@ -46,7 +46,7 @@ func handleSysGenerateRootAttemptGet(core *vault.Core, w http.ResponseWriter, r if core.SealAccess().RecoveryKeySupported() { sealConfig, err = core.SealAccess().RecoveryConfig(ctx) if err != nil { - respondError(w, http.StatusInternalServerError, err, core.SetCustomResponseHeaders) + respondError(w, http.StatusInternalServerError, err, r) return } } @@ -54,14 +54,14 @@ func handleSysGenerateRootAttemptGet(core *vault.Core, w http.ResponseWriter, r // Get the generation configuration generationConfig, err := core.GenerateRootConfiguration() if err != nil { - respondError(w, http.StatusInternalServerError, err, core.SetCustomResponseHeaders) + respondError(w, http.StatusInternalServerError, err, r) return } // Get the progress progress, err := core.GenerateRootProgress() if err != nil { - respondError(w, http.StatusInternalServerError, err, core.SetCustomResponseHeaders) + respondError(w, http.StatusInternalServerError, err, r) return } @@ -80,7 +80,7 @@ func handleSysGenerateRootAttemptGet(core *vault.Core, w http.ResponseWriter, r status.PGPFingerprint = generationConfig.PGPFingerprint } - respondOk(w, status, core.SetCustomResponseHeaders) + respondOk(w, status, r) } func handleSysGenerateRootAttemptPut(core *vault.Core, w http.ResponseWriter, r *http.Request, generateStrategy vault.GenerateRootStrategy) { @@ -88,7 +88,7 @@ func handleSysGenerateRootAttemptPut(core *vault.Core, w http.ResponseWriter, r // Parse the request var req GenerateRootInitRequest if _, err := parseJSONRequest(core.PerfStandby(), r, w, &req); err != nil && err != io.EOF { - respondError(w, http.StatusBadRequest, err, core.SetCustomResponseHeaders) + respondError(w, http.StatusBadRequest, err, r) return } @@ -101,14 +101,14 @@ func handleSysGenerateRootAttemptPut(core *vault.Core, w http.ResponseWriter, r genned = true req.OTP, err = base62.Random(vault.TokenLength + 2) if err != nil { - respondError(w, http.StatusInternalServerError, err, core.SetCustomResponseHeaders) + respondError(w, http.StatusInternalServerError, err, r) return } } // Attemptialize the generation if err := core.GenerateRootInit(req.OTP, req.PGPKey, generateStrategy); err != nil { - respondError(w, http.StatusBadRequest, err, core.SetCustomResponseHeaders) + respondError(w, http.StatusBadRequest, err, r) return } @@ -124,10 +124,10 @@ func handleSysGenerateRootAttemptDelete(core *vault.Core, w http.ResponseWriter, errNew := core.GenerateRootCancel() if errNew != nil { - respondError(w, http.StatusInternalServerError, errNew, core.SetCustomResponseHeaders) + respondError(w, http.StatusInternalServerError, errNew, r) return } - respondOk(w, nil, core.SetCustomResponseHeaders) + respondOk(w, nil, r) } func handleSysGenerateRootUpdate(core *vault.Core, generateStrategy vault.GenerateRootStrategy) http.Handler { @@ -135,14 +135,14 @@ func handleSysGenerateRootUpdate(core *vault.Core, generateStrategy vault.Genera // Parse the request var req GenerateRootUpdateRequest if _, err := parseJSONRequest(core.PerfStandby(), r, w, &req); err != nil { - respondError(w, http.StatusBadRequest, err, core.SetCustomResponseHeaders) + respondError(w, http.StatusBadRequest, err, r) return } if req.Key == "" { respondError( w, http.StatusBadRequest, errors.New("'key' must be specified in request body as JSON"), - core.SetCustomResponseHeaders) + r) return } @@ -158,7 +158,7 @@ func handleSysGenerateRootUpdate(core *vault.Core, generateStrategy vault.Genera respondError( w, http.StatusBadRequest, errors.New("'key' must be a valid hex or base64 string"), - core.SetCustomResponseHeaders) + r) return } } @@ -169,7 +169,7 @@ func handleSysGenerateRootUpdate(core *vault.Core, generateStrategy vault.Genera // Use the key to make progress on root generation result, err := core.GenerateRootUpdate(ctx, key, req.Nonce, generateStrategy) if err != nil { - respondError(w, http.StatusBadRequest, err, core.SetCustomResponseHeaders) + respondError(w, http.StatusBadRequest, err, r) return } @@ -187,7 +187,7 @@ func handleSysGenerateRootUpdate(core *vault.Core, generateStrategy vault.Genera resp.EncodedRootToken = result.EncodedToken } - respondOk(w, resp, core.SetCustomResponseHeaders) + respondOk(w, resp, r) }) } diff --git a/http/sys_health.go b/http/sys_health.go index 219ebdf8e6ee7..0f1865db474a5 100644 --- a/http/sys_health.go +++ b/http/sys_health.go @@ -22,7 +22,7 @@ func handleSysHealth(core *vault.Core) http.Handler { case "HEAD": handleSysHealthHead(core, w, r) default: - respondError(w, http.StatusMethodNotAllowed, nil, core.SetCustomResponseHeaders) + respondError(w, http.StatusMethodNotAllowed, nil, r) } }) } @@ -44,17 +44,17 @@ func handleSysHealthGet(core *vault.Core, w http.ResponseWriter, r *http.Request code, body, err := getSysHealth(core, r) if err != nil { core.Logger().Error("error checking health", "error", err) - respondError(w, code, nil, core.SetCustomResponseHeaders) + respondError(w, code, nil, r) return } if body == nil { - respondError(w, code, nil, core.SetCustomResponseHeaders) + respondError(w, code, nil, r) return } w.Header().Set("Content-Type", "application/json") - core.SetCustomResponseHeaders(w, code) + SetCustomResponseHeaders(w, code, r) w.WriteHeader(code) // Generate the response @@ -69,7 +69,7 @@ func handleSysHealthHead(core *vault.Core, w http.ResponseWriter, r *http.Reques w.Header().Set("Content-Type", "application/json") } - core.SetCustomResponseHeaders(w, code) + SetCustomResponseHeaders(w, code, r) w.WriteHeader(code) } diff --git a/http/sys_init.go b/http/sys_init.go index 3a224f804ae4b..c7dc5fcd50cf8 100644 --- a/http/sys_init.go +++ b/http/sys_init.go @@ -17,7 +17,7 @@ func handleSysInit(core *vault.Core) http.Handler { case "PUT", "POST": handleSysInitPut(core, w, r) default: - respondError(w, http.StatusMethodNotAllowed, nil, core.SetCustomResponseHeaders) + respondError(w, http.StatusMethodNotAllowed, nil, r) } }) } @@ -25,13 +25,13 @@ func handleSysInit(core *vault.Core) http.Handler { func handleSysInitGet(core *vault.Core, w http.ResponseWriter, r *http.Request) { init, err := core.Initialized(context.Background()) if err != nil { - respondError(w, http.StatusInternalServerError, err, core.SetCustomResponseHeaders) + respondError(w, http.StatusInternalServerError, err, r) return } respondOk(w, &InitStatusResponse{ Initialized: init, - }, core.SetCustomResponseHeaders) + }, r) } func handleSysInitPut(core *vault.Core, w http.ResponseWriter, r *http.Request) { @@ -41,7 +41,7 @@ func handleSysInitPut(core *vault.Core, w http.ResponseWriter, r *http.Request) // Parse the request var req InitRequest if _, err := parseJSONRequest(core.PerfStandby(), r, w, &req); err != nil { - respondError(w, http.StatusBadRequest, err, core.SetCustomResponseHeaders) + respondError(w, http.StatusBadRequest, err, r) return } @@ -68,7 +68,7 @@ func handleSysInitPut(core *vault.Core, w http.ResponseWriter, r *http.Request) result, initErr := core.Initialize(ctx, initParams) if initErr != nil { if vault.IsFatalError(initErr) { - respondError(w, http.StatusBadRequest, initErr, core.SetCustomResponseHeaders) + respondError(w, http.StatusBadRequest, initErr, r) return } else { // Add a warnings field? The error will be logged in the vault log @@ -100,11 +100,11 @@ func handleSysInitPut(core *vault.Core, w http.ResponseWriter, r *http.Request) } if err := core.UnsealWithStoredKeys(ctx); err != nil { - respondError(w, http.StatusInternalServerError, err, core.SetCustomResponseHeaders) + respondError(w, http.StatusInternalServerError, err, r) return } - respondOk(w, resp, core.SetCustomResponseHeaders) + respondOk(w, resp, r) } type InitRequest struct { diff --git a/http/sys_leader.go b/http/sys_leader.go index b31da9ada99f1..71ed9c796c6e7 100644 --- a/http/sys_leader.go +++ b/http/sys_leader.go @@ -14,7 +14,7 @@ func handleSysLeader(core *vault.Core) http.Handler { case "GET": handleSysLeaderGet(core, w, r) default: - respondError(w, http.StatusMethodNotAllowed, nil, core.SetCustomResponseHeaders) + respondError(w, http.StatusMethodNotAllowed, nil, r) } }) } @@ -22,8 +22,8 @@ func handleSysLeader(core *vault.Core) http.Handler { func handleSysLeaderGet(core *vault.Core, w http.ResponseWriter, r *http.Request) { resp, err := core.GetLeaderStatus() if err != nil { - respondError(w, http.StatusInternalServerError, err, core.SetCustomResponseHeaders) + respondError(w, http.StatusInternalServerError, err, r) return } - respondOk(w, resp, core.SetCustomResponseHeaders) + respondOk(w, resp, r) } diff --git a/http/sys_metrics.go b/http/sys_metrics.go index 840d75d87be5d..65b69d3df3f11 100644 --- a/http/sys_metrics.go +++ b/http/sys_metrics.go @@ -17,13 +17,13 @@ func handleMetricsUnauthenticated(core *vault.Core) http.Handler { switch r.Method { case "GET": default: - respondError(w, http.StatusMethodNotAllowed, nil, core.SetCustomResponseHeaders) + respondError(w, http.StatusMethodNotAllowed, nil, r) return } // Parse form if err := r.ParseForm(); err != nil { - respondError(w, http.StatusBadRequest, err, core.SetCustomResponseHeaders) + respondError(w, http.StatusBadRequest, err, r) return } @@ -37,7 +37,7 @@ func handleMetricsUnauthenticated(core *vault.Core) http.Handler { // Manually extract the logical response and send back the information status := resp.Data[logical.HTTPStatusCode].(int) - core.SetCustomResponseHeaders(w, status) + SetCustomResponseHeaders(w, status, r) w.Header().Set("Content-Type", resp.Data[logical.HTTPContentType].(string)) switch v := resp.Data[logical.HTTPRawBody].(type) { case string: @@ -47,7 +47,7 @@ func handleMetricsUnauthenticated(core *vault.Core) http.Handler { w.WriteHeader(status) w.Write(v) default: - respondError(w, http.StatusInternalServerError, fmt.Errorf("wrong response returned"), core.SetCustomResponseHeaders) + respondError(w, http.StatusInternalServerError, fmt.Errorf("wrong response returned"), r) } }) } diff --git a/http/sys_raft.go b/http/sys_raft.go index 1659a8da86cd9..93612b8d1fb13 100644 --- a/http/sys_raft.go +++ b/http/sys_raft.go @@ -18,16 +18,16 @@ func handleSysRaftBootstrap(core *vault.Core) http.Handler { switch r.Method { case "POST", "PUT": if core.Sealed() { - respondError(w, http.StatusBadRequest, errors.New("node must be unsealed to bootstrap"), core.SetCustomResponseHeaders) + respondError(w, http.StatusBadRequest, errors.New("node must be unsealed to bootstrap"), r) } if err := core.RaftBootstrap(context.Background(), false); err != nil { - respondError(w, http.StatusInternalServerError, err, core.SetCustomResponseHeaders) + respondError(w, http.StatusInternalServerError, err, r) return } default: - respondError(w, http.StatusBadRequest, nil, core.SetCustomResponseHeaders) + respondError(w, http.StatusBadRequest, nil, r) } }) } @@ -38,7 +38,7 @@ func handleSysRaftJoin(core *vault.Core) http.Handler { case "POST", "PUT": handleSysRaftJoinPost(core, w, r) default: - respondError(w, http.StatusMethodNotAllowed, nil, core.SetCustomResponseHeaders) + respondError(w, http.StatusMethodNotAllowed, nil, r) } }) } @@ -47,12 +47,12 @@ func handleSysRaftJoinPost(core *vault.Core, w http.ResponseWriter, r *http.Requ // Parse the request var req JoinRequest if _, err := parseJSONRequest(core.PerfStandby(), r, w, &req); err != nil && err != io.EOF { - respondError(w, http.StatusBadRequest, err, core.SetCustomResponseHeaders) + respondError(w, http.StatusBadRequest, err, r) return } if req.NonVoter && !nonVotersAllowed { - respondError(w, http.StatusBadRequest, errors.New("non-voting nodes not allowed"), core.SetCustomResponseHeaders) + respondError(w, http.StatusBadRequest, errors.New("non-voting nodes not allowed"), r) return } @@ -61,14 +61,14 @@ func handleSysRaftJoinPost(core *vault.Core, w http.ResponseWriter, r *http.Requ if len(req.LeaderCACert) != 0 || len(req.LeaderClientCert) != 0 || len(req.LeaderClientKey) != 0 { tlsConfig, err = tlsutil.ClientTLSConfig([]byte(req.LeaderCACert), []byte(req.LeaderClientCert), []byte(req.LeaderClientKey)) if err != nil { - respondError(w, http.StatusBadRequest, err, core.SetCustomResponseHeaders) + respondError(w, http.StatusBadRequest, err, r) return } tlsConfig.ServerName = req.LeaderTLSServerName } if req.AutoJoinScheme != "" && (req.AutoJoinScheme != "http" && req.AutoJoinScheme != "https") { - respondError(w, http.StatusBadRequest, fmt.Errorf("invalid scheme '%s'; must either be http or https", req.AutoJoinScheme), core.SetCustomResponseHeaders) + respondError(w, http.StatusBadRequest, fmt.Errorf("invalid scheme '%s'; must either be http or https", req.AutoJoinScheme), r) return } @@ -85,14 +85,14 @@ func handleSysRaftJoinPost(core *vault.Core, w http.ResponseWriter, r *http.Requ joined, err := core.JoinRaftCluster(context.Background(), leaderInfos, req.NonVoter) if err != nil { - respondError(w, http.StatusInternalServerError, err, core.SetCustomResponseHeaders) + respondError(w, http.StatusInternalServerError, err, r) return } resp := JoinResponse{ Joined: joined, } - respondOk(w, resp, core.SetCustomResponseHeaders) + respondOk(w, resp, r) } type JoinResponse struct { diff --git a/http/sys_rekey.go b/http/sys_rekey.go index f93b8370789de..41dd49e5126fc 100644 --- a/http/sys_rekey.go +++ b/http/sys_rekey.go @@ -17,7 +17,7 @@ func handleSysRekeyInit(core *vault.Core, recovery bool) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { standby, _ := core.Standby() if standby { - respondStandby(core, w, r.URL) + respondStandby(core, w, r) return } @@ -25,7 +25,7 @@ func handleSysRekeyInit(core *vault.Core, recovery bool) http.Handler { if repState.HasState(consts.ReplicationPerformanceSecondary) { respondError(w, http.StatusBadRequest, fmt.Errorf("rekeying can only be performed on the primary cluster when replication is activated"), - core.SetCustomResponseHeaders) + r) return } @@ -34,7 +34,7 @@ func handleSysRekeyInit(core *vault.Core, recovery bool) http.Handler { switch { case recovery && !core.SealAccess().RecoveryKeySupported(): - respondError(w, http.StatusBadRequest, fmt.Errorf("recovery rekeying not supported"), core.SetCustomResponseHeaders) + respondError(w, http.StatusBadRequest, fmt.Errorf("recovery rekeying not supported"), r) case r.Method == "GET": handleSysRekeyInitGet(ctx, core, recovery, w, r) case r.Method == "POST" || r.Method == "PUT": @@ -42,7 +42,7 @@ func handleSysRekeyInit(core *vault.Core, recovery bool) http.Handler { case r.Method == "DELETE": handleSysRekeyInitDelete(ctx, core, recovery, w, r) default: - respondError(w, http.StatusMethodNotAllowed, nil, core.SetCustomResponseHeaders) + respondError(w, http.StatusMethodNotAllowed, nil, r) } }) } @@ -50,24 +50,24 @@ func handleSysRekeyInit(core *vault.Core, recovery bool) http.Handler { func handleSysRekeyInitGet(ctx context.Context, core *vault.Core, recovery bool, w http.ResponseWriter, r *http.Request) { barrierConfig, barrierConfErr := core.SealAccess().BarrierConfig(ctx) if barrierConfErr != nil { - respondError(w, http.StatusInternalServerError, barrierConfErr, core.SetCustomResponseHeaders) + respondError(w, http.StatusInternalServerError, barrierConfErr, r) return } if barrierConfig == nil { - respondError(w, http.StatusBadRequest, fmt.Errorf("server is not yet initialized"), core.SetCustomResponseHeaders) + respondError(w, http.StatusBadRequest, fmt.Errorf("server is not yet initialized"), r) return } // Get the rekey configuration rekeyConf, err := core.RekeyConfig(recovery) if err != nil { - respondError(w, err.Code(), err, core.SetCustomResponseHeaders) + respondError(w, err.Code(), err, r) return } sealThreshold, err := core.RekeyThreshold(ctx, recovery) if err != nil { - respondError(w, err.Code(), err, core.SetCustomResponseHeaders) + respondError(w, err.Code(), err, r) return } @@ -82,7 +82,7 @@ func handleSysRekeyInitGet(ctx context.Context, core *vault.Core, recovery bool, // Get the progress started, progress, err := core.RekeyProgress(recovery, false) if err != nil { - respondError(w, err.Code(), err, core.SetCustomResponseHeaders) + respondError(w, err.Code(), err, r) return } @@ -96,31 +96,31 @@ func handleSysRekeyInitGet(ctx context.Context, core *vault.Core, recovery bool, if rekeyConf.PGPKeys != nil && len(rekeyConf.PGPKeys) != 0 { pgpFingerprints, err := pgpkeys.GetFingerprints(rekeyConf.PGPKeys, nil) if err != nil { - respondError(w, http.StatusInternalServerError, err, core.SetCustomResponseHeaders) + respondError(w, http.StatusInternalServerError, err, r) return } status.PGPFingerprints = pgpFingerprints status.Backup = rekeyConf.Backup } } - respondOk(w, status, core.SetCustomResponseHeaders) + respondOk(w, status, r) } func handleSysRekeyInitPut(ctx context.Context, core *vault.Core, recovery bool, w http.ResponseWriter, r *http.Request) { // Parse the request var req RekeyRequest if _, err := parseJSONRequest(core.PerfStandby(), r, w, &req); err != nil { - respondError(w, http.StatusBadRequest, err, core.SetCustomResponseHeaders) + respondError(w, http.StatusBadRequest, err, r) return } if req.Backup && len(req.PGPKeys) == 0 { - respondError(w, http.StatusBadRequest, fmt.Errorf("cannot request a backup of the new keys without providing PGP keys for encryption"), core.SetCustomResponseHeaders) + respondError(w, http.StatusBadRequest, fmt.Errorf("cannot request a backup of the new keys without providing PGP keys for encryption"), r) return } if len(req.PGPKeys) > 0 && len(req.PGPKeys) != req.SecretShares { - respondError(w, http.StatusBadRequest, fmt.Errorf("incorrect number of PGP keys for rekey"), core.SetCustomResponseHeaders) + respondError(w, http.StatusBadRequest, fmt.Errorf("incorrect number of PGP keys for rekey"), r) return } @@ -134,7 +134,7 @@ func handleSysRekeyInitPut(ctx context.Context, core *vault.Core, recovery bool, VerificationRequired: req.RequireVerification, }, recovery) if err != nil { - respondError(w, err.Code(), err, core.SetCustomResponseHeaders) + respondError(w, err.Code(), err, r) return } @@ -143,31 +143,31 @@ func handleSysRekeyInitPut(ctx context.Context, core *vault.Core, recovery bool, func handleSysRekeyInitDelete(ctx context.Context, core *vault.Core, recovery bool, w http.ResponseWriter, r *http.Request) { if err := core.RekeyCancel(recovery); err != nil { - respondError(w, err.Code(), err, core.SetCustomResponseHeaders) + respondError(w, err.Code(), err, r) return } - respondOk(w, nil, core.SetCustomResponseHeaders) + respondOk(w, nil, r) } func handleSysRekeyUpdate(core *vault.Core, recovery bool) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { standby, _ := core.Standby() if standby { - respondStandby(core, w, r.URL) + respondStandby(core, w, r) return } // Parse the request var req RekeyUpdateRequest if _, err := parseJSONRequest(core.PerfStandby(), r, w, &req); err != nil { - respondError(w, http.StatusBadRequest, err, core.SetCustomResponseHeaders) + respondError(w, http.StatusBadRequest, err, r) return } if req.Key == "" { respondError( w, http.StatusBadRequest, errors.New("'key' must be specified in request body as JSON"), - core.SetCustomResponseHeaders) + r) return } @@ -183,7 +183,7 @@ func handleSysRekeyUpdate(core *vault.Core, recovery bool) http.Handler { respondError( w, http.StatusBadRequest, errors.New("'key' must be a valid hex or base64 string"), - core.SetCustomResponseHeaders) + r) return } } @@ -194,7 +194,7 @@ func handleSysRekeyUpdate(core *vault.Core, recovery bool) http.Handler { // Use the key to make progress on rekey result, rekeyErr := core.RekeyUpdate(ctx, key, req.Nonce, recovery) if rekeyErr != nil { - respondError(w, rekeyErr.Code(), rekeyErr, core.SetCustomResponseHeaders) + respondError(w, rekeyErr.Code(), rekeyErr, r) return } @@ -217,7 +217,7 @@ func handleSysRekeyUpdate(core *vault.Core, recovery bool) http.Handler { } resp.Keys = keys resp.KeysB64 = keysB64 - respondOk(w, resp, core.SetCustomResponseHeaders) + respondOk(w, resp, r) } else { handleSysRekeyInitGet(ctx, core, recovery, w, r) } @@ -228,7 +228,7 @@ func handleSysRekeyVerify(core *vault.Core, recovery bool) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { standby, _ := core.Standby() if standby { - respondStandby(core, w, r.URL) + respondStandby(core, w, r) return } @@ -236,7 +236,7 @@ func handleSysRekeyVerify(core *vault.Core, recovery bool) http.Handler { if repState.HasState(consts.ReplicationPerformanceSecondary) { respondError(w, http.StatusBadRequest, fmt.Errorf("rekeying can only be performed on the primary cluster when replication is activated"), - core.SetCustomResponseHeaders) + r) return } @@ -245,7 +245,7 @@ func handleSysRekeyVerify(core *vault.Core, recovery bool) http.Handler { switch { case recovery && !core.SealAccess().RecoveryKeySupported(): - respondError(w, http.StatusBadRequest, fmt.Errorf("recovery rekeying not supported"), core.SetCustomResponseHeaders) + respondError(w, http.StatusBadRequest, fmt.Errorf("recovery rekeying not supported"), r) case r.Method == "GET": handleSysRekeyVerifyGet(ctx, core, recovery, w, r) case r.Method == "POST" || r.Method == "PUT": @@ -253,7 +253,7 @@ func handleSysRekeyVerify(core *vault.Core, recovery bool) http.Handler { case r.Method == "DELETE": handleSysRekeyVerifyDelete(ctx, core, recovery, w, r) default: - respondError(w, http.StatusMethodNotAllowed, nil, core.SetCustomResponseHeaders) + respondError(w, http.StatusMethodNotAllowed, nil, r) } }) } @@ -261,29 +261,29 @@ func handleSysRekeyVerify(core *vault.Core, recovery bool) http.Handler { func handleSysRekeyVerifyGet(ctx context.Context, core *vault.Core, recovery bool, w http.ResponseWriter, r *http.Request) { barrierConfig, barrierConfErr := core.SealAccess().BarrierConfig(ctx) if barrierConfErr != nil { - respondError(w, http.StatusInternalServerError, barrierConfErr, core.SetCustomResponseHeaders) + respondError(w, http.StatusInternalServerError, barrierConfErr, r) return } if barrierConfig == nil { - respondError(w, http.StatusBadRequest, fmt.Errorf("server is not yet initialized"), core.SetCustomResponseHeaders) + respondError(w, http.StatusBadRequest, fmt.Errorf("server is not yet initialized"), r) return } // Get the rekey configuration rekeyConf, err := core.RekeyConfig(recovery) if err != nil { - respondError(w, err.Code(), err, core.SetCustomResponseHeaders) + respondError(w, err.Code(), err, r) return } if rekeyConf == nil { - respondError(w, http.StatusBadRequest, errors.New("no rekey configuration found"), core.SetCustomResponseHeaders) + respondError(w, http.StatusBadRequest, errors.New("no rekey configuration found"), r) return } // Get the progress started, progress, err := core.RekeyProgress(recovery, true) if err != nil { - respondError(w, err.Code(), err, core.SetCustomResponseHeaders) + respondError(w, err.Code(), err, r) return } @@ -295,12 +295,12 @@ func handleSysRekeyVerifyGet(ctx context.Context, core *vault.Core, recovery boo N: rekeyConf.SecretShares, Progress: progress, } - respondOk(w, status, core.SetCustomResponseHeaders) + respondOk(w, status, r) } func handleSysRekeyVerifyDelete(ctx context.Context, core *vault.Core, recovery bool, w http.ResponseWriter, r *http.Request) { if err := core.RekeyVerifyRestart(recovery); err != nil { - respondError(w, err.Code(), err, core.SetCustomResponseHeaders) + respondError(w, err.Code(), err, r) return } @@ -311,14 +311,14 @@ func handleSysRekeyVerifyPut(ctx context.Context, core *vault.Core, recovery boo // Parse the request var req RekeyVerificationUpdateRequest if _, err := parseJSONRequest(core.PerfStandby(), r, w, &req); err != nil { - respondError(w, http.StatusBadRequest, err, core.SetCustomResponseHeaders) + respondError(w, http.StatusBadRequest, err, r) return } if req.Key == "" { respondError( w, http.StatusBadRequest, errors.New("'key' must be specified in request body as JSON"), - core.SetCustomResponseHeaders) + r) return } @@ -334,7 +334,7 @@ func handleSysRekeyVerifyPut(ctx context.Context, core *vault.Core, recovery boo respondError( w, http.StatusBadRequest, errors.New("'key' must be a valid hex or base64 string"), - core.SetCustomResponseHeaders) + r) return } } @@ -345,7 +345,7 @@ func handleSysRekeyVerifyPut(ctx context.Context, core *vault.Core, recovery boo // Use the key to make progress on rekey result, rekeyErr := core.RekeyVerify(ctx, key, req.Nonce, recovery) if rekeyErr != nil { - respondError(w, rekeyErr.Code(), rekeyErr, core.SetCustomResponseHeaders) + respondError(w, rekeyErr.Code(), rekeyErr, r) return } @@ -354,7 +354,7 @@ func handleSysRekeyVerifyPut(ctx context.Context, core *vault.Core, recovery boo if result != nil { resp.Complete = true resp.Nonce = result.Nonce - respondOk(w, resp, core.SetCustomResponseHeaders) + respondOk(w, resp, r) } else { handleSysRekeyVerifyGet(ctx, core, recovery, w, r) } diff --git a/http/sys_seal.go b/http/sys_seal.go index effcca7221d74..ef4652779a5b3 100644 --- a/http/sys_seal.go +++ b/http/sys_seal.go @@ -17,14 +17,14 @@ func handleSysSeal(core *vault.Core) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { req, _, statusCode, err := buildLogicalRequest(core, w, r) if err != nil || statusCode != 0 { - respondError(w, statusCode, err, core.SetCustomResponseHeaders) + respondError(w, statusCode, err, r) return } switch req.Operation { case logical.UpdateOperation: default: - respondError(w, http.StatusMethodNotAllowed, nil, core.SetCustomResponseHeaders) + respondError(w, http.StatusMethodNotAllowed, nil, r) return } @@ -32,14 +32,14 @@ func handleSysSeal(core *vault.Core) http.Handler { // We use context.Background since there won't be a request context if the node isn't active if err := core.SealWithRequest(r.Context(), req); err != nil { if errwrap.Contains(err, logical.ErrPermissionDenied.Error()) { - respondError(w, http.StatusForbidden, err, core.SetCustomResponseHeaders) + respondError(w, http.StatusForbidden, err, r) return } - respondError(w, http.StatusInternalServerError, err, core.SetCustomResponseHeaders) + respondError(w, http.StatusInternalServerError, err, r) return } - respondOk(w, nil, core.SetCustomResponseHeaders) + respondOk(w, nil, r) }) } @@ -47,28 +47,28 @@ func handleSysStepDown(core *vault.Core) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { req, _, statusCode, err := buildLogicalRequest(core, w, r) if err != nil || statusCode != 0 { - respondError(w, statusCode, err, core.SetCustomResponseHeaders) + respondError(w, statusCode, err, r) return } switch req.Operation { case logical.UpdateOperation: default: - respondError(w, http.StatusMethodNotAllowed, nil, core.SetCustomResponseHeaders) + respondError(w, http.StatusMethodNotAllowed, nil, r) return } // Seal with the token above if err := core.StepDown(r.Context(), req); err != nil { if errwrap.Contains(err, logical.ErrPermissionDenied.Error()) { - respondError(w, http.StatusForbidden, err, core.SetCustomResponseHeaders) + respondError(w, http.StatusForbidden, err, r) return } - respondError(w, http.StatusInternalServerError, err, core.SetCustomResponseHeaders) + respondError(w, http.StatusInternalServerError, err, r) return } - respondOk(w, nil, core.SetCustomResponseHeaders) + respondOk(w, nil, r) }) } @@ -78,20 +78,20 @@ func handleSysUnseal(core *vault.Core) http.Handler { case "PUT": case "POST": default: - respondError(w, http.StatusMethodNotAllowed, nil, core.SetCustomResponseHeaders) + respondError(w, http.StatusMethodNotAllowed, nil, r) return } // Parse the request var req UnsealRequest if _, err := parseJSONRequest(core.PerfStandby(), r, w, &req); err != nil { - respondError(w, http.StatusBadRequest, err, core.SetCustomResponseHeaders) + respondError(w, http.StatusBadRequest, err, r) return } if req.Reset { if !core.Sealed() { - respondError(w, http.StatusBadRequest, errors.New("vault is unsealed"), core.SetCustomResponseHeaders) + respondError(w, http.StatusBadRequest, errors.New("vault is unsealed"), r) return } core.ResetUnsealProcess() @@ -103,7 +103,7 @@ func handleSysUnseal(core *vault.Core) http.Handler { respondError( w, http.StatusBadRequest, errors.New("'key' must be specified in request body as JSON, or 'reset' set to true"), - core.SetCustomResponseHeaders) + r) return } @@ -119,7 +119,7 @@ func handleSysUnseal(core *vault.Core) http.Handler { respondError( w, http.StatusBadRequest, errors.New("'key' must be a valid hex or base64 string"), - core.SetCustomResponseHeaders) + r) return } } @@ -139,10 +139,10 @@ func handleSysUnseal(core *vault.Core) http.Handler { case errwrap.Contains(err, vault.ErrBarrierSealed.Error()): case errwrap.Contains(err, consts.ErrStandby.Error()): default: - respondError(w, http.StatusInternalServerError, err, core.SetCustomResponseHeaders) + respondError(w, http.StatusInternalServerError, err, r) return } - respondError(w, http.StatusBadRequest, err, core.SetCustomResponseHeaders) + respondError(w, http.StatusBadRequest, err, r) return } @@ -154,7 +154,7 @@ func handleSysUnseal(core *vault.Core) http.Handler { func handleSysSealStatus(core *vault.Core) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != "GET" { - respondError(w, http.StatusMethodNotAllowed, nil, core.SetCustomResponseHeaders) + respondError(w, http.StatusMethodNotAllowed, nil, r) return } @@ -166,11 +166,11 @@ func handleSysSealStatusRaw(core *vault.Core, w http.ResponseWriter, r *http.Req ctx := context.Background() status, err := core.GetSealStatus(ctx) if err != nil { - respondError(w, http.StatusInternalServerError, err, core.SetCustomResponseHeaders) + respondError(w, http.StatusInternalServerError, err, r) return } - respondOk(w, status, core.SetCustomResponseHeaders) + respondOk(w, status, r) } // Note: because we didn't provide explicit tagging in the past we can't do it diff --git a/http/util.go b/http/util.go index a27405d571ed1..195c1e9ba137e 100644 --- a/http/util.go +++ b/http/util.go @@ -35,7 +35,7 @@ func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ns, err := namespace.FromContext(r.Context()) if err != nil { - respondError(w, http.StatusInternalServerError, err, core.SetCustomResponseHeaders) + respondError(w, http.StatusInternalServerError, err, r) return } @@ -44,7 +44,7 @@ func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler // again, which is not desired. path, status, err := buildLogicalPath(r) if err != nil || status != 0 { - respondError(w, status, err, core.SetCustomResponseHeaders) + respondError(w, status, err, r) return } @@ -57,7 +57,7 @@ func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler }) if err != nil { core.Logger().Error("failed to apply quota", "path", path, "error", err) - respondError(w, http.StatusUnprocessableEntity, err, core.SetCustomResponseHeaders) + respondError(w, http.StatusUnprocessableEntity, err, r) return } @@ -69,7 +69,7 @@ func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler if !quotaResp.Allowed { quotaErr := fmt.Errorf("request path %q: %w", path, quotas.ErrRateLimitQuotaExceeded) - respondError(w, http.StatusTooManyRequests, quotaErr, core.SetCustomResponseHeaders) + respondError(w, http.StatusTooManyRequests, quotaErr, r) if core.Logger().IsTrace() { core.Logger().Trace("request rejected due to rate limit quota violation", "request_path", path) @@ -78,7 +78,7 @@ func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler if core.RateLimitAuditLoggingEnabled() { req, _, status, err := buildLogicalRequestNoAuth(core.PerfStandby(), w, r) if err != nil || status != 0 { - respondError(w, status, err, core.SetCustomResponseHeaders) + respondError(w, status, err, r) return } diff --git a/internalshared/configutil/http_response_headers.go b/internalshared/configutil/http_response_headers.go index a3c44a3c2b500..12de78be4faa7 100644 --- a/internalshared/configutil/http_response_headers.go +++ b/internalshared/configutil/http_response_headers.go @@ -31,10 +31,10 @@ const ( xFrameOptions = "Deny" xContentTypeOptions = "nosniff" strictTransportSecurity = "max-age=31536000; includeSubDomains" - contentType = "text/plain; charset=utf-8" + contentType = "application/json" ) -func ParseDefaultHeaders(h string) string { +func GetDefaultHeaderValue(h string) string { switch h { case "Content-Security-Policy": return contentSecurityPolicy @@ -66,7 +66,7 @@ func setDefaultResponseHeaders(c map[string]string) map[string]string { if _, ok := c[hn]; ok { continue } - hv := ParseDefaultHeaders(hn) + hv := GetDefaultHeaderValue(hn) if hv != "" { defaults[hn] = hv } @@ -76,7 +76,7 @@ func setDefaultResponseHeaders(c map[string]string) map[string]string { } func ParseCustomResponseHeaders(r interface{}) (map[string]map[string]string, error) { - if !isValidListDict(r) { + if _, ok := r.([]map[string]interface{}); !ok { return nil, fmt.Errorf("response headers were not configured correctly. please make sure they're in a map") } @@ -84,16 +84,16 @@ func ParseCustomResponseHeaders(r interface{}) (map[string]map[string]string, er h := make(map[string]map[string]string) for _, crh := range customResponseHeader { - for sc, rh := range crh { - if !isValidListDict(rh){ - return nil, fmt.Errorf("invalid response header type") + for statusCode, responseHeader := range crh { + if _, ok := responseHeader.([]map[string]interface{}); !ok { + return nil, fmt.Errorf("response headers were not configured correctly. please make sure they're in a map") } - if !IsValidStatusCode(sc) { - return nil, fmt.Errorf("invalid status code found in the config file: %v", sc) + if !IsValidStatusCode(statusCode) { + return nil, fmt.Errorf("invalid status code found in the config file: %v", statusCode) } - hvl := rh.([]map[string]interface{}) + hvl := responseHeader.([]map[string]interface{}) if len(hvl) != 1 { return nil, fmt.Errorf("invalid number of response headers exist") } @@ -103,7 +103,7 @@ func ParseCustomResponseHeaders(r interface{}) (map[string]map[string]string, er return nil, err } - h[sc] = hv + h[statusCode] = hv } } @@ -114,20 +114,6 @@ func ParseCustomResponseHeaders(r interface{}) (map[string]map[string]string, er return h, nil } -func isValidListDict(in interface{}) bool { - if _, ok := in.([]map[string]interface{}); ok { - return true - } - return false -} - -func isValidList(in interface{}) bool { - if _, ok := in.([]interface{}); ok { - return true - } - return false -} - func IsValidStatusCodeCollection(sc string) bool { for _, v := range ValidCustomStatusCodeCollection { if sc == v { @@ -173,11 +159,14 @@ func parseHeaders(in map[string]interface{}) (map[string]string, error) { func parseHeaderValues(h interface{}) (string, error) { var sl []string - if !isValidList(h) { + if _, ok := h.([]interface{}); !ok { return "", fmt.Errorf("headers must be given in a list of strings") } vli := h.([]interface{}) for _, vh := range vli { + if vh.(string) == "" { + continue + } sl = append(sl, vh.(string)) } s := strings.Join(sl, "; ") diff --git a/internalshared/configutil/listener.go b/internalshared/configutil/listener.go index 77cbbf93b10d1..858b6678cc1ec 100644 --- a/internalshared/configutil/listener.go +++ b/internalshared/configutil/listener.go @@ -102,7 +102,7 @@ type Listener struct { // Custom Http response headers CustomResponseHeaders map[string]map[string]string `hcl:"-"` - CustomResponseHeadersRaw interface{} `hcl:"custom_response_headers,alias:custom_response_headers"` + CustomResponseHeadersRaw interface{} `hcl:"custom_response_headers"` } @@ -371,7 +371,7 @@ func ParseListeners(result *SharedConfig, list *ast.ObjectList) error { if l.CustomResponseHeadersRaw != nil { customHeadersMap, err := ParseCustomResponseHeaders(l.CustomResponseHeadersRaw) if err != nil { - return multierror.Prefix(fmt.Errorf("failed to parse custom_response_headers:%w", err), fmt.Sprintf("listeners.%d", i)) + return multierror.Prefix(fmt.Errorf("failed to parse custom_response_headers: %w", err), fmt.Sprintf("listeners.%d", i)) } l.CustomResponseHeaders = customHeadersMap l.CustomResponseHeadersRaw = nil diff --git a/internalshared/listenerutil/custom_response_headers.go b/internalshared/listenerutil/custom_response_headers.go new file mode 100644 index 0000000000000..9b01afb168e31 --- /dev/null +++ b/internalshared/listenerutil/custom_response_headers.go @@ -0,0 +1,160 @@ +package listenerutil + +import ( + "fmt" + "net/http" + "net/textproto" + "strconv" + "strings" + + "github.com/hashicorp/vault/internalshared/configutil" +) + +type ListenerCustomHeaders struct { + Address string + StatusCodeHeaderMap map[string][]*CustomHeader + // ConfiguredHeadersStatusCodeMap field is introduced so that we would not need to loop through + // StatusCodeHeaderMap to see if a header exists, the key for this map is the headers names + ConfiguredHeadersStatusCodeMap map[string][]string +} + +type CustomHeader struct { + Name string + Value string +} + +// ChangeListenerAddress is used for tests where the listener address (at least the port) +// is chosen at random +func (l *ListenerCustomHeaders) ChangeListenerAddress(la string) { + l.Address = la + return +} + +func (l *ListenerCustomHeaders) SetCustomResponseHeaders(w http.ResponseWriter, status int) { + if w == nil { + fmt.Println("No ResponseWriter provided") + return + } + + sch := l.StatusCodeHeaderMap + if sch == nil { + fmt.Println("status code header map not configured") + return + } + + // setter function to set the headers + setter := func(hvl []*CustomHeader) { + for _, hv := range hvl { + w.Header().Set(hv.Name, hv.Value) + } + } + + // Checking the validity of the status code + if status >= 600 || status < 100 { + fmt.Println("invalid status code") + return + } + + // Setting the default headers first + setter(sch["default"]) + + // setting the Xyy pattern first + d := fmt.Sprintf("%vxx", status / 100) + if val, ok := sch[d]; ok { + setter(val) + } + + // Setting the specific headers + if val, ok := sch[strconv.Itoa(status)]; ok { + setter(val) + } + + return +} + + +func (l *ListenerCustomHeaders) findCustomHeaderMatchStatusCode(sc string, hn string) string { + + getHeader := func(ch []*CustomHeader) string { + for _, h := range ch { + if h.Name == hn { + return h.Value + } + } + return "" + } + + hm := l.StatusCodeHeaderMap + + // starting with the most specific status code + if ch, ok := hm[sc]; ok { + h := getHeader(ch) + if h != "" { + return h + } + } + + // Checking for the Yxx pattern + var firstDig string + if len(sc) == 3 { + firstDig = strings.Split(sc, "")[0] + } + if firstDig != "" { + s := fmt.Sprintf("%vxx", firstDig) + if configutil.IsValidStatusCodeCollection(s) { + if ch, ok := hm[s]; ok { + h := getHeader(ch) + if h != "" { + return h + } + } + } + } + + // At this point, we could not find a match for the given status code in the config file + // so, we just return the "default" ones + h := getHeader(hm["default"]) + if h != ""{ + return h + } + + return "" +} + +func(l *ListenerCustomHeaders) FetchHeaderForStatusCode(header, sc string) (string, error) { + + if header == "" { + return "", fmt.Errorf("invalid target header") + } + + if l.StatusCodeHeaderMap == nil { + return "", fmt.Errorf("custom headers not configured") + } + + if !configutil.IsValidStatusCode(sc) { + return "", fmt.Errorf("failed to check if a header exist in config file due to invalid status code") + } + + hn := textproto.CanonicalMIMEHeaderKey(header) + + h := l.findCustomHeaderMatchStatusCode(sc, hn) + + return h, nil +} + +func (l *ListenerCustomHeaders) ExistCustomResponseHeader(header string) bool { + + if header == "" { + return false + } + + if l.StatusCodeHeaderMap == nil { + return false + } + + hn := textproto.CanonicalMIMEHeaderKey(header) + + hs := l.ConfiguredHeadersStatusCodeMap + _, ok := hs[hn] + return ok +} \ No newline at end of file diff --git a/vault/core.go b/vault/core.go index 255a98c59a04b..54125ebeec70e 100644 --- a/vault/core.go +++ b/vault/core.go @@ -42,6 +42,7 @@ import ( "github.com/hashicorp/vault/command/server" "github.com/hashicorp/vault/helper/metricsutil" "github.com/hashicorp/vault/helper/namespace" + "github.com/hashicorp/vault/internalshared/listenerutil" "github.com/hashicorp/vault/physical/raft" "github.com/hashicorp/vault/sdk/helper/certutil" "github.com/hashicorp/vault/sdk/helper/consts" @@ -1004,9 +1005,7 @@ func NewCore(conf *CoreConfig) (*Core, error) { c.clusterListener.Store((*cluster.Listener)(nil)) uiHeaders, _ := c.UIHeaders() - customHeaderLogger := conf.Logger.Named("customHeader") - c.allLoggers = append(c.allLoggers, customHeaderLogger) - c.customListenerHeader = NewListenerCustomHeader(conf.RawConfig.Listeners, customHeaderLogger, uiHeaders) + c.customListenerHeader = NewListenerCustomHeader(conf.RawConfig.Listeners, c.logger, uiHeaders) quotasLogger := conf.Logger.Named("quotas") c.allLoggers = append(c.allLoggers, quotasLogger) @@ -2627,18 +2626,6 @@ func (c *Core) SetLogLevel(level log.Level) { } } -func (c *Core) GetLogger(name string) log.Logger { - c.allLoggersLock.Lock() - defer c.allLoggersLock.Unlock() - for _, logger := range c.allLoggers { - ln := logger.Name() - if ln == name { - return logger - } - } - return nil -} - // SetConfig sets core's config object to the newly provided config. func (c *Core) SetConfig(conf *server.Config) { c.rawConfig.Store(conf) @@ -2651,13 +2638,36 @@ func (c *Core) SetConfig(conf *server.Config) { c.logger.Debug("set config", "sanitized config", string(bz)) } -func (c *Core) SetCustomResponseHeaders(w http.ResponseWriter, status int) { +func (c *Core) getCustomResponseHeaders(la string) []*listenerutil.ListenerCustomHeaders { if c.customListenerHeader == nil { c.logger.Debug("failed to find the custom response headers configuration") - return + return nil + } + + lch := c.customListenerHeader.getListenerMap(la) + if lch == nil { + c.logger.Warn("no listener config found", "address", la) + return nil } - c.customListenerHeader.SetCustomResponseHeaders(w, status) + return lch +} + +func (c *Core) GetListenerCustomResponseHeaders(la string) *listenerutil.ListenerCustomHeaders { + + if la == "" { + return nil + } + lch := c.getCustomResponseHeaders(la) + if lch == nil { + return nil + } + if len(lch) != 1 { + c.logger.Warn("multiple listeners with the same address configured") + return nil + } + + return lch[0] } func (c *Core) ExistCustomResponseHeader(header string, la string) bool { @@ -2666,7 +2676,21 @@ func (c *Core) ExistCustomResponseHeader(header string, la string) bool { return false } - return c.customListenerHeader.ExistCustomResponseHeader(header, la) + lch := c.getCustomResponseHeaders(la) + if lch == nil { + c.logger.Warn("no listener config found", "address", la) + return false + } + + exist := false + for _, l := range lch { + exist = l.ExistCustomResponseHeader(header) + if exist { + return true + } + } + + return exist } func (c *Core) ReloadCustomResponseHeaders() error { @@ -2681,17 +2705,11 @@ func (c *Core) ReloadCustomResponseHeaders() error { uiHeaders, _ := c.UIHeaders() - customHeaderLogger := c.GetLogger("customHeader") - if customHeaderLogger == nil { - customHeaderLogger = c.Logger().Named("customHeader") - c.AddLogger(customHeaderLogger) - } - lns := conf.(*server.Config).Listeners - c.customListenerHeader = NewListenerCustomHeader(lns, customHeaderLogger, uiHeaders) + c.customListenerHeader = NewListenerCustomHeader(lns, c.logger, uiHeaders) if c.customListenerHeader == nil { - c.logger.Error("failed to reload custom headers, reverting back the old configuration") + c.logger.Error("failed to reload custom headers. the previous configuration will be used") c.customListenerHeader = tempLH } diff --git a/vault/custom_response_headers.go b/vault/custom_response_headers.go index b57eefa979a45..44824c9ebc9a7 100644 --- a/vault/custom_response_headers.go +++ b/vault/custom_response_headers.go @@ -1,32 +1,16 @@ package vault import ( - "fmt" log "github.com/hashicorp/go-hclog" "net/http" - "net/textproto" - "strconv" "strings" "github.com/hashicorp/vault/internalshared/configutil" + "github.com/hashicorp/vault/internalshared/listenerutil" ) type ListenersCustomResponseHeadersList struct { - logger log.Logger - CustomHeadersList []*ListenerCustomHeaders -} - -type ListenerCustomHeaders struct { - Address string - StatusCodeHeaderMap map[string][]*CustomHeader - // ConfiguredHeadersStatusCodeMap field is introduced so that we would not need to loop through - // StatusCodeHeaderMap to see if a header exists, the key for this map is the headers names - ConfiguredHeadersStatusCodeMap map[string][]string -} - -type CustomHeader struct { - Name string - Value string + CustomHeadersList []*listenerutil.ListenerCustomHeaders } func NewListenerCustomHeader(ln []*configutil.Listener, logger log.Logger, uiHeaders http.Header) *ListenersCustomResponseHeadersList { @@ -35,23 +19,21 @@ func NewListenerCustomHeader(ln []*configutil.Listener, logger log.Logger, uiHea return nil } - ll := &ListenersCustomResponseHeadersList{ - logger: logger, - } + ll := &ListenersCustomResponseHeadersList{} for _, l := range ln { - lc := &ListenerCustomHeaders{ + lc := &listenerutil.ListenerCustomHeaders{ Address: l.Address, } - lc.StatusCodeHeaderMap = make(map[string][]*CustomHeader) + lc.StatusCodeHeaderMap = make(map[string][]*listenerutil.CustomHeader) lc.ConfiguredHeadersStatusCodeMap = make(map[string][]string) for sc, hv := range l.CustomResponseHeaders { - var chl []*CustomHeader + var chl []*listenerutil.CustomHeader for h, v := range hv { // Sanitizing custom headers // X-Vault- prefix is reserved for Vault internal processes if strings.HasPrefix(h, "X-Vault-") { - logger.Error("Custom headers starting with X-Vault are not valid", "header", h) + logger.Warn("Custom headers starting with X-Vault are not valid", "header", h) continue } @@ -63,7 +45,13 @@ func NewListenerCustomHeader(ln []*configutil.Listener, logger log.Logger, uiHea } } - ch := &CustomHeader{ + // Checking if the header value is not an empty string + if v == "" { + logger.Warn("header value is an empty string", "header", h, "value", v) + continue + } + + ch := &listenerutil.CustomHeader{ Name: h, Value: v, } @@ -81,76 +69,14 @@ func NewListenerCustomHeader(ln []*configutil.Listener, logger log.Logger, uiHea return ll } -func (c *ListenersCustomResponseHeadersList) SetCustomResponseHeaders(w http.ResponseWriter, status int) { - if w == nil { - c.logger.Error("No ResponseWriter provided") - return - } - - // Getting the listener address to set its corresponding custom headers - la := w.Header().Get("X-Vault-Listener-Add") - if la == "" { - c.logger.Error("X-Vault-Listener-Add was not set in the ResponseWriter") - return - } - - // Removing X-Vault-Listener-Add header from ResponseWriter - // This should be safe as the call to this function is right - // before w.WriteHeader for which the status code is finalized and known - w.Header().Del("X-Vault-Listener-Add") - - lch := c.getListenerMap(la) - if lch == nil { - c.logger.Warn("no listener config found", "address", la) - return - } - if len(lch) != 1 { - c.logger.Warn("multiple listeners with the same address configured") - return - } - sch := lch[0].StatusCodeHeaderMap - if sch == nil { - c.logger.Warn("status code header map not configured") - return - } - - // setter function to set the headers - setter := func(hvl []*CustomHeader) { - for _, hv := range hvl { - w.Header().Set(hv.Name, hv.Value) - } - } - - // Checking the validity of the status code - if status >= 600 || status < 100 { - c.logger.Error("invalid status code") - return - } - - // Setting the default headers first - setter(sch["default"]) - - // setting the Xyy pattern first - d := fmt.Sprintf("%vxx", status / 100) - if val, ok := sch[d]; ok { - setter(val) - } - // Setting the specific headers - if val, ok := sch[strconv.Itoa(status)]; ok { - setter(val) - } - - return -} - -func (c *ListenersCustomResponseHeadersList) getListenerMap(address string) []*ListenerCustomHeaders { +func (c *ListenersCustomResponseHeadersList) getListenerMap(address string) []*listenerutil.ListenerCustomHeaders { if c.CustomHeadersList == nil { return nil } // either looking for a specific listener, or if listener address isn't given, // checking for all available listeners - var lch []*ListenerCustomHeaders + var lch []*listenerutil.ListenerCustomHeaders if address == "" { lch = c.CustomHeadersList } else { @@ -165,114 +91,3 @@ func (c *ListenersCustomResponseHeadersList) getListenerMap(address string) []*L } return lch } - -func (c *ListenersCustomResponseHeadersList) findCustomHeaderMatchStatusCode(hm map[string][]*CustomHeader, sc string, hn string) string { - - getHeader := func(ch []*CustomHeader) string { - for _, h := range ch { - if h.Name == hn { - return h.Value - } - } - return "" - } - - // starting with the most specific status code - if ch, ok := hm[sc]; ok { - h := getHeader(ch) - if h != "" { - return h - } - } - - // Checking for the Yxx pattern - var firstDig string - if len(sc) == 3 { - firstDig = strings.Split(sc, "")[0] - } - if firstDig != "" { - s := fmt.Sprintf("%vxx", firstDig) - if configutil.IsValidStatusCodeCollection(s) { - if ch, ok := hm[s]; ok { - h := getHeader(ch) - if h != "" { - return h - } - } - } - } - - // At this point, we could not find a match for the given status code in the config file - // so, we just return the "default" ones - h := getHeader(hm["default"]) - if h != ""{ - return h - } - - return "" -} - -func (c *ListenersCustomResponseHeadersList) FetchCustomResponseHeaderValue(header string, sc string, la string) ([]string, error) { - - if header == "" { - return nil, fmt.Errorf("invalid target header") - } - - if c.CustomHeadersList == nil { - return nil, fmt.Errorf("core custom headers not configured") - } - - lch := c.getListenerMap(la) - if lch == nil { - return nil, fmt.Errorf("no listener found with address:%v", la) - } - - var headers []string - var err error - hn := textproto.CanonicalMIMEHeaderKey(header) - for _, l := range lch { - h := c.findCustomHeaderMatchStatusCode(l.StatusCodeHeaderMap, sc, hn) - if h == "" { - continue - } - headers = append(headers, h) - } - - return headers, err -} - -func(c *ListenersCustomResponseHeadersList) FetchHeaderForStausCode(header, sc, la string) bool { - - if !configutil.IsValidStatusCode(sc) { - c.logger.Error("failed to check if a header exist in config file due to invalid status code") - return false - } - - chv, _ := c.FetchCustomResponseHeaderValue(header, sc, la) - if chv != nil { - return true - } - - return false -} - -func (c *ListenersCustomResponseHeadersList) ExistCustomResponseHeader(header, la string) bool { - - lch := c.getListenerMap(la) - if lch == nil { - return false - } - if len(lch) != 1 { - c.logger.Warn("multiple listeners with the same address configured, checking all listeners for the custom header") - } - - hn := textproto.CanonicalMIMEHeaderKey(header) - for _, chs := range lch { - hs := chs.ConfiguredHeadersStatusCodeMap - if _, ok := hs[hn]; ok { - return true - } - } - - return false -} \ No newline at end of file diff --git a/vault/custom_response_headers_test.go b/vault/custom_response_headers_test.go index eda26cf355b41..546916203413d 100644 --- a/vault/custom_response_headers_test.go +++ b/vault/custom_response_headers_test.go @@ -77,43 +77,49 @@ func TestConfigCustomHeaders(t *testing.T) { if customListenerHeader == nil { t.Fatalf("custom header config should be configured") } + listenerCustomHeaders := customListenerHeader.getListenerMap("127.0.0.1:443") + if listenerCustomHeaders == nil || len(listenerCustomHeaders) != 1 { + t.Fatalf("failed to find listener specific custom header") + } + + lch := listenerCustomHeaders[0] - if customListenerHeader.ExistCustomResponseHeader("X-Vault-Ignored-307", "127.0.0.1:443") { + if lch.ExistCustomResponseHeader("X-Vault-Ignored-307") { t.Fatalf("header name with X-Vault prefix is not valid") } - if customListenerHeader.ExistCustomResponseHeader("X-Vault-Ignored-3xx", "127.0.0.1:443") { + if lch.ExistCustomResponseHeader("X-Vault-Ignored-3xx") { t.Fatalf("header name with X-Vault prefix is not valid") } - if !customListenerHeader.ExistCustomResponseHeader("X-Custom-Header", "127.0.0.1:443") { + if !lch.ExistCustomResponseHeader("X-Custom-Header") { t.Fatalf("header name with X-Vault prefix is not valid") } commonDefaultUiHeader := uiHeaders["Content-Security-Policy"] - commonDefaultResponseHeader, _ := customListenerHeader.FetchCustomResponseHeaderValue("Content-Security-Policy", "default", "127.0.0.1:443") + commonDefaultResponseHeader, _ := lch.FetchHeaderForStatusCode("Content-Security-Policy", "default") - if commonDefaultUiHeader[0] == commonDefaultResponseHeader[0] { + if commonDefaultUiHeader[0] == commonDefaultResponseHeader { t.Fatalf("default haeder ") } w := httptest.NewRecorder() w.Header().Set("X-Vault-Listener-Add", rawListenerConfig[0].Address) - customListenerHeader.SetCustomResponseHeaders(w, 200) + lch.SetCustomResponseHeaders(w, 200) if w.Header().Get("Someheader-200") != "200" || w.Header().Get("X-Custom-Header") != "Custom header value 200"{ t.Fatalf("response headers related to status code %v did not set properly", 200) } w = httptest.NewRecorder() w.Header().Set("X-Vault-Listener-Add", rawListenerConfig[0].Address) - customListenerHeader.SetCustomResponseHeaders(w, 204) + lch.SetCustomResponseHeaders(w, 204) if w.Header().Get("Someheader-200") == "200" || w.Header().Get("X-Custom-Header") != "Custom header value 2xx" { t.Fatalf("response headers related to status code %v did not set properly", "2xx") } w = httptest.NewRecorder() w.Header().Set("X-Vault-Listener-Add", rawListenerConfig[0].Address) - customListenerHeader.SetCustomResponseHeaders(w, 500) + lch.SetCustomResponseHeaders(w, 500) for h, v := range defaultCustomHeaders { if h != "X-Vault-Ignored" && w.Header().Get(h) != v { t.Fatalf("response headers related to status code %v did not set properly", 500) @@ -124,7 +130,6 @@ func TestConfigCustomHeaders(t *testing.T) { } } - func TestCustomResponseHeadersConfigInteractUiConfig(t *testing.T) { b := testSystemBackend(t) _, barrier, _ := mockBarrier(t) diff --git a/vault/logical_system.go b/vault/logical_system.go index b23e8cdead191..c3d5e2bda0c85 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -2627,8 +2627,9 @@ func (b *SystemBackend) handleConfigUIHeadersUpdate(ctx context.Context, req *lo // Translate the list of values to the valid header string value := http.Header{} for _, v := range values { - la := req.ResponseWriter.Header().Get("X-Vault-Listener-Add") - if b.Core.ExistCustomResponseHeader(header, la) { + // ExistCustomResponseHeader support checking for custom headers per listener address + // If address is an empty string, it checks all listeners for the header. + if b.Core.ExistCustomResponseHeader(header, "") { return logical.ErrorResponse(fmt.Sprintf("header already exist in server configuration file: %v", header)), logical.ErrInvalidRequest } value.Add(header, v) From ad93253c21df1fbbae1ba760c1bd13615f1a568b Mon Sep 17 00:00:00 2001 From: hamid ghaf Date: Tue, 14 Sep 2021 11:26:26 -0700 Subject: [PATCH 11/25] removing some unused references --- http/handler.go | 10 +++++----- http/logical.go | 5 ----- http/sys_generate_root.go | 6 +++--- vault/logical_system.go | 4 ---- 4 files changed, 8 insertions(+), 17 deletions(-) diff --git a/http/handler.go b/http/handler.go index b7a4ba8c8a87c..a385385dfe87a 100644 --- a/http/handler.go +++ b/http/handler.go @@ -167,8 +167,8 @@ func Handler(props *vault.HandlerProperties) http.Handler { } else { mux.Handle("/ui/", handleUIHeaders(core, handleUIStub())) } - mux.Handle("/ui", handleUIRedirect(core)) - mux.Handle("/", handleUIRedirect(core)) + mux.Handle("/ui", handleUIRedirect()) + mux.Handle("/", handleUIRedirect()) } @@ -348,7 +348,7 @@ func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerPr ctx = context.WithValue(ctx, "original_request_path", r.URL.Path) if listenerCustomHeaders != nil { - ctx = context.WithValue(ctx, "X-Vault-Listener-Custom-Headers-Struct", listenerCustomHeaders) + ctx = context.WithValue(ctx, "listener_custom_headers_struct", listenerCustomHeaders) } r = r.WithContext(ctx) @@ -612,7 +612,7 @@ func handleUIStub() http.Handler { }) } -func handleUIRedirect(core *vault.Core) http.Handler { +func handleUIRedirect() http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { status := 307 SetCustomResponseHeaders(w, status, req) @@ -1161,7 +1161,7 @@ func SetCustomResponseHeaders(w http.ResponseWriter, status int, r *http.Request return } ctx := r.Context() - listenerCustomHeaders := ctx.Value("X-Vault-Listener-Custom-Headers-Struct") + listenerCustomHeaders := ctx.Value("listener_custom_headers_struct") if listenerCustomHeaders != nil { lc := listenerCustomHeaders.(*listenerutil.ListenerCustomHeaders) if lc != nil { diff --git a/http/logical.go b/http/logical.go index f6d744a874a72..3db2bd4a2accb 100644 --- a/http/logical.go +++ b/http/logical.go @@ -98,11 +98,6 @@ func buildLogicalRequestNoAuth(perfStandby bool, w http.ResponseWriter, r *http. bufferedBody := newBufferedReader(r.Body) r.Body = bufferedBody - // response writer is needed when updating ui headers to make sure it - // does not interfere with custom response headers set in the configuration file - if strings.HasPrefix(path,"sys/config/ui") { - responseWriter = w - } // If we are uploading a snapshot we don't want to parse it. Instead // we will simply add the HTTP request to the logical request object // for later consumption. diff --git a/http/sys_generate_root.go b/http/sys_generate_root.go index b91e94dd651c5..881af4cfb7145 100644 --- a/http/sys_generate_root.go +++ b/http/sys_generate_root.go @@ -122,9 +122,9 @@ func handleSysGenerateRootAttemptPut(core *vault.Core, w http.ResponseWriter, r func handleSysGenerateRootAttemptDelete(core *vault.Core, w http.ResponseWriter, r *http.Request) { - errNew := core.GenerateRootCancel() - if errNew != nil { - respondError(w, http.StatusInternalServerError, errNew, r) + err := core.GenerateRootCancel() + if err != nil { + respondError(w, http.StatusInternalServerError, err, r) return } respondOk(w, nil, r) diff --git a/vault/logical_system.go b/vault/logical_system.go index c3d5e2bda0c85..3a9954f4761de 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -2620,10 +2620,6 @@ func (b *SystemBackend) handleConfigUIHeadersUpdate(ctx context.Context, req *lo return logical.ErrorResponse("X-Vault headers cannot be set"), logical.ErrInvalidRequest } - if req.ResponseWriter == nil { - return logical.ErrorResponse("no ResponseWriter in the request"), logical.ErrInvalidRequest - } - // Translate the list of values to the valid header string value := http.Header{} for _, v := range values { From 0905ab2a920793d29c700e52b48dd958d4efc45b Mon Sep 17 00:00:00 2001 From: hamid ghaf Date: Tue, 14 Sep 2021 12:45:56 -0700 Subject: [PATCH 12/25] fixing a test --- command/server/config_custom_response_headers_test.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/command/server/config_custom_response_headers_test.go b/command/server/config_custom_response_headers_test.go index 1b6c395cddd24..9a881f34eb0f9 100644 --- a/command/server/config_custom_response_headers_test.go +++ b/command/server/config_custom_response_headers_test.go @@ -1,6 +1,7 @@ package server import ( + "fmt" "github.com/go-test/deep" "testing" ) @@ -12,7 +13,7 @@ var defaultCustomHeaders = map[string]string { "X-Custom-Header": "Custom header value default", "X-Frame-Options": "Deny", "X-Content-Type-Options": "nosniff", - "Content-Type": "text/plain; charset=utf-8", + "Content-Type": "application/json", "X-XSS-Protection": "1; mode=block", } @@ -54,7 +55,7 @@ func TestCustomResponseHeadersConfigs(t *testing.T) { t.Fatalf("Error encountered when loading config %+v", err) } if diff := deep.Equal(expectedCustomResponseHeader, config.Listeners[0].CustomResponseHeaders); diff != nil { - t.Fatalf("parsed custom headers do not match the expected ones") + t.Fatalf(fmt.Sprintf("parsed custom headers do not match the expected ones, difference: %v", diff)) } } From 33a4aa53c78b588343ef0806a4607a82153247a8 Mon Sep 17 00:00:00 2001 From: hamid ghaf Date: Tue, 14 Sep 2021 15:29:41 -0700 Subject: [PATCH 13/25] changing some error messages, removing a default header value from /ui --- internalshared/configutil/http_response_headers.go | 2 +- vault/custom_response_headers.go | 4 ++-- vault/logical_system.go | 2 +- vault/ui.go | 1 - 4 files changed, 4 insertions(+), 5 deletions(-) diff --git a/internalshared/configutil/http_response_headers.go b/internalshared/configutil/http_response_headers.go index 12de78be4faa7..3a18561b474d6 100644 --- a/internalshared/configutil/http_response_headers.go +++ b/internalshared/configutil/http_response_headers.go @@ -90,7 +90,7 @@ func ParseCustomResponseHeaders(r interface{}) (map[string]map[string]string, er } if !IsValidStatusCode(statusCode) { - return nil, fmt.Errorf("invalid status code found in the config file: %v", statusCode) + return nil, fmt.Errorf("invalid status code found in the server configuration: %v", statusCode) } hvl := responseHeader.([]map[string]interface{}) diff --git a/vault/custom_response_headers.go b/vault/custom_response_headers.go index 44824c9ebc9a7..7acd3a9723669 100644 --- a/vault/custom_response_headers.go +++ b/vault/custom_response_headers.go @@ -33,7 +33,7 @@ func NewListenerCustomHeader(ln []*configutil.Listener, logger log.Logger, uiHea // Sanitizing custom headers // X-Vault- prefix is reserved for Vault internal processes if strings.HasPrefix(h, "X-Vault-") { - logger.Warn("Custom headers starting with X-Vault are not valid", "header", h) + logger.Warn("custom headers starting with X-Vault are not valid", "header", h) continue } @@ -41,7 +41,7 @@ func NewListenerCustomHeader(ln []*configutil.Listener, logger log.Logger, uiHea if uiHeaders != nil { exist := uiHeaders.Get(h) if exist != "" { - logger.Warn("found a duplicate header in UI, note that config file headers take precedence.", "header:", h) + logger.Warn("found a duplicate header in UI. Headers defined in the server configuration take precedence.", "header:", h) } } diff --git a/vault/logical_system.go b/vault/logical_system.go index 3a9954f4761de..e554c38f1f7e0 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -2626,7 +2626,7 @@ func (b *SystemBackend) handleConfigUIHeadersUpdate(ctx context.Context, req *lo // ExistCustomResponseHeader support checking for custom headers per listener address // If address is an empty string, it checks all listeners for the header. if b.Core.ExistCustomResponseHeader(header, "") { - return logical.ErrorResponse(fmt.Sprintf("header already exist in server configuration file: %v", header)), logical.ErrInvalidRequest + return logical.ErrorResponse(fmt.Sprintf("This header already exist in server configuration. %v", header)), logical.ErrInvalidRequest } value.Add(header, v) } diff --git a/vault/ui.go b/vault/ui.go index c36a247af304a..e845420864f33 100644 --- a/vault/ui.go +++ b/vault/ui.go @@ -32,7 +32,6 @@ type UIConfig struct { // NewUIConfig creates a new UI config func NewUIConfig(enabled bool, physicalStorage physical.Backend, barrierStorage logical.Storage) *UIConfig { defaultHeaders := http.Header{} - defaultHeaders.Set("Content-Security-Policy", "default-src 'none'; connect-src 'self'; img-src 'self' data:; script-src 'self'; style-src 'unsafe-inline' 'self'; form-action 'none'; frame-ancestors 'none'; font-src 'self'") defaultHeaders.Set("Service-Worker-Allowed", "/") return &UIConfig{ From 92867b478c9825ce6bb410cf945f0725f62298b3 Mon Sep 17 00:00:00 2001 From: hamid ghaf Date: Tue, 14 Sep 2021 16:05:25 -0700 Subject: [PATCH 14/25] fixing a test --- vault/custom_response_headers_test.go | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/vault/custom_response_headers_test.go b/vault/custom_response_headers_test.go index 546916203413d..de8bf6b273d55 100644 --- a/vault/custom_response_headers_test.go +++ b/vault/custom_response_headers_test.go @@ -2,6 +2,7 @@ package vault import ( "context" + "fmt" "net/http/httptest" "strings" "testing" @@ -95,30 +96,19 @@ func TestConfigCustomHeaders(t *testing.T) { t.Fatalf("header name with X-Vault prefix is not valid") } - commonDefaultUiHeader := uiHeaders["Content-Security-Policy"] - commonDefaultResponseHeader, _ := lch.FetchHeaderForStatusCode("Content-Security-Policy", "default") - - if commonDefaultUiHeader[0] == commonDefaultResponseHeader { - t.Fatalf("default haeder ") - } - w := httptest.NewRecorder() - w.Header().Set("X-Vault-Listener-Add", rawListenerConfig[0].Address) - lch.SetCustomResponseHeaders(w, 200) if w.Header().Get("Someheader-200") != "200" || w.Header().Get("X-Custom-Header") != "Custom header value 200"{ t.Fatalf("response headers related to status code %v did not set properly", 200) } w = httptest.NewRecorder() - w.Header().Set("X-Vault-Listener-Add", rawListenerConfig[0].Address) lch.SetCustomResponseHeaders(w, 204) if w.Header().Get("Someheader-200") == "200" || w.Header().Get("X-Custom-Header") != "Custom header value 2xx" { t.Fatalf("response headers related to status code %v did not set properly", "2xx") } w = httptest.NewRecorder() - w.Header().Set("X-Vault-Listener-Add", rawListenerConfig[0].Address) lch.SetCustomResponseHeaders(w, 500) for h, v := range defaultCustomHeaders { if h != "X-Vault-Ignored" && w.Header().Get(h) != v { @@ -166,7 +156,6 @@ func TestCustomResponseHeadersConfigInteractUiConfig(t *testing.T) { } w := httptest.NewRecorder() - w.Header().Set("X-Vault-Listener-Add", "127.0.0.1:443") hw := logical.NewHTTPResponseWriter(w) // setting a header that already exist in custom headers @@ -178,7 +167,7 @@ func TestCustomResponseHeadersConfigInteractUiConfig(t *testing.T) { if err == nil { t.Fatal("request did not fail on setting a header that is present in custom response headers") } - if !strings.Contains(resp.Data["error"].(string), "header already exist in server configuration file") { + if !strings.Contains(resp.Data["error"].(string), fmt.Sprintf("This header already exist in server configuration. %v", "X-Custom-Header")) { t.Fatalf("failed to get the expected error") } From 2ae11ea5ab6f2bcbc1338caafe1baaf7d2f968df Mon Sep 17 00:00:00 2001 From: hamid ghaf Date: Wed, 15 Sep 2021 17:49:06 -0700 Subject: [PATCH 15/25] wrapping ResponseWriter to set the custom headers --- command/agent.go | 8 +- command/agent/cache/handler.go | 17 +- command/agent/cache/lease_cache.go | 14 +- http/cors.go | 6 +- http/handler.go | 227 ++++++++++++++++---------- http/handler_test.go | 6 +- http/help.go | 6 +- http/logical.go | 23 +-- http/sys_feature_flags.go | 6 +- http/sys_generate_root.go | 34 ++-- http/sys_health.go | 8 +- http/sys_init.go | 14 +- http/sys_leader.go | 6 +- http/sys_metrics.go | 7 +- http/sys_raft.go | 20 +-- http/sys_rekey.go | 72 ++++---- http/sys_seal.go | 40 ++--- http/testing.go | 9 +- http/util.go | 10 +- sdk/logical/response_util.go | 1 + vault/core.go | 5 +- vault/custom_response_headers.go | 118 ++++++++++++- vault/custom_response_headers_test.go | 22 --- vault/testing.go | 55 +++++++ 24 files changed, 452 insertions(+), 282 deletions(-) diff --git a/command/agent.go b/command/agent.go index 2cd7b0e94b420..cbbcba5757b4a 100644 --- a/command/agent.go +++ b/command/agent.go @@ -2,6 +2,7 @@ package command import ( "context" + "errors" "flag" "fmt" "io" @@ -877,10 +878,9 @@ func (c *AgentCommand) Run(args []string) int { func verifyRequestHeader(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if val, ok := r.Header[consts.RequestHeaderName]; !ok || len(val) != 1 || val[0] != "true" { - err := fmt.Errorf("missing '%s' header", consts.RequestHeaderName) - status := http.StatusPreconditionFailed - logical.AdjustErrorStatusCode(&status, err) - logical.RespondError(w, status, err) + logical.RespondError(w, + http.StatusPreconditionFailed, + errors.New(fmt.Sprintf("missing '%s' header", consts.RequestHeaderName))) return } diff --git a/command/agent/cache/handler.go b/command/agent/cache/handler.go index e33c9beffab86..73062df41fbd0 100644 --- a/command/agent/cache/handler.go +++ b/command/agent/cache/handler.go @@ -37,11 +37,8 @@ func Handler(ctx context.Context, logger hclog.Logger, proxier Proxier, inmemSin // Parse and reset body. reqBody, err := ioutil.ReadAll(r.Body) if err != nil { - errRet := errors.New("failed to read request body") - logger.Error(errRet.Error()) - status := http.StatusInternalServerError - logical.AdjustErrorStatusCode(&status, errRet) - logical.RespondError(w, status, errRet) + logger.Error("failed to read request body") + logical.RespondError(w, http.StatusInternalServerError, errors.New("failed to read request body")) return } if r.Body != nil { @@ -62,20 +59,14 @@ func Handler(ctx context.Context, logger hclog.Logger, proxier Proxier, inmemSin w.WriteHeader(resp.Response.StatusCode) io.Copy(w, resp.Response.Body) } else { - status := http.StatusInternalServerError - errNew := fmt.Errorf("failed to get the response: %w", err) - logical.AdjustErrorStatusCode(&status, errNew) - logical.RespondError(w, status, errNew) + logical.RespondError(w, http.StatusInternalServerError, fmt.Errorf("failed to get the response: %w", err)) } return } err = processTokenLookupResponse(ctx, logger, inmemSink, req, resp) if err != nil { - status := http.StatusInternalServerError - errNew := fmt.Errorf("failed to process token lookup response: %w", err) - logical.AdjustErrorStatusCode(&status, errNew) - logical.RespondError(w, status, errNew) + logical.RespondError(w, http.StatusInternalServerError, fmt.Errorf("failed to process token lookup response: %w", err)) return } diff --git a/command/agent/cache/lease_cache.go b/command/agent/cache/lease_cache.go index a4e739378a27b..a8b2d4bd88cea 100644 --- a/command/agent/cache/lease_cache.go +++ b/command/agent/cache/lease_cache.go @@ -576,10 +576,7 @@ func (c *LeaseCache) HandleCacheClear(ctx context.Context) http.Handler { if err == io.EOF { err = errors.New("empty JSON provided") } - status := http.StatusBadRequest - errNew := fmt.Errorf("failed to parse JSON input: %w", err) - logical.AdjustErrorStatusCode(&status, errNew) - logical.RespondError(w, status, errNew) + logical.RespondError(w, http.StatusBadRequest, fmt.Errorf("failed to parse JSON input: %w", err)) return } @@ -588,10 +585,7 @@ func (c *LeaseCache) HandleCacheClear(ctx context.Context) http.Handler { in, err := parseCacheClearInput(req) if err != nil { c.logger.Error("unable to parse clear input", "error", err) - status := http.StatusBadRequest - errNew := fmt.Errorf("failed to parse clear input: %w", err) - logical.AdjustErrorStatusCode(&status, errNew) - logical.RespondError(w, status, errNew) + logical.RespondError(w, http.StatusBadRequest, fmt.Errorf("failed to parse clear input: %w", err)) return } @@ -602,9 +596,7 @@ func (c *LeaseCache) HandleCacheClear(ctx context.Context) http.Handler { if err == errInvalidType { httpStatus = http.StatusBadRequest } - errNew := fmt.Errorf("failed to clear cache: %w", err) - logical.AdjustErrorStatusCode(&httpStatus, errNew) - logical.RespondError(w, httpStatus, errNew) + logical.RespondError(w, httpStatus, fmt.Errorf("failed to clear cache: %w", err)) return } diff --git a/http/cors.go b/http/cors.go index ed48b31228a14..74cfeeaef072e 100644 --- a/http/cors.go +++ b/http/cors.go @@ -40,14 +40,12 @@ func wrapCORSHandler(h http.Handler, core *vault.Core) http.Handler { // Return a 403 if the origin is not allowed to make cross-origin requests. if !corsConf.IsValidOrigin(origin) { - respondError(w, http.StatusForbidden, fmt.Errorf("origin not allowed"), req) + respondError(w, http.StatusForbidden, fmt.Errorf("origin not allowed")) return } if req.Method == http.MethodOptions && !strutil.StrListContains(allowedMethods, requestMethod) { - status := http.StatusMethodNotAllowed - SetCustomResponseHeaders(w, status, req) - w.WriteHeader(status) + w.WriteHeader(http.StatusMethodNotAllowed) return } diff --git a/http/handler.go b/http/handler.go index a385385dfe87a..c6d363c9a0b69 100644 --- a/http/handler.go +++ b/http/handler.go @@ -16,17 +16,18 @@ import ( "net/textproto" "net/url" "os" + "strconv" "strings" "time" "github.com/NYTimes/gziphandler" "github.com/hashicorp/errwrap" "github.com/hashicorp/go-cleanhttp" + log "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-secure-stdlib/parseutil" "github.com/hashicorp/go-sockaddr" "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/internalshared/configutil" - "github.com/hashicorp/vault/internalshared/listenerutil" "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/jsonutil" "github.com/hashicorp/vault/sdk/helper/pathmanager" @@ -196,8 +197,7 @@ func Handler(props *vault.HandlerProperties) http.Handler { } // Wrap the handler in another handler to trigger all help paths. - unregisteredPathsHandler := wrapUnregisteredPathsHandler(mux, core) - helpWrappedHandler := wrapHelpHandler(unregisteredPathsHandler, core) + helpWrappedHandler := wrapHelpHandler(mux, core) corsWrappedHandler := wrapCORSHandler(helpWrappedHandler, core) quotaWrappedHandler := rateLimitQuotaWrapping(corsWrappedHandler, core) genericWrappedHandler := genericWrapping(core, quotaWrappedHandler, props) @@ -212,6 +212,93 @@ func Handler(props *vault.HandlerProperties) http.Handler { return printablePathCheckHandler } +type WrappingResponseWriter interface { + http.ResponseWriter + Wrapped() http.ResponseWriter +} + +type statusHeaderResponseWriter struct { + wrapped http.ResponseWriter + logger log.Logger + wroteHeader bool + statusCode int + headers map[string][]*vault.CustomHeader +} + +func (w statusHeaderResponseWriter) Wrapped() http.ResponseWriter { + return w.wrapped +} + +func (w statusHeaderResponseWriter) Header() http.Header { + return w.wrapped.Header() +} + +func (w statusHeaderResponseWriter) Write(buf []byte) (int, error) { + // It is allowed to only call ResponseWriter.Write and skip ResponseWriter.WriteHeader + // An example of such a situation is "handleUIStub". The Write function will internally + // set the status code 200 for the response, and that call might invoke other implementations + // the WriteHeader function normally for some sort of IO writer. + // So, we still need to set the custom headers. + // In cases where both WriteHeader and Write of statusHeaderResponseWriter struct are called + // the internal call to the WriterHeader invoked from inside Write method won't change + // the headers. + if !w.wroteHeader { + w.setCustomResponseHeaders(w.statusCode) + } + + return w.wrapped.Write(buf) +} + +func (w statusHeaderResponseWriter) WriteHeader(statusCode int) { + w.setCustomResponseHeaders(statusCode) + w.wrapped.WriteHeader(statusCode) + w.statusCode = statusCode + // in cases where Write is called after WriteHeader, let's prevent setting + // ResponseWriter headers twice + if !w.wroteHeader { + w.wroteHeader = true + } +} + +func (w statusHeaderResponseWriter) setCustomResponseHeaders(status int) { + + sch := w.headers + if sch == nil { + w.logger.Warn("status code header map not configured") + return + } + + // setter function to set the headers + setter := func(hvl []*vault.CustomHeader) { + for _, hv := range hvl { + w.Header().Set(hv.Name, hv.Value) + } + } + + // Checking the validity of the status code + if status >= 600 || status < 100 { + return + } + + // Setting the default headers first + setter(sch["default"]) + + // setting the Xyy pattern first + d := fmt.Sprintf("%vxx", status / 100) + if val, ok := sch[d]; ok { + setter(val) + } + + // Setting the specific headers + if val, ok := sch[strconv.Itoa(status)]; ok { + setter(val) + } + + return +} + +var _ WrappingResponseWriter = &statusHeaderResponseWriter{} + type copyResponseWriter struct { wrapped http.ResponseWriter statusCode int @@ -250,7 +337,7 @@ func handleAuditNonLogical(core *vault.Core, h http.Handler) http.Handler { req, _, status, err := buildLogicalRequestNoAuth(core.PerfStandby(), w, r) if err != nil || status != 0 { - respondError(w, status, err, r) + respondError(w, status, err) return } if origBody != nil { @@ -261,7 +348,7 @@ func handleAuditNonLogical(core *vault.Core, h http.Handler) http.Handler { } err = core.AuditLogger().AuditRequest(r.Context(), input) if err != nil { - respondError(w, status, err, r) + respondError(w, status, err) return } cw := newCopyResponseWriter(w) @@ -275,26 +362,8 @@ func handleAuditNonLogical(core *vault.Core, h http.Handler) http.Handler { input.Response = logical.HTTPResponseToLogicalResponse(httpResp) err = core.AuditLogger().AuditResponse(r.Context(), input) if err != nil { - respondError(w, status, err, r) - } - return - }) -} - -// wrapUnregisteredPathsHandler is the last layer before the endpoint direct handlers -// This is to prevent response headers being overwritten unintentionally -func wrapUnregisteredPathsHandler(h http.Handler, core *vault.Core) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // According to the net/http package, if a subtree has been registered and a request - // is received naming the subtree root without its trailing slash, ServeMux - // redirects that request to the subtree root (adding the trailing slash). - // Calling "/v1/sys" endpoint will be redirected to "/v1/sys/" with status - // 301 (Moved Permanently), however, the status code is set inside the net/http - // package. So, we set the custom response headers here instead. - if r.URL.Path == "/v1/sys" { - SetCustomResponseHeaders(w, 301, r) + respondError(w, status, err) } - h.ServeHTTP(w, r) return }) } @@ -316,17 +385,29 @@ func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerPr maxRequestSize = DefaultMaxRequestSize } - var listenerCustomHeaders *listenerutil.ListenerCustomHeaders - if props.ListenerConfig != nil { - la := props.ListenerConfig.Address - listenerCustomHeaders = core.GetListenerCustomResponseHeaders(la) - } // Swallow this error since we don't want to pollute the logs and we also don't want to // return an HTTP error here. This information is best effort. hostname, _ := os.Hostname() return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + + // This block needs to be here so that upon sending SIGHUP, custom response + // headers are also reloaded into the handlers. + if props.ListenerConfig != nil { + la := props.ListenerConfig.Address + listenerCustomHeaders := core.GetListenerCustomResponseHeaders(la) + if listenerCustomHeaders != nil { + w = &statusHeaderResponseWriter{ + wrapped: w, + logger: core.Logger(), + wroteHeader: false, + statusCode: 200, + headers: listenerCustomHeaders.StatusCodeHeaderMap, + } + } + } + // Set the Cache-Control header for all the responses returned // by Vault w.Header().Set("Cache-Control", "no-store") @@ -347,10 +428,6 @@ func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerPr } ctx = context.WithValue(ctx, "original_request_path", r.URL.Path) - if listenerCustomHeaders != nil { - ctx = context.WithValue(ctx, "listener_custom_headers_struct", listenerCustomHeaders) - } - r = r.WithContext(ctx) r = r.WithContext(namespace.ContextWithNamespace(r.Context(), namespace.RootNamespace)) @@ -370,7 +447,7 @@ func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerPr case strings.HasPrefix(r.URL.Path, "/v1/"): newR, status := adjustRequest(core, r) if status != 0 { - respondError(w, status, nil, r) + respondError(w, status, nil) cancelFunc() return } @@ -378,7 +455,7 @@ func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerPr case strings.HasPrefix(r.URL.Path, "/ui"), r.URL.Path == "/robots.txt", r.URL.Path == "/": default: - respondError(w, http.StatusNotFound, nil, r) + respondError(w, http.StatusNotFound, nil) cancelFunc() return } @@ -408,7 +485,7 @@ func WrapForwardedForHandler(h http.Handler, l *configutil.Listener) http.Handle h.ServeHTTP(w, r) return } - respondError(w, http.StatusBadRequest, fmt.Errorf("missing x-forwarded-for header and configured to reject when not present"), r) + respondError(w, http.StatusBadRequest, fmt.Errorf("missing x-forwarded-for header and configured to reject when not present")) return } @@ -421,7 +498,7 @@ func WrapForwardedForHandler(h http.Handler, l *configutil.Listener) http.Handle h.ServeHTTP(w, r) return } - respondError(w, http.StatusBadRequest, fmt.Errorf("error parsing client hostport: %w", err), r) + respondError(w, http.StatusBadRequest, fmt.Errorf("error parsing client hostport: %w", err)) return } @@ -432,7 +509,7 @@ func WrapForwardedForHandler(h http.Handler, l *configutil.Listener) http.Handle h.ServeHTTP(w, r) return } - respondError(w, http.StatusBadRequest, fmt.Errorf("error parsing client address: %w", err), r) + respondError(w, http.StatusBadRequest, fmt.Errorf("error parsing client address: %w", err)) return } @@ -450,7 +527,7 @@ func WrapForwardedForHandler(h http.Handler, l *configutil.Listener) http.Handle h.ServeHTTP(w, r) return } - respondError(w, http.StatusBadRequest, fmt.Errorf("client address not authorized for x-forwarded-for and configured to reject connection"), r) + respondError(w, http.StatusBadRequest, fmt.Errorf("client address not authorized for x-forwarded-for and configured to reject connection")) return } @@ -478,7 +555,7 @@ func WrapForwardedForHandler(h http.Handler, l *configutil.Listener) http.Handle // authorized (or we've turned off explicit rejection) and we // should assume that what comes in should be properly // formatted. - respondError(w, http.StatusBadRequest, fmt.Errorf("malformed x-forwarded-for configuration or request, hops to skip (%d) would skip before earliest chain link (chain length %d)", hopSkips, len(headers)), r) + respondError(w, http.StatusBadRequest, fmt.Errorf("malformed x-forwarded-for configuration or request, hops to skip (%d) would skip before earliest chain link (chain length %d)", hopSkips, len(headers))) return } @@ -509,7 +586,7 @@ func handleUIHeaders(core *vault.Core, h http.Handler) http.Handler { userHeaders, err := core.UIHeaders() if err != nil { - respondError(w, http.StatusInternalServerError, err, req) + respondError(w, http.StatusInternalServerError, err) return } if userHeaders != nil { @@ -519,12 +596,6 @@ func handleUIHeaders(core *vault.Core, h http.Handler) http.Handler { } } - // This function wraps handleUI and handleUIStub which do not set the - // status code specifically, instead, a call to w.Write is called which - // internally also sets the status code to 200. - // Just setting the headers for status code 200. - SetCustomResponseHeaders(w, 200, req) - h.ServeHTTP(w, req) }) } @@ -614,9 +685,7 @@ func handleUIStub() http.Handler { func handleUIRedirect() http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - status := 307 - SetCustomResponseHeaders(w, status, req) - http.Redirect(w, req, "/ui/", status) + http.Redirect(w, req, "/ui/", 307) return }) } @@ -673,7 +742,15 @@ func parseJSONRequest(perfStandby bool, r *http.Request, w http.ResponseWriter, return nil, errors.New("could not parse max_request_size from request context") } if max > 0 { - reader = http.MaxBytesReader(w, r.Body, max) + // MaxBytesReader won't do all the internal stuff it must unless it's + // given a ResponseWriter that implements the internal http interface + // requestTooLarger. So we let it have access to the underlying + // ResponseWriter. + inw := w + if myw, ok := inw.(WrappingResponseWriter); ok { + inw = myw.Wrapped() + } + reader = http.MaxBytesReader(inw, r.Body, max) } } var origBody io.ReadWriter @@ -771,7 +848,7 @@ func handleRequestForwarding(core *vault.Core, handler http.Handler) http.Handle // Note if the client requested forwarding shouldForward, err := forwardBasedOnHeaders(core, r) if err != nil { - respondError(w, http.StatusBadRequest, err, r) + respondError(w, http.StatusBadRequest, err) return } @@ -779,7 +856,7 @@ func handleRequestForwarding(core *vault.Core, handler http.Handler) http.Handle if core.PerfStandby() && !shouldForward { ns, err := namespace.FromContext(r.Context()) if err != nil { - respondError(w, http.StatusBadRequest, err, r) + respondError(w, http.StatusBadRequest, err) return } path := ns.TrimmedPath(r.URL.Path[len("/v1/"):]) @@ -807,7 +884,7 @@ func handleRequestForwarding(core *vault.Core, handler http.Handler) http.Handle return } // Some internal error occurred - respondError(w, http.StatusInternalServerError, err, r) + respondError(w, http.StatusInternalServerError, err) return } if isLeader { @@ -816,7 +893,7 @@ func handleRequestForwarding(core *vault.Core, handler http.Handler) http.Handle return } if leaderAddr == "" { - respondError(w, http.StatusInternalServerError, fmt.Errorf("local node not active but active cluster node not found"), r) + respondError(w, http.StatusInternalServerError, fmt.Errorf("local node not active but active cluster node not found")) return } @@ -841,7 +918,7 @@ func forwardRequest(core *vault.Core, w http.ResponseWriter, r *http.Request) { ns, err := namespace.FromContext(r.Context()) if err != nil { - respondError(w, http.StatusBadRequest, err, r) + respondError(w, http.StatusBadRequest, err) return } path := ns.TrimmedPath(r.URL.Path[len("/v1/"):]) @@ -872,7 +949,6 @@ func forwardRequest(core *vault.Core, w http.ResponseWriter, r *http.Request) { } } - SetCustomResponseHeaders(w, statusCode, r) w.WriteHeader(statusCode) w.Write(retBytes) } @@ -929,7 +1005,7 @@ func request(core *vault.Core, w http.ResponseWriter, rawReq *http.Request, r *l return nil, true, false } - if respondErrorCommon(w, r, resp, err, rawReq) { + if respondErrorCommon(w, r, resp, err) { return resp, false, false } @@ -946,25 +1022,25 @@ func respondStandby(core *vault.Core, w http.ResponseWriter, req *http.Request) if err == vault.ErrHANotEnabled { // Standalone node, serve 503 err = errors.New("node is not active") - respondError(w, http.StatusServiceUnavailable, err, req) + respondError(w, http.StatusServiceUnavailable, err) return } - respondError(w, http.StatusInternalServerError, err, req) + respondError(w, http.StatusInternalServerError, err) return } // If there is no leader, generate a 503 error if redirectAddr == "" { err = errors.New("no active Vault instance found") - respondError(w, http.StatusServiceUnavailable, err, req) + respondError(w, http.StatusServiceUnavailable, err) return } // Parse the redirect location redirectURL, err := url.Parse(redirectAddr) if err != nil { - respondError(w, http.StatusInternalServerError, err, req) + respondError(w, http.StatusInternalServerError, err) return } @@ -985,7 +1061,6 @@ func respondStandby(core *vault.Core, w http.ResponseWriter, req *http.Request) // because we don't actually know if its permanent and // the request method should be preserved. w.Header().Set("Location", finalURL.String()) - SetCustomResponseHeaders(w, 307, req) w.WriteHeader(307) } @@ -1150,37 +1225,21 @@ func isForm(head []byte, contentType string) bool { return true } -func respondError(w http.ResponseWriter, status int, err error, r *http.Request) { - logical.AdjustErrorStatusCode(&status, err) - SetCustomResponseHeaders(w, status, r) +func respondError(w http.ResponseWriter, status int, err error) { logical.RespondError(w, status, err) } -func SetCustomResponseHeaders(w http.ResponseWriter, status int, r *http.Request) { - if r == nil { - return - } - ctx := r.Context() - listenerCustomHeaders := ctx.Value("listener_custom_headers_struct") - if listenerCustomHeaders != nil { - lc := listenerCustomHeaders.(*listenerutil.ListenerCustomHeaders) - if lc != nil { - lc.SetCustomResponseHeaders(w, status) - } - } -} - -func respondErrorCommon(w http.ResponseWriter, req *logical.Request, resp *logical.Response, err error, r *http.Request) bool { +func respondErrorCommon(w http.ResponseWriter, req *logical.Request, resp *logical.Response, err error) bool { statusCode, newErr := logical.RespondErrorCommon(req, resp, err) if newErr == nil && statusCode == 0 { return false } - respondError(w, statusCode, newErr, r) + respondError(w, statusCode, newErr) return true } -func respondOk(w http.ResponseWriter, body interface{}, r *http.Request) { +func respondOk(w http.ResponseWriter, body interface{}) { w.Header().Set("Content-Type", "application/json") var status int @@ -1190,10 +1249,10 @@ func respondOk(w http.ResponseWriter, body interface{}, r *http.Request) { status = http.StatusOK } - SetCustomResponseHeaders(w, status, r) w.WriteHeader(status) if body != nil { + enc := json.NewEncoder(w) enc.Encode(body) } diff --git a/http/handler_test.go b/http/handler_test.go index 01d1866eb7b42..c228629ea8dce 100644 --- a/http/handler_test.go +++ b/http/handler_test.go @@ -599,7 +599,7 @@ func TestHandler_ui_enabled(t *testing.T) { func TestHandler_error(t *testing.T) { w := httptest.NewRecorder() - respondError(w, 500, errors.New("test Error"), nil) + respondError(w, 500, errors.New("test Error")) if w.Code != 500 { t.Fatalf("expected 500, got %d", w.Code) @@ -610,7 +610,7 @@ func TestHandler_error(t *testing.T) { w2 := httptest.NewRecorder() e := logical.CodedError(403, "error text") - respondError(w2, 500, e, nil) + respondError(w2, 500, e) if w2.Code != 403 { t.Fatalf("expected 403, got %d", w2.Code) @@ -619,7 +619,7 @@ func TestHandler_error(t *testing.T) { // vault.ErrSealed is a special case w3 := httptest.NewRecorder() - respondError(w3, 400, consts.ErrSealed, nil) + respondError(w3, 400, consts.ErrSealed) if w3.Code != 503 { t.Fatalf("expected 503, got %d", w3.Code) diff --git a/http/help.go b/http/help.go index 6e2903d5faef1..45099bd7b67f5 100644 --- a/http/help.go +++ b/http/help.go @@ -28,7 +28,7 @@ func wrapHelpHandler(h http.Handler, core *vault.Core) http.Handler { func handleHelp(core *vault.Core, w http.ResponseWriter, r *http.Request) { ns, err := namespace.FromContext(r.Context()) if err != nil { - respondError(w, http.StatusBadRequest, nil, r) + respondError(w, http.StatusBadRequest, nil) return } path := ns.TrimmedPath(r.URL.Path[len("/v1/"):]) @@ -42,9 +42,9 @@ func handleHelp(core *vault.Core, w http.ResponseWriter, r *http.Request) { resp, err := core.HandleRequest(r.Context(), req) if err != nil { - respondErrorCommon(w, req, resp, err, r) + respondErrorCommon(w, req, resp, err) return } - respondOk(w, resp.Data, r) + respondOk(w, resp.Data) } diff --git a/http/logical.go b/http/logical.go index 3db2bd4a2accb..dd9abce34dfdb 100644 --- a/http/logical.go +++ b/http/logical.go @@ -268,17 +268,17 @@ func handleLogicalRecovery(raw *vault.RawBackend, token *atomic.String) http.Han return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { req, _, statusCode, err := buildLogicalRequestNoAuth(false, w, r) if err != nil || statusCode != 0 { - respondError(w, statusCode, err, r) + respondError(w, statusCode, err) return } reqToken := r.Header.Get(consts.AuthHeaderName) if reqToken == "" || token.Load() == "" || reqToken != token.Load() { - respondError(w, http.StatusForbidden, nil, r) + respondError(w, http.StatusForbidden, nil) return } resp, err := raw.HandleRequest(r.Context(), req) - if respondErrorCommon(w, req, resp, err, r) { + if respondErrorCommon(w, req, resp, err) { return } @@ -287,7 +287,7 @@ func handleLogicalRecovery(raw *vault.RawBackend, token *atomic.String) http.Han httpResp = logical.LogicalResponseToHTTPResponse(resp) httpResp.RequestID = req.ID } - respondOk(w, httpResp, r) + respondOk(w, httpResp) }) } @@ -298,7 +298,7 @@ func handleLogicalInternal(core *vault.Core, injectDataIntoTopLevel bool, noForw return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { req, origBody, statusCode, err := buildLogicalRequest(core, w, r) if err != nil || statusCode != 0 { - respondError(w, statusCode, err, r) + respondError(w, statusCode, err) return } @@ -310,7 +310,7 @@ func handleLogicalInternal(core *vault.Core, injectDataIntoTopLevel bool, noForw resp, ok, needsForward := request(core, w, r, req) switch { case needsForward && noForward: - respondError(w, http.StatusBadRequest, vault.ErrCannotForwardLocalOnly, r) + respondError(w, http.StatusBadRequest, vault.ErrCannotForwardLocalOnly) return case needsForward && !noForward: if origBody != nil { @@ -345,9 +345,7 @@ func respondLogical(core *vault.Core, w http.ResponseWriter, r *http.Request, re if resp.Redirect != "" { // If we have a redirect, redirect! We use a 307 code // because we don't actually know if its permanent. - status := 307 - SetCustomResponseHeaders(w, status, r) - http.Redirect(w, r, resp.Redirect, status) + http.Redirect(w, r, resp.Redirect, 307) return } @@ -386,7 +384,7 @@ func respondLogical(core *vault.Core, w http.ResponseWriter, r *http.Request, re adjustResponse(core, w, req) // Respond - respondOk(w, ret, r) + respondOk(w, ret) return } @@ -396,9 +394,7 @@ func respondLogical(core *vault.Core, w http.ResponseWriter, r *http.Request, re func respondRaw(w http.ResponseWriter, r *http.Request, resp *logical.Response) { retErr := func(w http.ResponseWriter, err string) { w.Header().Set("X-Vault-Raw-Error", err) - code := http.StatusInternalServerError - SetCustomResponseHeaders(w, code, r) - w.WriteHeader(code) + w.WriteHeader(http.StatusInternalServerError) w.Write(nil) } @@ -487,7 +483,6 @@ WRITE_RESPONSE: w.Header().Set("Cache-Control", cacheControl) } - SetCustomResponseHeaders(w, status, r) w.WriteHeader(status) w.Write(body) } diff --git a/http/sys_feature_flags.go b/http/sys_feature_flags.go index 9f115d6e73d7d..11ece32795b77 100644 --- a/http/sys_feature_flags.go +++ b/http/sys_feature_flags.go @@ -31,7 +31,7 @@ func handleSysInternalFeatureFlags(core *vault.Core) http.Handler { case "GET": break default: - respondError(w, http.StatusMethodNotAllowed, nil, r) + respondError(w, http.StatusMethodNotAllowed, nil) } response := &FeatureFlagsResponse{} @@ -43,9 +43,7 @@ func handleSysInternalFeatureFlags(core *vault.Core) http.Handler { } w.Header().Set("Content-Type", "application/json") - status := http.StatusOK - SetCustomResponseHeaders(w, status, r) - w.WriteHeader(status) + w.WriteHeader(http.StatusOK) // Generate the response enc := json.NewEncoder(w) diff --git a/http/sys_generate_root.go b/http/sys_generate_root.go index 881af4cfb7145..3aeb5c395e992 100644 --- a/http/sys_generate_root.go +++ b/http/sys_generate_root.go @@ -22,7 +22,7 @@ func handleSysGenerateRootAttempt(core *vault.Core, generateStrategy vault.Gener case "DELETE": handleSysGenerateRootAttemptDelete(core, w, r) default: - respondError(w, http.StatusMethodNotAllowed, nil, r) + respondError(w, http.StatusMethodNotAllowed, nil) } }) } @@ -34,11 +34,11 @@ func handleSysGenerateRootAttemptGet(core *vault.Core, w http.ResponseWriter, r // Get the current seal configuration barrierConfig, err := core.SealAccess().BarrierConfig(ctx) if err != nil { - respondError(w, http.StatusInternalServerError, err, r) + respondError(w, http.StatusInternalServerError, err) return } if barrierConfig == nil { - respondError(w, http.StatusBadRequest, fmt.Errorf("server is not yet initialized"), r) + respondError(w, http.StatusBadRequest, fmt.Errorf("server is not yet initialized")) return } @@ -46,7 +46,7 @@ func handleSysGenerateRootAttemptGet(core *vault.Core, w http.ResponseWriter, r if core.SealAccess().RecoveryKeySupported() { sealConfig, err = core.SealAccess().RecoveryConfig(ctx) if err != nil { - respondError(w, http.StatusInternalServerError, err, r) + respondError(w, http.StatusInternalServerError, err) return } } @@ -54,14 +54,14 @@ func handleSysGenerateRootAttemptGet(core *vault.Core, w http.ResponseWriter, r // Get the generation configuration generationConfig, err := core.GenerateRootConfiguration() if err != nil { - respondError(w, http.StatusInternalServerError, err, r) + respondError(w, http.StatusInternalServerError, err) return } // Get the progress progress, err := core.GenerateRootProgress() if err != nil { - respondError(w, http.StatusInternalServerError, err, r) + respondError(w, http.StatusInternalServerError, err) return } @@ -80,7 +80,7 @@ func handleSysGenerateRootAttemptGet(core *vault.Core, w http.ResponseWriter, r status.PGPFingerprint = generationConfig.PGPFingerprint } - respondOk(w, status, r) + respondOk(w, status) } func handleSysGenerateRootAttemptPut(core *vault.Core, w http.ResponseWriter, r *http.Request, generateStrategy vault.GenerateRootStrategy) { @@ -88,7 +88,7 @@ func handleSysGenerateRootAttemptPut(core *vault.Core, w http.ResponseWriter, r // Parse the request var req GenerateRootInitRequest if _, err := parseJSONRequest(core.PerfStandby(), r, w, &req); err != nil && err != io.EOF { - respondError(w, http.StatusBadRequest, err, r) + respondError(w, http.StatusBadRequest, err) return } @@ -101,14 +101,14 @@ func handleSysGenerateRootAttemptPut(core *vault.Core, w http.ResponseWriter, r genned = true req.OTP, err = base62.Random(vault.TokenLength + 2) if err != nil { - respondError(w, http.StatusInternalServerError, err, r) + respondError(w, http.StatusInternalServerError, err) return } } // Attemptialize the generation if err := core.GenerateRootInit(req.OTP, req.PGPKey, generateStrategy); err != nil { - respondError(w, http.StatusBadRequest, err, r) + respondError(w, http.StatusBadRequest, err) return } @@ -124,10 +124,10 @@ func handleSysGenerateRootAttemptDelete(core *vault.Core, w http.ResponseWriter, err := core.GenerateRootCancel() if err != nil { - respondError(w, http.StatusInternalServerError, err, r) + respondError(w, http.StatusInternalServerError, err) return } - respondOk(w, nil, r) + respondOk(w, nil) } func handleSysGenerateRootUpdate(core *vault.Core, generateStrategy vault.GenerateRootStrategy) http.Handler { @@ -135,14 +135,14 @@ func handleSysGenerateRootUpdate(core *vault.Core, generateStrategy vault.Genera // Parse the request var req GenerateRootUpdateRequest if _, err := parseJSONRequest(core.PerfStandby(), r, w, &req); err != nil { - respondError(w, http.StatusBadRequest, err, r) + respondError(w, http.StatusBadRequest, err) return } if req.Key == "" { respondError( w, http.StatusBadRequest, errors.New("'key' must be specified in request body as JSON"), - r) + ) return } @@ -158,7 +158,7 @@ func handleSysGenerateRootUpdate(core *vault.Core, generateStrategy vault.Genera respondError( w, http.StatusBadRequest, errors.New("'key' must be a valid hex or base64 string"), - r) + ) return } } @@ -169,7 +169,7 @@ func handleSysGenerateRootUpdate(core *vault.Core, generateStrategy vault.Genera // Use the key to make progress on root generation result, err := core.GenerateRootUpdate(ctx, key, req.Nonce, generateStrategy) if err != nil { - respondError(w, http.StatusBadRequest, err, r) + respondError(w, http.StatusBadRequest, err) return } @@ -187,7 +187,7 @@ func handleSysGenerateRootUpdate(core *vault.Core, generateStrategy vault.Genera resp.EncodedRootToken = result.EncodedToken } - respondOk(w, resp, r) + respondOk(w, resp) }) } diff --git a/http/sys_health.go b/http/sys_health.go index 0f1865db474a5..a37e9dee522e4 100644 --- a/http/sys_health.go +++ b/http/sys_health.go @@ -22,7 +22,7 @@ func handleSysHealth(core *vault.Core) http.Handler { case "HEAD": handleSysHealthHead(core, w, r) default: - respondError(w, http.StatusMethodNotAllowed, nil, r) + respondError(w, http.StatusMethodNotAllowed, nil) } }) } @@ -44,17 +44,16 @@ func handleSysHealthGet(core *vault.Core, w http.ResponseWriter, r *http.Request code, body, err := getSysHealth(core, r) if err != nil { core.Logger().Error("error checking health", "error", err) - respondError(w, code, nil, r) + respondError(w, code, nil) return } if body == nil { - respondError(w, code, nil, r) + respondError(w, code, nil) return } w.Header().Set("Content-Type", "application/json") - SetCustomResponseHeaders(w, code, r) w.WriteHeader(code) // Generate the response @@ -69,7 +68,6 @@ func handleSysHealthHead(core *vault.Core, w http.ResponseWriter, r *http.Reques w.Header().Set("Content-Type", "application/json") } - SetCustomResponseHeaders(w, code, r) w.WriteHeader(code) } diff --git a/http/sys_init.go b/http/sys_init.go index c7dc5fcd50cf8..dd8605c4c08f3 100644 --- a/http/sys_init.go +++ b/http/sys_init.go @@ -17,7 +17,7 @@ func handleSysInit(core *vault.Core) http.Handler { case "PUT", "POST": handleSysInitPut(core, w, r) default: - respondError(w, http.StatusMethodNotAllowed, nil, r) + respondError(w, http.StatusMethodNotAllowed, nil) } }) } @@ -25,13 +25,13 @@ func handleSysInit(core *vault.Core) http.Handler { func handleSysInitGet(core *vault.Core, w http.ResponseWriter, r *http.Request) { init, err := core.Initialized(context.Background()) if err != nil { - respondError(w, http.StatusInternalServerError, err, r) + respondError(w, http.StatusInternalServerError, err) return } respondOk(w, &InitStatusResponse{ Initialized: init, - }, r) + }) } func handleSysInitPut(core *vault.Core, w http.ResponseWriter, r *http.Request) { @@ -41,7 +41,7 @@ func handleSysInitPut(core *vault.Core, w http.ResponseWriter, r *http.Request) // Parse the request var req InitRequest if _, err := parseJSONRequest(core.PerfStandby(), r, w, &req); err != nil { - respondError(w, http.StatusBadRequest, err, r) + respondError(w, http.StatusBadRequest, err) return } @@ -68,7 +68,7 @@ func handleSysInitPut(core *vault.Core, w http.ResponseWriter, r *http.Request) result, initErr := core.Initialize(ctx, initParams) if initErr != nil { if vault.IsFatalError(initErr) { - respondError(w, http.StatusBadRequest, initErr, r) + respondError(w, http.StatusBadRequest, initErr) return } else { // Add a warnings field? The error will be logged in the vault log @@ -100,11 +100,11 @@ func handleSysInitPut(core *vault.Core, w http.ResponseWriter, r *http.Request) } if err := core.UnsealWithStoredKeys(ctx); err != nil { - respondError(w, http.StatusInternalServerError, err, r) + respondError(w, http.StatusInternalServerError, err) return } - respondOk(w, resp, r) + respondOk(w, resp) } type InitRequest struct { diff --git a/http/sys_leader.go b/http/sys_leader.go index 71ed9c796c6e7..8c2ce21e5001d 100644 --- a/http/sys_leader.go +++ b/http/sys_leader.go @@ -14,7 +14,7 @@ func handleSysLeader(core *vault.Core) http.Handler { case "GET": handleSysLeaderGet(core, w, r) default: - respondError(w, http.StatusMethodNotAllowed, nil, r) + respondError(w, http.StatusMethodNotAllowed, nil) } }) } @@ -22,8 +22,8 @@ func handleSysLeader(core *vault.Core) http.Handler { func handleSysLeaderGet(core *vault.Core, w http.ResponseWriter, r *http.Request) { resp, err := core.GetLeaderStatus() if err != nil { - respondError(w, http.StatusInternalServerError, err, r) + respondError(w, http.StatusInternalServerError, err) return } - respondOk(w, resp, r) + respondOk(w, resp) } diff --git a/http/sys_metrics.go b/http/sys_metrics.go index 65b69d3df3f11..9808fa2222c7a 100644 --- a/http/sys_metrics.go +++ b/http/sys_metrics.go @@ -17,13 +17,13 @@ func handleMetricsUnauthenticated(core *vault.Core) http.Handler { switch r.Method { case "GET": default: - respondError(w, http.StatusMethodNotAllowed, nil, r) + respondError(w, http.StatusMethodNotAllowed, nil) return } // Parse form if err := r.ParseForm(); err != nil { - respondError(w, http.StatusBadRequest, err, r) + respondError(w, http.StatusBadRequest, err) return } @@ -37,7 +37,6 @@ func handleMetricsUnauthenticated(core *vault.Core) http.Handler { // Manually extract the logical response and send back the information status := resp.Data[logical.HTTPStatusCode].(int) - SetCustomResponseHeaders(w, status, r) w.Header().Set("Content-Type", resp.Data[logical.HTTPContentType].(string)) switch v := resp.Data[logical.HTTPRawBody].(type) { case string: @@ -47,7 +46,7 @@ func handleMetricsUnauthenticated(core *vault.Core) http.Handler { w.WriteHeader(status) w.Write(v) default: - respondError(w, http.StatusInternalServerError, fmt.Errorf("wrong response returned"), r) + respondError(w, http.StatusInternalServerError, fmt.Errorf("wrong response returned")) } }) } diff --git a/http/sys_raft.go b/http/sys_raft.go index 93612b8d1fb13..5db1a80fb78f6 100644 --- a/http/sys_raft.go +++ b/http/sys_raft.go @@ -18,16 +18,16 @@ func handleSysRaftBootstrap(core *vault.Core) http.Handler { switch r.Method { case "POST", "PUT": if core.Sealed() { - respondError(w, http.StatusBadRequest, errors.New("node must be unsealed to bootstrap"), r) + respondError(w, http.StatusBadRequest, errors.New("node must be unsealed to bootstrap")) } if err := core.RaftBootstrap(context.Background(), false); err != nil { - respondError(w, http.StatusInternalServerError, err, r) + respondError(w, http.StatusInternalServerError, err) return } default: - respondError(w, http.StatusBadRequest, nil, r) + respondError(w, http.StatusBadRequest, nil) } }) } @@ -38,7 +38,7 @@ func handleSysRaftJoin(core *vault.Core) http.Handler { case "POST", "PUT": handleSysRaftJoinPost(core, w, r) default: - respondError(w, http.StatusMethodNotAllowed, nil, r) + respondError(w, http.StatusMethodNotAllowed, nil) } }) } @@ -47,12 +47,12 @@ func handleSysRaftJoinPost(core *vault.Core, w http.ResponseWriter, r *http.Requ // Parse the request var req JoinRequest if _, err := parseJSONRequest(core.PerfStandby(), r, w, &req); err != nil && err != io.EOF { - respondError(w, http.StatusBadRequest, err, r) + respondError(w, http.StatusBadRequest, err) return } if req.NonVoter && !nonVotersAllowed { - respondError(w, http.StatusBadRequest, errors.New("non-voting nodes not allowed"), r) + respondError(w, http.StatusBadRequest, errors.New("non-voting nodes not allowed")) return } @@ -61,14 +61,14 @@ func handleSysRaftJoinPost(core *vault.Core, w http.ResponseWriter, r *http.Requ if len(req.LeaderCACert) != 0 || len(req.LeaderClientCert) != 0 || len(req.LeaderClientKey) != 0 { tlsConfig, err = tlsutil.ClientTLSConfig([]byte(req.LeaderCACert), []byte(req.LeaderClientCert), []byte(req.LeaderClientKey)) if err != nil { - respondError(w, http.StatusBadRequest, err, r) + respondError(w, http.StatusBadRequest, err) return } tlsConfig.ServerName = req.LeaderTLSServerName } if req.AutoJoinScheme != "" && (req.AutoJoinScheme != "http" && req.AutoJoinScheme != "https") { - respondError(w, http.StatusBadRequest, fmt.Errorf("invalid scheme '%s'; must either be http or https", req.AutoJoinScheme), r) + respondError(w, http.StatusBadRequest, fmt.Errorf("invalid scheme '%s'; must either be http or https", req.AutoJoinScheme)) return } @@ -85,14 +85,14 @@ func handleSysRaftJoinPost(core *vault.Core, w http.ResponseWriter, r *http.Requ joined, err := core.JoinRaftCluster(context.Background(), leaderInfos, req.NonVoter) if err != nil { - respondError(w, http.StatusInternalServerError, err, r) + respondError(w, http.StatusInternalServerError, err) return } resp := JoinResponse{ Joined: joined, } - respondOk(w, resp, r) + respondOk(w, resp) } type JoinResponse struct { diff --git a/http/sys_rekey.go b/http/sys_rekey.go index 41dd49e5126fc..71d3eed21fe5a 100644 --- a/http/sys_rekey.go +++ b/http/sys_rekey.go @@ -25,7 +25,7 @@ func handleSysRekeyInit(core *vault.Core, recovery bool) http.Handler { if repState.HasState(consts.ReplicationPerformanceSecondary) { respondError(w, http.StatusBadRequest, fmt.Errorf("rekeying can only be performed on the primary cluster when replication is activated"), - r) + ) return } @@ -34,7 +34,7 @@ func handleSysRekeyInit(core *vault.Core, recovery bool) http.Handler { switch { case recovery && !core.SealAccess().RecoveryKeySupported(): - respondError(w, http.StatusBadRequest, fmt.Errorf("recovery rekeying not supported"), r) + respondError(w, http.StatusBadRequest, fmt.Errorf("recovery rekeying not supported")) case r.Method == "GET": handleSysRekeyInitGet(ctx, core, recovery, w, r) case r.Method == "POST" || r.Method == "PUT": @@ -42,7 +42,7 @@ func handleSysRekeyInit(core *vault.Core, recovery bool) http.Handler { case r.Method == "DELETE": handleSysRekeyInitDelete(ctx, core, recovery, w, r) default: - respondError(w, http.StatusMethodNotAllowed, nil, r) + respondError(w, http.StatusMethodNotAllowed, nil) } }) } @@ -50,24 +50,24 @@ func handleSysRekeyInit(core *vault.Core, recovery bool) http.Handler { func handleSysRekeyInitGet(ctx context.Context, core *vault.Core, recovery bool, w http.ResponseWriter, r *http.Request) { barrierConfig, barrierConfErr := core.SealAccess().BarrierConfig(ctx) if barrierConfErr != nil { - respondError(w, http.StatusInternalServerError, barrierConfErr, r) + respondError(w, http.StatusInternalServerError, barrierConfErr) return } if barrierConfig == nil { - respondError(w, http.StatusBadRequest, fmt.Errorf("server is not yet initialized"), r) + respondError(w, http.StatusBadRequest, fmt.Errorf("server is not yet initialized")) return } // Get the rekey configuration rekeyConf, err := core.RekeyConfig(recovery) if err != nil { - respondError(w, err.Code(), err, r) + respondError(w, err.Code(), err) return } sealThreshold, err := core.RekeyThreshold(ctx, recovery) if err != nil { - respondError(w, err.Code(), err, r) + respondError(w, err.Code(), err) return } @@ -82,7 +82,7 @@ func handleSysRekeyInitGet(ctx context.Context, core *vault.Core, recovery bool, // Get the progress started, progress, err := core.RekeyProgress(recovery, false) if err != nil { - respondError(w, err.Code(), err, r) + respondError(w, err.Code(), err) return } @@ -96,31 +96,31 @@ func handleSysRekeyInitGet(ctx context.Context, core *vault.Core, recovery bool, if rekeyConf.PGPKeys != nil && len(rekeyConf.PGPKeys) != 0 { pgpFingerprints, err := pgpkeys.GetFingerprints(rekeyConf.PGPKeys, nil) if err != nil { - respondError(w, http.StatusInternalServerError, err, r) + respondError(w, http.StatusInternalServerError, err) return } status.PGPFingerprints = pgpFingerprints status.Backup = rekeyConf.Backup } } - respondOk(w, status, r) + respondOk(w, status) } func handleSysRekeyInitPut(ctx context.Context, core *vault.Core, recovery bool, w http.ResponseWriter, r *http.Request) { // Parse the request var req RekeyRequest if _, err := parseJSONRequest(core.PerfStandby(), r, w, &req); err != nil { - respondError(w, http.StatusBadRequest, err, r) + respondError(w, http.StatusBadRequest, err) return } if req.Backup && len(req.PGPKeys) == 0 { - respondError(w, http.StatusBadRequest, fmt.Errorf("cannot request a backup of the new keys without providing PGP keys for encryption"), r) + respondError(w, http.StatusBadRequest, fmt.Errorf("cannot request a backup of the new keys without providing PGP keys for encryption")) return } if len(req.PGPKeys) > 0 && len(req.PGPKeys) != req.SecretShares { - respondError(w, http.StatusBadRequest, fmt.Errorf("incorrect number of PGP keys for rekey"), r) + respondError(w, http.StatusBadRequest, fmt.Errorf("incorrect number of PGP keys for rekey")) return } @@ -134,7 +134,7 @@ func handleSysRekeyInitPut(ctx context.Context, core *vault.Core, recovery bool, VerificationRequired: req.RequireVerification, }, recovery) if err != nil { - respondError(w, err.Code(), err, r) + respondError(w, err.Code(), err) return } @@ -143,10 +143,10 @@ func handleSysRekeyInitPut(ctx context.Context, core *vault.Core, recovery bool, func handleSysRekeyInitDelete(ctx context.Context, core *vault.Core, recovery bool, w http.ResponseWriter, r *http.Request) { if err := core.RekeyCancel(recovery); err != nil { - respondError(w, err.Code(), err, r) + respondError(w, err.Code(), err) return } - respondOk(w, nil, r) + respondOk(w, nil) } func handleSysRekeyUpdate(core *vault.Core, recovery bool) http.Handler { @@ -160,14 +160,14 @@ func handleSysRekeyUpdate(core *vault.Core, recovery bool) http.Handler { // Parse the request var req RekeyUpdateRequest if _, err := parseJSONRequest(core.PerfStandby(), r, w, &req); err != nil { - respondError(w, http.StatusBadRequest, err, r) + respondError(w, http.StatusBadRequest, err) return } if req.Key == "" { respondError( w, http.StatusBadRequest, errors.New("'key' must be specified in request body as JSON"), - r) + ) return } @@ -183,7 +183,7 @@ func handleSysRekeyUpdate(core *vault.Core, recovery bool) http.Handler { respondError( w, http.StatusBadRequest, errors.New("'key' must be a valid hex or base64 string"), - r) + ) return } } @@ -194,7 +194,7 @@ func handleSysRekeyUpdate(core *vault.Core, recovery bool) http.Handler { // Use the key to make progress on rekey result, rekeyErr := core.RekeyUpdate(ctx, key, req.Nonce, recovery) if rekeyErr != nil { - respondError(w, rekeyErr.Code(), rekeyErr, r) + respondError(w, rekeyErr.Code(), rekeyErr) return } @@ -217,7 +217,7 @@ func handleSysRekeyUpdate(core *vault.Core, recovery bool) http.Handler { } resp.Keys = keys resp.KeysB64 = keysB64 - respondOk(w, resp, r) + respondOk(w, resp) } else { handleSysRekeyInitGet(ctx, core, recovery, w, r) } @@ -236,7 +236,7 @@ func handleSysRekeyVerify(core *vault.Core, recovery bool) http.Handler { if repState.HasState(consts.ReplicationPerformanceSecondary) { respondError(w, http.StatusBadRequest, fmt.Errorf("rekeying can only be performed on the primary cluster when replication is activated"), - r) + ) return } @@ -245,7 +245,7 @@ func handleSysRekeyVerify(core *vault.Core, recovery bool) http.Handler { switch { case recovery && !core.SealAccess().RecoveryKeySupported(): - respondError(w, http.StatusBadRequest, fmt.Errorf("recovery rekeying not supported"), r) + respondError(w, http.StatusBadRequest, fmt.Errorf("recovery rekeying not supported")) case r.Method == "GET": handleSysRekeyVerifyGet(ctx, core, recovery, w, r) case r.Method == "POST" || r.Method == "PUT": @@ -253,7 +253,7 @@ func handleSysRekeyVerify(core *vault.Core, recovery bool) http.Handler { case r.Method == "DELETE": handleSysRekeyVerifyDelete(ctx, core, recovery, w, r) default: - respondError(w, http.StatusMethodNotAllowed, nil, r) + respondError(w, http.StatusMethodNotAllowed, nil) } }) } @@ -261,29 +261,29 @@ func handleSysRekeyVerify(core *vault.Core, recovery bool) http.Handler { func handleSysRekeyVerifyGet(ctx context.Context, core *vault.Core, recovery bool, w http.ResponseWriter, r *http.Request) { barrierConfig, barrierConfErr := core.SealAccess().BarrierConfig(ctx) if barrierConfErr != nil { - respondError(w, http.StatusInternalServerError, barrierConfErr, r) + respondError(w, http.StatusInternalServerError, barrierConfErr) return } if barrierConfig == nil { - respondError(w, http.StatusBadRequest, fmt.Errorf("server is not yet initialized"), r) + respondError(w, http.StatusBadRequest, fmt.Errorf("server is not yet initialized")) return } // Get the rekey configuration rekeyConf, err := core.RekeyConfig(recovery) if err != nil { - respondError(w, err.Code(), err, r) + respondError(w, err.Code(), err) return } if rekeyConf == nil { - respondError(w, http.StatusBadRequest, errors.New("no rekey configuration found"), r) + respondError(w, http.StatusBadRequest, errors.New("no rekey configuration found")) return } // Get the progress started, progress, err := core.RekeyProgress(recovery, true) if err != nil { - respondError(w, err.Code(), err, r) + respondError(w, err.Code(), err) return } @@ -295,12 +295,12 @@ func handleSysRekeyVerifyGet(ctx context.Context, core *vault.Core, recovery boo N: rekeyConf.SecretShares, Progress: progress, } - respondOk(w, status, r) + respondOk(w, status) } func handleSysRekeyVerifyDelete(ctx context.Context, core *vault.Core, recovery bool, w http.ResponseWriter, r *http.Request) { if err := core.RekeyVerifyRestart(recovery); err != nil { - respondError(w, err.Code(), err, r) + respondError(w, err.Code(), err) return } @@ -311,14 +311,14 @@ func handleSysRekeyVerifyPut(ctx context.Context, core *vault.Core, recovery boo // Parse the request var req RekeyVerificationUpdateRequest if _, err := parseJSONRequest(core.PerfStandby(), r, w, &req); err != nil { - respondError(w, http.StatusBadRequest, err, r) + respondError(w, http.StatusBadRequest, err) return } if req.Key == "" { respondError( w, http.StatusBadRequest, errors.New("'key' must be specified in request body as JSON"), - r) + ) return } @@ -334,7 +334,7 @@ func handleSysRekeyVerifyPut(ctx context.Context, core *vault.Core, recovery boo respondError( w, http.StatusBadRequest, errors.New("'key' must be a valid hex or base64 string"), - r) + ) return } } @@ -345,7 +345,7 @@ func handleSysRekeyVerifyPut(ctx context.Context, core *vault.Core, recovery boo // Use the key to make progress on rekey result, rekeyErr := core.RekeyVerify(ctx, key, req.Nonce, recovery) if rekeyErr != nil { - respondError(w, rekeyErr.Code(), rekeyErr, r) + respondError(w, rekeyErr.Code(), rekeyErr) return } @@ -354,7 +354,7 @@ func handleSysRekeyVerifyPut(ctx context.Context, core *vault.Core, recovery boo if result != nil { resp.Complete = true resp.Nonce = result.Nonce - respondOk(w, resp, r) + respondOk(w, resp) } else { handleSysRekeyVerifyGet(ctx, core, recovery, w, r) } diff --git a/http/sys_seal.go b/http/sys_seal.go index ef4652779a5b3..3696b9d71e553 100644 --- a/http/sys_seal.go +++ b/http/sys_seal.go @@ -17,14 +17,14 @@ func handleSysSeal(core *vault.Core) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { req, _, statusCode, err := buildLogicalRequest(core, w, r) if err != nil || statusCode != 0 { - respondError(w, statusCode, err, r) + respondError(w, statusCode, err) return } switch req.Operation { case logical.UpdateOperation: default: - respondError(w, http.StatusMethodNotAllowed, nil, r) + respondError(w, http.StatusMethodNotAllowed, nil) return } @@ -32,14 +32,14 @@ func handleSysSeal(core *vault.Core) http.Handler { // We use context.Background since there won't be a request context if the node isn't active if err := core.SealWithRequest(r.Context(), req); err != nil { if errwrap.Contains(err, logical.ErrPermissionDenied.Error()) { - respondError(w, http.StatusForbidden, err, r) + respondError(w, http.StatusForbidden, err) return } - respondError(w, http.StatusInternalServerError, err, r) + respondError(w, http.StatusInternalServerError, err) return } - respondOk(w, nil, r) + respondOk(w, nil) }) } @@ -47,28 +47,28 @@ func handleSysStepDown(core *vault.Core) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { req, _, statusCode, err := buildLogicalRequest(core, w, r) if err != nil || statusCode != 0 { - respondError(w, statusCode, err, r) + respondError(w, statusCode, err) return } switch req.Operation { case logical.UpdateOperation: default: - respondError(w, http.StatusMethodNotAllowed, nil, r) + respondError(w, http.StatusMethodNotAllowed, nil) return } // Seal with the token above if err := core.StepDown(r.Context(), req); err != nil { if errwrap.Contains(err, logical.ErrPermissionDenied.Error()) { - respondError(w, http.StatusForbidden, err, r) + respondError(w, http.StatusForbidden, err) return } - respondError(w, http.StatusInternalServerError, err, r) + respondError(w, http.StatusInternalServerError, err) return } - respondOk(w, nil, r) + respondOk(w, nil) }) } @@ -78,20 +78,20 @@ func handleSysUnseal(core *vault.Core) http.Handler { case "PUT": case "POST": default: - respondError(w, http.StatusMethodNotAllowed, nil, r) + respondError(w, http.StatusMethodNotAllowed, nil) return } // Parse the request var req UnsealRequest if _, err := parseJSONRequest(core.PerfStandby(), r, w, &req); err != nil { - respondError(w, http.StatusBadRequest, err, r) + respondError(w, http.StatusBadRequest, err) return } if req.Reset { if !core.Sealed() { - respondError(w, http.StatusBadRequest, errors.New("vault is unsealed"), r) + respondError(w, http.StatusBadRequest, errors.New("vault is unsealed")) return } core.ResetUnsealProcess() @@ -103,7 +103,7 @@ func handleSysUnseal(core *vault.Core) http.Handler { respondError( w, http.StatusBadRequest, errors.New("'key' must be specified in request body as JSON, or 'reset' set to true"), - r) + ) return } @@ -119,7 +119,7 @@ func handleSysUnseal(core *vault.Core) http.Handler { respondError( w, http.StatusBadRequest, errors.New("'key' must be a valid hex or base64 string"), - r) + ) return } } @@ -139,10 +139,10 @@ func handleSysUnseal(core *vault.Core) http.Handler { case errwrap.Contains(err, vault.ErrBarrierSealed.Error()): case errwrap.Contains(err, consts.ErrStandby.Error()): default: - respondError(w, http.StatusInternalServerError, err, r) + respondError(w, http.StatusInternalServerError, err) return } - respondError(w, http.StatusBadRequest, err, r) + respondError(w, http.StatusBadRequest, err) return } @@ -154,7 +154,7 @@ func handleSysUnseal(core *vault.Core) http.Handler { func handleSysSealStatus(core *vault.Core) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != "GET" { - respondError(w, http.StatusMethodNotAllowed, nil, r) + respondError(w, http.StatusMethodNotAllowed, nil) return } @@ -166,11 +166,11 @@ func handleSysSealStatusRaw(core *vault.Core, w http.ResponseWriter, r *http.Req ctx := context.Background() status, err := core.GetSealStatus(ctx) if err != nil { - respondError(w, http.StatusInternalServerError, err, r) + respondError(w, http.StatusInternalServerError, err) return } - respondOk(w, status, r) + respondOk(w, status) } // Note: because we didn't provide explicit tagging in the past we can't do it diff --git a/http/testing.go b/http/testing.go index 53f7fca04a249..84ab73fc08006 100644 --- a/http/testing.go +++ b/http/testing.go @@ -6,6 +6,7 @@ import ( "net/http" "testing" + "github.com/hashicorp/vault/internalshared/configutil" "github.com/hashicorp/vault/vault" ) @@ -41,10 +42,16 @@ func TestServerWithListenerAndProperties(tb testing.TB, ln net.Listener, addr st } func TestServerWithListener(tb testing.TB, ln net.Listener, addr string, core *vault.Core) { + ip, _, _ := net.SplitHostPort(ln.Addr().String()) + // Create a muxer to handle our requests so that we can authenticate // for tests. props := &vault.HandlerProperties{ Core: core, + // This is needed for testing custom response headers + ListenerConfig: &configutil.Listener { + Address: ip, + }, } TestServerWithListenerAndProperties(tb, ln, addr, core, props) } @@ -62,5 +69,5 @@ func TestServerAuth(tb testing.TB, addr string, token string) { } func testHandleAuth(w http.ResponseWriter, req *http.Request) { - respondOk(w, nil, nil) + respondOk(w, nil) } diff --git a/http/util.go b/http/util.go index 195c1e9ba137e..0550a93c7e66e 100644 --- a/http/util.go +++ b/http/util.go @@ -35,7 +35,7 @@ func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ns, err := namespace.FromContext(r.Context()) if err != nil { - respondError(w, http.StatusInternalServerError, err, r) + respondError(w, http.StatusInternalServerError, err) return } @@ -44,7 +44,7 @@ func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler // again, which is not desired. path, status, err := buildLogicalPath(r) if err != nil || status != 0 { - respondError(w, status, err, r) + respondError(w, status, err) return } @@ -57,7 +57,7 @@ func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler }) if err != nil { core.Logger().Error("failed to apply quota", "path", path, "error", err) - respondError(w, http.StatusUnprocessableEntity, err, r) + respondError(w, http.StatusUnprocessableEntity, err) return } @@ -69,7 +69,7 @@ func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler if !quotaResp.Allowed { quotaErr := fmt.Errorf("request path %q: %w", path, quotas.ErrRateLimitQuotaExceeded) - respondError(w, http.StatusTooManyRequests, quotaErr, r) + respondError(w, http.StatusTooManyRequests, quotaErr) if core.Logger().IsTrace() { core.Logger().Trace("request rejected due to rate limit quota violation", "request_path", path) @@ -78,7 +78,7 @@ func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler if core.RateLimitAuditLoggingEnabled() { req, _, status, err := buildLogicalRequestNoAuth(core.PerfStandby(), w, r) if err != nil || status != 0 { - respondError(w, status, err, r) + respondError(w, status, err) return } diff --git a/sdk/logical/response_util.go b/sdk/logical/response_util.go index a570b7d602227..6ae3005b735f1 100644 --- a/sdk/logical/response_util.go +++ b/sdk/logical/response_util.go @@ -158,6 +158,7 @@ func AdjustErrorStatusCode(status *int, err error) { } func RespondError(w http.ResponseWriter, status int, err error) { + AdjustErrorStatusCode(&status, err) w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) diff --git a/vault/core.go b/vault/core.go index 54125ebeec70e..a3d315161ac11 100644 --- a/vault/core.go +++ b/vault/core.go @@ -42,7 +42,6 @@ import ( "github.com/hashicorp/vault/command/server" "github.com/hashicorp/vault/helper/metricsutil" "github.com/hashicorp/vault/helper/namespace" - "github.com/hashicorp/vault/internalshared/listenerutil" "github.com/hashicorp/vault/physical/raft" "github.com/hashicorp/vault/sdk/helper/certutil" "github.com/hashicorp/vault/sdk/helper/consts" @@ -2638,7 +2637,7 @@ func (c *Core) SetConfig(conf *server.Config) { c.logger.Debug("set config", "sanitized config", string(bz)) } -func (c *Core) getCustomResponseHeaders(la string) []*listenerutil.ListenerCustomHeaders { +func (c *Core) getCustomResponseHeaders(la string) []*ListenerCustomHeaders { if c.customListenerHeader == nil { c.logger.Debug("failed to find the custom response headers configuration") return nil @@ -2653,7 +2652,7 @@ func (c *Core) getCustomResponseHeaders(la string) []*listenerutil.ListenerCusto return lch } -func (c *Core) GetListenerCustomResponseHeaders(la string) *listenerutil.ListenerCustomHeaders { +func (c *Core) GetListenerCustomResponseHeaders(la string) *ListenerCustomHeaders { if la == "" { return nil diff --git a/vault/custom_response_headers.go b/vault/custom_response_headers.go index 7acd3a9723669..6ff46c325eaa9 100644 --- a/vault/custom_response_headers.go +++ b/vault/custom_response_headers.go @@ -1,16 +1,30 @@ package vault import ( + "fmt" log "github.com/hashicorp/go-hclog" "net/http" + "net/textproto" "strings" "github.com/hashicorp/vault/internalshared/configutil" - "github.com/hashicorp/vault/internalshared/listenerutil" ) type ListenersCustomResponseHeadersList struct { - CustomHeadersList []*listenerutil.ListenerCustomHeaders + CustomHeadersList []*ListenerCustomHeaders +} + +type ListenerCustomHeaders struct { + Address string + StatusCodeHeaderMap map[string][]*CustomHeader + // ConfiguredHeadersStatusCodeMap field is introduced so that we would not need to loop through + // StatusCodeHeaderMap to see if a header exists, the key for this map is the headers names + ConfiguredHeadersStatusCodeMap map[string][]string +} + +type CustomHeader struct { + Name string + Value string } func NewListenerCustomHeader(ln []*configutil.Listener, logger log.Logger, uiHeaders http.Header) *ListenersCustomResponseHeadersList { @@ -22,13 +36,13 @@ func NewListenerCustomHeader(ln []*configutil.Listener, logger log.Logger, uiHea ll := &ListenersCustomResponseHeadersList{} for _, l := range ln { - lc := &listenerutil.ListenerCustomHeaders{ + lc := &ListenerCustomHeaders{ Address: l.Address, } - lc.StatusCodeHeaderMap = make(map[string][]*listenerutil.CustomHeader) + lc.StatusCodeHeaderMap = make(map[string][]*CustomHeader) lc.ConfiguredHeadersStatusCodeMap = make(map[string][]string) for sc, hv := range l.CustomResponseHeaders { - var chl []*listenerutil.CustomHeader + var chl []*CustomHeader for h, v := range hv { // Sanitizing custom headers // X-Vault- prefix is reserved for Vault internal processes @@ -41,7 +55,7 @@ func NewListenerCustomHeader(ln []*configutil.Listener, logger log.Logger, uiHea if uiHeaders != nil { exist := uiHeaders.Get(h) if exist != "" { - logger.Warn("found a duplicate header in UI. Headers defined in the server configuration take precedence.", "header:", h) + logger.Warn("found a duplicate header in UI", "header:", h, "Headers defined in the server configuration take precedence.") } } @@ -51,7 +65,7 @@ func NewListenerCustomHeader(ln []*configutil.Listener, logger log.Logger, uiHea continue } - ch := &listenerutil.CustomHeader{ + ch := &CustomHeader{ Name: h, Value: v, } @@ -69,14 +83,14 @@ func NewListenerCustomHeader(ln []*configutil.Listener, logger log.Logger, uiHea return ll } -func (c *ListenersCustomResponseHeadersList) getListenerMap(address string) []*listenerutil.ListenerCustomHeaders { +func (c *ListenersCustomResponseHeadersList) getListenerMap(address string) []*ListenerCustomHeaders { if c.CustomHeadersList == nil { return nil } // either looking for a specific listener, or if listener address isn't given, // checking for all available listeners - var lch []*listenerutil.ListenerCustomHeaders + var lch []*ListenerCustomHeaders if address == "" { lch = c.CustomHeadersList } else { @@ -91,3 +105,89 @@ func (c *ListenersCustomResponseHeadersList) getListenerMap(address string) []*l } return lch } + +func (l *ListenerCustomHeaders) findCustomHeaderMatchStatusCode(sc string, hn string) string { + + getHeader := func(ch []*CustomHeader) string { + for _, h := range ch { + if h.Name == hn { + return h.Value + } + } + return "" + } + + hm := l.StatusCodeHeaderMap + + // starting with the most specific status code + if ch, ok := hm[sc]; ok { + h := getHeader(ch) + if h != "" { + return h + } + } + + // Checking for the Yxx pattern + var firstDig string + if len(sc) == 3 { + firstDig = strings.Split(sc, "")[0] + } + if firstDig != "" { + s := fmt.Sprintf("%vxx", firstDig) + if configutil.IsValidStatusCodeCollection(s) { + if ch, ok := hm[s]; ok { + h := getHeader(ch) + if h != "" { + return h + } + } + } + } + + // At this point, we could not find a match for the given status code in the config file + // so, we just return the "default" ones + h := getHeader(hm["default"]) + if h != ""{ + return h + } + + return "" +} + +func(l *ListenerCustomHeaders) FetchHeaderForStatusCode(header, sc string) (string, error) { + + if header == "" { + return "", fmt.Errorf("invalid target header") + } + + if l.StatusCodeHeaderMap == nil { + return "", fmt.Errorf("custom headers not configured") + } + + if !configutil.IsValidStatusCode(sc) { + return "", fmt.Errorf("failed to check if a header exist in config file due to invalid status code") + } + + hn := textproto.CanonicalMIMEHeaderKey(header) + + h := l.findCustomHeaderMatchStatusCode(sc, hn) + + return h, nil +} + +func (l *ListenerCustomHeaders) ExistCustomResponseHeader(header string) bool { + + if header == "" { + return false + } + + if l.StatusCodeHeaderMap == nil { + return false + } + + hn := textproto.CanonicalMIMEHeaderKey(header) + + hs := l.ConfiguredHeadersStatusCodeMap + _, ok := hs[hn] + return ok +} \ No newline at end of file diff --git a/vault/custom_response_headers_test.go b/vault/custom_response_headers_test.go index de8bf6b273d55..e4e4c5f33ab1c 100644 --- a/vault/custom_response_headers_test.go +++ b/vault/custom_response_headers_test.go @@ -96,28 +96,6 @@ func TestConfigCustomHeaders(t *testing.T) { t.Fatalf("header name with X-Vault prefix is not valid") } - w := httptest.NewRecorder() - lch.SetCustomResponseHeaders(w, 200) - if w.Header().Get("Someheader-200") != "200" || w.Header().Get("X-Custom-Header") != "Custom header value 200"{ - t.Fatalf("response headers related to status code %v did not set properly", 200) - } - - w = httptest.NewRecorder() - lch.SetCustomResponseHeaders(w, 204) - if w.Header().Get("Someheader-200") == "200" || w.Header().Get("X-Custom-Header") != "Custom header value 2xx" { - t.Fatalf("response headers related to status code %v did not set properly", "2xx") - } - - w = httptest.NewRecorder() - lch.SetCustomResponseHeaders(w, 500) - for h, v := range defaultCustomHeaders { - if h != "X-Vault-Ignored" && w.Header().Get(h) != v { - t.Fatalf("response headers related to status code %v did not set properly", 500) - } - } - if w.Header().Get("X-Vault-Ignored") != "" { - t.Fatalf("response headers contains a header with pattern X-Vault") - } } func TestCustomResponseHeadersConfigInteractUiConfig(t *testing.T) { diff --git a/vault/testing.go b/vault/testing.go index 51849d6d39a89..7d178d1c32252 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -128,6 +128,61 @@ func TestCoreWithSeal(t testing.T, testSeal Seal, enableRaw bool) *Core { return TestCoreWithSealAndUI(t, conf) } +func TestCoreWithCustomResponseHeaderAndUI(t testing.T, enableUI bool) (*Core, [][]byte, string) { + confRaw := &server.Config{ + SharedConfig: &configutil.SharedConfig{ + Listeners: []*configutil.Listener{ + { + Type: "tcp", + Address: "127.0.0.1", + CustomResponseHeaders: map[string]map[string]string{ + "default": { + "Strict-Transport-Security": "max-age=1; domains", + "Content-Security-Policy": "default-src 'others'", + "X-Vault-Ignored": "ignored", + "X-Custom-Header": "Custom header value default", + "X-Frame-Options": "Deny", + "X-Content-Type-Options": "nosniff", + "Content-Type": "application/json", + "X-XSS-Protection": "1; mode=block", + }, + "307": {"X-Custom-Header": "Custom header value 307"}, + "3xx": { + "X-Custom-Header": "Custom header value 3xx", + "X-Vault-Ignored-3xx": "Ignored 3xx", + }, + "200": { + "Someheader-200": "200", + "X-Custom-Header": "Custom header value 200", + }, + "2xx": { + "X-Custom-Header": "Custom header value 2xx", + }, + "400": { + "Someheader-400": "400", + }, + "405":{ + "Someheader-405": "405", + }, + "4xx": { + "Someheader-4xx": "4xx", + }, + }, + }, + }, + DisableMlock: true, + }, + } + conf := &CoreConfig{ + RawConfig: confRaw, + EnableUI: enableUI, + EnableRaw: true, + BuiltinRegistry: NewMockBuiltinRegistry(), + } + core := TestCoreWithSealAndUI(t, conf) + return testCoreUnsealed(t, core) +} + func TestCoreUI(t testing.T, enableUI bool) *Core { conf := &CoreConfig{ EnableUI: enableUI, From 434d8cb41034020da472625dc6201a54f8cd8e05 Mon Sep 17 00:00:00 2001 From: hamid ghaf Date: Wed, 15 Sep 2021 17:49:20 -0700 Subject: [PATCH 16/25] adding a new test --- http/cutom_header_test.go | 173 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 173 insertions(+) create mode 100644 http/cutom_header_test.go diff --git a/http/cutom_header_test.go b/http/cutom_header_test.go new file mode 100644 index 0000000000000..0383ae92dbbaf --- /dev/null +++ b/http/cutom_header_test.go @@ -0,0 +1,173 @@ +package http + +import ( + "testing" + + "github.com/hashicorp/vault/vault" +) + +var defaultCustomHeaders = map[string]string { + "Strict-Transport-Security": "max-age=1; domains", + "Content-Security-Policy": "default-src 'others'", + "X-Custom-Header": "Custom header value default", + "X-Frame-Options": "Deny", + "X-Content-Type-Options": "nosniff", + "Content-Type": "application/json", + "X-XSS-Protection": "1; mode=block", +} + +var customHeader2xx = map[string]string { + "X-Custom-Header": "Custom header value 2xx", +} + +var customHeader200 = map[string]string { + "Someheader-200": "200", + "X-Custom-Header": "Custom header value 200", +} + +var customHeader4xx = map[string]string { + "Someheader-4xx": "4xx", +} + +var customHeader400 = map[string]string { + "Someheader-400": "400", +} + +var customHeader405 = map[string]string { + "Someheader-405": "405", +} + +func TestCustomResponseHeaders(t *testing.T) { + core, _, token := vault.TestCoreWithCustomResponseHeaderAndUI(t, true) + ln, addr := TestServer(t, core) + defer ln.Close() + TestServerAuth(t, addr, token) + + resp := testHttpGet(t, token, addr+"/v1/sys/raw/") + testResponseStatus(t, resp, 404) + testResponseHeader(t, resp, defaultCustomHeaders) + testResponseHeader(t, resp, customHeader4xx) + + resp = testHttpGet(t, token, addr+"/v1/sys/generate-recovery-token/attempt") + testResponseStatus(t, resp, 404) + testResponseHeader(t, resp, defaultCustomHeaders) + testResponseHeader(t, resp, customHeader4xx) + + resp = testHttpGet(t, token, addr+"/v1/sys/generate-recovery-token/update") + testResponseStatus(t, resp, 404) + testResponseHeader(t, resp, defaultCustomHeaders) + testResponseHeader(t, resp, customHeader4xx) + + resp = testHttpGet(t, token, addr+"/v1/sys/config/state/") + testResponseStatus(t, resp, 404) + testResponseHeader(t, resp, defaultCustomHeaders) + testResponseHeader(t, resp, customHeader4xx) + + resp = testHttpGet(t, token, addr+"/v1/sys/seal") + testResponseStatus(t, resp, 405) + testResponseHeader(t, resp, defaultCustomHeaders) + testResponseHeader(t, resp, customHeader4xx) + testResponseHeader(t, resp, customHeader405) + + resp = testHttpGet(t, token, addr+"/v1/sys/step-down") + testResponseStatus(t, resp, 405) + testResponseHeader(t, resp, defaultCustomHeaders) + testResponseHeader(t, resp, customHeader4xx) + testResponseHeader(t, resp, customHeader405) + + resp = testHttpGet(t, token, addr+"/v1/sys/unseal") + testResponseStatus(t, resp, 405) + testResponseHeader(t, resp, defaultCustomHeaders) + testResponseHeader(t, resp, customHeader4xx) + testResponseHeader(t, resp, customHeader405) + + resp = testHttpGet(t, token, addr+"/v1/sys/leader") + testResponseStatus(t, resp, 200) + testResponseHeader(t, resp, customHeader200) + + resp = testHttpGet(t, token, addr+"/v1/sys/health") + testResponseStatus(t, resp, 200) + testResponseHeader(t, resp, customHeader200) + + resp = testHttpGet(t, token, addr+"/v1/sys/generate-root/attempt") + testResponseStatus(t, resp, 200) + testResponseHeader(t, resp, customHeader200) + + resp = testHttpGet(t, token, addr+"/v1/sys/generate-root/update") + testResponseStatus(t, resp, 400) + testResponseHeader(t, resp, defaultCustomHeaders) + testResponseHeader(t, resp, customHeader4xx) + testResponseHeader(t, resp, customHeader400) + + resp = testHttpGet(t, token, addr+"/v1/sys/rekey/init") + testResponseStatus(t, resp, 200) + testResponseHeader(t, resp, customHeader200) + + resp = testHttpGet(t, token, addr+"/v1/sys/rekey/update") + testResponseStatus(t, resp, 400) + testResponseHeader(t, resp, defaultCustomHeaders) + testResponseHeader(t, resp, customHeader4xx) + testResponseHeader(t, resp, customHeader400) + + resp = testHttpGet(t, token, addr+"/v1/sys/rekey/verify") + testResponseStatus(t, resp, 400) + testResponseHeader(t, resp, defaultCustomHeaders) + testResponseHeader(t, resp, customHeader4xx) + testResponseHeader(t, resp, customHeader400) + + resp = testHttpGet(t, token, addr+"/v1/sys/") + testResponseStatus(t, resp, 404) + testResponseHeader(t, resp, defaultCustomHeaders) + testResponseHeader(t, resp, customHeader4xx) + + resp = testHttpGet(t, token, addr+"/v1/sys") + testResponseStatus(t, resp, 404) + testResponseHeader(t, resp, defaultCustomHeaders) + testResponseHeader(t, resp, customHeader4xx) + + resp = testHttpGet(t, token, addr+"/v1/") + testResponseStatus(t, resp, 404) + testResponseHeader(t, resp, defaultCustomHeaders) + testResponseHeader(t, resp, customHeader4xx) + + resp = testHttpGet(t, token, addr+"/v1") + testResponseStatus(t, resp, 404) + testResponseHeader(t, resp, defaultCustomHeaders) + testResponseHeader(t, resp, customHeader4xx) + + resp = testHttpGet(t, token, addr+"/") + testResponseStatus(t, resp, 200) + testResponseHeader(t, resp, customHeader200) + + resp = testHttpGet(t, token, addr+"/v1/sys/host-info") + testResponseStatus(t, resp, 200) + testResponseHeader(t, resp, customHeader200) + + resp = testHttpGet(t, token, addr+"/v1/sys/init") + testResponseStatus(t, resp, 200) + testResponseHeader(t, resp, customHeader200) + + resp = testHttpGet(t, token, addr+"/v1/sys/seal-status") + testResponseStatus(t, resp, 200) + testResponseHeader(t, resp, customHeader200) + + resp = testHttpGet(t, token, addr+"/v1/sys/auth") + testResponseStatus(t, resp, 200) + testResponseHeader(t, resp, customHeader200) + + resp = testHttpGet(t, token, addr+"/ui") + testResponseStatus(t, resp, 200) + testResponseHeader(t, resp, customHeader200) + + resp = testHttpGet(t, token, addr+"/ui/") + testResponseStatus(t, resp, 200) + testResponseHeader(t, resp, customHeader200) + + resp = testHttpPost(t, token, addr+"/v1/sys/auth/foo", map[string]interface{}{ + "type": "noop", + "description": "foo", + }) + testResponseStatus(t, resp, 204) + testResponseHeader(t, resp, customHeader2xx) + +} \ No newline at end of file From 6a106f01831f6e08fded1db769384b0ff9f4792b Mon Sep 17 00:00:00 2001 From: hamid ghaf Date: Wed, 15 Sep 2021 17:55:57 -0700 Subject: [PATCH 17/25] some cleanup --- http/handler.go | 13 +- http/sys_rekey.go | 6 +- .../listenerutil/custom_response_headers.go | 160 ------------------ 3 files changed, 9 insertions(+), 170 deletions(-) delete mode 100644 internalshared/listenerutil/custom_response_headers.go diff --git a/http/handler.go b/http/handler.go index c6d363c9a0b69..296e4fa5912dd 100644 --- a/http/handler.go +++ b/http/handler.go @@ -905,14 +905,14 @@ func handleRequestForwarding(core *vault.Core, handler http.Handler) http.Handle func forwardRequest(core *vault.Core, w http.ResponseWriter, r *http.Request) { if r.Header.Get(vault.IntNoForwardingHeaderName) != "" { - respondStandby(core, w, r) + respondStandby(core, w, r.URL) return } if r.Header.Get(NoRequestForwardingHeaderName) != "" { // Forwarding explicitly disabled, fall back to previous behavior core.Logger().Debug("handleRequestForwarding: forwarding disabled by client request") - respondStandby(core, w, r) + respondStandby(core, w, r.URL) return } @@ -923,7 +923,7 @@ func forwardRequest(core *vault.Core, w http.ResponseWriter, r *http.Request) { } path := ns.TrimmedPath(r.URL.Path[len("/v1/"):]) if alwaysRedirectPaths.HasPath(path) { - respondStandby(core, w, r) + respondStandby(core, w, r.URL) return } @@ -939,7 +939,7 @@ func forwardRequest(core *vault.Core, w http.ResponseWriter, r *http.Request) { } // Fall back to redirection - respondStandby(core, w, r) + respondStandby(core, w, r.URL) return } @@ -964,7 +964,7 @@ func request(core *vault.Core, w http.ResponseWriter, rawReq *http.Request, r *l resp.AddWarning("Timeout hit while waiting for local replicated cluster to apply primary's write; this client may encounter stale reads of values written during this operation.") } if errwrap.Contains(err, consts.ErrStandby.Error()) { - respondStandby(core, w, rawReq) + respondStandby(core, w, rawReq.URL) return resp, false, false } if err != nil && errwrap.Contains(err, logical.ErrPerfStandbyPleaseForward.Error()) { @@ -1013,9 +1013,8 @@ func request(core *vault.Core, w http.ResponseWriter, rawReq *http.Request, r *l } // respondStandby is used to trigger a redirect in the case that this Vault is currently a hot standby -func respondStandby(core *vault.Core, w http.ResponseWriter, req *http.Request) { +func respondStandby(core *vault.Core, w http.ResponseWriter, reqURL *url.URL) { - reqURL := req.URL // Request the leader address _, redirectAddr, _, err := core.Leader() if err != nil { diff --git a/http/sys_rekey.go b/http/sys_rekey.go index 71d3eed21fe5a..b6a703237084b 100644 --- a/http/sys_rekey.go +++ b/http/sys_rekey.go @@ -17,7 +17,7 @@ func handleSysRekeyInit(core *vault.Core, recovery bool) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { standby, _ := core.Standby() if standby { - respondStandby(core, w, r) + respondStandby(core, w, r.URL) return } @@ -153,7 +153,7 @@ func handleSysRekeyUpdate(core *vault.Core, recovery bool) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { standby, _ := core.Standby() if standby { - respondStandby(core, w, r) + respondStandby(core, w, r.URL) return } @@ -228,7 +228,7 @@ func handleSysRekeyVerify(core *vault.Core, recovery bool) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { standby, _ := core.Standby() if standby { - respondStandby(core, w, r) + respondStandby(core, w, r.URL) return } diff --git a/internalshared/listenerutil/custom_response_headers.go b/internalshared/listenerutil/custom_response_headers.go deleted file mode 100644 index 9b01afb168e31..0000000000000 --- a/internalshared/listenerutil/custom_response_headers.go +++ /dev/null @@ -1,160 +0,0 @@ -package listenerutil - -import ( - "fmt" - "net/http" - "net/textproto" - "strconv" - "strings" - - "github.com/hashicorp/vault/internalshared/configutil" -) - -type ListenerCustomHeaders struct { - Address string - StatusCodeHeaderMap map[string][]*CustomHeader - // ConfiguredHeadersStatusCodeMap field is introduced so that we would not need to loop through - // StatusCodeHeaderMap to see if a header exists, the key for this map is the headers names - ConfiguredHeadersStatusCodeMap map[string][]string -} - -type CustomHeader struct { - Name string - Value string -} - -// ChangeListenerAddress is used for tests where the listener address (at least the port) -// is chosen at random -func (l *ListenerCustomHeaders) ChangeListenerAddress(la string) { - l.Address = la - return -} - -func (l *ListenerCustomHeaders) SetCustomResponseHeaders(w http.ResponseWriter, status int) { - if w == nil { - fmt.Println("No ResponseWriter provided") - return - } - - sch := l.StatusCodeHeaderMap - if sch == nil { - fmt.Println("status code header map not configured") - return - } - - // setter function to set the headers - setter := func(hvl []*CustomHeader) { - for _, hv := range hvl { - w.Header().Set(hv.Name, hv.Value) - } - } - - // Checking the validity of the status code - if status >= 600 || status < 100 { - fmt.Println("invalid status code") - return - } - - // Setting the default headers first - setter(sch["default"]) - - // setting the Xyy pattern first - d := fmt.Sprintf("%vxx", status / 100) - if val, ok := sch[d]; ok { - setter(val) - } - - // Setting the specific headers - if val, ok := sch[strconv.Itoa(status)]; ok { - setter(val) - } - - return -} - - -func (l *ListenerCustomHeaders) findCustomHeaderMatchStatusCode(sc string, hn string) string { - - getHeader := func(ch []*CustomHeader) string { - for _, h := range ch { - if h.Name == hn { - return h.Value - } - } - return "" - } - - hm := l.StatusCodeHeaderMap - - // starting with the most specific status code - if ch, ok := hm[sc]; ok { - h := getHeader(ch) - if h != "" { - return h - } - } - - // Checking for the Yxx pattern - var firstDig string - if len(sc) == 3 { - firstDig = strings.Split(sc, "")[0] - } - if firstDig != "" { - s := fmt.Sprintf("%vxx", firstDig) - if configutil.IsValidStatusCodeCollection(s) { - if ch, ok := hm[s]; ok { - h := getHeader(ch) - if h != "" { - return h - } - } - } - } - - // At this point, we could not find a match for the given status code in the config file - // so, we just return the "default" ones - h := getHeader(hm["default"]) - if h != ""{ - return h - } - - return "" -} - -func(l *ListenerCustomHeaders) FetchHeaderForStatusCode(header, sc string) (string, error) { - - if header == "" { - return "", fmt.Errorf("invalid target header") - } - - if l.StatusCodeHeaderMap == nil { - return "", fmt.Errorf("custom headers not configured") - } - - if !configutil.IsValidStatusCode(sc) { - return "", fmt.Errorf("failed to check if a header exist in config file due to invalid status code") - } - - hn := textproto.CanonicalMIMEHeaderKey(header) - - h := l.findCustomHeaderMatchStatusCode(sc, hn) - - return h, nil -} - -func (l *ListenerCustomHeaders) ExistCustomResponseHeader(header string) bool { - - if header == "" { - return false - } - - if l.StatusCodeHeaderMap == nil { - return false - } - - hn := textproto.CanonicalMIMEHeaderKey(header) - - hs := l.ConfiguredHeadersStatusCodeMap - _, ok := hs[hn] - return ok -} \ No newline at end of file From 7881067bc73d3f8723580d1348045fc7d8376a13 Mon Sep 17 00:00:00 2001 From: hamid ghaf Date: Thu, 16 Sep 2021 07:03:39 -0700 Subject: [PATCH 18/25] removing some extra lines --- http/handler.go | 6 ------ http/sys_generate_root.go | 7 ++----- http/sys_health.go | 1 - http/sys_init.go | 1 - http/sys_metrics.go | 1 - http/sys_rekey.go | 15 +++++---------- http/sys_seal.go | 6 ++---- 7 files changed, 9 insertions(+), 28 deletions(-) diff --git a/http/handler.go b/http/handler.go index 296e4fa5912dd..82379c86a1983 100644 --- a/http/handler.go +++ b/http/handler.go @@ -334,7 +334,6 @@ func handleAuditNonLogical(core *vault.Core, h http.Handler) http.Handler { origBody := new(bytes.Buffer) reader := ioutil.NopCloser(io.TeeReader(r.Body, origBody)) r.Body = reader - req, _, status, err := buildLogicalRequestNoAuth(core.PerfStandby(), w, r) if err != nil || status != 0 { respondError(w, status, err) @@ -385,7 +384,6 @@ func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerPr maxRequestSize = DefaultMaxRequestSize } - // Swallow this error since we don't want to pollute the logs and we also don't want to // return an HTTP error here. This information is best effort. hostname, _ := os.Hostname() @@ -427,7 +425,6 @@ func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerPr ctx = context.WithValue(ctx, "max_request_size", maxRequestSize) } ctx = context.WithValue(ctx, "original_request_path", r.URL.Path) - r = r.WithContext(ctx) r = r.WithContext(namespace.ContextWithNamespace(r.Context(), namespace.RootNamespace)) @@ -595,7 +592,6 @@ func handleUIHeaders(core *vault.Core, h http.Handler) http.Handler { header.Set(k, v) } } - h.ServeHTTP(w, req) }) } @@ -903,7 +899,6 @@ func handleRequestForwarding(core *vault.Core, handler http.Handler) http.Handle } func forwardRequest(core *vault.Core, w http.ResponseWriter, r *http.Request) { - if r.Header.Get(vault.IntNoForwardingHeaderName) != "" { respondStandby(core, w, r.URL) return @@ -1014,7 +1009,6 @@ func request(core *vault.Core, w http.ResponseWriter, rawReq *http.Request, r *l // respondStandby is used to trigger a redirect in the case that this Vault is currently a hot standby func respondStandby(core *vault.Core, w http.ResponseWriter, reqURL *url.URL) { - // Request the leader address _, redirectAddr, _, err := core.Leader() if err != nil { diff --git a/http/sys_generate_root.go b/http/sys_generate_root.go index 3aeb5c395e992..49d592fb33409 100644 --- a/http/sys_generate_root.go +++ b/http/sys_generate_root.go @@ -121,7 +121,6 @@ func handleSysGenerateRootAttemptPut(core *vault.Core, w http.ResponseWriter, r } func handleSysGenerateRootAttemptDelete(core *vault.Core, w http.ResponseWriter, r *http.Request) { - err := core.GenerateRootCancel() if err != nil { respondError(w, http.StatusInternalServerError, err) @@ -141,8 +140,7 @@ func handleSysGenerateRootUpdate(core *vault.Core, generateStrategy vault.Genera if req.Key == "" { respondError( w, http.StatusBadRequest, - errors.New("'key' must be specified in request body as JSON"), - ) + errors.New("'key' must be specified in request body as JSON")) return } @@ -157,8 +155,7 @@ func handleSysGenerateRootUpdate(core *vault.Core, generateStrategy vault.Genera if err != nil { respondError( w, http.StatusBadRequest, - errors.New("'key' must be a valid hex or base64 string"), - ) + errors.New("'key' must be a valid hex or base64 string")) return } } diff --git a/http/sys_health.go b/http/sys_health.go index a37e9dee522e4..fcaf4e1590999 100644 --- a/http/sys_health.go +++ b/http/sys_health.go @@ -67,7 +67,6 @@ func handleSysHealthHead(core *vault.Core, w http.ResponseWriter, r *http.Reques if body != nil { w.Header().Set("Content-Type", "application/json") } - w.WriteHeader(code) } diff --git a/http/sys_init.go b/http/sys_init.go index dd8605c4c08f3..b21e5363ea020 100644 --- a/http/sys_init.go +++ b/http/sys_init.go @@ -35,7 +35,6 @@ func handleSysInitGet(core *vault.Core, w http.ResponseWriter, r *http.Request) } func handleSysInitPut(core *vault.Core, w http.ResponseWriter, r *http.Request) { - ctx := context.Background() // Parse the request diff --git a/http/sys_metrics.go b/http/sys_metrics.go index 9808fa2222c7a..012417282e5f5 100644 --- a/http/sys_metrics.go +++ b/http/sys_metrics.go @@ -11,7 +11,6 @@ import ( func handleMetricsUnauthenticated(core *vault.Core) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - req := &logical.Request{Headers: r.Header} switch r.Method { diff --git a/http/sys_rekey.go b/http/sys_rekey.go index b6a703237084b..b3428a39f9620 100644 --- a/http/sys_rekey.go +++ b/http/sys_rekey.go @@ -166,8 +166,7 @@ func handleSysRekeyUpdate(core *vault.Core, recovery bool) http.Handler { if req.Key == "" { respondError( w, http.StatusBadRequest, - errors.New("'key' must be specified in request body as JSON"), - ) + errors.New("'key' must be specified in request body as JSON")) return } @@ -182,8 +181,7 @@ func handleSysRekeyUpdate(core *vault.Core, recovery bool) http.Handler { if err != nil { respondError( w, http.StatusBadRequest, - errors.New("'key' must be a valid hex or base64 string"), - ) + errors.New("'key' must be a valid hex or base64 string")) return } } @@ -235,8 +233,7 @@ func handleSysRekeyVerify(core *vault.Core, recovery bool) http.Handler { repState := core.ReplicationState() if repState.HasState(consts.ReplicationPerformanceSecondary) { respondError(w, http.StatusBadRequest, - fmt.Errorf("rekeying can only be performed on the primary cluster when replication is activated"), - ) + fmt.Errorf("rekeying can only be performed on the primary cluster when replication is activated")) return } @@ -317,8 +314,7 @@ func handleSysRekeyVerifyPut(ctx context.Context, core *vault.Core, recovery boo if req.Key == "" { respondError( w, http.StatusBadRequest, - errors.New("'key' must be specified in request body as JSON"), - ) + errors.New("'key' must be specified in request body as JSON")) return } @@ -333,8 +329,7 @@ func handleSysRekeyVerifyPut(ctx context.Context, core *vault.Core, recovery boo if err != nil { respondError( w, http.StatusBadRequest, - errors.New("'key' must be a valid hex or base64 string"), - ) + errors.New("'key' must be a valid hex or base64 string")) return } } diff --git a/http/sys_seal.go b/http/sys_seal.go index 3696b9d71e553..24f491b65d1d6 100644 --- a/http/sys_seal.go +++ b/http/sys_seal.go @@ -102,8 +102,7 @@ func handleSysUnseal(core *vault.Core) http.Handler { if req.Key == "" { respondError( w, http.StatusBadRequest, - errors.New("'key' must be specified in request body as JSON, or 'reset' set to true"), - ) + errors.New("'key' must be specified in request body as JSON, or 'reset' set to true")) return } @@ -118,8 +117,7 @@ func handleSysUnseal(core *vault.Core) http.Handler { if err != nil { respondError( w, http.StatusBadRequest, - errors.New("'key' must be a valid hex or base64 string"), - ) + errors.New("'key' must be a valid hex or base64 string")) return } } From e266713a57ee68bb9a5b9e3adf51ba9b1c22f467 Mon Sep 17 00:00:00 2001 From: hamid ghaf Date: Mon, 20 Sep 2021 18:04:24 -0700 Subject: [PATCH 19/25] Addressing comments --- command/server.go | 2 +- .../config_custom_response_headers_test.go | 71 +++-- command/server/config_test_helpers.go | 44 ++- ...om_response_headers_multiple_listeners.hcl | 43 +++ ...m_header_test.go => custom_header_test.go} | 16 +- http/handler.go | 32 +- http/sys_generate_root.go | 1 - http/sys_rekey.go | 3 +- .../configutil/http_response_headers.go | 286 +++++++++--------- internalshared/configutil/listener.go | 16 +- vault/core.go | 114 ++++--- vault/custom_response_headers.go | 65 +--- vault/custom_response_headers_test.go | 71 ++--- vault/logical_system.go | 2 +- vault/testing.go | 42 +-- 15 files changed, 440 insertions(+), 368 deletions(-) create mode 100644 command/server/test-fixtures/config_custom_response_headers_multiple_listeners.hcl rename http/{cutom_header_test.go => custom_header_test.go} (93%) diff --git a/command/server.go b/command/server.go index 768e907e2b152..e71a21be06253 100644 --- a/command/server.go +++ b/command/server.go @@ -1544,7 +1544,7 @@ func (c *ServerCommand) Run(args []string) int { // reloading custom response headers to make sure we have // the most up to date headers after reloading the config file if err = core.ReloadCustomResponseHeaders(); err != nil { - c.UI.Error(err.Error()) + c.logger.Error(err.Error()) } if config.LogLevel != "" { diff --git a/command/server/config_custom_response_headers_test.go b/command/server/config_custom_response_headers_test.go index 9a881f34eb0f9..a00b54fc2478d 100644 --- a/command/server/config_custom_response_headers_test.go +++ b/command/server/config_custom_response_headers_test.go @@ -2,55 +2,55 @@ package server import ( "fmt" - "github.com/go-test/deep" "testing" + + "github.com/go-test/deep" ) -var defaultCustomHeaders = map[string]string { +var defaultCustomHeaders = map[string]string{ "Strict-Transport-Security": "max-age=1; domains", - "Content-Security-Policy": "default-src 'others'", - "X-Vault-Ignored": "ignored", - "X-Custom-Header": "Custom header value default", - "X-Frame-Options": "Deny", - "X-Content-Type-Options": "nosniff", - "Content-Type": "application/json", - "X-XSS-Protection": "1; mode=block", + "Content-Security-Policy": "default-src 'others'", + "X-Vault-Ignored": "ignored", + "X-Custom-Header": "Custom header value default", + "X-Frame-Options": "Deny", + "X-Content-Type-Options": "nosniff", + "Content-Type": "application/json", + "X-XSS-Protection": "1; mode=block", } -var customHeaders307 = map[string]string { +var customHeaders307 = map[string]string{ "X-Custom-Header": "Custom header value 307", } -var customHeader3xx = map[string]string { +var customHeader3xx = map[string]string{ "X-Vault-Ignored-3xx": "Ignored 3xx", - "X-Custom-Header": "Custom header value 3xx", + "X-Custom-Header": "Custom header value 3xx", } -var customHeaders200 = map[string]string { - "Someheader-200": "200", +var customHeaders200 = map[string]string{ + "Someheader-200": "200", "X-Custom-Header": "Custom header value 200", } -var customHeader2xx = map[string]string { +var customHeader2xx = map[string]string{ "X-Custom-Header": "Custom header value 2xx", } -var customHeader400 = map[string]string { +var customHeader400 = map[string]string{ "Someheader-400": "400", } func TestCustomResponseHeadersConfigs(t *testing.T) { - expectedCustomResponseHeader := map[string]map[string]string { + expectedCustomResponseHeader := map[string]map[string]string{ "default": defaultCustomHeaders, - "307": customHeaders307, - "3xx": customHeader3xx, - "200": customHeaders200, - "2xx": customHeader2xx, - "400": customHeader400, + "307": customHeaders307, + "3xx": customHeader3xx, + "200": customHeaders200, + "2xx": customHeader2xx, + "400": customHeader400, } config, err := LoadConfigFile("./test-fixtures/config_custom_response_headers_1.hcl") - if err != nil { t.Fatalf("Error encountered when loading config %+v", err) } @@ -59,3 +59,28 @@ func TestCustomResponseHeadersConfigs(t *testing.T) { } } +func TestCustomResponseHeadersConfigsMultipleListeners(t *testing.T) { + expectedCustomResponseHeader := map[string]map[string]string{ + "default": defaultCustomHeaders, + "307": customHeaders307, + "3xx": customHeader3xx, + "200": customHeaders200, + "2xx": customHeader2xx, + "400": customHeader400, + } + + config, err := LoadConfigFile("./test-fixtures/config_custom_response_headers_multiple_listeners.hcl") + if err != nil { + t.Fatalf("Error encountered when loading config %+v", err) + } + if diff := deep.Equal(expectedCustomResponseHeader, config.Listeners[0].CustomResponseHeaders); diff != nil { + t.Fatalf(fmt.Sprintf("parsed custom headers do not match the expected ones, difference: %v", diff)) + } + + if diff := deep.Equal(expectedCustomResponseHeader, config.Listeners[1].CustomResponseHeaders); diff == nil { + t.Fatalf(fmt.Sprintf("parsed custom headers do not match the expected ones, difference: %v", diff)) + } + if diff := deep.Equal(expectedCustomResponseHeader["default"], config.Listeners[0].CustomResponseHeaders["default"]); diff != nil { + t.Fatalf(fmt.Sprintf("parsed custom headers do not match the expected ones, difference: %v", diff)) + } +} diff --git a/command/server/config_test_helpers.go b/command/server/config_test_helpers.go index e40f6c8368a6e..d30726e33e130 100644 --- a/command/server/config_test_helpers.go +++ b/command/server/config_test_helpers.go @@ -16,6 +16,15 @@ import ( "github.com/hashicorp/vault/internalshared/configutil" ) +var DefaultCustomHeaders = map[string]string{ + "Strict-Transport-Security": "max-age=31536000; includeSubDomains", + "Content-Security-Policy": "default-src 'none'; connect-src 'self'; img-src 'self' data:; script-src 'self'; style-src 'unsafe-inline' 'self'; form-action 'none'; frame-ancestors 'none'; font-src 'self'", + "X-Frame-Options": "Deny", + "X-Content-Type-Options": "nosniff", + "Content-Type": "application/json", + "X-XSS-Protection": "1; mode=block", +} + func boolPointer(x bool) *bool { return &x } @@ -32,6 +41,9 @@ func testConfigRaftRetryJoin(t *testing.T) { { Type: "tcp", Address: "127.0.0.1:8200", + CustomResponseHeaders: map[string]map[string]string{ + "default": DefaultCustomHeaders, + }, }, }, DisableMlock: true, @@ -64,6 +76,9 @@ func testLoadConfigFile_topLevel(t *testing.T, entropy *configutil.Entropy) { { Type: "tcp", Address: "127.0.0.1:443", + CustomResponseHeaders: map[string]map[string]string{ + "default": DefaultCustomHeaders, + }, }, }, @@ -174,10 +189,16 @@ func testLoadConfigFile_json2(t *testing.T, entropy *configutil.Entropy) { { Type: "tcp", Address: "127.0.0.1:443", + CustomResponseHeaders: map[string]map[string]string{ + "default": DefaultCustomHeaders, + }, }, { Type: "tcp", Address: "127.0.0.1:444", + CustomResponseHeaders: map[string]map[string]string{ + "default": DefaultCustomHeaders, + }, }, }, @@ -336,6 +357,9 @@ func testLoadConfigFileIntegerAndBooleanValuesCommon(t *testing.T, path string) { Type: "tcp", Address: "127.0.0.1:8200", + CustomResponseHeaders: map[string]map[string]string{ + "default": DefaultCustomHeaders, + }, }, }, DisableMlock: true, @@ -379,6 +403,9 @@ func testLoadConfigFile(t *testing.T) { { Type: "tcp", Address: "127.0.0.1:443", + CustomResponseHeaders: map[string]map[string]string{ + "default": DefaultCustomHeaders, + }, }, }, @@ -486,7 +513,7 @@ func testUnknownFieldValidation(t *testing.T) { for _, er1 := range errors { found := false if strings.Contains(er1.String(), "sentinel") { - //This happens on OSS, and is fine + // This happens on OSS, and is fine continue } for _, ex := range expected { @@ -525,6 +552,9 @@ func testLoadConfigFile_json(t *testing.T) { { Type: "tcp", Address: "127.0.0.1:443", + CustomResponseHeaders: map[string]map[string]string{ + "default": DefaultCustomHeaders, + }, }, }, @@ -610,6 +640,9 @@ func testLoadConfigDir(t *testing.T) { { Type: "tcp", Address: "127.0.0.1:443", + CustomResponseHeaders: map[string]map[string]string{ + "default": DefaultCustomHeaders, + }, }, }, @@ -818,6 +851,9 @@ listener "tcp" { Profiling: configutil.ListenerProfiling{ UnauthenticatedPProfAccess: true, }, + CustomResponseHeaders: map[string]map[string]string{ + "default": DefaultCustomHeaders, + }, }, }, }, @@ -845,6 +881,9 @@ func testParseSeals(t *testing.T) { { Type: "tcp", Address: "127.0.0.1:443", + CustomResponseHeaders: map[string]map[string]string{ + "default": DefaultCustomHeaders, + }, }, }, Seals: []*configutil.KMS{ @@ -898,6 +937,9 @@ func testLoadConfigFileLeaseMetrics(t *testing.T) { { Type: "tcp", Address: "127.0.0.1:443", + CustomResponseHeaders: map[string]map[string]string{ + "default": DefaultCustomHeaders, + }, }, }, diff --git a/command/server/test-fixtures/config_custom_response_headers_multiple_listeners.hcl b/command/server/test-fixtures/config_custom_response_headers_multiple_listeners.hcl new file mode 100644 index 0000000000000..15cd7b49c89b6 --- /dev/null +++ b/command/server/test-fixtures/config_custom_response_headers_multiple_listeners.hcl @@ -0,0 +1,43 @@ +storage "inmem" {} +listener "tcp" { + address = "127.0.0.1:8200" + tls_disable = true + custom_response_headers { + "default" = { + "Strict-Transport-Security" = ["max-age=1","domains"], + "Content-Security-Policy" = ["default-src 'others'"], + "X-Vault-Ignored" = ["ignored"], + "X-Custom-Header" = ["Custom header value default"], + } + "307" = { + "X-Custom-Header" = ["Custom header value 307"], + } + "3xx" = { + "X-Vault-Ignored-3xx" = ["Ignored 3xx"], + "X-Custom-Header" = ["Custom header value 3xx"] + } + "200" = { + "someheader-200" = ["200"], + "X-Custom-Header" = ["Custom header value 200"] + } + "2xx" = { + "X-Custom-Header" = ["Custom header value 2xx"] + } + "400" = { + "someheader-400" = ["400"] + } + } +} +listener "tcp" { + address = "127.0.0.2:8200" + tls_disable = true + custom_response_headers { + "default" = { + "Strict-Transport-Security" = ["max-age=1","domains"], + "Content-Security-Policy" = ["default-src 'others'"], + "X-Vault-Ignored" = ["ignored"], + "X-Custom-Header" = ["Custom header value default"], + } + } +} +disable_mlock = true diff --git a/http/cutom_header_test.go b/http/custom_header_test.go similarity index 93% rename from http/cutom_header_test.go rename to http/custom_header_test.go index 0383ae92dbbaf..6ed977c6f5015 100644 --- a/http/cutom_header_test.go +++ b/http/custom_header_test.go @@ -37,8 +37,22 @@ var customHeader405 = map[string]string { "Someheader-405": "405", } +var CustomResponseHeaders = map[string]map[string]string{ + "default": defaultCustomHeaders, + "307": {"X-Custom-Header": "Custom header value 307"}, + "3xx": { + "X-Custom-Header": "Custom header value 3xx", + "X-Vault-Ignored-3xx": "Ignored 3xx", + }, + "200": customHeader200, + "2xx": customHeader2xx, + "400": customHeader400, + "405": customHeader405, + "4xx": customHeader4xx, +} + func TestCustomResponseHeaders(t *testing.T) { - core, _, token := vault.TestCoreWithCustomResponseHeaderAndUI(t, true) + core, _, token := vault.TestCoreWithCustomResponseHeaderAndUI(t, CustomResponseHeaders, true) ln, addr := TestServer(t, core) defer ln.Close() TestServerAuth(t, addr, token) diff --git a/http/handler.go b/http/handler.go index 82379c86a1983..6c17cd7b0ae3e 100644 --- a/http/handler.go +++ b/http/handler.go @@ -218,11 +218,11 @@ type WrappingResponseWriter interface { } type statusHeaderResponseWriter struct { - wrapped http.ResponseWriter - logger log.Logger + wrapped http.ResponseWriter + logger log.Logger wroteHeader bool - statusCode int - headers map[string][]*vault.CustomHeader + statusCode int + headers map[string][]*vault.CustomHeader } func (w statusHeaderResponseWriter) Wrapped() http.ResponseWriter { @@ -234,14 +234,14 @@ func (w statusHeaderResponseWriter) Header() http.Header { } func (w statusHeaderResponseWriter) Write(buf []byte) (int, error) { - // It is allowed to only call ResponseWriter.Write and skip ResponseWriter.WriteHeader - // An example of such a situation is "handleUIStub". The Write function will internally - // set the status code 200 for the response, and that call might invoke other implementations - // the WriteHeader function normally for some sort of IO writer. - // So, we still need to set the custom headers. - // In cases where both WriteHeader and Write of statusHeaderResponseWriter struct are called - // the internal call to the WriterHeader invoked from inside Write method won't change - // the headers. + // It is allowed to only call ResponseWriter.Write and skip + // ResponseWriter.WriteHeader. An example of such a situation is + // "handleUIStub". The Write function will internally set the status code + // 200 for the response for which that call might invoke other + // implementations of the WriteHeader function. So, we still need to set + // the custom headers. In cases where both WriteHeader and Write of + // statusHeaderResponseWriter struct are called the internal call to the + // WriterHeader invoked from inside Write method won't change the headers. if !w.wroteHeader { w.setCustomResponseHeaders(w.statusCode) } @@ -255,13 +255,10 @@ func (w statusHeaderResponseWriter) WriteHeader(statusCode int) { w.statusCode = statusCode // in cases where Write is called after WriteHeader, let's prevent setting // ResponseWriter headers twice - if !w.wroteHeader { - w.wroteHeader = true - } + w.wroteHeader = true } func (w statusHeaderResponseWriter) setCustomResponseHeaders(status int) { - sch := w.headers if sch == nil { w.logger.Warn("status code header map not configured") @@ -284,7 +281,7 @@ func (w statusHeaderResponseWriter) setCustomResponseHeaders(status int) { setter(sch["default"]) // setting the Xyy pattern first - d := fmt.Sprintf("%vxx", status / 100) + d := fmt.Sprintf("%vxx", status/100) if val, ok := sch[d]; ok { setter(val) } @@ -389,7 +386,6 @@ func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerPr hostname, _ := os.Hostname() return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // This block needs to be here so that upon sending SIGHUP, custom response // headers are also reloaded into the handlers. if props.ListenerConfig != nil { diff --git a/http/sys_generate_root.go b/http/sys_generate_root.go index 49d592fb33409..4ac3015077447 100644 --- a/http/sys_generate_root.go +++ b/http/sys_generate_root.go @@ -84,7 +84,6 @@ func handleSysGenerateRootAttemptGet(core *vault.Core, w http.ResponseWriter, r } func handleSysGenerateRootAttemptPut(core *vault.Core, w http.ResponseWriter, r *http.Request, generateStrategy vault.GenerateRootStrategy) { - // Parse the request var req GenerateRootInitRequest if _, err := parseJSONRequest(core.PerfStandby(), r, w, &req); err != nil && err != io.EOF { diff --git a/http/sys_rekey.go b/http/sys_rekey.go index b3428a39f9620..d1cec653a6283 100644 --- a/http/sys_rekey.go +++ b/http/sys_rekey.go @@ -24,8 +24,7 @@ func handleSysRekeyInit(core *vault.Core, recovery bool) http.Handler { repState := core.ReplicationState() if repState.HasState(consts.ReplicationPerformanceSecondary) { respondError(w, http.StatusBadRequest, - fmt.Errorf("rekeying can only be performed on the primary cluster when replication is activated"), - ) + fmt.Errorf("rekeying can only be performed on the primary cluster when replication is activated")) return } diff --git a/internalshared/configutil/http_response_headers.go b/internalshared/configutil/http_response_headers.go index 3a18561b474d6..26da0aa67a99c 100644 --- a/internalshared/configutil/http_response_headers.go +++ b/internalshared/configutil/http_response_headers.go @@ -1,175 +1,187 @@ package configutil import ( - "fmt" - "net/textproto" - "strconv" - "strings" + "fmt" + "net/textproto" + "strconv" + "strings" ) -var DefaultHeaderNames = []string { - "Content-Security-Policy", - "X-XSS-Protection", - "X-Frame-Options", - "X-Content-Type-Options", - "Strict-Transport-Security", - "Content-Type", +var DefaultHeaderNames = []string{ + "Content-Security-Policy", + "X-XSS-Protection", + "X-Frame-Options", + "X-Content-Type-Options", + "Strict-Transport-Security", + "Content-Type", } -var ValidCustomStatusCodeCollection = []string { - "default", - "1xx", - "2xx", - "3xx", - "4xx", - "5xx", +var ValidCustomStatusCodeCollection = []string{ + "default", + "1xx", + "2xx", + "3xx", + "4xx", + "5xx", } const ( - contentSecurityPolicy = "default-src 'none'; connect-src 'self'; img-src 'self' data:; script-src 'self'; style-src 'unsafe-inline' 'self'; form-action 'none'; frame-ancestors 'none'; font-src 'self'" - xXssProtection = "1; mode=block" - xFrameOptions = "Deny" - xContentTypeOptions = "nosniff" - strictTransportSecurity = "max-age=31536000; includeSubDomains" - contentType = "application/json" + contentSecurityPolicy = "default-src 'none'; connect-src 'self'; img-src 'self' data:; script-src 'self'; style-src 'unsafe-inline' 'self'; form-action 'none'; frame-ancestors 'none'; font-src 'self'" + xXssProtection = "1; mode=block" + xFrameOptions = "Deny" + xContentTypeOptions = "nosniff" + strictTransportSecurity = "max-age=31536000; includeSubDomains" + contentType = "application/json" ) func GetDefaultHeaderValue(h string) string { - switch h { - case "Content-Security-Policy": - return contentSecurityPolicy - case "X-XSS-Protection": - return xXssProtection - case "X-Frame-Options": - return xFrameOptions - case "X-Content-Type-Options": - return xContentTypeOptions - case "Strict-Transport-Security": - return strictTransportSecurity - case "Content-Type": - return contentType - default: - return "" - } + switch h { + case "Content-Security-Policy": + return contentSecurityPolicy + case "X-XSS-Protection": + return xXssProtection + case "X-Frame-Options": + return xFrameOptions + case "X-Content-Type-Options": + return xContentTypeOptions + case "Strict-Transport-Security": + return strictTransportSecurity + case "Content-Type": + return contentType + default: + return "" + } } func setDefaultResponseHeaders(c map[string]string) map[string]string { - defaults := make(map[string]string) - // adding all parsed default headers - for k, v := range c { - defaults[k] = v - } - - // setting all default headers that are not included in the config - // file under the "default" category - for _, hn := range DefaultHeaderNames { - if _, ok := c[hn]; ok { - continue - } - hv := GetDefaultHeaderValue(hn) - if hv != "" { - defaults[hn] = hv - } - } - - return defaults + defaults := make(map[string]string) + // adding all parsed default headers + for k, v := range c { + defaults[k] = v + } + + // setting all default headers that are not included in the config + // file under the "default" category + for _, hn := range DefaultHeaderNames { + if _, ok := c[hn]; ok { + continue + } + hv := GetDefaultHeaderValue(hn) + if hv != "" { + defaults[hn] = hv + } + } + + return defaults } func ParseCustomResponseHeaders(r interface{}) (map[string]map[string]string, error) { - if _, ok := r.([]map[string]interface{}); !ok { - return nil, fmt.Errorf("response headers were not configured correctly. please make sure they're in a map") - } - - customResponseHeader := r.([]map[string]interface{}) - h := make(map[string]map[string]string) - - for _, crh := range customResponseHeader { - for statusCode, responseHeader := range crh { - if _, ok := responseHeader.([]map[string]interface{}); !ok { - return nil, fmt.Errorf("response headers were not configured correctly. please make sure they're in a map") - } - - if !IsValidStatusCode(statusCode) { - return nil, fmt.Errorf("invalid status code found in the server configuration: %v", statusCode) - } - - hvl := responseHeader.([]map[string]interface{}) - if len(hvl) != 1 { - return nil, fmt.Errorf("invalid number of response headers exist") - } - hvm := hvl[0] - hv, err := parseHeaders(hvm) - if err != nil { - return nil, err - } - - h[statusCode] = hv - } - } - - // setting default custom headers - de := h["default"] - h["default"] = setDefaultResponseHeaders(de) + h := make(map[string]map[string]string) + // if r is nil, we still should set the default custom headers + if r == nil { + de := h["default"] + h["default"] = setDefaultResponseHeaders(de) + return h, nil + } + + if _, ok := r.([]map[string]interface{}); !ok { + return nil, fmt.Errorf("response headers were not configured correctly. please make sure they're in a map") + } + + customResponseHeader := r.([]map[string]interface{}) + + for _, crh := range customResponseHeader { + for statusCode, responseHeader := range crh { + if _, ok := responseHeader.([]map[string]interface{}); !ok { + return nil, fmt.Errorf("response headers were not configured correctly. please make sure they're in a map") + } + + if !IsValidStatusCode(statusCode) { + return nil, fmt.Errorf("invalid status code found in the server configuration: %v", statusCode) + } + + hvl := responseHeader.([]map[string]interface{}) + if len(hvl) != 1 { + return nil, fmt.Errorf("invalid number of response headers exist") + } + hvm := hvl[0] + hv, err := parseHeaders(hvm) + if err != nil { + return nil, err + } + + h[statusCode] = hv + } + } + + // setting default custom headers + de := h["default"] + h["default"] = setDefaultResponseHeaders(de) return h, nil } func IsValidStatusCodeCollection(sc string) bool { - for _, v := range ValidCustomStatusCodeCollection { - if sc == v { - return true - } - } + for _, v := range ValidCustomStatusCodeCollection { + if sc == v { + return true + } + } - return false + return false } // IsValidStatusCode checking for status codes outside the boundary func IsValidStatusCode(sc string) bool { - if IsValidStatusCodeCollection(sc) { - return true - } + if IsValidStatusCodeCollection(sc) { + return true + } - i, err := strconv.Atoi(sc) - if err != nil { - return false - } + i, err := strconv.Atoi(sc) + if err != nil { + return false + } - if i >= 600 || i < 100 { - return false - } + if i >= 600 || i < 100 { + return false + } - return true + return true } func parseHeaders(in map[string]interface{}) (map[string]string, error) { - hvMap := make(map[string]string) - for k, v := range in { - // parsing header name - hn := textproto.CanonicalMIMEHeaderKey(k) - // parsing header values - s, err := parseHeaderValues(v) - if err != nil { - return nil, err - } - hvMap[hn] = s - } - return hvMap, nil + hvMap := make(map[string]string) + for k, v := range in { + // parsing header name + hn := textproto.CanonicalMIMEHeaderKey(k) + // parsing header values + s, err := parseHeaderValues(v) + if err != nil { + return nil, err + } + hvMap[hn] = s + } + return hvMap, nil } func parseHeaderValues(h interface{}) (string, error) { - var sl []string - if _, ok := h.([]interface{}); !ok { - return "", fmt.Errorf("headers must be given in a list of strings") - } - vli := h.([]interface{}) - for _, vh := range vli { - if vh.(string) == "" { - continue - } - sl = append(sl, vh.(string)) - } - s := strings.Join(sl, "; ") - - return s, nil -} \ No newline at end of file + var sl []string + if _, ok := h.([]interface{}); !ok { + return "", fmt.Errorf("headers must be given in a list of strings") + } + vli := h.([]interface{}) + for _, vh := range vli { + if _, ok := vh.(string); !ok { + return "", fmt.Errorf("found a non-string header value: %v", vh) + } + headerVal := vh.(string) + if headerVal == "" { + continue + } + sl = append(sl, headerVal) + + } + s := strings.Join(sl, "; ") + + return s, nil +} diff --git a/internalshared/configutil/listener.go b/internalshared/configutil/listener.go index 858b6678cc1ec..677dbf9df8b3a 100644 --- a/internalshared/configutil/listener.go +++ b/internalshared/configutil/listener.go @@ -101,9 +101,8 @@ type Listener struct { CorsAllowedHeadersRaw []string `hcl:"cors_allowed_headers,alias:cors_allowed_headers"` // Custom Http response headers - CustomResponseHeaders map[string]map[string]string `hcl:"-"` + CustomResponseHeaders map[string]map[string]string `hcl:"-"` CustomResponseHeadersRaw interface{} `hcl:"custom_response_headers"` - } func (l *Listener) GoString() string { @@ -368,14 +367,13 @@ func ParseListeners(result *SharedConfig, list *ast.ObjectList) error { // HTTP Headers { - if l.CustomResponseHeadersRaw != nil { - customHeadersMap, err := ParseCustomResponseHeaders(l.CustomResponseHeadersRaw) - if err != nil { - return multierror.Prefix(fmt.Errorf("failed to parse custom_response_headers: %w", err), fmt.Sprintf("listeners.%d", i)) - } - l.CustomResponseHeaders = customHeadersMap - l.CustomResponseHeadersRaw = nil + // if CustomResponseHeadersRaw is nil, we still need to set the default headers + customHeadersMap, err := ParseCustomResponseHeaders(l.CustomResponseHeadersRaw) + if err != nil { + return multierror.Prefix(fmt.Errorf("failed to parse custom_response_headers: %w", err), fmt.Sprintf("listeners.%d", i)) } + l.CustomResponseHeaders = customHeadersMap + l.CustomResponseHeadersRaw = nil } result.Listeners = append(result.Listeners, &l) diff --git a/vault/core.go b/vault/core.go index a3d315161ac11..0e98c168ef066 100644 --- a/vault/core.go +++ b/vault/core.go @@ -517,7 +517,7 @@ type Core struct { // clusterListener starts up and manages connections on the cluster ports clusterListener *atomic.Value - customListenerHeader *ListenersCustomResponseHeadersList + customListenerHeader *atomic.Value // Telemetry objects metricsHelper *metricsutil.MetricsHelper @@ -767,23 +767,24 @@ func CreateCore(conf *CoreConfig) (*Core, error) { // Setup the core c := &Core{ - entCore: entCore{}, - devToken: conf.DevToken, - physical: conf.Physical, - serviceRegistration: conf.GetServiceRegistration(), - underlyingPhysical: conf.Physical, - storageType: conf.StorageType, - redirectAddr: conf.RedirectAddr, - clusterAddr: new(atomic.Value), - clusterListener: new(atomic.Value), - seal: conf.Seal, - router: NewRouter(), - sealed: new(uint32), - sealMigrationDone: new(uint32), - standby: true, - standbyStopCh: new(atomic.Value), - baseLogger: conf.Logger, - logger: conf.Logger.Named("core"), + entCore: entCore{}, + devToken: conf.DevToken, + physical: conf.Physical, + serviceRegistration: conf.GetServiceRegistration(), + underlyingPhysical: conf.Physical, + storageType: conf.StorageType, + redirectAddr: conf.RedirectAddr, + clusterAddr: new(atomic.Value), + clusterListener: new(atomic.Value), + customListenerHeader: new(atomic.Value), + seal: conf.Seal, + router: NewRouter(), + sealed: new(uint32), + sealMigrationDone: new(uint32), + standby: true, + standbyStopCh: new(atomic.Value), + baseLogger: conf.Logger, + logger: conf.Logger.Named("core"), defaultLeaseTTL: conf.DefaultLeaseTTL, maxLeaseTTL: conf.MaxLeaseTTL, @@ -1003,8 +1004,16 @@ func NewCore(conf *CoreConfig) (*Core, error) { c.clusterListener.Store((*cluster.Listener)(nil)) - uiHeaders, _ := c.UIHeaders() - c.customListenerHeader = NewListenerCustomHeader(conf.RawConfig.Listeners, c.logger, uiHeaders) + // for listeners with custom response headers, configuring customListenerHeader + if conf.RawConfig.Listeners != nil { + uiHeaders, err := c.UIHeaders() + if err != nil { + return nil, err + } + c.customListenerHeader.Store(NewListenerCustomHeader(conf.RawConfig.Listeners, c.logger, uiHeaders)) + } else { + c.customListenerHeader.Store(([]*ListenerCustomHeaders)(nil)) + } quotasLogger := conf.Logger.Named("quotas") c.allLoggers = append(c.allLoggers, quotasLogger) @@ -2637,27 +2646,41 @@ func (c *Core) SetConfig(conf *server.Config) { c.logger.Debug("set config", "sanitized config", string(bz)) } -func (c *Core) getCustomResponseHeaders(la string) []*ListenerCustomHeaders { - if c.customListenerHeader == nil { - c.logger.Debug("failed to find the custom response headers configuration") +func (c *Core) getCustomHeadersListenerList(listenerAdd string) []*ListenerCustomHeaders { + customHeaders := c.customListenerHeader.Load() + if customHeaders == nil { return nil } - lch := c.customListenerHeader.getListenerMap(la) - if lch == nil { - c.logger.Warn("no listener config found", "address", la) + customHeadersList := customHeaders.([]*ListenerCustomHeaders) + if customHeadersList == nil { return nil } + // either looking for a specific listener, or if listener address isn't given, + // checking for all available listeners + var lch []*ListenerCustomHeaders + if listenerAdd == "" { + lch = customHeadersList + } else { + for _, l := range customHeadersList { + if l.Address == listenerAdd { + lch = append(lch, l) + break + } + } + if len(lch) == 0 { + return nil + } + } return lch } -func (c *Core) GetListenerCustomResponseHeaders(la string) *ListenerCustomHeaders { - - if la == "" { +func (c *Core) GetListenerCustomResponseHeaders(listenerAdd string) *ListenerCustomHeaders { + if listenerAdd == "" { return nil } - lch := c.getCustomResponseHeaders(la) + lch := c.getCustomHeadersListenerList(listenerAdd) if lch == nil { return nil } @@ -2669,15 +2692,10 @@ func (c *Core) GetListenerCustomResponseHeaders(la string) *ListenerCustomHeader return lch[0] } -func (c *Core) ExistCustomResponseHeader(header string, la string) bool { - if c.customListenerHeader == nil { - c.logger.Debug("failed to find the custom response headers configuration") - return false - } - - lch := c.getCustomResponseHeaders(la) +func (c *Core) ExistCustomResponseHeader(header string, listenerAdd string) bool { + lch := c.getCustomHeadersListenerList(listenerAdd) if lch == nil { - c.logger.Warn("no listener config found", "address", la) + c.logger.Warn("no listener config found", "address", listenerAdd) return false } @@ -2693,24 +2711,22 @@ func (c *Core) ExistCustomResponseHeader(header string, la string) bool { } func (c *Core) ReloadCustomResponseHeaders() error { - conf := c.rawConfig.Load() if conf == nil { return fmt.Errorf("failed to load core raw config") } - - tempLH := c.customListenerHeader - c.customListenerHeader = nil - - uiHeaders, _ := c.UIHeaders() - lns := conf.(*server.Config).Listeners - c.customListenerHeader = NewListenerCustomHeader(lns, c.logger, uiHeaders) + if lns == nil { + return fmt.Errorf("no listener configured") + } - if c.customListenerHeader == nil { - c.logger.Error("failed to reload custom headers. the previous configuration will be used") - c.customListenerHeader = tempLH + c.customListenerHeader.Store(([]*ListenerCustomHeaders)(nil)) + + uiHeaders, err := c.UIHeaders() + if err != nil { + return err } + c.customListenerHeader.Store(NewListenerCustomHeader(lns, c.logger, uiHeaders)) return nil } diff --git a/vault/custom_response_headers.go b/vault/custom_response_headers.go index 6ff46c325eaa9..382b2783c4b69 100644 --- a/vault/custom_response_headers.go +++ b/vault/custom_response_headers.go @@ -2,45 +2,36 @@ package vault import ( "fmt" - log "github.com/hashicorp/go-hclog" "net/http" "net/textproto" "strings" + log "github.com/hashicorp/go-hclog" "github.com/hashicorp/vault/internalshared/configutil" ) -type ListenersCustomResponseHeadersList struct { - CustomHeadersList []*ListenerCustomHeaders -} - type ListenerCustomHeaders struct { - Address string + Address string StatusCodeHeaderMap map[string][]*CustomHeader // ConfiguredHeadersStatusCodeMap field is introduced so that we would not need to loop through // StatusCodeHeaderMap to see if a header exists, the key for this map is the headers names - ConfiguredHeadersStatusCodeMap map[string][]string + configuredHeadersStatusCodeMap map[string][]string } type CustomHeader struct { - Name string + Name string Value string } -func NewListenerCustomHeader(ln []*configutil.Listener, logger log.Logger, uiHeaders http.Header) *ListenersCustomResponseHeadersList { - - if ln == nil { - return nil - } - - ll := &ListenersCustomResponseHeadersList{} +func NewListenerCustomHeader(ln []*configutil.Listener, logger log.Logger, uiHeaders http.Header) []*ListenerCustomHeaders { + var ll []*ListenerCustomHeaders for _, l := range ln { lc := &ListenerCustomHeaders{ Address: l.Address, } lc.StatusCodeHeaderMap = make(map[string][]*CustomHeader) - lc.ConfiguredHeadersStatusCodeMap = make(map[string][]string) + lc.configuredHeadersStatusCodeMap = make(map[string][]string) for sc, hv := range l.CustomResponseHeaders { var chl []*CustomHeader for h, v := range hv { @@ -66,48 +57,24 @@ func NewListenerCustomHeader(ln []*configutil.Listener, logger log.Logger, uiHea } ch := &CustomHeader{ - Name: h, + Name: h, Value: v, } chl = append(chl, ch) // setting up the reverse map of header to status code for easy lookups - lc.ConfiguredHeadersStatusCodeMap[h] = append(lc.ConfiguredHeadersStatusCodeMap[h], sc) + lc.configuredHeadersStatusCodeMap[h] = append(lc.configuredHeadersStatusCodeMap[h], sc) } lc.StatusCodeHeaderMap[sc] = chl } - ll.CustomHeadersList = append(ll.CustomHeadersList, lc) + ll = append(ll, lc) } return ll } -func (c *ListenersCustomResponseHeadersList) getListenerMap(address string) []*ListenerCustomHeaders { - if c.CustomHeadersList == nil { - return nil - } - - // either looking for a specific listener, or if listener address isn't given, - // checking for all available listeners - var lch []*ListenerCustomHeaders - if address == "" { - lch = c.CustomHeadersList - } else { - for _, l := range c.CustomHeadersList { - if l.Address == address { - lch = append(lch, l) - } - } - if len(lch) == 0 { - return nil - } - } - return lch -} - func (l *ListenerCustomHeaders) findCustomHeaderMatchStatusCode(sc string, hn string) string { - getHeader := func(ch []*CustomHeader) string { for _, h := range ch { if h.Name == hn { @@ -130,7 +97,7 @@ func (l *ListenerCustomHeaders) findCustomHeaderMatchStatusCode(sc string, hn st // Checking for the Yxx pattern var firstDig string if len(sc) == 3 { - firstDig = strings.Split(sc, "")[0] + firstDig = string(sc[0]) } if firstDig != "" { s := fmt.Sprintf("%vxx", firstDig) @@ -147,15 +114,14 @@ func (l *ListenerCustomHeaders) findCustomHeaderMatchStatusCode(sc string, hn st // At this point, we could not find a match for the given status code in the config file // so, we just return the "default" ones h := getHeader(hm["default"]) - if h != ""{ + if h != "" { return h } return "" } -func(l *ListenerCustomHeaders) FetchHeaderForStatusCode(header, sc string) (string, error) { - +func (l *ListenerCustomHeaders) FetchHeaderForStatusCode(header, sc string) (string, error) { if header == "" { return "", fmt.Errorf("invalid target header") } @@ -176,7 +142,6 @@ func(l *ListenerCustomHeaders) FetchHeaderForStatusCode(header, sc string) (stri } func (l *ListenerCustomHeaders) ExistCustomResponseHeader(header string) bool { - if header == "" { return false } @@ -187,7 +152,7 @@ func (l *ListenerCustomHeaders) ExistCustomResponseHeader(header string) bool { hn := textproto.CanonicalMIMEHeaderKey(header) - hs := l.ConfiguredHeadersStatusCodeMap + hs := l.configuredHeadersStatusCodeMap _, ok := hs[hn] return ok -} \ No newline at end of file +} diff --git a/vault/custom_response_headers_test.go b/vault/custom_response_headers_test.go index e4e4c5f33ab1c..681f37e5ad350 100644 --- a/vault/custom_response_headers_test.go +++ b/vault/custom_response_headers_test.go @@ -15,37 +15,36 @@ import ( "github.com/hashicorp/vault/sdk/physical/inmem" ) - -var defaultCustomHeaders = map[string]string { +var defaultCustomHeaders = map[string]string{ "Strict-Transport-Security": "max-age=1; domains", - "Content-Security-Policy": "default-src 'others'", - "X-Vault-Ignored": "ignored", - "X-Custom-Header": "Custom header value default", - "X-Frame-Options": "Deny", - "X-Content-Type-Options": "nosniff", - "Content-Type": "text/plain; charset=utf-8", - "X-XSS-Protection": "1; mode=block", + "Content-Security-Policy": "default-src 'others'", + "X-Vault-Ignored": "ignored", + "X-Custom-Header": "Custom header value default", + "X-Frame-Options": "Deny", + "X-Content-Type-Options": "nosniff", + "Content-Type": "text/plain; charset=utf-8", + "X-XSS-Protection": "1; mode=block", } -var customHeaders307 = map[string]string { +var customHeaders307 = map[string]string{ "X-Custom-Header": "Custom header value 307", } -var customHeader3xx = map[string]string { +var customHeader3xx = map[string]string{ "X-Vault-Ignored-3xx": "Ignored 3xx", - "X-Custom-Header": "Custom header value 3xx", + "X-Custom-Header": "Custom header value 3xx", } -var customHeaders200 = map[string]string { - "Someheader-200": "200", +var customHeaders200 = map[string]string{ + "Someheader-200": "200", "X-Custom-Header": "Custom header value 200", } -var customHeader2xx = map[string]string { +var customHeader2xx = map[string]string{ "X-Custom-Header": "Custom header value 2xx", } -var customHeader400 = map[string]string { +var customHeader400 = map[string]string{ "Someheader-400": "400", } @@ -58,29 +57,25 @@ func TestConfigCustomHeaders(t *testing.T) { logl := &logical.InmemStorage{} uiConfig := NewUIConfig(true, phys, logl) - rawListenerConfig := []*configutil.Listener { + rawListenerConfig := []*configutil.Listener{ { Type: "tcp", Address: "127.0.0.1:443", CustomResponseHeaders: map[string]map[string]string{ "default": defaultCustomHeaders, - "307": customHeaders307, - "3xx": customHeader3xx, - "200": customHeaders200, - "2xx": customHeader2xx, - "400": customHeader400, + "307": customHeaders307, + "3xx": customHeader3xx, + "200": customHeaders200, + "2xx": customHeader2xx, + "400": customHeader400, }, }, } uiHeaders, err := uiConfig.Headers(context.Background()) - customListenerHeader := NewListenerCustomHeader(rawListenerConfig, logger, uiHeaders) - if customListenerHeader == nil { - t.Fatalf("custom header config should be configured") - } - listenerCustomHeaders := customListenerHeader.getListenerMap("127.0.0.1:443") + listenerCustomHeaders := NewListenerCustomHeader(rawListenerConfig, logger, uiHeaders) if listenerCustomHeaders == nil || len(listenerCustomHeaders) != 1 { - t.Fatalf("failed to find listener specific custom header") + t.Fatalf("failed to get custom header configuration") } lch := listenerCustomHeaders[0] @@ -95,7 +90,6 @@ func TestConfigCustomHeaders(t *testing.T) { if !lch.ExistCustomResponseHeader("X-Custom-Header") { t.Fatalf("header name with X-Vault prefix is not valid") } - } func TestCustomResponseHeadersConfigInteractUiConfig(t *testing.T) { @@ -105,17 +99,17 @@ func TestCustomResponseHeadersConfigInteractUiConfig(t *testing.T) { b.(*SystemBackend).Core.systemBarrierView = view logger := logging.NewVaultLogger(log.Trace) - rawListenerConfig := []*configutil.Listener { + rawListenerConfig := []*configutil.Listener{ { Type: "tcp", Address: "127.0.0.1:443", CustomResponseHeaders: map[string]map[string]string{ "default": defaultCustomHeaders, - "307": customHeaders307, - "3xx": customHeader3xx, - "200": customHeaders200, - "2xx": customHeader2xx, - "400": customHeader400, + "307": customHeaders307, + "3xx": customHeader3xx, + "200": customHeaders200, + "2xx": customHeader2xx, + "400": customHeader400, }, }, } @@ -127,7 +121,7 @@ func TestCustomResponseHeadersConfigInteractUiConfig(t *testing.T) { if customListenerHeader == nil { t.Fatalf("custom header config should be configured") } - b.(*SystemBackend).Core.customListenerHeader = customListenerHeader + b.(*SystemBackend).Core.customListenerHeader.Store(customListenerHeader) clh := b.(*SystemBackend).Core.customListenerHeader if clh == nil { t.Fatalf("custom header config should be configured in core") @@ -145,7 +139,7 @@ func TestCustomResponseHeadersConfigInteractUiConfig(t *testing.T) { if err == nil { t.Fatal("request did not fail on setting a header that is present in custom response headers") } - if !strings.Contains(resp.Data["error"].(string), fmt.Sprintf("This header already exist in server configuration. %v", "X-Custom-Header")) { + if !strings.Contains(resp.Data["error"].(string), fmt.Sprintf("This header already exists in the server configuration and cannot be set in the UI.")) { t.Fatalf("failed to get the expected error") } @@ -163,13 +157,14 @@ func TestCustomResponseHeadersConfigInteractUiConfig(t *testing.T) { t.Fatalf("should not be able to set a header that is in custom response headers") } + // setting an ui specific header req = logical.TestRequest(t, logical.UpdateOperation, "config/ui/headers/X-CustomUiHeader") req.Data["values"] = []string{"Ui header value"} req.ResponseWriter = hw _, err = b.HandleRequest(namespace.RootContext(nil), req) if err != nil { - t.Fatal("request did not fail on setting a header that is present in custom response headers") + t.Fatal("request failed on setting a header that is not present in custom response headers") } h, err = b.(*SystemBackend).Core.uiConfig.Headers(context.Background()) diff --git a/vault/logical_system.go b/vault/logical_system.go index e554c38f1f7e0..43ce5556af102 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -2626,7 +2626,7 @@ func (b *SystemBackend) handleConfigUIHeadersUpdate(ctx context.Context, req *lo // ExistCustomResponseHeader support checking for custom headers per listener address // If address is an empty string, it checks all listeners for the header. if b.Core.ExistCustomResponseHeader(header, "") { - return logical.ErrorResponse(fmt.Sprintf("This header already exist in server configuration. %v", header)), logical.ErrInvalidRequest + return logical.ErrorResponse("This header already exists in the server configuration and cannot be set in the UI."), logical.ErrInvalidRequest } value.Add(header, v) } diff --git a/vault/testing.go b/vault/testing.go index 7d178d1c32252..f6bb28a0124d3 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -128,53 +128,21 @@ func TestCoreWithSeal(t testing.T, testSeal Seal, enableRaw bool) *Core { return TestCoreWithSealAndUI(t, conf) } -func TestCoreWithCustomResponseHeaderAndUI(t testing.T, enableUI bool) (*Core, [][]byte, string) { +func TestCoreWithCustomResponseHeaderAndUI(t testing.T, CustomResponseHeaders map[string]map[string]string, enableUI bool) (*Core, [][]byte, string) { confRaw := &server.Config{ SharedConfig: &configutil.SharedConfig{ Listeners: []*configutil.Listener{ { - Type: "tcp", - Address: "127.0.0.1", - CustomResponseHeaders: map[string]map[string]string{ - "default": { - "Strict-Transport-Security": "max-age=1; domains", - "Content-Security-Policy": "default-src 'others'", - "X-Vault-Ignored": "ignored", - "X-Custom-Header": "Custom header value default", - "X-Frame-Options": "Deny", - "X-Content-Type-Options": "nosniff", - "Content-Type": "application/json", - "X-XSS-Protection": "1; mode=block", - }, - "307": {"X-Custom-Header": "Custom header value 307"}, - "3xx": { - "X-Custom-Header": "Custom header value 3xx", - "X-Vault-Ignored-3xx": "Ignored 3xx", - }, - "200": { - "Someheader-200": "200", - "X-Custom-Header": "Custom header value 200", - }, - "2xx": { - "X-Custom-Header": "Custom header value 2xx", - }, - "400": { - "Someheader-400": "400", - }, - "405":{ - "Someheader-405": "405", - }, - "4xx": { - "Someheader-4xx": "4xx", - }, - }, + Type: "tcp", + Address: "127.0.0.1", + CustomResponseHeaders: CustomResponseHeaders, }, }, DisableMlock: true, }, } conf := &CoreConfig{ - RawConfig: confRaw, + RawConfig: confRaw, EnableUI: enableUI, EnableRaw: true, BuiltinRegistry: NewMockBuiltinRegistry(), From 879d4892db2f7f9d88905e380308c20d94eed3be Mon Sep 17 00:00:00 2001 From: hamid ghaf Date: Mon, 20 Sep 2021 19:42:30 -0700 Subject: [PATCH 20/25] fixing some agent tests --- command/agent/config/config_test.go | 41 ++++++++++++++++++- .../config_custom_response_headers_test.go | 2 +- vault/core.go | 1 + 3 files changed, 41 insertions(+), 3 deletions(-) diff --git a/command/agent/config/config_test.go b/command/agent/config/config_test.go index 0db8cf91954f7..0b57daf417c2e 100644 --- a/command/agent/config/config_test.go +++ b/command/agent/config/config_test.go @@ -11,6 +11,15 @@ import ( "github.com/hashicorp/vault/sdk/helper/pointerutil" ) +var DefaultCustomHeaders = map[string]string{ + "Strict-Transport-Security": "max-age=31536000; includeSubDomains", + "Content-Security-Policy": "default-src 'none'; connect-src 'self'; img-src 'self' data:; script-src 'self'; style-src 'unsafe-inline' 'self'; form-action 'none'; frame-ancestors 'none'; font-src 'self'", + "X-Frame-Options": "Deny", + "X-Content-Type-Options": "nosniff", + "Content-Type": "application/json", + "X-XSS-Protection": "1; mode=block", +} + func TestLoadConfigFile_AgentCache(t *testing.T) { config, err := LoadConfig("./test-fixtures/config-cache.hcl") if err != nil { @@ -28,17 +37,26 @@ func TestLoadConfigFile_AgentCache(t *testing.T) { SocketMode: "configmode", SocketUser: "configuser", SocketGroup: "configgroup", + CustomResponseHeaders: map[string]map[string]string{ + "default": DefaultCustomHeaders, + }, }, { Type: "tcp", Address: "127.0.0.1:8300", TLSDisable: true, + CustomResponseHeaders: map[string]map[string]string{ + "default": DefaultCustomHeaders, + }, }, { Type: "tcp", Address: "127.0.0.1:8400", TLSKeyFile: "/path/to/cakey.pem", TLSCertFile: "/path/to/cacert.pem", + CustomResponseHeaders: map[string]map[string]string{ + "default": DefaultCustomHeaders, + }, }, }, }, @@ -237,6 +255,9 @@ func TestLoadConfigFile_AgentCache_NoAutoAuth(t *testing.T) { Type: "tcp", Address: "127.0.0.1:8300", TLSDisable: true, + CustomResponseHeaders: map[string]map[string]string{ + "default": DefaultCustomHeaders, + }, }, }, }, @@ -315,6 +336,9 @@ func TestLoadConfigFile_AgentCache_AutoAuth_NoSink(t *testing.T) { Type: "tcp", Address: "127.0.0.1:8300", TLSDisable: true, + CustomResponseHeaders: map[string]map[string]string{ + "default": DefaultCustomHeaders, + }, }, }, PidFile: "./pidfile", @@ -359,6 +383,9 @@ func TestLoadConfigFile_AgentCache_AutoAuth_Force(t *testing.T) { Type: "tcp", Address: "127.0.0.1:8300", TLSDisable: true, + CustomResponseHeaders: map[string]map[string]string{ + "default": DefaultCustomHeaders, + }, }, }, PidFile: "./pidfile", @@ -403,6 +430,9 @@ func TestLoadConfigFile_AgentCache_AutoAuth_True(t *testing.T) { Type: "tcp", Address: "127.0.0.1:8300", TLSDisable: true, + CustomResponseHeaders: map[string]map[string]string{ + "default": DefaultCustomHeaders, + }, }, }, PidFile: "./pidfile", @@ -447,6 +477,9 @@ func TestLoadConfigFile_AgentCache_AutoAuth_False(t *testing.T) { Type: "tcp", Address: "127.0.0.1:8300", TLSDisable: true, + CustomResponseHeaders: map[string]map[string]string{ + "default": DefaultCustomHeaders, + }, }, }, PidFile: "./pidfile", @@ -512,6 +545,9 @@ func TestLoadConfigFile_AgentCache_Persist(t *testing.T) { Type: "tcp", Address: "127.0.0.1:8300", TLSDisable: true, + CustomResponseHeaders: map[string]map[string]string{ + "default": DefaultCustomHeaders, + }, }, }, }, @@ -536,7 +572,6 @@ func TestLoadConfigFile_AgentCache_PersistMissingType(t *testing.T) { } func TestLoadConfigFile_TemplateConfig(t *testing.T) { - testCases := map[string]struct { fixturePath string expectedTemplateConfig TemplateConfig @@ -586,7 +621,6 @@ func TestLoadConfigFile_TemplateConfig(t *testing.T) { } }) } - } // TestLoadConfigFile_Template tests template definitions in Vault Agent @@ -904,6 +938,9 @@ func TestLoadConfigFile_EnforceConsistency(t *testing.T) { Type: "tcp", Address: "127.0.0.1:8300", TLSDisable: true, + CustomResponseHeaders: map[string]map[string]string{ + "default": DefaultCustomHeaders, + }, }, }, PidFile: "", diff --git a/command/server/config_custom_response_headers_test.go b/command/server/config_custom_response_headers_test.go index a00b54fc2478d..db2e73b18a78c 100644 --- a/command/server/config_custom_response_headers_test.go +++ b/command/server/config_custom_response_headers_test.go @@ -80,7 +80,7 @@ func TestCustomResponseHeadersConfigsMultipleListeners(t *testing.T) { if diff := deep.Equal(expectedCustomResponseHeader, config.Listeners[1].CustomResponseHeaders); diff == nil { t.Fatalf(fmt.Sprintf("parsed custom headers do not match the expected ones, difference: %v", diff)) } - if diff := deep.Equal(expectedCustomResponseHeader["default"], config.Listeners[0].CustomResponseHeaders["default"]); diff != nil { + if diff := deep.Equal(expectedCustomResponseHeader["default"], config.Listeners[1].CustomResponseHeaders["default"]); diff != nil { t.Fatalf(fmt.Sprintf("parsed custom headers do not match the expected ones, difference: %v", diff)) } } diff --git a/vault/core.go b/vault/core.go index 0e98c168ef066..05caab94d2c09 100644 --- a/vault/core.go +++ b/vault/core.go @@ -517,6 +517,7 @@ type Core struct { // clusterListener starts up and manages connections on the cluster ports clusterListener *atomic.Value + // customListenerHeader holds custom response headers for a listener customListenerHeader *atomic.Value // Telemetry objects From 2b7dd501f5870b8122ee310a94e9f5cf18228283 Mon Sep 17 00:00:00 2001 From: hamid ghaf Date: Tue, 21 Sep 2021 17:14:16 -0700 Subject: [PATCH 21/25] skipping custom headers from agent listener config, removing two of the default headers as they cause issues with Vault in UI mode Adding X-Content-Type-Options to the ui default headers Let Content-Type be set as before --- command/agent/config/config.go | 7 ++++ command/agent/config/config_test.go | 39 ------------------- .../config_custom_response_headers_test.go | 2 - command/server/config_test_helpers.go | 2 - .../configutil/http_response_headers.go | 10 +---- vault/core.go | 12 +++--- vault/custom_response_headers_test.go | 4 +- vault/logical_system.go | 2 - vault/ui.go | 1 + 9 files changed, 17 insertions(+), 62 deletions(-) diff --git a/command/agent/config/config.go b/command/agent/config/config.go index 9438bd327444d..502d512d15a24 100644 --- a/command/agent/config/config.go +++ b/command/agent/config/config.go @@ -35,6 +35,7 @@ func (c *Config) Prune() { l.RawConfig = nil l.Profiling.UnusedKeys = nil l.Telemetry.UnusedKeys = nil + l.CustomResponseHeaders = nil } c.FoundKeys = nil c.UnusedKeys = nil @@ -172,6 +173,12 @@ func LoadConfig(path string) (*Config, error) { if err != nil { return nil, err } + + // Pruning custom headers for Agent for now + for _, ln := range sharedConfig.Listeners { + ln.CustomResponseHeaders = nil + } + result.SharedConfig = sharedConfig list, ok := obj.Node.(*ast.ObjectList) diff --git a/command/agent/config/config_test.go b/command/agent/config/config_test.go index 0b57daf417c2e..252461236c89b 100644 --- a/command/agent/config/config_test.go +++ b/command/agent/config/config_test.go @@ -11,15 +11,6 @@ import ( "github.com/hashicorp/vault/sdk/helper/pointerutil" ) -var DefaultCustomHeaders = map[string]string{ - "Strict-Transport-Security": "max-age=31536000; includeSubDomains", - "Content-Security-Policy": "default-src 'none'; connect-src 'self'; img-src 'self' data:; script-src 'self'; style-src 'unsafe-inline' 'self'; form-action 'none'; frame-ancestors 'none'; font-src 'self'", - "X-Frame-Options": "Deny", - "X-Content-Type-Options": "nosniff", - "Content-Type": "application/json", - "X-XSS-Protection": "1; mode=block", -} - func TestLoadConfigFile_AgentCache(t *testing.T) { config, err := LoadConfig("./test-fixtures/config-cache.hcl") if err != nil { @@ -37,26 +28,17 @@ func TestLoadConfigFile_AgentCache(t *testing.T) { SocketMode: "configmode", SocketUser: "configuser", SocketGroup: "configgroup", - CustomResponseHeaders: map[string]map[string]string{ - "default": DefaultCustomHeaders, - }, }, { Type: "tcp", Address: "127.0.0.1:8300", TLSDisable: true, - CustomResponseHeaders: map[string]map[string]string{ - "default": DefaultCustomHeaders, - }, }, { Type: "tcp", Address: "127.0.0.1:8400", TLSKeyFile: "/path/to/cakey.pem", TLSCertFile: "/path/to/cacert.pem", - CustomResponseHeaders: map[string]map[string]string{ - "default": DefaultCustomHeaders, - }, }, }, }, @@ -255,9 +237,6 @@ func TestLoadConfigFile_AgentCache_NoAutoAuth(t *testing.T) { Type: "tcp", Address: "127.0.0.1:8300", TLSDisable: true, - CustomResponseHeaders: map[string]map[string]string{ - "default": DefaultCustomHeaders, - }, }, }, }, @@ -336,9 +315,6 @@ func TestLoadConfigFile_AgentCache_AutoAuth_NoSink(t *testing.T) { Type: "tcp", Address: "127.0.0.1:8300", TLSDisable: true, - CustomResponseHeaders: map[string]map[string]string{ - "default": DefaultCustomHeaders, - }, }, }, PidFile: "./pidfile", @@ -383,9 +359,6 @@ func TestLoadConfigFile_AgentCache_AutoAuth_Force(t *testing.T) { Type: "tcp", Address: "127.0.0.1:8300", TLSDisable: true, - CustomResponseHeaders: map[string]map[string]string{ - "default": DefaultCustomHeaders, - }, }, }, PidFile: "./pidfile", @@ -430,9 +403,6 @@ func TestLoadConfigFile_AgentCache_AutoAuth_True(t *testing.T) { Type: "tcp", Address: "127.0.0.1:8300", TLSDisable: true, - CustomResponseHeaders: map[string]map[string]string{ - "default": DefaultCustomHeaders, - }, }, }, PidFile: "./pidfile", @@ -477,9 +447,6 @@ func TestLoadConfigFile_AgentCache_AutoAuth_False(t *testing.T) { Type: "tcp", Address: "127.0.0.1:8300", TLSDisable: true, - CustomResponseHeaders: map[string]map[string]string{ - "default": DefaultCustomHeaders, - }, }, }, PidFile: "./pidfile", @@ -545,9 +512,6 @@ func TestLoadConfigFile_AgentCache_Persist(t *testing.T) { Type: "tcp", Address: "127.0.0.1:8300", TLSDisable: true, - CustomResponseHeaders: map[string]map[string]string{ - "default": DefaultCustomHeaders, - }, }, }, }, @@ -938,9 +902,6 @@ func TestLoadConfigFile_EnforceConsistency(t *testing.T) { Type: "tcp", Address: "127.0.0.1:8300", TLSDisable: true, - CustomResponseHeaders: map[string]map[string]string{ - "default": DefaultCustomHeaders, - }, }, }, PidFile: "", diff --git a/command/server/config_custom_response_headers_test.go b/command/server/config_custom_response_headers_test.go index db2e73b18a78c..29e5eb2766ef0 100644 --- a/command/server/config_custom_response_headers_test.go +++ b/command/server/config_custom_response_headers_test.go @@ -13,8 +13,6 @@ var defaultCustomHeaders = map[string]string{ "X-Vault-Ignored": "ignored", "X-Custom-Header": "Custom header value default", "X-Frame-Options": "Deny", - "X-Content-Type-Options": "nosniff", - "Content-Type": "application/json", "X-XSS-Protection": "1; mode=block", } diff --git a/command/server/config_test_helpers.go b/command/server/config_test_helpers.go index d30726e33e130..6fec212025767 100644 --- a/command/server/config_test_helpers.go +++ b/command/server/config_test_helpers.go @@ -20,8 +20,6 @@ var DefaultCustomHeaders = map[string]string{ "Strict-Transport-Security": "max-age=31536000; includeSubDomains", "Content-Security-Policy": "default-src 'none'; connect-src 'self'; img-src 'self' data:; script-src 'self'; style-src 'unsafe-inline' 'self'; form-action 'none'; frame-ancestors 'none'; font-src 'self'", "X-Frame-Options": "Deny", - "X-Content-Type-Options": "nosniff", - "Content-Type": "application/json", "X-XSS-Protection": "1; mode=block", } diff --git a/internalshared/configutil/http_response_headers.go b/internalshared/configutil/http_response_headers.go index 26da0aa67a99c..4008107b28ad9 100644 --- a/internalshared/configutil/http_response_headers.go +++ b/internalshared/configutil/http_response_headers.go @@ -11,9 +11,7 @@ var DefaultHeaderNames = []string{ "Content-Security-Policy", "X-XSS-Protection", "X-Frame-Options", - "X-Content-Type-Options", "Strict-Transport-Security", - "Content-Type", } var ValidCustomStatusCodeCollection = []string{ @@ -29,9 +27,7 @@ const ( contentSecurityPolicy = "default-src 'none'; connect-src 'self'; img-src 'self' data:; script-src 'self'; style-src 'unsafe-inline' 'self'; form-action 'none'; frame-ancestors 'none'; font-src 'self'" xXssProtection = "1; mode=block" xFrameOptions = "Deny" - xContentTypeOptions = "nosniff" strictTransportSecurity = "max-age=31536000; includeSubDomains" - contentType = "application/json" ) func GetDefaultHeaderValue(h string) string { @@ -42,12 +38,8 @@ func GetDefaultHeaderValue(h string) string { return xXssProtection case "X-Frame-Options": return xFrameOptions - case "X-Content-Type-Options": - return xContentTypeOptions case "Strict-Transport-Security": return strictTransportSecurity - case "Content-Type": - return contentType default: return "" } @@ -174,7 +166,7 @@ func parseHeaderValues(h interface{}) (string, error) { if _, ok := vh.(string); !ok { return "", fmt.Errorf("found a non-string header value: %v", vh) } - headerVal := vh.(string) + headerVal := strings.TrimSpace(vh.(string)) if headerVal == "" { continue } diff --git a/vault/core.go b/vault/core.go index 05caab94d2c09..0f13c6b971558 100644 --- a/vault/core.go +++ b/vault/core.go @@ -2662,19 +2662,17 @@ func (c *Core) getCustomHeadersListenerList(listenerAdd string) []*ListenerCusto // checking for all available listeners var lch []*ListenerCustomHeaders if listenerAdd == "" { - lch = customHeadersList + return customHeadersList } else { for _, l := range customHeadersList { if l.Address == listenerAdd { lch = append(lch, l) - break + return lch } } - if len(lch) == 0 { - return nil - } } - return lch + + return nil } func (c *Core) GetListenerCustomResponseHeaders(listenerAdd string) *ListenerCustomHeaders { @@ -2693,6 +2691,8 @@ func (c *Core) GetListenerCustomResponseHeaders(listenerAdd string) *ListenerCus return lch[0] } +// ExistCustomResponseHeader support checking for custom headers per listener address +// If the address is an empty string, it checks all listeners for the header. func (c *Core) ExistCustomResponseHeader(header string, listenerAdd string) bool { lch := c.getCustomHeadersListenerList(listenerAdd) if lch == nil { diff --git a/vault/custom_response_headers_test.go b/vault/custom_response_headers_test.go index 681f37e5ad350..1ea79cf7ee5eb 100644 --- a/vault/custom_response_headers_test.go +++ b/vault/custom_response_headers_test.go @@ -164,11 +164,11 @@ func TestCustomResponseHeadersConfigInteractUiConfig(t *testing.T) { _, err = b.HandleRequest(namespace.RootContext(nil), req) if err != nil { - t.Fatal("request failed on setting a header that is not present in custom response headers") + t.Fatal("request failed on setting a header that is not present in custom response headers.", "error:", err) } h, err = b.(*SystemBackend).Core.uiConfig.Headers(context.Background()) if h.Get("X-CustomUiHeader") != "Ui header value" { - t.Fatalf("failed to sett a header that is not in custom response headers") + t.Fatalf("failed to set a header that is not in custom response headers") } } diff --git a/vault/logical_system.go b/vault/logical_system.go index 43ce5556af102..34f9df785b9e7 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -2623,8 +2623,6 @@ func (b *SystemBackend) handleConfigUIHeadersUpdate(ctx context.Context, req *lo // Translate the list of values to the valid header string value := http.Header{} for _, v := range values { - // ExistCustomResponseHeader support checking for custom headers per listener address - // If address is an empty string, it checks all listeners for the header. if b.Core.ExistCustomResponseHeader(header, "") { return logical.ErrorResponse("This header already exists in the server configuration and cannot be set in the UI."), logical.ErrInvalidRequest } diff --git a/vault/ui.go b/vault/ui.go index e845420864f33..277f62cc8a918 100644 --- a/vault/ui.go +++ b/vault/ui.go @@ -33,6 +33,7 @@ type UIConfig struct { func NewUIConfig(enabled bool, physicalStorage physical.Backend, barrierStorage logical.Storage) *UIConfig { defaultHeaders := http.Header{} defaultHeaders.Set("Service-Worker-Allowed", "/") + defaultHeaders.Set("X-Content-Type-Options", "nosniff") return &UIConfig{ physicalStorage: physicalStorage, From 0964ef82e41ecde359eadd1001126f81740a7d83 Mon Sep 17 00:00:00 2001 From: hamid ghaf Date: Tue, 5 Oct 2021 16:19:27 -0700 Subject: [PATCH 22/25] Removing default custom headers, and renaming some function varibles --- changelog/12485.txt | 2 +- .../config_custom_response_headers_test.go | 31 +++++- command/server/config_test_helpers.go | 53 +++------ ...om_response_headers_multiple_listeners.hcl | 17 ++- .../configutil/http_response_headers.go | 104 ++++++------------ vault/core.go | 4 +- vault/custom_response_headers.go | 70 ++++++------ vault/ui.go | 1 + 8 files changed, 129 insertions(+), 153 deletions(-) diff --git a/changelog/12485.txt b/changelog/12485.txt index 6ccb4432d67d9..6c8a87cd21cac 100644 --- a/changelog/12485.txt +++ b/changelog/12485.txt @@ -1,3 +1,3 @@ ```release-note:feature -http: Enable users to customize HTTP response headers +**Customizable HTTP Headers**: Add support to define custom HTTP headers for root path (`/`) and also on API endpoints (`/v1/*`) ``` diff --git a/command/server/config_custom_response_headers_test.go b/command/server/config_custom_response_headers_test.go index 29e5eb2766ef0..5380568c25107 100644 --- a/command/server/config_custom_response_headers_test.go +++ b/command/server/config_custom_response_headers_test.go @@ -12,8 +12,6 @@ var defaultCustomHeaders = map[string]string{ "Content-Security-Policy": "default-src 'others'", "X-Vault-Ignored": "ignored", "X-Custom-Header": "Custom header value default", - "X-Frame-Options": "Deny", - "X-XSS-Protection": "1; mode=block", } var customHeaders307 = map[string]string{ @@ -38,6 +36,17 @@ var customHeader400 = map[string]string{ "Someheader-400": "400", } +var defaultCustomHeadersMultiListener = map[string]string{ + "Strict-Transport-Security": "max-age=31536000; includeSubDomains", + "Content-Security-Policy": "default-src 'others'", + "X-Vault-Ignored": "ignored", + "X-Custom-Header": "Custom header value default", +} + +var defaultSTS = map[string]string{ + "Strict-Transport-Security": "max-age=31536000; includeSubDomains", +} + func TestCustomResponseHeadersConfigs(t *testing.T) { expectedCustomResponseHeader := map[string]map[string]string{ "default": defaultCustomHeaders, @@ -59,7 +68,7 @@ func TestCustomResponseHeadersConfigs(t *testing.T) { func TestCustomResponseHeadersConfigsMultipleListeners(t *testing.T) { expectedCustomResponseHeader := map[string]map[string]string{ - "default": defaultCustomHeaders, + "default": defaultCustomHeadersMultiListener, "307": customHeaders307, "3xx": customHeader3xx, "200": customHeaders200, @@ -81,4 +90,20 @@ func TestCustomResponseHeadersConfigsMultipleListeners(t *testing.T) { if diff := deep.Equal(expectedCustomResponseHeader["default"], config.Listeners[1].CustomResponseHeaders["default"]); diff != nil { t.Fatalf(fmt.Sprintf("parsed custom headers do not match the expected ones, difference: %v", diff)) } + + if diff := deep.Equal(expectedCustomResponseHeader, config.Listeners[2].CustomResponseHeaders); diff == nil { + t.Fatalf(fmt.Sprintf("parsed custom headers do not match the expected ones, difference: %v", diff)) + } + + if diff := deep.Equal(defaultSTS, config.Listeners[2].CustomResponseHeaders["default"]); diff != nil { + t.Fatalf(fmt.Sprintf("parsed custom headers do not match the expected ones, difference: %v", diff)) + } + + if diff := deep.Equal(expectedCustomResponseHeader, config.Listeners[3].CustomResponseHeaders); diff == nil { + t.Fatalf(fmt.Sprintf("parsed custom headers do not match the expected ones, difference: %v", diff)) + } + + if diff := deep.Equal(defaultSTS, config.Listeners[3].CustomResponseHeaders["default"]); diff != nil { + t.Fatalf(fmt.Sprintf("parsed custom headers do not match the expected ones, difference: %v", diff)) + } } diff --git a/command/server/config_test_helpers.go b/command/server/config_test_helpers.go index 6fec212025767..8936a0244090c 100644 --- a/command/server/config_test_helpers.go +++ b/command/server/config_test_helpers.go @@ -16,11 +16,10 @@ import ( "github.com/hashicorp/vault/internalshared/configutil" ) -var DefaultCustomHeaders = map[string]string{ - "Strict-Transport-Security": "max-age=31536000; includeSubDomains", - "Content-Security-Policy": "default-src 'none'; connect-src 'self'; img-src 'self' data:; script-src 'self'; style-src 'unsafe-inline' 'self'; form-action 'none'; frame-ancestors 'none'; font-src 'self'", - "X-Frame-Options": "Deny", - "X-XSS-Protection": "1; mode=block", +var DefaultCustomHeaders = map[string]map[string]string { + "default": { + "Strict-Transport-Security": configutil.StrictTransportSecurity, + }, } func boolPointer(x bool) *bool { @@ -39,9 +38,7 @@ func testConfigRaftRetryJoin(t *testing.T) { { Type: "tcp", Address: "127.0.0.1:8200", - CustomResponseHeaders: map[string]map[string]string{ - "default": DefaultCustomHeaders, - }, + CustomResponseHeaders: DefaultCustomHeaders, }, }, DisableMlock: true, @@ -74,9 +71,7 @@ func testLoadConfigFile_topLevel(t *testing.T, entropy *configutil.Entropy) { { Type: "tcp", Address: "127.0.0.1:443", - CustomResponseHeaders: map[string]map[string]string{ - "default": DefaultCustomHeaders, - }, + CustomResponseHeaders: DefaultCustomHeaders, }, }, @@ -187,16 +182,12 @@ func testLoadConfigFile_json2(t *testing.T, entropy *configutil.Entropy) { { Type: "tcp", Address: "127.0.0.1:443", - CustomResponseHeaders: map[string]map[string]string{ - "default": DefaultCustomHeaders, - }, + CustomResponseHeaders: DefaultCustomHeaders, }, { Type: "tcp", Address: "127.0.0.1:444", - CustomResponseHeaders: map[string]map[string]string{ - "default": DefaultCustomHeaders, - }, + CustomResponseHeaders: DefaultCustomHeaders, }, }, @@ -355,9 +346,7 @@ func testLoadConfigFileIntegerAndBooleanValuesCommon(t *testing.T, path string) { Type: "tcp", Address: "127.0.0.1:8200", - CustomResponseHeaders: map[string]map[string]string{ - "default": DefaultCustomHeaders, - }, + CustomResponseHeaders: DefaultCustomHeaders, }, }, DisableMlock: true, @@ -401,9 +390,7 @@ func testLoadConfigFile(t *testing.T) { { Type: "tcp", Address: "127.0.0.1:443", - CustomResponseHeaders: map[string]map[string]string{ - "default": DefaultCustomHeaders, - }, + CustomResponseHeaders: DefaultCustomHeaders, }, }, @@ -550,9 +537,7 @@ func testLoadConfigFile_json(t *testing.T) { { Type: "tcp", Address: "127.0.0.1:443", - CustomResponseHeaders: map[string]map[string]string{ - "default": DefaultCustomHeaders, - }, + CustomResponseHeaders: DefaultCustomHeaders, }, }, @@ -638,9 +623,7 @@ func testLoadConfigDir(t *testing.T) { { Type: "tcp", Address: "127.0.0.1:443", - CustomResponseHeaders: map[string]map[string]string{ - "default": DefaultCustomHeaders, - }, + CustomResponseHeaders: DefaultCustomHeaders, }, }, @@ -849,9 +832,7 @@ listener "tcp" { Profiling: configutil.ListenerProfiling{ UnauthenticatedPProfAccess: true, }, - CustomResponseHeaders: map[string]map[string]string{ - "default": DefaultCustomHeaders, - }, + CustomResponseHeaders: DefaultCustomHeaders, }, }, }, @@ -879,9 +860,7 @@ func testParseSeals(t *testing.T) { { Type: "tcp", Address: "127.0.0.1:443", - CustomResponseHeaders: map[string]map[string]string{ - "default": DefaultCustomHeaders, - }, + CustomResponseHeaders: DefaultCustomHeaders, }, }, Seals: []*configutil.KMS{ @@ -935,9 +914,7 @@ func testLoadConfigFileLeaseMetrics(t *testing.T) { { Type: "tcp", Address: "127.0.0.1:443", - CustomResponseHeaders: map[string]map[string]string{ - "default": DefaultCustomHeaders, - }, + CustomResponseHeaders: DefaultCustomHeaders, }, }, diff --git a/command/server/test-fixtures/config_custom_response_headers_multiple_listeners.hcl b/command/server/test-fixtures/config_custom_response_headers_multiple_listeners.hcl index 15cd7b49c89b6..11aa099232f91 100644 --- a/command/server/test-fixtures/config_custom_response_headers_multiple_listeners.hcl +++ b/command/server/test-fixtures/config_custom_response_headers_multiple_listeners.hcl @@ -4,7 +4,6 @@ listener "tcp" { tls_disable = true custom_response_headers { "default" = { - "Strict-Transport-Security" = ["max-age=1","domains"], "Content-Security-Policy" = ["default-src 'others'"], "X-Vault-Ignored" = ["ignored"], "X-Custom-Header" = ["Custom header value default"], @@ -33,11 +32,25 @@ listener "tcp" { tls_disable = true custom_response_headers { "default" = { - "Strict-Transport-Security" = ["max-age=1","domains"], "Content-Security-Policy" = ["default-src 'others'"], "X-Vault-Ignored" = ["ignored"], "X-Custom-Header" = ["Custom header value default"], } } } +listener "tcp" { + address = "127.0.0.3:8200" + tls_disable = true + custom_response_headers { + "2xx" = { + "X-Custom-Header" = ["Custom header value 2xx"] + } + } +} +listener "tcp" { + address = "127.0.0.4:8200" + tls_disable = true +} + + disable_mlock = true diff --git a/internalshared/configutil/http_response_headers.go b/internalshared/configutil/http_response_headers.go index 4008107b28ad9..f0dc99c963f2d 100644 --- a/internalshared/configutil/http_response_headers.go +++ b/internalshared/configutil/http_response_headers.go @@ -7,13 +7,6 @@ import ( "strings" ) -var DefaultHeaderNames = []string{ - "Content-Security-Policy", - "X-XSS-Protection", - "X-Frame-Options", - "Strict-Transport-Security", -} - var ValidCustomStatusCodeCollection = []string{ "default", "1xx", @@ -23,96 +16,63 @@ var ValidCustomStatusCodeCollection = []string{ "5xx", } -const ( - contentSecurityPolicy = "default-src 'none'; connect-src 'self'; img-src 'self' data:; script-src 'self'; style-src 'unsafe-inline' 'self'; form-action 'none'; frame-ancestors 'none'; font-src 'self'" - xXssProtection = "1; mode=block" - xFrameOptions = "Deny" - strictTransportSecurity = "max-age=31536000; includeSubDomains" -) +const StrictTransportSecurity = "max-age=31536000; includeSubDomains" -func GetDefaultHeaderValue(h string) string { - switch h { - case "Content-Security-Policy": - return contentSecurityPolicy - case "X-XSS-Protection": - return xXssProtection - case "X-Frame-Options": - return xFrameOptions - case "Strict-Transport-Security": - return strictTransportSecurity - default: - return "" - } -} - -func setDefaultResponseHeaders(c map[string]string) map[string]string { - defaults := make(map[string]string) - // adding all parsed default headers - for k, v := range c { - defaults[k] = v - } - - // setting all default headers that are not included in the config - // file under the "default" category - for _, hn := range DefaultHeaderNames { - if _, ok := c[hn]; ok { - continue - } - hv := GetDefaultHeaderValue(hn) - if hv != "" { - defaults[hn] = hv - } - } - - return defaults -} - -func ParseCustomResponseHeaders(r interface{}) (map[string]map[string]string, error) { +// ParseCustomResponseHeaders takes a raw config values for the +// "custom_response_headers". It makes sure the config entry is passed in +// as a map of status code to a map of header name and header values. It +// verifies the validity of the status codes, and header values. It also +// adds the default headers values. +func ParseCustomResponseHeaders(responseHeaders interface{}) (map[string]map[string]string, error) { h := make(map[string]map[string]string) // if r is nil, we still should set the default custom headers - if r == nil { - de := h["default"] - h["default"] = setDefaultResponseHeaders(de) + if responseHeaders == nil { + h["default"] = map[string]string{"Strict-Transport-Security": StrictTransportSecurity} return h, nil } - if _, ok := r.([]map[string]interface{}); !ok { - return nil, fmt.Errorf("response headers were not configured correctly. please make sure they're in a map") + if _, ok := responseHeaders.([]map[string]interface{}); !ok { + return nil, fmt.Errorf("response headers were not configured correctly. please make sure they're in a slice of maps") } - customResponseHeader := r.([]map[string]interface{}) + customResponseHeader := responseHeaders.([]map[string]interface{}) for _, crh := range customResponseHeader { for statusCode, responseHeader := range crh { if _, ok := responseHeader.([]map[string]interface{}); !ok { - return nil, fmt.Errorf("response headers were not configured correctly. please make sure they're in a map") + return nil, fmt.Errorf("response headers were not configured correctly. please make sure they're in a slice of maps") } if !IsValidStatusCode(statusCode) { return nil, fmt.Errorf("invalid status code found in the server configuration: %v", statusCode) } - hvl := responseHeader.([]map[string]interface{}) - if len(hvl) != 1 { + headerValList := responseHeader.([]map[string]interface{}) + if len(headerValList) != 1 { return nil, fmt.Errorf("invalid number of response headers exist") } - hvm := hvl[0] - hv, err := parseHeaders(hvm) + headerValMap := headerValList[0] + headerVal, err := parseHeaders(headerValMap) if err != nil { return nil, err } - h[statusCode] = hv + h[statusCode] = headerVal } } - // setting default custom headers - de := h["default"] - h["default"] = setDefaultResponseHeaders(de) + // setting Strict-Transport-Security as a default header + if h["default"] == nil { + h["default"] = make(map[string]string) + } + if _, ok := h["default"]["Strict-Transport-Security"]; !ok { + h["default"]["Strict-Transport-Security"] = StrictTransportSecurity + } return h, nil } +// IsValidStatusCodeCollection checks if the given status code is as expected func IsValidStatusCodeCollection(sc string) bool { for _, v := range ValidCustomStatusCodeCollection { if sc == v { @@ -145,24 +105,24 @@ func parseHeaders(in map[string]interface{}) (map[string]string, error) { hvMap := make(map[string]string) for k, v := range in { // parsing header name - hn := textproto.CanonicalMIMEHeaderKey(k) + headerName := textproto.CanonicalMIMEHeaderKey(k) // parsing header values s, err := parseHeaderValues(v) if err != nil { return nil, err } - hvMap[hn] = s + hvMap[headerName] = s } return hvMap, nil } -func parseHeaderValues(h interface{}) (string, error) { +func parseHeaderValues(header interface{}) (string, error) { var sl []string - if _, ok := h.([]interface{}); !ok { + if _, ok := header.([]interface{}); !ok { return "", fmt.Errorf("headers must be given in a list of strings") } - vli := h.([]interface{}) - for _, vh := range vli { + headerValList := header.([]interface{}) + for _, vh := range headerValList { if _, ok := vh.(string); !ok { return "", fmt.Errorf("found a non-string header value: %v", vh) } diff --git a/vault/core.go b/vault/core.go index 0f13c6b971558..98dfb1c1924bb 100644 --- a/vault/core.go +++ b/vault/core.go @@ -2653,8 +2653,8 @@ func (c *Core) getCustomHeadersListenerList(listenerAdd string) []*ListenerCusto return nil } - customHeadersList := customHeaders.([]*ListenerCustomHeaders) - if customHeadersList == nil { + customHeadersList, ok := customHeaders.([]*ListenerCustomHeaders) + if customHeadersList == nil || !ok { return nil } diff --git a/vault/custom_response_headers.go b/vault/custom_response_headers.go index 382b2783c4b69..4a6608838fcb6 100644 --- a/vault/custom_response_headers.go +++ b/vault/custom_response_headers.go @@ -24,71 +24,71 @@ type CustomHeader struct { } func NewListenerCustomHeader(ln []*configutil.Listener, logger log.Logger, uiHeaders http.Header) []*ListenerCustomHeaders { - var ll []*ListenerCustomHeaders + var listenerCustomHeadersList []*ListenerCustomHeaders for _, l := range ln { - lc := &ListenerCustomHeaders{ + listenerCustomHeaderStruct := &ListenerCustomHeaders{ Address: l.Address, } - lc.StatusCodeHeaderMap = make(map[string][]*CustomHeader) - lc.configuredHeadersStatusCodeMap = make(map[string][]string) - for sc, hv := range l.CustomResponseHeaders { - var chl []*CustomHeader - for h, v := range hv { + listenerCustomHeaderStruct.StatusCodeHeaderMap = make(map[string][]*CustomHeader) + listenerCustomHeaderStruct.configuredHeadersStatusCodeMap = make(map[string][]string) + for statusCode, headerValMap := range l.CustomResponseHeaders { + var customHeaderList []*CustomHeader + for headerName, headerVal := range headerValMap { // Sanitizing custom headers // X-Vault- prefix is reserved for Vault internal processes - if strings.HasPrefix(h, "X-Vault-") { - logger.Warn("custom headers starting with X-Vault are not valid", "header", h) + if strings.HasPrefix(headerName, "X-Vault-") { + logger.Warn("custom headers starting with X-Vault are not valid", "header", headerName) continue } // Checking for UI headers, if any common header exists, we just log an error if uiHeaders != nil { - exist := uiHeaders.Get(h) + exist := uiHeaders.Get(headerName) if exist != "" { - logger.Warn("found a duplicate header in UI", "header:", h, "Headers defined in the server configuration take precedence.") + logger.Warn("found a duplicate header in UI", "header:", headerName, "Headers defined in the server configuration take precedence.") } } // Checking if the header value is not an empty string - if v == "" { - logger.Warn("header value is an empty string", "header", h, "value", v) + if headerVal == "" { + logger.Warn("header value is an empty string", "header", headerName, "value", headerVal) continue } ch := &CustomHeader{ - Name: h, - Value: v, + Name: headerName, + Value: headerVal, } - chl = append(chl, ch) + customHeaderList = append(customHeaderList, ch) // setting up the reverse map of header to status code for easy lookups - lc.configuredHeadersStatusCodeMap[h] = append(lc.configuredHeadersStatusCodeMap[h], sc) + listenerCustomHeaderStruct.configuredHeadersStatusCodeMap[headerName] = append(listenerCustomHeaderStruct.configuredHeadersStatusCodeMap[headerName], statusCode) } - lc.StatusCodeHeaderMap[sc] = chl + listenerCustomHeaderStruct.StatusCodeHeaderMap[statusCode] = customHeaderList } - ll = append(ll, lc) + listenerCustomHeadersList = append(listenerCustomHeadersList, listenerCustomHeaderStruct) } - return ll + return listenerCustomHeadersList } -func (l *ListenerCustomHeaders) findCustomHeaderMatchStatusCode(sc string, hn string) string { +func (l *ListenerCustomHeaders) findCustomHeaderMatchStatusCode(statusCode string, headerName string) string { getHeader := func(ch []*CustomHeader) string { for _, h := range ch { - if h.Name == hn { + if h.Name == headerName { return h.Value } } return "" } - hm := l.StatusCodeHeaderMap + headerMap := l.StatusCodeHeaderMap // starting with the most specific status code - if ch, ok := hm[sc]; ok { - h := getHeader(ch) + if customHeaderList, ok := headerMap[statusCode]; ok { + h := getHeader(customHeaderList) if h != "" { return h } @@ -96,14 +96,14 @@ func (l *ListenerCustomHeaders) findCustomHeaderMatchStatusCode(sc string, hn st // Checking for the Yxx pattern var firstDig string - if len(sc) == 3 { - firstDig = string(sc[0]) + if len(statusCode) == 3 { + firstDig = string(statusCode[0]) } if firstDig != "" { s := fmt.Sprintf("%vxx", firstDig) if configutil.IsValidStatusCodeCollection(s) { - if ch, ok := hm[s]; ok { - h := getHeader(ch) + if customHeaderList, ok := headerMap[s]; ok { + h := getHeader(customHeaderList) if h != "" { return h } @@ -113,7 +113,7 @@ func (l *ListenerCustomHeaders) findCustomHeaderMatchStatusCode(sc string, hn st // At this point, we could not find a match for the given status code in the config file // so, we just return the "default" ones - h := getHeader(hm["default"]) + h := getHeader(headerMap["default"]) if h != "" { return h } @@ -134,9 +134,9 @@ func (l *ListenerCustomHeaders) FetchHeaderForStatusCode(header, sc string) (str return "", fmt.Errorf("failed to check if a header exist in config file due to invalid status code") } - hn := textproto.CanonicalMIMEHeaderKey(header) + headerName := textproto.CanonicalMIMEHeaderKey(header) - h := l.findCustomHeaderMatchStatusCode(sc, hn) + h := l.findCustomHeaderMatchStatusCode(sc, headerName) return h, nil } @@ -150,9 +150,9 @@ func (l *ListenerCustomHeaders) ExistCustomResponseHeader(header string) bool { return false } - hn := textproto.CanonicalMIMEHeaderKey(header) + headerName := textproto.CanonicalMIMEHeaderKey(header) - hs := l.configuredHeadersStatusCodeMap - _, ok := hs[hn] + headerMap := l.configuredHeadersStatusCodeMap + _, ok := headerMap[headerName] return ok } diff --git a/vault/ui.go b/vault/ui.go index 277f62cc8a918..bd1d3c6882147 100644 --- a/vault/ui.go +++ b/vault/ui.go @@ -34,6 +34,7 @@ func NewUIConfig(enabled bool, physicalStorage physical.Backend, barrierStorage defaultHeaders := http.Header{} defaultHeaders.Set("Service-Worker-Allowed", "/") defaultHeaders.Set("X-Content-Type-Options", "nosniff") + defaultHeaders.Set("Content-Security-Policy", "default-src 'none'; connect-src 'self'; img-src 'self' data:; script-src 'self'; style-src 'unsafe-inline' 'self'; form-action 'none'; frame-ancestors 'none'; font-src 'self'") return &UIConfig{ physicalStorage: physicalStorage, From b6eedd1b0ded274b60f5bd9571c4d148a5fd4710 Mon Sep 17 00:00:00 2001 From: hamid ghaf Date: Wed, 6 Oct 2021 18:29:22 -0700 Subject: [PATCH 23/25] some refacotring --- http/handler.go | 10 +++++----- vault/core.go | 5 ++--- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/http/handler.go b/http/handler.go index 6c17cd7b0ae3e..b1eb6224cf436 100644 --- a/http/handler.go +++ b/http/handler.go @@ -265,6 +265,11 @@ func (w statusHeaderResponseWriter) setCustomResponseHeaders(status int) { return } + // Checking the validity of the status code + if status >= 600 || status < 100 { + return + } + // setter function to set the headers setter := func(hvl []*vault.CustomHeader) { for _, hv := range hvl { @@ -272,11 +277,6 @@ func (w statusHeaderResponseWriter) setCustomResponseHeaders(status int) { } } - // Checking the validity of the status code - if status >= 600 || status < 100 { - return - } - // Setting the default headers first setter(sch["default"]) diff --git a/vault/core.go b/vault/core.go index 2d4525774f35c..a55f4f76d1b96 100644 --- a/vault/core.go +++ b/vault/core.go @@ -2708,15 +2708,14 @@ func (c *Core) ExistCustomResponseHeader(header string, listenerAdd string) bool return false } - exist := false for _, l := range lch { - exist = l.ExistCustomResponseHeader(header) + exist := l.ExistCustomResponseHeader(header) if exist { return true } } - return exist + return false } func (c *Core) ReloadCustomResponseHeaders() error { From 08d01574b220e2be0acfa46bd9e59dd96fb7dda7 Mon Sep 17 00:00:00 2001 From: hamid ghaf Date: Fri, 8 Oct 2021 13:09:21 -0700 Subject: [PATCH 24/25] Refactoring and addressing comments --- http/custom_header_test.go | 59 ---------------- http/handler.go | 21 ++---- .../configutil/http_response_headers.go | 19 ++---- vault/core.go | 51 ++++---------- vault/custom_response_headers.go | 68 ------------------- vault/logical_system.go | 2 +- 6 files changed, 30 insertions(+), 190 deletions(-) diff --git a/http/custom_header_test.go b/http/custom_header_test.go index 6ed977c6f5015..5125050ad570d 100644 --- a/http/custom_header_test.go +++ b/http/custom_header_test.go @@ -62,39 +62,12 @@ func TestCustomResponseHeaders(t *testing.T) { testResponseHeader(t, resp, defaultCustomHeaders) testResponseHeader(t, resp, customHeader4xx) - resp = testHttpGet(t, token, addr+"/v1/sys/generate-recovery-token/attempt") - testResponseStatus(t, resp, 404) - testResponseHeader(t, resp, defaultCustomHeaders) - testResponseHeader(t, resp, customHeader4xx) - - resp = testHttpGet(t, token, addr+"/v1/sys/generate-recovery-token/update") - testResponseStatus(t, resp, 404) - testResponseHeader(t, resp, defaultCustomHeaders) - testResponseHeader(t, resp, customHeader4xx) - - resp = testHttpGet(t, token, addr+"/v1/sys/config/state/") - testResponseStatus(t, resp, 404) - testResponseHeader(t, resp, defaultCustomHeaders) - testResponseHeader(t, resp, customHeader4xx) - resp = testHttpGet(t, token, addr+"/v1/sys/seal") testResponseStatus(t, resp, 405) testResponseHeader(t, resp, defaultCustomHeaders) testResponseHeader(t, resp, customHeader4xx) testResponseHeader(t, resp, customHeader405) - resp = testHttpGet(t, token, addr+"/v1/sys/step-down") - testResponseStatus(t, resp, 405) - testResponseHeader(t, resp, defaultCustomHeaders) - testResponseHeader(t, resp, customHeader4xx) - testResponseHeader(t, resp, customHeader405) - - resp = testHttpGet(t, token, addr+"/v1/sys/unseal") - testResponseStatus(t, resp, 405) - testResponseHeader(t, resp, defaultCustomHeaders) - testResponseHeader(t, resp, customHeader4xx) - testResponseHeader(t, resp, customHeader405) - resp = testHttpGet(t, token, addr+"/v1/sys/leader") testResponseStatus(t, resp, 200) testResponseHeader(t, resp, customHeader200) @@ -113,22 +86,6 @@ func TestCustomResponseHeaders(t *testing.T) { testResponseHeader(t, resp, customHeader4xx) testResponseHeader(t, resp, customHeader400) - resp = testHttpGet(t, token, addr+"/v1/sys/rekey/init") - testResponseStatus(t, resp, 200) - testResponseHeader(t, resp, customHeader200) - - resp = testHttpGet(t, token, addr+"/v1/sys/rekey/update") - testResponseStatus(t, resp, 400) - testResponseHeader(t, resp, defaultCustomHeaders) - testResponseHeader(t, resp, customHeader4xx) - testResponseHeader(t, resp, customHeader400) - - resp = testHttpGet(t, token, addr+"/v1/sys/rekey/verify") - testResponseStatus(t, resp, 400) - testResponseHeader(t, resp, defaultCustomHeaders) - testResponseHeader(t, resp, customHeader4xx) - testResponseHeader(t, resp, customHeader400) - resp = testHttpGet(t, token, addr+"/v1/sys/") testResponseStatus(t, resp, 404) testResponseHeader(t, resp, defaultCustomHeaders) @@ -153,22 +110,6 @@ func TestCustomResponseHeaders(t *testing.T) { testResponseStatus(t, resp, 200) testResponseHeader(t, resp, customHeader200) - resp = testHttpGet(t, token, addr+"/v1/sys/host-info") - testResponseStatus(t, resp, 200) - testResponseHeader(t, resp, customHeader200) - - resp = testHttpGet(t, token, addr+"/v1/sys/init") - testResponseStatus(t, resp, 200) - testResponseHeader(t, resp, customHeader200) - - resp = testHttpGet(t, token, addr+"/v1/sys/seal-status") - testResponseStatus(t, resp, 200) - testResponseHeader(t, resp, customHeader200) - - resp = testHttpGet(t, token, addr+"/v1/sys/auth") - testResponseStatus(t, resp, 200) - testResponseHeader(t, resp, customHeader200) - resp = testHttpGet(t, token, addr+"/ui") testResponseStatus(t, resp, 200) testResponseHeader(t, resp, customHeader200) diff --git a/http/handler.go b/http/handler.go index b1eb6224cf436..22aab6ccbd8b4 100644 --- a/http/handler.go +++ b/http/handler.go @@ -225,15 +225,15 @@ type statusHeaderResponseWriter struct { headers map[string][]*vault.CustomHeader } -func (w statusHeaderResponseWriter) Wrapped() http.ResponseWriter { +func (w *statusHeaderResponseWriter) Wrapped() http.ResponseWriter { return w.wrapped } -func (w statusHeaderResponseWriter) Header() http.Header { +func (w *statusHeaderResponseWriter) Header() http.Header { return w.wrapped.Header() } -func (w statusHeaderResponseWriter) Write(buf []byte) (int, error) { +func (w *statusHeaderResponseWriter) Write(buf []byte) (int, error) { // It is allowed to only call ResponseWriter.Write and skip // ResponseWriter.WriteHeader. An example of such a situation is // "handleUIStub". The Write function will internally set the status code @@ -249,7 +249,7 @@ func (w statusHeaderResponseWriter) Write(buf []byte) (int, error) { return w.wrapped.Write(buf) } -func (w statusHeaderResponseWriter) WriteHeader(statusCode int) { +func (w *statusHeaderResponseWriter) WriteHeader(statusCode int) { w.setCustomResponseHeaders(statusCode) w.wrapped.WriteHeader(statusCode) w.statusCode = statusCode @@ -258,7 +258,7 @@ func (w statusHeaderResponseWriter) WriteHeader(statusCode int) { w.wroteHeader = true } -func (w statusHeaderResponseWriter) setCustomResponseHeaders(status int) { +func (w *statusHeaderResponseWriter) setCustomResponseHeaders(status int) { sch := w.headers if sch == nil { w.logger.Warn("status code header map not configured") @@ -1231,17 +1231,10 @@ func respondErrorCommon(w http.ResponseWriter, req *logical.Request, resp *logic func respondOk(w http.ResponseWriter, body interface{}) { w.Header().Set("Content-Type", "application/json") - var status int if body == nil { - status = http.StatusNoContent + w.WriteHeader(http.StatusNoContent) } else { - status = http.StatusOK - } - - w.WriteHeader(status) - - if body != nil { - + w.WriteHeader(http.StatusOK) enc := json.NewEncoder(w) enc.Encode(body) } diff --git a/internalshared/configutil/http_response_headers.go b/internalshared/configutil/http_response_headers.go index f0dc99c963f2d..c6fc3ba92d8b0 100644 --- a/internalshared/configutil/http_response_headers.go +++ b/internalshared/configutil/http_response_headers.go @@ -5,6 +5,8 @@ import ( "net/textproto" "strconv" "strings" + + "github.com/hashicorp/go-secure-stdlib/strutil" ) var ValidCustomStatusCodeCollection = []string{ @@ -31,15 +33,15 @@ func ParseCustomResponseHeaders(responseHeaders interface{}) (map[string]map[str return h, nil } - if _, ok := responseHeaders.([]map[string]interface{}); !ok { + customResponseHeader, ok := responseHeaders.([]map[string]interface{}) + if !ok { return nil, fmt.Errorf("response headers were not configured correctly. please make sure they're in a slice of maps") } - customResponseHeader := responseHeaders.([]map[string]interface{}) - for _, crh := range customResponseHeader { for statusCode, responseHeader := range crh { - if _, ok := responseHeader.([]map[string]interface{}); !ok { + headerValList, ok := responseHeader.([]map[string]interface{}) + if !ok { return nil, fmt.Errorf("response headers were not configured correctly. please make sure they're in a slice of maps") } @@ -47,7 +49,6 @@ func ParseCustomResponseHeaders(responseHeaders interface{}) (map[string]map[str return nil, fmt.Errorf("invalid status code found in the server configuration: %v", statusCode) } - headerValList := responseHeader.([]map[string]interface{}) if len(headerValList) != 1 { return nil, fmt.Errorf("invalid number of response headers exist") } @@ -74,13 +75,7 @@ func ParseCustomResponseHeaders(responseHeaders interface{}) (map[string]map[str // IsValidStatusCodeCollection checks if the given status code is as expected func IsValidStatusCodeCollection(sc string) bool { - for _, v := range ValidCustomStatusCodeCollection { - if sc == v { - return true - } - } - - return false + return strutil.StrListContains(ValidCustomStatusCodeCollection, sc) } // IsValidStatusCode checking for status codes outside the boundary diff --git a/vault/core.go b/vault/core.go index a55f4f76d1b96..cb7c4b1b5c819 100644 --- a/vault/core.go +++ b/vault/core.go @@ -2655,7 +2655,8 @@ func (c *Core) SetConfig(conf *server.Config) { c.logger.Debug("set config", "sanitized config", string(bz)) } -func (c *Core) getCustomHeadersListenerList(listenerAdd string) []*ListenerCustomHeaders { +func (c *Core) GetListenerCustomResponseHeaders(listenerAdd string) *ListenerCustomHeaders { + customHeaders := c.customListenerHeader.Load() if customHeaders == nil { return nil @@ -2666,49 +2667,28 @@ func (c *Core) getCustomHeadersListenerList(listenerAdd string) []*ListenerCusto return nil } - // either looking for a specific listener, or if listener address isn't given, - // checking for all available listeners - var lch []*ListenerCustomHeaders - if listenerAdd == "" { - return customHeadersList - } else { - for _, l := range customHeadersList { - if l.Address == listenerAdd { - lch = append(lch, l) - return lch - } + for _, l := range customHeadersList { + if l.Address == listenerAdd { + return l } } - return nil } -func (c *Core) GetListenerCustomResponseHeaders(listenerAdd string) *ListenerCustomHeaders { - if listenerAdd == "" { - return nil - } - lch := c.getCustomHeadersListenerList(listenerAdd) - if lch == nil { - return nil - } - if len(lch) != 1 { - c.logger.Warn("multiple listeners with the same address configured") - return nil - } - - return lch[0] -} - // ExistCustomResponseHeader support checking for custom headers per listener address // If the address is an empty string, it checks all listeners for the header. -func (c *Core) ExistCustomResponseHeader(header string, listenerAdd string) bool { - lch := c.getCustomHeadersListenerList(listenerAdd) - if lch == nil { - c.logger.Warn("no listener config found", "address", listenerAdd) +func (c *Core) ExistCustomResponseHeader(header string) bool { + customHeaders := c.customListenerHeader.Load() + if customHeaders == nil { return false } - for _, l := range lch { + customHeadersList, ok := customHeaders.([]*ListenerCustomHeaders) + if customHeadersList == nil || !ok { + return false + } + + for _, l := range customHeadersList { exist := l.ExistCustomResponseHeader(header) if exist { return true @@ -2727,8 +2707,7 @@ func (c *Core) ReloadCustomResponseHeaders() error { if lns == nil { return fmt.Errorf("no listener configured") } - - c.customListenerHeader.Store(([]*ListenerCustomHeaders)(nil)) + //c.customListenerHeader.Store(([]*ListenerCustomHeaders)(nil)) uiHeaders, err := c.UIHeaders() if err != nil { diff --git a/vault/custom_response_headers.go b/vault/custom_response_headers.go index 4a6608838fcb6..54df089547fc4 100644 --- a/vault/custom_response_headers.go +++ b/vault/custom_response_headers.go @@ -1,7 +1,6 @@ package vault import ( - "fmt" "net/http" "net/textproto" "strings" @@ -74,73 +73,6 @@ func NewListenerCustomHeader(ln []*configutil.Listener, logger log.Logger, uiHea return listenerCustomHeadersList } -func (l *ListenerCustomHeaders) findCustomHeaderMatchStatusCode(statusCode string, headerName string) string { - getHeader := func(ch []*CustomHeader) string { - for _, h := range ch { - if h.Name == headerName { - return h.Value - } - } - return "" - } - - headerMap := l.StatusCodeHeaderMap - - // starting with the most specific status code - if customHeaderList, ok := headerMap[statusCode]; ok { - h := getHeader(customHeaderList) - if h != "" { - return h - } - } - - // Checking for the Yxx pattern - var firstDig string - if len(statusCode) == 3 { - firstDig = string(statusCode[0]) - } - if firstDig != "" { - s := fmt.Sprintf("%vxx", firstDig) - if configutil.IsValidStatusCodeCollection(s) { - if customHeaderList, ok := headerMap[s]; ok { - h := getHeader(customHeaderList) - if h != "" { - return h - } - } - } - } - - // At this point, we could not find a match for the given status code in the config file - // so, we just return the "default" ones - h := getHeader(headerMap["default"]) - if h != "" { - return h - } - - return "" -} - -func (l *ListenerCustomHeaders) FetchHeaderForStatusCode(header, sc string) (string, error) { - if header == "" { - return "", fmt.Errorf("invalid target header") - } - - if l.StatusCodeHeaderMap == nil { - return "", fmt.Errorf("custom headers not configured") - } - - if !configutil.IsValidStatusCode(sc) { - return "", fmt.Errorf("failed to check if a header exist in config file due to invalid status code") - } - - headerName := textproto.CanonicalMIMEHeaderKey(header) - - h := l.findCustomHeaderMatchStatusCode(sc, headerName) - - return h, nil -} - func (l *ListenerCustomHeaders) ExistCustomResponseHeader(header string) bool { if header == "" { return false diff --git a/vault/logical_system.go b/vault/logical_system.go index 34f9df785b9e7..1675475a422a3 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -2623,7 +2623,7 @@ func (b *SystemBackend) handleConfigUIHeadersUpdate(ctx context.Context, req *lo // Translate the list of values to the valid header string value := http.Header{} for _, v := range values { - if b.Core.ExistCustomResponseHeader(header, "") { + if b.Core.ExistCustomResponseHeader(header) { return logical.ErrorResponse("This header already exists in the server configuration and cannot be set in the UI."), logical.ErrInvalidRequest } value.Add(header, v) From 5734ab10f43ac91d0b1b1d1066f8160de3010879 Mon Sep 17 00:00:00 2001 From: hamid ghaf Date: Tue, 12 Oct 2021 10:30:07 -0700 Subject: [PATCH 25/25] removing a function and fixing comments --- internalshared/configutil/http_response_headers.go | 7 +------ vault/core.go | 5 ++--- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/internalshared/configutil/http_response_headers.go b/internalshared/configutil/http_response_headers.go index c6fc3ba92d8b0..2db3034e588b6 100644 --- a/internalshared/configutil/http_response_headers.go +++ b/internalshared/configutil/http_response_headers.go @@ -73,14 +73,9 @@ func ParseCustomResponseHeaders(responseHeaders interface{}) (map[string]map[str return h, nil } -// IsValidStatusCodeCollection checks if the given status code is as expected -func IsValidStatusCodeCollection(sc string) bool { - return strutil.StrListContains(ValidCustomStatusCodeCollection, sc) -} - // IsValidStatusCode checking for status codes outside the boundary func IsValidStatusCode(sc string) bool { - if IsValidStatusCodeCollection(sc) { + if strutil.StrListContains(ValidCustomStatusCodeCollection, sc) { return true } diff --git a/vault/core.go b/vault/core.go index cb7c4b1b5c819..e85d811d92d7b 100644 --- a/vault/core.go +++ b/vault/core.go @@ -2675,8 +2675,8 @@ func (c *Core) GetListenerCustomResponseHeaders(listenerAdd string) *ListenerCus return nil } -// ExistCustomResponseHeader support checking for custom headers per listener address -// If the address is an empty string, it checks all listeners for the header. +// ExistCustomResponseHeader checks if a custom header is configured in any +// listener's stanza func (c *Core) ExistCustomResponseHeader(header string) bool { customHeaders := c.customListenerHeader.Load() if customHeaders == nil { @@ -2707,7 +2707,6 @@ func (c *Core) ReloadCustomResponseHeaders() error { if lns == nil { return fmt.Errorf("no listener configured") } - //c.customListenerHeader.Store(([]*ListenerCustomHeaders)(nil)) uiHeaders, err := c.UIHeaders() if err != nil {