Skip to content

Commit

Permalink
fix(bigquery/storage/managedwriter): context refactoring (#8275)
Browse files Browse the repository at this point in the history
Co-authored-by: Alvaro Viebrantz <aviebrantz@google.com>
  • Loading branch information
shollyman and alvarowolfx committed Jul 21, 2023
1 parent 6e0227d commit c4104ea
Show file tree
Hide file tree
Showing 8 changed files with 265 additions and 44 deletions.
48 changes: 35 additions & 13 deletions bigquery/storage/managedwriter/client.go
Expand Up @@ -45,6 +45,11 @@ type Client struct {
rawClient *storage.BigQueryWriteClient
projectID string

// retained context. primarily used for connection management and the underlying
// client.
ctx context.Context
cancel context.CancelFunc

// cfg retains general settings (custom ClientOptions).
cfg *writerClientConfig

Expand All @@ -66,21 +71,27 @@ func NewClient(ctx context.Context, projectID string, opts ...option.ClientOptio
}
o = append(o, opts...)

rawClient, err := storage.NewBigQueryWriteClient(ctx, o...)
cCtx, cancel := context.WithCancel(ctx)

rawClient, err := storage.NewBigQueryWriteClient(cCtx, o...)
if err != nil {
cancel()
return nil, err
}
rawClient.SetGoogleClientInfo("gccl", internal.Version)

// Handle project autodetection.
projectID, err = detect.ProjectID(ctx, projectID, "", opts...)
if err != nil {
cancel()
return nil, err
}

return &Client{
rawClient: rawClient,
projectID: projectID,
ctx: cCtx,
cancel: cancel,
cfg: newWriterClientConfig(opts...),
pools: make(map[string]*connectionPool),
}, nil
Expand All @@ -103,6 +114,10 @@ func (c *Client) Close() error {
if err := c.rawClient.Close(); err != nil && firstErr == nil {
firstErr = err
}
// Cancel the retained client context.
if c.cancel != nil {
c.cancel()
}
return firstErr
}

Expand All @@ -114,8 +129,11 @@ func (c *Client) NewManagedStream(ctx context.Context, opts ...WriterOption) (*M
}

// createOpenF builds the opener function we need to access the AppendRows bidi stream.
func createOpenF(ctx context.Context, streamFunc streamClientFunc) func(opts ...gax.CallOption) (storagepb.BigQueryWrite_AppendRowsClient, error) {
return func(opts ...gax.CallOption) (storagepb.BigQueryWrite_AppendRowsClient, error) {
func createOpenF(streamFunc streamClientFunc, routingHeader string) func(ctx context.Context, opts ...gax.CallOption) (storagepb.BigQueryWrite_AppendRowsClient, error) {
return func(ctx context.Context, opts ...gax.CallOption) (storagepb.BigQueryWrite_AppendRowsClient, error) {
if routingHeader != "" {
ctx = metadata.AppendToOutgoingContext(ctx, "x-goog-request-params", routingHeader)
}
arc, err := streamFunc(ctx, opts...)
if err != nil {
return nil, err
Expand Down Expand Up @@ -167,11 +185,11 @@ func (c *Client) buildManagedStream(ctx context.Context, streamFunc streamClient
if err != nil {
return nil, err
}
// Add the writer to the pool, and derive context from the pool.
// Add the writer to the pool.
if err := pool.addWriter(writer); err != nil {
return nil, err
}
writer.ctx, writer.cancel = context.WithCancel(pool.ctx)
writer.ctx, writer.cancel = context.WithCancel(ctx)

// Attach any tag keys to the context on the writer, so instrumentation works as expected.
writer.ctx = setupWriterStatContext(writer)
Expand Down Expand Up @@ -218,7 +236,7 @@ func (c *Client) resolvePool(ctx context.Context, settings *streamSettings, stre
}

// No existing pool available, create one for the location and add to shared pools.
pool, err := c.createPool(ctx, loc, streamFunc)
pool, err := c.createPool(loc, streamFunc)
if err != nil {
return nil, err
}
Expand All @@ -227,24 +245,28 @@ func (c *Client) resolvePool(ctx context.Context, settings *streamSettings, stre
}

// createPool builds a connectionPool.
func (c *Client) createPool(ctx context.Context, location string, streamFunc streamClientFunc) (*connectionPool, error) {
cCtx, cancel := context.WithCancel(ctx)
func (c *Client) createPool(location string, streamFunc streamClientFunc) (*connectionPool, error) {
cCtx, cancel := context.WithCancel(c.ctx)

if c.cfg == nil {
cancel()
return nil, fmt.Errorf("missing client config")
}
if location != "" {
// add location header to the retained pool context.
cCtx = metadata.AppendToOutgoingContext(ctx, "x-goog-request-params", fmt.Sprintf("write_location=%s", location))
}

var routingHeader string
/*
* TODO: set once backend respects the new routing header
* if location != "" && c.projectID != "" {
* routingHeader = fmt.Sprintf("write_location=projects/%s/locations/%s", c.projectID, location)
* }
*/

pool := &connectionPool{
id: newUUID(poolIDPrefix),
location: location,
ctx: cCtx,
cancel: cancel,
open: createOpenF(ctx, streamFunc),
open: createOpenF(streamFunc, routingHeader),
callOptions: c.cfg.defaultAppendRowsCallOptions,
baseFlowController: newFlowController(c.cfg.defaultInflightRequests, c.cfg.defaultInflightBytes),
}
Expand Down
12 changes: 8 additions & 4 deletions bigquery/storage/managedwriter/client_test.go
Expand Up @@ -55,10 +55,13 @@ func TestTableParentFromStreamName(t *testing.T) {
}

func TestCreatePool_Location(t *testing.T) {
t.Skip("skipping until new write_location is allowed")
c := &Client{
cfg: &writerClientConfig{},
cfg: &writerClientConfig{},
ctx: context.Background(),
projectID: "myproj",
}
pool, err := c.createPool(context.Background(), "foo", nil)
pool, err := c.createPool("foo", nil)
if err != nil {
t.Fatalf("createPool: %v", err)
}
Expand All @@ -72,7 +75,7 @@ func TestCreatePool_Location(t *testing.T) {
}
found := false
for _, v := range vals {
if v == "write_location=foo" {
if v == "write_location=projects/myproj/locations/foo" {
found = true
break
}
Expand Down Expand Up @@ -151,8 +154,9 @@ func TestCreatePool(t *testing.T) {
for _, tc := range testCases {
c := &Client{
cfg: tc.cfg,
ctx: context.Background(),
}
pool, err := c.createPool(context.Background(), "", nil)
pool, err := c.createPool("", nil)
if err != nil {
t.Errorf("case %q: createPool errored unexpectedly: %v", tc.desc, err)
continue
Expand Down
23 changes: 14 additions & 9 deletions bigquery/storage/managedwriter/connection.go
Expand Up @@ -54,7 +54,7 @@ type connectionPool struct {

// We centralize the open function on the pool, rather than having an instance of the open func on every
// connection. Opening the connection is a stateless operation.
open func(opts ...gax.CallOption) (storagepb.BigQueryWrite_AppendRowsClient, error)
open func(ctx context.Context, opts ...gax.CallOption) (storagepb.BigQueryWrite_AppendRowsClient, error)

// We specify default calloptions for the pool.
// Explicit connections may have their own calloptions as well.
Expand Down Expand Up @@ -137,7 +137,7 @@ func (cp *connectionPool) openWithRetry(co *connection) (storagepb.BigQueryWrite
r := &unaryRetryer{}
for {
recordStat(cp.ctx, AppendClientOpenCount, 1)
arc, err := cp.open(cp.mergeCallOptions(co)...)
arc, err := cp.open(co.ctx, cp.mergeCallOptions(co)...)
if err != nil {
bo, shouldRetry := r.Retry(err)
if shouldRetry {
Expand All @@ -151,6 +151,7 @@ func (cp *connectionPool) openWithRetry(co *connection) (storagepb.BigQueryWrite
return nil, nil, err
}
}

// The channel relationship with its ARC is 1:1. If we get a new ARC, create a new pending
// write channel and fire up the associated receive processor. The channel ensures that
// responses for a connection are processed in the same order that appends were sent.
Expand All @@ -159,7 +160,7 @@ func (cp *connectionPool) openWithRetry(co *connection) (storagepb.BigQueryWrite
depth = d
}
ch := make(chan *pendingWrite, depth)
go connRecvProcessor(co, arc, ch)
go connRecvProcessor(co.ctx, co, arc, ch)
return arc, ch, nil
}
}
Expand Down Expand Up @@ -441,13 +442,17 @@ func (co *connection) getStream(arc *storagepb.BigQueryWrite_AppendRowsClient, f
if arc != co.arc && !forceReconnect {
return co.arc, co.pending, nil
}
// We need to (re)open a connection. Cleanup previous connection and channel if they are present.
// We need to (re)open a connection. Cleanup previous connection, channel, and context if they are present.
if co.arc != nil && (*co.arc) != (storagepb.BigQueryWrite_AppendRowsClient)(nil) {
(*co.arc).CloseSend()
}
if co.pending != nil {
close(co.pending)
}
if co.cancel != nil {
co.cancel()
co.ctx, co.cancel = context.WithCancel(co.pool.ctx)
}

co.arc = new(storagepb.BigQueryWrite_AppendRowsClient)
// We're going to (re)open the connection, so clear any optimizer state.
Expand All @@ -464,10 +469,10 @@ type streamClientFunc func(context.Context, ...gax.CallOption) (storagepb.BigQue
// connRecvProcessor is used to propagate append responses back up with the originating write requests. It
// It runs as a goroutine. A connection object allows for reconnection, and each reconnection establishes a new
// processing gorouting and backing channel.
func connRecvProcessor(co *connection, arc storagepb.BigQueryWrite_AppendRowsClient, ch <-chan *pendingWrite) {
func connRecvProcessor(ctx context.Context, co *connection, arc storagepb.BigQueryWrite_AppendRowsClient, ch <-chan *pendingWrite) {
for {
select {
case <-co.ctx.Done():
case <-ctx.Done():
// Context is done, so we're not going to get further updates. Mark all work left in the channel
// with the context error. We don't attempt to re-enqueue in this case.
for {
Expand All @@ -478,7 +483,7 @@ func connRecvProcessor(co *connection, arc storagepb.BigQueryWrite_AppendRowsCli
// It's unlikely this connection will recover here, but for correctness keep the flow controller
// state correct by releasing.
co.release(pw)
pw.markDone(nil, co.ctx.Err())
pw.markDone(nil, ctx.Err())
}
case nextWrite, ok := <-ch:
if !ok {
Expand All @@ -493,12 +498,12 @@ func connRecvProcessor(co *connection, arc storagepb.BigQueryWrite_AppendRowsCli
continue
}
// Record that we did in fact get a response from the backend.
recordStat(co.ctx, AppendResponses, 1)
recordStat(ctx, AppendResponses, 1)

if status := resp.GetError(); status != nil {
// The response from the backend embedded a status error. We record that the error
// occurred, and tag it based on the response code of the status.
if tagCtx, tagErr := tag.New(co.ctx, tag.Insert(keyError, codes.Code(status.GetCode()).String())); tagErr == nil {
if tagCtx, tagErr := tag.New(ctx, tag.Insert(keyError, codes.Code(status.GetCode()).String())); tagErr == nil {
recordStat(tagCtx, AppendResponseErrors, 1)
}
respErr := grpcstatus.ErrorProto(status)
Expand Down
6 changes: 3 additions & 3 deletions bigquery/storage/managedwriter/connection_test.go
Expand Up @@ -61,7 +61,7 @@ func TestConnection_OpenWithRetry(t *testing.T) {
for _, tc := range testCases {
pool := &connectionPool{
ctx: context.Background(),
open: func(opts ...gax.CallOption) (storagepb.BigQueryWrite_AppendRowsClient, error) {
open: func(ctx context.Context, opts ...gax.CallOption) (storagepb.BigQueryWrite_AppendRowsClient, error) {
if len(tc.errors) == 0 {
panic("out of errors")
}
Expand Down Expand Up @@ -162,12 +162,12 @@ func TestConnectionPool_OpenCallOptionPropagation(t *testing.T) {
pool := &connectionPool{
ctx: ctx,
cancel: cancel,
open: createOpenF(ctx, func(ctx context.Context, opts ...gax.CallOption) (storage.BigQueryWrite_AppendRowsClient, error) {
open: createOpenF(func(ctx context.Context, opts ...gax.CallOption) (storage.BigQueryWrite_AppendRowsClient, error) {
if len(opts) == 0 {
t.Fatalf("no options were propagated")
}
return nil, fmt.Errorf("no real client")
}),
}, ""),
callOptions: []gax.CallOption{
gax.WithGRPCOptions(grpc.MaxCallRecvMsgSize(99)),
},
Expand Down

0 comments on commit c4104ea

Please sign in to comment.