Skip to content

Commit

Permalink
http2: allow testing Transports with testSyncHooks
Browse files Browse the repository at this point in the history
Change-Id: Icafc4860ef0691e5133221a0b53bb1d2158346cc
Reviewed-on: https://go-review.googlesource.com/c/net/+/572378
Reviewed-by: Jonathan Amsterdam <jba@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
  • Loading branch information
neild committed Mar 19, 2024
1 parent 9e0498d commit 6e2c99c
Show file tree
Hide file tree
Showing 3 changed files with 288 additions and 250 deletions.
202 changes: 158 additions & 44 deletions http2/clientconn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,62 +99,57 @@ type testClientConn struct {

roundtrips []*testRoundTrip

rerr error // returned by Read
rbuf bytes.Buffer // sent to the test conn
wbuf bytes.Buffer // sent by the test conn
rerr error // returned by Read
netConnClosed bool // set when the ClientConn closes the net.Conn
rbuf bytes.Buffer // sent to the test conn
wbuf bytes.Buffer // sent by the test conn
}

func newTestClientConn(t *testing.T, opts ...func(*Transport)) *testClientConn {
t.Helper()

tr := &Transport{}
for _, o := range opts {
o(tr)
}

func newTestClientConnFromClientConn(t *testing.T, cc *ClientConn) *testClientConn {
tc := &testClientConn{
t: t,
tr: tr,
hooks: newTestSyncHooks(),
tr: cc.t,
cc: cc,
hooks: cc.t.syncHooks,
}
cc.tconn = (*testClientConnNetConn)(tc)
tc.enc = hpack.NewEncoder(&tc.encbuf)
tc.fr = NewFramer(&tc.rbuf, &tc.wbuf)
tc.fr.ReadMetaHeaders = hpack.NewDecoder(initialHeaderTableSize, nil)
tc.fr.SetMaxReadFrameSize(10 << 20)

t.Cleanup(func() {
tc.sync()
if tc.rerr == nil {
tc.rerr = io.EOF
}
tc.sync()
if tc.hooks.total != 0 {
t.Errorf("%v goroutines still running after test completed", tc.hooks.total)
}

})
return tc
}

tc.hooks.newclientconn = func(cc *ClientConn) {
tc.cc = cc
}
const singleUse = false
_, err := tc.tr.newClientConn((*testClientConnNetConn)(tc), singleUse, tc.hooks)
if err != nil {
t.Fatal(err)
}
tc.sync()
tc.hooks.newclientconn = nil

func (tc *testClientConn) readClientPreface() {
tc.t.Helper()
// Read the client's HTTP/2 preface, sent prior to any HTTP/2 frames.
buf := make([]byte, len(clientPreface))
if _, err := io.ReadFull(&tc.wbuf, buf); err != nil {
t.Fatalf("reading preface: %v", err)
tc.t.Fatalf("reading preface: %v", err)
}
if !bytes.Equal(buf, clientPreface) {
t.Fatalf("client preface: %q, want %q", buf, clientPreface)
tc.t.Fatalf("client preface: %q, want %q", buf, clientPreface)
}
}

return tc
func newTestClientConn(t *testing.T, opts ...func(*Transport)) *testClientConn {
t.Helper()

tt := newTestTransport(t, opts...)
const singleUse = false
_, err := tt.tr.newClientConn(nil, singleUse, tt.tr.syncHooks)
if err != nil {
t.Fatalf("newClientConn: %v", err)
}

return tt.getConn()
}

// sync waits for the ClientConn under test to reach a stable state,
Expand Down Expand Up @@ -349,7 +344,7 @@ func (b *testRequestBody) closeWithError(err error) {
// the request times out, or some other terminal condition is reached.)
func (tc *testClientConn) roundTrip(req *http.Request) *testRoundTrip {
rt := &testRoundTrip{
tc: tc,
t: tc.t,
donec: make(chan struct{}),
}
tc.roundtrips = append(tc.roundtrips, rt)
Expand All @@ -362,6 +357,9 @@ func (tc *testClientConn) roundTrip(req *http.Request) *testRoundTrip {
tc.hooks.newstream = nil

tc.t.Cleanup(func() {
if !rt.done() {
return
}
res, _ := rt.result()
if res != nil {
res.Body.Close()
Expand Down Expand Up @@ -460,6 +458,14 @@ func (tc *testClientConn) writeContinuation(streamID uint32, endHeaders bool, he
tc.sync()
}

func (tc *testClientConn) writeRSTStream(streamID uint32, code ErrCode) {
tc.t.Helper()
if err := tc.fr.WriteRSTStream(streamID, code); err != nil {
tc.t.Fatal(err)
}
tc.sync()
}

func (tc *testClientConn) writePing(ack bool, data [8]byte) {
tc.t.Helper()
if err := tc.fr.WritePing(ack, data); err != nil {
Expand Down Expand Up @@ -491,9 +497,25 @@ func (tc *testClientConn) closeWrite(err error) {
tc.sync()
}

// inflowWindow returns the amount of inbound flow control available for a stream,
// or for the connection if streamID is 0.
func (tc *testClientConn) inflowWindow(streamID uint32) int32 {
tc.cc.mu.Lock()
defer tc.cc.mu.Unlock()
if streamID == 0 {
return tc.cc.inflow.avail + tc.cc.inflow.unsent
}
cs := tc.cc.streams[streamID]
if cs == nil {
tc.t.Errorf("no stream with id %v", streamID)
return -1
}
return cs.inflow.avail + cs.inflow.unsent
}

// testRoundTrip manages a RoundTrip in progress.
type testRoundTrip struct {
tc *testClientConn
t *testing.T
resp *http.Response
respErr error
donec chan struct{}
Expand All @@ -502,6 +524,9 @@ type testRoundTrip struct {

// streamID returns the HTTP/2 stream ID of the request.
func (rt *testRoundTrip) streamID() uint32 {
if rt.cs == nil {
panic("stream ID unknown")
}
return rt.cs.ID
}

Expand All @@ -517,20 +542,20 @@ func (rt *testRoundTrip) done() bool {

// result returns the result of the RoundTrip.
func (rt *testRoundTrip) result() (*http.Response, error) {
t := rt.tc.t
t := rt.t
t.Helper()
select {
case <-rt.donec:
default:
t.Fatalf("RoundTrip (stream %v) is not done; want it to be", rt.streamID())
t.Fatalf("RoundTrip is not done; want it to be")
}
return rt.resp, rt.respErr
}

// response returns the response of a successful RoundTrip.
// If the RoundTrip unexpectedly failed, it calls t.Fatal.
func (rt *testRoundTrip) response() *http.Response {
t := rt.tc.t
t := rt.t
t.Helper()
resp, err := rt.result()
if err != nil {
Expand All @@ -544,15 +569,15 @@ func (rt *testRoundTrip) response() *http.Response {

// err returns the (possibly nil) error result of RoundTrip.
func (rt *testRoundTrip) err() error {
t := rt.tc.t
t := rt.t
t.Helper()
_, err := rt.result()
return err
}

// wantStatus indicates the expected response StatusCode.
func (rt *testRoundTrip) wantStatus(want int) {
t := rt.tc.t
t := rt.t
t.Helper()
if got := rt.response().StatusCode; got != want {
t.Fatalf("got response status %v, want %v", got, want)
Expand All @@ -561,15 +586,15 @@ func (rt *testRoundTrip) wantStatus(want int) {

// body reads the contents of the response body.
func (rt *testRoundTrip) readBody() ([]byte, error) {
t := rt.tc.t
t := rt.t
t.Helper()
return io.ReadAll(rt.response().Body)
}

// wantBody indicates the expected response body.
// (Note that this consumes the body.)
func (rt *testRoundTrip) wantBody(want []byte) {
t := rt.tc.t
t := rt.t
t.Helper()
got, err := rt.readBody()
if err != nil {
Expand All @@ -582,7 +607,7 @@ func (rt *testRoundTrip) wantBody(want []byte) {

// wantHeaders indicates the expected response headers.
func (rt *testRoundTrip) wantHeaders(want http.Header) {
t := rt.tc.t
t := rt.t
t.Helper()
res := rt.response()
if diff := diffHeaders(res.Header, want); diff != "" {
Expand All @@ -592,7 +617,7 @@ func (rt *testRoundTrip) wantHeaders(want http.Header) {

// wantTrailers indicates the expected response trailers.
func (rt *testRoundTrip) wantTrailers(want http.Header) {
t := rt.tc.t
t := rt.t
t.Helper()
res := rt.response()
if diff := diffHeaders(res.Trailer, want); diff != "" {
Expand Down Expand Up @@ -630,7 +655,8 @@ func (nc *testClientConnNetConn) Write(b []byte) (n int, err error) {
return nc.wbuf.Write(b)
}

func (*testClientConnNetConn) Close() error {
func (nc *testClientConnNetConn) Close() error {
nc.netConnClosed = true
return nil
}

Expand All @@ -639,3 +665,91 @@ func (*testClientConnNetConn) RemoteAddr() (_ net.Addr) { return }
func (*testClientConnNetConn) SetDeadline(t time.Time) error { return nil }
func (*testClientConnNetConn) SetReadDeadline(t time.Time) error { return nil }
func (*testClientConnNetConn) SetWriteDeadline(t time.Time) error { return nil }

// A testTransport allows testing Transport.RoundTrip against fake servers.
// Tests that aren't specifically exercising RoundTrip's retry loop or connection pooling
// should use testClientConn instead.
type testTransport struct {
t *testing.T
tr *Transport

ccs []*testClientConn
}

func newTestTransport(t *testing.T, opts ...func(*Transport)) *testTransport {
tr := &Transport{
syncHooks: newTestSyncHooks(),
}
for _, o := range opts {
o(tr)
}

tt := &testTransport{
t: t,
tr: tr,
}
tr.syncHooks.newclientconn = func(cc *ClientConn) {
tt.ccs = append(tt.ccs, newTestClientConnFromClientConn(t, cc))
}

t.Cleanup(func() {
tt.sync()
if len(tt.ccs) > 0 {
t.Fatalf("%v test ClientConns created, but not examined by test", len(tt.ccs))
}
if tt.tr.syncHooks.total != 0 {
t.Errorf("%v goroutines still running after test completed", tt.tr.syncHooks.total)
}
})

return tt
}

func (tt *testTransport) sync() {
tt.tr.syncHooks.waitInactive()
}

func (tt *testTransport) advance(d time.Duration) {
tt.tr.syncHooks.advance(d)
tt.sync()
}

func (tt *testTransport) hasConn() bool {
return len(tt.ccs) > 0
}

func (tt *testTransport) getConn() *testClientConn {
tt.t.Helper()
if len(tt.ccs) == 0 {
tt.t.Fatalf("no new ClientConns created; wanted one")
}
tc := tt.ccs[0]
tt.ccs = tt.ccs[1:]
tc.sync()
tc.readClientPreface()
return tc
}

func (tt *testTransport) roundTrip(req *http.Request) *testRoundTrip {
rt := &testRoundTrip{
t: tt.t,
donec: make(chan struct{}),
}
tt.tr.syncHooks.goRun(func() {
defer close(rt.donec)
rt.resp, rt.respErr = tt.tr.RoundTrip(req)
})
tt.sync()

tt.t.Cleanup(func() {
if !rt.done() {
return
}
res, _ := rt.result()
if res != nil {
res.Body.Close()
}
})

return rt
}

0 comments on commit 6e2c99c

Please sign in to comment.