Skip to content

Commit

Permalink
Fix data-race in metric without code and method but with `WithLab…
Browse files Browse the repository at this point in the history
…elFromCtx`

This commit fixes a data race that exists when the metric used in any
`promhttp` middleware doesn't collect the `code` and `method` but uses
`WithLabelFromCtx` to collect values from context.

The problem happens because when no `code` and `method` tags are
collected, the `labels` function returns a pre-initialized map
`emptyLabels` for every request.

When one or multipe `WithLabelFromCtx` options are configured, the
returned map from the `labels` function call is used to collect the
metrics from context which creates a multi-write data race.

Signed-off-by: Tiago Silva <tiago.silva@goteleport.com>
  • Loading branch information
tigrato committed Jul 30, 2023
1 parent 7f2db5f commit f488fe2
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 12 deletions.
6 changes: 4 additions & 2 deletions prometheus/promhttp/instrument_client.go
Expand Up @@ -70,11 +70,12 @@ func InstrumentRoundTripperCounter(counter *prometheus.CounterVec, next http.Rou

// Curry the counter with dynamic labels before checking the remaining labels.
code, method := checkLabels(counter.MustCurryWith(rtOpts.emptyDynamicLabels()))
hasExtraLabels := len(rtOpts.extraLabelsFromCtx) > 0

return func(r *http.Request) (*http.Response, error) {
resp, err := next.RoundTrip(r)
if err == nil {
l := labels(code, method, r.Method, resp.StatusCode, rtOpts.extraMethods...)
l := labels(code, method, hasExtraLabels, r.Method, resp.StatusCode, rtOpts.extraMethods...)
for label, resolve := range rtOpts.extraLabelsFromCtx {
l[label] = resolve(resp.Request.Context())
}
Expand Down Expand Up @@ -113,12 +114,13 @@ func InstrumentRoundTripperDuration(obs prometheus.ObserverVec, next http.RoundT

// Curry the observer with dynamic labels before checking the remaining labels.
code, method := checkLabels(obs.MustCurryWith(rtOpts.emptyDynamicLabels()))
hasExtraLabels := len(rtOpts.extraLabelsFromCtx) > 0

return func(r *http.Request) (*http.Response, error) {
start := time.Now()
resp, err := next.RoundTrip(r)
if err == nil {
l := labels(code, method, r.Method, resp.StatusCode, rtOpts.extraMethods...)
l := labels(code, method, hasExtraLabels, r.Method, resp.StatusCode, rtOpts.extraMethods...)
for label, resolve := range rtOpts.extraLabelsFromCtx {
l[label] = resolve(resp.Request.Context())
}
Expand Down
26 changes: 17 additions & 9 deletions prometheus/promhttp/instrument_server.go
Expand Up @@ -89,14 +89,15 @@ func InstrumentHandlerDuration(obs prometheus.ObserverVec, next http.Handler, op

// Curry the observer with dynamic labels before checking the remaining labels.
code, method := checkLabels(obs.MustCurryWith(hOpts.emptyDynamicLabels()))
hasExtraLabels := len(hOpts.extraLabelsFromCtx) > 0

if code {
return func(w http.ResponseWriter, r *http.Request) {
now := time.Now()
d := newDelegator(w, nil)
next.ServeHTTP(d, r)

l := labels(code, method, r.Method, d.Status(), hOpts.extraMethods...)
l := labels(code, method, hasExtraLabels, r.Method, d.Status(), hOpts.extraMethods...)
for label, resolve := range hOpts.extraLabelsFromCtx {
l[label] = resolve(r.Context())
}
Expand All @@ -107,7 +108,7 @@ func InstrumentHandlerDuration(obs prometheus.ObserverVec, next http.Handler, op
return func(w http.ResponseWriter, r *http.Request) {
now := time.Now()
next.ServeHTTP(w, r)
l := labels(code, method, r.Method, 0, hOpts.extraMethods...)
l := labels(code, method, hasExtraLabels, r.Method, 0, hOpts.extraMethods...)
for label, resolve := range hOpts.extraLabelsFromCtx {
l[label] = resolve(r.Context())
}
Expand Down Expand Up @@ -140,13 +141,14 @@ func InstrumentHandlerCounter(counter *prometheus.CounterVec, next http.Handler,

// Curry the counter with dynamic labels before checking the remaining labels.
code, method := checkLabels(counter.MustCurryWith(hOpts.emptyDynamicLabels()))
hasExtraLabels := len(hOpts.extraLabelsFromCtx) > 0

if code {
return func(w http.ResponseWriter, r *http.Request) {
d := newDelegator(w, nil)
next.ServeHTTP(d, r)

l := labels(code, method, r.Method, d.Status(), hOpts.extraMethods...)
l := labels(code, method, hasExtraLabels, r.Method, d.Status(), hOpts.extraMethods...)
for label, resolve := range hOpts.extraLabelsFromCtx {
l[label] = resolve(r.Context())
}
Expand All @@ -157,7 +159,7 @@ func InstrumentHandlerCounter(counter *prometheus.CounterVec, next http.Handler,
return func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(w, r)

l := labels(code, method, r.Method, 0, hOpts.extraMethods...)
l := labels(code, method, hasExtraLabels, r.Method, 0, hOpts.extraMethods...)
for label, resolve := range hOpts.extraLabelsFromCtx {
l[label] = resolve(r.Context())
}
Expand Down Expand Up @@ -195,11 +197,12 @@ func InstrumentHandlerTimeToWriteHeader(obs prometheus.ObserverVec, next http.Ha

// Curry the observer with dynamic labels before checking the remaining labels.
code, method := checkLabels(obs.MustCurryWith(hOpts.emptyDynamicLabels()))
hasExtraLabels := len(hOpts.extraLabelsFromCtx) > 0

return func(w http.ResponseWriter, r *http.Request) {
now := time.Now()
d := newDelegator(w, func(status int) {
l := labels(code, method, r.Method, status, hOpts.extraMethods...)
l := labels(code, method, hasExtraLabels, r.Method, status, hOpts.extraMethods...)
for label, resolve := range hOpts.extraLabelsFromCtx {
l[label] = resolve(r.Context())
}
Expand Down Expand Up @@ -236,14 +239,15 @@ func InstrumentHandlerRequestSize(obs prometheus.ObserverVec, next http.Handler,

// Curry the observer with dynamic labels before checking the remaining labels.
code, method := checkLabels(obs.MustCurryWith(hOpts.emptyDynamicLabels()))
hasExtraLabels := len(hOpts.extraLabelsFromCtx) > 0

if code {
return func(w http.ResponseWriter, r *http.Request) {
d := newDelegator(w, nil)
next.ServeHTTP(d, r)
size := computeApproximateRequestSize(r)

l := labels(code, method, r.Method, d.Status(), hOpts.extraMethods...)
l := labels(code, method, hasExtraLabels, r.Method, d.Status(), hOpts.extraMethods...)
for label, resolve := range hOpts.extraLabelsFromCtx {
l[label] = resolve(r.Context())
}
Expand All @@ -255,7 +259,7 @@ func InstrumentHandlerRequestSize(obs prometheus.ObserverVec, next http.Handler,
next.ServeHTTP(w, r)
size := computeApproximateRequestSize(r)

l := labels(code, method, r.Method, 0, hOpts.extraMethods...)
l := labels(code, method, hasExtraLabels, r.Method, 0, hOpts.extraMethods...)
for label, resolve := range hOpts.extraLabelsFromCtx {
l[label] = resolve(r.Context())
}
Expand Down Expand Up @@ -290,12 +294,13 @@ func InstrumentHandlerResponseSize(obs prometheus.ObserverVec, next http.Handler

// Curry the observer with dynamic labels before checking the remaining labels.
code, method := checkLabels(obs.MustCurryWith(hOpts.emptyDynamicLabels()))
hasExtraLabels := len(hOpts.extraLabelsFromCtx) > 0

return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
d := newDelegator(w, nil)
next.ServeHTTP(d, r)

l := labels(code, method, r.Method, d.Status(), hOpts.extraMethods...)
l := labels(code, method, hasExtraLabels, r.Method, d.Status(), hOpts.extraMethods...)
for label, resolve := range hOpts.extraLabelsFromCtx {
l[label] = resolve(r.Context())
}
Expand Down Expand Up @@ -393,8 +398,11 @@ func isLabelCurried(c prometheus.Collector, label string) bool {
// unnecessary allocations on each request.
var emptyLabels = prometheus.Labels{}

func labels(code, method bool, reqMethod string, status int, extraMethods ...string) prometheus.Labels {
func labels(code, method, hasExtraLabelsCtx bool, reqMethod string, status int, extraMethods ...string) prometheus.Labels {
if !(code || method) {
if hasExtraLabelsCtx {
return prometheus.Labels{}
}
return emptyLabels
}
labels := prometheus.Labels{}
Expand Down
2 changes: 1 addition & 1 deletion prometheus/promhttp/instrument_server_test.go
Expand Up @@ -346,7 +346,7 @@ func TestLabels(t *testing.T) {
t.Run(name, func(t *testing.T) {
if sc.ok {
gotCode, gotMethod := checkLabels(sc.varLabels)
gotLabels := labels(gotCode, gotMethod, sc.reqMethod, sc.respStatus, sc.extraMethods...)
gotLabels := labels(gotCode, gotMethod, false, sc.reqMethod, sc.respStatus, sc.extraMethods...)
if !equalLabels(gotLabels, sc.wantLabels) {
t.Errorf("wanted labels=%v for counter, got code=%v", sc.wantLabels, gotLabels)
}
Expand Down

0 comments on commit f488fe2

Please sign in to comment.