Skip to content

Commit 2ef25ea

Browse files
Aneurysm9MrAliasrghetia
authoredMar 16, 2020
Add filters for othttp plugin (#556)
* Add request filtering capability to othhtp.Handler * Add simple and useful filters for othttp plugin * Add note that all requests are traced in the absence of any filters * Add copyright notice to plugin/othttp/filters/filters_test.go Co-Authored-By: Tyler Yahn <MrAlias@users.noreply.github.com> * Add package docstring for filters package Co-authored-by: Tyler Yahn <MrAlias@users.noreply.github.com> Co-authored-by: Rahul Patel <rahulpa@google.com>
1 parent 217a97d commit 2ef25ea

File tree

4 files changed

+486
-0
lines changed

4 files changed

+486
-0
lines changed
 

‎plugin/othttp/filters/filters.go

+154
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
// Copyright 2020, OpenTelemetry Authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
// Package filters provides a set of filters useful with the
16+
// othttp.WithFilter() option to control which inbound requests are traced.
17+
package filters
18+
19+
import (
20+
"net/http"
21+
"strings"
22+
23+
"go.opentelemetry.io/otel/plugin/othttp"
24+
)
25+
26+
// Any takes a list of Filters and returns a Filter that
27+
// returns true if any Filter in the list returns true.
28+
func Any(fs ...othttp.Filter) othttp.Filter {
29+
return func(r *http.Request) bool {
30+
for _, f := range fs {
31+
if f(r) {
32+
return true
33+
}
34+
}
35+
return false
36+
}
37+
}
38+
39+
// All takes a list of Filters and returns a Filter that
40+
// returns true only if all Filters in the list return true.
41+
func All(fs ...othttp.Filter) othttp.Filter {
42+
return func(r *http.Request) bool {
43+
for _, f := range fs {
44+
if !f(r) {
45+
return false
46+
}
47+
}
48+
return true
49+
}
50+
}
51+
52+
// None takes a list of Filters and returns a Filter that returns
53+
// true only if none of the Filters in the list return true.
54+
func None(fs ...othttp.Filter) othttp.Filter {
55+
return func(r *http.Request) bool {
56+
for _, f := range fs {
57+
if f(r) {
58+
return false
59+
}
60+
}
61+
return true
62+
}
63+
}
64+
65+
// Not provides a convenience mechanism for inverting a Filter
66+
func Not(f othttp.Filter) othttp.Filter {
67+
return func(r *http.Request) bool {
68+
return !f(r)
69+
}
70+
}
71+
72+
// Hostname returns a Filter that returns true if the request's
73+
// hostname matches the provided string.
74+
func Hostname(h string) othttp.Filter {
75+
return func(r *http.Request) bool {
76+
return r.URL.Hostname() == h
77+
}
78+
}
79+
80+
// Path returns a Filter that returns true if the request's
81+
// path matches the provided string.
82+
func Path(p string) othttp.Filter {
83+
return func(r *http.Request) bool {
84+
return r.URL.Path == p
85+
}
86+
}
87+
88+
// PathPrefix returns a Filter that returns true if the request's
89+
// path starts with the provided string.
90+
func PathPrefix(p string) othttp.Filter {
91+
return func(r *http.Request) bool {
92+
return strings.HasPrefix(r.URL.Path, p)
93+
}
94+
}
95+
96+
// Query returns a Filter that returns true if the request
97+
// includes a query parameter k with a value equal to v.
98+
func Query(k, v string) othttp.Filter {
99+
return func(r *http.Request) bool {
100+
for _, qv := range r.URL.Query()[k] {
101+
if v == qv {
102+
return true
103+
}
104+
}
105+
return false
106+
}
107+
}
108+
109+
// QueryContains returns a Filter that returns true if the request
110+
// includes a query parameter k with a value that contains v.
111+
func QueryContains(k, v string) othttp.Filter {
112+
return func(r *http.Request) bool {
113+
for _, qv := range r.URL.Query()[k] {
114+
if strings.Contains(qv, v) {
115+
return true
116+
}
117+
}
118+
return false
119+
}
120+
}
121+
122+
// Method returns a Filter that returns true if the request
123+
// method is equal to the provided value.
124+
func Method(m string) othttp.Filter {
125+
return func(r *http.Request) bool {
126+
return m == r.Method
127+
}
128+
}
129+
130+
// Header returns a Filter that returns true if the request
131+
// includes a header k with a value equal to v.
132+
func Header(k, v string) othttp.Filter {
133+
return func(r *http.Request) bool {
134+
for _, hv := range r.Header.Values(k) {
135+
if v == hv {
136+
return true
137+
}
138+
}
139+
return false
140+
}
141+
}
142+
143+
// HeaderContains returns a Filter that returns true if the request
144+
// includes a header k with a value that contains v.
145+
func HeaderContains(k, v string) othttp.Filter {
146+
return func(r *http.Request) bool {
147+
for _, hv := range r.Header.Values(k) {
148+
if strings.Contains(hv, v) {
149+
return true
150+
}
151+
}
152+
return false
153+
}
154+
}

‎plugin/othttp/filters/filters_test.go

+266
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
// Copyright 2020, OpenTelemetry Authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package filters
16+
17+
import (
18+
"net/http"
19+
"net/url"
20+
"testing"
21+
22+
"go.opentelemetry.io/otel/plugin/othttp"
23+
)
24+
25+
type scenario struct {
26+
name string
27+
filter othttp.Filter
28+
req *http.Request
29+
exp bool
30+
}
31+
32+
func TestAny(t *testing.T) {
33+
for _, s := range []scenario{
34+
{
35+
name: "no matching filters",
36+
filter: Any(Path("/foo"), Hostname("bar.baz")),
37+
req: &http.Request{URL: &url.URL{Path: "/boo", Host: "baz.bar:8080"}},
38+
exp: false,
39+
},
40+
{
41+
name: "one matching filter",
42+
filter: Any(Path("/foo"), Hostname("bar.baz")),
43+
req: &http.Request{URL: &url.URL{Path: "/foo", Host: "baz.bar:8080"}},
44+
exp: true,
45+
},
46+
{
47+
name: "all matching filters",
48+
filter: Any(Path("/foo"), Hostname("bar.baz")),
49+
req: &http.Request{URL: &url.URL{Path: "/foo", Host: "bar.baz:8080"}},
50+
exp: true,
51+
},
52+
} {
53+
res := s.filter(s.req)
54+
if s.exp != res {
55+
t.Errorf("Failed testing %q. Expected %t, got %t", s.name, s.exp, res)
56+
}
57+
}
58+
}
59+
60+
func TestAll(t *testing.T) {
61+
for _, s := range []scenario{
62+
{
63+
name: "no matching filters",
64+
filter: All(Path("/foo"), Hostname("bar.baz")),
65+
req: &http.Request{URL: &url.URL{Path: "/boo", Host: "baz.bar:8080"}},
66+
exp: false,
67+
},
68+
{
69+
name: "one matching filter",
70+
filter: All(Path("/foo"), Hostname("bar.baz")),
71+
req: &http.Request{URL: &url.URL{Path: "/foo", Host: "baz.bar:8080"}},
72+
exp: false,
73+
},
74+
{
75+
name: "all matching filters",
76+
filter: All(Path("/foo"), Hostname("bar.baz")),
77+
req: &http.Request{URL: &url.URL{Path: "/foo", Host: "bar.baz:8080"}},
78+
exp: true,
79+
},
80+
} {
81+
res := s.filter(s.req)
82+
if s.exp != res {
83+
t.Errorf("Failed testing %q. Expected %t, got %t", s.name, s.exp, res)
84+
}
85+
}
86+
}
87+
88+
func TestNone(t *testing.T) {
89+
for _, s := range []scenario{
90+
{
91+
name: "no matching filters",
92+
filter: None(Path("/foo"), Hostname("bar.baz")),
93+
req: &http.Request{URL: &url.URL{Path: "/boo", Host: "baz.bar:8080"}},
94+
exp: true,
95+
},
96+
{
97+
name: "one matching filter",
98+
filter: None(Path("/foo"), Hostname("bar.baz")),
99+
req: &http.Request{URL: &url.URL{Path: "/foo", Host: "baz.bar:8080"}},
100+
exp: false,
101+
},
102+
{
103+
name: "all matching filters",
104+
filter: None(Path("/foo"), Hostname("bar.baz")),
105+
req: &http.Request{URL: &url.URL{Path: "/foo", Host: "bar.baz:8080"}},
106+
exp: false,
107+
},
108+
} {
109+
res := s.filter(s.req)
110+
if s.exp != res {
111+
t.Errorf("Failed testing %q. Expected %t, got %t", s.name, s.exp, res)
112+
}
113+
}
114+
}
115+
116+
func TestNot(t *testing.T) {
117+
req := &http.Request{URL: &url.URL{Path: "/foo", Host: "bar.baz:8080"}}
118+
filter := Path("/foo")
119+
if filter(req) == Not(filter)(req) {
120+
t.Error("Not filter should invert the result of the supplied filter")
121+
}
122+
}
123+
124+
func TestPathPrefix(t *testing.T) {
125+
for _, s := range []scenario{
126+
{
127+
name: "non-matching prefix",
128+
filter: PathPrefix("/foo"),
129+
req: &http.Request{URL: &url.URL{Path: "/boo/far", Host: "baz.bar:8080"}},
130+
exp: false,
131+
},
132+
{
133+
name: "matching prefix",
134+
filter: PathPrefix("/foo"),
135+
req: &http.Request{URL: &url.URL{Path: "/foo/bar", Host: "bar.baz:8080"}},
136+
exp: true,
137+
},
138+
} {
139+
res := s.filter(s.req)
140+
if s.exp != res {
141+
t.Errorf("Failed testing %q. Expected %t, got %t", s.name, s.exp, res)
142+
}
143+
}
144+
}
145+
146+
func TestMethod(t *testing.T) {
147+
for _, s := range []scenario{
148+
{
149+
name: "non-matching method",
150+
filter: Method(http.MethodGet),
151+
req: &http.Request{Method: http.MethodHead, URL: &url.URL{Path: "/boo/far", Host: "baz.bar:8080"}},
152+
exp: false,
153+
},
154+
{
155+
name: "matching method",
156+
filter: Method(http.MethodGet),
157+
req: &http.Request{Method: http.MethodGet, URL: &url.URL{Path: "/boo/far", Host: "baz.bar:8080"}},
158+
exp: true,
159+
},
160+
} {
161+
res := s.filter(s.req)
162+
if s.exp != res {
163+
t.Errorf("Failed testing %q. Expected %t, got %t", s.name, s.exp, res)
164+
}
165+
}
166+
}
167+
168+
func TestQuery(t *testing.T) {
169+
matching, _ := url.Parse("http://bar.baz:8080/foo/bar?key=value")
170+
nonMatching, _ := url.Parse("http://bar.baz:8080/foo/bar?key=other")
171+
for _, s := range []scenario{
172+
{
173+
name: "non-matching query parameter",
174+
filter: Query("key", "value"),
175+
req: &http.Request{Method: http.MethodHead, URL: nonMatching},
176+
exp: false,
177+
},
178+
{
179+
name: "matching query parameter",
180+
filter: Query("key", "value"),
181+
req: &http.Request{Method: http.MethodGet, URL: matching},
182+
exp: true,
183+
},
184+
} {
185+
res := s.filter(s.req)
186+
if s.exp != res {
187+
t.Errorf("Failed testing %q. Expected %t, got %t", s.name, s.exp, res)
188+
}
189+
}
190+
}
191+
192+
func TestQueryContains(t *testing.T) {
193+
matching, _ := url.Parse("http://bar.baz:8080/foo/bar?key=value")
194+
nonMatching, _ := url.Parse("http://bar.baz:8080/foo/bar?key=other")
195+
for _, s := range []scenario{
196+
{
197+
name: "non-matching query parameter",
198+
filter: QueryContains("key", "alu"),
199+
req: &http.Request{Method: http.MethodHead, URL: nonMatching},
200+
exp: false,
201+
},
202+
{
203+
name: "matching query parameter",
204+
filter: QueryContains("key", "alu"),
205+
req: &http.Request{Method: http.MethodGet, URL: matching},
206+
exp: true,
207+
},
208+
} {
209+
res := s.filter(s.req)
210+
if s.exp != res {
211+
t.Errorf("Failed testing %q. Expected %t, got %t", s.name, s.exp, res)
212+
}
213+
}
214+
}
215+
216+
func TestHeader(t *testing.T) {
217+
matching := http.Header{}
218+
matching.Add("key", "value")
219+
nonMatching := http.Header{}
220+
nonMatching.Add("key", "other")
221+
for _, s := range []scenario{
222+
{
223+
name: "non-matching query parameter",
224+
filter: Header("key", "value"),
225+
req: &http.Request{Method: http.MethodHead, Header: nonMatching},
226+
exp: false,
227+
},
228+
{
229+
name: "matching query parameter",
230+
filter: Header("key", "value"),
231+
req: &http.Request{Method: http.MethodGet, Header: matching},
232+
exp: true,
233+
},
234+
} {
235+
res := s.filter(s.req)
236+
if s.exp != res {
237+
t.Errorf("Failed testing %q. Expected %t, got %t", s.name, s.exp, res)
238+
}
239+
}
240+
}
241+
242+
func TestHeaderContains(t *testing.T) {
243+
matching := http.Header{}
244+
matching.Add("key", "value")
245+
nonMatching := http.Header{}
246+
nonMatching.Add("key", "other")
247+
for _, s := range []scenario{
248+
{
249+
name: "non-matching query parameter",
250+
filter: HeaderContains("key", "alu"),
251+
req: &http.Request{Method: http.MethodHead, Header: nonMatching},
252+
exp: false,
253+
},
254+
{
255+
name: "matching query parameter",
256+
filter: HeaderContains("key", "alu"),
257+
req: &http.Request{Method: http.MethodGet, Header: matching},
258+
exp: true,
259+
},
260+
} {
261+
res := s.filter(s.req)
262+
if s.exp != res {
263+
t.Errorf("Failed testing %q. Expected %t, got %t", s.name, s.exp, res)
264+
}
265+
}
266+
}

‎plugin/othttp/handler.go

+25
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ const (
4141
WriteErrorKey = core.Key("http.write_error") // if an error occurred while writing a reply, the string of the error (io.EOF is not recorded)
4242
)
4343

44+
// Filter is a predicate used to determine whether a given http.request should
45+
// be traced. A Filter must return true if the request should be traced.
46+
type Filter func(*http.Request) bool
47+
4448
// Handler is http middleware that corresponds to the http.Handler interface and
4549
// is designed to wrap a http.Mux (or equivalent), while individual routes on
4650
// the mux are wrapped with WithRouteTag. A Handler will add various attributes
@@ -54,6 +58,7 @@ type Handler struct {
5458
spanStartOptions []trace.StartOption
5559
readEvent bool
5660
writeEvent bool
61+
filters []Filter
5762
}
5863

5964
// Option function used for setting *optional* Handler properties
@@ -93,6 +98,18 @@ func WithSpanOptions(opts ...trace.StartOption) Option {
9398
}
9499
}
95100

101+
// WithFilter adds a filter to the list of filters used by the handler.
102+
// If any filter indicates to exclude a request then the request will not be
103+
// traced. All filters must allow a request to be traced for a Span to be created.
104+
// If no filters are provided then all requests are traced.
105+
// Filters will be invoked for each processed request, it is advised to make them
106+
// simple and fast.
107+
func WithFilter(f Filter) Option {
108+
return func(h *Handler) {
109+
h.filters = append(h.filters, f)
110+
}
111+
}
112+
96113
type event int
97114

98115
// Different types of events that can be recorded, see WithMessageEvents
@@ -141,6 +158,14 @@ func NewHandler(handler http.Handler, operation string, opts ...Option) http.Han
141158

142159
// ServeHTTP serves HTTP requests (http.Handler)
143160
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
161+
for _, f := range h.filters {
162+
if !f(r) {
163+
// Simply pass through to the handler if a filter rejects the request
164+
h.handler.ServeHTTP(w, r)
165+
return
166+
}
167+
}
168+
144169
opts := append([]trace.StartOption{}, h.spanStartOptions...) // start with the configured options
145170

146171
ctx := propagation.ExtractHTTP(r.Context(), h.props, r.Header)

‎plugin/othttp/handler_test.go

+41
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,44 @@ func TestBasics(t *testing.T) {
5959
t.Fatalf("got %q, expected %q", got, expected)
6060
}
6161
}
62+
63+
func TestBasicFilter(t *testing.T) {
64+
rr := httptest.NewRecorder()
65+
66+
var id uint64
67+
tracer := mocktrace.MockTracer{StartSpanID: &id}
68+
69+
h := NewHandler(
70+
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
71+
if _, err := io.WriteString(w, "hello world"); err != nil {
72+
t.Fatal(err)
73+
}
74+
}), "test_handler",
75+
WithTracer(&tracer),
76+
WithFilter(func(r *http.Request) bool {
77+
return false
78+
}),
79+
)
80+
81+
r, err := http.NewRequest(http.MethodGet, "http://localhost/", nil)
82+
if err != nil {
83+
t.Fatal(err)
84+
}
85+
h.ServeHTTP(rr, r)
86+
if got, expected := rr.Result().StatusCode, http.StatusOK; got != expected {
87+
t.Fatalf("got %d, expected %d", got, expected)
88+
}
89+
if got := rr.Header().Get("Traceparent"); got != "" {
90+
t.Fatal("expected empty trace header")
91+
}
92+
if got, expected := id, uint64(0); got != expected {
93+
t.Fatalf("got %d, expected %d", got, expected)
94+
}
95+
d, err := ioutil.ReadAll(rr.Result().Body)
96+
if err != nil {
97+
t.Fatal(err)
98+
}
99+
if got, expected := string(d), "hello world"; got != expected {
100+
t.Fatalf("got %q, expected %q", got, expected)
101+
}
102+
}

0 commit comments

Comments
 (0)
Please sign in to comment.