diff --git a/.github/.OwlBot.lock.yaml b/.github/.OwlBot.lock.yaml index a3da1b0d4c..7f291dbd5f 100644 --- a/.github/.OwlBot.lock.yaml +++ b/.github/.OwlBot.lock.yaml @@ -13,5 +13,5 @@ # limitations under the License. docker: image: gcr.io/cloud-devrel-public-resources/owlbot-python:latest - digest: sha256:3e3800bb100af5d7f9e810d48212b37812c1856d20ffeafb99ebe66461b61fc7 -# created: 2023-08-02T10:53:29.114535628Z + digest: sha256:4f9b3b106ad0beafc2c8a415e3f62c1a0cc23cabea115dbe841b848f581cfe99 +# created: 2023-10-18T20:26:37.410353675Z diff --git a/.github/sync-repo-settings.yaml b/.github/sync-repo-settings.yaml new file mode 100644 index 0000000000..75719fb3fc --- /dev/null +++ b/.github/sync-repo-settings.yaml @@ -0,0 +1,46 @@ +# Rules for main branch protection +branchProtectionRules: +# Identifies the protection rule pattern. Name of the branch to be protected. +# Defaults to `main` +- pattern: main + # Can admins overwrite branch protection. + # Defaults to `true` + isAdminEnforced: true + # Number of approving reviews required to update matching branches. + # Defaults to `1` + requiredApprovingReviewCount: 1 + # Are reviews from code owners required to update matching branches. + # Defaults to `false` + requiresCodeOwnerReviews: true + # Require up to date branches + requiresStrictStatusChecks: true + # List of required status check contexts that must pass for commits to be accepted to matching branches. + requiredStatusCheckContexts: + - 'Kokoro' + - 'Kokoro system-3.7' + - 'cla/google' + - 'OwlBot Post Processor' + - 'docs' + - 'docfx' + - 'lint' + - 'unit (3.7)' + - 'unit (3.8)' + - 'unit (3.9)' + - 'unit (3.10)' + - 'unit (3.11)' + - 'cover' + - 'run-systests' +# List of explicit permissions to add (additive only) +permissionRules: + # Team slug to add to repository permissions + - team: yoshi-admins + # Access level required, one of push|pull|admin|maintain|triage + permission: admin + # Team slug to add to repository permissions + - team: yoshi-python-admins + # Access level required, one of push|pull|admin|maintain|triage + permission: admin + # Team slug to add to repository permissions + - team: yoshi-python + # Access level required, one of push|pull|admin|maintain|triage + permission: push diff --git a/.github/workflows/mypy.yml b/.github/workflows/mypy.yml index de40c7ae88..20622633a7 100644 --- a/.github/workflows/mypy.yml +++ b/.github/workflows/mypy.yml @@ -8,7 +8,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Setup Python uses: actions/setup-python@v4 with: diff --git a/.github/workflows/system_emulated.yml b/.github/workflows/system_emulated.yml index d89a049992..44d56657f9 100644 --- a/.github/workflows/system_emulated.yml +++ b/.github/workflows/system_emulated.yml @@ -7,12 +7,12 @@ on: jobs: run-systests: - runs-on: ubuntu-20.04 + runs-on: ubuntu-22.04 steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Setup Python uses: actions/setup-python@v4 diff --git a/.gitignore b/.gitignore index 861c70e56f..6b3c7fdbc2 100644 --- a/.gitignore +++ b/.gitignore @@ -51,6 +51,7 @@ docs.metadata # Virtual environment env/ +venv/ # Test logs coverage.xml diff --git a/.kokoro/requirements.txt b/.kokoro/requirements.txt index 029bd342de..16170d0ca7 100644 --- a/.kokoro/requirements.txt +++ b/.kokoro/requirements.txt @@ -113,30 +113,30 @@ commonmark==0.9.1 \ --hash=sha256:452f9dc859be7f06631ddcb328b6919c67984aca654e5fefb3914d54691aed60 \ --hash=sha256:da2f38c92590f83de410ba1a3cbceafbc74fee9def35f9251ba9a971d6d66fd9 # via rich -cryptography==41.0.3 \ - --hash=sha256:0d09fb5356f975974dbcb595ad2d178305e5050656affb7890a1583f5e02a306 \ - --hash=sha256:23c2d778cf829f7d0ae180600b17e9fceea3c2ef8b31a99e3c694cbbf3a24b84 \ - --hash=sha256:3fb248989b6363906827284cd20cca63bb1a757e0a2864d4c1682a985e3dca47 \ - --hash=sha256:41d7aa7cdfded09b3d73a47f429c298e80796c8e825ddfadc84c8a7f12df212d \ - --hash=sha256:42cb413e01a5d36da9929baa9d70ca90d90b969269e5a12d39c1e0d475010116 \ - --hash=sha256:4c2f0d35703d61002a2bbdcf15548ebb701cfdd83cdc12471d2bae80878a4207 \ - --hash=sha256:4fd871184321100fb400d759ad0cddddf284c4b696568204d281c902fc7b0d81 \ - --hash=sha256:5259cb659aa43005eb55a0e4ff2c825ca111a0da1814202c64d28a985d33b087 \ - --hash=sha256:57a51b89f954f216a81c9d057bf1a24e2f36e764a1ca9a501a6964eb4a6800dd \ - --hash=sha256:652627a055cb52a84f8c448185922241dd5217443ca194d5739b44612c5e6507 \ - --hash=sha256:67e120e9a577c64fe1f611e53b30b3e69744e5910ff3b6e97e935aeb96005858 \ - --hash=sha256:6af1c6387c531cd364b72c28daa29232162010d952ceb7e5ca8e2827526aceae \ - --hash=sha256:6d192741113ef5e30d89dcb5b956ef4e1578f304708701b8b73d38e3e1461f34 \ - --hash=sha256:7efe8041897fe7a50863e51b77789b657a133c75c3b094e51b5e4b5cec7bf906 \ - --hash=sha256:84537453d57f55a50a5b6835622ee405816999a7113267739a1b4581f83535bd \ - --hash=sha256:8f09daa483aedea50d249ef98ed500569841d6498aa9c9f4b0531b9964658922 \ - --hash=sha256:95dd7f261bb76948b52a5330ba5202b91a26fbac13ad0e9fc8a3ac04752058c7 \ - --hash=sha256:a74fbcdb2a0d46fe00504f571a2a540532f4c188e6ccf26f1f178480117b33c4 \ - --hash=sha256:a983e441a00a9d57a4d7c91b3116a37ae602907a7618b882c8013b5762e80574 \ - --hash=sha256:ab8de0d091acbf778f74286f4989cf3d1528336af1b59f3e5d2ebca8b5fe49e1 \ - --hash=sha256:aeb57c421b34af8f9fe830e1955bf493a86a7996cc1338fe41b30047d16e962c \ - --hash=sha256:ce785cf81a7bdade534297ef9e490ddff800d956625020ab2ec2780a556c313e \ - --hash=sha256:d0d651aa754ef58d75cec6edfbd21259d93810b73f6ec246436a21b7841908de +cryptography==41.0.4 \ + --hash=sha256:004b6ccc95943f6a9ad3142cfabcc769d7ee38a3f60fb0dddbfb431f818c3a67 \ + --hash=sha256:047c4603aeb4bbd8db2756e38f5b8bd7e94318c047cfe4efeb5d715e08b49311 \ + --hash=sha256:0d9409894f495d465fe6fda92cb70e8323e9648af912d5b9141d616df40a87b8 \ + --hash=sha256:23a25c09dfd0d9f28da2352503b23e086f8e78096b9fd585d1d14eca01613e13 \ + --hash=sha256:2ed09183922d66c4ec5fdaa59b4d14e105c084dd0febd27452de8f6f74704143 \ + --hash=sha256:35c00f637cd0b9d5b6c6bd11b6c3359194a8eba9c46d4e875a3660e3b400005f \ + --hash=sha256:37480760ae08065437e6573d14be973112c9e6dcaf5f11d00147ee74f37a3829 \ + --hash=sha256:3b224890962a2d7b57cf5eeb16ccaafba6083f7b811829f00476309bce2fe0fd \ + --hash=sha256:5a0f09cefded00e648a127048119f77bc2b2ec61e736660b5789e638f43cc397 \ + --hash=sha256:5b72205a360f3b6176485a333256b9bcd48700fc755fef51c8e7e67c4b63e3ac \ + --hash=sha256:7e53db173370dea832190870e975a1e09c86a879b613948f09eb49324218c14d \ + --hash=sha256:7febc3094125fc126a7f6fb1f420d0da639f3f32cb15c8ff0dc3997c4549f51a \ + --hash=sha256:80907d3faa55dc5434a16579952ac6da800935cd98d14dbd62f6f042c7f5e839 \ + --hash=sha256:86defa8d248c3fa029da68ce61fe735432b047e32179883bdb1e79ed9bb8195e \ + --hash=sha256:8ac4f9ead4bbd0bc8ab2d318f97d85147167a488be0e08814a37eb2f439d5cf6 \ + --hash=sha256:93530900d14c37a46ce3d6c9e6fd35dbe5f5601bf6b3a5c325c7bffc030344d9 \ + --hash=sha256:9eeb77214afae972a00dee47382d2591abe77bdae166bda672fb1e24702a3860 \ + --hash=sha256:b5f4dfe950ff0479f1f00eda09c18798d4f49b98f4e2006d644b3301682ebdca \ + --hash=sha256:c3391bd8e6de35f6f1140e50aaeb3e2b3d6a9012536ca23ab0d9c35ec18c8a91 \ + --hash=sha256:c880eba5175f4307129784eca96f4e70b88e57aa3f680aeba3bab0e980b0f37d \ + --hash=sha256:cecfefa17042941f94ab54f769c8ce0fe14beff2694e9ac684176a2535bf9714 \ + --hash=sha256:e40211b4923ba5a6dc9769eab704bdb3fbb58d56c5b336d30996c24fcf12aadb \ + --hash=sha256:efc8ad4e6fc4f1752ebfb58aefece8b4e3c4cae940b0994d43649bdfce8d0d4f # via # gcp-releasetool # secretstorage @@ -382,6 +382,7 @@ protobuf==3.20.3 \ # gcp-docuploader # gcp-releasetool # google-api-core + # googleapis-common-protos pyasn1==0.4.8 \ --hash=sha256:39c7e2ec30515947ff4e87fb6f456dfc6e84857d34be479c9d4a4ba4bf46aa5d \ --hash=sha256:aef77c9fb94a3ac588e87841208bdec464471d9871bd5050a287cc9a475cd0ba @@ -466,9 +467,9 @@ typing-extensions==4.4.0 \ --hash=sha256:1511434bb92bf8dd198c12b1cc812e800d4181cfcb867674e0f8279cc93087aa \ --hash=sha256:16fa4864408f655d35ec496218b85f79b3437c829e93320c7c9215ccfd92489e # via -r requirements.in -urllib3==1.26.12 \ - --hash=sha256:3fa96cf423e6987997fc326ae8df396db2a8b7c667747d47ddd8ecba91f4a74e \ - --hash=sha256:b930dd878d5a8afb066a637fbb35144fe7901e3b209d1cd4f524bd0e9deee997 +urllib3==1.26.18 \ + --hash=sha256:34b97092d7e0a3a8cf7cd10e386f401b3737364026c45e622aa02903dffe0f07 \ + --hash=sha256:f8ecc1bba5667413457c529ab955bf8c67b45db799d159066261719e328580a0 # via # requests # twine diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 19409cbd37..6a8e169506 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,7 +22,7 @@ repos: - id: end-of-file-fixer - id: check-yaml - repo: https://github.com/psf/black - rev: 22.3.0 + rev: 23.7.0 hooks: - id: black - repo: https://github.com/pycqa/flake8 diff --git a/.release-please-manifest.json b/.release-please-manifest.json index 997329e9fc..c61c7bc41e 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,3 +1,3 @@ { - ".": "2.12.0" + ".": "2.13.0" } \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 89bbf4e1d5..e59ddbdc1d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,24 @@ [1]: https://pypi.org/project/google-cloud-firestore/#history +## [2.13.0](https://github.com/googleapis/python-firestore/compare/v2.12.0...v2.13.0) (2023-10-23) + + +### Features + +* Sum/Avg aggregation queries ([#715](https://github.com/googleapis/python-firestore/issues/715)) ([443475b](https://github.com/googleapis/python-firestore/commit/443475b01395a1749b02035313c54e1d775da09b)) + + +### Bug Fixes + +* Ensure transactions rollback on failure ([#767](https://github.com/googleapis/python-firestore/issues/767)) ([cdaf25b](https://github.com/googleapis/python-firestore/commit/cdaf25b35d27355e4ea577843004fdc2d16bb4ac)) +* Improve AsyncQuery typing ([#782](https://github.com/googleapis/python-firestore/issues/782)) ([ae1247b](https://github.com/googleapis/python-firestore/commit/ae1247b4502d395eac7b387dbdd5ef162264069f)) + + +### Documentation + +* Minor formatting ([41b5ea0](https://github.com/googleapis/python-firestore/commit/41b5ea091245bea291c8de841205ecb53a26087f)) + ## [2.12.0](https://github.com/googleapis/python-firestore/compare/v2.11.1...v2.12.0) (2023-08-07) diff --git a/google/cloud/firestore/gapic_version.py b/google/cloud/firestore/gapic_version.py index 16ae0e953c..a3c9255942 100644 --- a/google/cloud/firestore/gapic_version.py +++ b/google/cloud/firestore/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "2.12.0" # {x-release-please-version} +__version__ = "2.13.0" # {x-release-please-version} diff --git a/google/cloud/firestore_admin_v1/gapic_version.py b/google/cloud/firestore_admin_v1/gapic_version.py index 16ae0e953c..a3c9255942 100644 --- a/google/cloud/firestore_admin_v1/gapic_version.py +++ b/google/cloud/firestore_admin_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "2.12.0" # {x-release-please-version} +__version__ = "2.13.0" # {x-release-please-version} diff --git a/google/cloud/firestore_admin_v1/services/firestore_admin/async_client.py b/google/cloud/firestore_admin_v1/services/firestore_admin/async_client.py index 342c3ca7a2..db277475bb 100644 --- a/google/cloud/firestore_admin_v1/services/firestore_admin/async_client.py +++ b/google/cloud/firestore_admin_v1/services/firestore_admin/async_client.py @@ -54,7 +54,7 @@ from google.cloud.firestore_admin_v1.types import index as gfa_index from google.cloud.firestore_admin_v1.types import operation as gfa_operation from google.cloud.location import locations_pb2 # type: ignore -from google.longrunning import operations_pb2 +from google.longrunning import operations_pb2 # type: ignore from google.protobuf import empty_pb2 # type: ignore from google.protobuf import field_mask_pb2 # type: ignore from .transports.base import FirestoreAdminTransport, DEFAULT_CLIENT_INFO @@ -1453,6 +1453,7 @@ async def sample_create_database(): database, which will become the final component of the database's resource name. + The value must be set to "(default)". This corresponds to the ``database_id`` field diff --git a/google/cloud/firestore_admin_v1/services/firestore_admin/client.py b/google/cloud/firestore_admin_v1/services/firestore_admin/client.py index 5d3ba62f06..0b4b04e2fc 100644 --- a/google/cloud/firestore_admin_v1/services/firestore_admin/client.py +++ b/google/cloud/firestore_admin_v1/services/firestore_admin/client.py @@ -58,7 +58,7 @@ from google.cloud.firestore_admin_v1.types import index as gfa_index from google.cloud.firestore_admin_v1.types import operation as gfa_operation from google.cloud.location import locations_pb2 # type: ignore -from google.longrunning import operations_pb2 +from google.longrunning import operations_pb2 # type: ignore from google.protobuf import empty_pb2 # type: ignore from google.protobuf import field_mask_pb2 # type: ignore from .transports.base import FirestoreAdminTransport, DEFAULT_CLIENT_INFO @@ -1684,6 +1684,7 @@ def sample_create_database(): database, which will become the final component of the database's resource name. + The value must be set to "(default)". This corresponds to the ``database_id`` field diff --git a/google/cloud/firestore_admin_v1/services/firestore_admin/transports/base.py b/google/cloud/firestore_admin_v1/services/firestore_admin/transports/base.py index e80fc6f3fb..c7176773ea 100644 --- a/google/cloud/firestore_admin_v1/services/firestore_admin/transports/base.py +++ b/google/cloud/firestore_admin_v1/services/firestore_admin/transports/base.py @@ -32,7 +32,6 @@ from google.cloud.firestore_admin_v1.types import firestore_admin from google.cloud.firestore_admin_v1.types import index from google.cloud.location import locations_pb2 # type: ignore -from google.longrunning import operations_pb2 from google.longrunning import operations_pb2 # type: ignore from google.protobuf import empty_pb2 # type: ignore diff --git a/google/cloud/firestore_admin_v1/services/firestore_admin/transports/grpc.py b/google/cloud/firestore_admin_v1/services/firestore_admin/transports/grpc.py index d42b405ca8..fe6ecbdd91 100644 --- a/google/cloud/firestore_admin_v1/services/firestore_admin/transports/grpc.py +++ b/google/cloud/firestore_admin_v1/services/firestore_admin/transports/grpc.py @@ -30,7 +30,6 @@ from google.cloud.firestore_admin_v1.types import firestore_admin from google.cloud.firestore_admin_v1.types import index from google.cloud.location import locations_pb2 # type: ignore -from google.longrunning import operations_pb2 from google.longrunning import operations_pb2 # type: ignore from google.protobuf import empty_pb2 # type: ignore from .base import FirestoreAdminTransport, DEFAULT_CLIENT_INFO diff --git a/google/cloud/firestore_admin_v1/services/firestore_admin/transports/grpc_asyncio.py b/google/cloud/firestore_admin_v1/services/firestore_admin/transports/grpc_asyncio.py index a313e3be68..ebc9c46890 100644 --- a/google/cloud/firestore_admin_v1/services/firestore_admin/transports/grpc_asyncio.py +++ b/google/cloud/firestore_admin_v1/services/firestore_admin/transports/grpc_asyncio.py @@ -30,7 +30,6 @@ from google.cloud.firestore_admin_v1.types import firestore_admin from google.cloud.firestore_admin_v1.types import index from google.cloud.location import locations_pb2 # type: ignore -from google.longrunning import operations_pb2 from google.longrunning import operations_pb2 # type: ignore from google.protobuf import empty_pb2 # type: ignore from .base import FirestoreAdminTransport, DEFAULT_CLIENT_INFO diff --git a/google/cloud/firestore_admin_v1/services/firestore_admin/transports/rest.py b/google/cloud/firestore_admin_v1/services/firestore_admin/transports/rest.py index efe2da02e5..0264c2b1ca 100644 --- a/google/cloud/firestore_admin_v1/services/firestore_admin/transports/rest.py +++ b/google/cloud/firestore_admin_v1/services/firestore_admin/transports/rest.py @@ -29,7 +29,6 @@ from google.protobuf import json_format from google.api_core import operations_v1 from google.cloud.location import locations_pb2 # type: ignore -from google.longrunning import operations_pb2 from requests import __version__ as requests_version import dataclasses import re @@ -46,8 +45,8 @@ from google.cloud.firestore_admin_v1.types import field from google.cloud.firestore_admin_v1.types import firestore_admin from google.cloud.firestore_admin_v1.types import index -from google.longrunning import operations_pb2 # type: ignore from google.protobuf import empty_pb2 # type: ignore +from google.longrunning import operations_pb2 # type: ignore from .base import ( FirestoreAdminTransport, @@ -2054,7 +2053,6 @@ def __call__( timeout: Optional[float] = None, metadata: Sequence[Tuple[str, str]] = (), ) -> None: - r"""Call the cancel operation method over HTTP. Args: @@ -2120,7 +2118,6 @@ def __call__( timeout: Optional[float] = None, metadata: Sequence[Tuple[str, str]] = (), ) -> None: - r"""Call the delete operation method over HTTP. Args: @@ -2183,7 +2180,6 @@ def __call__( timeout: Optional[float] = None, metadata: Sequence[Tuple[str, str]] = (), ) -> operations_pb2.Operation: - r"""Call the get operation method over HTTP. Args: @@ -2250,7 +2246,6 @@ def __call__( timeout: Optional[float] = None, metadata: Sequence[Tuple[str, str]] = (), ) -> operations_pb2.ListOperationsResponse: - r"""Call the list operations method over HTTP. Args: diff --git a/google/cloud/firestore_admin_v1/types/database.py b/google/cloud/firestore_admin_v1/types/database.py index f78aab4342..c615bbe2fc 100644 --- a/google/cloud/firestore_admin_v1/types/database.py +++ b/google/cloud/firestore_admin_v1/types/database.py @@ -104,12 +104,14 @@ class ConcurrencyMode(proto.Enum): Use pessimistic concurrency control by default. This mode is available for Cloud Firestore databases. + This is the default setting for Cloud Firestore. OPTIMISTIC_WITH_ENTITY_GROUPS (3): Use optimistic concurrency control with entity groups by default. This is the only available mode for Cloud Datastore. + This mode is also available for Cloud Firestore with Datastore Mode but is not recommended. """ diff --git a/google/cloud/firestore_admin_v1/types/field.py b/google/cloud/firestore_admin_v1/types/field.py index acfa02cb18..dfba26d49d 100644 --- a/google/cloud/firestore_admin_v1/types/field.py +++ b/google/cloud/firestore_admin_v1/types/field.py @@ -32,6 +32,7 @@ class Field(proto.Message): r"""Represents a single field in the database. + Fields are grouped by their "Collection Group", which represent all collections in the database with the same id. diff --git a/google/cloud/firestore_admin_v1/types/firestore_admin.py b/google/cloud/firestore_admin_v1/types/firestore_admin.py index 31ab5c9290..5d2b56d28f 100644 --- a/google/cloud/firestore_admin_v1/types/firestore_admin.py +++ b/google/cloud/firestore_admin_v1/types/firestore_admin.py @@ -80,6 +80,7 @@ class CreateDatabaseRequest(proto.Message): Required. The ID to use for the database, which will become the final component of the database's resource name. + The value must be set to "(default)". """ diff --git a/google/cloud/firestore_admin_v1/types/index.py b/google/cloud/firestore_admin_v1/types/index.py index e5743dcbd6..4846a0d99a 100644 --- a/google/cloud/firestore_admin_v1/types/index.py +++ b/google/cloud/firestore_admin_v1/types/index.py @@ -44,6 +44,7 @@ class Index(proto.Message): that is the child of a specific document, specified at query time, and that has the same collection id. + Indexes with a collection group query scope specified allow queries against all collections descended from a specific document, specified at diff --git a/google/cloud/firestore_bundle/gapic_version.py b/google/cloud/firestore_bundle/gapic_version.py index 16ae0e953c..a3c9255942 100644 --- a/google/cloud/firestore_bundle/gapic_version.py +++ b/google/cloud/firestore_bundle/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "2.12.0" # {x-release-please-version} +__version__ = "2.13.0" # {x-release-please-version} diff --git a/google/cloud/firestore_v1/_helpers.py b/google/cloud/firestore_v1/_helpers.py index 3b6b7886bc..9c8976bb6a 100644 --- a/google/cloud/firestore_v1/_helpers.py +++ b/google/cloud/firestore_v1/_helpers.py @@ -440,7 +440,6 @@ def extract_fields( yield prefix_path, _EmptyDict else: for key, value in sorted(document_data.items()): - if expand_dots: sub_key = FieldPath.from_string(key) else: @@ -503,7 +502,6 @@ def __init__(self, document_data) -> None: iterator = self._get_document_iterator(prefix_path) for field_path, value in iterator: - if field_path == prefix_path and value is _EmptyDict: self.empty_document = True @@ -565,7 +563,6 @@ def _get_update_mask(self, allow_empty_mask=False) -> None: def get_update_pb( self, document_path, exists=None, allow_empty_mask=False ) -> types.write.Write: - if exists is not None: current_document = common.Precondition(exists=exists) else: @@ -762,7 +759,6 @@ def _normalize_merge_paths(self, merge) -> list: return merge_paths def _apply_merge_paths(self, merge) -> None: - if self.empty_document: raise ValueError("Cannot merge specific fields with empty document.") @@ -773,7 +769,6 @@ def _apply_merge_paths(self, merge) -> None: self.merge = merge_paths for merge_path in merge_paths: - if merge_path in self.transform_paths: self.transform_merge.append(merge_path) @@ -1187,7 +1182,6 @@ def deserialize_bundle( bundle: Optional[FirestoreBundle] = None data: Dict for data in _parse_bundle_elements_data(serialized): - # BundleElements are serialized as JSON containing one key outlining # the type, with all further data nested under that key keys: List[str] = list(data.keys()) diff --git a/google/cloud/firestore_v1/async_collection.py b/google/cloud/firestore_v1/async_collection.py index e997455092..293a1e0f5b 100644 --- a/google/cloud/firestore_v1/async_collection.py +++ b/google/cloud/firestore_v1/async_collection.py @@ -32,7 +32,7 @@ from google.cloud.firestore_v1.transaction import Transaction -class AsyncCollectionReference(BaseCollectionReference): +class AsyncCollectionReference(BaseCollectionReference[async_query.AsyncQuery]): """A reference to a collection in a Firestore database. The collection may already exist or this class can facilitate creation diff --git a/google/cloud/firestore_v1/async_query.py b/google/cloud/firestore_v1/async_query.py index efa172520a..d03ab72b87 100644 --- a/google/cloud/firestore_v1/async_query.py +++ b/google/cloud/firestore_v1/async_query.py @@ -34,13 +34,14 @@ ) from google.cloud.firestore_v1 import async_document +from google.cloud.firestore_v1.async_aggregation import AsyncAggregationQuery from google.cloud.firestore_v1.base_document import DocumentSnapshot -from typing import AsyncGenerator, List, Optional, Type - -# Types needed only for Type Hints -from google.cloud.firestore_v1.transaction import Transaction +from typing import AsyncGenerator, List, Optional, Type, TYPE_CHECKING -from google.cloud.firestore_v1.async_aggregation import AsyncAggregationQuery +if TYPE_CHECKING: # pragma: NO COVER + # Types needed only for Type Hints + from google.cloud.firestore_v1.transaction import Transaction + from google.cloud.firestore_v1.field_path import FieldPath class AsyncQuery(BaseQuery): @@ -222,8 +223,8 @@ def count( """Adds a count over the nested query. Args: - alias - (Optional[str]): The alias for the count + alias(Optional[str]): Optional name of the field to store the result of the aggregation into. + If not provided, Firestore will pick a default name following the format field_. Returns: :class:`~google.cloud.firestore_v1.async_aggregation.AsyncAggregationQuery`: @@ -231,6 +232,38 @@ def count( """ return AsyncAggregationQuery(self).count(alias=alias) + def sum( + self, field_ref: str | FieldPath, alias: str | None = None + ) -> Type["firestore_v1.async_aggregation.AsyncAggregationQuery"]: + """Adds a sum over the nested query. + + Args: + field_ref(Union[str, google.cloud.firestore_v1.field_path.FieldPath]): The field to aggregate across. + alias(Optional[str]): Optional name of the field to store the result of the aggregation into. + If not provided, Firestore will pick a default name following the format field_. + + Returns: + :class:`~google.cloud.firestore_v1.async_aggregation.AsyncAggregationQuery`: + An instance of an AsyncAggregationQuery object + """ + return AsyncAggregationQuery(self).sum(field_ref, alias=alias) + + def avg( + self, field_ref: str | FieldPath, alias: str | None = None + ) -> Type["firestore_v1.async_aggregation.AsyncAggregationQuery"]: + """Adds an avg over the nested query. + + Args: + field_ref(Union[str, google.cloud.firestore_v1.field_path.FieldPath]): The field to aggregate across. + alias(Optional[str]): Optional name of the field to store the result of the aggregation into. + If not provided, Firestore will pick a default name following the format field_. + + Returns: + :class:`~google.cloud.firestore_v1.async_aggregation.AsyncAggregationQuery`: + An instance of an AsyncAggregationQuery object + """ + return AsyncAggregationQuery(self).avg(field_ref, alias=alias) + async def stream( self, transaction=None, @@ -292,9 +325,9 @@ async def stream( yield snapshot @staticmethod - def _get_collection_reference_class() -> Type[ - "firestore_v1.async_collection.AsyncCollectionReference" - ]: + def _get_collection_reference_class() -> ( + Type["firestore_v1.async_collection.AsyncCollectionReference"] + ): from google.cloud.firestore_v1.async_collection import AsyncCollectionReference return AsyncCollectionReference diff --git a/google/cloud/firestore_v1/async_transaction.py b/google/cloud/firestore_v1/async_transaction.py index f4ecf32d34..b504bebadc 100644 --- a/google/cloud/firestore_v1/async_transaction.py +++ b/google/cloud/firestore_v1/async_transaction.py @@ -110,6 +110,7 @@ async def _rollback(self) -> None: Raises: ValueError: If no transaction is in progress. + google.api_core.exceptions.GoogleAPICallError: If the rollback fails. """ if not self.in_progress: raise ValueError(_CANT_ROLLBACK) @@ -124,6 +125,7 @@ async def _rollback(self) -> None: metadata=self._client._rpc_metadata, ) finally: + # clean up, even if rollback fails self._clean_up() async def _commit(self) -> list: @@ -223,10 +225,6 @@ async def _pre_commit( ) -> Coroutine: """Begin transaction and call the wrapped coroutine. - If the coroutine raises an exception, the transaction will be rolled - back. If not, the transaction will be "ready" for ``Commit`` (i.e. - it will have staged writes). - Args: transaction (:class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction`): @@ -250,41 +248,7 @@ async def _pre_commit( self.current_id = transaction._id if self.retry_id is None: self.retry_id = self.current_id - try: - return await self.to_wrap(transaction, *args, **kwargs) - except: # noqa - # NOTE: If ``rollback`` fails this will lose the information - # from the original failure. - await transaction._rollback() - raise - - async def _maybe_commit(self, transaction: AsyncTransaction) -> bool: - """Try to commit the transaction. - - If the transaction is read-write and the ``Commit`` fails with the - ``ABORTED`` status code, it will be retried. Any other failure will - not be caught. - - Args: - transaction - (:class:`~google.cloud.firestore_v1.transaction.Transaction`): - The transaction to be ``Commit``-ed. - - Returns: - bool: Indicating if the commit succeeded. - """ - try: - await transaction._commit() - return True - except exceptions.GoogleAPICallError as exc: - if transaction._read_only: - raise - - if isinstance(exc, exceptions.Aborted): - # If a read-write transaction returns ABORTED, retry. - return False - else: - raise + return await self.to_wrap(transaction, *args, **kwargs) async def __call__(self, transaction, *args, **kwargs): """Execute the wrapped callable within a transaction. @@ -306,22 +270,35 @@ async def __call__(self, transaction, *args, **kwargs): ``max_attempts``. """ self._reset() + retryable_exceptions = ( + (exceptions.Aborted) if not transaction._read_only else () + ) + last_exc = None - for attempt in range(transaction._max_attempts): - result = await self._pre_commit(transaction, *args, **kwargs) - succeeded = await self._maybe_commit(transaction) - if succeeded: - return result - - # Subsequent requests will use the failed transaction ID as part of - # the ``BeginTransactionRequest`` when restarting this transaction - # (via ``options.retry_transaction``). This preserves the "spot in - # line" of the transaction, so exponential backoff is not required - # in this case. - - await transaction._rollback() - msg = _EXCEED_ATTEMPTS_TEMPLATE.format(transaction._max_attempts) - raise ValueError(msg) + try: + for attempt in range(transaction._max_attempts): + result = await self._pre_commit(transaction, *args, **kwargs) + try: + await transaction._commit() + return result + except retryable_exceptions as exc: + last_exc = exc + # Retry attempts that result in retryable exceptions + # Subsequent requests will use the failed transaction ID as part of + # the ``BeginTransactionRequest`` when restarting this transaction + # (via ``options.retry_transaction``). This preserves the "spot in + # line" of the transaction, so exponential backoff is not required + # in this case. + # retries exhausted + # wrap the last exception in a ValueError before raising + msg = _EXCEED_ATTEMPTS_TEMPLATE.format(transaction._max_attempts) + raise ValueError(msg) from last_exc + + except BaseException: + # rollback the transaction on any error + # errors raised during _rollback will be chained to the original error through __context__ + await transaction._rollback() + raise def async_transactional( diff --git a/google/cloud/firestore_v1/base_aggregation.py b/google/cloud/firestore_v1/base_aggregation.py index b7a6605b87..d6097c136b 100644 --- a/google/cloud/firestore_v1/base_aggregation.py +++ b/google/cloud/firestore_v1/base_aggregation.py @@ -33,8 +33,8 @@ from google.api_core import retry as retries +from google.cloud.firestore_v1.field_path import FieldPath from google.cloud.firestore_v1.types import RunAggregationQueryResponse - from google.cloud.firestore_v1.types import StructuredAggregationQuery from google.cloud.firestore_v1 import _helpers @@ -60,6 +60,9 @@ def __repr__(self): class BaseAggregation(ABC): + def __init__(self, alias: str | None = None): + self.alias = alias + @abc.abstractmethod def _to_protobuf(self): """Convert this instance to the protobuf representation""" @@ -67,7 +70,7 @@ def _to_protobuf(self): class CountAggregation(BaseAggregation): def __init__(self, alias: str | None = None): - self.alias = alias + super(CountAggregation, self).__init__(alias=alias) def _to_protobuf(self): """Convert this instance to the protobuf representation""" @@ -77,13 +80,48 @@ def _to_protobuf(self): return aggregation_pb +class SumAggregation(BaseAggregation): + def __init__(self, field_ref: str | FieldPath, alias: str | None = None): + if isinstance(field_ref, FieldPath): + # convert field path to string + field_ref = field_ref.to_api_repr() + self.field_ref = field_ref + super(SumAggregation, self).__init__(alias=alias) + + def _to_protobuf(self): + """Convert this instance to the protobuf representation""" + aggregation_pb = StructuredAggregationQuery.Aggregation() + aggregation_pb.alias = self.alias + aggregation_pb.sum = StructuredAggregationQuery.Aggregation.Sum() + aggregation_pb.sum.field.field_path = self.field_ref + return aggregation_pb + + +class AvgAggregation(BaseAggregation): + def __init__(self, field_ref: str | FieldPath, alias: str | None = None): + if isinstance(field_ref, FieldPath): + # convert field path to string + field_ref = field_ref.to_api_repr() + self.field_ref = field_ref + super(AvgAggregation, self).__init__(alias=alias) + + def _to_protobuf(self): + """Convert this instance to the protobuf representation""" + aggregation_pb = StructuredAggregationQuery.Aggregation() + aggregation_pb.alias = self.alias + aggregation_pb.avg = StructuredAggregationQuery.Aggregation.Avg() + aggregation_pb.avg.field.field_path = self.field_ref + return aggregation_pb + + def _query_response_to_result( response_pb: RunAggregationQueryResponse, ) -> List[AggregationResult]: results = [ AggregationResult( alias=key, - value=response_pb.result.aggregate_fields[key].integer_value, + value=response_pb.result.aggregate_fields[key].integer_value + or response_pb.result.aggregate_fields[key].double_value, read_time=response_pb.read_time, ) for key in response_pb.result.aggregate_fields.pb.keys() @@ -95,11 +133,9 @@ def _query_response_to_result( class BaseAggregationQuery(ABC): """Represents an aggregation query to the Firestore API.""" - def __init__( - self, - nested_query, - ) -> None: + def __init__(self, nested_query, alias: str | None = None) -> None: self._nested_query = nested_query + self._alias = alias self._collection_ref = nested_query._parent self._aggregations: List[BaseAggregation] = [] @@ -115,6 +151,22 @@ def count(self, alias: str | None = None): self._aggregations.append(count_aggregation) return self + def sum(self, field_ref: str | FieldPath, alias: str | None = None): + """ + Adds a sum over the nested query + """ + sum_aggregation = SumAggregation(field_ref, alias=alias) + self._aggregations.append(sum_aggregation) + return self + + def avg(self, field_ref: str | FieldPath, alias: str | None = None): + """ + Adds an avg over the nested query + """ + avg_aggregation = AvgAggregation(field_ref, alias=alias) + self._aggregations.append(avg_aggregation) + return self + def add_aggregation(self, aggregation: BaseAggregation) -> None: """ Adds an aggregation operation to the nested query @@ -196,9 +248,10 @@ def stream( retries.Retry, None, gapic_v1.method._MethodDefault ] = gapic_v1.method.DEFAULT, timeout: float | None = None, - ) -> Generator[List[AggregationResult], Any, None] | AsyncGenerator[ - List[AggregationResult], None - ]: + ) -> ( + Generator[List[AggregationResult], Any, None] + | AsyncGenerator[List[AggregationResult], None] + ): """Runs the aggregation query. This sends a``RunAggregationQuery`` RPC and returns an iterator in the stream of ``RunAggregationQueryResponse`` messages. diff --git a/google/cloud/firestore_v1/base_client.py b/google/cloud/firestore_v1/base_client.py index bed9d4c2a4..345e061428 100644 --- a/google/cloud/firestore_v1/base_client.py +++ b/google/cloud/firestore_v1/base_client.py @@ -262,13 +262,15 @@ def _rpc_metadata(self): return self._rpc_metadata_internal - def collection(self, *collection_path) -> BaseCollectionReference: + def collection(self, *collection_path) -> BaseCollectionReference[BaseQuery]: raise NotImplementedError def collection_group(self, collection_id: str) -> BaseQuery: raise NotImplementedError - def _get_collection_reference(self, collection_id: str) -> BaseCollectionReference: + def _get_collection_reference( + self, collection_id: str + ) -> BaseCollectionReference[BaseQuery]: """Checks validity of collection_id and then uses subclasses collection implementation. Args: @@ -325,7 +327,7 @@ def _document_path_helper(self, *document_path) -> List[str]: def recursive_delete( self, - reference: Union[BaseCollectionReference, BaseDocumentReference], + reference: Union[BaseCollectionReference[BaseQuery], BaseDocumentReference], bulk_writer: Optional["BulkWriter"] = None, # type: ignore ) -> int: raise NotImplementedError @@ -459,8 +461,8 @@ def collections( retry: retries.Retry = None, timeout: float = None, ) -> Union[ - AsyncGenerator[BaseCollectionReference, Any], - Generator[BaseCollectionReference, Any, Any], + AsyncGenerator[BaseCollectionReference[BaseQuery], Any], + Generator[BaseCollectionReference[BaseQuery], Any, Any], ]: raise NotImplementedError diff --git a/google/cloud/firestore_v1/base_collection.py b/google/cloud/firestore_v1/base_collection.py index 6f87dffb2e..a9d644c4b4 100644 --- a/google/cloud/firestore_v1/base_collection.py +++ b/google/cloud/firestore_v1/base_collection.py @@ -13,6 +13,7 @@ # limitations under the License. """Classes for representing collections for the Google Cloud Firestore API.""" +from __future__ import annotations import random from google.api_core import retry as retries @@ -20,6 +21,7 @@ from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.document import DocumentReference from google.cloud.firestore_v1.base_aggregation import BaseAggregationQuery +from google.cloud.firestore_v1.base_query import QueryType from typing import ( @@ -28,23 +30,27 @@ AsyncGenerator, Coroutine, Generator, + Generic, AsyncIterator, Iterator, Iterable, NoReturn, Tuple, Union, + TYPE_CHECKING, ) -# Types needed only for Type Hints -from google.cloud.firestore_v1.base_document import DocumentSnapshot -from google.cloud.firestore_v1.base_query import BaseQuery -from google.cloud.firestore_v1.transaction import Transaction + +if TYPE_CHECKING: # pragma: NO COVER + # Types needed only for Type Hints + from google.cloud.firestore_v1.base_document import DocumentSnapshot + from google.cloud.firestore_v1.transaction import Transaction + from google.cloud.firestore_v1.field_path import FieldPath _AUTO_ID_CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" -class BaseCollectionReference(object): +class BaseCollectionReference(Generic[QueryType]): """A reference to a collection in a Firestore database. The collection may already exist or this class can facilitate creation @@ -108,7 +114,7 @@ def parent(self): parent_path = self._path[:-1] return self._client.document(*parent_path) - def _query(self) -> BaseQuery: + def _query(self) -> QueryType: raise NotImplementedError def _aggregation_query(self) -> BaseAggregationQuery: @@ -215,10 +221,10 @@ def list_documents( ]: raise NotImplementedError - def recursive(self) -> "BaseQuery": + def recursive(self) -> QueryType: return self._query().recursive() - def select(self, field_paths: Iterable[str]) -> BaseQuery: + def select(self, field_paths: Iterable[str]) -> QueryType: """Create a "select" query with this collection as parent. See @@ -243,8 +249,8 @@ def where( op_string: Optional[str] = None, value=None, *, - filter=None - ) -> BaseQuery: + filter=None, + ) -> QueryType: """Create a "where" query with this collection as parent. See @@ -280,7 +286,6 @@ def where( wrapped_names = [] for name in value: - if isinstance(name, str): name = self.document(name) @@ -291,7 +296,7 @@ def where( else: return query.where(filter=filter) - def order_by(self, field_path: str, **kwargs) -> BaseQuery: + def order_by(self, field_path: str, **kwargs) -> QueryType: """Create an "order by" query with this collection as parent. See @@ -313,7 +318,7 @@ def order_by(self, field_path: str, **kwargs) -> BaseQuery: query = self._query() return query.order_by(field_path, **kwargs) - def limit(self, count: int) -> BaseQuery: + def limit(self, count: int) -> QueryType: """Create a limited query with this collection as parent. .. note:: @@ -356,7 +361,7 @@ def limit_to_last(self, count: int): query = self._query() return query.limit_to_last(count) - def offset(self, num_to_skip: int) -> BaseQuery: + def offset(self, num_to_skip: int) -> QueryType: """Skip to an offset in a query with this collection as parent. See @@ -376,7 +381,7 @@ def offset(self, num_to_skip: int) -> BaseQuery: def start_at( self, document_fields: Union[DocumentSnapshot, dict, list, tuple] - ) -> BaseQuery: + ) -> QueryType: """Start query at a cursor with this collection as parent. See @@ -399,7 +404,7 @@ def start_at( def start_after( self, document_fields: Union[DocumentSnapshot, dict, list, tuple] - ) -> BaseQuery: + ) -> QueryType: """Start query after a cursor with this collection as parent. See @@ -422,7 +427,7 @@ def start_after( def end_before( self, document_fields: Union[DocumentSnapshot, dict, list, tuple] - ) -> BaseQuery: + ) -> QueryType: """End query before a cursor with this collection as parent. See @@ -445,7 +450,7 @@ def end_before( def end_at( self, document_fields: Union[DocumentSnapshot, dict, list, tuple] - ) -> BaseQuery: + ) -> QueryType: """End query at a cursor with this collection as parent. See @@ -507,6 +512,33 @@ def count(self, alias=None): """ return self._aggregation_query().count(alias=alias) + def sum(self, field_ref: str | FieldPath, alias=None): + """ + Adds a sum over the nested query. + + :type field_ref: Union[str, google.cloud.firestore_v1.field_path.FieldPath] + :param field_ref: The field to aggregate across. + + :type alias: Optional[str] + :param alias: Optional name of the field to store the result of the aggregation into. + If not provided, Firestore will pick a default name following the format field_. + + """ + return self._aggregation_query().sum(field_ref, alias=alias) + + def avg(self, field_ref: str | FieldPath, alias=None): + """ + Adds an avg over the nested query. + + :type field_ref: Union[str, google.cloud.firestore_v1.field_path.FieldPath] + :param field_ref: The field to aggregate across. + + :type alias: Optional[str] + :param alias: Optional name of the field to store the result of the aggregation into. + If not provided, Firestore will pick a default name following the format field_. + """ + return self._aggregation_query().avg(field_ref, alias=alias) + def _auto_id() -> str: """Generate a "random" automatically generated ID. diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index 9fd2fe1c08..da1e41232e 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -47,12 +47,17 @@ Optional, Tuple, Type, + TypeVar, Union, + TYPE_CHECKING, ) # Types needed only for Type Hints from google.cloud.firestore_v1.base_document import DocumentSnapshot +if TYPE_CHECKING: # pragma: NO COVER + from google.cloud.firestore_v1.field_path import FieldPath + _BAD_DIR_STRING: str _BAD_OP_NAN_NULL: str _BAD_OP_STRING: str @@ -102,6 +107,8 @@ _not_passed = object() +QueryType = TypeVar("QueryType", bound="BaseQuery") + class BaseFilter(abc.ABC): """Base class for Filters""" @@ -319,7 +326,7 @@ def _client(self): """ return self._parent._client - def select(self, field_paths: Iterable[str]) -> "BaseQuery": + def select(self: QueryType, field_paths: Iterable[str]) -> QueryType: """Project documents matching query to a limited set of fields. See :meth:`~google.cloud.firestore_v1.client.Client.field_path` for @@ -354,7 +361,7 @@ def select(self, field_paths: Iterable[str]) -> "BaseQuery": return self._copy(projection=new_projection) def _copy( - self, + self: QueryType, *, projection: Optional[query.StructuredQuery.Projection] = _not_passed, field_filters: Optional[Tuple[query.StructuredQuery.FieldFilter]] = _not_passed, @@ -366,7 +373,7 @@ def _copy( end_at: Optional[Tuple[dict, bool]] = _not_passed, all_descendants: Optional[bool] = _not_passed, recursive: Optional[bool] = _not_passed, - ) -> "BaseQuery": + ) -> QueryType: return self.__class__( self._parent, projection=self._evaluate_param(projection, self._projection), @@ -389,13 +396,13 @@ def _evaluate_param(self, value, fallback_value): return value if value is not _not_passed else fallback_value def where( - self, + self: QueryType, field_path: Optional[str] = None, op_string: Optional[str] = None, value=None, *, filter=None, - ) -> "BaseQuery": + ) -> QueryType: """Filter the query on a field. See :meth:`~google.cloud.firestore_v1.client.Client.field_path` for @@ -492,7 +499,9 @@ def _make_order(field_path, direction) -> StructuredQuery.Order: direction=_enum_from_direction(direction), ) - def order_by(self, field_path: str, direction: str = ASCENDING) -> "BaseQuery": + def order_by( + self: QueryType, field_path: str, direction: str = ASCENDING + ) -> QueryType: """Modify the query to add an order clause on a specific field. See :meth:`~google.cloud.firestore_v1.client.Client.field_path` for @@ -526,7 +535,7 @@ def order_by(self, field_path: str, direction: str = ASCENDING) -> "BaseQuery": new_orders = self._orders + (order_pb,) return self._copy(orders=new_orders) - def limit(self, count: int) -> "BaseQuery": + def limit(self: QueryType, count: int) -> QueryType: """Limit a query to return at most `count` matching results. If the current query already has a `limit` set, this will override it. @@ -545,7 +554,7 @@ def limit(self, count: int) -> "BaseQuery": """ return self._copy(limit=count, limit_to_last=False) - def limit_to_last(self, count: int) -> "BaseQuery": + def limit_to_last(self: QueryType, count: int) -> QueryType: """Limit a query to return the last `count` matching results. If the current query already has a `limit_to_last` set, this will override it. @@ -570,7 +579,7 @@ def _resolve_chunk_size(self, num_loaded: int, chunk_size: int) -> int: return max(self._limit - num_loaded, 0) return chunk_size - def offset(self, num_to_skip: int) -> "BaseQuery": + def offset(self: QueryType, num_to_skip: int) -> QueryType: """Skip to an offset in a query. If the current query already has specified an offset, this will @@ -601,11 +610,11 @@ def _check_snapshot(self, document_snapshot) -> None: raise ValueError("Cannot use snapshot from another collection as a cursor.") def _cursor_helper( - self, + self: QueryType, document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple], before: bool, start: bool, - ) -> "BaseQuery": + ) -> QueryType: """Set values to be used for a ``start_at`` or ``end_at`` cursor. The values will later be used in a query protobuf. @@ -658,8 +667,9 @@ def _cursor_helper( return self._copy(**query_kwargs) def start_at( - self, document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple] - ) -> "BaseQuery": + self: QueryType, + document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple], + ) -> QueryType: """Start query results at a particular document value. The result set will **include** the document specified by @@ -690,8 +700,9 @@ def start_at( return self._cursor_helper(document_fields_or_snapshot, before=True, start=True) def start_after( - self, document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple] - ) -> "BaseQuery": + self: QueryType, + document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple], + ) -> QueryType: """Start query results after a particular document value. The result set will **exclude** the document specified by @@ -723,8 +734,9 @@ def start_after( ) def end_before( - self, document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple] - ) -> "BaseQuery": + self: QueryType, + document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple], + ) -> QueryType: """End query results before a particular document value. The result set will **exclude** the document specified by @@ -756,8 +768,9 @@ def end_before( ) def end_at( - self, document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple] - ) -> "BaseQuery": + self: QueryType, + document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple], + ) -> QueryType: """End query results at a particular document value. The result set will **include** the document specified by @@ -808,7 +821,6 @@ def _filters_pb(self) -> Optional[StructuredQuery.Filter]: else: return _filter_pb(filter) else: - composite_filter = query.StructuredQuery.CompositeFilter( op=StructuredQuery.CompositeFilter.Operator.AND, ) @@ -826,7 +838,6 @@ def _filters_pb(self) -> Optional[StructuredQuery.Filter]: def _normalize_projection(projection) -> StructuredQuery.Projection: """Helper: convert field paths to message.""" if projection is not None: - fields = list(projection.fields) if not fields: @@ -963,6 +974,16 @@ def count( ) -> Type["firestore_v1.base_aggregation.BaseAggregationQuery"]: raise NotImplementedError + def sum( + self, field_ref: str | FieldPath, alias: str | None = None + ) -> Type["firestore_v1.base_aggregation.BaseAggregationQuery"]: + raise NotImplementedError + + def avg( + self, field_ref: str | FieldPath, alias: str | None = None + ) -> Type["firestore_v1.base_aggregation.BaseAggregationQuery"]: + raise NotImplementedError + def get( self, transaction=None, @@ -1005,7 +1026,7 @@ def stream( def on_snapshot(self, callback) -> NoReturn: raise NotImplementedError - def recursive(self) -> "BaseQuery": + def recursive(self: QueryType) -> QueryType: """Returns a copy of this query whose iterator will yield all matching documents as well as each of their descendent subcollections and documents. diff --git a/google/cloud/firestore_v1/base_transaction.py b/google/cloud/firestore_v1/base_transaction.py index 1453212459..b4e5dd0382 100644 --- a/google/cloud/firestore_v1/base_transaction.py +++ b/google/cloud/firestore_v1/base_transaction.py @@ -185,8 +185,5 @@ def _reset(self) -> None: def _pre_commit(self, transaction, *args, **kwargs) -> NoReturn: raise NotImplementedError - def _maybe_commit(self, transaction) -> NoReturn: - raise NotImplementedError - def __call__(self, transaction, *args, **kwargs): raise NotImplementedError diff --git a/google/cloud/firestore_v1/bulk_writer.py b/google/cloud/firestore_v1/bulk_writer.py index 9c7c0d5c9e..6d86f46965 100644 --- a/google/cloud/firestore_v1/bulk_writer.py +++ b/google/cloud/firestore_v1/bulk_writer.py @@ -185,7 +185,6 @@ def _retry_operation( self, operation: "BulkWriterOperation", ) -> concurrent.futures.Future: - delay: int = 0 if self._options.retry == BulkRetry.exponential: delay = operation.attempts**2 # pragma: NO COVER @@ -365,7 +364,6 @@ def flush(self): return while True: - # Queue any waiting operations and try our luck again. # This can happen if users add a number of records not divisible by # 20 and then call flush (which should be ~19 out of 20 use cases). @@ -469,7 +467,6 @@ def _send_until_queue_is_empty(self): self._schedule_ready_retries() while self._queued_batches: - # For FIFO order, add to the right of this deque (via `append`) and take # from the left (via `popleft`). operations: List[BulkWriterOperation] = self._queued_batches.popleft() diff --git a/google/cloud/firestore_v1/client.py b/google/cloud/firestore_v1/client.py index 73d1f268bb..05c135479b 100644 --- a/google/cloud/firestore_v1/client.py +++ b/google/cloud/firestore_v1/client.py @@ -69,8 +69,7 @@ class Client(BaseClient): OAuth2 Credentials to use for this client. If not passed, falls back to the default inferred from the environment. database (Optional[str]): The database name that the client targets. - For now, :attr:`DEFAULT_DATABASE` (the default value) is the - only valid database. + If not passed, falls back to :attr:`DEFAULT_DATABASE`. client_info (Optional[google.api_core.gapic_v1.client_info.ClientInfo]): The client info used to send a user-agent string along with API requests. If ``None``, then default info will be used. Generally, diff --git a/google/cloud/firestore_v1/collection.py b/google/cloud/firestore_v1/collection.py index 12e9ec883d..f6ba1833d6 100644 --- a/google/cloud/firestore_v1/collection.py +++ b/google/cloud/firestore_v1/collection.py @@ -31,7 +31,7 @@ from google.cloud.firestore_v1.transaction import Transaction -class CollectionReference(BaseCollectionReference): +class CollectionReference(BaseCollectionReference[query_mod.Query]): """A reference to a collection in a Firestore database. The collection may already exist or this class can facilitate creation diff --git a/google/cloud/firestore_v1/gapic_version.py b/google/cloud/firestore_v1/gapic_version.py index 16ae0e953c..a3c9255942 100644 --- a/google/cloud/firestore_v1/gapic_version.py +++ b/google/cloud/firestore_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "2.12.0" # {x-release-please-version} +__version__ = "2.13.0" # {x-release-please-version} diff --git a/google/cloud/firestore_v1/query.py b/google/cloud/firestore_v1/query.py index 7cabfcc5f9..d37964dce0 100644 --- a/google/cloud/firestore_v1/query.py +++ b/google/cloud/firestore_v1/query.py @@ -38,7 +38,10 @@ from google.cloud.firestore_v1 import document from google.cloud.firestore_v1.watch import Watch -from typing import Any, Callable, Generator, List, Optional, Type +from typing import Any, Callable, Generator, List, Optional, Type, TYPE_CHECKING + +if TYPE_CHECKING: # pragma: NO COVER + from google.cloud.firestore_v1.field_path import FieldPath class Query(BaseQuery): @@ -174,7 +177,6 @@ def get( def _chunkify( self, chunk_size: int ) -> Generator[List[DocumentSnapshot], None, None]: - max_to_return: Optional[int] = self._limit num_returned: int = 0 original: Query = self._copy() @@ -243,11 +245,42 @@ def count( """ Adds a count over the query. - :type alias: str - :param alias: (Optional) The alias for the count + :type alias: Optional[str] + :param alias: Optional name of the field to store the result of the aggregation into. + If not provided, Firestore will pick a default name following the format field_. """ return aggregation.AggregationQuery(self).count(alias=alias) + def sum( + self, field_ref: str | FieldPath, alias: str | None = None + ) -> Type["firestore_v1.aggregation.AggregationQuery"]: + """ + Adds a sum over the query. + + :type field_ref: Union[str, google.cloud.firestore_v1.field_path.FieldPath] + :param field_ref: The field to aggregate across. + + :type alias: Optional[str] + :param alias: Optional name of the field to store the result of the aggregation into. + If not provided, Firestore will pick a default name following the format field_. + """ + return aggregation.AggregationQuery(self).sum(field_ref, alias=alias) + + def avg( + self, field_ref: str | FieldPath, alias: str | None = None + ) -> Type["firestore_v1.aggregation.AggregationQuery"]: + """ + Adds an avg over the query. + + :type field_ref: [Union[str, google.cloud.firestore_v1.field_path.FieldPath] + :param field_ref: The field to aggregate across. + + :type alias: Optional[str] + :param alias: Optional name of the field to store the result of the aggregation into. + If not provided, Firestore will pick a default name following the format field_. + """ + return aggregation.AggregationQuery(self).avg(field_ref, alias=alias) + def stream( self, transaction=None, @@ -354,9 +387,9 @@ def on_snapshot(docs, changes, read_time): return Watch.for_query(self, callback, document.DocumentSnapshot) @staticmethod - def _get_collection_reference_class() -> Type[ - "firestore_v1.collection.CollectionReference" - ]: + def _get_collection_reference_class() -> ( + Type["firestore_v1.collection.CollectionReference"] + ): from google.cloud.firestore_v1.collection import CollectionReference return CollectionReference diff --git a/google/cloud/firestore_v1/services/firestore/async_client.py b/google/cloud/firestore_v1/services/firestore/async_client.py index b992d2afa9..a134b47f80 100644 --- a/google/cloud/firestore_v1/services/firestore/async_client.py +++ b/google/cloud/firestore_v1/services/firestore/async_client.py @@ -54,7 +54,7 @@ from google.cloud.firestore_v1.types import query from google.cloud.firestore_v1.types import write as gf_write from google.cloud.location import locations_pb2 # type: ignore -from google.longrunning import operations_pb2 +from google.longrunning import operations_pb2 # type: ignore from google.protobuf import timestamp_pb2 # type: ignore from google.rpc import status_pb2 # type: ignore from .transports.base import FirestoreTransport, DEFAULT_CLIENT_INFO @@ -64,6 +64,7 @@ class FirestoreAsyncClient: """The Cloud Firestore service. + Cloud Firestore is a fast, fully managed, serverless, cloud-native NoSQL document database that simplifies storing, syncing, and querying data for your mobile, web, and IoT apps at @@ -276,6 +277,7 @@ async def sample_get_document(): Returns: google.cloud.firestore_v1.types.Document: A Firestore document. + Must not exceed 1 MiB - 4 bytes. """ @@ -483,6 +485,7 @@ async def sample_update_document(): The fields to update. None of the field paths in the mask may contain a reserved name. + If the document exists on the server and has fields not referenced in the mask, they are left unchanged. @@ -502,6 +505,7 @@ async def sample_update_document(): Returns: google.cloud.firestore_v1.types.Document: A Firestore document. + Must not exceed 1 MiB - 4 bytes. """ @@ -673,6 +677,7 @@ def batch_get_documents( metadata: Sequence[Tuple[str, str]] = (), ) -> Awaitable[AsyncIterable[firestore.BatchGetDocumentsResponse]]: r"""Gets multiple documents. + Documents returned by this method are not guaranteed to be returned in the same order that they were requested. @@ -926,6 +931,7 @@ async def sample_commit(): should not be set. writes (:class:`MutableSequence[google.cloud.firestore_v1.types.Write]`): The writes to apply. + Always executed atomically and in order. This corresponds to the ``writes`` field @@ -1906,6 +1912,7 @@ async def sample_create_document(): Returns: google.cloud.firestore_v1.types.Document: A Firestore document. + Must not exceed 1 MiB - 4 bytes. """ diff --git a/google/cloud/firestore_v1/services/firestore/client.py b/google/cloud/firestore_v1/services/firestore/client.py index 23b50bd72e..bf1b75dddf 100644 --- a/google/cloud/firestore_v1/services/firestore/client.py +++ b/google/cloud/firestore_v1/services/firestore/client.py @@ -57,7 +57,7 @@ from google.cloud.firestore_v1.types import query from google.cloud.firestore_v1.types import write as gf_write from google.cloud.location import locations_pb2 # type: ignore -from google.longrunning import operations_pb2 +from google.longrunning import operations_pb2 # type: ignore from google.protobuf import timestamp_pb2 # type: ignore from google.rpc import status_pb2 # type: ignore from .transports.base import FirestoreTransport, DEFAULT_CLIENT_INFO @@ -103,6 +103,7 @@ def get_transport_class( class FirestoreClient(metaclass=FirestoreClientMeta): """The Cloud Firestore service. + Cloud Firestore is a fast, fully managed, serverless, cloud-native NoSQL document database that simplifies storing, syncing, and querying data for your mobile, web, and IoT apps at @@ -488,6 +489,7 @@ def sample_get_document(): Returns: google.cloud.firestore_v1.types.Document: A Firestore document. + Must not exceed 1 MiB - 4 bytes. """ @@ -673,6 +675,7 @@ def sample_update_document(): The fields to update. None of the field paths in the mask may contain a reserved name. + If the document exists on the server and has fields not referenced in the mask, they are left unchanged. @@ -692,6 +695,7 @@ def sample_update_document(): Returns: google.cloud.firestore_v1.types.Document: A Firestore document. + Must not exceed 1 MiB - 4 bytes. """ @@ -841,6 +845,7 @@ def batch_get_documents( metadata: Sequence[Tuple[str, str]] = (), ) -> Iterable[firestore.BatchGetDocumentsResponse]: r"""Gets multiple documents. + Documents returned by this method are not guaranteed to be returned in the same order that they were requested. @@ -1071,6 +1076,7 @@ def sample_commit(): should not be set. writes (MutableSequence[google.cloud.firestore_v1.types.Write]): The writes to apply. + Always executed atomically and in order. This corresponds to the ``writes`` field @@ -1954,6 +1960,7 @@ def sample_create_document(): Returns: google.cloud.firestore_v1.types.Document: A Firestore document. + Must not exceed 1 MiB - 4 bytes. """ diff --git a/google/cloud/firestore_v1/services/firestore/transports/base.py b/google/cloud/firestore_v1/services/firestore/transports/base.py index 0637e608f4..2230fdc1d2 100644 --- a/google/cloud/firestore_v1/services/firestore/transports/base.py +++ b/google/cloud/firestore_v1/services/firestore/transports/base.py @@ -30,7 +30,7 @@ from google.cloud.firestore_v1.types import document as gf_document from google.cloud.firestore_v1.types import firestore from google.cloud.location import locations_pb2 # type: ignore -from google.longrunning import operations_pb2 +from google.longrunning import operations_pb2 # type: ignore from google.protobuf import empty_pb2 # type: ignore DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( diff --git a/google/cloud/firestore_v1/services/firestore/transports/grpc.py b/google/cloud/firestore_v1/services/firestore/transports/grpc.py index d6d34cd3d3..01c0227483 100644 --- a/google/cloud/firestore_v1/services/firestore/transports/grpc.py +++ b/google/cloud/firestore_v1/services/firestore/transports/grpc.py @@ -28,7 +28,7 @@ from google.cloud.firestore_v1.types import document as gf_document from google.cloud.firestore_v1.types import firestore from google.cloud.location import locations_pb2 # type: ignore -from google.longrunning import operations_pb2 +from google.longrunning import operations_pb2 # type: ignore from google.protobuf import empty_pb2 # type: ignore from .base import FirestoreTransport, DEFAULT_CLIENT_INFO @@ -37,6 +37,7 @@ class FirestoreGrpcTransport(FirestoreTransport): """gRPC backend transport for Firestore. The Cloud Firestore service. + Cloud Firestore is a fast, fully managed, serverless, cloud-native NoSQL document database that simplifies storing, syncing, and querying data for your mobile, web, and IoT apps at @@ -354,6 +355,7 @@ def batch_get_documents( r"""Return a callable for the batch get documents method over gRPC. Gets multiple documents. + Documents returned by this method are not guaranteed to be returned in the same order that they were requested. diff --git a/google/cloud/firestore_v1/services/firestore/transports/grpc_asyncio.py b/google/cloud/firestore_v1/services/firestore/transports/grpc_asyncio.py index 79d8c0789b..d0366356de 100644 --- a/google/cloud/firestore_v1/services/firestore/transports/grpc_asyncio.py +++ b/google/cloud/firestore_v1/services/firestore/transports/grpc_asyncio.py @@ -28,7 +28,7 @@ from google.cloud.firestore_v1.types import document as gf_document from google.cloud.firestore_v1.types import firestore from google.cloud.location import locations_pb2 # type: ignore -from google.longrunning import operations_pb2 +from google.longrunning import operations_pb2 # type: ignore from google.protobuf import empty_pb2 # type: ignore from .base import FirestoreTransport, DEFAULT_CLIENT_INFO from .grpc import FirestoreGrpcTransport @@ -38,6 +38,7 @@ class FirestoreGrpcAsyncIOTransport(FirestoreTransport): """gRPC AsyncIO backend transport for Firestore. The Cloud Firestore service. + Cloud Firestore is a fast, fully managed, serverless, cloud-native NoSQL document database that simplifies storing, syncing, and querying data for your mobile, web, and IoT apps at @@ -360,6 +361,7 @@ def batch_get_documents( r"""Return a callable for the batch get documents method over gRPC. Gets multiple documents. + Documents returned by this method are not guaranteed to be returned in the same order that they were requested. diff --git a/google/cloud/firestore_v1/services/firestore/transports/rest.py b/google/cloud/firestore_v1/services/firestore/transports/rest.py index 47b84e5581..bfa7dc45d1 100644 --- a/google/cloud/firestore_v1/services/firestore/transports/rest.py +++ b/google/cloud/firestore_v1/services/firestore/transports/rest.py @@ -28,7 +28,6 @@ from google.protobuf import json_format from google.cloud.location import locations_pb2 # type: ignore -from google.longrunning import operations_pb2 from requests import __version__ as requests_version import dataclasses import re @@ -45,6 +44,7 @@ from google.cloud.firestore_v1.types import document as gf_document from google.cloud.firestore_v1.types import firestore from google.protobuf import empty_pb2 # type: ignore +from google.longrunning import operations_pb2 # type: ignore from .base import FirestoreTransport, DEFAULT_CLIENT_INFO as BASE_DEFAULT_CLIENT_INFO @@ -567,6 +567,7 @@ class FirestoreRestTransport(FirestoreTransport): """REST backend transport for Firestore. The Cloud Firestore service. + Cloud Firestore is a fast, fully managed, serverless, cloud-native NoSQL document database that simplifies storing, syncing, and querying data for your mobile, web, and IoT apps at @@ -1091,6 +1092,7 @@ def __call__( Returns: ~.document.Document: A Firestore document. + Must not exceed 1 MiB - 4 bytes. """ @@ -1264,6 +1266,7 @@ def __call__( Returns: ~.document.Document: A Firestore document. + Must not exceed 1 MiB - 4 bytes. """ @@ -1958,6 +1961,7 @@ def __call__( Returns: ~.gf_document.Document: A Firestore document. + Must not exceed 1 MiB - 4 bytes. """ @@ -2176,7 +2180,6 @@ def __call__( timeout: Optional[float] = None, metadata: Sequence[Tuple[str, str]] = (), ) -> None: - r"""Call the cancel operation method over HTTP. Args: @@ -2242,7 +2245,6 @@ def __call__( timeout: Optional[float] = None, metadata: Sequence[Tuple[str, str]] = (), ) -> None: - r"""Call the delete operation method over HTTP. Args: @@ -2305,7 +2307,6 @@ def __call__( timeout: Optional[float] = None, metadata: Sequence[Tuple[str, str]] = (), ) -> operations_pb2.Operation: - r"""Call the get operation method over HTTP. Args: @@ -2372,7 +2373,6 @@ def __call__( timeout: Optional[float] = None, metadata: Sequence[Tuple[str, str]] = (), ) -> operations_pb2.ListOperationsResponse: - r"""Call the list operations method over HTTP. Args: diff --git a/google/cloud/firestore_v1/transaction.py b/google/cloud/firestore_v1/transaction.py index cfcb968c8f..3c175a4ced 100644 --- a/google/cloud/firestore_v1/transaction.py +++ b/google/cloud/firestore_v1/transaction.py @@ -44,7 +44,7 @@ # Types needed only for Type Hints from google.cloud.firestore_v1.base_document import DocumentSnapshot from google.cloud.firestore_v1.types import CommitResponse -from typing import Any, Callable, Generator, Optional +from typing import Any, Callable, Generator class Transaction(batch.WriteBatch, BaseTransaction): @@ -108,6 +108,7 @@ def _rollback(self) -> None: Raises: ValueError: If no transaction is in progress. + google.api_core.exceptions.GoogleAPICallError: If the rollback fails. """ if not self.in_progress: raise ValueError(_CANT_ROLLBACK) @@ -122,6 +123,7 @@ def _rollback(self) -> None: metadata=self._client._rpc_metadata, ) finally: + # clean up, even if rollback fails self._clean_up() def _commit(self) -> list: @@ -214,10 +216,6 @@ def __init__(self, to_wrap) -> None: def _pre_commit(self, transaction: Transaction, *args, **kwargs) -> Any: """Begin transaction and call the wrapped callable. - If the callable raises an exception, the transaction will be rolled - back. If not, the transaction will be "ready" for ``Commit`` (i.e. - it will have staged writes). - Args: transaction (:class:`~google.cloud.firestore_v1.transaction.Transaction`): @@ -241,41 +239,7 @@ def _pre_commit(self, transaction: Transaction, *args, **kwargs) -> Any: self.current_id = transaction._id if self.retry_id is None: self.retry_id = self.current_id - try: - return self.to_wrap(transaction, *args, **kwargs) - except: # noqa - # NOTE: If ``rollback`` fails this will lose the information - # from the original failure. - transaction._rollback() - raise - - def _maybe_commit(self, transaction: Transaction) -> Optional[bool]: - """Try to commit the transaction. - - If the transaction is read-write and the ``Commit`` fails with the - ``ABORTED`` status code, it will be retried. Any other failure will - not be caught. - - Args: - transaction - (:class:`~google.cloud.firestore_v1.transaction.Transaction`): - The transaction to be ``Commit``-ed. - - Returns: - bool: Indicating if the commit succeeded. - """ - try: - transaction._commit() - return True - except exceptions.GoogleAPICallError as exc: - if transaction._read_only: - raise - - if isinstance(exc, exceptions.Aborted): - # If a read-write transaction returns ABORTED, retry. - return False - else: - raise + return self.to_wrap(transaction, *args, **kwargs) def __call__(self, transaction: Transaction, *args, **kwargs): """Execute the wrapped callable within a transaction. @@ -297,22 +261,34 @@ def __call__(self, transaction: Transaction, *args, **kwargs): ``max_attempts``. """ self._reset() + retryable_exceptions = ( + (exceptions.Aborted) if not transaction._read_only else () + ) + last_exc = None - for attempt in range(transaction._max_attempts): - result = self._pre_commit(transaction, *args, **kwargs) - succeeded = self._maybe_commit(transaction) - if succeeded: - return result - - # Subsequent requests will use the failed transaction ID as part of - # the ``BeginTransactionRequest`` when restarting this transaction - # (via ``options.retry_transaction``). This preserves the "spot in - # line" of the transaction, so exponential backoff is not required - # in this case. - - transaction._rollback() - msg = _EXCEED_ATTEMPTS_TEMPLATE.format(transaction._max_attempts) - raise ValueError(msg) + try: + for attempt in range(transaction._max_attempts): + result = self._pre_commit(transaction, *args, **kwargs) + try: + transaction._commit() + return result + except retryable_exceptions as exc: + last_exc = exc + # Retry attempts that result in retryable exceptions + # Subsequent requests will use the failed transaction ID as part of + # the ``BeginTransactionRequest`` when restarting this transaction + # (via ``options.retry_transaction``). This preserves the "spot in + # line" of the transaction, so exponential backoff is not required + # in this case. + # retries exhausted + # wrap the last exception in a ValueError before raising + msg = _EXCEED_ATTEMPTS_TEMPLATE.format(transaction._max_attempts) + raise ValueError(msg) from last_exc + except BaseException: # noqa: B901 + # rollback the transaction on any error + # errors raised during _rollback will be chained to the original error through __context__ + transaction._rollback() + raise def transactional(to_wrap: Callable) -> _Transactional: diff --git a/google/cloud/firestore_v1/types/common.py b/google/cloud/firestore_v1/types/common.py index 84c5541b38..da9a02befb 100644 --- a/google/cloud/firestore_v1/types/common.py +++ b/google/cloud/firestore_v1/types/common.py @@ -139,6 +139,7 @@ class ReadOnly(proto.Message): Attributes: read_time (google.protobuf.timestamp_pb2.Timestamp): Reads documents at the given time. + This must be a microsecond precision timestamp within the past one hour, or if Point-in-Time Recovery is enabled, can additionally be a whole diff --git a/google/cloud/firestore_v1/types/document.py b/google/cloud/firestore_v1/types/document.py index 8c0477239f..2476d2d131 100644 --- a/google/cloud/firestore_v1/types/document.py +++ b/google/cloud/firestore_v1/types/document.py @@ -37,6 +37,7 @@ class Document(proto.Message): r"""A Firestore document. + Must not exceed 1 MiB - 4 bytes. Attributes: @@ -137,12 +138,14 @@ class Value(proto.Message): This field is a member of `oneof`_ ``value_type``. timestamp_value (google.protobuf.timestamp_pb2.Timestamp): A timestamp value. + Precise only to microseconds. When stored, any additional precision is rounded down. This field is a member of `oneof`_ ``value_type``. string_value (str): A string value. + The string, represented as UTF-8, must not exceed 1 MiB - 89 bytes. Only the first 1,500 bytes of the UTF-8 representation are considered @@ -151,6 +154,7 @@ class Value(proto.Message): This field is a member of `oneof`_ ``value_type``. bytes_value (bytes): A bytes value. + Must not exceed 1 MiB - 89 bytes. Only the first 1,500 bytes are considered by queries. @@ -168,6 +172,7 @@ class Value(proto.Message): This field is a member of `oneof`_ ``value_type``. array_value (google.cloud.firestore_v1.types.ArrayValue): An array value. + Cannot directly contain another array value, though can contain an map which contains another array. diff --git a/google/cloud/firestore_v1/types/firestore.py b/google/cloud/firestore_v1/types/firestore.py index 29ee19b6aa..bde5556afc 100644 --- a/google/cloud/firestore_v1/types/firestore.py +++ b/google/cloud/firestore_v1/types/firestore.py @@ -189,6 +189,7 @@ class ListDocumentsRequest(proto.Message): This field is a member of `oneof`_ ``consistency_selector``. read_time (google.protobuf.timestamp_pb2.Timestamp): Perform the read at the provided time. + This must be a microsecond precision timestamp within the past one hour, or if Point-in-Time Recovery is enabled, can additionally be a whole @@ -347,6 +348,7 @@ class UpdateDocumentRequest(proto.Message): The fields to update. None of the field paths in the mask may contain a reserved name. + If the document exists on the server and has fields not referenced in the mask, they are left unchanged. @@ -599,6 +601,7 @@ class CommitRequest(proto.Message): ``projects/{project_id}/databases/{database_id}``. writes (MutableSequence[google.cloud.firestore_v1.types.Write]): The writes to apply. + Always executed atomically and in order. transaction (bytes): If set, applies all writes in this @@ -627,6 +630,7 @@ class CommitResponse(proto.Message): Attributes: write_results (MutableSequence[google.cloud.firestore_v1.types.WriteResult]): The result of applying the writes. + This i-th write result corresponds to the i-th write in the request. commit_time (google.protobuf.timestamp_pb2.Timestamp): @@ -850,6 +854,7 @@ class RunAggregationQueryRequest(proto.Message): This field is a member of `oneof`_ ``consistency_selector``. read_time (google.protobuf.timestamp_pb2.Timestamp): Executes the query at the given timestamp. + This must be a microsecond precision timestamp within the past one hour, or if Point-in-Time Recovery is enabled, can additionally be a whole @@ -894,6 +899,7 @@ class RunAggregationQueryResponse(proto.Message): Attributes: result (google.cloud.firestore_v1.types.AggregationResult): A single aggregation result. + Not present when reporting partial progress. transaction (bytes): The transaction that was started as part of @@ -1050,7 +1056,8 @@ class PartitionQueryResponse(proto.Message): - query, start_at B An empty result may indicate that the query has too few - results to be partitioned. + results to be partitioned, or that the query is not yet + supported for partitioning. next_page_token (str): A page token that may be used to request an additional set of results, up to the number specified by @@ -1098,6 +1105,7 @@ class WriteRequest(proto.Message): left empty, a new write stream will be created. writes (MutableSequence[google.cloud.firestore_v1.types.Write]): The writes to apply. + Always executed atomically and in order. This must be empty on the first request. This may be empty on the last request. @@ -1160,9 +1168,11 @@ class WriteResponse(proto.Message): A token that represents the position of this response in the stream. This can be used by a client to resume the stream at this point. + This field is always set. write_results (MutableSequence[google.cloud.firestore_v1.types.WriteResult]): The result of applying the writes. + This i-th write result corresponds to the i-th write in the request. commit_time (google.protobuf.timestamp_pb2.Timestamp): @@ -1351,9 +1361,26 @@ class Target(proto.Message): This field is a member of `oneof`_ ``resume_type``. target_id (int): - The target ID that identifies the target on - the stream. Must be a positive number and - non-zero. + The target ID that identifies the target on the stream. Must + be a positive number and non-zero. + + If ``target_id`` is 0 (or unspecified), the server will + assign an ID for this target and return that in a + ``TargetChange::ADD`` event. Once a target with + ``target_id=0`` is added, all subsequent targets must also + have ``target_id=0``. If an ``AddTarget`` request with + ``target_id != 0`` is sent to the server after a target with + ``target_id=0`` is added, the server will immediately send a + response with a ``TargetChange::Remove`` event. + + Note that if the client sends multiple ``AddTarget`` + requests without an ID, the order of IDs returned in + ``TargetChage.target_ids`` are undefined. Therefore, clients + should provide a target ID instead of relying on the server + to assign one. + + If ``target_id`` is non-zero, there must not be an existing + active target on this stream with the same ID. once (bool): If the target should be removed once it is current and consistent. @@ -1462,6 +1489,7 @@ class TargetChange(proto.Message): The type of change that occurred. target_ids (MutableSequence[int]): The target IDs of targets that have changed. + If empty, the change applies to all targets. The order of the target IDs is not defined. @@ -1633,6 +1661,7 @@ class BatchWriteRequest(proto.Message): ``projects/{project_id}/databases/{database_id}``. writes (MutableSequence[google.cloud.firestore_v1.types.Write]): The writes to apply. + Method does not apply writes atomically and does not guarantee ordering. Each write succeeds or fails independently. You cannot write to the @@ -1664,10 +1693,12 @@ class BatchWriteResponse(proto.Message): Attributes: write_results (MutableSequence[google.cloud.firestore_v1.types.WriteResult]): The result of applying the writes. + This i-th write result corresponds to the i-th write in the request. status (MutableSequence[google.rpc.status_pb2.Status]): The status of applying the writes. + This i-th write status corresponds to the i-th write in the request. """ diff --git a/google/cloud/firestore_v1/types/query.py b/google/cloud/firestore_v1/types/query.py index ac1b246260..bca04d71ea 100644 --- a/google/cloud/firestore_v1/types/query.py +++ b/google/cloud/firestore_v1/types/query.py @@ -472,11 +472,12 @@ class FieldReference(proto.Message): Attributes: field_path (str): - The relative path of the document being referenced. + A reference to a field in a document. Requires: - - Conform to [document field + - MUST be a dot-delimited (``.``) string of segments, where + each segment conforms to [document field name][google.firestore.v1.Document.fields] limitations. """ @@ -763,6 +764,7 @@ class Cursor(proto.Message): The values that represent a position, in the order they appear in the order by clause of a query. + Can contain fewer values than specified in the order by clause. before (bool): diff --git a/google/cloud/firestore_v1/types/write.py b/google/cloud/firestore_v1/types/write.py index 09b75240d6..e4e9e69b33 100644 --- a/google/cloud/firestore_v1/types/write.py +++ b/google/cloud/firestore_v1/types/write.py @@ -83,6 +83,7 @@ class Write(proto.Message): and in order. current_document (google.cloud.firestore_v1.types.Precondition): An optional precondition on the document. + The write will fail if this is set and not met by the target document. """ diff --git a/google/cloud/firestore_v1/watch.py b/google/cloud/firestore_v1/watch.py index d1ce5a57af..555b895019 100644 --- a/google/cloud/firestore_v1/watch.py +++ b/google/cloud/firestore_v1/watch.py @@ -230,7 +230,6 @@ def __init__( self._init_stream() def _init_stream(self): - rpc_request = self._get_rpc_request self._rpc = ResumableBidiRpc( @@ -401,7 +400,9 @@ def _on_snapshot_target_change_remove(self, target_change): error_message = "Error %s: %s" % (code, message) - raise RuntimeError(error_message) + raise RuntimeError(error_message) from exceptions.from_grpc_status( + code, message + ) def _on_snapshot_target_change_reset(self, target_change): # Whatever changes have happened so far no longer matter. @@ -443,7 +444,6 @@ def on_snapshot(self, proto): which = pb.WhichOneof("response_type") if which == "target_change": - target_change_type = pb.target_change.target_change_type _LOGGER.debug(f"on_snapshot: target change: {target_change_type}") diff --git a/noxfile.py b/noxfile.py index e90d8b8d85..a620dad223 100644 --- a/noxfile.py +++ b/noxfile.py @@ -17,23 +17,24 @@ # Generated by synthtool. DO NOT EDIT! from __future__ import absolute_import + import os import pathlib import re import shutil +from typing import Dict, List import warnings import nox FLAKE8_VERSION = "flake8==6.1.0" -PYTYPE_VERSION = "pytype==2020.7.24" -BLACK_VERSION = "black==22.3.0" -ISORT_VERSION = "isort==5.10.1" +BLACK_VERSION = "black[jupyter]==23.7.0" +ISORT_VERSION = "isort==5.11.0" LINT_PATHS = ["docs", "google", "tests", "noxfile.py", "setup.py"] DEFAULT_PYTHON_VERSION = "3.8" -UNIT_TEST_PYTHON_VERSIONS = ["3.7", "3.8", "3.9", "3.10", "3.11"] +UNIT_TEST_PYTHON_VERSIONS: List[str] = ["3.7", "3.8", "3.9", "3.10", "3.11"] UNIT_TEST_STANDARD_DEPENDENCIES = [ "mock", "asyncmock", @@ -41,27 +42,29 @@ "pytest-cov", "pytest-asyncio", ] -UNIT_TEST_EXTERNAL_DEPENDENCIES = [ +UNIT_TEST_EXTERNAL_DEPENDENCIES: List[str] = [ "aiounittest", + "six", ] -UNIT_TEST_LOCAL_DEPENDENCIES = [] -UNIT_TEST_DEPENDENCIES = [] -UNIT_TEST_EXTRAS = [] -UNIT_TEST_EXTRAS_BY_PYTHON = {} +UNIT_TEST_LOCAL_DEPENDENCIES: List[str] = [] +UNIT_TEST_DEPENDENCIES: List[str] = [] +UNIT_TEST_EXTRAS: List[str] = [] +UNIT_TEST_EXTRAS_BY_PYTHON: Dict[str, List[str]] = {} -SYSTEM_TEST_PYTHON_VERSIONS = ["3.7"] -SYSTEM_TEST_STANDARD_DEPENDENCIES = [ +SYSTEM_TEST_PYTHON_VERSIONS: List[str] = ["3.7"] +SYSTEM_TEST_STANDARD_DEPENDENCIES: List[str] = [ "mock", "pytest", "google-cloud-testutils", ] -SYSTEM_TEST_EXTERNAL_DEPENDENCIES = [ +SYSTEM_TEST_EXTERNAL_DEPENDENCIES: List[str] = [ "pytest-asyncio", + "six", ] -SYSTEM_TEST_LOCAL_DEPENDENCIES = [] -SYSTEM_TEST_DEPENDENCIES = [] -SYSTEM_TEST_EXTRAS = [] -SYSTEM_TEST_EXTRAS_BY_PYTHON = {} +SYSTEM_TEST_LOCAL_DEPENDENCIES: List[str] = [] +SYSTEM_TEST_DEPENDENCIES: List[str] = [] +SYSTEM_TEST_EXTRAS: List[str] = [] +SYSTEM_TEST_EXTRAS_BY_PYTHON: Dict[str, List[str]] = {} CURRENT_DIRECTORY = pathlib.Path(__file__).parent.absolute() @@ -76,6 +79,7 @@ "lint_setup_py", "blacken", "docs", + "format", ] # Error if a python version is missing @@ -212,7 +216,6 @@ def unit(session): def install_systemtest_dependencies(session, *constraints): - # Use pre-release gRPC for system tests. # Exclude version 1.52.0rc1 which has a known issue. # See https://github.com/grpc/grpc/issues/32163 diff --git a/owlbot.py b/owlbot.py index 4076f0c8dd..a125593ef3 100644 --- a/owlbot.py +++ b/owlbot.py @@ -139,8 +139,8 @@ def update_fixup_scripts(library): templated_files = common.py_library( samples=False, # set to True only if there are samples system_test_python_versions=["3.7"], - unit_test_external_dependencies=["aiounittest"], - system_test_external_dependencies=["pytest-asyncio"], + unit_test_external_dependencies=["aiounittest", "six"], + system_test_external_dependencies=["pytest-asyncio", "six"], microgenerator=True, cov_level=100, split_system_tests=True, diff --git a/tests/system/test_system.py b/tests/system/test_system.py index b48eb77f59..12e3b87b22 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -564,10 +564,14 @@ def query_docs(client, database): @pytest.fixture -def query(query_docs): - collection, stored, allowed_vals = query_docs - query = collection.where(filter=FieldFilter("a", "==", 1)) - return query +def collection(query_docs): + collection, _, _ = query_docs + return collection + + +@pytest.fixture +def query(collection): + return collection.where(filter=FieldFilter("a", "==", 1)) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) @@ -1411,7 +1415,6 @@ def _persist_documents( def _do_recursive_delete(client, bulk_writer, empty_philosophers=False): - if empty_philosophers: doc_paths = philosophers = [] else: @@ -1823,7 +1826,6 @@ def test_count_query_stream_default_alias(query, database): @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) def test_count_query_stream_with_alias(query, database): - count_query = query.count(alias="total") for result in count_query.stream(): for aggregation_result in result: @@ -1881,77 +1883,283 @@ def test_count_query_stream_empty_aggregation(query, database): assert "Aggregations can not be empty" in exc_info.value.message -@firestore.transactional -def create_in_transaction(collection_id, transaction, cleanup): - collection = client.collection(collection_id) +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +def test_count_query_with_start_at(query, database): + """ + Ensure that count aggregation queries work when chained with a start_at - query = collection.where(filter=FieldFilter("a", "==", 1)) - count_query = query.count() + eg `col.where(...).startAt(...).count()` + """ + result = query.get() + start_doc = result[1] + # find count excluding first result + expected_count = len(result) - 1 + # start new query that starts at the second result + count_query = query.start_at(start_doc).count("a") + # ensure that the first doc was skipped in sum aggregation + for result in count_query.stream(): + for aggregation_result in result: + assert aggregation_result.value == expected_count - result = count_query.get(transaction=transaction) + +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +def test_sum_query_get_default_alias(collection, database): + sum_query = collection.sum("stats.product") + result = sum_query.get() + assert len(result) == 1 for r in result[0]: - assert r.value <= 2 - if r.value < 2: - document_id_3 = "doc3" + UNIQUE_RESOURCE_ID - document_3 = client.document(collection_id, document_id_3) - cleanup(document_3.delete) - document_3.create({"a": 1}) - else: - raise ValueError("Collection can't have more than 2 documents") + assert r.alias == "field_1" + assert r.value == 100 -@firestore.transactional -def create_in_transaction_helper(transaction, client, collection_id, cleanup, database): - collection = client.collection(collection_id) - query = collection.where(filter=FieldFilter("a", "==", 1)) - count_query = query.count() - result = count_query.get(transaction=transaction) +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +def test_sum_query_get_with_alias(collection, database): + sum_query = collection.sum("stats.product", alias="total") + result = sum_query.get() + assert len(result) == 1 + for r in result[0]: + assert r.alias == "total" + assert r.value == 100 + +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +def test_sum_query_get_with_limit(collection, database): + # sum without limit + sum_query = collection.sum("stats.product", alias="total") + result = sum_query.get() + assert len(result) == 1 for r in result[0]: - if r.value < 2: - document_id_3 = "doc3" + UNIQUE_RESOURCE_ID - document_3 = client.document(collection_id, document_id_3) - cleanup(document_3.delete) - document_3.create({"a": 1}) - else: # transaction is rolled back - raise ValueError("Collection can't have more than 2 docs") + assert r.alias == "total" + assert r.value == 100 + + # sum with limit + # limit query = [0,0,0,0,0,0,0,0,0,1,2,2] + sum_query = collection.limit(12).sum("stats.product", alias="total") + + result = sum_query.get() + assert len(result) == 1 + for r in result[0]: + assert r.alias == "total" + assert r.value == 5 @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_count_query_in_transaction(client, cleanup, database): - collection_id = "doc-create" + UNIQUE_RESOURCE_ID - document_id_1 = "doc1" + UNIQUE_RESOURCE_ID - document_id_2 = "doc2" + UNIQUE_RESOURCE_ID +def test_sum_query_get_multiple_aggregations(collection, database): + sum_query = collection.sum("stats.product", alias="total").sum( + "stats.product", alias="all" + ) - document_1 = client.document(collection_id, document_id_1) - document_2 = client.document(collection_id, document_id_2) + result = sum_query.get() + assert len(result[0]) == 2 - cleanup(document_1.delete) - cleanup(document_2.delete) + expected_aliases = ["total", "all"] + found_alias = set( + [r.alias for r in result[0]] + ) # ensure unique elements in the result + assert len(found_alias) == 2 + assert found_alias == set(expected_aliases) - document_1.create({"a": 1}) - document_2.create({"a": 1}) - transaction = client.transaction() +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +def test_sum_query_stream_default_alias(collection, database): + sum_query = collection.sum("stats.product") + for result in sum_query.stream(): + for aggregation_result in result: + assert aggregation_result.alias == "field_1" + assert aggregation_result.value == 100 - with pytest.raises(ValueError) as exc: - create_in_transaction_helper( - transaction, client, collection_id, cleanup, database - ) - assert str(exc.value) == "Collection can't have more than 2 docs" - collection = client.collection(collection_id) +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +def test_sum_query_stream_with_alias(collection, database): + sum_query = collection.sum("stats.product", alias="total") + for result in sum_query.stream(): + for aggregation_result in result: + assert aggregation_result.alias == "total" + assert aggregation_result.value == 100 - query = collection.where(filter=FieldFilter("a", "==", 1)) - count_query = query.count() - result = count_query.get() + +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +def test_sum_query_stream_with_limit(collection, database): + # sum without limit + sum_query = collection.sum("stats.product", alias="total") + for result in sum_query.stream(): + for aggregation_result in result: + assert aggregation_result.alias == "total" + assert aggregation_result.value == 100 + + # sum with limit + sum_query = collection.limit(12).sum("stats.product", alias="total") + + for result in sum_query.stream(): + for aggregation_result in result: + assert aggregation_result.alias == "total" + assert aggregation_result.value == 5 + + +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +def test_sum_query_stream_multiple_aggregations(collection, database): + sum_query = collection.sum("stats.product", alias="total").sum( + "stats.product", alias="all" + ) + + for result in sum_query.stream(): + for aggregation_result in result: + assert aggregation_result.alias in ["total", "all"] + + +# tests for issue reported in b/306241058 +# we will skip test in client for now, until backend fix is implemented +@pytest.mark.skip(reason="backend fix required") +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +def test_sum_query_with_start_at(query, database): + """ + Ensure that sum aggregation queries work when chained with a start_at + + eg `col.where(...).startAt(...).sum()` + """ + result = query.get() + start_doc = result[1] + # find sum excluding first result + expected_sum = sum([doc.get("a") for doc in result[1:]]) + # start new query that starts at the second result + sum_result = query.start_at(start_doc).sum("a").get() + assert len(sum_result) == 1 + # ensure that the first doc was skipped in sum aggregation + assert sum_result[0].value == expected_sum + + +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +def test_avg_query_get_default_alias(collection, database): + avg_query = collection.avg("stats.product") + result = avg_query.get() + assert len(result) == 1 for r in result[0]: - assert r.value == 2 # there are still only 2 docs + assert r.alias == "field_1" + assert r.value == 4.0 + assert isinstance(r.value, float) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_query_with_and_composite_filter(query_docs, database): - collection, stored, allowed_vals = query_docs +def test_avg_query_get_with_alias(collection, database): + avg_query = collection.avg("stats.product", alias="total") + result = avg_query.get() + assert len(result) == 1 + for r in result[0]: + assert r.alias == "total" + assert r.value == 4 + + +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +def test_avg_query_get_with_limit(collection, database): + # avg without limit + avg_query = collection.avg("stats.product", alias="total") + result = avg_query.get() + assert len(result) == 1 + for r in result[0]: + assert r.alias == "total" + assert r.value == 4.0 + + # avg with limit + # limit result = [0,0,0,0,0,0,0,0,0,1,2,2] + avg_query = collection.limit(12).avg("stats.product", alias="total") + + result = avg_query.get() + assert len(result) == 1 + for r in result[0]: + assert r.alias == "total" + assert r.value == 5 / 12 + assert isinstance(r.value, float) + + +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +def test_avg_query_get_multiple_aggregations(collection, database): + avg_query = collection.avg("stats.product", alias="total").avg( + "stats.product", alias="all" + ) + + result = avg_query.get() + assert len(result[0]) == 2 + + expected_aliases = ["total", "all"] + found_alias = set( + [r.alias for r in result[0]] + ) # ensure unique elements in the result + assert len(found_alias) == 2 + assert found_alias == set(expected_aliases) + + +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +def test_avg_query_stream_default_alias(collection, database): + avg_query = collection.avg("stats.product") + for result in avg_query.stream(): + for aggregation_result in result: + assert aggregation_result.alias == "field_1" + assert aggregation_result.value == 4 + + +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +def test_avg_query_stream_with_alias(collection, database): + avg_query = collection.avg("stats.product", alias="total") + for result in avg_query.stream(): + for aggregation_result in result: + assert aggregation_result.alias == "total" + assert aggregation_result.value == 4 + + +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +def test_avg_query_stream_with_limit(collection, database): + # avg without limit + avg_query = collection.avg("stats.product", alias="total") + for result in avg_query.stream(): + for aggregation_result in result: + assert aggregation_result.alias == "total" + assert aggregation_result.value == 4 + + # avg with limit + avg_query = collection.limit(12).avg("stats.product", alias="total") + + for result in avg_query.stream(): + for aggregation_result in result: + assert aggregation_result.alias == "total" + assert aggregation_result.value == 5 / 12 + + +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +def test_avg_query_stream_multiple_aggregations(collection, database): + avg_query = collection.avg("stats.product", alias="total").avg( + "stats.product", alias="all" + ) + + for result in avg_query.stream(): + for aggregation_result in result: + assert aggregation_result.alias in ["total", "all"] + + +# tests for issue reported in b/306241058 +# we will skip test in client for now, until backend fix is implemented +@pytest.mark.skip(reason="backend fix required") +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +def test_avg_query_with_start_at(query, database): + """ + Ensure that avg aggregation queries work when chained with a start_at + + eg `col.where(...).startAt(...).avg()` + """ + from statistics import mean + + result = query.get() + start_doc = result[1] + # find average, excluding first result + expected_avg = mean([doc.get("a") for doc in result[1:]]) + # start new query that starts at the second result + avg_result = query.start_at(start_doc).avg("a").get() + assert len(avg_result) == 1 + # ensure that the first doc was skipped in avg aggregation + assert avg_result[0].value == expected_avg + + +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +def test_query_with_and_composite_filter(collection, database): and_filter = And( filters=[ FieldFilter("stats.product", ">", 5), @@ -1966,8 +2174,7 @@ def test_query_with_and_composite_filter(query_docs, database): @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_query_with_or_composite_filter(query_docs, database): - collection, stored, allowed_vals = query_docs +def test_query_with_or_composite_filter(collection, database): or_filter = Or( filters=[ FieldFilter("stats.product", ">", 5), @@ -1990,8 +2197,7 @@ def test_query_with_or_composite_filter(query_docs, database): @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_query_with_complex_composite_filter(query_docs, database): - collection, stored, allowed_vals = query_docs +def test_query_with_complex_composite_filter(collection, database): field_filter = FieldFilter("b", "==", 0) or_filter = Or( filters=[FieldFilter("stats.sum", "==", 0), FieldFilter("stats.sum", "==", 4)] @@ -2035,48 +2241,140 @@ def test_query_with_complex_composite_filter(query_docs, database): assert b_not_3 is True +@pytest.mark.parametrize( + "aggregation_type,aggregation_args,expected", + [("count", (), 3), ("sum", ("b"), 12), ("avg", ("b"), 4)], +) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_or_query_in_transaction(client, cleanup, database): +def test_aggregation_query_in_transaction( + client, cleanup, database, aggregation_type, aggregation_args, expected +): + """ + Test creating an aggregation query inside a transaction + Should send transaction id along with request. Results should be consistent with non-transactional query + """ collection_id = "doc-create" + UNIQUE_RESOURCE_ID - document_id_1 = "doc1" + UNIQUE_RESOURCE_ID - document_id_2 = "doc2" + UNIQUE_RESOURCE_ID + doc_ids = [f"doc{i}" + UNIQUE_RESOURCE_ID for i in range(4)] + doc_refs = [client.document(collection_id, doc_id) for doc_id in doc_ids] + for doc_ref in doc_refs: + cleanup(doc_ref.delete) + doc_refs[0].create({"a": 3, "b": 1}) + doc_refs[1].create({"a": 5, "b": 1}) + doc_refs[2].create({"a": 5, "b": 10}) + doc_refs[3].create({"a": 10, "b": 0}) # should be ignored by query - document_1 = client.document(collection_id, document_id_1) - document_2 = client.document(collection_id, document_id_2) + collection = client.collection(collection_id) + query = collection.where(filter=FieldFilter("b", ">", 0)) + aggregation_query = getattr(query, aggregation_type)(*aggregation_args) - cleanup(document_1.delete) - cleanup(document_2.delete) + with client.transaction() as transaction: + # should fail if transaction has not been initiated + with pytest.raises(ValueError): + aggregation_query.get(transaction=transaction) - document_1.create({"a": 1, "b": 2}) - document_2.create({"a": 1, "b": 1}) + # should work when transaction is initiated through transactional decorator + @firestore.transactional + def in_transaction(transaction): + global inner_fn_ran + result = aggregation_query.get(transaction=transaction) + assert len(result) == 1 + assert len(result[0]) == 1 + assert result[0][0].value == expected + inner_fn_ran = True - transaction = client.transaction() + in_transaction(transaction) + # make sure we didn't skip assertions in inner function + assert inner_fn_ran is True - with pytest.raises(ValueError) as exc: - create_in_transaction_helper( - transaction, client, collection_id, cleanup, database - ) - assert str(exc.value) == "Collection can't have more than 2 docs" - collection = client.collection(collection_id) +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +def test_or_query_in_transaction(client, cleanup, database): + """ + Test running or query inside a transaction. Should pass transaction id along with request + """ + collection_id = "doc-create" + UNIQUE_RESOURCE_ID + doc_ids = [f"doc{i}" + UNIQUE_RESOURCE_ID for i in range(5)] + doc_refs = [client.document(collection_id, doc_id) for doc_id in doc_ids] + for doc_ref in doc_refs: + cleanup(doc_ref.delete) + doc_refs[0].create({"a": 1, "b": 2}) + doc_refs[1].create({"a": 1, "b": 1}) + doc_refs[2].create({"a": 2, "b": 1}) # should be ignored by query + doc_refs[3].create({"a": 1, "b": 0}) # should be ignored by query + collection = client.collection(collection_id) query = collection.where(filter=FieldFilter("a", "==", 1)).where( filter=Or([FieldFilter("b", "==", 1), FieldFilter("b", "==", 2)]) ) - b_1 = False - b_2 = False - count = 0 - for result in query.stream(): - assert result.get("a") == 1 # assert a==1 is True in both results - assert result.get("b") == 1 or result.get("b") == 2 - if result.get("b") == 1: - b_1 = True - if result.get("b") == 2: - b_2 = True - count += 1 - - assert b_1 is True # assert one of them is b == 1 - assert b_2 is True # assert one of them is b == 2 - assert ( - count == 2 - ) # assert only 2 results, the third one was rolledback and not created + + with client.transaction() as transaction: + # should fail if transaction has not been initiated + with pytest.raises(ValueError): + query.get(transaction=transaction) + + # should work when transaction is initiated through transactional decorator + @firestore.transactional + def in_transaction(transaction): + global inner_fn_ran + result = query.get(transaction=transaction) + assert len(result) == 2 + # both documents should have a == 1 + assert result[0].get("a") == 1 + assert result[1].get("a") == 1 + # one document should have b == 1 and the other should have b == 2 + assert (result[0].get("b") == 1 and result[1].get("b") == 2) or ( + result[0].get("b") == 2 and result[1].get("b") == 1 + ) + inner_fn_ran = True + + in_transaction(transaction) + # make sure we didn't skip assertions in inner function + assert inner_fn_ran is True + + +@pytest.mark.parametrize("with_rollback,expected", [(True, 2), (False, 3)]) +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +def test_transaction_rollback(client, cleanup, database, with_rollback, expected): + """ + Create a document in a transaction that is rolled back + Document should not show up in later queries + """ + collection_id = "doc-create" + UNIQUE_RESOURCE_ID + doc_ids = [f"doc{i}" + UNIQUE_RESOURCE_ID for i in range(3)] + doc_refs = [client.document(collection_id, doc_id) for doc_id in doc_ids] + for doc_ref in doc_refs: + cleanup(doc_ref.delete) + doc_refs[0].create({"a": 1}) + doc_refs[1].create({"a": 1}) + doc_refs[2].create({"a": 2}) # should be ignored by query + + transaction = client.transaction() + + @firestore.transactional + def in_transaction(transaction, rollback): + """ + create a document in a transaction that is rolled back (raises an exception) + """ + new_document_id = "in_transaction_doc" + UNIQUE_RESOURCE_ID + new_document_ref = client.document(collection_id, new_document_id) + cleanup(new_document_ref.delete) + transaction.create(new_document_ref, {"a": 1}) + if rollback: + raise RuntimeError("rollback") + + if with_rollback: + # run transaction in function that results in a rollback + with pytest.raises(RuntimeError) as exc: + in_transaction(transaction, with_rollback) + assert str(exc.value) == "rollback" + else: + # no rollback expected + in_transaction(transaction, with_rollback) + + collection = client.collection(collection_id) + + query = collection.where(filter=FieldFilter("a", "==", 1)).count() + result = query.get() + assert len(result) == 1 + assert len(result[0]) == 1 + assert result[0][0].value == expected diff --git a/tests/system/test_system_async.py b/tests/system/test_system_async.py index bb7cff58fa..5201149167 100644 --- a/tests/system/test_system_async.py +++ b/tests/system/test_system_async.py @@ -609,11 +609,14 @@ async def query_docs(client): @pytest_asyncio.fixture -async def async_query(query_docs): - collection, stored, allowed_vals = query_docs - query = collection.where(filter=FieldFilter("a", "==", 1)) +async def collection(query_docs): + collection, _, _ = query_docs + yield collection - return query + +@pytest_asyncio.fixture +async def async_query(collection): + return collection.where(filter=FieldFilter("a", "==", 1)) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) @@ -1283,7 +1286,6 @@ async def _persist_documents( async def _do_recursive_delete(client, bulk_writer, empty_philosophers=False): - if empty_philosophers: philosophers = doc_paths = [] else: @@ -1514,7 +1516,6 @@ async def test_count_async_query_get_default_alias(async_query, database): @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) async def test_async_count_query_get_with_alias(async_query, database): - count_query = async_query.count(alias="total") result = await count_query.get() for r in result[0]: @@ -1523,7 +1524,6 @@ async def test_async_count_query_get_with_alias(async_query, database): @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) async def test_async_count_query_get_with_limit(async_query, database): - count_query = async_query.count(alias="total") result = await count_query.get() for r in result[0]: @@ -1540,7 +1540,6 @@ async def test_async_count_query_get_with_limit(async_query, database): @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) async def test_async_count_query_get_multiple_aggregations(async_query, database): - count_query = async_query.count(alias="total").count(alias="all") result = await count_query.get() @@ -1558,7 +1557,6 @@ async def test_async_count_query_get_multiple_aggregations(async_query, database async def test_async_count_query_get_multiple_aggregations_duplicated_alias( async_query, database ): - count_query = async_query.count(alias="total").count(alias="total") with pytest.raises(InvalidArgument) as exc_info: @@ -1580,8 +1578,7 @@ async def test_async_count_query_get_empty_aggregation(async_query, database): @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -async def test_count_async_query_stream_default_alias(async_query, database): - +async def test_async_count_query_stream_default_alias(async_query, database): count_query = async_query.count() async for result in count_query.stream(): @@ -1591,7 +1588,6 @@ async def test_count_async_query_stream_default_alias(async_query, database): @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) async def test_async_count_query_stream_with_alias(async_query, database): - count_query = async_query.count(alias="total") async for result in count_query.stream(): for aggregation_result in result: @@ -1615,7 +1611,6 @@ async def test_async_count_query_stream_with_limit(async_query, database): @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) async def test_async_count_query_stream_multiple_aggregations(async_query, database): - count_query = async_query.count(alias="total").count(alias="all") async for result in count_query.stream(): @@ -1628,7 +1623,6 @@ async def test_async_count_query_stream_multiple_aggregations(async_query, datab async def test_async_count_query_stream_multiple_aggregations_duplicated_alias( async_query, database ): - count_query = async_query.count(alias="total").count(alias="total") with pytest.raises(InvalidArgument) as exc_info: @@ -1651,6 +1645,201 @@ async def test_async_count_query_stream_empty_aggregation(async_query, database) assert "Aggregations can not be empty" in exc_info.value.message +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_async_sum_query_get_default_alias(collection, database): + sum_query = collection.sum("stats.product") + result = await sum_query.get() + for r in result[0]: + assert r.alias == "field_1" + assert r.value == 100 + + +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_async_sum_query_get_with_alias(collection, database): + sum_query = collection.sum("stats.product", alias="total") + result = await sum_query.get() + for r in result[0]: + assert r.alias == "total" + assert r.value == 100 + + +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_async_sum_query_get_with_limit(collection, database): + sum_query = collection.sum("stats.product", alias="total") + result = await sum_query.get() + for r in result[0]: + assert r.alias == "total" + assert r.value == 100 + + # sum with limit + sum_query = collection.limit(12).sum("stats.product", alias="total") + result = await sum_query.get() + for r in result[0]: + assert r.alias == "total" + assert r.value == 5 + + +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_async_sum_query_get_multiple_aggregations(collection, database): + sum_query = collection.sum("stats.product", alias="total").sum( + "stats.product", alias="all" + ) + + result = await sum_query.get() + assert len(result[0]) == 2 + + expected_aliases = ["total", "all"] + found_alias = set( + [r.alias for r in result[0]] + ) # ensure unique elements in the result + assert len(found_alias) == 2 + assert found_alias == set(expected_aliases) + + +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_async_sum_query_stream_default_alias(collection, database): + sum_query = collection.sum("stats.product") + + async for result in sum_query.stream(): + for aggregation_result in result: + assert aggregation_result.alias == "field_1" + assert aggregation_result.value == 100 + + +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_async_sum_query_stream_with_alias(collection, database): + sum_query = collection.sum("stats.product", alias="total") + async for result in sum_query.stream(): + for aggregation_result in result: + assert aggregation_result.alias == "total" + + +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_async_sum_query_stream_with_limit(collection, database): + # sum without limit + sum_query = collection.sum("stats.product", alias="total") + async for result in sum_query.stream(): + for aggregation_result in result: + assert aggregation_result.value == 100 + + # sum with limit + sum_query = collection.limit(12).sum("stats.product", alias="total") + async for result in sum_query.stream(): + for aggregation_result in result: + assert aggregation_result.value == 5 + + +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_async_sum_query_stream_multiple_aggregations(collection, database): + sum_query = collection.sum("stats.product", alias="total").sum( + "stats.product", alias="all" + ) + + async for result in sum_query.stream(): + assert len(result) == 2 + for aggregation_result in result: + assert aggregation_result.alias in ["total", "all"] + + +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_async_avg_query_get_default_alias(collection, database): + avg_query = collection.avg("stats.product") + result = await avg_query.get() + for r in result[0]: + assert r.alias == "field_1" + assert r.value == 4 + assert isinstance(r.value, float) + + +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_async_avg_query_get_with_alias(collection, database): + avg_query = collection.avg("stats.product", alias="total") + result = await avg_query.get() + for r in result[0]: + assert r.alias == "total" + assert r.value == 4 + + +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_async_avg_query_get_with_limit(collection, database): + avg_query = collection.avg("stats.product", alias="total") + result = await avg_query.get() + for r in result[0]: + assert r.alias == "total" + assert r.value == 4 + + # avg with limit + avg_query = collection.limit(12).avg("stats.product", alias="total") + result = await avg_query.get() + for r in result[0]: + assert r.alias == "total" + assert r.value == 5 / 12 + + +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_async_avg_query_get_multiple_aggregations(collection, database): + avg_query = collection.avg("stats.product", alias="total").avg( + "stats.product", alias="all" + ) + + result = await avg_query.get() + assert len(result[0]) == 2 + + expected_aliases = ["total", "all"] + found_alias = set( + [r.alias for r in result[0]] + ) # ensure unique elements in the result + assert len(found_alias) == 2 + assert found_alias == set(expected_aliases) + + +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_async_avg_query_stream_default_alias(collection, database): + avg_query = collection.avg("stats.product") + + async for result in avg_query.stream(): + for aggregation_result in result: + assert aggregation_result.alias == "field_1" + assert aggregation_result.value == 4.0 + assert isinstance(aggregation_result.value, float) + + +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_async_avg_query_stream_with_alias(collection, database): + avg_query = collection.avg("stats.product", alias="total") + async for result in avg_query.stream(): + for aggregation_result in result: + assert aggregation_result.alias == "total" + + +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_async_avg_query_stream_with_limit(collection, database): + # avg without limit + avg_query = collection.avg("stats.product", alias="total") + async for result in avg_query.stream(): + for aggregation_result in result: + assert aggregation_result.value == 4.0 + + # avg with limit + avg_query = collection.limit(12).avg("stats.product", alias="total") + async for result in avg_query.stream(): + for aggregation_result in result: + assert aggregation_result.value == 5 / 12 + assert isinstance(aggregation_result.value, float) + + +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_async_avg_query_stream_multiple_aggregations(collection, database): + avg_query = collection.avg("stats.product", alias="total").avg( + "stats.product", alias="all" + ) + + async for result in avg_query.stream(): + assert len(result) == 2 + for aggregation_result in result: + assert aggregation_result.alias in ["total", "all"] + + @firestore.async_transactional async def create_in_transaction_helper( transaction, client, collection_id, cleanup, database diff --git a/tests/unit/gapic/firestore_admin_v1/test_firestore_admin.py b/tests/unit/gapic/firestore_admin_v1/test_firestore_admin.py index 219c894a0a..95a774280b 100644 --- a/tests/unit/gapic/firestore_admin_v1/test_firestore_admin.py +++ b/tests/unit/gapic/firestore_admin_v1/test_firestore_admin.py @@ -65,7 +65,7 @@ from google.cloud.firestore_admin_v1.types import index as gfa_index from google.cloud.firestore_admin_v1.types import operation as gfa_operation from google.cloud.location import locations_pb2 -from google.longrunning import operations_pb2 +from google.longrunning import operations_pb2 # type: ignore from google.oauth2 import service_account from google.protobuf import empty_pb2 # type: ignore from google.protobuf import field_mask_pb2 # type: ignore @@ -4162,6 +4162,73 @@ def test_create_index_rest(request_type): "fields": [{"field_path": "field_path_value", "order": 1, "array_config": 1}], "state": 1, } + # The version of a generated dependency at test runtime may differ from the version used during generation. + # Delete any fields which are not present in the current runtime dependency + # See https://github.com/googleapis/gapic-generator-python/issues/1748 + + # Determine if the message type is proto-plus or protobuf + test_field = firestore_admin.CreateIndexRequest.meta.fields["index"] + + def get_message_fields(field): + # Given a field which is a message (composite type), return a list with + # all the fields of the message. + # If the field is not a composite type, return an empty list. + message_fields = [] + + if hasattr(field, "message") and field.message: + is_field_type_proto_plus_type = not hasattr(field.message, "DESCRIPTOR") + + if is_field_type_proto_plus_type: + message_fields = field.message.meta.fields.values() + # Add `# pragma: NO COVER` because there may not be any `*_pb2` field types + else: # pragma: NO COVER + message_fields = field.message.DESCRIPTOR.fields + return message_fields + + runtime_nested_fields = [ + (field.name, nested_field.name) + for field in get_message_fields(test_field) + for nested_field in get_message_fields(field) + ] + + subfields_not_in_runtime = [] + + # For each item in the sample request, create a list of sub fields which are not present at runtime + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for field, value in request_init["index"].items(): # pragma: NO COVER + result = None + is_repeated = False + # For repeated fields + if isinstance(value, list) and len(value): + is_repeated = True + result = value[0] + # For fields where the type is another message + if isinstance(value, dict): + result = value + + if result and hasattr(result, "keys"): + for subfield in result.keys(): + if (field, subfield) not in runtime_nested_fields: + subfields_not_in_runtime.append( + { + "field": field, + "subfield": subfield, + "is_repeated": is_repeated, + } + ) + + # Remove fields from the sample request which are not present in the runtime version of the dependency + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for subfield_to_delete in subfields_not_in_runtime: # pragma: NO COVER + field = subfield_to_delete.get("field") + field_repeated = subfield_to_delete.get("is_repeated") + subfield = subfield_to_delete.get("subfield") + if subfield: + if field_repeated: + for i in range(0, len(request_init["index"][field])): + del request_init["index"][field][i][subfield] + else: + del request_init["index"][field][subfield] request = request_type(**request_init) # Mock the http request call within the method and fake a response. @@ -4346,13 +4413,6 @@ def test_create_index_rest_bad_request( request_init = { "parent": "projects/sample1/databases/sample2/collectionGroups/sample3" } - request_init["index"] = { - "name": "name_value", - "query_scope": 1, - "api_scope": 1, - "fields": [{"field_path": "field_path_value", "order": 1, "array_config": 1}], - "state": 1, - } request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. @@ -4461,8 +4521,9 @@ def test_list_indexes_rest(request_type): # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 - pb_return_value = firestore_admin.ListIndexesResponse.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) + # Convert return value to protobuf type + return_value = firestore_admin.ListIndexesResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -4545,8 +4606,9 @@ def test_list_indexes_rest_required_fields( response_value = Response() response_value.status_code = 200 - pb_return_value = firestore_admin.ListIndexesResponse.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) + # Convert return value to protobuf type + return_value = firestore_admin.ListIndexesResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -4684,8 +4746,9 @@ def test_list_indexes_rest_flattened(): # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 - pb_return_value = firestore_admin.ListIndexesResponse.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) + # Convert return value to protobuf type + return_value = firestore_admin.ListIndexesResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -4814,8 +4877,9 @@ def test_get_index_rest(request_type): # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 - pb_return_value = index.Index.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) + # Convert return value to protobuf type + return_value = index.Index.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -4891,8 +4955,9 @@ def test_get_index_rest_required_fields(request_type=firestore_admin.GetIndexReq response_value = Response() response_value.status_code = 200 - pb_return_value = index.Index.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) + # Convert return value to protobuf type + return_value = index.Index.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -5019,8 +5084,9 @@ def test_get_index_rest_flattened(): # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 - pb_return_value = index.Index.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) + # Convert return value to protobuf type + return_value = index.Index.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -5346,8 +5412,9 @@ def test_get_field_rest(request_type): # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 - pb_return_value = field.Field.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) + # Convert return value to protobuf type + return_value = field.Field.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -5420,8 +5487,9 @@ def test_get_field_rest_required_fields(request_type=firestore_admin.GetFieldReq response_value = Response() response_value.status_code = 200 - pb_return_value = field.Field.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) + # Convert return value to protobuf type + return_value = field.Field.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -5548,8 +5616,9 @@ def test_get_field_rest_flattened(): # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 - pb_return_value = field.Field.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) + # Convert return value to protobuf type + return_value = field.Field.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -5630,6 +5699,73 @@ def test_update_field_rest(request_type): }, "ttl_config": {"state": 1}, } + # The version of a generated dependency at test runtime may differ from the version used during generation. + # Delete any fields which are not present in the current runtime dependency + # See https://github.com/googleapis/gapic-generator-python/issues/1748 + + # Determine if the message type is proto-plus or protobuf + test_field = firestore_admin.UpdateFieldRequest.meta.fields["field"] + + def get_message_fields(field): + # Given a field which is a message (composite type), return a list with + # all the fields of the message. + # If the field is not a composite type, return an empty list. + message_fields = [] + + if hasattr(field, "message") and field.message: + is_field_type_proto_plus_type = not hasattr(field.message, "DESCRIPTOR") + + if is_field_type_proto_plus_type: + message_fields = field.message.meta.fields.values() + # Add `# pragma: NO COVER` because there may not be any `*_pb2` field types + else: # pragma: NO COVER + message_fields = field.message.DESCRIPTOR.fields + return message_fields + + runtime_nested_fields = [ + (field.name, nested_field.name) + for field in get_message_fields(test_field) + for nested_field in get_message_fields(field) + ] + + subfields_not_in_runtime = [] + + # For each item in the sample request, create a list of sub fields which are not present at runtime + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for field, value in request_init["field"].items(): # pragma: NO COVER + result = None + is_repeated = False + # For repeated fields + if isinstance(value, list) and len(value): + is_repeated = True + result = value[0] + # For fields where the type is another message + if isinstance(value, dict): + result = value + + if result and hasattr(result, "keys"): + for subfield in result.keys(): + if (field, subfield) not in runtime_nested_fields: + subfields_not_in_runtime.append( + { + "field": field, + "subfield": subfield, + "is_repeated": is_repeated, + } + ) + + # Remove fields from the sample request which are not present in the runtime version of the dependency + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for subfield_to_delete in subfields_not_in_runtime: # pragma: NO COVER + field = subfield_to_delete.get("field") + field_repeated = subfield_to_delete.get("is_repeated") + subfield = subfield_to_delete.get("subfield") + if subfield: + if field_repeated: + for i in range(0, len(request_init["field"][field])): + del request_init["field"][field][i][subfield] + else: + del request_init["field"][field][subfield] request = request_type(**request_init) # Mock the http request call within the method and fake a response. @@ -5805,30 +5941,6 @@ def test_update_field_rest_bad_request( "name": "projects/sample1/databases/sample2/collectionGroups/sample3/fields/sample4" } } - request_init["field"] = { - "name": "projects/sample1/databases/sample2/collectionGroups/sample3/fields/sample4", - "index_config": { - "indexes": [ - { - "name": "name_value", - "query_scope": 1, - "api_scope": 1, - "fields": [ - { - "field_path": "field_path_value", - "order": 1, - "array_config": 1, - } - ], - "state": 1, - } - ], - "uses_ancestor_config": True, - "ancestor_field": "ancestor_field_value", - "reverting": True, - }, - "ttl_config": {"state": 1}, - } request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. @@ -5937,8 +6049,9 @@ def test_list_fields_rest(request_type): # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 - pb_return_value = firestore_admin.ListFieldsResponse.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) + # Convert return value to protobuf type + return_value = firestore_admin.ListFieldsResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -6021,8 +6134,9 @@ def test_list_fields_rest_required_fields( response_value = Response() response_value.status_code = 200 - pb_return_value = firestore_admin.ListFieldsResponse.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) + # Convert return value to protobuf type + return_value = firestore_admin.ListFieldsResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -6160,8 +6274,9 @@ def test_list_fields_rest_flattened(): # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 - pb_return_value = firestore_admin.ListFieldsResponse.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) + # Convert return value to protobuf type + return_value = firestore_admin.ListFieldsResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -6810,6 +6925,73 @@ def test_create_database_rest(request_type): "key_prefix": "key_prefix_value", "etag": "etag_value", } + # The version of a generated dependency at test runtime may differ from the version used during generation. + # Delete any fields which are not present in the current runtime dependency + # See https://github.com/googleapis/gapic-generator-python/issues/1748 + + # Determine if the message type is proto-plus or protobuf + test_field = firestore_admin.CreateDatabaseRequest.meta.fields["database"] + + def get_message_fields(field): + # Given a field which is a message (composite type), return a list with + # all the fields of the message. + # If the field is not a composite type, return an empty list. + message_fields = [] + + if hasattr(field, "message") and field.message: + is_field_type_proto_plus_type = not hasattr(field.message, "DESCRIPTOR") + + if is_field_type_proto_plus_type: + message_fields = field.message.meta.fields.values() + # Add `# pragma: NO COVER` because there may not be any `*_pb2` field types + else: # pragma: NO COVER + message_fields = field.message.DESCRIPTOR.fields + return message_fields + + runtime_nested_fields = [ + (field.name, nested_field.name) + for field in get_message_fields(test_field) + for nested_field in get_message_fields(field) + ] + + subfields_not_in_runtime = [] + + # For each item in the sample request, create a list of sub fields which are not present at runtime + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for field, value in request_init["database"].items(): # pragma: NO COVER + result = None + is_repeated = False + # For repeated fields + if isinstance(value, list) and len(value): + is_repeated = True + result = value[0] + # For fields where the type is another message + if isinstance(value, dict): + result = value + + if result and hasattr(result, "keys"): + for subfield in result.keys(): + if (field, subfield) not in runtime_nested_fields: + subfields_not_in_runtime.append( + { + "field": field, + "subfield": subfield, + "is_repeated": is_repeated, + } + ) + + # Remove fields from the sample request which are not present in the runtime version of the dependency + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for subfield_to_delete in subfields_not_in_runtime: # pragma: NO COVER + field = subfield_to_delete.get("field") + field_repeated = subfield_to_delete.get("is_repeated") + subfield = subfield_to_delete.get("subfield") + if subfield: + if field_repeated: + for i in range(0, len(request_init["database"][field])): + del request_init["database"][field][i][subfield] + else: + del request_init["database"][field][subfield] request = request_type(**request_init) # Mock the http request call within the method and fake a response. @@ -7008,15 +7190,6 @@ def test_create_database_rest_bad_request( # send a request that will satisfy transcoding request_init = {"parent": "projects/sample1"} - request_init["database"] = { - "name": "name_value", - "location_id": "location_id_value", - "type_": 1, - "concurrency_mode": 1, - "app_engine_integration_mode": 1, - "key_prefix": "key_prefix_value", - "etag": "etag_value", - } request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. @@ -7127,8 +7300,9 @@ def test_get_database_rest(request_type): # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 - pb_return_value = database.Database.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) + # Convert return value to protobuf type + return_value = database.Database.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -7212,8 +7386,9 @@ def test_get_database_rest_required_fields( response_value = Response() response_value.status_code = 200 - pb_return_value = database.Database.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) + # Convert return value to protobuf type + return_value = database.Database.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -7336,8 +7511,9 @@ def test_get_database_rest_flattened(): # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 - pb_return_value = database.Database.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) + # Convert return value to protobuf type + return_value = database.Database.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -7398,8 +7574,9 @@ def test_list_databases_rest(request_type): # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 - pb_return_value = firestore_admin.ListDatabasesResponse.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) + # Convert return value to protobuf type + return_value = firestore_admin.ListDatabasesResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -7473,8 +7650,9 @@ def test_list_databases_rest_required_fields( response_value = Response() response_value.status_code = 200 - pb_return_value = firestore_admin.ListDatabasesResponse.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) + # Convert return value to protobuf type + return_value = firestore_admin.ListDatabasesResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -7599,8 +7777,9 @@ def test_list_databases_rest_flattened(): # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 - pb_return_value = firestore_admin.ListDatabasesResponse.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) + # Convert return value to protobuf type + return_value = firestore_admin.ListDatabasesResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -7660,6 +7839,73 @@ def test_update_database_rest(request_type): "key_prefix": "key_prefix_value", "etag": "etag_value", } + # The version of a generated dependency at test runtime may differ from the version used during generation. + # Delete any fields which are not present in the current runtime dependency + # See https://github.com/googleapis/gapic-generator-python/issues/1748 + + # Determine if the message type is proto-plus or protobuf + test_field = firestore_admin.UpdateDatabaseRequest.meta.fields["database"] + + def get_message_fields(field): + # Given a field which is a message (composite type), return a list with + # all the fields of the message. + # If the field is not a composite type, return an empty list. + message_fields = [] + + if hasattr(field, "message") and field.message: + is_field_type_proto_plus_type = not hasattr(field.message, "DESCRIPTOR") + + if is_field_type_proto_plus_type: + message_fields = field.message.meta.fields.values() + # Add `# pragma: NO COVER` because there may not be any `*_pb2` field types + else: # pragma: NO COVER + message_fields = field.message.DESCRIPTOR.fields + return message_fields + + runtime_nested_fields = [ + (field.name, nested_field.name) + for field in get_message_fields(test_field) + for nested_field in get_message_fields(field) + ] + + subfields_not_in_runtime = [] + + # For each item in the sample request, create a list of sub fields which are not present at runtime + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for field, value in request_init["database"].items(): # pragma: NO COVER + result = None + is_repeated = False + # For repeated fields + if isinstance(value, list) and len(value): + is_repeated = True + result = value[0] + # For fields where the type is another message + if isinstance(value, dict): + result = value + + if result and hasattr(result, "keys"): + for subfield in result.keys(): + if (field, subfield) not in runtime_nested_fields: + subfields_not_in_runtime.append( + { + "field": field, + "subfield": subfield, + "is_repeated": is_repeated, + } + ) + + # Remove fields from the sample request which are not present in the runtime version of the dependency + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for subfield_to_delete in subfields_not_in_runtime: # pragma: NO COVER + field = subfield_to_delete.get("field") + field_repeated = subfield_to_delete.get("is_repeated") + subfield = subfield_to_delete.get("subfield") + if subfield: + if field_repeated: + for i in range(0, len(request_init["database"][field])): + del request_init["database"][field][i][subfield] + else: + del request_init["database"][field][subfield] request = request_type(**request_init) # Mock the http request call within the method and fake a response. @@ -7831,15 +8077,6 @@ def test_update_database_rest_bad_request( # send a request that will satisfy transcoding request_init = {"database": {"name": "projects/sample1/databases/sample2"}} - request_init["database"] = { - "name": "projects/sample1/databases/sample2", - "location_id": "location_id_value", - "type_": 1, - "concurrency_mode": 1, - "app_engine_integration_mode": 1, - "key_prefix": "key_prefix_value", - "etag": "etag_value", - } request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. diff --git a/tests/unit/gapic/firestore_v1/test_firestore.py b/tests/unit/gapic/firestore_v1/test_firestore.py index fab28bd69f..6529897f9b 100644 --- a/tests/unit/gapic/firestore_v1/test_firestore.py +++ b/tests/unit/gapic/firestore_v1/test_firestore.py @@ -56,7 +56,7 @@ from google.cloud.firestore_v1.types import query from google.cloud.firestore_v1.types import write as gf_write from google.cloud.location import locations_pb2 -from google.longrunning import operations_pb2 +from google.longrunning import operations_pb2 # type: ignore from google.oauth2 import service_account from google.protobuf import struct_pb2 # type: ignore from google.protobuf import timestamp_pb2 # type: ignore @@ -4079,8 +4079,9 @@ def test_get_document_rest(request_type): # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 - pb_return_value = document.Document.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) + # Convert return value to protobuf type + return_value = document.Document.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -4161,8 +4162,9 @@ def test_get_document_rest_required_fields(request_type=firestore.GetDocumentReq response_value = Response() response_value.status_code = 200 - pb_return_value = document.Document.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) + # Convert return value to protobuf type + return_value = document.Document.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -4305,8 +4307,9 @@ def test_list_documents_rest(request_type): # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 - pb_return_value = firestore.ListDocumentsResponse.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) + # Convert return value to protobuf type + return_value = firestore.ListDocumentsResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -4393,8 +4396,9 @@ def test_list_documents_rest_required_fields( response_value = Response() response_value.status_code = 200 - pb_return_value = firestore.ListDocumentsResponse.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) + # Convert return value to protobuf type + return_value = firestore.ListDocumentsResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -4597,6 +4601,73 @@ def test_update_document_rest(request_type): "create_time": {"seconds": 751, "nanos": 543}, "update_time": {}, } + # The version of a generated dependency at test runtime may differ from the version used during generation. + # Delete any fields which are not present in the current runtime dependency + # See https://github.com/googleapis/gapic-generator-python/issues/1748 + + # Determine if the message type is proto-plus or protobuf + test_field = firestore.UpdateDocumentRequest.meta.fields["document"] + + def get_message_fields(field): + # Given a field which is a message (composite type), return a list with + # all the fields of the message. + # If the field is not a composite type, return an empty list. + message_fields = [] + + if hasattr(field, "message") and field.message: + is_field_type_proto_plus_type = not hasattr(field.message, "DESCRIPTOR") + + if is_field_type_proto_plus_type: + message_fields = field.message.meta.fields.values() + # Add `# pragma: NO COVER` because there may not be any `*_pb2` field types + else: # pragma: NO COVER + message_fields = field.message.DESCRIPTOR.fields + return message_fields + + runtime_nested_fields = [ + (field.name, nested_field.name) + for field in get_message_fields(test_field) + for nested_field in get_message_fields(field) + ] + + subfields_not_in_runtime = [] + + # For each item in the sample request, create a list of sub fields which are not present at runtime + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for field, value in request_init["document"].items(): # pragma: NO COVER + result = None + is_repeated = False + # For repeated fields + if isinstance(value, list) and len(value): + is_repeated = True + result = value[0] + # For fields where the type is another message + if isinstance(value, dict): + result = value + + if result and hasattr(result, "keys"): + for subfield in result.keys(): + if (field, subfield) not in runtime_nested_fields: + subfields_not_in_runtime.append( + { + "field": field, + "subfield": subfield, + "is_repeated": is_repeated, + } + ) + + # Remove fields from the sample request which are not present in the runtime version of the dependency + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for subfield_to_delete in subfields_not_in_runtime: # pragma: NO COVER + field = subfield_to_delete.get("field") + field_repeated = subfield_to_delete.get("is_repeated") + subfield = subfield_to_delete.get("subfield") + if subfield: + if field_repeated: + for i in range(0, len(request_init["document"][field])): + del request_init["document"][field][i][subfield] + else: + del request_init["document"][field][subfield] request = request_type(**request_init) # Mock the http request call within the method and fake a response. @@ -4609,8 +4680,9 @@ def test_update_document_rest(request_type): # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 - pb_return_value = gf_document.Document.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) + # Convert return value to protobuf type + return_value = gf_document.Document.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -4689,8 +4761,9 @@ def test_update_document_rest_required_fields( response_value = Response() response_value.status_code = 200 - pb_return_value = gf_document.Document.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) + # Convert return value to protobuf type + return_value = gf_document.Document.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -4787,12 +4860,6 @@ def test_update_document_rest_bad_request( "name": "projects/sample1/databases/sample2/documents/sample3/sample4" } } - request_init["document"] = { - "name": "projects/sample1/databases/sample2/documents/sample3/sample4", - "fields": {}, - "create_time": {"seconds": 751, "nanos": 543}, - "update_time": {}, - } request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. @@ -4835,8 +4902,9 @@ def test_update_document_rest_flattened(): # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 - pb_return_value = gf_document.Document.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) + # Convert return value to protobuf type + return_value = gf_document.Document.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -5162,8 +5230,9 @@ def test_batch_get_documents_rest(request_type): # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 - pb_return_value = firestore.BatchGetDocumentsResponse.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) + # Convert return value to protobuf type + return_value = firestore.BatchGetDocumentsResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) json_return_value = "[{}]".format(json_return_value) @@ -5246,8 +5315,9 @@ def test_batch_get_documents_rest_required_fields( response_value = Response() response_value.status_code = 200 - pb_return_value = firestore.BatchGetDocumentsResponse.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) + # Convert return value to protobuf type + return_value = firestore.BatchGetDocumentsResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) json_return_value = "[{}]".format(json_return_value) response_value._content = json_return_value.encode("UTF-8") @@ -5384,8 +5454,9 @@ def test_begin_transaction_rest(request_type): # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 - pb_return_value = firestore.BeginTransactionResponse.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) + # Convert return value to protobuf type + return_value = firestore.BeginTransactionResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -5461,8 +5532,9 @@ def test_begin_transaction_rest_required_fields( response_value = Response() response_value.status_code = 200 - pb_return_value = firestore.BeginTransactionResponse.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) + # Convert return value to protobuf type + return_value = firestore.BeginTransactionResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -5585,8 +5657,9 @@ def test_begin_transaction_rest_flattened(): # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 - pb_return_value = firestore.BeginTransactionResponse.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) + # Convert return value to protobuf type + return_value = firestore.BeginTransactionResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -5649,8 +5722,9 @@ def test_commit_rest(request_type): # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 - pb_return_value = firestore.CommitResponse.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) + # Convert return value to protobuf type + return_value = firestore.CommitResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -5723,8 +5797,9 @@ def test_commit_rest_required_fields(request_type=firestore.CommitRequest): response_value = Response() response_value.status_code = 200 - pb_return_value = firestore.CommitResponse.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) + # Convert return value to protobuf type + return_value = firestore.CommitResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -5846,8 +5921,9 @@ def test_commit_rest_flattened(): # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 - pb_return_value = firestore.CommitResponse.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) + # Convert return value to protobuf type + return_value = firestore.CommitResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -6177,8 +6253,9 @@ def test_run_query_rest(request_type): # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 - pb_return_value = firestore.RunQueryResponse.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) + # Convert return value to protobuf type + return_value = firestore.RunQueryResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) json_return_value = "[{}]".format(json_return_value) @@ -6260,8 +6337,9 @@ def test_run_query_rest_required_fields(request_type=firestore.RunQueryRequest): response_value = Response() response_value.status_code = 200 - pb_return_value = firestore.RunQueryResponse.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) + # Convert return value to protobuf type + return_value = firestore.RunQueryResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) json_return_value = "[{}]".format(json_return_value) response_value._content = json_return_value.encode("UTF-8") @@ -6396,8 +6474,9 @@ def test_run_aggregation_query_rest(request_type): # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 - pb_return_value = firestore.RunAggregationQueryResponse.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) + # Convert return value to protobuf type + return_value = firestore.RunAggregationQueryResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) json_return_value = "[{}]".format(json_return_value) @@ -6480,8 +6559,9 @@ def test_run_aggregation_query_rest_required_fields( response_value = Response() response_value.status_code = 200 - pb_return_value = firestore.RunAggregationQueryResponse.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) + # Convert return value to protobuf type + return_value = firestore.RunAggregationQueryResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) json_return_value = "[{}]".format(json_return_value) response_value._content = json_return_value.encode("UTF-8") @@ -6618,8 +6698,9 @@ def test_partition_query_rest(request_type): # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 - pb_return_value = firestore.PartitionQueryResponse.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) + # Convert return value to protobuf type + return_value = firestore.PartitionQueryResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -6695,8 +6776,9 @@ def test_partition_query_rest_required_fields( response_value = Response() response_value.status_code = 200 - pb_return_value = firestore.PartitionQueryResponse.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) + # Convert return value to protobuf type + return_value = firestore.PartitionQueryResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -6907,8 +6989,9 @@ def test_list_collection_ids_rest(request_type): # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 - pb_return_value = firestore.ListCollectionIdsResponse.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) + # Convert return value to protobuf type + return_value = firestore.ListCollectionIdsResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -6985,8 +7068,9 @@ def test_list_collection_ids_rest_required_fields( response_value = Response() response_value.status_code = 200 - pb_return_value = firestore.ListCollectionIdsResponse.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) + # Convert return value to protobuf type + return_value = firestore.ListCollectionIdsResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -7109,8 +7193,9 @@ def test_list_collection_ids_rest_flattened(): # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 - pb_return_value = firestore.ListCollectionIdsResponse.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) + # Convert return value to protobuf type + return_value = firestore.ListCollectionIdsResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -7230,8 +7315,9 @@ def test_batch_write_rest(request_type): # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 - pb_return_value = firestore.BatchWriteResponse.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) + # Convert return value to protobuf type + return_value = firestore.BatchWriteResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -7304,8 +7390,9 @@ def test_batch_write_rest_required_fields(request_type=firestore.BatchWriteReque response_value = Response() response_value.status_code = 200 - pb_return_value = firestore.BatchWriteResponse.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) + # Convert return value to protobuf type + return_value = firestore.BatchWriteResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -7433,6 +7520,73 @@ def test_create_document_rest(request_type): "create_time": {"seconds": 751, "nanos": 543}, "update_time": {}, } + # The version of a generated dependency at test runtime may differ from the version used during generation. + # Delete any fields which are not present in the current runtime dependency + # See https://github.com/googleapis/gapic-generator-python/issues/1748 + + # Determine if the message type is proto-plus or protobuf + test_field = firestore.CreateDocumentRequest.meta.fields["document"] + + def get_message_fields(field): + # Given a field which is a message (composite type), return a list with + # all the fields of the message. + # If the field is not a composite type, return an empty list. + message_fields = [] + + if hasattr(field, "message") and field.message: + is_field_type_proto_plus_type = not hasattr(field.message, "DESCRIPTOR") + + if is_field_type_proto_plus_type: + message_fields = field.message.meta.fields.values() + # Add `# pragma: NO COVER` because there may not be any `*_pb2` field types + else: # pragma: NO COVER + message_fields = field.message.DESCRIPTOR.fields + return message_fields + + runtime_nested_fields = [ + (field.name, nested_field.name) + for field in get_message_fields(test_field) + for nested_field in get_message_fields(field) + ] + + subfields_not_in_runtime = [] + + # For each item in the sample request, create a list of sub fields which are not present at runtime + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for field, value in request_init["document"].items(): # pragma: NO COVER + result = None + is_repeated = False + # For repeated fields + if isinstance(value, list) and len(value): + is_repeated = True + result = value[0] + # For fields where the type is another message + if isinstance(value, dict): + result = value + + if result and hasattr(result, "keys"): + for subfield in result.keys(): + if (field, subfield) not in runtime_nested_fields: + subfields_not_in_runtime.append( + { + "field": field, + "subfield": subfield, + "is_repeated": is_repeated, + } + ) + + # Remove fields from the sample request which are not present in the runtime version of the dependency + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for subfield_to_delete in subfields_not_in_runtime: # pragma: NO COVER + field = subfield_to_delete.get("field") + field_repeated = subfield_to_delete.get("is_repeated") + subfield = subfield_to_delete.get("subfield") + if subfield: + if field_repeated: + for i in range(0, len(request_init["document"][field])): + del request_init["document"][field][i][subfield] + else: + del request_init["document"][field][subfield] request = request_type(**request_init) # Mock the http request call within the method and fake a response. @@ -7445,8 +7599,9 @@ def test_create_document_rest(request_type): # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 - pb_return_value = document.Document.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) + # Convert return value to protobuf type + return_value = document.Document.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -7533,8 +7688,9 @@ def test_create_document_rest_required_fields( response_value = Response() response_value.status_code = 200 - pb_return_value = document.Document.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) + # Convert return value to protobuf type + return_value = document.Document.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -7635,12 +7791,6 @@ def test_create_document_rest_bad_request( "parent": "projects/sample1/databases/sample2/documents/sample3", "collection_id": "sample4", } - request_init["document"] = { - "name": "name_value", - "fields": {}, - "create_time": {"seconds": 751, "nanos": 543}, - "update_time": {}, - } request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. diff --git a/tests/unit/v1/test__helpers.py b/tests/unit/v1/test__helpers.py index 91b70c48d6..0e56a84952 100644 --- a/tests/unit/v1/test__helpers.py +++ b/tests/unit/v1/test__helpers.py @@ -1897,7 +1897,6 @@ def test_documentextractorformerge_apply_merge_list_fields_w_delete(): def test_documentextractorformerge_apply_merge_list_fields_w_prefixes(): - document_data = {"a": {"b": {"c": 123}}} inst = _make_document_extractor_for_merge(document_data) @@ -1906,7 +1905,6 @@ def test_documentextractorformerge_apply_merge_list_fields_w_prefixes(): def test_documentextractorformerge_apply_merge_lists_w_missing_data_paths(): - document_data = {"write_me": "value", "ignore_me": 123} inst = _make_document_extractor_for_merge(document_data) @@ -1915,7 +1913,6 @@ def test_documentextractorformerge_apply_merge_lists_w_missing_data_paths(): def test_documentextractorformerge_apply_merge_list_fields_w_non_merge_field(): - document_data = {"write_me": "value", "ignore_me": 123} inst = _make_document_extractor_for_merge(document_data) diff --git a/tests/unit/v1/test_aggregation.py b/tests/unit/v1/test_aggregation.py index 7b07aa9afa..d19cf69e81 100644 --- a/tests/unit/v1/test_aggregation.py +++ b/tests/unit/v1/test_aggregation.py @@ -21,6 +21,8 @@ from google.cloud.firestore_v1.base_aggregation import ( CountAggregation, + SumAggregation, + AvgAggregation, AggregationResult, ) from tests.unit.v1._test_helpers import ( @@ -46,6 +48,58 @@ def test_count_aggregation_to_pb(): assert count_aggregation._to_protobuf() == expected_aggregation_query_pb +def test_sum_aggregation_w_field_path(): + """ + SumAggregation should convert FieldPath inputs into strings + """ + from google.cloud.firestore_v1.field_path import FieldPath + + field_path = FieldPath("foo", "bar") + sum_aggregation = SumAggregation(field_path, alias="total") + assert sum_aggregation.field_ref == "foo.bar" + + +def test_avg_aggregation_w_field_path(): + """ + AvgAggregation should convert FieldPath inputs into strings + """ + from google.cloud.firestore_v1.field_path import FieldPath + + field_path = FieldPath("foo", "bar") + avg_aggregation = AvgAggregation(field_path, alias="total") + assert avg_aggregation.field_ref == "foo.bar" + + +def test_sum_aggregation_to_pb(): + from google.cloud.firestore_v1.types import query as query_pb2 + + sum_aggregation = SumAggregation("someref", alias="total") + + expected_aggregation_query_pb = query_pb2.StructuredAggregationQuery.Aggregation() + expected_aggregation_query_pb.sum = ( + query_pb2.StructuredAggregationQuery.Aggregation.Sum() + ) + expected_aggregation_query_pb.sum.field.field_path = "someref" + + expected_aggregation_query_pb.alias = sum_aggregation.alias + assert sum_aggregation._to_protobuf() == expected_aggregation_query_pb + + +def test_avg_aggregation_to_pb(): + from google.cloud.firestore_v1.types import query as query_pb2 + + avg_aggregation = AvgAggregation("someref", alias="total") + + expected_aggregation_query_pb = query_pb2.StructuredAggregationQuery.Aggregation() + expected_aggregation_query_pb.avg = ( + query_pb2.StructuredAggregationQuery.Aggregation.Avg() + ) + expected_aggregation_query_pb.avg.field.field_path = "someref" + expected_aggregation_query_pb.alias = avg_aggregation.alias + + assert avg_aggregation._to_protobuf() == expected_aggregation_query_pb + + def test_aggregation_query_constructor(): client = make_client() parent = client.collection("dee") @@ -64,11 +118,23 @@ def test_aggregation_query_add_aggregation(): query = make_query(parent) aggregation_query = make_aggregation_query(query) aggregation_query.add_aggregation(CountAggregation(alias="all")) + aggregation_query.add_aggregation(SumAggregation("sumref", alias="sum_all")) + aggregation_query.add_aggregation(AvgAggregation("avgref", alias="avg_all")) - assert len(aggregation_query._aggregations) == 1 + assert len(aggregation_query._aggregations) == 3 assert aggregation_query._aggregations[0].alias == "all" assert isinstance(aggregation_query._aggregations[0], CountAggregation) + assert len(aggregation_query._aggregations) == 3 + assert aggregation_query._aggregations[1].alias == "sum_all" + assert aggregation_query._aggregations[1].field_ref == "sumref" + assert isinstance(aggregation_query._aggregations[1], SumAggregation) + + assert len(aggregation_query._aggregations) == 3 + assert aggregation_query._aggregations[2].alias == "avg_all" + assert aggregation_query._aggregations[2].field_ref == "avgref" + assert isinstance(aggregation_query._aggregations[2], AvgAggregation) + def test_aggregation_query_add_aggregations(): client = make_client() @@ -77,15 +143,26 @@ def test_aggregation_query_add_aggregations(): aggregation_query = make_aggregation_query(query) aggregation_query.add_aggregations( - [CountAggregation(alias="all"), CountAggregation(alias="total")] + [ + CountAggregation(alias="all"), + CountAggregation(alias="total"), + SumAggregation("sumref", alias="sum_all"), + AvgAggregation("avgref", alias="avg_all"), + ] ) - assert len(aggregation_query._aggregations) == 2 + assert len(aggregation_query._aggregations) == 4 assert aggregation_query._aggregations[0].alias == "all" assert aggregation_query._aggregations[1].alias == "total" + assert aggregation_query._aggregations[2].alias == "sum_all" + assert aggregation_query._aggregations[2].field_ref == "sumref" + assert aggregation_query._aggregations[3].alias == "avg_all" + assert aggregation_query._aggregations[3].field_ref == "avgref" assert isinstance(aggregation_query._aggregations[0], CountAggregation) assert isinstance(aggregation_query._aggregations[1], CountAggregation) + assert isinstance(aggregation_query._aggregations[2], SumAggregation) + assert isinstance(aggregation_query._aggregations[3], AvgAggregation) def test_aggregation_query_count(): @@ -118,6 +195,102 @@ def test_aggregation_query_count_twice(): assert isinstance(aggregation_query._aggregations[1], CountAggregation) +def test_aggregation_query_sum(): + client = make_client() + parent = client.collection("dee") + query = make_query(parent) + aggregation_query = make_aggregation_query(query) + + aggregation_query.sum("someref", alias="all") + + assert len(aggregation_query._aggregations) == 1 + assert aggregation_query._aggregations[0].alias == "all" + assert aggregation_query._aggregations[0].field_ref == "someref" + + assert isinstance(aggregation_query._aggregations[0], SumAggregation) + + +def test_aggregation_query_sum_twice(): + client = make_client() + parent = client.collection("dee") + query = make_query(parent) + aggregation_query = make_aggregation_query(query) + + aggregation_query.sum("someref", alias="all").sum("another_ref", alias="total") + + assert len(aggregation_query._aggregations) == 2 + assert aggregation_query._aggregations[0].alias == "all" + assert aggregation_query._aggregations[0].field_ref == "someref" + assert aggregation_query._aggregations[1].alias == "total" + assert aggregation_query._aggregations[1].field_ref == "another_ref" + + assert isinstance(aggregation_query._aggregations[0], SumAggregation) + assert isinstance(aggregation_query._aggregations[1], SumAggregation) + + +def test_aggregation_query_sum_no_alias(): + client = make_client() + parent = client.collection("dee") + query = make_query(parent) + aggregation_query = make_aggregation_query(query) + + aggregation_query.sum("someref") + + assert len(aggregation_query._aggregations) == 1 + assert aggregation_query._aggregations[0].alias is None + assert aggregation_query._aggregations[0].field_ref == "someref" + + assert isinstance(aggregation_query._aggregations[0], SumAggregation) + + +def test_aggregation_query_avg(): + client = make_client() + parent = client.collection("dee") + query = make_query(parent) + aggregation_query = make_aggregation_query(query) + + aggregation_query.avg("someref", alias="all") + + assert len(aggregation_query._aggregations) == 1 + assert aggregation_query._aggregations[0].alias == "all" + assert aggregation_query._aggregations[0].field_ref == "someref" + + assert isinstance(aggregation_query._aggregations[0], AvgAggregation) + + +def test_aggregation_query_avg_twice(): + client = make_client() + parent = client.collection("dee") + query = make_query(parent) + aggregation_query = make_aggregation_query(query) + + aggregation_query.avg("someref", alias="all").avg("another_ref", alias="total") + + assert len(aggregation_query._aggregations) == 2 + assert aggregation_query._aggregations[0].alias == "all" + assert aggregation_query._aggregations[0].field_ref == "someref" + assert aggregation_query._aggregations[1].alias == "total" + assert aggregation_query._aggregations[1].field_ref == "another_ref" + + assert isinstance(aggregation_query._aggregations[0], AvgAggregation) + assert isinstance(aggregation_query._aggregations[1], AvgAggregation) + + +def test_aggregation_query_avg_no_alias(): + client = make_client() + parent = client.collection("dee") + query = make_query(parent) + aggregation_query = make_aggregation_query(query) + + aggregation_query.avg("someref") + + assert len(aggregation_query._aggregations) == 1 + assert aggregation_query._aggregations[0].alias is None + assert aggregation_query._aggregations[0].field_ref == "someref" + + assert isinstance(aggregation_query._aggregations[0], AvgAggregation) + + def test_aggregation_query_to_protobuf(): client = make_client() parent = client.collection("dee") @@ -125,11 +298,15 @@ def test_aggregation_query_to_protobuf(): aggregation_query = make_aggregation_query(query) aggregation_query.count(alias="all") + aggregation_query.sum("someref", alias="sumall") + aggregation_query.avg("anotherref", alias="avgall") pb = aggregation_query._to_protobuf() assert pb.structured_query == parent._query()._to_protobuf() - assert len(pb.aggregations) == 1 + assert len(pb.aggregations) == 3 assert pb.aggregations[0] == aggregation_query._aggregations[0]._to_protobuf() + assert pb.aggregations[1] == aggregation_query._aggregations[1]._to_protobuf() + assert pb.aggregations[2] == aggregation_query._aggregations[2]._to_protobuf() def test_aggregation_query_prep_stream(): @@ -139,6 +316,8 @@ def test_aggregation_query_prep_stream(): aggregation_query = make_aggregation_query(query) aggregation_query.count(alias="all") + aggregation_query.sum("someref", alias="sumall") + aggregation_query.avg("anotherref", alias="avgall") request, kwargs = aggregation_query._prep_stream() @@ -163,6 +342,8 @@ def test_aggregation_query_prep_stream_with_transaction(): aggregation_query = make_aggregation_query(query) aggregation_query.count(alias="all") + aggregation_query.sum("someref", alias="sumall") + aggregation_query.avg("anotherref", alias="avgall") request, kwargs = aggregation_query._prep_stream(transaction=transaction) @@ -194,6 +375,7 @@ def _aggregation_query_get_helper(retry=None, timeout=None, read_time=None): aggregation_query.count(alias="all") aggregation_result = AggregationResult(alias="total", value=5, read_time=read_time) + response_pb = make_aggregation_query_response( [aggregation_result], read_time=read_time ) @@ -446,31 +628,38 @@ def test_aggregation_from_query(): response_pb = make_aggregation_query_response( [aggregation_result], transaction=txn_id ) - firestore_api.run_aggregation_query.return_value = iter([response_pb]) retry = None timeout = None kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) # Execute the query and check the response. - aggregation_query = query.count(alias="total") - returned = aggregation_query.get(transaction=transaction, **kwargs) - assert isinstance(returned, list) - assert len(returned) == 1 - - for result in returned: - for r in result: - assert r.alias == aggregation_result.alias - assert r.value == aggregation_result.value - - # Verify the mock call. - parent_path, _ = parent._parent_info() - - firestore_api.run_aggregation_query.assert_called_once_with( - request={ - "parent": parent_path, - "structured_aggregation_query": aggregation_query._to_protobuf(), - "transaction": txn_id, - }, - metadata=client._rpc_metadata, - **kwargs, - ) + for aggregation_query in [ + query.count(alias="total"), + query.sum("foo", alias="total"), + query.avg("foo", alias="total"), + ]: + # reset api mock + firestore_api.run_aggregation_query.reset_mock() + firestore_api.run_aggregation_query.return_value = iter([response_pb]) + # run query + returned = aggregation_query.get(transaction=transaction, **kwargs) + assert isinstance(returned, list) + assert len(returned) == 1 + + for result in returned: + for r in result: + assert r.alias == aggregation_result.alias + assert r.value == aggregation_result.value + + # Verify the mock call. + parent_path, _ = parent._parent_info() + + firestore_api.run_aggregation_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_aggregation_query": aggregation_query._to_protobuf(), + "transaction": txn_id, + }, + metadata=client._rpc_metadata, + **kwargs, + ) diff --git a/tests/unit/v1/test_async_aggregation.py b/tests/unit/v1/test_async_aggregation.py index 6ed2f74b62..4ed97ddb98 100644 --- a/tests/unit/v1/test_async_aggregation.py +++ b/tests/unit/v1/test_async_aggregation.py @@ -19,6 +19,8 @@ from google.cloud.firestore_v1.base_aggregation import ( CountAggregation, + SumAggregation, + AvgAggregation, AggregationResult, ) @@ -54,11 +56,22 @@ def test_async_aggregation_query_add_aggregation(): aggregation_query = make_async_aggregation_query(query) aggregation_query.add_aggregation(CountAggregation(alias="all")) + aggregation_query.add_aggregation(SumAggregation("someref", alias="sum_all")) + aggregation_query.add_aggregation(AvgAggregation("otherref", alias="avg_all")) + + assert len(aggregation_query._aggregations) == 3 - assert len(aggregation_query._aggregations) == 1 assert aggregation_query._aggregations[0].alias == "all" assert isinstance(aggregation_query._aggregations[0], CountAggregation) + assert aggregation_query._aggregations[1].field_ref == "someref" + assert aggregation_query._aggregations[1].alias == "sum_all" + assert isinstance(aggregation_query._aggregations[1], SumAggregation) + + assert aggregation_query._aggregations[2].field_ref == "otherref" + assert aggregation_query._aggregations[2].alias == "avg_all" + assert isinstance(aggregation_query._aggregations[2], AvgAggregation) + def test_async_aggregation_query_add_aggregations(): client = make_async_client() @@ -67,15 +80,28 @@ def test_async_aggregation_query_add_aggregations(): aggregation_query = make_async_aggregation_query(query) aggregation_query.add_aggregations( - [CountAggregation(alias="all"), CountAggregation(alias="total")] + [ + CountAggregation(alias="all"), + CountAggregation(alias="total"), + SumAggregation("someref", alias="sum_all"), + AvgAggregation("otherref", alias="avg_all"), + ] ) - assert len(aggregation_query._aggregations) == 2 + assert len(aggregation_query._aggregations) == 4 assert aggregation_query._aggregations[0].alias == "all" assert aggregation_query._aggregations[1].alias == "total" + assert aggregation_query._aggregations[2].field_ref == "someref" + assert aggregation_query._aggregations[2].alias == "sum_all" + + assert aggregation_query._aggregations[3].field_ref == "otherref" + assert aggregation_query._aggregations[3].alias == "avg_all" + assert isinstance(aggregation_query._aggregations[0], CountAggregation) assert isinstance(aggregation_query._aggregations[1], CountAggregation) + assert isinstance(aggregation_query._aggregations[2], SumAggregation) + assert isinstance(aggregation_query._aggregations[3], AvgAggregation) def test_async_aggregation_query_count(): @@ -108,6 +134,104 @@ def test_async_aggregation_query_count_twice(): assert isinstance(aggregation_query._aggregations[1], CountAggregation) +def test_async_aggregation_sum(): + client = make_async_client() + parent = client.collection("dee") + query = make_async_query(parent) + aggregation_query = make_async_aggregation_query(query) + + aggregation_query.sum("someref", alias="sum_all") + + assert len(aggregation_query._aggregations) == 1 + assert aggregation_query._aggregations[0].alias == "sum_all" + assert aggregation_query._aggregations[0].field_ref == "someref" + + assert isinstance(aggregation_query._aggregations[0], SumAggregation) + + +def test_async_aggregation_query_sum_twice(): + client = make_async_client() + parent = client.collection("dee") + query = make_async_query(parent) + aggregation_query = make_async_aggregation_query(query) + + aggregation_query.sum("someref", alias="sum_all").sum( + "another_ref", alias="sum_total" + ) + + assert len(aggregation_query._aggregations) == 2 + assert aggregation_query._aggregations[0].alias == "sum_all" + assert aggregation_query._aggregations[0].field_ref == "someref" + assert aggregation_query._aggregations[1].alias == "sum_total" + assert aggregation_query._aggregations[1].field_ref == "another_ref" + + assert isinstance(aggregation_query._aggregations[0], SumAggregation) + assert isinstance(aggregation_query._aggregations[1], SumAggregation) + + +def test_async_aggregation_sum_no_alias(): + client = make_async_client() + parent = client.collection("dee") + query = make_async_query(parent) + aggregation_query = make_async_aggregation_query(query) + + aggregation_query.sum("someref") + + assert len(aggregation_query._aggregations) == 1 + assert aggregation_query._aggregations[0].alias is None + assert aggregation_query._aggregations[0].field_ref == "someref" + + assert isinstance(aggregation_query._aggregations[0], SumAggregation) + + +def test_aggregation_query_avg(): + client = make_async_client() + parent = client.collection("dee") + query = make_async_query(parent) + aggregation_query = make_async_aggregation_query(query) + + aggregation_query.avg("someref", alias="all") + + assert len(aggregation_query._aggregations) == 1 + assert aggregation_query._aggregations[0].alias == "all" + assert aggregation_query._aggregations[0].field_ref == "someref" + + assert isinstance(aggregation_query._aggregations[0], AvgAggregation) + + +def test_aggregation_query_avg_twice(): + client = make_async_client() + parent = client.collection("dee") + query = make_async_query(parent) + aggregation_query = make_async_aggregation_query(query) + + aggregation_query.avg("someref", alias="all").avg("another_ref", alias="total") + + assert len(aggregation_query._aggregations) == 2 + assert aggregation_query._aggregations[0].alias == "all" + assert aggregation_query._aggregations[0].field_ref == "someref" + assert aggregation_query._aggregations[1].alias == "total" + assert aggregation_query._aggregations[1].field_ref == "another_ref" + + assert isinstance(aggregation_query._aggregations[0], AvgAggregation) + assert isinstance(aggregation_query._aggregations[1], AvgAggregation) + + +def test_aggregation_query_avg_no_alias(): + client = make_async_client() + parent = client.collection("dee") + query = make_async_query(parent) + aggregation_query = make_async_aggregation_query(query) + + aggregation_query.avg("someref") + + assert len(aggregation_query._aggregations) == 1 + assert aggregation_query._aggregations[0].alias is None + assert aggregation_query._aggregations[0].field_ref == "someref" + + assert isinstance(aggregation_query._aggregations[0], AvgAggregation) + + def test_async_aggregation_query_to_protobuf(): client = make_async_client() parent = client.collection("dee") @@ -115,11 +239,15 @@ def test_async_aggregation_query_to_protobuf(): aggregation_query = make_async_aggregation_query(query) aggregation_query.count(alias="all") + aggregation_query.sum("someref", alias="sum_all") + aggregation_query.avg("someref", alias="avg_all") pb = aggregation_query._to_protobuf() assert pb.structured_query == parent._query()._to_protobuf() - assert len(pb.aggregations) == 1 + assert len(pb.aggregations) == 3 assert pb.aggregations[0] == aggregation_query._aggregations[0]._to_protobuf() + assert pb.aggregations[1] == aggregation_query._aggregations[1]._to_protobuf() + assert pb.aggregations[2] == aggregation_query._aggregations[2]._to_protobuf() def test_async_aggregation_query_prep_stream(): @@ -129,7 +257,8 @@ def test_async_aggregation_query_prep_stream(): aggregation_query = make_async_aggregation_query(query) aggregation_query.count(alias="all") - + aggregation_query.sum("someref", alias="sum_all") + aggregation_query.avg("someref", alias="avg_all") request, kwargs = aggregation_query._prep_stream() parent_path, _ = parent._parent_info() @@ -152,6 +281,8 @@ def test_async_aggregation_query_prep_stream_with_transaction(): query = make_async_query(parent) aggregation_query = make_async_aggregation_query(query) aggregation_query.count(alias="all") + aggregation_query.sum("someref", alias="sum_all") + aggregation_query.avg("someref", alias="avg_all") request, kwargs = aggregation_query._prep_stream(transaction=transaction) @@ -196,7 +327,6 @@ async def _async_aggregation_query_get_helper(retry=None, timeout=None, read_tim assert len(returned) == 1 for result in returned: - for r in result: assert r.alias == aggregation_result.alias assert r.value == aggregation_result.value @@ -319,31 +449,38 @@ async def test_async_aggregation_from_query(): response_pb = make_aggregation_query_response( [aggregation_result], transaction=txn_id ) - firestore_api.run_aggregation_query.return_value = AsyncIter([response_pb]) retry = None timeout = None kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) - # Execute the query and check the response. - aggregation_query = query.count(alias="total") - returned = await aggregation_query.get(transaction=transaction, **kwargs) - assert isinstance(returned, list) - assert len(returned) == 1 - - for result in returned: - for r in result: - assert r.alias == aggregation_result.alias - assert r.value == aggregation_result.value - - # Verify the mock call. - parent_path, _ = parent._parent_info() - - firestore_api.run_aggregation_query.assert_called_once_with( - request={ - "parent": parent_path, - "structured_aggregation_query": aggregation_query._to_protobuf(), - "transaction": txn_id, - }, - metadata=client._rpc_metadata, - **kwargs, - ) + # Execute each aggregation query type and check the response. + for aggregation_query in [ + query.count(alias="total"), + query.sum("foo", alias="total"), + query.avg("foo", alias="total"), + ]: + # reset api mock + firestore_api.run_aggregation_query.reset_mock() + firestore_api.run_aggregation_query.return_value = AsyncIter([response_pb]) + # run query + returned = await aggregation_query.get(transaction=transaction, **kwargs) + assert isinstance(returned, list) + assert len(returned) == 1 + + for result in returned: + for r in result: + assert r.alias == aggregation_result.alias + assert r.value == aggregation_result.value + + # Verify the mock call. + parent_path, _ = parent._parent_info() + + firestore_api.run_aggregation_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_aggregation_query": aggregation_query._to_protobuf(), + "transaction": txn_id, + }, + metadata=client._rpc_metadata, + **kwargs, + ) diff --git a/tests/unit/v1/test_async_collection.py b/tests/unit/v1/test_async_collection.py index 0599937cca..c5bce0ae8d 100644 --- a/tests/unit/v1/test_async_collection.py +++ b/tests/unit/v1/test_async_collection.py @@ -97,6 +97,36 @@ def test_async_collection_count(): assert aggregation_query._aggregations[0].alias == alias +def test_async_collection_sum(): + firestore_api = AsyncMock(spec=["create_document", "commit"]) + client = make_async_client() + client._firestore_api_internal = firestore_api + collection = _make_async_collection_reference("grand-parent", client=client) + + alias = "total" + field_ref = "someref" + aggregation_query = collection.sum(field_ref, alias=alias) + + assert len(aggregation_query._aggregations) == 1 + assert aggregation_query._aggregations[0].alias == alias + assert aggregation_query._aggregations[0].field_ref == field_ref + + +def test_async_collection_avg(): + firestore_api = AsyncMock(spec=["create_document", "commit"]) + client = make_async_client() + client._firestore_api_internal = firestore_api + collection = _make_async_collection_reference("grand-parent", client=client) + + alias = "total" + field_ref = "someref" + aggregation_query = collection.avg(field_ref, alias=alias) + + assert len(aggregation_query._aggregations) == 1 + assert aggregation_query._aggregations[0].alias == alias + assert aggregation_query._aggregations[0].field_ref == field_ref + + @pytest.mark.asyncio async def test_asynccollectionreference_add_auto_assigned(): from google.cloud.firestore_v1.types import document diff --git a/tests/unit/v1/test_async_query.py b/tests/unit/v1/test_async_query.py index b74a215c3f..c0f3d0d9ed 100644 --- a/tests/unit/v1/test_async_query.py +++ b/tests/unit/v1/test_async_query.py @@ -160,6 +160,64 @@ async def test_asyncquery_get_limit_to_last(): ) +def test_asyncquery_sum(): + from google.cloud.firestore_v1.field_path import FieldPath + from google.cloud.firestore_v1.base_aggregation import SumAggregation + + client = make_async_client() + parent = client.collection("dee") + field_str = "field_str" + field_path = FieldPath("foo", "bar") + query = make_async_query(parent) + # test with only field populated + sum_query = query.sum(field_str) + sum_agg = sum_query._aggregations[0] + assert isinstance(sum_agg, SumAggregation) + assert sum_agg.field_ref == field_str + assert sum_agg.alias is None + # test with field and alias populated + sum_query = query.sum(field_str, alias="alias") + sum_agg = sum_query._aggregations[0] + assert isinstance(sum_agg, SumAggregation) + assert sum_agg.field_ref == field_str + assert sum_agg.alias == "alias" + # test with field_path + sum_query = query.sum(field_path, alias="alias") + sum_agg = sum_query._aggregations[0] + assert isinstance(sum_agg, SumAggregation) + assert sum_agg.field_ref == "foo.bar" + assert sum_agg.alias == "alias" + + +def test_asyncquery_avg(): + from google.cloud.firestore_v1.field_path import FieldPath + from google.cloud.firestore_v1.base_aggregation import AvgAggregation + + client = make_async_client() + parent = client.collection("dee") + field_str = "field_str" + field_path = FieldPath("foo", "bar") + query = make_async_query(parent) + # test with only field populated + avg_query = query.avg(field_str) + avg_agg = avg_query._aggregations[0] + assert isinstance(avg_agg, AvgAggregation) + assert avg_agg.field_ref == field_str + assert avg_agg.alias is None + # test with field and alias populated + avg_query = query.avg(field_str, alias="alias") + avg_agg = avg_query._aggregations[0] + assert isinstance(avg_agg, AvgAggregation) + assert avg_agg.field_ref == field_str + assert avg_agg.alias == "alias" + # test with field_path + avg_query = query.avg(field_path, alias="alias") + avg_agg = avg_query._aggregations[0] + assert isinstance(avg_agg, AvgAggregation) + assert avg_agg.field_ref == "foo.bar" + assert avg_agg.alias == "alias" + + @pytest.mark.asyncio async def test_asyncquery_chunkify_w_empty(): client = make_async_client() diff --git a/tests/unit/v1/test_async_transaction.py b/tests/unit/v1/test_async_transaction.py index 12f704a6ec..7c1ab0650d 100644 --- a/tests/unit/v1/test_async_transaction.py +++ b/tests/unit/v1/test_async_transaction.py @@ -158,7 +158,6 @@ async def test_asynctransaction__rollback_not_allowed(): with pytest.raises(ValueError) as exc_info: await transaction._rollback() - assert exc_info.value.args == (_CANT_ROLLBACK,) @@ -460,135 +459,147 @@ async def test_asynctransactional__pre_commit_retry_id_already_set_success(): @pytest.mark.asyncio -async def test_asynctransactional__pre_commit_failure(): - exc = RuntimeError("Nope not today.") - to_wrap = AsyncMock(side_effect=exc, spec=[]) +async def test_asynctransactional___call__success_first_attempt(): + to_wrap = AsyncMock(return_value=mock.sentinel.result, spec=[]) wrapped = _make_async_transactional(to_wrap) - txn_id = b"gotta-fail" + txn_id = b"whole-enchilada" transaction = _make_transaction(txn_id) - with pytest.raises(RuntimeError) as exc_info: - await wrapped._pre_commit(transaction, 10, 20) - assert exc_info.value is exc + result = await wrapped(transaction, "a", b="c") + assert result is mock.sentinel.result assert transaction._id is None assert wrapped.current_id == txn_id assert wrapped.retry_id == txn_id # Verify mocks. - to_wrap.assert_called_once_with(transaction, 10, 20) + to_wrap.assert_called_once_with(transaction, "a", b="c") firestore_api = transaction._client._firestore_api firestore_api.begin_transaction.assert_called_once_with( request={"database": transaction._client._database_string, "options": None}, metadata=transaction._client._rpc_metadata, ) - firestore_api.rollback.assert_called_once_with( + firestore_api.rollback.assert_not_called() + firestore_api.commit.assert_called_once_with( request={ "database": transaction._client._database_string, + "writes": [], "transaction": txn_id, }, metadata=transaction._client._rpc_metadata, ) - firestore_api.commit.assert_not_called() @pytest.mark.asyncio -async def test_asynctransactional__pre_commit_failure_with_rollback_failure(): +async def test_asynctransactional___call__success_second_attempt(): from google.api_core import exceptions + from google.cloud.firestore_v1.types import common + from google.cloud.firestore_v1.types import firestore + from google.cloud.firestore_v1.types import write - exc1 = ValueError("I will not be only failure.") - to_wrap = AsyncMock(side_effect=exc1, spec=[]) + to_wrap = AsyncMock(return_value=mock.sentinel.result, spec=[]) wrapped = _make_async_transactional(to_wrap) - txn_id = b"both-will-fail" + txn_id = b"whole-enchilada" transaction = _make_transaction(txn_id) - # Actually force the ``rollback`` to fail as well. - exc2 = exceptions.InternalServerError("Rollback blues.") + + # Actually force the ``commit`` to fail on first / succeed on second. + exc = exceptions.Aborted("Contention junction.") firestore_api = transaction._client._firestore_api - firestore_api.rollback.side_effect = exc2 + firestore_api.commit.side_effect = [ + exc, + firestore.CommitResponse(write_results=[write.WriteResult()]), + ] - # Try to ``_pre_commit`` - with pytest.raises(exceptions.InternalServerError) as exc_info: - await wrapped._pre_commit(transaction, a="b", c="zebra") - assert exc_info.value is exc2 + # Call the __call__-able ``wrapped``. + result = await wrapped(transaction, "a", b="c") + assert result is mock.sentinel.result assert transaction._id is None assert wrapped.current_id == txn_id assert wrapped.retry_id == txn_id # Verify mocks. - to_wrap.assert_called_once_with(transaction, a="b", c="zebra") - firestore_api.begin_transaction.assert_called_once_with( - request={"database": transaction._client._database_string, "options": None}, - metadata=transaction._client._rpc_metadata, - ) - firestore_api.rollback.assert_called_once_with( - request={ - "database": transaction._client._database_string, - "transaction": txn_id, - }, - metadata=transaction._client._rpc_metadata, - ) - firestore_api.commit.assert_not_called() - - -@pytest.mark.asyncio -async def test_asynctransactional__maybe_commit_success(): - wrapped = _make_async_transactional(mock.sentinel.callable_) - - txn_id = b"nyet" - transaction = _make_transaction(txn_id) - transaction._id = txn_id # We won't call ``begin()``. - succeeded = await wrapped._maybe_commit(transaction) - assert succeeded - - # On success, _id is reset. - assert transaction._id is None - - # Verify mocks. + wrapped_call = mock.call(transaction, "a", b="c") + assert to_wrap.mock_calls == [wrapped_call, wrapped_call] firestore_api = transaction._client._firestore_api - firestore_api.begin_transaction.assert_not_called() + db_str = transaction._client._database_string + options_ = common.TransactionOptions( + read_write=common.TransactionOptions.ReadWrite(retry_transaction=txn_id) + ) + expected_calls = [ + mock.call( + request={"database": db_str, "options": None}, + metadata=transaction._client._rpc_metadata, + ), + mock.call( + request={"database": db_str, "options": options_}, + metadata=transaction._client._rpc_metadata, + ), + ] + assert firestore_api.begin_transaction.mock_calls == expected_calls firestore_api.rollback.assert_not_called() - firestore_api.commit.assert_called_once_with( - request={ - "database": transaction._client._database_string, - "writes": [], - "transaction": txn_id, - }, + commit_call = mock.call( + request={"database": db_str, "writes": [], "transaction": txn_id}, metadata=transaction._client._rpc_metadata, ) + assert firestore_api.commit.mock_calls == [commit_call, commit_call] +@pytest.mark.parametrize("max_attempts", [1, 5]) @pytest.mark.asyncio -async def test_asynctransactional__maybe_commit_failure_read_only(): +async def test_asynctransactional___call__failure_max_attempts(max_attempts): + """ + rasie retryable error and exhause max_attempts + """ from google.api_core import exceptions + from google.cloud.firestore_v1.types import common + from google.cloud.firestore_v1.async_transaction import _EXCEED_ATTEMPTS_TEMPLATE - wrapped = _make_async_transactional(mock.sentinel.callable_) + to_wrap = AsyncMock(return_value=mock.sentinel.result, spec=[]) + wrapped = _make_async_transactional(to_wrap) - txn_id = b"failed" - transaction = _make_transaction(txn_id, read_only=True) - transaction._id = txn_id # We won't call ``begin()``. - wrapped.current_id = txn_id # We won't call ``_pre_commit()``. - wrapped.retry_id = txn_id # We won't call ``_pre_commit()``. + txn_id = b"attempt_exhaustion" + transaction = _make_transaction(txn_id, max_attempts=max_attempts) - # Actually force the ``commit`` to fail (use ABORTED, but cannot - # retry since read-only). - exc = exceptions.Aborted("Read-only did a bad.") + # Actually force the ``commit`` to fail. + exc = exceptions.Aborted("Contention just once.") firestore_api = transaction._client._firestore_api firestore_api.commit.side_effect = exc - with pytest.raises(exceptions.Aborted) as exc_info: - await wrapped._maybe_commit(transaction) - assert exc_info.value is exc + # Call the __call__-able ``wrapped``. + with pytest.raises(ValueError) as exc_info: + await wrapped(transaction, "here", there=1.5) - assert transaction._id == txn_id + err_msg = _EXCEED_ATTEMPTS_TEMPLATE.format(transaction._max_attempts) + assert exc_info.value.args == (err_msg,) + # should retain cause exception + assert exc_info.value.__cause__ == exc + + assert transaction._id is None assert wrapped.current_id == txn_id assert wrapped.retry_id == txn_id # Verify mocks. - firestore_api.begin_transaction.assert_not_called() - firestore_api.rollback.assert_not_called() - firestore_api.commit.assert_called_once_with( + assert to_wrap.call_count == max_attempts + to_wrap.assert_called_with(transaction, "here", there=1.5) + assert firestore_api.begin_transaction.call_count == max_attempts + options_ = common.TransactionOptions( + read_write=common.TransactionOptions.ReadWrite(retry_transaction=txn_id) + ) + expected_calls = [ + mock.call( + request={ + "database": transaction._client._database_string, + "options": None if i == 0 else options_, + }, + metadata=transaction._client._rpc_metadata, + ) + for i in range(max_attempts) + ] + assert firestore_api.begin_transaction.call_args_list == expected_calls + assert firestore_api.commit.call_count == max_attempts + firestore_api.commit.assert_called_with( request={ "database": transaction._client._database_string, "writes": [], @@ -596,105 +607,63 @@ async def test_asynctransactional__maybe_commit_failure_read_only(): }, metadata=transaction._client._rpc_metadata, ) - - -@pytest.mark.asyncio -async def test_asynctransactional__maybe_commit_failure_can_retry(): - from google.api_core import exceptions - - wrapped = _make_async_transactional(mock.sentinel.callable_) - - txn_id = b"failed-but-retry" - transaction = _make_transaction(txn_id) - transaction._id = txn_id # We won't call ``begin()``. - wrapped.current_id = txn_id # We won't call ``_pre_commit()``. - wrapped.retry_id = txn_id # We won't call ``_pre_commit()``. - - # Actually force the ``commit`` to fail. - exc = exceptions.Aborted("Read-write did a bad.") - firestore_api = transaction._client._firestore_api - firestore_api.commit.side_effect = exc - - succeeded = await wrapped._maybe_commit(transaction) - assert not succeeded - - assert transaction._id == txn_id - assert wrapped.current_id == txn_id - assert wrapped.retry_id == txn_id - - # Verify mocks. - firestore_api.begin_transaction.assert_not_called() - firestore_api.rollback.assert_not_called() - firestore_api.commit.assert_called_once_with( + firestore_api.rollback.assert_called_once_with( request={ "database": transaction._client._database_string, - "writes": [], "transaction": txn_id, }, metadata=transaction._client._rpc_metadata, ) +@pytest.mark.parametrize("max_attempts", [1, 5]) @pytest.mark.asyncio -async def test_asynctransactional__maybe_commit_failure_cannot_retry(): +async def test_asynctransactional___call__failure_readonly(max_attempts): + """ + readonly transaction should never retry + """ from google.api_core import exceptions + from google.cloud.firestore_v1.types import common - wrapped = _make_async_transactional(mock.sentinel.callable_) + to_wrap = AsyncMock(return_value=mock.sentinel.result, spec=[]) + wrapped = _make_async_transactional(to_wrap) - txn_id = b"failed-but-not-retryable" - transaction = _make_transaction(txn_id) - transaction._id = txn_id # We won't call ``begin()``. - wrapped.current_id = txn_id # We won't call ``_pre_commit()``. - wrapped.retry_id = txn_id # We won't call ``_pre_commit()``. + txn_id = b"read_only_fail" + transaction = _make_transaction(txn_id, max_attempts=max_attempts, read_only=True) # Actually force the ``commit`` to fail. - exc = exceptions.InternalServerError("Real bad thing") + exc = exceptions.Aborted("Contention just once.") firestore_api = transaction._client._firestore_api firestore_api.commit.side_effect = exc - with pytest.raises(exceptions.InternalServerError) as exc_info: - await wrapped._maybe_commit(transaction) - assert exc_info.value is exc + # Call the __call__-able ``wrapped``. + with pytest.raises(exceptions.Aborted) as exc_info: + await wrapped(transaction, "here", there=1.5) - assert transaction._id == txn_id + assert exc_info.value == exc + + assert transaction._id is None assert wrapped.current_id == txn_id assert wrapped.retry_id == txn_id # Verify mocks. - firestore_api.begin_transaction.assert_not_called() - firestore_api.rollback.assert_not_called() - firestore_api.commit.assert_called_once_with( + to_wrap.assert_called_once_with(transaction, "here", there=1.5) + firestore_api.begin_transaction.assert_called_once_with( request={ "database": transaction._client._database_string, - "writes": [], - "transaction": txn_id, + "options": common.TransactionOptions( + read_only=common.TransactionOptions.ReadOnly() + ), }, metadata=transaction._client._rpc_metadata, ) - - -@pytest.mark.asyncio -async def test_asynctransactional___call__success_first_attempt(): - to_wrap = AsyncMock(return_value=mock.sentinel.result, spec=[]) - wrapped = _make_async_transactional(to_wrap) - - txn_id = b"whole-enchilada" - transaction = _make_transaction(txn_id) - result = await wrapped(transaction, "a", b="c") - assert result is mock.sentinel.result - - assert transaction._id is None - assert wrapped.current_id == txn_id - assert wrapped.retry_id == txn_id - - # Verify mocks. - to_wrap.assert_called_once_with(transaction, "a", b="c") - firestore_api = transaction._client._firestore_api - firestore_api.begin_transaction.assert_called_once_with( - request={"database": transaction._client._database_string, "options": None}, + firestore_api.rollback.assert_called_once_with( + request={ + "database": transaction._client._database_string, + "transaction": txn_id, + }, metadata=transaction._client._rpc_metadata, ) - firestore_api.rollback.assert_not_called() firestore_api.commit.assert_called_once_with( request={ "database": transaction._client._database_string, @@ -705,93 +674,101 @@ async def test_asynctransactional___call__success_first_attempt(): ) +@pytest.mark.parametrize("max_attempts", [1, 5]) @pytest.mark.asyncio -async def test_asynctransactional___call__success_second_attempt(): +async def test_asynctransactional___call__failure_with_non_retryable(max_attempts): + """ + call fails due to an exception that is not retryable. + Should rollback raise immediately + """ from google.api_core import exceptions - from google.cloud.firestore_v1.types import common - from google.cloud.firestore_v1.types import firestore - from google.cloud.firestore_v1.types import write to_wrap = AsyncMock(return_value=mock.sentinel.result, spec=[]) wrapped = _make_async_transactional(to_wrap) - txn_id = b"whole-enchilada" - transaction = _make_transaction(txn_id) + txn_id = b"non_retryable" + transaction = _make_transaction(txn_id, max_attempts=max_attempts) - # Actually force the ``commit`` to fail on first / succeed on second. - exc = exceptions.Aborted("Contention junction.") + # Actually force the ``commit`` to fail. + exc = exceptions.InvalidArgument("non retryable") firestore_api = transaction._client._firestore_api - firestore_api.commit.side_effect = [ - exc, - firestore.CommitResponse(write_results=[write.WriteResult()]), - ] + firestore_api.commit.side_effect = exc # Call the __call__-able ``wrapped``. - result = await wrapped(transaction, "a", b="c") - assert result is mock.sentinel.result + with pytest.raises(exceptions.InvalidArgument) as exc_info: + await wrapped(transaction, "here", there=1.5) + + assert exc_info.value == exc assert transaction._id is None assert wrapped.current_id == txn_id - assert wrapped.retry_id == txn_id # Verify mocks. - wrapped_call = mock.call(transaction, "a", b="c") - assert to_wrap.mock_calls == [wrapped_call, wrapped_call] - firestore_api = transaction._client._firestore_api - db_str = transaction._client._database_string - options_ = common.TransactionOptions( - read_write=common.TransactionOptions.ReadWrite(retry_transaction=txn_id) + to_wrap.assert_called_once_with(transaction, "here", there=1.5) + firestore_api.begin_transaction.assert_called_once_with( + request={ + "database": transaction._client._database_string, + "options": None, + }, + metadata=transaction._client._rpc_metadata, ) - expected_calls = [ - mock.call( - request={"database": db_str, "options": None}, - metadata=transaction._client._rpc_metadata, - ), - mock.call( - request={"database": db_str, "options": options_}, - metadata=transaction._client._rpc_metadata, - ), - ] - assert firestore_api.begin_transaction.mock_calls == expected_calls - firestore_api.rollback.assert_not_called() - commit_call = mock.call( - request={"database": db_str, "writes": [], "transaction": txn_id}, + firestore_api.rollback.assert_called_once_with( + request={ + "database": transaction._client._database_string, + "transaction": txn_id, + }, + metadata=transaction._client._rpc_metadata, + ) + firestore_api.commit.assert_called_once_with( + request={ + "database": transaction._client._database_string, + "writes": [], + "transaction": txn_id, + }, metadata=transaction._client._rpc_metadata, ) - assert firestore_api.commit.mock_calls == [commit_call, commit_call] @pytest.mark.asyncio -async def test_asynctransactional___call__failure(): +async def test_asynctransactional___call__failure_with_rollback_failure(): + """ + Test second failure as part of rollback + should maintain first failure as __context__ + """ from google.api_core import exceptions - from google.cloud.firestore_v1.async_transaction import _EXCEED_ATTEMPTS_TEMPLATE to_wrap = AsyncMock(return_value=mock.sentinel.result, spec=[]) wrapped = _make_async_transactional(to_wrap) - txn_id = b"only-one-shot" + txn_id = b"non_retryable" transaction = _make_transaction(txn_id, max_attempts=1) # Actually force the ``commit`` to fail. - exc = exceptions.Aborted("Contention just once.") + exc = exceptions.InvalidArgument("first error") firestore_api = transaction._client._firestore_api firestore_api.commit.side_effect = exc + # also force a second error on rollback + rb_exc = exceptions.InternalServerError("second error") + firestore_api.rollback.side_effect = rb_exc # Call the __call__-able ``wrapped``. - with pytest.raises(ValueError) as exc_info: + # should raise second error with first error as __context__ + with pytest.raises(exceptions.InternalServerError) as exc_info: await wrapped(transaction, "here", there=1.5) - err_msg = _EXCEED_ATTEMPTS_TEMPLATE.format(transaction._max_attempts) - assert exc_info.value.args == (err_msg,) + assert exc_info.value == rb_exc + assert exc_info.value.__context__ == exc assert transaction._id is None assert wrapped.current_id == txn_id - assert wrapped.retry_id == txn_id # Verify mocks. to_wrap.assert_called_once_with(transaction, "here", there=1.5) firestore_api.begin_transaction.assert_called_once_with( - request={"database": transaction._client._database_string, "options": None}, + request={ + "database": transaction._client._database_string, + "options": None, + }, metadata=transaction._client._rpc_metadata, ) firestore_api.rollback.assert_called_once_with( diff --git a/tests/unit/v1/test_base_query.py b/tests/unit/v1/test_base_query.py index 4b8093f1a7..8075e71b05 100644 --- a/tests/unit/v1/test_base_query.py +++ b/tests/unit/v1/test_base_query.py @@ -760,7 +760,6 @@ def test_basequery_end_at(): def test_basequery_where_filter_keyword_arg(): - from google.cloud.firestore_v1.types import StructuredQuery from google.cloud.firestore_v1.types import document from google.cloud.firestore_v1.types import query @@ -862,7 +861,6 @@ def test_basequery_where_filter_keyword_arg(): def test_basequery_where_cannot_pass_both_positional_and_keyword_filter_arg(): - from google.cloud.firestore_v1.base_query import FieldFilter field_path_1 = "x.y" diff --git a/tests/unit/v1/test_bundle.py b/tests/unit/v1/test_bundle.py index 8508a79b21..15ee737581 100644 --- a/tests/unit/v1/test_bundle.py +++ b/tests/unit/v1/test_bundle.py @@ -28,7 +28,6 @@ class _CollectionQueryMixin: - # Path to each document where we don't specify custom collection names or # document Ids doc_key: str = ( diff --git a/tests/unit/v1/test_collection.py b/tests/unit/v1/test_collection.py index 04e6e21985..f3bc099b97 100644 --- a/tests/unit/v1/test_collection.py +++ b/tests/unit/v1/test_collection.py @@ -65,7 +65,6 @@ def test_collection_aggregation_query(): def test_collection_count(): - collection_id1 = "rooms" document_id = "roomA" collection_id2 = "messages" @@ -82,6 +81,44 @@ def test_collection_count(): assert aggregation_query._aggregations[0].alias == alias +def test_collection_sum(): + collection_id1 = "rooms" + document_id = "roomA" + collection_id2 = "messages" + client = mock.sentinel.client + + collection = _make_collection_reference( + collection_id1, document_id, collection_id2, client=client + ) + + alias = "total" + field_ref = "someref" + aggregation_query = collection.sum(field_ref, alias=alias) + + assert len(aggregation_query._aggregations) == 1 + assert aggregation_query._aggregations[0].alias == alias + assert aggregation_query._aggregations[0].field_ref == field_ref + + +def test_collection_avg(): + collection_id1 = "rooms" + document_id = "roomA" + collection_id2 = "messages" + client = mock.sentinel.client + + collection = _make_collection_reference( + collection_id1, document_id, collection_id2, client=client + ) + + alias = "total" + field_ref = "someref" + aggregation_query = collection.avg(field_ref, alias=alias) + + assert len(aggregation_query._aggregations) == 1 + assert aggregation_query._aggregations[0].alias == alias + assert aggregation_query._aggregations[0].field_ref == field_ref + + def test_constructor(): collection_id1 = "rooms" document_id = "roomA" @@ -339,7 +376,6 @@ def test_get_w_retry_timeout(query_class): @mock.patch("google.cloud.firestore_v1.query.Query", autospec=True) def test_get_with_transaction(query_class): - collection = _make_collection_reference("collection") transaction = mock.sentinel.txn get_response = collection.get(transaction=transaction) diff --git a/tests/unit/v1/test_cross_language.py b/tests/unit/v1/test_cross_language.py index 2c5823fc9c..44f7985f1c 100644 --- a/tests/unit/v1/test_cross_language.py +++ b/tests/unit/v1/test_cross_language.py @@ -465,7 +465,6 @@ def parse_query(testcase): query = collection for clause in testcase.clauses: - if "select" in clause: field_paths = [ ".".join(field_path.field) for field_path in clause.select.fields diff --git a/tests/unit/v1/test_order.py b/tests/unit/v1/test_order.py index 1287e77a08..8abb295507 100644 --- a/tests/unit/v1/test_order.py +++ b/tests/unit/v1/test_order.py @@ -136,7 +136,6 @@ def test_order_compare_across_heterogenous_values(): for left in groups[i]: for j in range(len(groups)): for right in groups[j]: - expected = Order._compare_to(i, j) assert target.compare(left, right) == expected diff --git a/tests/unit/v1/test_query.py b/tests/unit/v1/test_query.py index ad972aa763..a7f2e60162 100644 --- a/tests/unit/v1/test_query.py +++ b/tests/unit/v1/test_query.py @@ -152,6 +152,66 @@ def test_query_get_limit_to_last(database): ) +@pytest.mark.parametrize("database", [None, "somedb"]) +def test_query_sum(database): + from google.cloud.firestore_v1.field_path import FieldPath + from google.cloud.firestore_v1.base_aggregation import SumAggregation + + client = make_client(database=database) + parent = client.collection("dee") + field_str = "field_str" + field_path = FieldPath("foo", "bar") + query = make_query(parent) + # test with only field populated + sum_query = query.sum(field_str) + sum_agg = sum_query._aggregations[0] + assert isinstance(sum_agg, SumAggregation) + assert sum_agg.field_ref == field_str + assert sum_agg.alias is None + # test with field and alias populated + sum_query = query.sum(field_str, alias="alias") + sum_agg = sum_query._aggregations[0] + assert isinstance(sum_agg, SumAggregation) + assert sum_agg.field_ref == field_str + assert sum_agg.alias == "alias" + # test with field_path + sum_query = query.sum(field_path, alias="alias") + sum_agg = sum_query._aggregations[0] + assert isinstance(sum_agg, SumAggregation) + assert sum_agg.field_ref == "foo.bar" + assert sum_agg.alias == "alias" + + +@pytest.mark.parametrize("database", [None, "somedb"]) +def test_query_avg(database): + from google.cloud.firestore_v1.field_path import FieldPath + from google.cloud.firestore_v1.base_aggregation import AvgAggregation + + client = make_client(database=database) + parent = client.collection("dee") + field_str = "field_str" + field_path = FieldPath("foo", "bar") + query = make_query(parent) + # test with only field populated + avg_query = query.avg(field_str) + avg_agg = avg_query._aggregations[0] + assert isinstance(avg_agg, AvgAggregation) + assert avg_agg.field_ref == field_str + assert avg_agg.alias is None + # test with field and alias populated + avg_query = query.avg(field_str, alias="alias") + avg_agg = avg_query._aggregations[0] + assert isinstance(avg_agg, AvgAggregation) + assert avg_agg.field_ref == field_str + assert avg_agg.alias == "alias" + # test with field_path + avg_query = query.avg(field_path, alias="alias") + avg_agg = avg_query._aggregations[0] + assert isinstance(avg_agg, AvgAggregation) + assert avg_agg.field_ref == "foo.bar" + assert avg_agg.alias == "alias" + + @pytest.mark.parametrize("database", [None, "somedb"]) def test_query_chunkify_w_empty(database): client = make_client(database=database) diff --git a/tests/unit/v1/test_transaction.py b/tests/unit/v1/test_transaction.py index 27366b276e..26bb5cc9ca 100644 --- a/tests/unit/v1/test_transaction.py +++ b/tests/unit/v1/test_transaction.py @@ -464,135 +464,149 @@ def test__transactional__pre_commit_retry_id_already_set_success(database): @pytest.mark.parametrize("database", [None, "somedb"]) -def test__transactional__pre_commit_failure(database): - exc = RuntimeError("Nope not today.") - to_wrap = mock.Mock(side_effect=exc, spec=[]) +def test__transactional___call__success_first_attempt(database): + to_wrap = mock.Mock(return_value=mock.sentinel.result, spec=[]) wrapped = _make__transactional(to_wrap) - txn_id = b"gotta-fail" + txn_id = b"whole-enchilada" transaction = _make_transaction_pb(txn_id, database=database) - with pytest.raises(RuntimeError) as exc_info: - wrapped._pre_commit(transaction, 10, 20) - assert exc_info.value is exc + result = wrapped(transaction, "a", b="c") + assert result is mock.sentinel.result assert transaction._id is None assert wrapped.current_id == txn_id assert wrapped.retry_id == txn_id # Verify mocks. - to_wrap.assert_called_once_with(transaction, 10, 20) + to_wrap.assert_called_once_with(transaction, "a", b="c") firestore_api = transaction._client._firestore_api firestore_api.begin_transaction.assert_called_once_with( request={"database": transaction._client._database_string, "options": None}, metadata=transaction._client._rpc_metadata, ) - firestore_api.rollback.assert_called_once_with( + firestore_api.rollback.assert_not_called() + firestore_api.commit.assert_called_once_with( request={ "database": transaction._client._database_string, + "writes": [], "transaction": txn_id, }, metadata=transaction._client._rpc_metadata, ) - firestore_api.commit.assert_not_called() @pytest.mark.parametrize("database", [None, "somedb"]) -def test__transactional__pre_commit_failure_with_rollback_failure(database): +def test__transactional___call__success_second_attempt(database): from google.api_core import exceptions + from google.cloud.firestore_v1.types import common + from google.cloud.firestore_v1.types import firestore + from google.cloud.firestore_v1.types import write - exc1 = ValueError("I will not be only failure.") - to_wrap = mock.Mock(side_effect=exc1, spec=[]) + to_wrap = mock.Mock(return_value=mock.sentinel.result, spec=[]) wrapped = _make__transactional(to_wrap) - txn_id = b"both-will-fail" + txn_id = b"whole-enchilada" transaction = _make_transaction_pb(txn_id, database=database) - # Actually force the ``rollback`` to fail as well. - exc2 = exceptions.InternalServerError("Rollback blues.") + + # Actually force the ``commit`` to fail on first / succeed on second. + exc = exceptions.Aborted("Contention junction.") firestore_api = transaction._client._firestore_api - firestore_api.rollback.side_effect = exc2 + firestore_api.commit.side_effect = [ + exc, + firestore.CommitResponse(write_results=[write.WriteResult()]), + ] - # Try to ``_pre_commit`` - with pytest.raises(exceptions.InternalServerError) as exc_info: - wrapped._pre_commit(transaction, a="b", c="zebra") - assert exc_info.value is exc2 + # Call the __call__-able ``wrapped``. + result = wrapped(transaction, "a", b="c") + assert result is mock.sentinel.result assert transaction._id is None assert wrapped.current_id == txn_id assert wrapped.retry_id == txn_id # Verify mocks. - to_wrap.assert_called_once_with(transaction, a="b", c="zebra") - firestore_api.begin_transaction.assert_called_once_with( - request={"database": transaction._client._database_string, "options": None}, - metadata=transaction._client._rpc_metadata, - ) - firestore_api.rollback.assert_called_once_with( - request={ - "database": transaction._client._database_string, - "transaction": txn_id, - }, - metadata=transaction._client._rpc_metadata, - ) - firestore_api.commit.assert_not_called() - - -@pytest.mark.parametrize("database", [None, "somedb"]) -def test__transactional__maybe_commit_success(database): - wrapped = _make__transactional(mock.sentinel.callable_) - - txn_id = b"nyet" - transaction = _make_transaction_pb(txn_id, database=database) - transaction._id = txn_id # We won't call ``begin()``. - succeeded = wrapped._maybe_commit(transaction) - assert succeeded - - # On success, _id is reset. - assert transaction._id is None - - # Verify mocks. + wrapped_call = mock.call(transaction, "a", b="c") + assert to_wrap.mock_calls, [wrapped_call == wrapped_call] firestore_api = transaction._client._firestore_api - firestore_api.begin_transaction.assert_not_called() + db_str = transaction._client._database_string + options_ = common.TransactionOptions( + read_write=common.TransactionOptions.ReadWrite(retry_transaction=txn_id) + ) + expected_calls = [ + mock.call( + request={"database": db_str, "options": None}, + metadata=transaction._client._rpc_metadata, + ), + mock.call( + request={"database": db_str, "options": options_}, + metadata=transaction._client._rpc_metadata, + ), + ] + assert firestore_api.begin_transaction.mock_calls == expected_calls firestore_api.rollback.assert_not_called() - firestore_api.commit.assert_called_once_with( - request={ - "database": transaction._client._database_string, - "writes": [], - "transaction": txn_id, - }, + commit_call = mock.call( + request={"database": db_str, "writes": [], "transaction": txn_id}, metadata=transaction._client._rpc_metadata, ) + assert firestore_api.commit.mock_calls == [commit_call, commit_call] @pytest.mark.parametrize("database", [None, "somedb"]) -def test__transactional__maybe_commit_failure_read_only(database): +@pytest.mark.parametrize("max_attempts", [1, 5]) +def test_transactional___call__failure_max_attempts(database, max_attempts): + """ + rasie retryable error and exhause max_attempts + """ from google.api_core import exceptions + from google.cloud.firestore_v1.types import common + from google.cloud.firestore_v1.transaction import _EXCEED_ATTEMPTS_TEMPLATE - wrapped = _make__transactional(mock.sentinel.callable_) + to_wrap = mock.Mock(return_value=mock.sentinel.result, spec=[]) + wrapped = _make__transactional(to_wrap) - txn_id = b"failed" - transaction = _make_transaction_pb(txn_id, read_only=True, database=database) - transaction._id = txn_id # We won't call ``begin()``. - wrapped.current_id = txn_id # We won't call ``_pre_commit()``. - wrapped.retry_id = txn_id # We won't call ``_pre_commit()``. + txn_id = b"attempt_exhaustion" + transaction = _make_transaction_pb( + txn_id, database=database, max_attempts=max_attempts + ) - # Actually force the ``commit`` to fail (use ABORTED, but cannot - # retry since read-only). - exc = exceptions.Aborted("Read-only did a bad.") + # Actually force the ``commit`` to fail. + exc = exceptions.Aborted("Contention just once.") firestore_api = transaction._client._firestore_api firestore_api.commit.side_effect = exc - with pytest.raises(exceptions.Aborted) as exc_info: - wrapped._maybe_commit(transaction) - assert exc_info.value is exc + # Call the __call__-able ``wrapped``. + with pytest.raises(ValueError) as exc_info: + wrapped(transaction, "here", there=1.5) - assert transaction._id == txn_id + err_msg = _EXCEED_ATTEMPTS_TEMPLATE.format(transaction._max_attempts) + assert exc_info.value.args == (err_msg,) + # should retain cause exception + assert exc_info.value.__cause__ == exc + + assert transaction._id is None assert wrapped.current_id == txn_id assert wrapped.retry_id == txn_id # Verify mocks. - firestore_api.begin_transaction.assert_not_called() - firestore_api.rollback.assert_not_called() - firestore_api.commit.assert_called_once_with( + assert to_wrap.call_count == max_attempts + to_wrap.assert_called_with(transaction, "here", there=1.5) + assert firestore_api.begin_transaction.call_count == max_attempts + options_ = common.TransactionOptions( + read_write=common.TransactionOptions.ReadWrite(retry_transaction=txn_id) + ) + expected_calls = [ + mock.call( + request={ + "database": transaction._client._database_string, + "options": None if i == 0 else options_, + }, + metadata=transaction._client._rpc_metadata, + ) + for i in range(max_attempts) + ] + assert firestore_api.begin_transaction.call_args_list == expected_calls + assert firestore_api.commit.call_count == max_attempts + firestore_api.commit.assert_called_with( request={ "database": transaction._client._database_string, "writes": [], @@ -600,39 +614,9 @@ def test__transactional__maybe_commit_failure_read_only(database): }, metadata=transaction._client._rpc_metadata, ) - - -@pytest.mark.parametrize("database", [None, "somedb"]) -def test__transactional__maybe_commit_failure_can_retry(database): - from google.api_core import exceptions - - wrapped = _make__transactional(mock.sentinel.callable_) - - txn_id = b"failed-but-retry" - transaction = _make_transaction_pb(txn_id, database=database) - transaction._id = txn_id # We won't call ``begin()``. - wrapped.current_id = txn_id # We won't call ``_pre_commit()``. - wrapped.retry_id = txn_id # We won't call ``_pre_commit()``. - - # Actually force the ``commit`` to fail. - exc = exceptions.Aborted("Read-write did a bad.") - firestore_api = transaction._client._firestore_api - firestore_api.commit.side_effect = exc - - succeeded = wrapped._maybe_commit(transaction) - assert not succeeded - - assert transaction._id == txn_id - assert wrapped.current_id == txn_id - assert wrapped.retry_id == txn_id - - # Verify mocks. - firestore_api.begin_transaction.assert_not_called() - firestore_api.rollback.assert_not_called() - firestore_api.commit.assert_called_once_with( + firestore_api.rollback.assert_called_once_with( request={ "database": transaction._client._database_string, - "writes": [], "transaction": txn_id, }, metadata=transaction._client._rpc_metadata, @@ -640,65 +624,55 @@ def test__transactional__maybe_commit_failure_can_retry(database): @pytest.mark.parametrize("database", [None, "somedb"]) -def test__transactional__maybe_commit_failure_cannot_retry(database): +@pytest.mark.parametrize("max_attempts", [1, 5]) +def test_transactional___call__failure_readonly(database, max_attempts): + """ + readonly transaction should never retry + """ from google.api_core import exceptions + from google.cloud.firestore_v1.types import common - wrapped = _make__transactional(mock.sentinel.callable_) + to_wrap = mock.Mock(return_value=mock.sentinel.result, spec=[]) + wrapped = _make__transactional(to_wrap) - txn_id = b"failed-but-not-retryable" - transaction = _make_transaction_pb(txn_id, database=database) - transaction._id = txn_id # We won't call ``begin()``. - wrapped.current_id = txn_id # We won't call ``_pre_commit()``. - wrapped.retry_id = txn_id # We won't call ``_pre_commit()``. + txn_id = b"read_only_fail" + transaction = _make_transaction_pb( + txn_id, database=database, max_attempts=max_attempts, read_only=True + ) # Actually force the ``commit`` to fail. - exc = exceptions.InternalServerError("Real bad thing") + exc = exceptions.Aborted("Contention just once.") firestore_api = transaction._client._firestore_api firestore_api.commit.side_effect = exc - with pytest.raises(exceptions.InternalServerError) as exc_info: - wrapped._maybe_commit(transaction) - assert exc_info.value is exc + # Call the __call__-able ``wrapped``. + with pytest.raises(exceptions.Aborted) as exc_info: + wrapped(transaction, "here", there=1.5) - assert transaction._id == txn_id + assert exc_info.value == exc + + assert transaction._id is None assert wrapped.current_id == txn_id assert wrapped.retry_id == txn_id # Verify mocks. - firestore_api.begin_transaction.assert_not_called() - firestore_api.rollback.assert_not_called() - firestore_api.commit.assert_called_once_with( + to_wrap.assert_called_once_with(transaction, "here", there=1.5) + firestore_api.begin_transaction.assert_called_once_with( request={ "database": transaction._client._database_string, - "writes": [], - "transaction": txn_id, + "options": common.TransactionOptions( + read_only=common.TransactionOptions.ReadOnly() + ), }, metadata=transaction._client._rpc_metadata, ) - - -@pytest.mark.parametrize("database", [None, "somedb"]) -def test__transactional___call__success_first_attempt(database): - to_wrap = mock.Mock(return_value=mock.sentinel.result, spec=[]) - wrapped = _make__transactional(to_wrap) - - txn_id = b"whole-enchilada" - transaction = _make_transaction_pb(txn_id, database=database) - result = wrapped(transaction, "a", b="c") - assert result is mock.sentinel.result - - assert transaction._id is None - assert wrapped.current_id == txn_id - assert wrapped.retry_id == txn_id - - # Verify mocks. - to_wrap.assert_called_once_with(transaction, "a", b="c") - firestore_api = transaction._client._firestore_api - firestore_api.begin_transaction.assert_called_once_with( - request={"database": transaction._client._database_string, "options": None}, + firestore_api.rollback.assert_called_once_with( + request={ + "database": transaction._client._database_string, + "transaction": txn_id, + }, metadata=transaction._client._rpc_metadata, ) - firestore_api.rollback.assert_not_called() firestore_api.commit.assert_called_once_with( request={ "database": transaction._client._database_string, @@ -710,92 +684,102 @@ def test__transactional___call__success_first_attempt(database): @pytest.mark.parametrize("database", [None, "somedb"]) -def test__transactional___call__success_second_attempt(database): +@pytest.mark.parametrize("max_attempts", [1, 5]) +def test_transactional___call__failure_with_non_retryable(database, max_attempts): + """ + call fails due to an exception that is not retryable. + Should rollback raise immediately + """ from google.api_core import exceptions - from google.cloud.firestore_v1.types import common - from google.cloud.firestore_v1.types import firestore - from google.cloud.firestore_v1.types import write to_wrap = mock.Mock(return_value=mock.sentinel.result, spec=[]) wrapped = _make__transactional(to_wrap) - txn_id = b"whole-enchilada" - transaction = _make_transaction_pb(txn_id, database=database) + txn_id = b"non_retryable" + transaction = _make_transaction_pb( + txn_id, database=database, max_attempts=max_attempts + ) - # Actually force the ``commit`` to fail on first / succeed on second. - exc = exceptions.Aborted("Contention junction.") + # Actually force the ``commit`` to fail. + exc = exceptions.InvalidArgument("non retryable") firestore_api = transaction._client._firestore_api - firestore_api.commit.side_effect = [ - exc, - firestore.CommitResponse(write_results=[write.WriteResult()]), - ] + firestore_api.commit.side_effect = exc # Call the __call__-able ``wrapped``. - result = wrapped(transaction, "a", b="c") - assert result is mock.sentinel.result + with pytest.raises(exceptions.InvalidArgument) as exc_info: + wrapped(transaction, "here", there=1.5) + + assert exc_info.value == exc assert transaction._id is None assert wrapped.current_id == txn_id - assert wrapped.retry_id == txn_id # Verify mocks. - wrapped_call = mock.call(transaction, "a", b="c") - assert to_wrap.mock_calls, [wrapped_call == wrapped_call] - firestore_api = transaction._client._firestore_api - db_str = transaction._client._database_string - options_ = common.TransactionOptions( - read_write=common.TransactionOptions.ReadWrite(retry_transaction=txn_id) + to_wrap.assert_called_once_with(transaction, "here", there=1.5) + firestore_api.begin_transaction.assert_called_once_with( + request={ + "database": transaction._client._database_string, + "options": None, + }, + metadata=transaction._client._rpc_metadata, ) - expected_calls = [ - mock.call( - request={"database": db_str, "options": None}, - metadata=transaction._client._rpc_metadata, - ), - mock.call( - request={"database": db_str, "options": options_}, - metadata=transaction._client._rpc_metadata, - ), - ] - assert firestore_api.begin_transaction.mock_calls == expected_calls - firestore_api.rollback.assert_not_called() - commit_call = mock.call( - request={"database": db_str, "writes": [], "transaction": txn_id}, + firestore_api.rollback.assert_called_once_with( + request={ + "database": transaction._client._database_string, + "transaction": txn_id, + }, + metadata=transaction._client._rpc_metadata, + ) + firestore_api.commit.assert_called_once_with( + request={ + "database": transaction._client._database_string, + "writes": [], + "transaction": txn_id, + }, metadata=transaction._client._rpc_metadata, ) - assert firestore_api.commit.mock_calls == [commit_call, commit_call] @pytest.mark.parametrize("database", [None, "somedb"]) -def test__transactional___call__failure(database): +def test_transactional___call__failure_with_rollback_failure(database): + """ + Test second failure as part of rollback + should maintain first failure as __context__ + """ from google.api_core import exceptions - from google.cloud.firestore_v1.base_transaction import _EXCEED_ATTEMPTS_TEMPLATE to_wrap = mock.Mock(return_value=mock.sentinel.result, spec=[]) wrapped = _make__transactional(to_wrap) - txn_id = b"only-one-shot" - transaction = _make_transaction_pb(txn_id, max_attempts=1, database=database) + txn_id = b"non_retryable" + transaction = _make_transaction_pb(txn_id, database=database, max_attempts=1) # Actually force the ``commit`` to fail. - exc = exceptions.Aborted("Contention just once.") + exc = exceptions.InvalidArgument("first error") firestore_api = transaction._client._firestore_api firestore_api.commit.side_effect = exc + # also force a second error on rollback + rb_exc = exceptions.InternalServerError("second error") + firestore_api.rollback.side_effect = rb_exc # Call the __call__-able ``wrapped``. - with pytest.raises(ValueError) as exc_info: + # should raise second error with first error as __context__ + with pytest.raises(exceptions.InternalServerError) as exc_info: wrapped(transaction, "here", there=1.5) - err_msg = _EXCEED_ATTEMPTS_TEMPLATE.format(transaction._max_attempts) - assert exc_info.value.args == (err_msg,) + assert exc_info.value == rb_exc + assert exc_info.value.__context__ == exc assert transaction._id is None assert wrapped.current_id == txn_id - assert wrapped.retry_id == txn_id # Verify mocks. to_wrap.assert_called_once_with(transaction, "here", there=1.5) firestore_api.begin_transaction.assert_called_once_with( - request={"database": transaction._client._database_string, "options": None}, + request={ + "database": transaction._client._database_string, + "options": None, + }, metadata=transaction._client._rpc_metadata, ) firestore_api.rollback.assert_called_once_with(