Skip to content
This repository was archived by the owner on Jul 6, 2023. It is now read-only.

Commit bdcf454

Browse files
feat: add context manager support in client (#9)
- [ ] Regenerate this pull request now. chore: fix docstring for first attribute of protos committer: @busunkim96 PiperOrigin-RevId: 401271153 Source-Link: googleapis/googleapis@787f8c9 Source-Link: https://github.com/googleapis/googleapis-gen/commit/81decffe9fc72396a8153e756d1d67a6eecfd620 Copy-Tag: eyJwIjoiLmdpdGh1Yi8uT3dsQm90LnlhbWwiLCJoIjoiODFkZWNmZmU5ZmM3MjM5NmE4MTUzZTc1NmQxZDY3YTZlZWNmZDYyMCJ9
1 parent c9c9684 commit bdcf454

File tree

7 files changed

+117
-5
lines changed

7 files changed

+117
-5
lines changed

google/cloud/deploy_v1/services/cloud_deploy/async_client.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1811,6 +1811,12 @@ async def get_config(
18111811
# Done; return the response.
18121812
return response
18131813

1814+
async def __aenter__(self):
1815+
return self
1816+
1817+
async def __aexit__(self, exc_type, exc, tb):
1818+
await self.transport.close()
1819+
18141820

18151821
try:
18161822
DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(

google/cloud/deploy_v1/services/cloud_deploy/client.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -474,10 +474,7 @@ def __init__(
474474
client_cert_source_for_mtls=client_cert_source_func,
475475
quota_project_id=client_options.quota_project_id,
476476
client_info=client_info,
477-
always_use_jwt_access=(
478-
Transport == type(self).get_transport_class("grpc")
479-
or Transport == type(self).get_transport_class("grpc_asyncio")
480-
),
477+
always_use_jwt_access=True,
481478
)
482479

483480
def list_delivery_pipelines(
@@ -2029,6 +2026,19 @@ def get_config(
20292026
# Done; return the response.
20302027
return response
20312028

2029+
def __enter__(self):
2030+
return self
2031+
2032+
def __exit__(self, type, value, traceback):
2033+
"""Releases underlying transport's resources.
2034+
2035+
.. warning::
2036+
ONLY use as a context manager if the transport is NOT shared
2037+
with other clients! Exiting the with block will CLOSE the transport
2038+
and may cause errors in other clients!
2039+
"""
2040+
self.transport.close()
2041+
20322042

20332043
try:
20342044
DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(

google/cloud/deploy_v1/services/cloud_deploy/transports/base.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,15 @@ def _prep_wrapped_messages(self, client_info):
315315
),
316316
}
317317

318+
def close(self):
319+
"""Closes resources associated with the transport.
320+
321+
.. warning::
322+
Only call this method if the transport is NOT shared
323+
with other clients - this may cause errors in other clients!
324+
"""
325+
raise NotImplementedError()
326+
318327
@property
319328
def operations_client(self) -> operations_v1.OperationsClient:
320329
"""Return the client designed to process long-running operations."""

google/cloud/deploy_v1/services/cloud_deploy/transports/grpc.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -733,5 +733,8 @@ def get_config(
733733
)
734734
return self._stubs["get_config"]
735735

736+
def close(self):
737+
self.grpc_channel.close()
738+
736739

737740
__all__ = ("CloudDeployGrpcTransport",)

google/cloud/deploy_v1/services/cloud_deploy/transports/grpc_asyncio.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -755,5 +755,8 @@ def get_config(
755755
)
756756
return self._stubs["get_config"]
757757

758+
def close(self):
759+
return self.grpc_channel.close()
760+
758761

759762
__all__ = ("CloudDeployGrpcAsyncIOTransport",)

google/cloud/deploy_v1/types/cloud_deploy.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ class SerialPipeline(proto.Message):
146146

147147
class Stage(proto.Message):
148148
r"""Stage specifies a location to which to deploy.
149+
149150
Attributes:
150151
target_id (str):
151152
The target_id to which this stage points. This field refers
@@ -224,6 +225,7 @@ class PipelineCondition(proto.Message):
224225

225226
class ListDeliveryPipelinesRequest(proto.Message):
226227
r"""The request object for ``ListDeliveryPipelines``.
228+
227229
Attributes:
228230
parent (str):
229231
Required. The parent, which owns this collection of
@@ -260,6 +262,7 @@ class ListDeliveryPipelinesRequest(proto.Message):
260262

261263
class ListDeliveryPipelinesResponse(proto.Message):
262264
r"""The response object from ``ListDeliveryPipelines``.
265+
263266
Attributes:
264267
delivery_pipelines (Sequence[google.cloud.deploy_v1.types.DeliveryPipeline]):
265268
The ``DeliveryPipeline`` objects.
@@ -284,6 +287,7 @@ def raw_page(self):
284287

285288
class GetDeliveryPipelineRequest(proto.Message):
286289
r"""The request object for ``GetDeliveryPipeline``
290+
287291
Attributes:
288292
name (str):
289293
Required. Name of the ``DeliveryPipeline``. Format must be
@@ -295,6 +299,7 @@ class GetDeliveryPipelineRequest(proto.Message):
295299

296300
class CreateDeliveryPipelineRequest(proto.Message):
297301
r"""The request object for ``CreateDeliveryPipeline``.
302+
298303
Attributes:
299304
parent (str):
300305
Required. The parent collection in which the
@@ -339,6 +344,7 @@ class CreateDeliveryPipelineRequest(proto.Message):
339344

340345
class UpdateDeliveryPipelineRequest(proto.Message):
341346
r"""The request object for ``UpdateDeliveryPipeline``.
347+
342348
Attributes:
343349
update_mask (google.protobuf.field_mask_pb2.FieldMask):
344350
Required. Field mask is used to specify the fields to be
@@ -390,6 +396,7 @@ class UpdateDeliveryPipelineRequest(proto.Message):
390396

391397
class DeleteDeliveryPipelineRequest(proto.Message):
392398
r"""The request object for ``DeleteDeliveryPipeline``.
399+
393400
Attributes:
394401
name (str):
395402
Required. The name of the ``DeliveryPipeline`` to delete.
@@ -547,6 +554,7 @@ class ExecutionEnvironmentUsage(proto.Enum):
547554

548555
class DefaultPool(proto.Message):
549556
r"""Execution using the default Cloud Build pool.
557+
550558
Attributes:
551559
service_account (str):
552560
Optional. Google service account to use for execution. If
@@ -568,6 +576,7 @@ class DefaultPool(proto.Message):
568576

569577
class PrivatePool(proto.Message):
570578
r"""Execution using a private Cloud Build pool.
579+
571580
Attributes:
572581
worker_pool (str):
573582
Required. Resource name of the Cloud Build worker pool to
@@ -594,6 +603,7 @@ class PrivatePool(proto.Message):
594603

595604
class GkeCluster(proto.Message):
596605
r"""Information specifying a GKE Cluster.
606+
597607
Attributes:
598608
cluster (str):
599609
Information specifying a GKE Cluster. Format is
@@ -605,6 +615,7 @@ class GkeCluster(proto.Message):
605615

606616
class ListTargetsRequest(proto.Message):
607617
r"""The request object for ``ListTargets``.
618+
608619
Attributes:
609620
parent (str):
610621
Required. The parent, which owns this collection of targets.
@@ -641,6 +652,7 @@ class ListTargetsRequest(proto.Message):
641652

642653
class ListTargetsResponse(proto.Message):
643654
r"""The response object from ``ListTargets``.
655+
644656
Attributes:
645657
targets (Sequence[google.cloud.deploy_v1.types.Target]):
646658
The ``Target`` objects.
@@ -663,6 +675,7 @@ def raw_page(self):
663675

664676
class GetTargetRequest(proto.Message):
665677
r"""The request object for ``GetTarget``.
678+
666679
Attributes:
667680
name (str):
668681
Required. Name of the ``Target``. Format must be
@@ -674,6 +687,7 @@ class GetTargetRequest(proto.Message):
674687

675688
class CreateTargetRequest(proto.Message):
676689
r"""The request object for ``CreateTarget``.
690+
677691
Attributes:
678692
parent (str):
679693
Required. The parent collection in which the ``Target``
@@ -716,6 +730,7 @@ class CreateTargetRequest(proto.Message):
716730

717731
class UpdateTargetRequest(proto.Message):
718732
r"""The request object for ``UpdateTarget``.
733+
719734
Attributes:
720735
update_mask (google.protobuf.field_mask_pb2.FieldMask):
721736
Required. Field mask is used to specify the fields to be
@@ -764,6 +779,7 @@ class UpdateTargetRequest(proto.Message):
764779

765780
class DeleteTargetRequest(proto.Message):
766781
r"""The request object for ``DeleteTarget``.
782+
767783
Attributes:
768784
name (str):
769785
Required. The name of the ``Target`` to delete. Format
@@ -896,6 +912,7 @@ class RenderState(proto.Enum):
896912

897913
class TargetRender(proto.Message):
898914
r"""Details of rendering for a single target.
915+
899916
Attributes:
900917
rendering_build (str):
901918
Output only. The resource name of the Cloud Build ``Build``
@@ -953,6 +970,7 @@ class TargetRenderState(proto.Enum):
953970

954971
class BuildArtifact(proto.Message):
955972
r"""Description of an a image to use during Skaffold rendering.
973+
956974
Attributes:
957975
image (str):
958976
Image name in Skaffold configuration.
@@ -969,6 +987,7 @@ class BuildArtifact(proto.Message):
969987

970988
class TargetArtifact(proto.Message):
971989
r"""The artifacts produced by a target render operation.
990+
972991
Attributes:
973992
artifact_uri (str):
974993
Output only. URI of a directory containing
@@ -990,6 +1009,7 @@ class TargetArtifact(proto.Message):
9901009

9911010
class ListReleasesRequest(proto.Message):
9921011
r"""The request object for ``ListReleases``.
1012+
9931013
Attributes:
9941014
parent (str):
9951015
Required. The ``DeliveryPipeline`` which owns this
@@ -1025,6 +1045,7 @@ class ListReleasesRequest(proto.Message):
10251045

10261046
class ListReleasesResponse(proto.Message):
10271047
r"""The response object from ``ListReleases``.
1048+
10281049
Attributes:
10291050
releases (Sequence[google.cloud.deploy_v1.types.Release]):
10301051
The ``Release`` objects.
@@ -1047,6 +1068,7 @@ def raw_page(self):
10471068

10481069
class GetReleaseRequest(proto.Message):
10491070
r"""The request object for ``GetRelease``.
1071+
10501072
Attributes:
10511073
name (str):
10521074
Required. Name of the ``Release``. Format must be
@@ -1058,6 +1080,7 @@ class GetReleaseRequest(proto.Message):
10581080

10591081
class CreateReleaseRequest(proto.Message):
10601082
r"""The request object for ``CreateRelease``,
1083+
10611084
Attributes:
10621085
parent (str):
10631086
Required. The parent collection in which the ``Release``
@@ -1209,6 +1232,7 @@ class State(proto.Enum):
12091232

12101233
class ListRolloutsRequest(proto.Message):
12111234
r"""ListRolloutsRequest is the request object used by ``ListRollouts``.
1235+
12121236
Attributes:
12131237
parent (str):
12141238
Required. The ``Release`` which owns this collection of
@@ -1268,6 +1292,7 @@ def raw_page(self):
12681292

12691293
class GetRolloutRequest(proto.Message):
12701294
r"""GetRolloutRequest is the request object used by ``GetRollout``.
1295+
12711296
Attributes:
12721297
name (str):
12731298
Required. Name of the ``Rollout``. Format must be
@@ -1323,6 +1348,7 @@ class CreateRolloutRequest(proto.Message):
13231348

13241349
class OperationMetadata(proto.Message):
13251350
r"""Represents the metadata of the long-running operation.
1351+
13261352
Attributes:
13271353
create_time (google.protobuf.timestamp_pb2.Timestamp):
13281354
Output only. The time the operation was
@@ -1361,6 +1387,7 @@ class OperationMetadata(proto.Message):
13611387

13621388
class ApproveRolloutRequest(proto.Message):
13631389
r"""The request object used by ``ApproveRollout``.
1390+
13641391
Attributes:
13651392
name (str):
13661393
Required. Name of the Rollout. Format is
@@ -1375,11 +1402,13 @@ class ApproveRolloutRequest(proto.Message):
13751402

13761403

13771404
class ApproveRolloutResponse(proto.Message):
1378-
r"""The response object from ``ApproveRollout``. """
1405+
r"""The response object from ``ApproveRollout``.
1406+
"""
13791407

13801408

13811409
class Config(proto.Message):
13821410
r"""Service-wide configuration.
1411+
13831412
Attributes:
13841413
name (str):
13851414
Name of the configuration.
@@ -1401,6 +1430,7 @@ class Config(proto.Message):
14011430

14021431
class SkaffoldVersion(proto.Message):
14031432
r"""Details of a supported Skaffold version.
1433+
14041434
Attributes:
14051435
version (str):
14061436
Release version number. For example,
@@ -1416,6 +1446,7 @@ class SkaffoldVersion(proto.Message):
14161446

14171447
class GetConfigRequest(proto.Message):
14181448
r"""Request to get a configuration.
1449+
14191450
Attributes:
14201451
name (str):
14211452
Required. Name of requested configuration.

tests/unit/gapic/deploy_v1/test_cloud_deploy.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from google.api_core import grpc_helpers_async
3333
from google.api_core import operation_async # type: ignore
3434
from google.api_core import operations_v1
35+
from google.api_core import path_template
3536
from google.auth import credentials as ga_credentials
3637
from google.auth.exceptions import MutualTLSChannelError
3738
from google.cloud.deploy_v1.services.cloud_deploy import CloudDeployAsyncClient
@@ -5023,6 +5024,9 @@ def test_cloud_deploy_base_transport():
50235024
with pytest.raises(NotImplementedError):
50245025
getattr(transport, method)(request=object())
50255026

5027+
with pytest.raises(NotImplementedError):
5028+
transport.close()
5029+
50265030
# Additionally, the LRO client (a property) should
50275031
# also raise NotImplementedError
50285032
with pytest.raises(NotImplementedError):
@@ -5704,3 +5708,49 @@ def test_client_withDEFAULT_CLIENT_INFO():
57045708
credentials=ga_credentials.AnonymousCredentials(), client_info=client_info,
57055709
)
57065710
prep.assert_called_once_with(client_info)
5711+
5712+
5713+
@pytest.mark.asyncio
5714+
async def test_transport_close_async():
5715+
client = CloudDeployAsyncClient(
5716+
credentials=ga_credentials.AnonymousCredentials(), transport="grpc_asyncio",
5717+
)
5718+
with mock.patch.object(
5719+
type(getattr(client.transport, "grpc_channel")), "close"
5720+
) as close:
5721+
async with client:
5722+
close.assert_not_called()
5723+
close.assert_called_once()
5724+
5725+
5726+
def test_transport_close():
5727+
transports = {
5728+
"grpc": "_grpc_channel",
5729+
}
5730+
5731+
for transport, close_name in transports.items():
5732+
client = CloudDeployClient(
5733+
credentials=ga_credentials.AnonymousCredentials(), transport=transport
5734+
)
5735+
with mock.patch.object(
5736+
type(getattr(client.transport, close_name)), "close"
5737+
) as close:
5738+
with client:
5739+
close.assert_not_called()
5740+
close.assert_called_once()
5741+
5742+
5743+
def test_client_ctx():
5744+
transports = [
5745+
"grpc",
5746+
]
5747+
for transport in transports:
5748+
client = CloudDeployClient(
5749+
credentials=ga_credentials.AnonymousCredentials(), transport=transport
5750+
)
5751+
# Test client calls underlying transport.
5752+
with mock.patch.object(type(client.transport), "close") as close:
5753+
close.assert_not_called()
5754+
with client:
5755+
pass
5756+
close.assert_called()

0 commit comments

Comments
 (0)