From 19163e1b8984e2577ea366d8aa97f0f75c6f6ac6 Mon Sep 17 00:00:00 2001 From: greatfilter <76988454+greatfilter@users.noreply.github.com> Date: Wed, 19 May 2021 16:18:20 -0400 Subject: [PATCH] Adds HTTP/2 and gRPC support to Martian Proxy (#318) * HTTP/2 and gRPC support for Martian. Changes are opt-in via h2.Config being set in mitm.Config so have no impact on existing uses. Closes issue #268 --- .travis.yml | 8 +- cmd/proxy/main.go | 6 +- go.mod | 8 +- go.sum | 91 ++++ h2/grpc/grpc.go | 342 ++++++++++++++ h2/h2.go | 156 +++++++ h2/h2_test.go | 449 ++++++++++++++++++ h2/processor.go | 141 ++++++ h2/queued_frames.go | 224 +++++++++ h2/relay.go | 618 +++++++++++++++++++++++++ h2/testing/certs.go | 134 ++++++ h2/testing/fixture.go | 204 ++++++++ h2/testing/test_service.go | 67 +++ h2/testservice/test_service.pb.go | 365 +++++++++++++++ h2/testservice/test_service.proto | 46 ++ h2/testservice/test_service_grpc.pb.go | 212 +++++++++ mitm/mitm.go | 24 +- proxy.go | 3 + 18 files changed, 3090 insertions(+), 8 deletions(-) create mode 100644 h2/grpc/grpc.go create mode 100644 h2/h2.go create mode 100644 h2/h2_test.go create mode 100644 h2/processor.go create mode 100644 h2/queued_frames.go create mode 100644 h2/relay.go create mode 100644 h2/testing/certs.go create mode 100644 h2/testing/fixture.go create mode 100644 h2/testing/test_service.go create mode 100644 h2/testservice/test_service.pb.go create mode 100644 h2/testservice/test_service.proto create mode 100644 h2/testservice/test_service_grpc.pb.go diff --git a/.travis.yml b/.travis.yml index f87da0b38..19f18364a 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,12 +3,16 @@ sudo: false language: go go: - - 1.10.x - - 1.11.x + - 1.13.x install: + - go get -u golang.org/x/net/http2 + - go get -u golang.org/x/net/http2/hpack - go get -u golang.org/x/net/websocket - go get -u golang.org/x/lint/golint + - go get -u google.golang.org/protobuf/proto + - go get -u google.golang.org/grpc + - go get -u github.com/golang/snappy script: - golint ./... diff --git a/cmd/proxy/main.go b/cmd/proxy/main.go index 801b1bc8f..b6dcc4edb 100644 --- a/cmd/proxy/main.go +++ b/cmd/proxy/main.go @@ -253,6 +253,8 @@ var ( ) func main() { + martian.Init() + p := martian.NewProxy() defer p.Close() @@ -440,10 +442,6 @@ func main() { os.Exit(0) } -func init() { - martian.Init() -} - // configure installs a configuration handler at path. func configure(pattern string, handler http.Handler, mux *http.ServeMux) { if *allowCORS { diff --git a/go.mod b/go.mod index f9a54ed45..31b92adbe 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,10 @@ module github.com/google/martian/v3 go 1.11 -require golang.org/x/net v0.0.0-20190628185345-da137c7871d7 +require ( + github.com/golang/protobuf v1.5.2 + github.com/golang/snappy v0.0.3 + golang.org/x/net v0.0.0-20190628185345-da137c7871d7 + google.golang.org/grpc v1.37.0 + google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.1.0 // indirect +) diff --git a/go.sum b/go.sum index 549168c4c..aa7f07d5b 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,96 @@ +cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= +github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= +github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= +github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= +github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= +github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= +github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= +github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= +github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= +github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= +github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= +github.com/golang/protobuf v1.4.2 h1:+Z5KGCizgyZCbGh1KZqA0fcLLkwbsjIzS4aV2v7wJX0= +github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= +github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/golang/snappy v0.0.3 h1:fHPg5GQYlCeLIPB9BZqMVR5nR9A+IM5zcgeTdjMYmLA= +github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/martian v2.1.0+incompatible h1:/CP5g8u/VJHijgedC/Legn3BAbAaWPgecwXBIDzw5no= +github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= +golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190628185345-da137c7871d7 h1:rTIdg5QFRR7XCaK4LCjBiPbx8j4DQRpdYMnGn/bJUEU= golang.org/x/net v0.0.0-20190628185345-da137c7871d7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210505214959-0714010a04ed h1:V9kAVxLvz1lkufatrpHuUVyJ/5tR3Ms7rk951P4mI98= +golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= +google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= +google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013 h1:+kGHl1aib/qcwaRi1CbqBZ1rk19r85MNUf8HaBghugY= +google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= +google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= +google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= +google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= +google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= +google.golang.org/grpc v1.37.0 h1:uSZWeQJX5j11bIQ4AJoj+McDBo29cY1MCoC1wO3ts+c= +google.golang.org/grpc v1.37.0/go.mod h1:NREThFqKR1f3iQ6oBuvc5LadQuXVGo9rkm5ZGrQdJfM= +google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.1.0 h1:M1YKkFIboKNieVO5DLUEVzQfGwJD30Nv2jfUgzb5UcE= +google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.1.0/go.mod h1:6Kw0yEErY5E/yWrBtf03jp27GLLJujG4z/JK95pnjjw= +google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= +google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= +google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= +google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= +google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= +google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.25.0 h1:Ejskq+SyPohKW+1uil0JJMtmHCgJPJ/qWTxr8qp+R4c= +google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.26.0 h1:bxAC2xTBsZGibn2RTntX0oH50xLsqy1OxA9tTL3p/lk= +google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/h2/grpc/grpc.go b/h2/grpc/grpc.go new file mode 100644 index 000000000..202137487 --- /dev/null +++ b/h2/grpc/grpc.go @@ -0,0 +1,342 @@ +// Copyright 2021 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package grpc contains gRPC functionality for Martian proxy. +package grpc + +import ( + "bytes" + "compress/flate" + "compress/gzip" + "encoding/binary" + "fmt" + "io/ioutil" + "net/url" + "sync/atomic" + + "github.com/golang/snappy" + "github.com/google/martian/v3/h2" + "golang.org/x/net/http2" + "golang.org/x/net/http2/hpack" +) + +// Encoding is the grpc-encoding type. See Content-Coding entry at: +// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests +type Encoding uint8 + +const ( + // Identity indicates that no compression is used. + Identity Encoding = iota + // Gzip indicates that Gzip compression is used. + Gzip + // Deflate indicates that Deflate compression is used. + Deflate + // Snappy indicates that Snappy compression is used. + Snappy +) + +// ProcessorFactory creates gRPC processors that implement the Processor interface, which abstracts +// away some of the details of the underlying HTTP/2 protocol. A processor must forward +// invocations to the given `server` or `client` processors, which will arrange to have the data +// forwarded to the destination, with possible edits. Nil values are safe to return and no +// processing occurs in such cases. NOTE: an interface may have a non-nil type with a nil value. +// Such values are treated as valid processors. +type ProcessorFactory func(url *url.URL, server, client Processor) (Processor, Processor) + +// AsStreamProcessorFactory converts a ProcessorFactory into a StreamProcessorFactory. It creates +// an adapter that abstracts HTTP/2 frames into a representation that is closer to gRPC. +func AsStreamProcessorFactory(f ProcessorFactory) h2.StreamProcessorFactory { + return func(url *url.URL, sinks *h2.Processors) (h2.Processor, h2.Processor) { + var cToS, sToC h2.Processor + + // A grpc.Processor is translated into an h2.Processor in layers. + // + // adapter → processor → emitter → sink + // \_____________________________↗ + // + // * The adapter wraps the grpc.Processor interface so that it conforms with h2.Processor. It + // performs some processing to translate HTTP/2 frames into gRPC concepts. Frames that are + // not relevant to gRPC are forwarded directly to the sink. + // * The processor is the gRPC processing logic provided by the client factory. + // * The emitter wraps an h2.Processor sink and translates the processed gRPC data into HTTP/2 + // frames. + cToSEmitter := &emitter{sink: sinks.ForDirection(h2.ClientToServer)} + sToCEmitter := &emitter{sink: sinks.ForDirection(h2.ServerToClient)} + cToSProcessor, sToCProcessor := f(url, cToSEmitter, sToCEmitter) + + // enabled indicates whether the stream should be processed as gRPC. It is shared between the + // the two adapters because its detection is on a client-to-server HEADER frame and the state + // applies bidirectionally. + enabled := int32(0) + if cToSProcessor != nil { + cToSEmitter.adapter = &adapter{ + enabled: &enabled, + dir: h2.ClientToServer, + processor: cToSProcessor, + sink: sinks.ForDirection(h2.ClientToServer), + } + cToS = cToSEmitter.adapter + } + if sToCProcessor != nil { + sToCEmitter.adapter = &adapter{ + enabled: &enabled, + dir: h2.ServerToClient, + processor: sToCProcessor, + sink: sinks.ForDirection(h2.ServerToClient), + } + sToC = sToCEmitter.adapter + } + return cToS, sToC + } +} + +// Processor processes gRPC traffic. +type Processor interface { + h2.HeaderProcessor + // Message receives serialized messages. + Message(data []byte, streamEnded bool) error +} + +// dataState represents one of two possible states when consuming gRPC DATA frames. +type dataState uint8 + +const ( + readingMetadata dataState = iota + readingMessageData +) + +// adapter wraps the Processor interface with an h2.Processor interface. It filters streams that +// are not gRPC and handles decompressing the message data. +type adapter struct { + enabled *int32 + + dir h2.Direction + + processor Processor + sink h2.Processor + + encoding Encoding + + // State for the data interpreter. + buffer bytes.Buffer + state dataState + compressed bool + length uint32 +} + +func (a *adapter) Header( + headers []hpack.HeaderField, + streamEnded bool, + priority http2.PriorityParam, +) error { + if !a.isEnabled() { + for _, h := range headers { + if h.Name == "content-type" && h.Value == "application/grpc" { + atomic.StoreInt32(a.enabled, 1) + break + } + } + if !a.isEnabled() { + return a.sink.Header(headers, streamEnded, priority) + } + } + + for _, h := range headers { + if h.Name == "grpc-encoding" { + switch h.Value { + case "identity": + a.encoding = Identity + case "gzip": + a.encoding = Gzip + case "deflate": + a.encoding = Deflate + case "snappy": + a.encoding = Snappy + default: + return fmt.Errorf("unrecognized grpc-encoding %s in %v", h.Value, headers) + } + } + } + return a.processor.Header(headers, streamEnded, priority) +} + +func (a *adapter) Data(data []byte, streamEnded bool) error { + if !a.isEnabled() { + return a.sink.Data(data, streamEnded) + } + + a.buffer.Write(data) + + for { + switch a.state { + case readingMetadata: + if streamEnded && a.buffer.Len() == 0 { + // gRPC may send empty DATA frames to end a stream. + if err := a.processor.Message(nil, true); err != nil { + return err + } + } + if a.buffer.Len() < 5 { + return nil + } + compressed, _ := a.buffer.ReadByte() + a.compressed = compressed > 0 + if err := binary.Read(&a.buffer, binary.BigEndian, &a.length); err != nil { + return fmt.Errorf("reading message length: %w", err) + } + a.state = readingMessageData + case readingMessageData: + if uint32(a.buffer.Len()) < a.length { + return nil + } + data := make([]byte, a.length) + a.buffer.Read(data) + + if a.compressed { + switch a.encoding { + case Identity: + case Gzip: + var err error + data, err = gunzip(data) + if err != nil { + return fmt.Errorf("gunzipping data: %w", err) + } + case Deflate: + var err error + data, err = deflate(data) + if err != nil { + return fmt.Errorf("deflating data: %w", err) + } + case Snappy: + var err error + data, err = ioutil.ReadAll(snappy.NewReader(bytes.NewReader(data))) + if err != nil { + return fmt.Errorf("uncompressing snappy: %w", err) + } + default: + panic(fmt.Sprintf("unexpected enocding: %v", a.encoding)) + } + } + a.state = readingMetadata + + // Only marks stream ended for the message if there is no data remaining. For ease of + // implementation, this proxy aligns messages with data frames. This means that if a data + // frame with stream ended contains multiple messages, the earlier ones should not be + // marked with stream ended. + // + // As explained in https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#data-frames, + // this reframing is safe because gRPC implementations won't be making any assumptions about + // the framing. + if err := a.processor.Message(data, streamEnded && a.buffer.Len() == 0); err != nil { + return err + } + default: + panic(fmt.Sprintf("unexpected state: %v", a.state)) + } + if a.buffer.Len() == 0 { + return nil + } + } +} + +func (a *adapter) Priority(priority http2.PriorityParam) error { + return a.sink.Priority(priority) +} + +func (a *adapter) RSTStream(errCode http2.ErrCode) error { + return a.sink.RSTStream(errCode) +} + +func (a *adapter) PushPromise(promiseID uint32, headers []hpack.HeaderField) error { + return a.sink.PushPromise(promiseID, headers) +} + +func (a *adapter) isEnabled() bool { + return atomic.LoadInt32(a.enabled) > 0 +} + +// emitter is a Processor implementation that wraps a h2.Processor instance, forwarding traffic to +// it. It handles recompression of the data. +type emitter struct { + sink h2.Processor + // adapter is a reference to the adapter needed to retrieve state. + adapter *adapter +} + +func (e *emitter) Header( + headers []hpack.HeaderField, + streamEnded bool, + priority http2.PriorityParam, +) error { + return e.sink.Header(headers, streamEnded, priority) +} + +func (e *emitter) Message(data []byte, streamEnded bool) error { + // Applies compression to `data` depending on `adapter`'s state. + if e.adapter.compressed { + switch e.adapter.encoding { + case Identity: + case Gzip: + var buf bytes.Buffer + w := gzip.NewWriter(&buf) + if _, err := w.Write(data); err != nil { + return fmt.Errorf("gzipping message data: %w", err) + } + if err := w.Close(); err != nil { + return fmt.Errorf("gzipping message data: %w", err) + } + data = buf.Bytes() + case Deflate: + var buf bytes.Buffer + w, _ := flate.NewWriter(&buf, -1) + if _, err := w.Write(data); err != nil { + return fmt.Errorf("flate compressing message data: %w", err) + } + if err := w.Close(); err != nil { + return fmt.Errorf("flate compressing message data: %w", err) + } + data = buf.Bytes() + case Snappy: + data = snappy.Encode(nil, data) + } + } + var buf bytes.Buffer + // Writes the compression status. + if e.adapter.compressed { + buf.WriteByte(1) + } else { + buf.WriteByte(0) + } + binary.Write(&buf, binary.BigEndian, uint32(len(data))) // Writes the length of the data. + buf.Write(data) // Writes the actual data. + return e.sink.Data(buf.Bytes(), streamEnded) +} + +func gunzip(data []byte) ([]byte, error) { + r, err := gzip.NewReader(bytes.NewReader(data)) + if err != nil { + return nil, err + } + return ioutil.ReadAll(r) +} + +func deflate(data []byte) (_ []byte, rerr error) { + r := flate.NewReader(bytes.NewReader(data)) + defer func() { + if err := r.Close(); err != nil && rerr != nil { + rerr = err + } + }() + return ioutil.ReadAll(r) +} diff --git a/h2/h2.go b/h2/h2.go new file mode 100644 index 000000000..2df63c611 --- /dev/null +++ b/h2/h2.go @@ -0,0 +1,156 @@ +// Copyright 2021 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package h2 contains basic HTTP/2 handling for Martian. +package h2 + +import ( + "bytes" + "crypto/tls" + "crypto/x509" + "encoding/hex" + "fmt" + "io" + "net/url" + "sync" + + "github.com/google/martian/v3/log" + "golang.org/x/net/http2" +) + +var ( + // connectionPreface is the constant value of the connection preface. + // https://tools.ietf.org/html/rfc7540#section-3.5 + connectionPreface = []byte("PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n") +) + +// Config stores the configuration information needed for HTTP/2 processing. +type Config struct { + // AllowedHostsFilter is a function returning true if the argument is a host for which H2 is + // permitted. + AllowedHostsFilter func(string) bool + + // RootCAs is the pool of CA certificates used by the MitM client to authenticate the server. + RootCAs *x509.CertPool + + // StreamProcessorFactories is a list of factories used to instantiate a chain of HTTP/2 stream + // processors. A chain is created for every stream. + StreamProcessorFactories []StreamProcessorFactory + + // EnableDebugLogs turns on fine-grained debug logging for HTTP/2. + EnableDebugLogs bool +} + +// Proxy proxies HTTP/2 traffic between a client connection, `cc`, and the HTTP/2 `url` assuming +// h2 is being used. Since no browsers use h2c, it's safe to assume all traffic uses TLS. +func (c *Config) Proxy(closing chan bool, cc io.ReadWriter, url *url.URL) error { + if c.EnableDebugLogs { + log.Infof("\u001b[1;35mProxying %v with HTTP/2\u001b[0m", url) + } + sc, err := tls.Dial("tcp", url.Host, &tls.Config{ + RootCAs: c.RootCAs, + NextProtos: []string{"h2"}, + }) + if err != nil { + return fmt.Errorf("connecting h2 to %v: %w", url, err) + } + if err := forwardPreface(sc, cc); err != nil { + return fmt.Errorf("initializing h2 with %v: %w", url, err) + } + + cf, sf := http2.NewFramer(cc, cc), http2.NewFramer(sc, sc) + cToS := newRelay(ClientToServer, "client", url.String(), cf, sf, &c.EnableDebugLogs) + sToC := newRelay(ServerToClient, url.String(), "client", sf, cf, &c.EnableDebugLogs) + + // Completes circular parts of the initialization. + + // The client-to-server relay depends on the server-to-client relay and vice versa. + cToS.peer, sToC.peer = sToC, cToS + + // Creating processors is circular because the create function references the relays and the + // relays need to call create. + cToS.processors = &streamProcessors{ + create: func(id uint32) *Processors { + p := &Processors{cToS: &relayAdapter{id, cToS}, sToC: &relayAdapter{id, sToC}} + // Chains the pipeline of processors together. + for i := len(c.StreamProcessorFactories) - 1; i >= 0; i-- { + cToS, sToC := c.StreamProcessorFactories[i](url, p) + // Bypasses any nil processors. + if cToS == nil { + cToS = p.ForDirection(ClientToServer) + } + if sToC == nil { + sToC = p.ForDirection(ServerToClient) + } + p = &Processors{cToS: cToS, sToC: sToC} + } + return p + }, + } + sToC.processors = cToS.processors + + var wg sync.WaitGroup + wg.Add(2) + go func() { // Forwards frames from client to server. + defer wg.Done() + if err := cToS.relayFrames(closing); err != nil { + log.Errorf("relaying frame from client to %v: %v", url, err) + } + }() + go func() { // Forwards frames from server to client. + defer wg.Done() + if err := sToC.relayFrames(closing); err != nil { + log.Errorf("relaying frame from %v to client: %v", url, err) + } + }() + wg.Wait() + return nil +} + +// forwardPreface forwards the connection preface from the client to the server. +func forwardPreface(server io.Writer, client io.Reader) error { + preface := make([]byte, len(connectionPreface)) + if _, err := client.Read(preface); err != nil { + return fmt.Errorf("reading preface: %w", err) + } + if !bytes.Equal(preface, connectionPreface) { + return fmt.Errorf("client sent unexpected preface: %s", hex.Dump(preface)) + } + for m := len(connectionPreface); m > 0; { + n, err := server.Write([]byte(preface)) + if err != nil { + return fmt.Errorf("writing preface: %w", err) + } + preface = preface[n:] + m -= n + } + return nil +} + +type streamProcessors struct { + // processors stores `*Processors` instances keyed by uint32 stream ID. + processors sync.Map + + // create creates `*Processors` for the given stream ID. + create func(uint32) *Processors +} + +// Get returns a the processor with the given ID and direction. +func (s *streamProcessors) Get(id uint32, dir Direction) Processor { + value, ok := s.processors.Load(id) + if !ok { + value, _ = s.processors.LoadOrStore(id, s.create(id)) + } + return value.(*Processors).ForDirection(dir) +} diff --git a/h2/h2_test.go b/h2/h2_test.go new file mode 100644 index 000000000..b75fba0c3 --- /dev/null +++ b/h2/h2_test.go @@ -0,0 +1,449 @@ +// Copyright 2021 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package h2_test + +import ( + "context" + "encoding/base64" + "fmt" + "io" + "math/rand" + "net/url" + "sync" + "testing" + + "github.com/google/martian/v3/h2" + mgrpc "github.com/google/martian/v3/h2/grpc" + ht "github.com/google/martian/v3/h2/testing" + "golang.org/x/net/http2" + "golang.org/x/net/http2/hpack" + "google.golang.org/grpc" + "google.golang.org/grpc/encoding/gzip" + "google.golang.org/protobuf/proto" + + tspb "github.com/google/martian/v3/h2/testservice" +) + +type requestProcessor struct { + dest mgrpc.Processor + requests *[]*tspb.EchoRequest +} + +func (p *requestProcessor) Header( + headers []hpack.HeaderField, + streamEnded bool, + priority http2.PriorityParam, +) error { + return p.dest.Header(headers, streamEnded, priority) +} + +func (p *requestProcessor) Message(data []byte, streamEnded bool) error { + msg := &tspb.EchoRequest{} + if err := proto.Unmarshal(data, msg); err != nil { + return fmt.Errorf("unmarshalling request: %w", err) + } + *p.requests = append(*p.requests, msg) + return p.dest.Message(data, streamEnded) +} + +type responseProcessor struct { + dest mgrpc.Processor + responses *[]*tspb.EchoResponse +} + +func (p *responseProcessor) Header( + headers []hpack.HeaderField, + streamEnded bool, + priority http2.PriorityParam, +) error { + return p.dest.Header(headers, streamEnded, priority) +} + +func (p *responseProcessor) Message(data []byte, streamEnded bool) error { + msg := &tspb.EchoResponse{} + if err := proto.Unmarshal(data, msg); err != nil { + return fmt.Errorf("unmarshalling response: %w", err) + } + *p.responses = append(*p.responses, msg) + return p.dest.Message(data, streamEnded) +} + +func TestEcho(t *testing.T) { + // This is a basic smoke test. It verifies that the end-to-end flow works and that gRPC messages + // are observed as expected in processors. + var requests []*tspb.EchoRequest + var responses []*tspb.EchoResponse + fixture, err := ht.New([]h2.StreamProcessorFactory{ + mgrpc.AsStreamProcessorFactory( + func(_ *url.URL, server, client mgrpc.Processor) (mgrpc.Processor, mgrpc.Processor) { + return &requestProcessor{server, &requests}, &responseProcessor{client, &responses} + }), + }) + if err != nil { + t.Fatalf("ht.New(...) = %v, want nil", err) + } + defer func() { + if err := fixture.Close(); err != nil { + t.Fatalf("f.Close() = %v, want nil", err) + } + }() + + ctx := context.Background() + req := &tspb.EchoRequest{ + Payload: "Hello", + } + resp, err := fixture.Echo(ctx, req) + if err != nil { + t.Fatalf("fixture.Echo(...) = _, %v, want _, nil", err) + } + if got, want := resp.GetPayload(), req.GetPayload(); got != want { + t.Errorf("resp.GetPayload() = %s, want = %s", got, want) + } + + // Verifies the captured requests and responses. + if got := len(requests); got != 1 { + t.Fatalf("len(requests) = %d, want 1", got) + } + if got, want := requests[0].GetPayload(), req.GetPayload(); got != want { + t.Errorf("requests[0].GetPayload() = %s, want = %s", got, want) + } + if got := len(responses); got != 1 { + t.Fatalf("len(requests) = %d, want 1", got) + } + if got, want := responses[0].GetPayload(), req.GetPayload(); got != want { + t.Errorf("responses[0].GetPayload() = %s, want = %s", got, want) + } +} + +type requestEditor struct { + dest mgrpc.Processor +} + +func (p *requestEditor) Header( + headers []hpack.HeaderField, + streamEnded bool, + priority http2.PriorityParam, +) error { + return p.dest.Header(headers, streamEnded, priority) +} + +func (p *requestEditor) Message(_ []byte, streamEnded bool) error { + msg := &tspb.EchoRequest{ + Payload: "Goodbye", + } + data, err := proto.Marshal(msg) + if err != nil { + return fmt.Errorf("marshalling request: %w", err) + } + return p.dest.Message(data, streamEnded) +} + +func TestRequestEditor(t *testing.T) { + // This test inserts a request modifier that changes the payload from "Hello" to "Goodbye". + fixture, err := ht.New([]h2.StreamProcessorFactory{ + mgrpc.AsStreamProcessorFactory( + func(_ *url.URL, server, client mgrpc.Processor) (mgrpc.Processor, mgrpc.Processor) { + return &requestEditor{server}, nil + }), + }) + if err != nil { + t.Fatalf("ht.New(...) = %v, want nil", err) + } + defer func() { + if err := fixture.Close(); err != nil { + t.Fatalf("f.Close() = %v, want nil", err) + } + }() + + ctx := context.Background() + req := &tspb.EchoRequest{ + Payload: "Hello", + } + resp, err := fixture.Echo(ctx, req) + if err != nil { + t.Fatalf("fixture.Echo(...) = _, %v, want _, nil", err) + } + if got, want := resp.GetPayload(), "Goodbye"; got != want { + t.Errorf("resp.GetPayload() = %s, want = %s", got, want) + } +} + +type plusOne struct { + dest mgrpc.Processor +} + +func (p *plusOne) Header( + headers []hpack.HeaderField, + streamEnded bool, + priority http2.PriorityParam, +) error { + return p.dest.Header(headers, streamEnded, priority) +} + +func (p *plusOne) Message(data []byte, streamEnded bool) error { + msg := &tspb.SumRequest{} + if err := proto.Unmarshal(data, msg); err != nil { + return fmt.Errorf("unmarshalling request: %w", err) + } + msg.Values = append(msg.Values, 1) + + data, err := proto.Marshal(msg) + if err != nil { + return fmt.Errorf("marshalling request: %w", err) + } + return p.dest.Message(data, streamEnded) +} + +func TestProcessorChaining(t *testing.T) { + // This test constructs a chain of processors and checks that the effects are correctly applied + // at the result. + fixture, err := ht.New([]h2.StreamProcessorFactory{ + mgrpc.AsStreamProcessorFactory( + func(_ *url.URL, server, client mgrpc.Processor) (mgrpc.Processor, mgrpc.Processor) { + return &plusOne{server}, nil + }), + mgrpc.AsStreamProcessorFactory( + func(_ *url.URL, server, client mgrpc.Processor) (mgrpc.Processor, mgrpc.Processor) { + return &plusOne{server}, nil + }), + mgrpc.AsStreamProcessorFactory( + func(_ *url.URL, server, client mgrpc.Processor) (mgrpc.Processor, mgrpc.Processor) { + return &plusOne{server}, nil + }), + }) + if err != nil { + t.Fatalf("ht.New(...) = %v, want nil", err) + } + defer func() { + if err := fixture.Close(); err != nil { + t.Fatalf("f.Close() = %v, want nil", err) + } + }() + + ctx := context.Background() + req := &tspb.SumRequest{ + Values: []int32{5}, + } + resp, err := fixture.Sum(ctx, req) + if err != nil { + t.Fatalf("fixture.Sum(...) = _, %v, want _, nil", err) + } + if got, want := resp.GetValue(), int32(8); got != want { + t.Errorf("resp.GetValue() = %d, want = %d", got, want) + } +} + +type headerCapture struct { + dest mgrpc.Processor + headers *[][]hpack.HeaderField +} + +func (h *headerCapture) Header( + headers []hpack.HeaderField, + streamEnded bool, + priority http2.PriorityParam, +) error { + c := make([]hpack.HeaderField, len(headers)) + copy(c, headers) + *h.headers = append(*h.headers, c) + return h.dest.Header(headers, streamEnded, priority) +} + +func (h *headerCapture) Message(data []byte, streamEnded bool) error { + return h.dest.Message(data, streamEnded) +} + +func TestLargeEcho(t *testing.T) { + // Sends a >128KB payload through the proxy. Since the standard gRPC frame size is only 16KB, + // this exercises frame merging, splitting and flow control code. + payload := make([]byte, 128*1024) + rand.Read(payload) + req := &tspb.EchoRequest{ + Payload: base64.StdEncoding.EncodeToString(payload), + } + + // This test also covers using gzip compression. Ideally, we would test more compression types + // but the golang gRPC implementation only provides a gzip compressor. + tests := []struct { + name string + useCompression bool + }{ + {"RawData", false}, + {"Gzip", true}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var cToSHeaders, sToCHeaders [][]hpack.HeaderField + fixture, err := ht.New([]h2.StreamProcessorFactory{ + mgrpc.AsStreamProcessorFactory( + func(_ *url.URL, server, client mgrpc.Processor) (mgrpc.Processor, mgrpc.Processor) { + return &headerCapture{server, &cToSHeaders}, &headerCapture{client, &sToCHeaders} + }), + }) + if err != nil { + t.Fatalf("ht.New(...) = %v, want nil", err) + } + defer func() { + if err := fixture.Close(); err != nil { + t.Fatalf("f.Close() = %v, want nil", err) + } + }() + + ctx := context.Background() + var resp *tspb.EchoResponse + if tc.useCompression { + resp, err = fixture.Echo(ctx, req, grpc.UseCompressor(gzip.Name)) + } else { + resp, err = fixture.Echo(ctx, req) + } + if err != nil { + t.Fatalf("fixture.Echo(...) = _, %v, want _, nil", err) + } + if got, want := resp.GetPayload(), req.GetPayload(); got != want { + t.Errorf("resp.GetPayload() = %s, want = %s", got, want) + } + // Verifies that grpc-encoding=gzip is present in the first headers on the stream when + // compression is active. + for _, headers := range [][]hpack.HeaderField{cToSHeaders[0], sToCHeaders[0]} { + foundGRPCEncoding := false + for _, h := range headers { + if h.Name == "grpc-encoding" { + foundGRPCEncoding = true + if got, want := h.Value, "gzip"; got != want { + t.Errorf("h.Value = %s, want %s", got, want) + } + } + } + if got, want := foundGRPCEncoding, tc.useCompression; got != want { + t.Errorf("foundGRPCEncoding = %t, want %t", got, want) + } + } + }) + } +} + +type noopProcessor struct { + sink mgrpc.Processor +} + +func (p *noopProcessor) Header( + headers []hpack.HeaderField, + streamEnded bool, + priority http2.PriorityParam, +) error { + return p.sink.Header(headers, streamEnded, priority) +} + +func (p *noopProcessor) Message(data []byte, streamEnded bool) error { + return p.sink.Message(data, streamEnded) +} + +func TestStream(t *testing.T) { + tests := []struct { + name string + factory h2.StreamProcessorFactory + }{ + { + "NilH2Processor", + func(_ *url.URL, _ *h2.Processors) (h2.Processor, h2.Processor) { + return nil, nil + }, + }, + { + // This differs from NilH2Processor only in how mgrpc.AsStreamProcessorFactory handles nil + // grpc.Processor values. It should end up processing exactly the same as + // h2.StreamProcessorFactory afterwards. + "NilGRPCProcessor", + mgrpc.AsStreamProcessorFactory( + func(_ *url.URL, _, _ mgrpc.Processor) (mgrpc.Processor, mgrpc.Processor) { + return nil, nil + }), + }, + { + // This differs from NilGRPCProcessor in that NilGRPCProcessor ends up behaving like + // NilH2Processor and no gRPC processing takes place. NoopProcessor causes the frames to + // be processed as gRPC. + "NoopGRPCProcessor", + mgrpc.AsStreamProcessorFactory( + func(_ *url.URL, server, client mgrpc.Processor) (mgrpc.Processor, mgrpc.Processor) { + return &noopProcessor{server}, &noopProcessor{client} + }), + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + fixture, err := ht.New([]h2.StreamProcessorFactory{tc.factory}) + if err != nil { + t.Fatalf("ht.New(...) = %v, want nil", err) + } + defer func() { + if err := fixture.Close(); err != nil { + t.Fatalf("f.Close() = %v, want nil", err) + } + }() + ctx := context.Background() + stream, err := fixture.DoubleEcho(ctx) + if err != nil { + t.Fatalf("fixture.DoubleEcho(ctx) = _, %v, want _, nil", err) + } + + var received []*tspb.EchoResponse + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + for { + resp, err := stream.Recv() + if err == io.EOF { + return + } + if err != nil { + t.Errorf("stream.Recv() = %v, want nil", err) + return + } + received = append(received, resp) + } + }() + + var sent []*tspb.EchoRequest + for i := 0; i < 5; i++ { + payload := make([]byte, 20*1024) + rand.Read(payload) + req := &tspb.EchoRequest{ + Payload: base64.StdEncoding.EncodeToString(payload), + } + if err := stream.Send(req); err != nil { + t.Fatalf("stream.Send(req) = %v, want nil", err) + } + sent = append(sent, req) + } + if err := stream.CloseSend(); err != nil { + t.Fatalf("stream.CloseSend() = %v, want nil", err) + } + wg.Wait() + + for i, req := range sent { + want := req.GetPayload() + if got := received[2*i].GetPayload(); got != want { + t.Errorf("received[2*i].GetPayload() = %s, want %s", got, want) + } + if got := received[2*i+1].GetPayload(); got != want { + t.Errorf("received[2*i+1].GetPayload() = %s, want %s", got, want) + } + } + }) + } +} diff --git a/h2/processor.go b/h2/processor.go new file mode 100644 index 000000000..f1515175a --- /dev/null +++ b/h2/processor.go @@ -0,0 +1,141 @@ +// Copyright 2021 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package h2 + +import ( + "fmt" + "net/url" + + "golang.org/x/net/http2" + "golang.org/x/net/http2/hpack" +) + +// Direction indicates the direction of the traffic flow. +type Direction uint8 + +const ( + // ClientToServer indicates traffic flowing from client-to-server. + ClientToServer Direction = iota + // ServerToClient indicates traffic flowing from server-to-client. + ServerToClient +) + +// StreamProcessorFactory is implemented by clients that wish to observe or edit HTTP/2 frames +// flowing through the proxy. It creates a pair of processors for the bidirectional stream. A +// processor consumes frames then calls the corresponding sink methods to forward frames to the +// destination, modifying the frame if needed. +// +// Returns the client-to-server and server-to-client processors. Nil values are safe to return and +// no processing occurs in such cases. NOTE: an interface may have a non-nil type with a nil value. +// Such values are treated as valid processors. +// +// Concurrency: there is a separate client-to-server and server-to-client thread. Calls against +// the `ClientToServer` sink must be made on the client-to-server thread and calls against +// the `ServerToClient` sink must be made on the server-to-client thread. Implementors should +// guard interactions across threads. +type StreamProcessorFactory func(url *url.URL, sinks *Processors) (Processor, Processor) + +// Processors encapsulates the two traffic receiving endpoints. +type Processors struct { + cToS, sToC Processor +} + +// ForDirection returns the processor receiving traffic in the given direction. +func (s *Processors) ForDirection(dir Direction) Processor { + switch dir { + case ClientToServer: + return s.cToS + case ServerToClient: + return s.sToC + } + panic(fmt.Sprintf("invalid direction: %v", dir)) +} + +// Processor accepts the possible stream frames. +// +// This API abstracts away some of the lower level HTTP/2 mechanisms. +// CONTINUATION frames are appropriately buffered and turned into Header calls and Header or +// PushPromise calls are split into CONTINUATION frames when needed. +// +// The proxy handles WINDOW_UPDATE frames and flow control, managing it independently for both +// endpoints. +type Processor interface { + DataFrameProcessor + HeaderProcessor + PriorityFrameProcessor + RSTStreamProcessor + PushPromiseProcessor +} + +// DataFrameProcessor processes data frames. +type DataFrameProcessor interface { + Data(data []byte, streamEnded bool) error +} + +// HeaderProcessor processes headers, abstracting out continuations. +type HeaderProcessor interface { + Header( + headers []hpack.HeaderField, + streamEnded bool, + priority http2.PriorityParam, + ) error +} + +// PriorityFrameProcessor processes priority frames. +type PriorityFrameProcessor interface { + Priority(http2.PriorityParam) error +} + +// RSTStreamProcessor processes RSTStream frames. +type RSTStreamProcessor interface { + RSTStream(http2.ErrCode) error +} + +// PushPromiseProcessor processes push promises, abstracting out continuations. +type PushPromiseProcessor interface { + PushPromise(promiseID uint32, headers []hpack.HeaderField) error +} + +// relayAdapter implements the Processor interface by delegating to an underlying relay. +type relayAdapter struct { + id uint32 + relay *relay +} + +func (r *relayAdapter) Data(data []byte, streamEnded bool) error { + return r.relay.data(r.id, data, streamEnded) +} + +func (r *relayAdapter) Header( + headers []hpack.HeaderField, + streamEnded bool, + priority http2.PriorityParam, +) error { + return r.relay.header(r.id, headers, streamEnded, priority) +} + +func (r *relayAdapter) Priority(priority http2.PriorityParam) error { + r.relay.priority(r.id, priority) + return nil +} + +func (r *relayAdapter) RSTStream(errCode http2.ErrCode) error { + r.relay.rstStream(r.id, errCode) + return nil +} + +func (r *relayAdapter) PushPromise(promiseID uint32, headers []hpack.HeaderField) error { + return r.relay.pushPromise(r.id, promiseID, headers) +} diff --git a/h2/queued_frames.go b/h2/queued_frames.go new file mode 100644 index 000000000..5b2691880 --- /dev/null +++ b/h2/queued_frames.go @@ -0,0 +1,224 @@ +// Copyright 2021 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package h2 + +import ( + "bytes" + "fmt" + + "golang.org/x/net/http2" +) + +// queuedFrame stores frames that belong to a stream and need to be kept in order. The need for +// this stems from flow control needed in the context of gRPC. Since a gRPC message can be split +// over multiple DATA frames, the proxy needs to buffer such frames so they can be reassembled +// into messages and edited before being forwarded. +// +// Note that the proxy does man-in-the-middle flow control independently to each endpoint instead +// of forwarding endpoint flow-control messages to each other directly. This is necessary because +// multiple DATA frames need to be captured before they can be forwarded. While the data frames are +// being held in the proxy, the destination of those frames cannot see them to send WINDOW_UPDATE +// acknowledgements and the sender will stop sending data. So the proxy must emit its own +// WINDOW_UPDATEs. +// +// Example: While DATA frames are being output-buffered due to pending WINDOW_UPDATE frames from +// the destination, it's possible for the source to send subsequent HEADER frames. Those HEADER +// frames must be queued after the DATA frames for consistency with HTTP/2's total ordering of +// frames within a stream. +// +// While the example only illustrates the need for HEADER frame buffering, a similar argument +// applies to other types of stream frames. WINDOW_UPDATE is a special case that is associated +// with a stream but does not require buffering or special ordering. This is because WINDOW_UPDATEs +// are basically acknowledgements for messages coming from the peer endpoint. In other words, +// WINDOW_UPDATE frames are associated with messages being received instead of messages being sent. +// The asynchrony of receiving remote messages should allow reordering freedom. +type queuedFrame interface { + // StreamID is the stream ID for the frame. + StreamID() uint32 + + // flowControlSize returns the size of this frame for the purposes of flow control. It is only + // non-zero for DATA frames. + flowControlSize() int + + // send writes the frame to the provided framer. This is not thread-safe and the caller should be + // holding appropriate locks. + send(*http2.Framer) error +} + +type queuedDataFrame struct { + streamID uint32 + endStream bool + data []byte +} + +func (f *queuedDataFrame) StreamID() uint32 { + return f.streamID +} + +func (f *queuedDataFrame) flowControlSize() int { + return len(f.data) +} + +func (f *queuedDataFrame) send(dest *http2.Framer) error { + return dest.WriteData(f.streamID, f.endStream, f.data) +} + +func (f *queuedDataFrame) String() string { + return fmt.Sprintf("data[id=%d, endStream=%t, len=%d]", f.streamID, f.endStream, len(f.data)) +} + +type queuedHeaderFrame struct { + streamID uint32 + endStream bool + priority http2.PriorityParam + chunks [][]byte +} + +func (f *queuedHeaderFrame) StreamID() uint32 { + return f.streamID +} + +func (*queuedHeaderFrame) flowControlSize() int { + return 0 +} + +func (f *queuedHeaderFrame) send(dest *http2.Framer) error { + if err := dest.WriteHeaders(http2.HeadersFrameParam{ + StreamID: f.streamID, + BlockFragment: f.chunks[0], + EndStream: f.endStream, + EndHeaders: len(f.chunks) <= 1, + PadLength: 0, + Priority: f.priority, + }); err != nil { + return fmt.Errorf("sending header %v: %w", f, err) + } + for i := 1; i < len(f.chunks); i++ { + headersEnded := i == len(f.chunks)-1 + if err := dest.WriteContinuation(f.streamID, headersEnded, f.chunks[i]); err != nil { + return fmt.Errorf("sending header continuations %v: %w", f, err) + } + } + return nil +} + +func (f *queuedHeaderFrame) String() string { + var buf bytes.Buffer // strings.Builder is not available on App Engine. + fmt.Fprintf(&buf, "header[id=%d, endStream=%t", f.streamID, f.endStream) + fmt.Fprintf(&buf, ", priority=%v, chunk lengths=[", f.priority) + for i, c := range f.chunks { + if i > 0 { + fmt.Fprintf(&buf, ",") + } + fmt.Fprintf(&buf, "%d", len(c)) + } + fmt.Fprintf(&buf, "]]") + return buf.String() +} + +type queuedPushPromiseFrame struct { + streamID uint32 + promiseID uint32 + chunks [][]byte +} + +func (f *queuedPushPromiseFrame) StreamID() uint32 { + return f.streamID +} + +func (*queuedPushPromiseFrame) flowControlSize() int { + return 0 +} + +func (f *queuedPushPromiseFrame) send(dest *http2.Framer) error { + if err := dest.WritePushPromise(http2.PushPromiseParam{ + StreamID: f.streamID, + PromiseID: f.promiseID, + BlockFragment: f.chunks[0], + EndHeaders: len(f.chunks) <= 1, + PadLength: 0, + }); err != nil { + return fmt.Errorf("sending push promise %v: %w", f, err) + } + for i := 1; i < len(f.chunks); i++ { + headersEnded := i == len(f.chunks)-1 + if err := dest.WriteContinuation(f.streamID, headersEnded, f.chunks[i]); err != nil { + return fmt.Errorf("sending push promise continuations %v: %w", f, err) + } + } + return nil +} + +func (f *queuedPushPromiseFrame) String() string { + var buf bytes.Buffer + fmt.Fprintf(&buf, "push promise[streamID=%d, promiseID= %d", f.streamID, f.promiseID) + fmt.Fprintf(&buf, ", chunk lengths=[") + for i, c := range f.chunks { + if i > 0 { + fmt.Fprintf(&buf, ",") + } + fmt.Fprintf(&buf, "%d", len(c)) + } + fmt.Fprintf(&buf, "]]") + return buf.String() +} + +type queuedPriorityFrame struct { + streamID uint32 + priority http2.PriorityParam +} + +func (f *queuedPriorityFrame) StreamID() uint32 { + return f.streamID +} + +func (*queuedPriorityFrame) flowControlSize() int { + return 0 +} + +func (f *queuedPriorityFrame) send(dest *http2.Framer) error { + if err := dest.WritePriority(f.streamID, f.priority); err != nil { + return fmt.Errorf("sending %v: %w", f, err) + } + return nil +} + +func (f *queuedPriorityFrame) String() string { + return fmt.Sprintf("priority[id=%d, priority=%v]", f.streamID, f.priority) +} + +type queuedRSTStreamFrame struct { + streamID uint32 + errCode http2.ErrCode +} + +func (f *queuedRSTStreamFrame) StreamID() uint32 { + return f.streamID +} + +func (*queuedRSTStreamFrame) flowControlSize() int { + return 0 +} + +func (f *queuedRSTStreamFrame) send(dest *http2.Framer) error { + if err := dest.WriteRSTStream(f.streamID, f.errCode); err != nil { + return fmt.Errorf("sending %v: %w", f, err) + } + return nil +} + +func (f *queuedRSTStreamFrame) String() string { + return fmt.Sprintf("RSTStream[id=%d, errCode=%v]", f.streamID, f.errCode) +} diff --git a/h2/relay.go b/h2/relay.go new file mode 100644 index 000000000..4397eea48 --- /dev/null +++ b/h2/relay.go @@ -0,0 +1,618 @@ +// Copyright 2021 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package h2 + +import ( + "bytes" + "container/list" + "errors" + "fmt" + "io" + "math" + "sync" + "sync/atomic" + + "github.com/google/martian/v3/log" + "golang.org/x/net/http2" + "golang.org/x/net/http2/hpack" +) + +const ( + // See: https://httpwg.org/specs/rfc7540.html#SettingValues + initialMaxFrameSize = 16384 + initialMaxHeaderTableSize = 4096 + + // See: https://tools.ietf.org/html/rfc7540#section-6.9.2 + defaultInitialWindowSize = 65535 + + // headersPriorityMetadataLength is the length of the priority metadata that optionally occurs at + // the beginning of the payload of the header frame. + // + // See: https://tools.ietf.org/html/rfc7540#section-6.2 + headersPriorityMetadataLength = 5 + + // pushPromiseMetadataLength is the length of the metadata that is part of the payload of the + // pushPromise frame. This does not include the padding length octet, which isn't needed due to + // the relaxed security constraints of a development proxy. + // + // See: https://tools.ietf.org/html/rfc7540#section-6.6 + pushPromiseMetadataLength = 4 + + // outputChannelSize is the size of the output channel. Roughly, it should be large enough to + // allow a window's worth of frames to minimize synchronization overhead. + outputChannelSize = 15 +) + +// relay encapsulates a flow of h2 traffic in one direction. +type relay struct { + dir Direction + + // srcLabel and destLabel are used only to create debugging messages. + srcLabel, destLabel string + + src *http2.Framer + + // destMu guards writes to dest, which may occur on from either the `relayFrames` thread of + // this relay or `peer`. `peer` writes WINDOW_UPDATE frames to this relay when it receives + // DATA frames. + destMu sync.Mutex + dest *http2.Framer + + // maxFrameSize is set by the peer relay and is accessed atomically. + maxFrameSize uint32 + + // The decoder and encoder settings can be adjusted by the peer connection so access to these + // fields must be guarded. + decoderMu sync.Mutex + decoder *hpack.Decoder + + encoderMu sync.Mutex + encoder *hpack.Encoder + reencoded bytes.Buffer // handle to the output buffer of `encoder` + + // headerBuffer collects header fragments that are received across multiple frames, i.e., + // when there are continuation frames. + headerBuffer bytes.Buffer + continuationState continuationState + + // flowMu guards access to flow-control related fields. + flowMu sync.Mutex + initialWindowSize uint32 + connectionWindowSize int // "global" connection-level window size + // outputBuffers is output pending available window size per-stream + outputBuffers map[uint32]*outputBuffer + // output stores stream output that is ready to be sent over HTTP/2. It provides a way to + // guarantee frame order without blocking on each frame being sent. + output chan queuedFrame + + enableDebugLogs *bool + + // The following fields depend on a circular dependency between the relays in opposite directions + // so must be set explicitly after initialization. + + // processors stores per HTTP/2 stream processors. + processors *streamProcessors + + peer *relay // relay for traffic from the peer +} + +// newRelay initializes a relay for the given direction. This performs only partial initialization +// due to circular dependency. +func newRelay( + dir Direction, + srcLabel, destLabel string, + src, dest *http2.Framer, + enableDebugLogs *bool, +) *relay { + ret := &relay{ + dir: dir, + srcLabel: srcLabel, + destLabel: destLabel, + src: src, + dest: dest, + maxFrameSize: initialMaxFrameSize, + decoder: hpack.NewDecoder(initialMaxHeaderTableSize, nil), + initialWindowSize: defaultInitialWindowSize, + connectionWindowSize: defaultInitialWindowSize, + outputBuffers: make(map[uint32]*outputBuffer), + output: make(chan queuedFrame, outputChannelSize), + enableDebugLogs: enableDebugLogs, + } + ret.encoder = hpack.NewEncoder(&ret.reencoded) + + // These limits seem to be part of the Go implementation of hpack. They exist because in a + // production system, there must be limits on the resources requested by clients. However, this + // is irrevelevant in a development proxy context. + ret.decoder.SetAllowedMaxDynamicTableSize(math.MaxUint32) + ret.encoder.SetMaxDynamicTableSizeLimit(math.MaxUint32) + return ret +} + +// relayFrames reads frames from `f.src` to `f.dest` until an error occurs or the connection closes. +func (r *relay) relayFrames(closing chan bool) error { + // Shutting down producer-consumers linked by channels is subtle. In this function, the writer + // goroutine consumes frames from `r.output`, which are populated by the reader goroutine. If + // the writer shuts down before the reader, the reader may deadlock on inserting frames into + // `r.output`. The writer therefore has to keep processing until the reader is done. This is + // coordinated via `readerDone`. + // + // A second subtlely is that errors on the writer goroutine should stop the reader goroutine. + // This is communicated via `writeErr`. To avoid deadlocks, even after the error occurs, the + // writer thread must still wait until `readerDone` has been communicated to stop processing. + + // Communicates to the consuming writer goroutine that the reader (the calling goroutine of this + // method) is done. + readerDone := make(chan struct{}) + defer func() { readerDone <- struct{}{} }() + + // Communicates errors occuring on the writer goroutine to the reader goroutine. + writerErr := make(chan error, 1) + + // This writer goroutine consumes the strictly ordered frames in `r.output` and delivers them. + go func() { + var err error + for { + select { + case f := <-r.output: + if err == nil { + r.destMu.Lock() + err = f.send(r.dest) + r.destMu.Unlock() + if err != nil { + writerErr <- err + } + } + // Once an output error has occurred, the remaining frames are drained from the channel + // without sending them. + case <-readerDone: + return + } + } + }() + + // This channel is buffered to allow the ReadFrame goroutine to drain on closing. + frameReady := make(chan struct{}, 1) + for { + var frame http2.Frame + var err error + go func() { + // ReadFrame is called in its own goroutine to make this function responsive to closing. It + // does not need to block here to close. + frame, err = r.src.ReadFrame() + frameReady <- struct{}{} + }() + select { + case <-frameReady: + if err != nil { + if err == io.EOF { + return nil + } + return fmt.Errorf("reading frame: %w", err) + } + if err := r.processFrame(frame); err != nil { + return fmt.Errorf("processing frame: %w", err) + } + if *r.enableDebugLogs { + log.Infof("%s--%v-->%s", r.srcLabel, frame, r.destLabel) + } + case err := <-writerErr: + return fmt.Errorf("sending frame: %w", err) + case <-closing: + // The ReadFrame goroutine is abandoned at this point. It completes as soon as the blocking + // ReadFrame call completes, but could potentially leak for an unspecified duration. + return nil + } + } +} + +func (r *relay) processFrame(f http2.Frame) error { + var err error + switch f := f.(type) { + case *http2.DataFrame: + // The proxy's window increments as soon as it receives data. This assumes that the proxy has + // ample resources because it is inteded for testing and development. + if err = r.peer.sendWindowUpdates(f); err == nil { + err = r.processor(f.StreamID).Data(f.Data(), f.StreamEnded()) + } + case *http2.HeadersFrame: + if !f.HeadersEnded() { + r.headerBuffer.Reset() + r.headerBuffer.Write(f.HeaderBlockFragment()) + r.continuationState = &headerContinuation{f.Priority} + } else { + var headers []hpack.HeaderField + headers, err = r.decodeFull(f.HeaderBlockFragment()) + if err != nil { + return fmt.Errorf("decoding header %v: %w", f, err) + } + err = r.processor(f.StreamID).Header(headers, f.StreamEnded(), f.Priority) + } + case *http2.PriorityFrame: + err = r.processor(f.StreamID).Priority(f.PriorityParam) + case *http2.RSTStreamFrame: + err = r.processor(f.StreamID).RSTStream(f.ErrCode) + case *http2.SettingsFrame: + if f.IsAck() { + r.destMu.Lock() + err = r.dest.WriteSettingsAck() + r.destMu.Unlock() + } else { + var settings []http2.Setting + if err = f.ForeachSetting(func(s http2.Setting) error { + switch s.ID { + case http2.SettingHeaderTableSize: + r.peer.updateTableSize(s.Val) + case http2.SettingInitialWindowSize: + r.peer.updateInitialWindowSize(s.Val) + case http2.SettingMaxFrameSize: + r.peer.updateMaxFrameSize(s.Val) + } + settings = append(settings, s) + return nil + }); err == nil { + r.destMu.Lock() + err = r.dest.WriteSettings(settings...) + r.destMu.Unlock() + } + } + case *http2.PushPromiseFrame: + if !f.HeadersEnded() { + r.headerBuffer.Reset() + r.headerBuffer.Write(f.HeaderBlockFragment()) + r.continuationState = &pushPromiseContinuation{f.PromiseID} + } else { + var headers []hpack.HeaderField + headers, err = r.decodeFull(f.HeaderBlockFragment()) + if err != nil { + return fmt.Errorf("decoding push promise %v: %w", f, err) + } + err = r.processor(f.StreamID).PushPromise(f.PromiseID, headers) + } + case *http2.PingFrame: + r.destMu.Lock() + err = r.dest.WritePing(f.IsAck(), f.Data) + r.destMu.Unlock() + case *http2.GoAwayFrame: + r.destMu.Lock() + err = r.dest.WriteGoAway(f.LastStreamID, f.ErrCode, f.DebugData()) + r.destMu.Unlock() + case *http2.WindowUpdateFrame: + r.peer.updateWindow(f) + case *http2.ContinuationFrame: + r.headerBuffer.Write(f.HeaderBlockFragment()) + if f.HeadersEnded() { + var headers []hpack.HeaderField + headers, err = r.decodeFull(r.headerBuffer.Bytes()) + if err != nil { + return fmt.Errorf("decoding headers for continuation %v: %w", f, err) + } + err = r.continuationState.complete(r.processor(f.StreamID), headers) + } + default: + err = errors.New("unrecognized frame type") + } + return err +} + +func (r *relay) processor(id uint32) Processor { + return r.processors.Get(id, r.dir) +} + +func (r *relay) updateTableSize(v uint32) { + r.decoderMu.Lock() + r.decoder.SetMaxDynamicTableSize(v) + r.decoderMu.Unlock() + + r.encoderMu.Lock() + r.encoder.SetMaxDynamicTableSize(v) + r.encoderMu.Unlock() +} + +func (r *relay) updateMaxFrameSize(v uint32) { + atomic.StoreUint32(&r.maxFrameSize, v) +} + +// updateInitialWindowSize updates the initial window size and updates all stream windows based on +// the difference. Note that this should not include the connection window. +// See: https://tools.ietf.org/html/rfc7540#section-6.9.2 +// +// This is called by `peer`, so requires a thread-safe implementation. +func (r *relay) updateInitialWindowSize(v uint32) { + r.flowMu.Lock() + delta := int(v) - int(r.initialWindowSize) + r.initialWindowSize = v + for _, w := range r.outputBuffers { + w.windowSize += delta + } + r.flowMu.Unlock() + // Since all the stream windows may be impacted, all the queues need to be checked for newly + // eligible frames. + r.sendQueuedFramesUnderWindowSize() +} + +// updateWindow updates the specified window size and may result in the sending of data frames. +func (r *relay) updateWindow(f *http2.WindowUpdateFrame) { + if f.StreamID == 0 { + // A stream ID of 0 means updating the global connection window size. This may cause any + // queued frame belonging to any stream to become eligible for sending. + r.flowMu.Lock() + r.connectionWindowSize += int(f.Increment) + r.flowMu.Unlock() + r.sendQueuedFramesUnderWindowSize() + } + + r.flowMu.Lock() + w := r.outputBuffer(f.StreamID) + w.windowSize += int(f.Increment) + w.emitEligibleFrames(r.output, &r.connectionWindowSize) + r.flowMu.Unlock() +} + +func (r *relay) data(id uint32, data []byte, streamEnded bool) error { + // This implementation only allows `WriteData` without padding. Padding is used to improve the + // security against attacks like CRIME, but this isn't relevant for a development proxy. + // + // If padding were allowed, this length would need to vary depending on whether the padding + // length octet is present. + maxPayloadLength := atomic.LoadUint32(&r.maxFrameSize) + + r.flowMu.Lock() + w := r.outputBuffer(id) + r.flowMu.Unlock() + // If data is larger than what would be permitted at the current max frame size setting, the data + // is split across multiple frames. + for { + nextPayloadLength := uint32(len(data)) + if nextPayloadLength > maxPayloadLength { + nextPayloadLength = maxPayloadLength + } + nextPayload := make([]byte, nextPayloadLength) + copy(nextPayload, data) + data = data[nextPayloadLength:] + f := &queuedDataFrame{id, streamEnded && len(data) == 0, nextPayload} + + r.flowMu.Lock() + w.enqueue(f) + w.emitEligibleFrames(r.output, &r.connectionWindowSize) + r.flowMu.Unlock() + + // Some protocols send empty data frames with END_STREAM so the check is done here at the end + // of the loop instead of at the beginning of the loop. + if len(data) == 0 { + break + } + } + return nil +} + +func (r *relay) header( + id uint32, + headers []hpack.HeaderField, + streamEnded bool, + priority http2.PriorityParam, +) error { + encoded, err := r.encodeFull(headers) + if err != nil { + return fmt.Errorf("encoding headers %v: %w", headers, err) + } + + maxPayloadLength := atomic.LoadUint32(&r.maxFrameSize) + // Padding is not implemented because the extra security is not needed for a development proxy. + // If it were used, a single padding length octet should be deducted from the max header fragment + // length. + maxHeaderFragmentLength := maxPayloadLength + if !priority.IsZero() { + maxHeaderFragmentLength -= headersPriorityMetadataLength + } + chunks := splitIntoChunks(int(maxHeaderFragmentLength), int(maxPayloadLength), encoded) + + r.enqueueFrame(&queuedHeaderFrame{ + streamID: id, + endStream: streamEnded, + priority: priority, + chunks: chunks, + }) + return nil +} + +func (r *relay) priority(id uint32, priority http2.PriorityParam) { + r.enqueueFrame(&queuedPriorityFrame{ + streamID: id, + priority: priority, + }) +} + +func (r *relay) rstStream(id uint32, errCode http2.ErrCode) { + r.enqueueFrame(&queuedRSTStreamFrame{ + streamID: id, + errCode: errCode, + }) +} + +func (r *relay) pushPromise(id, promiseID uint32, headers []hpack.HeaderField) error { + encoded, err := r.encodeFull(headers) + if err != nil { + return fmt.Errorf("encoding push promise headers %v: %w", headers, err) + } + + maxPayloadLength := atomic.LoadUint32(&r.maxFrameSize) + maxHeaderFragmentLength := maxPayloadLength - pushPromiseMetadataLength + chunks := splitIntoChunks(int(maxHeaderFragmentLength), int(maxPayloadLength), encoded) + + r.enqueueFrame(&queuedPushPromiseFrame{ + streamID: id, + promiseID: promiseID, + chunks: chunks, + }) + return nil +} + +func (r *relay) enqueueFrame(f queuedFrame) { + // The frame is first added to the appropriate stream. + r.flowMu.Lock() + w := r.outputBuffer(f.StreamID()) + w.enqueue(f) + w.emitEligibleFrames(r.output, &r.connectionWindowSize) + r.flowMu.Unlock() +} + +func (r *relay) sendQueuedFramesUnderWindowSize() { + r.flowMu.Lock() + for _, w := range r.outputBuffers { + w.emitEligibleFrames(r.output, &r.connectionWindowSize) + } + r.flowMu.Unlock() +} + +// outputBuffer returns the outputBuffer instance for the given stream, creating one if needed. +// +// This method is not thread-safe. The caller should be holding `flowMu`. +func (r *relay) outputBuffer(streamID uint32) *outputBuffer { + w, ok := r.outputBuffers[streamID] + if !ok { + w = &outputBuffer{ + windowSize: int(r.initialWindowSize), + } + r.outputBuffers[streamID] = w + } + return w +} + +// sendWindowUpdates sends WINDOW_UPDATE frames effectively acknowledging consumption of the +// given data frame. +func (r *relay) sendWindowUpdates(f *http2.DataFrame) error { + if len(f.Data()) <= 0 { + return nil + } + r.destMu.Lock() + defer r.destMu.Unlock() + // First updates the connection level window. + if err := r.dest.WriteWindowUpdate(0, uint32(len(f.Data()))); err != nil { + return err + } + // Next updates the stream specific window. + return r.dest.WriteWindowUpdate(f.StreamID, uint32(len(f.Data()))) +} + +func (r *relay) decodeFull(data []byte) ([]hpack.HeaderField, error) { + r.decoderMu.Lock() + defer r.decoderMu.Unlock() + return r.decoder.DecodeFull(data) +} + +func (r *relay) encodeFull(headers []hpack.HeaderField) ([]byte, error) { + r.encoderMu.Lock() + defer r.encoderMu.Unlock() + + r.reencoded.Reset() + var buf bytes.Buffer + for _, h := range headers { + if *r.enableDebugLogs { + if h.Name == "content-type" && h.Value == "application/grpc" { + fmt.Fprintf(&buf, " \u001b[1;36m%v\u001b[0m\n", h) + } else { + fmt.Fprintf(&buf, " %v\n", h) + } + } + if err := r.encoder.WriteField(h); err != nil { + return nil, fmt.Errorf("reencoding header field %v in %v: %w", h, headers, err) + } + } + if *r.enableDebugLogs { + log.Infof("sending headers %s -> %s:\n%s", r.srcLabel, r.destLabel, buf.Bytes()) + } + return r.reencoded.Bytes(), nil +} + +// outputBuffer stores enqueued output frames for a given stream. +type outputBuffer struct { + // windowSize indicates how much data the receiver is ready to process. + windowSize int + queue list.List // contains queuedFrame elements +} + +// emitEligibleFrames emits frames that would fit under both the stream window size and the +// given connection window size. It updates the given connectionWindowSize if applicable. +// +// This is not thread-safe. The caller should be holding `relay.flowMu`. +func (w *outputBuffer) emitEligibleFrames(output chan queuedFrame, connectionWindowSize *int) { + for e := w.queue.Front(); e != nil; { + f := e.Value.(queuedFrame) + if f.flowControlSize() > *connectionWindowSize || f.flowControlSize() > w.windowSize { + break + } + output <- f + + *connectionWindowSize -= f.flowControlSize() + w.windowSize -= f.flowControlSize() + + next := e.Next() + w.queue.Remove(e) + e = next + } +} + +// enqueue adds the frame to this stream output. This is not thread-safe. The caller must hold +// relay.flowMu. +func (w *outputBuffer) enqueue(f queuedFrame) { + w.queue.PushBack(f) +} + +// continuationState holds the context needed to interpret CONTINUATION frames, specifically whether +// the parents were HEADERS or PUSH_PROMISE frames. +type continuationState interface { + complete(s Processor, headers []hpack.HeaderField) error +} + +type headerContinuation struct { + priority http2.PriorityParam +} + +func (h *headerContinuation) complete(s Processor, headers []hpack.HeaderField) error { + return s.Header(headers, true, h.priority) +} + +type pushPromiseContinuation struct { + promiseID uint32 +} + +func (p *pushPromiseContinuation) complete(s Processor, headers []hpack.HeaderField) error { + return s.PushPromise(p.promiseID, headers) +} + +// splitIntoChunks splits header payloads into chunks that respect frame size limits. +func splitIntoChunks(firstChunkMax, continuationMax int, data []byte) [][]byte { + var chunks [][]byte + + firstChunkLength := len(data) + if firstChunkLength > firstChunkMax { + firstChunkLength = firstChunkMax + } + buf := make([]byte, firstChunkLength) + copy(buf, data[:firstChunkLength]) + chunks = append(chunks, buf) + remaining := data[firstChunkLength:] + for len(remaining) > 0 { + nextChunkLength := len(remaining) + if nextChunkLength > continuationMax { + nextChunkLength = continuationMax + } + buf = make([]byte, nextChunkLength) + copy(buf, remaining[:nextChunkLength]) + chunks = append(chunks, buf) + remaining = remaining[nextChunkLength:] + } + return chunks +} diff --git a/h2/testing/certs.go b/h2/testing/certs.go new file mode 100644 index 000000000..fe73004e0 --- /dev/null +++ b/h2/testing/certs.go @@ -0,0 +1,134 @@ +// Copyright 2021 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testing + +import ( + "crypto" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "fmt" + "log" + "os" + "time" + + "github.com/google/martian/v3/cybervillains" + "github.com/google/martian/v3/mitm" + "google.golang.org/grpc/credentials" +) + +var ( + // CA is the certificate authority. It uses the Cybervillains key pair. + CA *x509.Certificate + // CAKey is the private key of the certificate authority. + CAKey crypto.PrivateKey + + // RootCAs is a certificate pool containing `CA`. + RootCAs *x509.CertPool + + // ClientTLS is a set of transport credentials to use with chains signed by `CA`. + ClientTLS credentials.TransportCredentials + + // Localhost is a certificate for "localhost" signed by `CA`. + Localhost *tls.Certificate +) + +func init() { + var err error + CA, CAKey, err = initCA() + if err != nil { + log.Fatalf("Error initializing Cybervillains CA: %v", err) + } + + RootCAs = x509.NewCertPool() + RootCAs.AddCert(CA) + ClientTLS = credentials.NewClientTLSFromCert(RootCAs, "") + + Localhost, err = initLocalhostCert(CA, CAKey) + if err != nil { + log.Fatalf("Error creating localhost server certificate: %v", err) + } +} + +func initCA() (*x509.Certificate, crypto.PrivateKey, error) { + chain, err := tls.X509KeyPair([]byte(cybervillains.Cert), []byte(cybervillains.Key)) + if err != nil { + return nil, nil, fmt.Errorf("creating Cybervillains root: %w", err) + } + cert, err := x509.ParseCertificate(chain.Certificate[0]) + if err != nil { + return nil, nil, fmt.Errorf("parsing Cybervillains certificate: %w", err) + } + return cert, chain.PrivateKey, nil +} + +func initLocalhostCert(ca *x509.Certificate, caPriv crypto.PrivateKey) (*tls.Certificate, error) { + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, fmt.Errorf("generating random key: %w", err) + } + + // Subject Key Identifier support for end entity certificate. + // https://www.ietf.org/rfc/rfc3280.txt (section 4.2.1.2) + pkixpub, err := x509.MarshalPKIXPublicKey(priv.Public()) + if err != nil { + return nil, fmt.Errorf("marshalling public key: %w", err) + } + hasher := sha256.New() + hasher.Write(pkixpub) + keyID := hasher.Sum(nil) + + serial, err := rand.Int(rand.Reader, mitm.MaxSerialNumber) + if err != nil { + return nil, fmt.Errorf("generating serial number: %w", err) + } + + hostname, err := os.Hostname() + if err != nil { + return nil, fmt.Errorf("getting hostname for creating cert: %w", err) + } + + tmpl := &x509.Certificate{ + SerialNumber: serial, + Subject: pkix.Name{ + CommonName: hostname, + Organization: []string{"Martian Proxy"}, + }, + SubjectKeyId: keyID, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), + DNSNames: []string{hostname}, + } + + der, err := x509.CreateCertificate(rand.Reader, tmpl, ca, priv.Public(), caPriv) + if err != nil { + return nil, fmt.Errorf("creating X509 server certificate: %w", err) + } + x509c, err := x509.ParseCertificate(der) + if err != nil { + return nil, fmt.Errorf("parsing DER encoded certificate: %w", err) + } + return &tls.Certificate{ + Certificate: [][]byte{x509c.Raw, ca.Raw}, + PrivateKey: priv, + Leaf: x509c, + }, nil +} diff --git a/h2/testing/fixture.go b/h2/testing/fixture.go new file mode 100644 index 000000000..56793dbbd --- /dev/null +++ b/h2/testing/fixture.go @@ -0,0 +1,204 @@ +// Copyright 2021 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package testing contains a test fixture for working with gRPC over HTTP/2. +package testing + +import ( + "crypto/tls" + "fmt" + "io/ioutil" + "net" + "net/http" + "os" + "strconv" + "sync" + "time" + + "github.com/google/martian/v3" + "github.com/google/martian/v3/h2" + "github.com/google/martian/v3/mitm" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + + tspb "github.com/google/martian/v3/h2/testservice" +) + +var ( + // proxyPort is a global variable that stores the listener used by the proxy. This value is + // shared globally because golang http transport code caches the environment variable values, in + // particular HTTPS_PROXY. + proxyPort int +) + +// Fixture encapsulates the TestService gRPC server, a proxy and a gRPC client. +type Fixture struct { + // TestServiceClient is a client pointing at the service and redirected through the proxy. + tspb.TestServiceClient + + wg sync.WaitGroup + server *grpc.Server + // serverErr is any error returned by invoking `Serve` on the gRPC server. + serverErr error + + proxyListener net.Listener + proxy *martian.Proxy + + conn *grpc.ClientConn +} + +// New creates a new instance of the Fixture. It is not possible for there to be more than one +// instance concurrently because clients decide whether to use the proxy based on the global +// HTTPS_PROXY environment variable. +func New(spf []h2.StreamProcessorFactory) (*Fixture, error) { + f := &Fixture{} + + // Starts the gRPC server. + f.server = grpc.NewServer(grpc.Creds(credentials.NewServerTLSFromCert(Localhost))) + tspb.RegisterTestServiceServer(f.server, &Server{}) + + lis, err := net.Listen("tcp", ":0") + if err != nil { + return nil, fmt.Errorf("creating listener for gRPC service: %w", err) + } + + f.wg.Add(1) + go func() { + defer f.wg.Done() + f.serverErr = f.server.Serve(lis) + }() + + hostname, err := os.Hostname() + if err != nil { + return nil, fmt.Errorf("getting hostname: %w", err) + } + + // Creates a listener for the proxy, obtaining a new port if needed. + if proxyPort == 0 { + // Attempts a query to port server first, falling back if it is unavailable. Ports that are + // provided by listening on ":0" can be recyled by the OS leading to flakiness in certain + // environments since we need the same port to be available across multiple instances of the + // test fixture. + proxyPort = queryPortServer() + if proxyPort == 0 { + var err error + f.proxyListener, err = net.Listen("tcp", ":0") + if err != nil { + return nil, fmt.Errorf("creating listener for proxy; %w", err) + } + proxyPort = f.proxyListener.Addr().(*net.TCPAddr).Port + } + proxyTarget := hostname + ":" + strconv.Itoa(proxyPort) + // Sets the HTTPS_PROXY environment variable so that http requests will go through the proxy. + os.Setenv("HTTPS_PROXY", fmt.Sprintf("http://%s", proxyTarget)) + fmt.Printf("proxy at %s\n", proxyTarget) + } + if f.proxyListener == nil { + var err error + f.proxyListener, err = net.Listen("tcp", fmt.Sprintf(":%d", proxyPort)) + if err != nil { + return nil, fmt.Errorf("creating listener for proxy; %w", err) + } + } + + // Starts the proxy. + f.proxy, err = newProxy(spf) + if err != nil { + return nil, fmt.Errorf("creating proxy: %w", err) + } + go func() { + f.proxy.Serve(f.proxyListener) + }() + + port := lis.Addr().(*net.TCPAddr).Port + target := hostname + ":" + strconv.Itoa(port) + + fmt.Printf("server at %s\n", target) + + // Connects a gRPC client with the service via the proxy. + f.conn, err = grpc.Dial(target, grpc.WithTransportCredentials(ClientTLS)) + if err != nil { + return nil, fmt.Errorf("error dialing %s: %w", target, err) + } + f.TestServiceClient = tspb.NewTestServiceClient(f.conn) + + return f, nil +} + +// Close cleans up the servers and connections. +func (f *Fixture) Close() error { + f.conn.Close() + f.server.Stop() + f.proxy.Close() + f.wg.Wait() + + if err := f.proxyListener.Close(); err != nil { + return fmt.Errorf("closing proxy listener: %w", err) + } + return f.serverErr +} + +func newProxy(spf []h2.StreamProcessorFactory) (*martian.Proxy, error) { + p := martian.NewProxy() + mc, err := mitm.NewConfig(CA, CAKey) + if err != nil { + return nil, fmt.Errorf("creating mitm config: %w", err) + } + mc.SetValidity(time.Hour) + mc.SetOrganization("Martian Proxy") + mc.SetH2Config(&h2.Config{ + AllowedHostsFilter: func(_ string) bool { return true }, + RootCAs: RootCAs, + StreamProcessorFactories: spf, + EnableDebugLogs: true, + }) + + p.SetMITM(mc) + + tr := &http.Transport{ + TLSClientConfig: &tls.Config{ + RootCAs: RootCAs, + }, + } + p.SetRoundTripper(tr) + + return p, nil +} + +func queryPortServer() int { + // portpicker isn't available in third_party. + if portServer := os.Getenv("PORTSERVER_ADDRESS"); portServer != "" { + c, err := net.Dial("unix", portServer) + if err != nil { + // failed connection to portServer; this is normal in many circumstances. + return 0 + } + defer c.Close() + if _, err := fmt.Fprintf(c, "%d\n", os.Getpid()); err != nil { + return 0 + } + buf, err := ioutil.ReadAll(c) + if err != nil || len(buf) == 0 { + return 0 + } + buf = buf[:len(buf)-1] // remove newline char + port, err := strconv.Atoi(string(buf)) + if err != nil { + return 0 + } + fmt.Printf("got port %d\n", port) + return port + } + return 0 +} diff --git a/h2/testing/test_service.go b/h2/testing/test_service.go new file mode 100644 index 000000000..3d14b9c6b --- /dev/null +++ b/h2/testing/test_service.go @@ -0,0 +1,67 @@ +// Copyright 2021 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testing + +import ( + "context" + "io" + + tspb "github.com/google/martian/v3/h2/testservice" +) + +// Server is a testing gRPC server. +type Server struct { + tspb.UnimplementedTestServiceServer +} + +// Echo handles TestService.Echo RPCs. +func (s *Server) Echo(ctx context.Context, in *tspb.EchoRequest) (*tspb.EchoResponse, error) { + return &tspb.EchoResponse{ + Payload: in.GetPayload(), + }, nil +} + +// Sum handles TestService.Sum RPCs. +func (s *Server) Sum(_ context.Context, in *tspb.SumRequest) (*tspb.SumResponse, error) { + sum := int32(0) + for _, v := range in.GetValues() { + sum += v + } + return &tspb.SumResponse{ + Value: sum, + }, nil +} + +// DoubleEcho handles TestService.DoubleEcho RPCs. +func (s *Server) DoubleEcho(stream tspb.TestService_DoubleEchoServer) error { + for { + req, err := stream.Recv() + if err == io.EOF { + return nil + } + if err != nil { + return err + } + resp := &tspb.EchoResponse{ + Payload: req.GetPayload(), + } + if err := stream.Send(resp); err != nil { + return err + } + if err := stream.Send(resp); err != nil { + return err + } + } +} diff --git a/h2/testservice/test_service.pb.go b/h2/testservice/test_service.pb.go new file mode 100644 index 000000000..03d5f160a --- /dev/null +++ b/h2/testservice/test_service.pb.go @@ -0,0 +1,365 @@ +// Copyright 2021 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.26.0 +// protoc v3.6.1 +// source: test_service.proto + +package testservice + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type EchoRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Payload string `protobuf:"bytes,1,opt,name=payload,proto3" json:"payload,omitempty"` +} + +func (x *EchoRequest) Reset() { + *x = EchoRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_test_service_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *EchoRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*EchoRequest) ProtoMessage() {} + +func (x *EchoRequest) ProtoReflect() protoreflect.Message { + mi := &file_test_service_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use EchoRequest.ProtoReflect.Descriptor instead. +func (*EchoRequest) Descriptor() ([]byte, []int) { + return file_test_service_proto_rawDescGZIP(), []int{0} +} + +func (x *EchoRequest) GetPayload() string { + if x != nil { + return x.Payload + } + return "" +} + +type EchoResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Payload string `protobuf:"bytes,1,opt,name=payload,proto3" json:"payload,omitempty"` +} + +func (x *EchoResponse) Reset() { + *x = EchoResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_test_service_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *EchoResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*EchoResponse) ProtoMessage() {} + +func (x *EchoResponse) ProtoReflect() protoreflect.Message { + mi := &file_test_service_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use EchoResponse.ProtoReflect.Descriptor instead. +func (*EchoResponse) Descriptor() ([]byte, []int) { + return file_test_service_proto_rawDescGZIP(), []int{1} +} + +func (x *EchoResponse) GetPayload() string { + if x != nil { + return x.Payload + } + return "" +} + +type SumRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Values []int32 `protobuf:"varint,1,rep,packed,name=values,proto3" json:"values,omitempty"` +} + +func (x *SumRequest) Reset() { + *x = SumRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_test_service_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *SumRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SumRequest) ProtoMessage() {} + +func (x *SumRequest) ProtoReflect() protoreflect.Message { + mi := &file_test_service_proto_msgTypes[2] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SumRequest.ProtoReflect.Descriptor instead. +func (*SumRequest) Descriptor() ([]byte, []int) { + return file_test_service_proto_rawDescGZIP(), []int{2} +} + +func (x *SumRequest) GetValues() []int32 { + if x != nil { + return x.Values + } + return nil +} + +type SumResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Value int32 `protobuf:"varint,1,opt,name=value,proto3" json:"value,omitempty"` +} + +func (x *SumResponse) Reset() { + *x = SumResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_test_service_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *SumResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SumResponse) ProtoMessage() {} + +func (x *SumResponse) ProtoReflect() protoreflect.Message { + mi := &file_test_service_proto_msgTypes[3] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SumResponse.ProtoReflect.Descriptor instead. +func (*SumResponse) Descriptor() ([]byte, []int) { + return file_test_service_proto_rawDescGZIP(), []int{3} +} + +func (x *SumResponse) GetValue() int32 { + if x != nil { + return x.Value + } + return 0 +} + +var File_test_service_proto protoreflect.FileDescriptor + +var file_test_service_proto_rawDesc = []byte{ + 0x0a, 0x12, 0x74, 0x65, 0x73, 0x74, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x0c, 0x74, 0x65, 0x73, 0x74, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x69, + 0x63, 0x65, 0x22, 0x27, 0x0a, 0x0b, 0x45, 0x63, 0x68, 0x6f, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x12, 0x18, 0x0a, 0x07, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x07, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x22, 0x28, 0x0a, 0x0c, 0x45, + 0x63, 0x68, 0x6f, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x70, + 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x70, 0x61, + 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x22, 0x24, 0x0a, 0x0a, 0x53, 0x75, 0x6d, 0x52, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x73, 0x18, 0x01, 0x20, + 0x03, 0x28, 0x05, 0x52, 0x06, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x73, 0x22, 0x23, 0x0a, 0x0b, 0x53, + 0x75, 0x6d, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, + 0x6c, 0x75, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, + 0x32, 0xd7, 0x01, 0x0a, 0x0b, 0x54, 0x65, 0x73, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, + 0x12, 0x3f, 0x0a, 0x04, 0x45, 0x63, 0x68, 0x6f, 0x12, 0x19, 0x2e, 0x74, 0x65, 0x73, 0x74, 0x5f, + 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x45, 0x63, 0x68, 0x6f, 0x52, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x74, 0x65, 0x73, 0x74, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x69, + 0x63, 0x65, 0x2e, 0x45, 0x63, 0x68, 0x6f, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, + 0x00, 0x12, 0x3c, 0x0a, 0x03, 0x53, 0x75, 0x6d, 0x12, 0x18, 0x2e, 0x74, 0x65, 0x73, 0x74, 0x5f, + 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x53, 0x75, 0x6d, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x74, 0x65, 0x73, 0x74, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, + 0x65, 0x2e, 0x53, 0x75, 0x6d, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, + 0x49, 0x0a, 0x0a, 0x44, 0x6f, 0x75, 0x62, 0x6c, 0x65, 0x45, 0x63, 0x68, 0x6f, 0x12, 0x19, 0x2e, + 0x74, 0x65, 0x73, 0x74, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x45, 0x63, 0x68, + 0x6f, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x74, 0x65, 0x73, 0x74, 0x5f, + 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x45, 0x63, 0x68, 0x6f, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x2a, 0x5a, 0x28, 0x67, 0x69, + 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, + 0x6d, 0x61, 0x72, 0x74, 0x69, 0x61, 0x6e, 0x2f, 0x68, 0x32, 0x2f, 0x74, 0x65, 0x73, 0x74, 0x73, + 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_test_service_proto_rawDescOnce sync.Once + file_test_service_proto_rawDescData = file_test_service_proto_rawDesc +) + +func file_test_service_proto_rawDescGZIP() []byte { + file_test_service_proto_rawDescOnce.Do(func() { + file_test_service_proto_rawDescData = protoimpl.X.CompressGZIP(file_test_service_proto_rawDescData) + }) + return file_test_service_proto_rawDescData +} + +var file_test_service_proto_msgTypes = make([]protoimpl.MessageInfo, 4) +var file_test_service_proto_goTypes = []interface{}{ + (*EchoRequest)(nil), // 0: test_service.EchoRequest + (*EchoResponse)(nil), // 1: test_service.EchoResponse + (*SumRequest)(nil), // 2: test_service.SumRequest + (*SumResponse)(nil), // 3: test_service.SumResponse +} +var file_test_service_proto_depIdxs = []int32{ + 0, // 0: test_service.TestService.Echo:input_type -> test_service.EchoRequest + 2, // 1: test_service.TestService.Sum:input_type -> test_service.SumRequest + 0, // 2: test_service.TestService.DoubleEcho:input_type -> test_service.EchoRequest + 1, // 3: test_service.TestService.Echo:output_type -> test_service.EchoResponse + 3, // 4: test_service.TestService.Sum:output_type -> test_service.SumResponse + 1, // 5: test_service.TestService.DoubleEcho:output_type -> test_service.EchoResponse + 3, // [3:6] is the sub-list for method output_type + 0, // [0:3] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_test_service_proto_init() } +func file_test_service_proto_init() { + if File_test_service_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_test_service_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*EchoRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_test_service_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*EchoResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_test_service_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*SumRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_test_service_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*SumResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_test_service_proto_rawDesc, + NumEnums: 0, + NumMessages: 4, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_test_service_proto_goTypes, + DependencyIndexes: file_test_service_proto_depIdxs, + MessageInfos: file_test_service_proto_msgTypes, + }.Build() + File_test_service_proto = out.File + file_test_service_proto_rawDesc = nil + file_test_service_proto_goTypes = nil + file_test_service_proto_depIdxs = nil +} diff --git a/h2/testservice/test_service.proto b/h2/testservice/test_service.proto new file mode 100644 index 000000000..ab4fdaee7 --- /dev/null +++ b/h2/testservice/test_service.proto @@ -0,0 +1,46 @@ +// Copyright 2021 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package test_service; + +option go_package = "github.com/google/martian/h2/testservice"; + +message EchoRequest { + string payload = 1; +} + +message EchoResponse { + string payload = 1; +} + +message SumRequest { + repeated int32 values = 1; +} + +message SumResponse { + int32 value = 1; +} + +service TestService { + // The server returns the client message as-is. + rpc Echo(EchoRequest) returns (EchoResponse) {} + + // The server returns the sum of the input values. + rpc Sum(SumRequest) returns (SumResponse) {} + + // The server returns every message twice. + rpc DoubleEcho(stream EchoRequest) returns (stream EchoResponse) {} +} diff --git a/h2/testservice/test_service_grpc.pb.go b/h2/testservice/test_service_grpc.pb.go new file mode 100644 index 000000000..4f10b5ce3 --- /dev/null +++ b/h2/testservice/test_service_grpc.pb.go @@ -0,0 +1,212 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. + +package testservice + +import ( + context "context" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.32.0 or later. +const _ = grpc.SupportPackageIsVersion7 + +// TestServiceClient is the client API for TestService service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type TestServiceClient interface { + // The server returns the client message as-is. + Echo(ctx context.Context, in *EchoRequest, opts ...grpc.CallOption) (*EchoResponse, error) + // The server returns the sum of the input values. + Sum(ctx context.Context, in *SumRequest, opts ...grpc.CallOption) (*SumResponse, error) + // The server returns every message twice. + DoubleEcho(ctx context.Context, opts ...grpc.CallOption) (TestService_DoubleEchoClient, error) +} + +type testServiceClient struct { + cc grpc.ClientConnInterface +} + +func NewTestServiceClient(cc grpc.ClientConnInterface) TestServiceClient { + return &testServiceClient{cc} +} + +func (c *testServiceClient) Echo(ctx context.Context, in *EchoRequest, opts ...grpc.CallOption) (*EchoResponse, error) { + out := new(EchoResponse) + err := c.cc.Invoke(ctx, "/test_service.TestService/Echo", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *testServiceClient) Sum(ctx context.Context, in *SumRequest, opts ...grpc.CallOption) (*SumResponse, error) { + out := new(SumResponse) + err := c.cc.Invoke(ctx, "/test_service.TestService/Sum", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *testServiceClient) DoubleEcho(ctx context.Context, opts ...grpc.CallOption) (TestService_DoubleEchoClient, error) { + stream, err := c.cc.NewStream(ctx, &TestService_ServiceDesc.Streams[0], "/test_service.TestService/DoubleEcho", opts...) + if err != nil { + return nil, err + } + x := &testServiceDoubleEchoClient{stream} + return x, nil +} + +type TestService_DoubleEchoClient interface { + Send(*EchoRequest) error + Recv() (*EchoResponse, error) + grpc.ClientStream +} + +type testServiceDoubleEchoClient struct { + grpc.ClientStream +} + +func (x *testServiceDoubleEchoClient) Send(m *EchoRequest) error { + return x.ClientStream.SendMsg(m) +} + +func (x *testServiceDoubleEchoClient) Recv() (*EchoResponse, error) { + m := new(EchoResponse) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +// TestServiceServer is the server API for TestService service. +// All implementations must embed UnimplementedTestServiceServer +// for forward compatibility +type TestServiceServer interface { + // The server returns the client message as-is. + Echo(context.Context, *EchoRequest) (*EchoResponse, error) + // The server returns the sum of the input values. + Sum(context.Context, *SumRequest) (*SumResponse, error) + // The server returns every message twice. + DoubleEcho(TestService_DoubleEchoServer) error + mustEmbedUnimplementedTestServiceServer() +} + +// UnimplementedTestServiceServer must be embedded to have forward compatible implementations. +type UnimplementedTestServiceServer struct { +} + +func (UnimplementedTestServiceServer) Echo(context.Context, *EchoRequest) (*EchoResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method Echo not implemented") +} +func (UnimplementedTestServiceServer) Sum(context.Context, *SumRequest) (*SumResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method Sum not implemented") +} +func (UnimplementedTestServiceServer) DoubleEcho(TestService_DoubleEchoServer) error { + return status.Errorf(codes.Unimplemented, "method DoubleEcho not implemented") +} +func (UnimplementedTestServiceServer) mustEmbedUnimplementedTestServiceServer() {} + +// UnsafeTestServiceServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to TestServiceServer will +// result in compilation errors. +type UnsafeTestServiceServer interface { + mustEmbedUnimplementedTestServiceServer() +} + +func RegisterTestServiceServer(s grpc.ServiceRegistrar, srv TestServiceServer) { + s.RegisterService(&TestService_ServiceDesc, srv) +} + +func _TestService_Echo_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(EchoRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(TestServiceServer).Echo(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/test_service.TestService/Echo", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(TestServiceServer).Echo(ctx, req.(*EchoRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _TestService_Sum_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(SumRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(TestServiceServer).Sum(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/test_service.TestService/Sum", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(TestServiceServer).Sum(ctx, req.(*SumRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _TestService_DoubleEcho_Handler(srv interface{}, stream grpc.ServerStream) error { + return srv.(TestServiceServer).DoubleEcho(&testServiceDoubleEchoServer{stream}) +} + +type TestService_DoubleEchoServer interface { + Send(*EchoResponse) error + Recv() (*EchoRequest, error) + grpc.ServerStream +} + +type testServiceDoubleEchoServer struct { + grpc.ServerStream +} + +func (x *testServiceDoubleEchoServer) Send(m *EchoResponse) error { + return x.ServerStream.SendMsg(m) +} + +func (x *testServiceDoubleEchoServer) Recv() (*EchoRequest, error) { + m := new(EchoRequest) + if err := x.ServerStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +// TestService_ServiceDesc is the grpc.ServiceDesc for TestService service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var TestService_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "test_service.TestService", + HandlerType: (*TestServiceServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "Echo", + Handler: _TestService_Echo_Handler, + }, + { + MethodName: "Sum", + Handler: _TestService_Sum_Handler, + }, + }, + Streams: []grpc.StreamDesc{ + { + StreamName: "DoubleEcho", + Handler: _TestService_DoubleEcho_Handler, + ServerStreams: true, + ClientStreams: true, + }, + }, + Metadata: "test_service.proto", +} diff --git a/mitm/mitm.go b/mitm/mitm.go index 29d113bf7..83ee60f5f 100644 --- a/mitm/mitm.go +++ b/mitm/mitm.go @@ -32,6 +32,7 @@ import ( "sync" "time" + "github.com/google/martian/v3/h2" "github.com/google/martian/v3/log" ) @@ -49,6 +50,7 @@ type Config struct { keyID []byte validity time.Duration org string + h2Config *h2.Config getCertificate func(*tls.ClientHelloInfo) (*tls.Certificate, error) roots *x509.CertPool skipVerify bool @@ -164,6 +166,16 @@ func (c *Config) SetOrganization(org string) { c.org = org } +// SetH2Config configures processing of HTTP/2 streams. +func (c *Config) SetH2Config(h2Config *h2.Config) { + c.h2Config = h2Config +} + +// H2Config returns the current HTTP/2 configuration. +func (c *Config) H2Config() *h2.Config { + return c.h2Config +} + // SetHandshakeErrorCallback sets the handshakeErrorCallback function. func (c *Config) SetHandshakeErrorCallback(cb func(*http.Request, error)) { c.handshakeErrorCallback = cb @@ -197,6 +209,10 @@ func (c *Config) TLS() *tls.Config { // TLSForHost returns a *tls.Config that will generate certificates on-the-fly // using SNI from the connection, or fall back to the provided hostname. func (c *Config) TLSForHost(hostname string) *tls.Config { + nextProtos := []string{"http/1.1"} + if c.h2AllowedHost(hostname) { + nextProtos = []string{"h2", "http/1.1"} + } return &tls.Config{ InsecureSkipVerify: c.skipVerify, GetCertificate: func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { @@ -207,10 +223,16 @@ func (c *Config) TLSForHost(hostname string) *tls.Config { return c.cert(host) }, - NextProtos: []string{"http/1.1"}, + NextProtos: nextProtos, } } +func (c *Config) h2AllowedHost(host string) bool { + return c.h2Config != nil && + c.h2Config.AllowedHostsFilter != nil && + c.h2Config.AllowedHostsFilter(host) +} + func (c *Config) cert(hostname string) (*tls.Certificate, error) { // Remove the port if it exists. host, _, err := net.SplitHostPort(hostname) diff --git a/proxy.go b/proxy.go index dd3fa2768..650493b55 100644 --- a/proxy.go +++ b/proxy.go @@ -343,6 +343,9 @@ func (p *Proxy) handleConnectRequest(ctx *Context, req *http.Request, session *S p.mitm.HandshakeErrorCallback(req, err) return err } + if tlsconn.ConnectionState().NegotiatedProtocol == "h2" { + return p.mitm.H2Config().Proxy(p.closing, tlsconn, req.URL) + } var nconn net.Conn nconn = tlsconn