From af02bb3cf51dc4a05068b97957076dd86e354fb3 Mon Sep 17 00:00:00 2001 From: Zev Goldstein Date: Thu, 14 Jul 2022 12:24:05 -0400 Subject: [PATCH] allow API calls without GAE context The ApiHost used to require a security ticket for all API calls, so the client side code used to be able to assume that any request without one was invalid and reject it. The backend now is able to handle requests without security tickets in many cases, so the client side check is now just getting in the way. This change removes that check and lets the backend attempt to handle all requests. The way the client implementation happened to require a security ticket was actually by requiring a GAE context. This change removes that constraint as well and removes the now-unecessary BackgroundContext. --- aetest/instance_vm.go | 14 +---- appengine_vm.go | 9 +-- internal/api.go | 127 ++++++++++++-------------------------- internal/api_common.go | 14 ++++- internal/api_test.go | 71 ++++++++++----------- v2/aetest/instance.go | 13 +--- v2/appengine.go | 5 +- v2/internal/api.go | 120 +++++++++++------------------------ v2/internal/api_common.go | 12 ++++ v2/internal/api_test.go | 66 +++++++++----------- 10 files changed, 173 insertions(+), 278 deletions(-) diff --git a/aetest/instance_vm.go b/aetest/instance_vm.go index 48aa5d48..e1647fdb 100644 --- a/aetest/instance_vm.go +++ b/aetest/instance_vm.go @@ -18,7 +18,6 @@ import ( "regexp" "time" - "golang.org/x/net/context" "google.golang.org/appengine/internal" ) @@ -61,7 +60,6 @@ type instance struct { appDir string appID string startupTimeout time.Duration - relFuncs []func() // funcs to release any associated contexts } // NewRequest returns an *http.Request associated with this instance. @@ -72,21 +70,11 @@ func (i *instance) NewRequest(method, urlStr string, body io.Reader) (*http.Requ } // Associate this request. - req, release := internal.RegisterTestRequest(req, i.apiURL, func(ctx context.Context) context.Context { - ctx = internal.WithAppIDOverride(ctx, "dev~"+i.appID) - return ctx - }) - i.relFuncs = append(i.relFuncs, release) - - return req, nil + return internal.RegisterTestRequest(req, i.apiURL, "dev~"+i.appID), nil } // Close kills the child api_server.py process, releasing its resources. func (i *instance) Close() (err error) { - for _, rel := range i.relFuncs { - rel() - } - i.relFuncs = nil child := i.child if child == nil { return nil diff --git a/appengine_vm.go b/appengine_vm.go index 78d2575a..6e1d041c 100644 --- a/appengine_vm.go +++ b/appengine_vm.go @@ -8,12 +8,13 @@ package appengine import ( - "golang.org/x/net/context" - - "google.golang.org/appengine/internal" + "context" ) // BackgroundContext returns a context not associated with a request. +// +// Deprecated: App Engine no longer has a special background context. +// Just use context.Background(). func BackgroundContext() context.Context { - return internal.BackgroundContext() + return context.Background() } diff --git a/internal/api.go b/internal/api.go index 750a0077..28d4e288 100644 --- a/internal/api.go +++ b/internal/api.go @@ -24,8 +24,9 @@ import ( "sync/atomic" "time" + netcontext "context" + "github.com/golang/protobuf/proto" - netcontext "golang.org/x/net/context" basepb "google.golang.org/appengine/internal/base" logpb "google.golang.org/appengine/internal/log" @@ -33,8 +34,7 @@ import ( ) const ( - apiPath = "/rpc_http" - defaultTicketSuffix = "/default.20150612t184001.0" + apiPath = "/rpc_http" ) var ( @@ -66,21 +66,22 @@ var ( IdleConnTimeout: 90 * time.Second, }, } - - defaultTicketOnce sync.Once - defaultTicket string - backgroundContextOnce sync.Once - backgroundContext netcontext.Context ) -func apiURL() *url.URL { +func apiURL(ctx netcontext.Context) *url.URL { host, port := "appengine.googleapis.internal", "10001" if h := os.Getenv("API_HOST"); h != "" { host = h } + if hostOverride := ctx.Value(&apiHostOverrideKey); hostOverride != nil { + host = hostOverride.(string) + } if p := os.Getenv("API_PORT"); p != "" { port = p } + if portOverride := ctx.Value(&apiPortOverrideKey); portOverride != nil { + port = portOverride.(string) + } return &url.URL{ Scheme: "http", Host: host + ":" + port, @@ -98,7 +99,6 @@ func handleHTTPMiddleware(next http.Handler) http.Handler { c := &context{ req: r, outHeader: w.Header(), - apiURL: apiURL(), } r = r.WithContext(withContext(r.Context(), c)) c.req = r @@ -235,8 +235,6 @@ type context struct { lines []*logpb.UserAppLogLine flushes int } - - apiURL *url.URL } var contextKey = "holds a *context" @@ -304,59 +302,19 @@ func WithContext(parent netcontext.Context, req *http.Request) netcontext.Contex } } -// DefaultTicket returns a ticket used for background context or dev_appserver. -func DefaultTicket() string { - defaultTicketOnce.Do(func() { - if IsDevAppServer() { - defaultTicket = "testapp" + defaultTicketSuffix - return - } - appID := partitionlessAppID() - escAppID := strings.Replace(strings.Replace(appID, ":", "_", -1), ".", "_", -1) - majVersion := VersionID(nil) - if i := strings.Index(majVersion, "."); i > 0 { - majVersion = majVersion[:i] - } - defaultTicket = fmt.Sprintf("%s/%s.%s.%s", escAppID, ModuleName(nil), majVersion, InstanceID()) - }) - return defaultTicket -} - -func BackgroundContext() netcontext.Context { - backgroundContextOnce.Do(func() { - // Compute background security ticket. - ticket := DefaultTicket() - - c := &context{ - req: &http.Request{ - Header: http.Header{ - ticketHeader: []string{ticket}, - }, - }, - apiURL: apiURL(), - } - backgroundContext = toContext(c) - - // TODO(dsymonds): Wire up the shutdown handler to do a final flush. - go c.logFlusher(make(chan int)) - }) - - return backgroundContext -} - // RegisterTestRequest registers the HTTP request req for testing, such that -// any API calls are sent to the provided URL. It returns a closure to delete -// the registration. +// any API calls are sent to the provided URL. // It should only be used by aetest package. -func RegisterTestRequest(req *http.Request, apiURL *url.URL, decorate func(netcontext.Context) netcontext.Context) (*http.Request, func()) { - c := &context{ - req: req, - apiURL: apiURL, - } - ctx := withContext(decorate(req.Context()), c) - req = req.WithContext(ctx) - c.req = req - return req, func() {} +func RegisterTestRequest(req *http.Request, apiURL *url.URL, appID string) *http.Request { + ctx := req.Context() + ctx = withAPIHostOverride(ctx, apiURL.Hostname()) + ctx = withAPIPortOverride(ctx, apiURL.Port()) + ctx = WithAppIDOverride(ctx, appID) + + // use the unregistered request as a placeholder so that withContext can read the headers + c := &context{req: req} + c.req = req.WithContext(withContext(ctx, c)) + return c.req } var errTimeout = &CallError{ @@ -401,10 +359,11 @@ func (c *context) WriteHeader(code int) { c.outCode = code } -func (c *context) post(body []byte, timeout time.Duration) (b []byte, err error) { +func post(ctx netcontext.Context, body []byte, timeout time.Duration) (b []byte, err error) { + apiURL := apiURL(ctx) hreq := &http.Request{ Method: "POST", - URL: c.apiURL, + URL: apiURL, Header: http.Header{ apiEndpointHeader: apiEndpointHeaderValue, apiMethodHeader: apiMethodHeaderValue, @@ -413,13 +372,16 @@ func (c *context) post(body []byte, timeout time.Duration) (b []byte, err error) }, Body: ioutil.NopCloser(bytes.NewReader(body)), ContentLength: int64(len(body)), - Host: c.apiURL.Host, - } - if info := c.req.Header.Get(dapperHeader); info != "" { - hreq.Header.Set(dapperHeader, info) + Host: apiURL.Host, } - if info := c.req.Header.Get(traceHeader); info != "" { - hreq.Header.Set(traceHeader, info) + c := fromContext(ctx) + if c != nil { + if info := c.req.Header.Get(dapperHeader); info != "" { + hreq.Header.Set(dapperHeader, info) + } + if info := c.req.Header.Get(traceHeader); info != "" { + hreq.Header.Set(traceHeader, info) + } } tr := apiHTTPClient.Transport.(*http.Transport) @@ -480,10 +442,6 @@ func Call(ctx netcontext.Context, service, method string, in, out proto.Message) } c := fromContext(ctx) - if c == nil { - // Give a good error message rather than a panic lower down. - return errNotAppEngineContext - } // Apply transaction modifications if we're in a transaction. if t := transactionFromContext(ctx); t != nil { @@ -504,20 +462,13 @@ func Call(ctx netcontext.Context, service, method string, in, out proto.Message) return err } - ticket := c.req.Header.Get(ticketHeader) - // Use a test ticket under test environment. - if ticket == "" { - if appid := ctx.Value(&appIDOverrideKey); appid != nil { - ticket = appid.(string) + defaultTicketSuffix + ticket := "" + if c != nil { + ticket = c.req.Header.Get(ticketHeader) + if dri := c.req.Header.Get(devRequestIdHeader); IsDevAppServer() && dri != "" { + ticket = dri } } - // Fall back to use background ticket when the request ticket is not available in Flex or dev_appserver. - if ticket == "" { - ticket = DefaultTicket() - } - if dri := c.req.Header.Get(devRequestIdHeader); IsDevAppServer() && dri != "" { - ticket = dri - } req := &remotepb.Request{ ServiceName: &service, Method: &method, @@ -529,7 +480,7 @@ func Call(ctx netcontext.Context, service, method string, in, out proto.Message) return err } - hrespBody, err := c.post(hreqBody, timeout) + hrespBody, err := post(ctx, hreqBody, timeout) if err != nil { return err } diff --git a/internal/api_common.go b/internal/api_common.go index e0c0b214..07755e3a 100644 --- a/internal/api_common.go +++ b/internal/api_common.go @@ -5,11 +5,11 @@ package internal import ( + netcontext "context" "errors" "os" "github.com/golang/protobuf/proto" - netcontext "golang.org/x/net/context" ) var errNotAppEngineContext = errors.New("not an App Engine context") @@ -55,6 +55,18 @@ func WithAppIDOverride(ctx netcontext.Context, appID string) netcontext.Context return netcontext.WithValue(ctx, &appIDOverrideKey, appID) } +var apiHostOverrideKey = "holds a string, being the alternate API_HOST" + +func withAPIHostOverride(ctx netcontext.Context, apiHost string) netcontext.Context { + return netcontext.WithValue(ctx, &apiHostOverrideKey, apiHost) +} + +var apiPortOverrideKey = "holds a string, being the alternate API_PORT" + +func withAPIPortOverride(ctx netcontext.Context, apiPort string) netcontext.Context { + return netcontext.WithValue(ctx, &apiPortOverrideKey, apiPort) +} + var namespaceKey = "holds the namespace string" func withNamespace(ctx netcontext.Context, ns string) netcontext.Context { diff --git a/internal/api_test.go b/internal/api_test.go index c20f52f9..0e033dfd 100644 --- a/internal/api_test.go +++ b/internal/api_test.go @@ -10,6 +10,7 @@ package internal import ( "bufio" "bytes" + netcontext "context" "fmt" "io" "io/ioutil" @@ -24,7 +25,6 @@ import ( "time" "github.com/golang/protobuf/proto" - netcontext "golang.org/x/net/context" basepb "google.golang.org/appengine/internal/base" remotepb "google.golang.org/appengine/internal/remote_api" @@ -41,6 +41,8 @@ type fakeAPIHandler struct { hang chan int // used for RunSlowly RPC LogFlushes int32 // atomic + + allowMissingTicket bool } func (f *fakeAPIHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -67,7 +69,7 @@ func (f *fakeAPIHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { http.Error(w, fmt.Sprintf("Bad encoded API request: %v", err), 500) return } - if *apiReq.RequestId != "s3cr3t" && *apiReq.RequestId != DefaultTicket() { + if *apiReq.RequestId != "s3cr3t" && !f.allowMissingTicket { writeResponse(&remotepb.Response{ RpcError: &remotepb.RpcError{ Code: proto.Int32(int32(remotepb.RpcError_SECURITY_VIOLATION)), @@ -147,18 +149,25 @@ func setup() (f *fakeAPIHandler, c *context, cleanup func()) { f = &fakeAPIHandler{} srv := httptest.NewServer(f) u, err := url.Parse(srv.URL + apiPath) + revertAPIHost := restoreEnvVar("API_HOST") + restoreAPIPort := restoreEnvVar("API_HOST") + os.Setenv("API_HOST", u.Hostname()) + os.Setenv("API_PORT", u.Port()) if err != nil { panic(fmt.Sprintf("url.Parse(%q): %v", srv.URL+apiPath, err)) } return f, &context{ - req: &http.Request{ - Header: http.Header{ - ticketHeader: []string{"s3cr3t"}, - dapperHeader: []string{"trace-001"}, + req: &http.Request{ + Header: http.Header{ + ticketHeader: []string{"s3cr3t"}, + dapperHeader: []string{"trace-001"}, + }, }, - }, - apiURL: u, - }, srv.Close + }, func() { + revertAPIHost() + restoreAPIPort() + srv.Close() + } } func restoreEnvVar(key string) (cleanup func()) { @@ -192,8 +201,9 @@ func TestAPICall(t *testing.T) { func TestAPICallTicketUnavailable(t *testing.T) { resetEnv := SetTestEnv() defer resetEnv() - _, c, cleanup := setup() + f, c, cleanup := setup() defer cleanup() + f.allowMissingTicket = true c.req.Header.Set(ticketHeader, "") req := &basepb.StringProto{ @@ -243,13 +253,9 @@ func TestAPICallRPCFailure(t *testing.T) { func TestAPICallDialFailure(t *testing.T) { // See what happens if the API host is unresponsive. // This should time out quickly, not hang forever. - _, c, cleanup := setup() - defer cleanup() - // Reset the URL to the production address so that dialing fails. - c.apiURL = apiURL() - + // We intentially don't set up the fakeAPIHandler for this test to cause the dail failure. start := time.Now() - err := Call(toContext(c), "foo", "bar", &basepb.VoidProto{}, &basepb.VoidProto{}) + err := Call(netcontext.Background(), "foo", "bar", &basepb.VoidProto{}, &basepb.VoidProto{}) const max = 1 * time.Second if taken := time.Since(start); taken > max { t.Errorf("Dial hang took too long: %v > %v", taken, max) @@ -282,7 +288,6 @@ func TestDelayedLogFlushing(t *testing.T) { http.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) { logC := WithContext(netcontext.Background(), r) - fromContext(logC).apiURL = c.apiURL // Otherwise it will try to use the default URL. Logf(logC, 1, "It's a lovely day.") w.WriteHeader(200) time.Sleep(1200 * time.Millisecond) @@ -344,7 +349,6 @@ func TestLogFlushing(t *testing.T) { path := "/quick_log_" + tc.logToLogservice http.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) { logC := WithContext(netcontext.Background(), r) - fromContext(logC).apiURL = c.apiURL // Otherwise it will try to use the default URL. Logf(logC, 1, "It's a lovely day.") w.WriteHeader(200) w.Write(make([]byte, 100<<10)) // write 100 KB to force HTTP flush @@ -435,7 +439,8 @@ func TestAPICallAllocations(t *testing.T) { } // Run the test API server in a subprocess so we aren't counting its allocations. - u, cleanup := launchHelperProcess(t) + cleanup := launchHelperProcess(t) + defer cleanup() c := &context{ req: &http.Request{ @@ -444,7 +449,6 @@ func TestAPICallAllocations(t *testing.T) { dapperHeader: []string{"trace-001"}, }, }, - apiURL: u, } req := &basepb.StringProto{ @@ -469,7 +473,7 @@ func TestAPICallAllocations(t *testing.T) { } } -func launchHelperProcess(t *testing.T) (apiURL *url.URL, cleanup func()) { +func launchHelperProcess(t *testing.T) (cleanup func()) { cmd := exec.Command(os.Args[0], "-test.run=TestHelperProcess") cmd.Env = []string{"GO_WANT_HELPER_PROCESS=1"} stdin, err := cmd.StdinPipe() @@ -504,7 +508,13 @@ func launchHelperProcess(t *testing.T) (apiURL *url.URL, cleanup func()) { t.Fatal("Helper process never reported") } - return u, func() { + restoreAPIHost := restoreEnvVar("API_HOST") + restoreAPIPort := restoreEnvVar("API_HOST") + os.Setenv("API_HOST", u.Hostname()) + os.Setenv("API_PORT", u.Port()) + return func() { + restoreAPIHost() + restoreAPIPort() stdin.Close() if err := cmd.Wait(); err != nil { t.Errorf("Helper process did not exit cleanly: %v", err) @@ -529,20 +539,3 @@ func TestHelperProcess(*testing.T) { // Wait for stdin to be closed. io.Copy(ioutil.Discard, os.Stdin) } - -func TestBackgroundContext(t *testing.T) { - resetEnv := SetTestEnv() - defer resetEnv() - - ctx, key := fromContext(BackgroundContext()), "X-Magic-Ticket-Header" - if g, w := ctx.req.Header.Get(key), "my-app-id/default.20150612t184001.0"; g != w { - t.Errorf("%v = %q, want %q", key, g, w) - } - - // Check that using the background context doesn't panic. - req := &basepb.StringProto{ - Value: proto.String("Doctor Who"), - } - res := &basepb.StringProto{} - Call(BackgroundContext(), "actordb", "LookupActor", req, res) // expected to fail -} diff --git a/v2/aetest/instance.go b/v2/aetest/instance.go index 3143a89c..985a52cc 100644 --- a/v2/aetest/instance.go +++ b/v2/aetest/instance.go @@ -110,7 +110,6 @@ type instance struct { appDir string appID string startupTimeout time.Duration - relFuncs []func() // funcs to release any associated contexts } // NewRequest returns an *http.Request associated with this instance. @@ -121,21 +120,11 @@ func (i *instance) NewRequest(method, urlStr string, body io.Reader) (*http.Requ } // Associate this request. - req, release := internal.RegisterTestRequest(req, i.apiURL, func(ctx context.Context) context.Context { - ctx = internal.WithAppIDOverride(ctx, "dev~"+i.appID) - return ctx - }) - i.relFuncs = append(i.relFuncs, release) - - return req, nil + return internal.RegisterTestRequest(req, i.apiURL, "dev~"+i.appID), nil } // Close kills the child api_server.py process, releasing its resources. func (i *instance) Close() (err error) { - for _, rel := range i.relFuncs { - rel() - } - i.relFuncs = nil child := i.child if child == nil { return nil diff --git a/v2/appengine.go b/v2/appengine.go index cabc4ac0..02e1f24e 100644 --- a/v2/appengine.go +++ b/v2/appengine.go @@ -135,6 +135,9 @@ func APICall(ctx context.Context, service, method string, in, out proto.Message) } // BackgroundContext returns a context not associated with a request. +// +// Deprecated: App Engine no longer has a special background context. +// Just use context.Background(). func BackgroundContext() context.Context { - return internal.BackgroundContext() + return context.Background() } diff --git a/v2/internal/api.go b/v2/internal/api.go index 96df8204..ef5ab9f9 100644 --- a/v2/internal/api.go +++ b/v2/internal/api.go @@ -17,8 +17,6 @@ import ( "os" "runtime" "strconv" - "strings" - "sync" "sync/atomic" "time" @@ -28,8 +26,7 @@ import ( ) const ( - apiPath = "/rpc_http" - defaultTicketSuffix = "/default.20150612t184001.0" + apiPath = "/rpc_http" ) var ( @@ -62,23 +59,24 @@ var ( }, } - defaultTicketOnce sync.Once - defaultTicket string - backgroundContextOnce sync.Once - backgroundContext netcontext.Context - logStream io.Writer = os.Stderr // For test hooks. timeNow func() time.Time = time.Now // For test hooks. ) -func apiURL() *url.URL { +func apiURL(ctx netcontext.Context) *url.URL { host, port := "appengine.googleapis.internal", "10001" if h := os.Getenv("API_HOST"); h != "" { host = h } + if hostOverride := ctx.Value(&apiHostOverrideKey); hostOverride != nil { + host = hostOverride.(string) + } if p := os.Getenv("API_PORT"); p != "" { port = p } + if portOverride := ctx.Value(&apiPortOverrideKey); portOverride != nil { + port = portOverride.(string) + } return &url.URL{ Scheme: "http", Host: host + ":" + port, @@ -90,7 +88,6 @@ func handleHTTP(w http.ResponseWriter, r *http.Request) { c := &context{ req: r, outHeader: w.Header(), - apiURL: apiURL(), } r = r.WithContext(withContext(r.Context(), c)) c.req = r @@ -183,8 +180,6 @@ type context struct { outCode int outHeader http.Header outBody []byte - - apiURL *url.URL } var contextKey = "holds a *context" @@ -252,56 +247,20 @@ func WithContext(parent netcontext.Context, req *http.Request) netcontext.Contex } } -// DefaultTicket returns a ticket used for background context or dev_appserver. -func DefaultTicket() string { - defaultTicketOnce.Do(func() { - if IsDevAppServer() { - defaultTicket = "testapp" + defaultTicketSuffix - return - } - appID := partitionlessAppID() - escAppID := strings.Replace(strings.Replace(appID, ":", "_", -1), ".", "_", -1) - majVersion := VersionID(nil) - if i := strings.Index(majVersion, "."); i > 0 { - majVersion = majVersion[:i] - } - defaultTicket = fmt.Sprintf("%s/%s.%s.%s", escAppID, ModuleName(nil), majVersion, InstanceID()) - }) - return defaultTicket -} - -func BackgroundContext() netcontext.Context { - backgroundContextOnce.Do(func() { - // Compute background security ticket. - ticket := DefaultTicket() - - c := &context{ - req: &http.Request{ - Header: http.Header{ - ticketHeader: []string{ticket}, - }, - }, - apiURL: apiURL(), - } - backgroundContext = toContext(c) - }) - - return backgroundContext -} - // RegisterTestRequest registers the HTTP request req for testing, such that // any API calls are sent to the provided URL. It returns a closure to delete // the registration. // It should only be used by aetest package. -func RegisterTestRequest(req *http.Request, apiURL *url.URL, decorate func(netcontext.Context) netcontext.Context) (*http.Request, func()) { - c := &context{ - req: req, - apiURL: apiURL, - } - ctx := withContext(decorate(req.Context()), c) - req = req.WithContext(ctx) - c.req = req - return req, func() {} +func RegisterTestRequest(req *http.Request, apiURL *url.URL, appID string) *http.Request { + ctx := req.Context() + ctx = withAPIHostOverride(ctx, apiURL.Hostname()) + ctx = withAPIPortOverride(ctx, apiURL.Port()) + ctx = WithAppIDOverride(ctx, appID) + + // use the unregistered request as a placeholder so that withContext can read the headers + c := &context{req: req} + c.req = req.WithContext(withContext(ctx, c)) + return c.req } var errTimeout = &CallError{ @@ -346,10 +305,11 @@ func (c *context) WriteHeader(code int) { c.outCode = code } -func (c *context) post(body []byte, timeout time.Duration) (b []byte, err error) { +func post(ctx netcontext.Context, body []byte, timeout time.Duration) (b []byte, err error) { + apiURL := apiURL(ctx) hreq := &http.Request{ Method: "POST", - URL: c.apiURL, + URL: apiURL, Header: http.Header{ apiEndpointHeader: apiEndpointHeaderValue, apiMethodHeader: apiMethodHeaderValue, @@ -358,13 +318,16 @@ func (c *context) post(body []byte, timeout time.Duration) (b []byte, err error) }, Body: ioutil.NopCloser(bytes.NewReader(body)), ContentLength: int64(len(body)), - Host: c.apiURL.Host, - } - if info := c.req.Header.Get(dapperHeader); info != "" { - hreq.Header.Set(dapperHeader, info) + Host: apiURL.Host, } - if info := c.req.Header.Get(traceHeader); info != "" { - hreq.Header.Set(traceHeader, info) + c := fromContext(ctx) + if c != nil { + if info := c.req.Header.Get(dapperHeader); info != "" { + hreq.Header.Set(dapperHeader, info) + } + if info := c.req.Header.Get(traceHeader); info != "" { + hreq.Header.Set(traceHeader, info) + } } tr := apiHTTPClient.Transport.(*http.Transport) @@ -425,10 +388,6 @@ func Call(ctx netcontext.Context, service, method string, in, out proto.Message) } c := fromContext(ctx) - if c == nil { - // Give a good error message rather than a panic lower down. - return errNotAppEngineContext - } // Apply transaction modifications if we're in a transaction. if t := transactionFromContext(ctx); t != nil { @@ -449,20 +408,13 @@ func Call(ctx netcontext.Context, service, method string, in, out proto.Message) return err } - ticket := c.req.Header.Get(ticketHeader) - // Use a test ticket under test environment. - if ticket == "" { - if appid := ctx.Value(&appIDOverrideKey); appid != nil { - ticket = appid.(string) + defaultTicketSuffix + ticket := "" + if c != nil { + ticket = c.req.Header.Get(ticketHeader) + if dri := c.req.Header.Get(devRequestIdHeader); IsDevAppServer() && dri != "" { + ticket = dri } } - // Fall back to use background ticket when the request ticket is not available in Flex or dev_appserver. - if ticket == "" { - ticket = DefaultTicket() - } - if dri := c.req.Header.Get(devRequestIdHeader); IsDevAppServer() && dri != "" { - ticket = dri - } req := &remotepb.Request{ ServiceName: &service, Method: &method, @@ -474,7 +426,7 @@ func Call(ctx netcontext.Context, service, method string, in, out proto.Message) return err } - hrespBody, err := c.post(hreqBody, timeout) + hrespBody, err := post(ctx, hreqBody, timeout) if err != nil { return err } diff --git a/v2/internal/api_common.go b/v2/internal/api_common.go index a952aaa2..07755e3a 100644 --- a/v2/internal/api_common.go +++ b/v2/internal/api_common.go @@ -55,6 +55,18 @@ func WithAppIDOverride(ctx netcontext.Context, appID string) netcontext.Context return netcontext.WithValue(ctx, &appIDOverrideKey, appID) } +var apiHostOverrideKey = "holds a string, being the alternate API_HOST" + +func withAPIHostOverride(ctx netcontext.Context, apiHost string) netcontext.Context { + return netcontext.WithValue(ctx, &apiHostOverrideKey, apiHost) +} + +var apiPortOverrideKey = "holds a string, being the alternate API_PORT" + +func withAPIPortOverride(ctx netcontext.Context, apiPort string) netcontext.Context { + return netcontext.WithValue(ctx, &apiPortOverrideKey, apiPort) +} + var namespaceKey = "holds the namespace string" func withNamespace(ctx netcontext.Context, ns string) netcontext.Context { diff --git a/v2/internal/api_test.go b/v2/internal/api_test.go index c851f5a6..bc4b5158 100644 --- a/v2/internal/api_test.go +++ b/v2/internal/api_test.go @@ -37,6 +37,8 @@ type fakeAPIHandler struct { hang chan int // used for RunSlowly RPC LogFlushes int32 // atomic + + allowMissingTicket bool } func (f *fakeAPIHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -63,7 +65,7 @@ func (f *fakeAPIHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { http.Error(w, fmt.Sprintf("Bad encoded API request: %v", err), 500) return } - if *apiReq.RequestId != "s3cr3t" && *apiReq.RequestId != DefaultTicket() { + if *apiReq.RequestId != "s3cr3t" && !f.allowMissingTicket { writeResponse(&remotepb.Response{ RpcError: &remotepb.RpcError{ Code: proto.Int32(int32(remotepb.RpcError_SECURITY_VIOLATION)), @@ -143,18 +145,25 @@ func setup() (f *fakeAPIHandler, c *context, cleanup func()) { f = &fakeAPIHandler{} srv := httptest.NewServer(f) u, err := url.Parse(srv.URL + apiPath) + revertAPIHost := restoreEnvVar("API_HOST") + restoreAPIPort := restoreEnvVar("API_HOST") + os.Setenv("API_HOST", u.Hostname()) + os.Setenv("API_PORT", u.Port()) if err != nil { panic(fmt.Sprintf("url.Parse(%q): %v", srv.URL+apiPath, err)) } return f, &context{ - req: &http.Request{ - Header: http.Header{ - ticketHeader: []string{"s3cr3t"}, - dapperHeader: []string{"trace-001"}, + req: &http.Request{ + Header: http.Header{ + ticketHeader: []string{"s3cr3t"}, + dapperHeader: []string{"trace-001"}, + }, }, - }, - apiURL: u, - }, srv.Close + }, func() { + revertAPIHost() + restoreAPIPort() + srv.Close() + } } func restoreEnvVar(key string) (cleanup func()) { @@ -188,8 +197,9 @@ func TestAPICall(t *testing.T) { func TestAPICallTicketUnavailable(t *testing.T) { resetEnv := SetTestEnv() defer resetEnv() - _, c, cleanup := setup() + f, c, cleanup := setup() defer cleanup() + f.allowMissingTicket = true c.req.Header.Set(ticketHeader, "") req := &basepb.StringProto{ @@ -239,13 +249,9 @@ func TestAPICallRPCFailure(t *testing.T) { func TestAPICallDialFailure(t *testing.T) { // See what happens if the API host is unresponsive. // This should time out quickly, not hang forever. - _, c, cleanup := setup() - defer cleanup() - // Reset the URL to the production address so that dialing fails. - c.apiURL = apiURL() - + // We intentially don't set up the fakeAPIHandler for this test to cause the dail failure. start := time.Now() - err := Call(toContext(c), "foo", "bar", &basepb.VoidProto{}, &basepb.VoidProto{}) + err := Call(netcontext.Background(), "foo", "bar", &basepb.VoidProto{}, &basepb.VoidProto{}) const max = 1 * time.Second if taken := time.Since(start); taken > max { t.Errorf("Dial hang took too long: %v > %v", taken, max) @@ -317,7 +323,7 @@ func TestAPICallAllocations(t *testing.T) { } // Run the test API server in a subprocess so we aren't counting its allocations. - u, cleanup := launchHelperProcess(t) + cleanup := launchHelperProcess(t) defer cleanup() c := &context{ req: &http.Request{ @@ -326,7 +332,6 @@ func TestAPICallAllocations(t *testing.T) { dapperHeader: []string{"trace-001"}, }, }, - apiURL: u, } req := &basepb.StringProto{ @@ -351,7 +356,7 @@ func TestAPICallAllocations(t *testing.T) { } } -func launchHelperProcess(t *testing.T) (apiURL *url.URL, cleanup func()) { +func launchHelperProcess(t *testing.T) (cleanup func()) { cmd := exec.Command(os.Args[0], "-test.run=TestHelperProcess") cmd.Env = []string{"GO_WANT_HELPER_PROCESS=1"} stdin, err := cmd.StdinPipe() @@ -386,7 +391,13 @@ func launchHelperProcess(t *testing.T) (apiURL *url.URL, cleanup func()) { t.Fatal("Helper process never reported") } - return u, func() { + restoreAPIHost := restoreEnvVar("API_HOST") + restoreAPIPort := restoreEnvVar("API_HOST") + os.Setenv("API_HOST", u.Hostname()) + os.Setenv("API_PORT", u.Port()) + return func() { + restoreAPIHost() + restoreAPIPort() stdin.Close() if err := cmd.Wait(); err != nil { t.Errorf("Helper process did not exit cleanly: %v", err) @@ -411,20 +422,3 @@ func TestHelperProcess(*testing.T) { // Wait for stdin to be closed. io.Copy(ioutil.Discard, os.Stdin) } - -func TestBackgroundContext(t *testing.T) { - resetEnv := SetTestEnv() - defer resetEnv() - - ctx, key := fromContext(BackgroundContext()), "X-Magic-Ticket-Header" - if g, w := ctx.req.Header.Get(key), "my-app-id/default.20150612t184001.0"; g != w { - t.Errorf("%v = %q, want %q", key, g, w) - } - - // Check that using the background context doesn't panic. - req := &basepb.StringProto{ - Value: proto.String("Doctor Who"), - } - res := &basepb.StringProto{} - Call(BackgroundContext(), "actordb", "LookupActor", req, res) // expected to fail -}