diff --git a/plugin/ochttp/propagation/tracecontext/propagation.go b/plugin/ochttp/propagation/tracecontext/propagation.go index 65ab1e996..48d815249 100644 --- a/plugin/ochttp/propagation/tracecontext/propagation.go +++ b/plugin/ochttp/propagation/tracecontext/propagation.go @@ -47,11 +47,17 @@ type HTTPFormat struct{} // SpanContextFromRequest extracts a span context from incoming requests. func (f *HTTPFormat) SpanContextFromRequest(req *http.Request) (sc trace.SpanContext, ok bool) { - h, ok := getRequestHeader(req, traceparentHeader, false) - if !ok { + tp, _ := getRequestHeader(req, traceparentHeader, false) + ts, _ := getRequestHeader(req, tracestateHeader, true) + return f.SpanContextFromHeaders(tp, ts) +} + +// SpanContextFromHeaders extracts a span context from provided header values. +func (f *HTTPFormat) SpanContextFromHeaders(tp string, ts string) (sc trace.SpanContext, ok bool) { + if tp == "" { return trace.SpanContext{}, false } - sections := strings.Split(h, "-") + sections := strings.Split(tp, "-") if len(sections) < 4 { return trace.SpanContext{}, false } @@ -101,7 +107,7 @@ func (f *HTTPFormat) SpanContextFromRequest(req *http.Request) (sc trace.SpanCon return trace.SpanContext{}, false } - sc.Tracestate = tracestateFromRequest(req) + sc.Tracestate = tracestateFromHeader(ts) return sc, true } @@ -128,14 +134,13 @@ func getRequestHeader(req *http.Request, name string, commaSeparated bool) (hdr // are resolved. // https://github.com/w3c/distributed-tracing/issues/172 // https://github.com/w3c/distributed-tracing/issues/175 -func tracestateFromRequest(req *http.Request) *tracestate.Tracestate { - h, _ := getRequestHeader(req, tracestateHeader, true) - if h == "" { +func tracestateFromHeader(ts string) *tracestate.Tracestate { + if ts == "" { return nil } var entries []tracestate.Entry - pairs := strings.Split(h, ",") + pairs := strings.Split(ts, ",") hdrLenWithoutOWS := len(pairs) - 1 // Number of commas for _, pair := range pairs { matches := trimOWSRegExp.FindStringSubmatch(pair) @@ -153,15 +158,15 @@ func tracestateFromRequest(req *http.Request) *tracestate.Tracestate { } entries = append(entries, tracestate.Entry{Key: kv[0], Value: kv[1]}) } - ts, err := tracestate.New(nil, entries...) + tsParsed, err := tracestate.New(nil, entries...) if err != nil { return nil } - return ts + return tsParsed } -func tracestateToRequest(sc trace.SpanContext, req *http.Request) { +func tracestateToHeader(sc trace.SpanContext) string { var pairs = make([]string, 0, len(sc.Tracestate.Entries())) if sc.Tracestate != nil { for _, entry := range sc.Tracestate.Entries() { @@ -170,18 +175,28 @@ func tracestateToRequest(sc trace.SpanContext, req *http.Request) { h := strings.Join(pairs, ",") if h != "" && len(h) <= maxTracestateLen { - req.Header.Set(tracestateHeader, h) + return h } } + return "" } -// SpanContextToRequest modifies the given request to include traceparent and tracestate headers. -func (f *HTTPFormat) SpanContextToRequest(sc trace.SpanContext, req *http.Request) { - h := fmt.Sprintf("%x-%x-%x-%x", +// SpanContextToHeaders serialize the SpanContext to traceparent and tracestate headers. +func (f *HTTPFormat) SpanContextToHeaders(sc trace.SpanContext) (tp string, ts string) { + tp = fmt.Sprintf("%x-%x-%x-%x", []byte{supportedVersion}, sc.TraceID[:], sc.SpanID[:], []byte{byte(sc.TraceOptions)}) - req.Header.Set(traceparentHeader, h) - tracestateToRequest(sc, req) + ts = tracestateToHeader(sc) + return +} + +// SpanContextToRequest modifies the given request to include traceparent and tracestate headers. +func (f *HTTPFormat) SpanContextToRequest(sc trace.SpanContext, req *http.Request) { + tp, ts := f.SpanContextToHeaders(sc) + req.Header.Set(traceparentHeader, tp) + if ts != "" { + req.Header.Set(tracestateHeader, ts) + } } diff --git a/plugin/ochttp/propagation/tracecontext/propagation_test.go b/plugin/ochttp/propagation/tracecontext/propagation_test.go index 996cfa883..5489ba9dc 100644 --- a/plugin/ochttp/propagation/tracecontext/propagation_test.go +++ b/plugin/ochttp/propagation/tracecontext/propagation_test.go @@ -100,6 +100,14 @@ func TestHTTPFormat_FromRequest(t *testing.T) { if gotOk != tt.wantOk { t.Errorf("HTTPFormat.FromRequest() gotOk = %v, want %v", gotOk, tt.wantOk) } + + gotSc, gotOk = f.SpanContextFromHeaders(tt.header, "") + if !reflect.DeepEqual(gotSc, tt.wantSc) { + t.Errorf("HTTPFormat.SpanContextFromHeaders() gotTs = %v, want %v", gotSc.Tracestate, tt.wantSc.Tracestate) + } + if gotOk != tt.wantOk { + t.Errorf("HTTPFormat.SpanContextFromHeaders() gotOk = %v, want %v", gotOk, tt.wantOk) + } }) } } @@ -128,6 +136,11 @@ func TestHTTPFormat_ToRequest(t *testing.T) { if got, want := h, tt.wantHeader; got != want { t.Errorf("HTTPFormat.ToRequest() header = %v, want %v", got, want) } + + gotTp, _ := f.SpanContextToHeaders(tt.sc) + if gotTp != tt.wantHeader { + t.Errorf("HTTPFormat.SpanContextToHeaders() tracestate header = %v, want %v", gotTp, tt.wantHeader) + } }) } } @@ -212,6 +225,14 @@ func TestHTTPFormatTracestate_FromRequest(t *testing.T) { if gotOk != tt.wantOk { t.Errorf("HTTPFormat.FromRequest() gotOk = %v, want %v", gotOk, tt.wantOk) } + + gotSc, gotOk = f.SpanContextFromHeaders(tt.tpHeader, tt.tsHeader) + if !reflect.DeepEqual(gotSc, tt.wantSc) { + t.Errorf("HTTPFormat.SpanContextFromHeaders() gotTs = %v, want %v", gotSc.Tracestate, tt.wantSc.Tracestate) + } + if gotOk != tt.wantOk { + t.Errorf("HTTPFormat.SpanContextFromHeaders() gotOk = %v, want %v", gotOk, tt.wantOk) + } }) } } @@ -262,6 +283,11 @@ func TestHTTPFormatTracestate_ToRequest(t *testing.T) { if got, want := h, tt.wantHeader; got != want { t.Errorf("HTTPFormat.ToRequest() tracestate header = %v, want %v", got, want) } + + _, gotTs := f.SpanContextToHeaders(tt.sc) + if gotTs != tt.wantHeader { + t.Errorf("HTTPFormat.SpanContextToHeaders() tracestate header = %v, want %v", gotTs, tt.wantHeader) + } }) } }