Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow API calls without GAE context #284

Merged
merged 2 commits into from Jul 18, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 1 addition & 13 deletions aetest/instance_vm.go
Expand Up @@ -18,7 +18,6 @@ import (
"regexp"
"time"

"golang.org/x/net/context"
"google.golang.org/appengine/internal"
)

Expand Down Expand Up @@ -61,7 +60,6 @@ type instance struct {
appDir string
appID string
startupTimeout time.Duration
relFuncs []func() // funcs to release any associated contexts
}

// NewRequest returns an *http.Request associated with this instance.
Expand All @@ -72,21 +70,11 @@ func (i *instance) NewRequest(method, urlStr string, body io.Reader) (*http.Requ
}

// Associate this request.
req, release := internal.RegisterTestRequest(req, i.apiURL, func(ctx context.Context) context.Context {
ctx = internal.WithAppIDOverride(ctx, "dev~"+i.appID)
return ctx
})
i.relFuncs = append(i.relFuncs, release)

return req, nil
return internal.RegisterTestRequest(req, i.apiURL, "dev~"+i.appID), nil
}

// Close kills the child api_server.py process, releasing its resources.
func (i *instance) Close() (err error) {
for _, rel := range i.relFuncs {
rel()
}
i.relFuncs = nil
child := i.child
if child == nil {
return nil
Expand Down
9 changes: 5 additions & 4 deletions appengine_vm.go
Expand Up @@ -8,12 +8,13 @@
package appengine

import (
"golang.org/x/net/context"

"google.golang.org/appengine/internal"
"context"
)

// BackgroundContext returns a context not associated with a request.
//
// Deprecated: App Engine no longer has a special background context.
// Just use context.Background().
func BackgroundContext() context.Context {
return internal.BackgroundContext()
return context.Background()
}
127 changes: 39 additions & 88 deletions internal/api.go
Expand Up @@ -24,17 +24,17 @@ import (
"sync/atomic"
"time"

netcontext "context"

"github.com/golang/protobuf/proto"
netcontext "golang.org/x/net/context"

basepb "google.golang.org/appengine/internal/base"
logpb "google.golang.org/appengine/internal/log"
remotepb "google.golang.org/appengine/internal/remote_api"
)

const (
apiPath = "/rpc_http"
defaultTicketSuffix = "/default.20150612t184001.0"
apiPath = "/rpc_http"
)

var (
Expand Down Expand Up @@ -66,21 +66,22 @@ var (
IdleConnTimeout: 90 * time.Second,
},
}

defaultTicketOnce sync.Once
defaultTicket string
backgroundContextOnce sync.Once
backgroundContext netcontext.Context
)

func apiURL() *url.URL {
func apiURL(ctx netcontext.Context) *url.URL {
host, port := "appengine.googleapis.internal", "10001"
if h := os.Getenv("API_HOST"); h != "" {
host = h
}
if hostOverride := ctx.Value(apiHostOverrideKey); hostOverride != nil {
host = hostOverride.(string)
}
if p := os.Getenv("API_PORT"); p != "" {
port = p
}
if portOverride := ctx.Value(apiPortOverrideKey); portOverride != nil {
port = portOverride.(string)
}
return &url.URL{
Scheme: "http",
Host: host + ":" + port,
Expand All @@ -98,7 +99,6 @@ func handleHTTPMiddleware(next http.Handler) http.Handler {
c := &context{
req: r,
outHeader: w.Header(),
apiURL: apiURL(),
}
r = r.WithContext(withContext(r.Context(), c))
c.req = r
Expand Down Expand Up @@ -235,8 +235,6 @@ type context struct {
lines []*logpb.UserAppLogLine
flushes int
}

apiURL *url.URL
}

var contextKey = "holds a *context"
Expand Down Expand Up @@ -304,59 +302,19 @@ func WithContext(parent netcontext.Context, req *http.Request) netcontext.Contex
}
}

// DefaultTicket returns a ticket used for background context or dev_appserver.
func DefaultTicket() string {
defaultTicketOnce.Do(func() {
if IsDevAppServer() {
defaultTicket = "testapp" + defaultTicketSuffix
return
}
appID := partitionlessAppID()
escAppID := strings.Replace(strings.Replace(appID, ":", "_", -1), ".", "_", -1)
majVersion := VersionID(nil)
if i := strings.Index(majVersion, "."); i > 0 {
majVersion = majVersion[:i]
}
defaultTicket = fmt.Sprintf("%s/%s.%s.%s", escAppID, ModuleName(nil), majVersion, InstanceID())
})
return defaultTicket
}

func BackgroundContext() netcontext.Context {
backgroundContextOnce.Do(func() {
// Compute background security ticket.
ticket := DefaultTicket()

c := &context{
req: &http.Request{
Header: http.Header{
ticketHeader: []string{ticket},
},
},
apiURL: apiURL(),
}
backgroundContext = toContext(c)

// TODO(dsymonds): Wire up the shutdown handler to do a final flush.
go c.logFlusher(make(chan int))
})

return backgroundContext
}

// RegisterTestRequest registers the HTTP request req for testing, such that
// any API calls are sent to the provided URL. It returns a closure to delete
// the registration.
// any API calls are sent to the provided URL.
// It should only be used by aetest package.
func RegisterTestRequest(req *http.Request, apiURL *url.URL, decorate func(netcontext.Context) netcontext.Context) (*http.Request, func()) {
c := &context{
req: req,
apiURL: apiURL,
}
ctx := withContext(decorate(req.Context()), c)
req = req.WithContext(ctx)
c.req = req
return req, func() {}
func RegisterTestRequest(req *http.Request, apiURL *url.URL, appID string) *http.Request {
ctx := req.Context()
ctx = withAPIHostOverride(ctx, apiURL.Hostname())
ctx = withAPIPortOverride(ctx, apiURL.Port())
ctx = WithAppIDOverride(ctx, appID)

// use the unregistered request as a placeholder so that withContext can read the headers
c := &context{req: req}
c.req = req.WithContext(withContext(ctx, c))
return c.req
}

var errTimeout = &CallError{
Expand Down Expand Up @@ -401,10 +359,11 @@ func (c *context) WriteHeader(code int) {
c.outCode = code
}

func (c *context) post(body []byte, timeout time.Duration) (b []byte, err error) {
func post(ctx netcontext.Context, body []byte, timeout time.Duration) (b []byte, err error) {
apiURL := apiURL(ctx)
hreq := &http.Request{
Method: "POST",
URL: c.apiURL,
URL: apiURL,
Header: http.Header{
apiEndpointHeader: apiEndpointHeaderValue,
apiMethodHeader: apiMethodHeaderValue,
Expand All @@ -413,13 +372,16 @@ func (c *context) post(body []byte, timeout time.Duration) (b []byte, err error)
},
Body: ioutil.NopCloser(bytes.NewReader(body)),
ContentLength: int64(len(body)),
Host: c.apiURL.Host,
}
if info := c.req.Header.Get(dapperHeader); info != "" {
hreq.Header.Set(dapperHeader, info)
Host: apiURL.Host,
}
if info := c.req.Header.Get(traceHeader); info != "" {
hreq.Header.Set(traceHeader, info)
c := fromContext(ctx)
if c != nil {
if info := c.req.Header.Get(dapperHeader); info != "" {
hreq.Header.Set(dapperHeader, info)
}
if info := c.req.Header.Get(traceHeader); info != "" {
hreq.Header.Set(traceHeader, info)
}
}

tr := apiHTTPClient.Transport.(*http.Transport)
Expand Down Expand Up @@ -480,10 +442,6 @@ func Call(ctx netcontext.Context, service, method string, in, out proto.Message)
}

c := fromContext(ctx)
if c == nil {
// Give a good error message rather than a panic lower down.
return errNotAppEngineContext
}

// Apply transaction modifications if we're in a transaction.
if t := transactionFromContext(ctx); t != nil {
Expand All @@ -504,20 +462,13 @@ func Call(ctx netcontext.Context, service, method string, in, out proto.Message)
return err
}

ticket := c.req.Header.Get(ticketHeader)
// Use a test ticket under test environment.
if ticket == "" {
if appid := ctx.Value(&appIDOverrideKey); appid != nil {
ticket = appid.(string) + defaultTicketSuffix
ticket := ""
if c != nil {
ticket = c.req.Header.Get(ticketHeader)
if dri := c.req.Header.Get(devRequestIdHeader); IsDevAppServer() && dri != "" {
ticket = dri
}
}
// Fall back to use background ticket when the request ticket is not available in Flex or dev_appserver.
if ticket == "" {
ticket = DefaultTicket()
}
if dri := c.req.Header.Get(devRequestIdHeader); IsDevAppServer() && dri != "" {
ticket = dri
}
req := &remotepb.Request{
ServiceName: &service,
Method: &method,
Expand All @@ -529,7 +480,7 @@ func Call(ctx netcontext.Context, service, method string, in, out proto.Message)
return err
}

hrespBody, err := c.post(hreqBody, timeout)
hrespBody, err := post(ctx, hreqBody, timeout)
if err != nil {
return err
}
Expand Down
20 changes: 19 additions & 1 deletion internal/api_common.go
Expand Up @@ -5,13 +5,19 @@
package internal

import (
netcontext "context"
"errors"
"os"

"github.com/golang/protobuf/proto"
netcontext "golang.org/x/net/context"
)

type ctxKey string

func (c ctxKey) String() string {
return "appengine context key: " + string(c)
}

var errNotAppEngineContext = errors.New("not an App Engine context")

type CallOverrideFunc func(ctx netcontext.Context, service, method string, in, out proto.Message) error
Expand Down Expand Up @@ -55,6 +61,18 @@ func WithAppIDOverride(ctx netcontext.Context, appID string) netcontext.Context
return netcontext.WithValue(ctx, &appIDOverrideKey, appID)
}

var apiHostOverrideKey = ctxKey("holds a string, being the alternate API_HOST")

func withAPIHostOverride(ctx netcontext.Context, apiHost string) netcontext.Context {
return netcontext.WithValue(ctx, apiHostOverrideKey, apiHost)
}

var apiPortOverrideKey = ctxKey("holds a string, being the alternate API_PORT")

func withAPIPortOverride(ctx netcontext.Context, apiPort string) netcontext.Context {
return netcontext.WithValue(ctx, apiPortOverrideKey, apiPort)
}

var namespaceKey = "holds the namespace string"

func withNamespace(ctx netcontext.Context, ns string) netcontext.Context {
Expand Down