Skip to content

Commit c4104ea

Browse files
fix(bigquery/storage/managedwriter): context refactoring (#8275)
Co-authored-by: Alvaro Viebrantz <[email protected]>
1 parent 6e0227d commit c4104ea

File tree

8 files changed

+265
-44
lines changed

8 files changed

+265
-44
lines changed

bigquery/storage/managedwriter/client.go

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ type Client struct {
4545
rawClient *storage.BigQueryWriteClient
4646
projectID string
4747

48+
// retained context. primarily used for connection management and the underlying
49+
// client.
50+
ctx context.Context
51+
cancel context.CancelFunc
52+
4853
// cfg retains general settings (custom ClientOptions).
4954
cfg *writerClientConfig
5055

@@ -66,21 +71,27 @@ func NewClient(ctx context.Context, projectID string, opts ...option.ClientOptio
6671
}
6772
o = append(o, opts...)
6873

69-
rawClient, err := storage.NewBigQueryWriteClient(ctx, o...)
74+
cCtx, cancel := context.WithCancel(ctx)
75+
76+
rawClient, err := storage.NewBigQueryWriteClient(cCtx, o...)
7077
if err != nil {
78+
cancel()
7179
return nil, err
7280
}
7381
rawClient.SetGoogleClientInfo("gccl", internal.Version)
7482

7583
// Handle project autodetection.
7684
projectID, err = detect.ProjectID(ctx, projectID, "", opts...)
7785
if err != nil {
86+
cancel()
7887
return nil, err
7988
}
8089

8190
return &Client{
8291
rawClient: rawClient,
8392
projectID: projectID,
93+
ctx: cCtx,
94+
cancel: cancel,
8495
cfg: newWriterClientConfig(opts...),
8596
pools: make(map[string]*connectionPool),
8697
}, nil
@@ -103,6 +114,10 @@ func (c *Client) Close() error {
103114
if err := c.rawClient.Close(); err != nil && firstErr == nil {
104115
firstErr = err
105116
}
117+
// Cancel the retained client context.
118+
if c.cancel != nil {
119+
c.cancel()
120+
}
106121
return firstErr
107122
}
108123

@@ -114,8 +129,11 @@ func (c *Client) NewManagedStream(ctx context.Context, opts ...WriterOption) (*M
114129
}
115130

116131
// createOpenF builds the opener function we need to access the AppendRows bidi stream.
117-
func createOpenF(ctx context.Context, streamFunc streamClientFunc) func(opts ...gax.CallOption) (storagepb.BigQueryWrite_AppendRowsClient, error) {
118-
return func(opts ...gax.CallOption) (storagepb.BigQueryWrite_AppendRowsClient, error) {
132+
func createOpenF(streamFunc streamClientFunc, routingHeader string) func(ctx context.Context, opts ...gax.CallOption) (storagepb.BigQueryWrite_AppendRowsClient, error) {
133+
return func(ctx context.Context, opts ...gax.CallOption) (storagepb.BigQueryWrite_AppendRowsClient, error) {
134+
if routingHeader != "" {
135+
ctx = metadata.AppendToOutgoingContext(ctx, "x-goog-request-params", routingHeader)
136+
}
119137
arc, err := streamFunc(ctx, opts...)
120138
if err != nil {
121139
return nil, err
@@ -167,11 +185,11 @@ func (c *Client) buildManagedStream(ctx context.Context, streamFunc streamClient
167185
if err != nil {
168186
return nil, err
169187
}
170-
// Add the writer to the pool, and derive context from the pool.
188+
// Add the writer to the pool.
171189
if err := pool.addWriter(writer); err != nil {
172190
return nil, err
173191
}
174-
writer.ctx, writer.cancel = context.WithCancel(pool.ctx)
192+
writer.ctx, writer.cancel = context.WithCancel(ctx)
175193

176194
// Attach any tag keys to the context on the writer, so instrumentation works as expected.
177195
writer.ctx = setupWriterStatContext(writer)
@@ -218,7 +236,7 @@ func (c *Client) resolvePool(ctx context.Context, settings *streamSettings, stre
218236
}
219237

220238
// No existing pool available, create one for the location and add to shared pools.
221-
pool, err := c.createPool(ctx, loc, streamFunc)
239+
pool, err := c.createPool(loc, streamFunc)
222240
if err != nil {
223241
return nil, err
224242
}
@@ -227,24 +245,28 @@ func (c *Client) resolvePool(ctx context.Context, settings *streamSettings, stre
227245
}
228246

229247
// createPool builds a connectionPool.
230-
func (c *Client) createPool(ctx context.Context, location string, streamFunc streamClientFunc) (*connectionPool, error) {
231-
cCtx, cancel := context.WithCancel(ctx)
248+
func (c *Client) createPool(location string, streamFunc streamClientFunc) (*connectionPool, error) {
249+
cCtx, cancel := context.WithCancel(c.ctx)
232250

233251
if c.cfg == nil {
234252
cancel()
235253
return nil, fmt.Errorf("missing client config")
236254
}
237-
if location != "" {
238-
// add location header to the retained pool context.
239-
cCtx = metadata.AppendToOutgoingContext(ctx, "x-goog-request-params", fmt.Sprintf("write_location=%s", location))
240-
}
255+
256+
var routingHeader string
257+
/*
258+
* TODO: set once backend respects the new routing header
259+
* if location != "" && c.projectID != "" {
260+
* routingHeader = fmt.Sprintf("write_location=projects/%s/locations/%s", c.projectID, location)
261+
* }
262+
*/
241263

242264
pool := &connectionPool{
243265
id: newUUID(poolIDPrefix),
244266
location: location,
245267
ctx: cCtx,
246268
cancel: cancel,
247-
open: createOpenF(ctx, streamFunc),
269+
open: createOpenF(streamFunc, routingHeader),
248270
callOptions: c.cfg.defaultAppendRowsCallOptions,
249271
baseFlowController: newFlowController(c.cfg.defaultInflightRequests, c.cfg.defaultInflightBytes),
250272
}

bigquery/storage/managedwriter/client_test.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,13 @@ func TestTableParentFromStreamName(t *testing.T) {
5555
}
5656

5757
func TestCreatePool_Location(t *testing.T) {
58+
t.Skip("skipping until new write_location is allowed")
5859
c := &Client{
59-
cfg: &writerClientConfig{},
60+
cfg: &writerClientConfig{},
61+
ctx: context.Background(),
62+
projectID: "myproj",
6063
}
61-
pool, err := c.createPool(context.Background(), "foo", nil)
64+
pool, err := c.createPool("foo", nil)
6265
if err != nil {
6366
t.Fatalf("createPool: %v", err)
6467
}
@@ -72,7 +75,7 @@ func TestCreatePool_Location(t *testing.T) {
7275
}
7376
found := false
7477
for _, v := range vals {
75-
if v == "write_location=foo" {
78+
if v == "write_location=projects/myproj/locations/foo" {
7679
found = true
7780
break
7881
}
@@ -151,8 +154,9 @@ func TestCreatePool(t *testing.T) {
151154
for _, tc := range testCases {
152155
c := &Client{
153156
cfg: tc.cfg,
157+
ctx: context.Background(),
154158
}
155-
pool, err := c.createPool(context.Background(), "", nil)
159+
pool, err := c.createPool("", nil)
156160
if err != nil {
157161
t.Errorf("case %q: createPool errored unexpectedly: %v", tc.desc, err)
158162
continue

bigquery/storage/managedwriter/connection.go

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ type connectionPool struct {
5454

5555
// We centralize the open function on the pool, rather than having an instance of the open func on every
5656
// connection. Opening the connection is a stateless operation.
57-
open func(opts ...gax.CallOption) (storagepb.BigQueryWrite_AppendRowsClient, error)
57+
open func(ctx context.Context, opts ...gax.CallOption) (storagepb.BigQueryWrite_AppendRowsClient, error)
5858

5959
// We specify default calloptions for the pool.
6060
// Explicit connections may have their own calloptions as well.
@@ -137,7 +137,7 @@ func (cp *connectionPool) openWithRetry(co *connection) (storagepb.BigQueryWrite
137137
r := &unaryRetryer{}
138138
for {
139139
recordStat(cp.ctx, AppendClientOpenCount, 1)
140-
arc, err := cp.open(cp.mergeCallOptions(co)...)
140+
arc, err := cp.open(co.ctx, cp.mergeCallOptions(co)...)
141141
if err != nil {
142142
bo, shouldRetry := r.Retry(err)
143143
if shouldRetry {
@@ -151,6 +151,7 @@ func (cp *connectionPool) openWithRetry(co *connection) (storagepb.BigQueryWrite
151151
return nil, nil, err
152152
}
153153
}
154+
154155
// The channel relationship with its ARC is 1:1. If we get a new ARC, create a new pending
155156
// write channel and fire up the associated receive processor. The channel ensures that
156157
// responses for a connection are processed in the same order that appends were sent.
@@ -159,7 +160,7 @@ func (cp *connectionPool) openWithRetry(co *connection) (storagepb.BigQueryWrite
159160
depth = d
160161
}
161162
ch := make(chan *pendingWrite, depth)
162-
go connRecvProcessor(co, arc, ch)
163+
go connRecvProcessor(co.ctx, co, arc, ch)
163164
return arc, ch, nil
164165
}
165166
}
@@ -441,13 +442,17 @@ func (co *connection) getStream(arc *storagepb.BigQueryWrite_AppendRowsClient, f
441442
if arc != co.arc && !forceReconnect {
442443
return co.arc, co.pending, nil
443444
}
444-
// We need to (re)open a connection. Cleanup previous connection and channel if they are present.
445+
// We need to (re)open a connection. Cleanup previous connection, channel, and context if they are present.
445446
if co.arc != nil && (*co.arc) != (storagepb.BigQueryWrite_AppendRowsClient)(nil) {
446447
(*co.arc).CloseSend()
447448
}
448449
if co.pending != nil {
449450
close(co.pending)
450451
}
452+
if co.cancel != nil {
453+
co.cancel()
454+
co.ctx, co.cancel = context.WithCancel(co.pool.ctx)
455+
}
451456

452457
co.arc = new(storagepb.BigQueryWrite_AppendRowsClient)
453458
// We're going to (re)open the connection, so clear any optimizer state.
@@ -464,10 +469,10 @@ type streamClientFunc func(context.Context, ...gax.CallOption) (storagepb.BigQue
464469
// connRecvProcessor is used to propagate append responses back up with the originating write requests. It
465470
// It runs as a goroutine. A connection object allows for reconnection, and each reconnection establishes a new
466471
// processing gorouting and backing channel.
467-
func connRecvProcessor(co *connection, arc storagepb.BigQueryWrite_AppendRowsClient, ch <-chan *pendingWrite) {
472+
func connRecvProcessor(ctx context.Context, co *connection, arc storagepb.BigQueryWrite_AppendRowsClient, ch <-chan *pendingWrite) {
468473
for {
469474
select {
470-
case <-co.ctx.Done():
475+
case <-ctx.Done():
471476
// Context is done, so we're not going to get further updates. Mark all work left in the channel
472477
// with the context error. We don't attempt to re-enqueue in this case.
473478
for {
@@ -478,7 +483,7 @@ func connRecvProcessor(co *connection, arc storagepb.BigQueryWrite_AppendRowsCli
478483
// It's unlikely this connection will recover here, but for correctness keep the flow controller
479484
// state correct by releasing.
480485
co.release(pw)
481-
pw.markDone(nil, co.ctx.Err())
486+
pw.markDone(nil, ctx.Err())
482487
}
483488
case nextWrite, ok := <-ch:
484489
if !ok {
@@ -493,12 +498,12 @@ func connRecvProcessor(co *connection, arc storagepb.BigQueryWrite_AppendRowsCli
493498
continue
494499
}
495500
// Record that we did in fact get a response from the backend.
496-
recordStat(co.ctx, AppendResponses, 1)
501+
recordStat(ctx, AppendResponses, 1)
497502

498503
if status := resp.GetError(); status != nil {
499504
// The response from the backend embedded a status error. We record that the error
500505
// occurred, and tag it based on the response code of the status.
501-
if tagCtx, tagErr := tag.New(co.ctx, tag.Insert(keyError, codes.Code(status.GetCode()).String())); tagErr == nil {
506+
if tagCtx, tagErr := tag.New(ctx, tag.Insert(keyError, codes.Code(status.GetCode()).String())); tagErr == nil {
502507
recordStat(tagCtx, AppendResponseErrors, 1)
503508
}
504509
respErr := grpcstatus.ErrorProto(status)

bigquery/storage/managedwriter/connection_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ func TestConnection_OpenWithRetry(t *testing.T) {
6161
for _, tc := range testCases {
6262
pool := &connectionPool{
6363
ctx: context.Background(),
64-
open: func(opts ...gax.CallOption) (storagepb.BigQueryWrite_AppendRowsClient, error) {
64+
open: func(ctx context.Context, opts ...gax.CallOption) (storagepb.BigQueryWrite_AppendRowsClient, error) {
6565
if len(tc.errors) == 0 {
6666
panic("out of errors")
6767
}
@@ -162,12 +162,12 @@ func TestConnectionPool_OpenCallOptionPropagation(t *testing.T) {
162162
pool := &connectionPool{
163163
ctx: ctx,
164164
cancel: cancel,
165-
open: createOpenF(ctx, func(ctx context.Context, opts ...gax.CallOption) (storage.BigQueryWrite_AppendRowsClient, error) {
165+
open: createOpenF(func(ctx context.Context, opts ...gax.CallOption) (storage.BigQueryWrite_AppendRowsClient, error) {
166166
if len(opts) == 0 {
167167
t.Fatalf("no options were propagated")
168168
}
169169
return nil, fmt.Errorf("no real client")
170-
}),
170+
}, ""),
171171
callOptions: []gax.CallOption{
172172
gax.WithGRPCOptions(grpc.MaxCallRecvMsgSize(99)),
173173
},

0 commit comments

Comments
 (0)