diff --git a/aetest/instance_vm.go b/aetest/instance_vm.go index 48aa5d48..e1647fdb 100644 --- a/aetest/instance_vm.go +++ b/aetest/instance_vm.go @@ -18,7 +18,6 @@ import ( "regexp" "time" - "golang.org/x/net/context" "google.golang.org/appengine/internal" ) @@ -61,7 +60,6 @@ type instance struct { appDir string appID string startupTimeout time.Duration - relFuncs []func() // funcs to release any associated contexts } // NewRequest returns an *http.Request associated with this instance. @@ -72,21 +70,11 @@ func (i *instance) NewRequest(method, urlStr string, body io.Reader) (*http.Requ } // Associate this request. - req, release := internal.RegisterTestRequest(req, i.apiURL, func(ctx context.Context) context.Context { - ctx = internal.WithAppIDOverride(ctx, "dev~"+i.appID) - return ctx - }) - i.relFuncs = append(i.relFuncs, release) - - return req, nil + return internal.RegisterTestRequest(req, i.apiURL, "dev~"+i.appID), nil } // Close kills the child api_server.py process, releasing its resources. func (i *instance) Close() (err error) { - for _, rel := range i.relFuncs { - rel() - } - i.relFuncs = nil child := i.child if child == nil { return nil diff --git a/appengine_vm.go b/appengine_vm.go index 78d2575a..6e1d041c 100644 --- a/appengine_vm.go +++ b/appengine_vm.go @@ -8,12 +8,13 @@ package appengine import ( - "golang.org/x/net/context" - - "google.golang.org/appengine/internal" + "context" ) // BackgroundContext returns a context not associated with a request. +// +// Deprecated: App Engine no longer has a special background context. +// Just use context.Background(). func BackgroundContext() context.Context { - return internal.BackgroundContext() + return context.Background() } diff --git a/internal/api.go b/internal/api.go index 750a0077..2339da39 100644 --- a/internal/api.go +++ b/internal/api.go @@ -24,8 +24,9 @@ import ( "sync/atomic" "time" + netcontext "context" + "github.com/golang/protobuf/proto" - netcontext "golang.org/x/net/context" basepb "google.golang.org/appengine/internal/base" logpb "google.golang.org/appengine/internal/log" @@ -33,8 +34,7 @@ import ( ) const ( - apiPath = "/rpc_http" - defaultTicketSuffix = "/default.20150612t184001.0" + apiPath = "/rpc_http" ) var ( @@ -66,21 +66,22 @@ var ( IdleConnTimeout: 90 * time.Second, }, } - - defaultTicketOnce sync.Once - defaultTicket string - backgroundContextOnce sync.Once - backgroundContext netcontext.Context ) -func apiURL() *url.URL { +func apiURL(ctx netcontext.Context) *url.URL { host, port := "appengine.googleapis.internal", "10001" if h := os.Getenv("API_HOST"); h != "" { host = h } + if hostOverride := ctx.Value(apiHostOverrideKey); hostOverride != nil { + host = hostOverride.(string) + } if p := os.Getenv("API_PORT"); p != "" { port = p } + if portOverride := ctx.Value(apiPortOverrideKey); portOverride != nil { + port = portOverride.(string) + } return &url.URL{ Scheme: "http", Host: host + ":" + port, @@ -98,7 +99,6 @@ func handleHTTPMiddleware(next http.Handler) http.Handler { c := &context{ req: r, outHeader: w.Header(), - apiURL: apiURL(), } r = r.WithContext(withContext(r.Context(), c)) c.req = r @@ -235,8 +235,6 @@ type context struct { lines []*logpb.UserAppLogLine flushes int } - - apiURL *url.URL } var contextKey = "holds a *context" @@ -304,59 +302,19 @@ func WithContext(parent netcontext.Context, req *http.Request) netcontext.Contex } } -// DefaultTicket returns a ticket used for background context or dev_appserver. -func DefaultTicket() string { - defaultTicketOnce.Do(func() { - if IsDevAppServer() { - defaultTicket = "testapp" + defaultTicketSuffix - return - } - appID := partitionlessAppID() - escAppID := strings.Replace(strings.Replace(appID, ":", "_", -1), ".", "_", -1) - majVersion := VersionID(nil) - if i := strings.Index(majVersion, "."); i > 0 { - majVersion = majVersion[:i] - } - defaultTicket = fmt.Sprintf("%s/%s.%s.%s", escAppID, ModuleName(nil), majVersion, InstanceID()) - }) - return defaultTicket -} - -func BackgroundContext() netcontext.Context { - backgroundContextOnce.Do(func() { - // Compute background security ticket. - ticket := DefaultTicket() - - c := &context{ - req: &http.Request{ - Header: http.Header{ - ticketHeader: []string{ticket}, - }, - }, - apiURL: apiURL(), - } - backgroundContext = toContext(c) - - // TODO(dsymonds): Wire up the shutdown handler to do a final flush. - go c.logFlusher(make(chan int)) - }) - - return backgroundContext -} - // RegisterTestRequest registers the HTTP request req for testing, such that -// any API calls are sent to the provided URL. It returns a closure to delete -// the registration. +// any API calls are sent to the provided URL. // It should only be used by aetest package. -func RegisterTestRequest(req *http.Request, apiURL *url.URL, decorate func(netcontext.Context) netcontext.Context) (*http.Request, func()) { - c := &context{ - req: req, - apiURL: apiURL, - } - ctx := withContext(decorate(req.Context()), c) - req = req.WithContext(ctx) - c.req = req - return req, func() {} +func RegisterTestRequest(req *http.Request, apiURL *url.URL, appID string) *http.Request { + ctx := req.Context() + ctx = withAPIHostOverride(ctx, apiURL.Hostname()) + ctx = withAPIPortOverride(ctx, apiURL.Port()) + ctx = WithAppIDOverride(ctx, appID) + + // use the unregistered request as a placeholder so that withContext can read the headers + c := &context{req: req} + c.req = req.WithContext(withContext(ctx, c)) + return c.req } var errTimeout = &CallError{ @@ -401,10 +359,11 @@ func (c *context) WriteHeader(code int) { c.outCode = code } -func (c *context) post(body []byte, timeout time.Duration) (b []byte, err error) { +func post(ctx netcontext.Context, body []byte, timeout time.Duration) (b []byte, err error) { + apiURL := apiURL(ctx) hreq := &http.Request{ Method: "POST", - URL: c.apiURL, + URL: apiURL, Header: http.Header{ apiEndpointHeader: apiEndpointHeaderValue, apiMethodHeader: apiMethodHeaderValue, @@ -413,13 +372,16 @@ func (c *context) post(body []byte, timeout time.Duration) (b []byte, err error) }, Body: ioutil.NopCloser(bytes.NewReader(body)), ContentLength: int64(len(body)), - Host: c.apiURL.Host, - } - if info := c.req.Header.Get(dapperHeader); info != "" { - hreq.Header.Set(dapperHeader, info) + Host: apiURL.Host, } - if info := c.req.Header.Get(traceHeader); info != "" { - hreq.Header.Set(traceHeader, info) + c := fromContext(ctx) + if c != nil { + if info := c.req.Header.Get(dapperHeader); info != "" { + hreq.Header.Set(dapperHeader, info) + } + if info := c.req.Header.Get(traceHeader); info != "" { + hreq.Header.Set(traceHeader, info) + } } tr := apiHTTPClient.Transport.(*http.Transport) @@ -480,10 +442,6 @@ func Call(ctx netcontext.Context, service, method string, in, out proto.Message) } c := fromContext(ctx) - if c == nil { - // Give a good error message rather than a panic lower down. - return errNotAppEngineContext - } // Apply transaction modifications if we're in a transaction. if t := transactionFromContext(ctx); t != nil { @@ -504,20 +462,13 @@ func Call(ctx netcontext.Context, service, method string, in, out proto.Message) return err } - ticket := c.req.Header.Get(ticketHeader) - // Use a test ticket under test environment. - if ticket == "" { - if appid := ctx.Value(&appIDOverrideKey); appid != nil { - ticket = appid.(string) + defaultTicketSuffix + ticket := "" + if c != nil { + ticket = c.req.Header.Get(ticketHeader) + if dri := c.req.Header.Get(devRequestIdHeader); IsDevAppServer() && dri != "" { + ticket = dri } } - // Fall back to use background ticket when the request ticket is not available in Flex or dev_appserver. - if ticket == "" { - ticket = DefaultTicket() - } - if dri := c.req.Header.Get(devRequestIdHeader); IsDevAppServer() && dri != "" { - ticket = dri - } req := &remotepb.Request{ ServiceName: &service, Method: &method, @@ -529,7 +480,7 @@ func Call(ctx netcontext.Context, service, method string, in, out proto.Message) return err } - hrespBody, err := c.post(hreqBody, timeout) + hrespBody, err := post(ctx, hreqBody, timeout) if err != nil { return err } diff --git a/internal/api_common.go b/internal/api_common.go index e0c0b214..f6101d3b 100644 --- a/internal/api_common.go +++ b/internal/api_common.go @@ -5,13 +5,19 @@ package internal import ( + netcontext "context" "errors" "os" "github.com/golang/protobuf/proto" - netcontext "golang.org/x/net/context" ) +type ctxKey string + +func (c ctxKey) String() string { + return "appengine context key: " + string(c) +} + var errNotAppEngineContext = errors.New("not an App Engine context") type CallOverrideFunc func(ctx netcontext.Context, service, method string, in, out proto.Message) error @@ -55,6 +61,18 @@ func WithAppIDOverride(ctx netcontext.Context, appID string) netcontext.Context return netcontext.WithValue(ctx, &appIDOverrideKey, appID) } +var apiHostOverrideKey = ctxKey("holds a string, being the alternate API_HOST") + +func withAPIHostOverride(ctx netcontext.Context, apiHost string) netcontext.Context { + return netcontext.WithValue(ctx, apiHostOverrideKey, apiHost) +} + +var apiPortOverrideKey = ctxKey("holds a string, being the alternate API_PORT") + +func withAPIPortOverride(ctx netcontext.Context, apiPort string) netcontext.Context { + return netcontext.WithValue(ctx, apiPortOverrideKey, apiPort) +} + var namespaceKey = "holds the namespace string" func withNamespace(ctx netcontext.Context, ns string) netcontext.Context { diff --git a/internal/api_test.go b/internal/api_test.go index c20f52f9..c1be17db 100644 --- a/internal/api_test.go +++ b/internal/api_test.go @@ -10,6 +10,7 @@ package internal import ( "bufio" "bytes" + netcontext "context" "fmt" "io" "io/ioutil" @@ -18,13 +19,13 @@ import ( "net/url" "os" "os/exec" + "runtime" "strings" "sync/atomic" "testing" "time" "github.com/golang/protobuf/proto" - netcontext "golang.org/x/net/context" basepb "google.golang.org/appengine/internal/base" remotepb "google.golang.org/appengine/internal/remote_api" @@ -41,6 +42,8 @@ type fakeAPIHandler struct { hang chan int // used for RunSlowly RPC LogFlushes int32 // atomic + + allowMissingTicket bool } func (f *fakeAPIHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -67,7 +70,7 @@ func (f *fakeAPIHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { http.Error(w, fmt.Sprintf("Bad encoded API request: %v", err), 500) return } - if *apiReq.RequestId != "s3cr3t" && *apiReq.RequestId != DefaultTicket() { + if *apiReq.RequestId != "s3cr3t" && !f.allowMissingTicket { writeResponse(&remotepb.Response{ RpcError: &remotepb.RpcError{ Code: proto.Int32(int32(remotepb.RpcError_SECURITY_VIOLATION)), @@ -147,18 +150,25 @@ func setup() (f *fakeAPIHandler, c *context, cleanup func()) { f = &fakeAPIHandler{} srv := httptest.NewServer(f) u, err := url.Parse(srv.URL + apiPath) + restoreAPIHost := restoreEnvVar("API_HOST") + restoreAPIPort := restoreEnvVar("API_HOST") + os.Setenv("API_HOST", u.Hostname()) + os.Setenv("API_PORT", u.Port()) if err != nil { panic(fmt.Sprintf("url.Parse(%q): %v", srv.URL+apiPath, err)) } return f, &context{ - req: &http.Request{ - Header: http.Header{ - ticketHeader: []string{"s3cr3t"}, - dapperHeader: []string{"trace-001"}, + req: &http.Request{ + Header: http.Header{ + ticketHeader: []string{"s3cr3t"}, + dapperHeader: []string{"trace-001"}, + }, }, - }, - apiURL: u, - }, srv.Close + }, func() { + restoreAPIHost() + restoreAPIPort() + srv.Close() + } } func restoreEnvVar(key string) (cleanup func()) { @@ -192,8 +202,9 @@ func TestAPICall(t *testing.T) { func TestAPICallTicketUnavailable(t *testing.T) { resetEnv := SetTestEnv() defer resetEnv() - _, c, cleanup := setup() + f, c, cleanup := setup() defer cleanup() + f.allowMissingTicket = true c.req.Header.Set(ticketHeader, "") req := &basepb.StringProto{ @@ -243,13 +254,9 @@ func TestAPICallRPCFailure(t *testing.T) { func TestAPICallDialFailure(t *testing.T) { // See what happens if the API host is unresponsive. // This should time out quickly, not hang forever. - _, c, cleanup := setup() - defer cleanup() - // Reset the URL to the production address so that dialing fails. - c.apiURL = apiURL() - + // We intentially don't set up the fakeAPIHandler for this test to cause the dail failure. start := time.Now() - err := Call(toContext(c), "foo", "bar", &basepb.VoidProto{}, &basepb.VoidProto{}) + err := Call(netcontext.Background(), "foo", "bar", &basepb.VoidProto{}, &basepb.VoidProto{}) const max = 1 * time.Second if taken := time.Since(start); taken > max { t.Errorf("Dial hang took too long: %v > %v", taken, max) @@ -282,7 +289,6 @@ func TestDelayedLogFlushing(t *testing.T) { http.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) { logC := WithContext(netcontext.Background(), r) - fromContext(logC).apiURL = c.apiURL // Otherwise it will try to use the default URL. Logf(logC, 1, "It's a lovely day.") w.WriteHeader(200) time.Sleep(1200 * time.Millisecond) @@ -344,7 +350,6 @@ func TestLogFlushing(t *testing.T) { path := "/quick_log_" + tc.logToLogservice http.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) { logC := WithContext(netcontext.Background(), r) - fromContext(logC).apiURL = c.apiURL // Otherwise it will try to use the default URL. Logf(logC, 1, "It's a lovely day.") w.WriteHeader(200) w.Write(make([]byte, 100<<10)) // write 100 KB to force HTTP flush @@ -435,7 +440,8 @@ func TestAPICallAllocations(t *testing.T) { } // Run the test API server in a subprocess so we aren't counting its allocations. - u, cleanup := launchHelperProcess(t) + cleanup := launchHelperProcess(t) + defer cleanup() c := &context{ req: &http.Request{ @@ -444,7 +450,6 @@ func TestAPICallAllocations(t *testing.T) { dapperHeader: []string{"trace-001"}, }, }, - apiURL: u, } req := &basepb.StringProto{ @@ -463,13 +468,18 @@ func TestAPICallAllocations(t *testing.T) { } // Lots of room for improvement... - const min, max float64 = 60, 86 + var min, max float64 = 60, 86 + if strings.HasPrefix(runtime.Version(), "go1.11.") || strings.HasPrefix(runtime.Version(), "go1.12.") { + // add a bit more overhead for versions before go1.13 + // see https://go.dev/doc/go1.13#compilers + max = 90 + } if avg < min || max < avg { t.Errorf("Allocations per API call = %g, want in [%g,%g]", avg, min, max) } } -func launchHelperProcess(t *testing.T) (apiURL *url.URL, cleanup func()) { +func launchHelperProcess(t *testing.T) (cleanup func()) { cmd := exec.Command(os.Args[0], "-test.run=TestHelperProcess") cmd.Env = []string{"GO_WANT_HELPER_PROCESS=1"} stdin, err := cmd.StdinPipe() @@ -504,7 +514,13 @@ func launchHelperProcess(t *testing.T) (apiURL *url.URL, cleanup func()) { t.Fatal("Helper process never reported") } - return u, func() { + restoreAPIHost := restoreEnvVar("API_HOST") + restoreAPIPort := restoreEnvVar("API_HOST") + os.Setenv("API_HOST", u.Hostname()) + os.Setenv("API_PORT", u.Port()) + return func() { + restoreAPIHost() + restoreAPIPort() stdin.Close() if err := cmd.Wait(); err != nil { t.Errorf("Helper process did not exit cleanly: %v", err) @@ -529,20 +545,3 @@ func TestHelperProcess(*testing.T) { // Wait for stdin to be closed. io.Copy(ioutil.Discard, os.Stdin) } - -func TestBackgroundContext(t *testing.T) { - resetEnv := SetTestEnv() - defer resetEnv() - - ctx, key := fromContext(BackgroundContext()), "X-Magic-Ticket-Header" - if g, w := ctx.req.Header.Get(key), "my-app-id/default.20150612t184001.0"; g != w { - t.Errorf("%v = %q, want %q", key, g, w) - } - - // Check that using the background context doesn't panic. - req := &basepb.StringProto{ - Value: proto.String("Doctor Who"), - } - res := &basepb.StringProto{} - Call(BackgroundContext(), "actordb", "LookupActor", req, res) // expected to fail -} diff --git a/v2/aetest/instance.go b/v2/aetest/instance.go index 3143a89c..985a52cc 100644 --- a/v2/aetest/instance.go +++ b/v2/aetest/instance.go @@ -110,7 +110,6 @@ type instance struct { appDir string appID string startupTimeout time.Duration - relFuncs []func() // funcs to release any associated contexts } // NewRequest returns an *http.Request associated with this instance. @@ -121,21 +120,11 @@ func (i *instance) NewRequest(method, urlStr string, body io.Reader) (*http.Requ } // Associate this request. - req, release := internal.RegisterTestRequest(req, i.apiURL, func(ctx context.Context) context.Context { - ctx = internal.WithAppIDOverride(ctx, "dev~"+i.appID) - return ctx - }) - i.relFuncs = append(i.relFuncs, release) - - return req, nil + return internal.RegisterTestRequest(req, i.apiURL, "dev~"+i.appID), nil } // Close kills the child api_server.py process, releasing its resources. func (i *instance) Close() (err error) { - for _, rel := range i.relFuncs { - rel() - } - i.relFuncs = nil child := i.child if child == nil { return nil diff --git a/v2/appengine.go b/v2/appengine.go index cabc4ac0..02e1f24e 100644 --- a/v2/appengine.go +++ b/v2/appengine.go @@ -135,6 +135,9 @@ func APICall(ctx context.Context, service, method string, in, out proto.Message) } // BackgroundContext returns a context not associated with a request. +// +// Deprecated: App Engine no longer has a special background context. +// Just use context.Background(). func BackgroundContext() context.Context { - return internal.BackgroundContext() + return context.Background() } diff --git a/v2/internal/api.go b/v2/internal/api.go index 96df8204..9bf67ad6 100644 --- a/v2/internal/api.go +++ b/v2/internal/api.go @@ -17,8 +17,6 @@ import ( "os" "runtime" "strconv" - "strings" - "sync" "sync/atomic" "time" @@ -28,8 +26,7 @@ import ( ) const ( - apiPath = "/rpc_http" - defaultTicketSuffix = "/default.20150612t184001.0" + apiPath = "/rpc_http" ) var ( @@ -62,23 +59,24 @@ var ( }, } - defaultTicketOnce sync.Once - defaultTicket string - backgroundContextOnce sync.Once - backgroundContext netcontext.Context - logStream io.Writer = os.Stderr // For test hooks. timeNow func() time.Time = time.Now // For test hooks. ) -func apiURL() *url.URL { +func apiURL(ctx netcontext.Context) *url.URL { host, port := "appengine.googleapis.internal", "10001" if h := os.Getenv("API_HOST"); h != "" { host = h } + if hostOverride := ctx.Value(apiHostOverrideKey); hostOverride != nil { + host = hostOverride.(string) + } if p := os.Getenv("API_PORT"); p != "" { port = p } + if portOverride := ctx.Value(apiPortOverrideKey); portOverride != nil { + port = portOverride.(string) + } return &url.URL{ Scheme: "http", Host: host + ":" + port, @@ -90,7 +88,6 @@ func handleHTTP(w http.ResponseWriter, r *http.Request) { c := &context{ req: r, outHeader: w.Header(), - apiURL: apiURL(), } r = r.WithContext(withContext(r.Context(), c)) c.req = r @@ -183,8 +180,6 @@ type context struct { outCode int outHeader http.Header outBody []byte - - apiURL *url.URL } var contextKey = "holds a *context" @@ -252,56 +247,20 @@ func WithContext(parent netcontext.Context, req *http.Request) netcontext.Contex } } -// DefaultTicket returns a ticket used for background context or dev_appserver. -func DefaultTicket() string { - defaultTicketOnce.Do(func() { - if IsDevAppServer() { - defaultTicket = "testapp" + defaultTicketSuffix - return - } - appID := partitionlessAppID() - escAppID := strings.Replace(strings.Replace(appID, ":", "_", -1), ".", "_", -1) - majVersion := VersionID(nil) - if i := strings.Index(majVersion, "."); i > 0 { - majVersion = majVersion[:i] - } - defaultTicket = fmt.Sprintf("%s/%s.%s.%s", escAppID, ModuleName(nil), majVersion, InstanceID()) - }) - return defaultTicket -} - -func BackgroundContext() netcontext.Context { - backgroundContextOnce.Do(func() { - // Compute background security ticket. - ticket := DefaultTicket() - - c := &context{ - req: &http.Request{ - Header: http.Header{ - ticketHeader: []string{ticket}, - }, - }, - apiURL: apiURL(), - } - backgroundContext = toContext(c) - }) - - return backgroundContext -} - // RegisterTestRequest registers the HTTP request req for testing, such that // any API calls are sent to the provided URL. It returns a closure to delete // the registration. // It should only be used by aetest package. -func RegisterTestRequest(req *http.Request, apiURL *url.URL, decorate func(netcontext.Context) netcontext.Context) (*http.Request, func()) { - c := &context{ - req: req, - apiURL: apiURL, - } - ctx := withContext(decorate(req.Context()), c) - req = req.WithContext(ctx) - c.req = req - return req, func() {} +func RegisterTestRequest(req *http.Request, apiURL *url.URL, appID string) *http.Request { + ctx := req.Context() + ctx = withAPIHostOverride(ctx, apiURL.Hostname()) + ctx = withAPIPortOverride(ctx, apiURL.Port()) + ctx = WithAppIDOverride(ctx, appID) + + // use the unregistered request as a placeholder so that withContext can read the headers + c := &context{req: req} + c.req = req.WithContext(withContext(ctx, c)) + return c.req } var errTimeout = &CallError{ @@ -346,10 +305,11 @@ func (c *context) WriteHeader(code int) { c.outCode = code } -func (c *context) post(body []byte, timeout time.Duration) (b []byte, err error) { +func post(ctx netcontext.Context, body []byte, timeout time.Duration) (b []byte, err error) { + apiURL := apiURL(ctx) hreq := &http.Request{ Method: "POST", - URL: c.apiURL, + URL: apiURL, Header: http.Header{ apiEndpointHeader: apiEndpointHeaderValue, apiMethodHeader: apiMethodHeaderValue, @@ -358,13 +318,16 @@ func (c *context) post(body []byte, timeout time.Duration) (b []byte, err error) }, Body: ioutil.NopCloser(bytes.NewReader(body)), ContentLength: int64(len(body)), - Host: c.apiURL.Host, - } - if info := c.req.Header.Get(dapperHeader); info != "" { - hreq.Header.Set(dapperHeader, info) + Host: apiURL.Host, } - if info := c.req.Header.Get(traceHeader); info != "" { - hreq.Header.Set(traceHeader, info) + c := fromContext(ctx) + if c != nil { + if info := c.req.Header.Get(dapperHeader); info != "" { + hreq.Header.Set(dapperHeader, info) + } + if info := c.req.Header.Get(traceHeader); info != "" { + hreq.Header.Set(traceHeader, info) + } } tr := apiHTTPClient.Transport.(*http.Transport) @@ -425,10 +388,6 @@ func Call(ctx netcontext.Context, service, method string, in, out proto.Message) } c := fromContext(ctx) - if c == nil { - // Give a good error message rather than a panic lower down. - return errNotAppEngineContext - } // Apply transaction modifications if we're in a transaction. if t := transactionFromContext(ctx); t != nil { @@ -449,20 +408,13 @@ func Call(ctx netcontext.Context, service, method string, in, out proto.Message) return err } - ticket := c.req.Header.Get(ticketHeader) - // Use a test ticket under test environment. - if ticket == "" { - if appid := ctx.Value(&appIDOverrideKey); appid != nil { - ticket = appid.(string) + defaultTicketSuffix + ticket := "" + if c != nil { + ticket = c.req.Header.Get(ticketHeader) + if dri := c.req.Header.Get(devRequestIdHeader); IsDevAppServer() && dri != "" { + ticket = dri } } - // Fall back to use background ticket when the request ticket is not available in Flex or dev_appserver. - if ticket == "" { - ticket = DefaultTicket() - } - if dri := c.req.Header.Get(devRequestIdHeader); IsDevAppServer() && dri != "" { - ticket = dri - } req := &remotepb.Request{ ServiceName: &service, Method: &method, @@ -474,7 +426,7 @@ func Call(ctx netcontext.Context, service, method string, in, out proto.Message) return err } - hrespBody, err := c.post(hreqBody, timeout) + hrespBody, err := post(ctx, hreqBody, timeout) if err != nil { return err } diff --git a/v2/internal/api_common.go b/v2/internal/api_common.go index a952aaa2..f6101d3b 100644 --- a/v2/internal/api_common.go +++ b/v2/internal/api_common.go @@ -12,6 +12,12 @@ import ( "github.com/golang/protobuf/proto" ) +type ctxKey string + +func (c ctxKey) String() string { + return "appengine context key: " + string(c) +} + var errNotAppEngineContext = errors.New("not an App Engine context") type CallOverrideFunc func(ctx netcontext.Context, service, method string, in, out proto.Message) error @@ -55,6 +61,18 @@ func WithAppIDOverride(ctx netcontext.Context, appID string) netcontext.Context return netcontext.WithValue(ctx, &appIDOverrideKey, appID) } +var apiHostOverrideKey = ctxKey("holds a string, being the alternate API_HOST") + +func withAPIHostOverride(ctx netcontext.Context, apiHost string) netcontext.Context { + return netcontext.WithValue(ctx, apiHostOverrideKey, apiHost) +} + +var apiPortOverrideKey = ctxKey("holds a string, being the alternate API_PORT") + +func withAPIPortOverride(ctx netcontext.Context, apiPort string) netcontext.Context { + return netcontext.WithValue(ctx, apiPortOverrideKey, apiPort) +} + var namespaceKey = "holds the namespace string" func withNamespace(ctx netcontext.Context, ns string) netcontext.Context { diff --git a/v2/internal/api_test.go b/v2/internal/api_test.go index c851f5a6..67785b35 100644 --- a/v2/internal/api_test.go +++ b/v2/internal/api_test.go @@ -37,6 +37,8 @@ type fakeAPIHandler struct { hang chan int // used for RunSlowly RPC LogFlushes int32 // atomic + + allowMissingTicket bool } func (f *fakeAPIHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -63,7 +65,7 @@ func (f *fakeAPIHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { http.Error(w, fmt.Sprintf("Bad encoded API request: %v", err), 500) return } - if *apiReq.RequestId != "s3cr3t" && *apiReq.RequestId != DefaultTicket() { + if *apiReq.RequestId != "s3cr3t" && !f.allowMissingTicket { writeResponse(&remotepb.Response{ RpcError: &remotepb.RpcError{ Code: proto.Int32(int32(remotepb.RpcError_SECURITY_VIOLATION)), @@ -143,18 +145,25 @@ func setup() (f *fakeAPIHandler, c *context, cleanup func()) { f = &fakeAPIHandler{} srv := httptest.NewServer(f) u, err := url.Parse(srv.URL + apiPath) + restoreAPIHost := restoreEnvVar("API_HOST") + restoreAPIPort := restoreEnvVar("API_HOST") + os.Setenv("API_HOST", u.Hostname()) + os.Setenv("API_PORT", u.Port()) if err != nil { panic(fmt.Sprintf("url.Parse(%q): %v", srv.URL+apiPath, err)) } return f, &context{ - req: &http.Request{ - Header: http.Header{ - ticketHeader: []string{"s3cr3t"}, - dapperHeader: []string{"trace-001"}, + req: &http.Request{ + Header: http.Header{ + ticketHeader: []string{"s3cr3t"}, + dapperHeader: []string{"trace-001"}, + }, }, - }, - apiURL: u, - }, srv.Close + }, func() { + restoreAPIHost() + restoreAPIPort() + srv.Close() + } } func restoreEnvVar(key string) (cleanup func()) { @@ -188,8 +197,9 @@ func TestAPICall(t *testing.T) { func TestAPICallTicketUnavailable(t *testing.T) { resetEnv := SetTestEnv() defer resetEnv() - _, c, cleanup := setup() + f, c, cleanup := setup() defer cleanup() + f.allowMissingTicket = true c.req.Header.Set(ticketHeader, "") req := &basepb.StringProto{ @@ -239,13 +249,9 @@ func TestAPICallRPCFailure(t *testing.T) { func TestAPICallDialFailure(t *testing.T) { // See what happens if the API host is unresponsive. // This should time out quickly, not hang forever. - _, c, cleanup := setup() - defer cleanup() - // Reset the URL to the production address so that dialing fails. - c.apiURL = apiURL() - + // We intentially don't set up the fakeAPIHandler for this test to cause the dail failure. start := time.Now() - err := Call(toContext(c), "foo", "bar", &basepb.VoidProto{}, &basepb.VoidProto{}) + err := Call(netcontext.Background(), "foo", "bar", &basepb.VoidProto{}, &basepb.VoidProto{}) const max = 1 * time.Second if taken := time.Since(start); taken > max { t.Errorf("Dial hang took too long: %v > %v", taken, max) @@ -317,7 +323,7 @@ func TestAPICallAllocations(t *testing.T) { } // Run the test API server in a subprocess so we aren't counting its allocations. - u, cleanup := launchHelperProcess(t) + cleanup := launchHelperProcess(t) defer cleanup() c := &context{ req: &http.Request{ @@ -326,7 +332,6 @@ func TestAPICallAllocations(t *testing.T) { dapperHeader: []string{"trace-001"}, }, }, - apiURL: u, } req := &basepb.StringProto{ @@ -351,7 +356,7 @@ func TestAPICallAllocations(t *testing.T) { } } -func launchHelperProcess(t *testing.T) (apiURL *url.URL, cleanup func()) { +func launchHelperProcess(t *testing.T) (cleanup func()) { cmd := exec.Command(os.Args[0], "-test.run=TestHelperProcess") cmd.Env = []string{"GO_WANT_HELPER_PROCESS=1"} stdin, err := cmd.StdinPipe() @@ -386,7 +391,13 @@ func launchHelperProcess(t *testing.T) (apiURL *url.URL, cleanup func()) { t.Fatal("Helper process never reported") } - return u, func() { + restoreAPIHost := restoreEnvVar("API_HOST") + restoreAPIPort := restoreEnvVar("API_HOST") + os.Setenv("API_HOST", u.Hostname()) + os.Setenv("API_PORT", u.Port()) + return func() { + restoreAPIHost() + restoreAPIPort() stdin.Close() if err := cmd.Wait(); err != nil { t.Errorf("Helper process did not exit cleanly: %v", err) @@ -411,20 +422,3 @@ func TestHelperProcess(*testing.T) { // Wait for stdin to be closed. io.Copy(ioutil.Discard, os.Stdin) } - -func TestBackgroundContext(t *testing.T) { - resetEnv := SetTestEnv() - defer resetEnv() - - ctx, key := fromContext(BackgroundContext()), "X-Magic-Ticket-Header" - if g, w := ctx.req.Header.Get(key), "my-app-id/default.20150612t184001.0"; g != w { - t.Errorf("%v = %q, want %q", key, g, w) - } - - // Check that using the background context doesn't panic. - req := &basepb.StringProto{ - Value: proto.String("Doctor Who"), - } - res := &basepb.StringProto{} - Call(BackgroundContext(), "actordb", "LookupActor", req, res) // expected to fail -}