diff --git a/v2/internal/api.go b/v2/internal/api.go index 57ae049e..44314a2b 100644 --- a/v2/internal/api.go +++ b/v2/internal/api.go @@ -258,9 +258,8 @@ func WithContext(parent context.Context, req *http.Request) context.Context { } // RegisterTestRequest registers the HTTP request req for testing, such that -// any API calls are sent to the provided URL. It returns a closure to delete -// the registration. -// It should only be used by aetest package. +// any API calls are sent to the provided URL. +// It should only be used by test code or test helpers like aetest. func RegisterTestRequest(req *http.Request, apiURL *url.URL, appID string) *http.Request { ctx := req.Context() ctx = withAPIHostOverride(ctx, apiURL.Hostname()) diff --git a/v2/internal/api_test.go b/v2/internal/api_test.go index 50b48c9e..77071db5 100644 --- a/v2/internal/api_test.go +++ b/v2/internal/api_test.go @@ -141,51 +141,35 @@ func (f *fakeAPIHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { }) } -func setup() (f *fakeAPIHandler, c *aeContext, cleanup func()) { +func makeTestRequest(apiURL *url.URL) *http.Request { + req := &http.Request{ + Header: http.Header{ + ticketHeader: []string{"s3cr3t"}, + dapperHeader: []string{"trace-001"}, + }, + } + return RegisterTestRequest(req, apiURL, "") +} + +func setup() (f *fakeAPIHandler, r *http.Request, cleanup func()) { f = &fakeAPIHandler{} srv := httptest.NewServer(f) - u, err := url.Parse(srv.URL + apiPath) - restoreAPIHost := restoreEnvVar("API_HOST") - restoreAPIPort := restoreEnvVar("API_HOST") - os.Setenv("API_HOST", u.Hostname()) - os.Setenv("API_PORT", u.Port()) + apiURL, err := url.Parse(srv.URL + apiPath) if err != nil { panic(fmt.Sprintf("url.Parse(%q): %v", srv.URL+apiPath, err)) } - return f, &aeContext{ - req: &http.Request{ - Header: http.Header{ - ticketHeader: []string{"s3cr3t"}, - dapperHeader: []string{"trace-001"}, - }, - }, - }, func() { - restoreAPIHost() - restoreAPIPort() - srv.Close() - } -} - -func restoreEnvVar(key string) (cleanup func()) { - oldval, ok := os.LookupEnv(key) - return func() { - if ok { - os.Setenv(key, oldval) - } else { - os.Unsetenv(key) - } - } + return f, makeTestRequest(apiURL), srv.Close } func TestAPICall(t *testing.T) { - _, c, cleanup := setup() + _, r, cleanup := setup() defer cleanup() req := &basepb.StringProto{ Value: proto.String("Doctor Who"), } res := &basepb.StringProto{} - err := Call(toContext(c), "actordb", "LookupActor", req, res) + err := Call(r.Context(), "actordb", "LookupActor", req, res) if err != nil { t.Fatalf("API call failed: %v", err) } @@ -195,18 +179,16 @@ func TestAPICall(t *testing.T) { } func TestAPICallTicketUnavailable(t *testing.T) { - resetEnv := SetTestEnv() - defer resetEnv() - f, c, cleanup := setup() + f, r, cleanup := setup() defer cleanup() f.allowMissingTicket = true - c.req.Header.Set(ticketHeader, "") + r.Header.Set(ticketHeader, "") req := &basepb.StringProto{ Value: proto.String("Doctor Who"), } res := &basepb.StringProto{} - err := Call(toContext(c), "actordb", "LookupActor", req, res) + err := Call(r.Context(), "actordb", "LookupActor", req, res) if err != nil { t.Fatalf("API call failed: %v", err) } @@ -216,7 +198,7 @@ func TestAPICallTicketUnavailable(t *testing.T) { } func TestAPICallRPCFailure(t *testing.T) { - f, c, cleanup := setup() + f, r, cleanup := setup() defer cleanup() testCases := []struct { @@ -230,7 +212,7 @@ func TestAPICallRPCFailure(t *testing.T) { } f.hang = make(chan int) // only for RunSlowly for _, tc := range testCases { - ctx, _ := context.WithTimeout(toContext(c), 100*time.Millisecond) + ctx, _ := context.WithTimeout(r.Context(), 100*time.Millisecond) err := Call(ctx, "errors", tc.method, &basepb.VoidProto{}, &basepb.VoidProto{}) ce, ok := err.(*CallError) if !ok { @@ -247,9 +229,7 @@ func TestAPICallRPCFailure(t *testing.T) { } func TestAPICallDialFailure(t *testing.T) { - // See what happens if the API host is unresponsive. - // This should time out quickly, not hang forever. - // We intentially don't set up the fakeAPIHandler for this test to cause the dail failure. + // we intentially don't set up the fakeAPIHandler for this test to cause the dail failure start := time.Now() err := Call(context.Background(), "foo", "bar", &basepb.VoidProto{}, &basepb.VoidProto{}) const max = 1 * time.Second @@ -323,16 +303,10 @@ func TestAPICallAllocations(t *testing.T) { } // Run the test API server in a subprocess so we aren't counting its allocations. - cleanup := launchHelperProcess(t) + apiURL, cleanup := launchHelperProcess(t) defer cleanup() - c := &aeContext{ - req: &http.Request{ - Header: http.Header{ - ticketHeader: []string{"s3cr3t"}, - dapperHeader: []string{"trace-001"}, - }, - }, - } + + r := makeTestRequest(apiURL) req := &basepb.StringProto{ Value: proto.String("Doctor Who"), @@ -340,7 +314,7 @@ func TestAPICallAllocations(t *testing.T) { res := &basepb.StringProto{} var apiErr error avg := testing.AllocsPerRun(100, func() { - ctx, _ := context.WithTimeout(toContext(c), 100*time.Millisecond) + ctx, _ := context.WithTimeout(r.Context(), 100*time.Millisecond) if err := Call(ctx, "actordb", "LookupActor", req, res); err != nil && apiErr == nil { apiErr = err // get the first error only } @@ -356,7 +330,7 @@ func TestAPICallAllocations(t *testing.T) { } } -func launchHelperProcess(t *testing.T) (cleanup func()) { +func launchHelperProcess(t *testing.T) (apiURL *url.URL, cleanup func()) { cmd := exec.Command(os.Args[0], "-test.run=TestHelperProcess") cmd.Env = []string{"GO_WANT_HELPER_PROCESS=1"} stdin, err := cmd.StdinPipe() @@ -391,13 +365,7 @@ func launchHelperProcess(t *testing.T) (cleanup func()) { t.Fatal("Helper process never reported") } - restoreAPIHost := restoreEnvVar("API_HOST") - restoreAPIPort := restoreEnvVar("API_HOST") - os.Setenv("API_HOST", u.Hostname()) - os.Setenv("API_PORT", u.Port()) - return func() { - restoreAPIHost() - restoreAPIPort() + return u, func() { stdin.Close() if err := cmd.Wait(); err != nil { t.Errorf("Helper process did not exit cleanly: %v", err) diff --git a/v2/internal/net_test.go b/v2/internal/net_test.go index 7d1c2e7d..70ba0e64 100644 --- a/v2/internal/net_test.go +++ b/v2/internal/net_test.go @@ -26,7 +26,7 @@ func TestDialLimit(t *testing.T) { } }() - f, c, cleanup := setup() // setup is in api_test.go + f, r, cleanup := setup() // setup is in api_test.go defer cleanup() f.hang = make(chan int) @@ -37,12 +37,12 @@ func TestDialLimit(t *testing.T) { for i := 0; i < 2; i++ { go func() { defer wg.Done() - Call(toContext(c), "errors", "RunSlowly", &basepb.VoidProto{}, &basepb.VoidProto{}) + Call(r.Context(), "errors", "RunSlowly", &basepb.VoidProto{}, &basepb.VoidProto{}) }() } time.Sleep(50 * time.Millisecond) // let those two RPCs start - ctx, _ := context.WithTimeout(toContext(c), 50*time.Millisecond) + ctx, _ := context.WithTimeout(r.Context(), 50*time.Millisecond) err := Call(ctx, "errors", "Non200", &basepb.VoidProto{}, &basepb.VoidProto{}) if err != errTimeout { t.Errorf("Non200 RPC returned with err %v, want errTimeout", err)