diff --git a/.github/.OwlBot.lock.yaml b/.github/.OwlBot.lock.yaml index 6201596218..8b90899d21 100644 --- a/.github/.OwlBot.lock.yaml +++ b/.github/.OwlBot.lock.yaml @@ -13,4 +13,5 @@ # limitations under the License. docker: image: gcr.io/cloud-devrel-public-resources/owlbot-python:latest - digest: sha256:5651442a6336971a2fb2df40fb56b3337df67cafa14c0809cc89cb34ccee1b8e + digest: sha256:2dc6f67639bee669c33c6277a624ab9857d363e2fd33ac5b02d417b7d25f1ffc +# created: 2024-08-15T17:41:26.438340772Z diff --git a/.github/workflows/system_emulated.yml b/.github/workflows/system_emulated.yml index ec60eae65f..66f4367a68 100644 --- a/.github/workflows/system_emulated.yml +++ b/.github/workflows/system_emulated.yml @@ -20,7 +20,7 @@ jobs: python-version: '3.7' - name: Setup GCloud SDK - uses: google-github-actions/setup-gcloud@v2.1.0 + uses: google-github-actions/setup-gcloud@v2.1.1 - name: Install / run Nox run: | diff --git a/.kokoro/docker/docs/Dockerfile b/.kokoro/docker/docs/Dockerfile index a26ce61930..e5410e296b 100644 --- a/.kokoro/docker/docs/Dockerfile +++ b/.kokoro/docker/docs/Dockerfile @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ubuntu:22.04 +from ubuntu:24.04 ENV DEBIAN_FRONTEND noninteractive @@ -40,7 +40,6 @@ RUN apt-get update \ libssl-dev \ libsqlite3-dev \ portaudio19-dev \ - python3-distutils \ redis-server \ software-properties-common \ ssh \ @@ -60,28 +59,31 @@ RUN apt-get update \ && rm -rf /var/lib/apt/lists/* \ && rm -f /var/cache/apt/archives/*.deb -###################### Install python 3.9.13 -# Download python 3.9.13 -RUN wget https://www.python.org/ftp/python/3.9.13/Python-3.9.13.tgz +###################### Install python 3.10.14 for docs/docfx session + +# Download python 3.10.14 +RUN wget https://www.python.org/ftp/python/3.10.14/Python-3.10.14.tgz # Extract files -RUN tar -xvf Python-3.9.13.tgz +RUN tar -xvf Python-3.10.14.tgz -# Install python 3.9.13 -RUN ./Python-3.9.13/configure --enable-optimizations +# Install python 3.10.14 +RUN ./Python-3.10.14/configure --enable-optimizations RUN make altinstall +ENV PATH /usr/local/bin/python3.10:$PATH + ###################### Install pip RUN wget -O /tmp/get-pip.py 'https://bootstrap.pypa.io/get-pip.py' \ - && python3 /tmp/get-pip.py \ + && python3.10 /tmp/get-pip.py \ && rm /tmp/get-pip.py # Test pip -RUN python3 -m pip +RUN python3.10 -m pip # Install build requirements COPY requirements.txt /requirements.txt -RUN python3 -m pip install --require-hashes -r requirements.txt +RUN python3.10 -m pip install --require-hashes -r requirements.txt -CMD ["python3.8"] +CMD ["python3.10"] diff --git a/.kokoro/docker/docs/requirements.txt b/.kokoro/docker/docs/requirements.txt index 0e5d70f20f..7129c77155 100644 --- a/.kokoro/docker/docs/requirements.txt +++ b/.kokoro/docker/docs/requirements.txt @@ -4,9 +4,9 @@ # # pip-compile --allow-unsafe --generate-hashes requirements.in # -argcomplete==3.2.3 \ - --hash=sha256:bf7900329262e481be5a15f56f19736b376df6f82ed27576fa893652c5de6c23 \ - --hash=sha256:c12355e0494c76a2a7b73e3a59b09024ca0ba1e279fb9ed6c1b82d5b74b6a70c +argcomplete==3.4.0 \ + --hash=sha256:69a79e083a716173e5532e0fa3bef45f793f4e61096cf52b5a42c0211c8b8aa5 \ + --hash=sha256:c2abcdfe1be8ace47ba777d4fce319eb13bf8ad9dace8d085dcad6eded88057f # via nox colorlog==6.8.2 \ --hash=sha256:3e3e079a41feb5a1b64f978b5ea4f46040a94f11f0e8bbb8261e3dbbeca64d44 \ @@ -16,23 +16,27 @@ distlib==0.3.8 \ --hash=sha256:034db59a0b96f8ca18035f36290806a9a6e6bd9d1ff91e45a7f172eb17e51784 \ --hash=sha256:1530ea13e350031b6312d8580ddb6b27a104275a31106523b8f123787f494f64 # via virtualenv -filelock==3.13.1 \ - --hash=sha256:521f5f56c50f8426f5e03ad3b281b490a87ef15bc6c526f168290f0c7148d44e \ - --hash=sha256:57dbda9b35157b05fb3e58ee91448612eb674172fab98ee235ccb0b5bee19a1c +filelock==3.15.4 \ + --hash=sha256:2207938cbc1844345cb01a5a95524dae30f0ce089eba5b00378295a17e3e90cb \ + --hash=sha256:6ca1fffae96225dab4c6eaf1c4f4f28cd2568d3ec2a44e15a08520504de468e7 # via virtualenv -nox==2024.3.2 \ - --hash=sha256:e53514173ac0b98dd47585096a55572fe504fecede58ced708979184d05440be \ - --hash=sha256:f521ae08a15adbf5e11f16cb34e8d0e6ea521e0b92868f684e91677deb974553 +nox==2024.4.15 \ + --hash=sha256:6492236efa15a460ecb98e7b67562a28b70da006ab0be164e8821177577c0565 \ + --hash=sha256:ecf6700199cdfa9e5ea0a41ff5e6ef4641d09508eda6edb89d9987864115817f # via -r requirements.in -packaging==24.0 \ - --hash=sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5 \ - --hash=sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9 +packaging==24.1 \ + --hash=sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002 \ + --hash=sha256:5b8f2217dbdbd2f7f384c41c628544e6d52f2d0f53c6d0c3ea61aa5d1d7ff124 # via nox -platformdirs==4.2.0 \ - --hash=sha256:0614df2a2f37e1a662acbd8e2b25b92ccf8632929bc6d43467e17fe89c75e068 \ - --hash=sha256:ef0cc731df711022c174543cb70a9b5bd22e5a9337c8624ef2c2ceb8ddad8768 +platformdirs==4.2.2 \ + --hash=sha256:2d7a1657e36a80ea911db832a8a6ece5ee53d8de21edd5cc5879af6530b1bfee \ + --hash=sha256:38b7b51f512eed9e84a22788b4bce1de17c0adb134d6becb09836e37d8654cd3 # via virtualenv -virtualenv==20.25.1 \ - --hash=sha256:961c026ac520bac5f69acb8ea063e8a4f071bcc9457b9c1f28f6b085c511583a \ - --hash=sha256:e08e13ecdca7a0bd53798f356d5831434afa5b07b93f0abdf0797b7a06ffe197 +tomli==2.0.1 \ + --hash=sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc \ + --hash=sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f + # via nox +virtualenv==20.26.3 \ + --hash=sha256:4c43a2a236279d9ea36a0d76f98d84bd6ca94ac4e0f4a3b9d46d05e10fea542a \ + --hash=sha256:8cc4a31139e796e9a7de2cd5cf2489de1217193116a8fd42328f1bd65f434589 # via nox diff --git a/.kokoro/publish-docs.sh b/.kokoro/publish-docs.sh index 38f083f05a..233205d580 100755 --- a/.kokoro/publish-docs.sh +++ b/.kokoro/publish-docs.sh @@ -21,18 +21,18 @@ export PYTHONUNBUFFERED=1 export PATH="${HOME}/.local/bin:${PATH}" # Install nox -python3 -m pip install --require-hashes -r .kokoro/requirements.txt -python3 -m nox --version +python3.10 -m pip install --require-hashes -r .kokoro/requirements.txt +python3.10 -m nox --version # build docs nox -s docs # create metadata -python3 -m docuploader create-metadata \ +python3.10 -m docuploader create-metadata \ --name=$(jq --raw-output '.name // empty' .repo-metadata.json) \ - --version=$(python3 setup.py --version) \ + --version=$(python3.10 setup.py --version) \ --language=$(jq --raw-output '.language // empty' .repo-metadata.json) \ - --distribution-name=$(python3 setup.py --name) \ + --distribution-name=$(python3.10 setup.py --name) \ --product-page=$(jq --raw-output '.product_documentation // empty' .repo-metadata.json) \ --github-repository=$(jq --raw-output '.repo // empty' .repo-metadata.json) \ --issue-tracker=$(jq --raw-output '.issue_tracker // empty' .repo-metadata.json) @@ -40,18 +40,18 @@ python3 -m docuploader create-metadata \ cat docs.metadata # upload docs -python3 -m docuploader upload docs/_build/html --metadata-file docs.metadata --staging-bucket "${STAGING_BUCKET}" +python3.10 -m docuploader upload docs/_build/html --metadata-file docs.metadata --staging-bucket "${STAGING_BUCKET}" # docfx yaml files nox -s docfx # create metadata. -python3 -m docuploader create-metadata \ +python3.10 -m docuploader create-metadata \ --name=$(jq --raw-output '.name // empty' .repo-metadata.json) \ - --version=$(python3 setup.py --version) \ + --version=$(python3.10 setup.py --version) \ --language=$(jq --raw-output '.language // empty' .repo-metadata.json) \ - --distribution-name=$(python3 setup.py --name) \ + --distribution-name=$(python3.10 setup.py --name) \ --product-page=$(jq --raw-output '.product_documentation // empty' .repo-metadata.json) \ --github-repository=$(jq --raw-output '.repo // empty' .repo-metadata.json) \ --issue-tracker=$(jq --raw-output '.issue_tracker // empty' .repo-metadata.json) @@ -59,4 +59,4 @@ python3 -m docuploader create-metadata \ cat docs.metadata # upload docs -python3 -m docuploader upload docs/_build/html/docfx_yaml --metadata-file docs.metadata --destination-prefix docfx --staging-bucket "${V2_STAGING_BUCKET}" +python3.10 -m docuploader upload docs/_build/html/docfx_yaml --metadata-file docs.metadata --destination-prefix docfx --staging-bucket "${V2_STAGING_BUCKET}" diff --git a/.kokoro/requirements.txt b/.kokoro/requirements.txt index 35ece0e4d2..9622baf0ba 100644 --- a/.kokoro/requirements.txt +++ b/.kokoro/requirements.txt @@ -20,9 +20,9 @@ cachetools==5.3.3 \ --hash=sha256:0abad1021d3f8325b2fc1d2e9c8b9c9d57b04c3932657a72465447332c24d945 \ --hash=sha256:ba29e2dfa0b8b556606f097407ed1aa62080ee108ab0dc5ec9d6a723a007d105 # via google-auth -certifi==2024.6.2 \ - --hash=sha256:3cd43f1c6fa7dedc5899d69d3ad0398fd018ad1a17fba83ddaf78aa46c747516 \ - --hash=sha256:ddc6c8ce995e6987e7faf5e3f1b02b302836a0e5d98ece18392cb1a36c72ad56 +certifi==2024.7.4 \ + --hash=sha256:5a1e7645bc0ec61a09e26c36f6106dd4cf40c6db3a1fb6352b0244e7fb057c7b \ + --hash=sha256:c198e21b1289c2ab85ee4e67bb4b4ef3ead0892059901a8d5b622f24a1101e90 # via requests cffi==1.16.0 \ --hash=sha256:0c9ef6ff37e974b73c25eecc13952c55bceed9112be2d9d938ded8e856138bcc \ @@ -371,23 +371,23 @@ more-itertools==10.3.0 \ # via # jaraco-classes # jaraco-functools -nh3==0.2.17 \ - --hash=sha256:0316c25b76289cf23be6b66c77d3608a4fdf537b35426280032f432f14291b9a \ - --hash=sha256:1a814dd7bba1cb0aba5bcb9bebcc88fd801b63e21e2450ae6c52d3b3336bc911 \ - --hash=sha256:1aa52a7def528297f256de0844e8dd680ee279e79583c76d6fa73a978186ddfb \ - --hash=sha256:22c26e20acbb253a5bdd33d432a326d18508a910e4dcf9a3316179860d53345a \ - --hash=sha256:40015514022af31975c0b3bca4014634fa13cb5dc4dbcbc00570acc781316dcc \ - --hash=sha256:40d0741a19c3d645e54efba71cb0d8c475b59135c1e3c580f879ad5514cbf028 \ - --hash=sha256:551672fd71d06cd828e282abdb810d1be24e1abb7ae2543a8fa36a71c1006fe9 \ - --hash=sha256:66f17d78826096291bd264f260213d2b3905e3c7fae6dfc5337d49429f1dc9f3 \ - --hash=sha256:85cdbcca8ef10733bd31f931956f7fbb85145a4d11ab9e6742bbf44d88b7e351 \ - --hash=sha256:a3f55fabe29164ba6026b5ad5c3151c314d136fd67415a17660b4aaddacf1b10 \ - --hash=sha256:b4427ef0d2dfdec10b641ed0bdaf17957eb625b2ec0ea9329b3d28806c153d71 \ - --hash=sha256:ba73a2f8d3a1b966e9cdba7b211779ad8a2561d2dba9674b8a19ed817923f65f \ - --hash=sha256:c21bac1a7245cbd88c0b0e4a420221b7bfa838a2814ee5bb924e9c2f10a1120b \ - --hash=sha256:c551eb2a3876e8ff2ac63dff1585236ed5dfec5ffd82216a7a174f7c5082a78a \ - --hash=sha256:c790769152308421283679a142dbdb3d1c46c79c823008ecea8e8141db1a2062 \ - --hash=sha256:d7a25fd8c86657f5d9d576268e3b3767c5cd4f42867c9383618be8517f0f022a +nh3==0.2.18 \ + --hash=sha256:0411beb0589eacb6734f28d5497ca2ed379eafab8ad8c84b31bb5c34072b7164 \ + --hash=sha256:14c5a72e9fe82aea5fe3072116ad4661af5cf8e8ff8fc5ad3450f123e4925e86 \ + --hash=sha256:19aaba96e0f795bd0a6c56291495ff59364f4300d4a39b29a0abc9cb3774a84b \ + --hash=sha256:34c03fa78e328c691f982b7c03d4423bdfd7da69cd707fe572f544cf74ac23ad \ + --hash=sha256:36c95d4b70530b320b365659bb5034341316e6a9b30f0b25fa9c9eff4c27a204 \ + --hash=sha256:3a157ab149e591bb638a55c8c6bcb8cdb559c8b12c13a8affaba6cedfe51713a \ + --hash=sha256:42c64511469005058cd17cc1537578eac40ae9f7200bedcfd1fc1a05f4f8c200 \ + --hash=sha256:5f36b271dae35c465ef5e9090e1fdaba4a60a56f0bb0ba03e0932a66f28b9189 \ + --hash=sha256:6955369e4d9f48f41e3f238a9e60f9410645db7e07435e62c6a9ea6135a4907f \ + --hash=sha256:7b7c2a3c9eb1a827d42539aa64091640bd275b81e097cd1d8d82ef91ffa2e811 \ + --hash=sha256:8ce0f819d2f1933953fca255db2471ad58184a60508f03e6285e5114b6254844 \ + --hash=sha256:94a166927e53972a9698af9542ace4e38b9de50c34352b962f4d9a7d4c927af4 \ + --hash=sha256:a7f1b5b2c15866f2db413a3649a8fe4fd7b428ae58be2c0f6bca5eefd53ca2be \ + --hash=sha256:c8b3a1cebcba9b3669ed1a84cc65bf005728d2f0bc1ed2a6594a992e817f3a50 \ + --hash=sha256:de3ceed6e661954871d6cd78b410213bdcb136f79aafe22aa7182e028b8c7307 \ + --hash=sha256:f0eca9ca8628dbb4e916ae2491d72957fdd35f7a5d326b7032a345f111ac07fe # via readme-renderer nox==2024.4.15 \ --hash=sha256:6492236efa15a460ecb98e7b67562a28b70da006ab0be164e8821177577c0565 \ @@ -460,9 +460,9 @@ python-dateutil==2.9.0.post0 \ --hash=sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3 \ --hash=sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427 # via gcp-releasetool -readme-renderer==43.0 \ - --hash=sha256:1818dd28140813509eeed8d62687f7cd4f7bad90d4db586001c5dc09d4fde311 \ - --hash=sha256:19db308d86ecd60e5affa3b2a98f017af384678c63c88e5d4556a380e674f3f9 +readme-renderer==44.0 \ + --hash=sha256:2fbca89b81a08526aadf1357a8c2ae889ec05fb03f5da67f9769c9a592166151 \ + --hash=sha256:8712034eabbfa6805cacf1402b4eeb2a73028f72d1166d6f5cb7f9c047c5d1e1 # via twine requests==2.32.3 \ --hash=sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760 \ diff --git a/.release-please-manifest.json b/.release-please-manifest.json index b337b52396..a627e662e0 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,3 +1,3 @@ { - ".": "2.17.2" + ".": "2.18.0" } \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index ae15d43845..786b1399b7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,18 @@ [1]: https://pypi.org/project/google-cloud-firestore/#history +## [2.18.0](https://github.com/googleapis/python-firestore/compare/v2.17.2...v2.18.0) (2024-08-26) + + +### Features + +* Support returning computed distance and set distance thresholds on VectorQueries ([#960](https://github.com/googleapis/python-firestore/issues/960)) ([5c2192d](https://github.com/googleapis/python-firestore/commit/5c2192d3c66f6b6a11f122affbfb29556a77a535)) + + +### Bug Fixes + +* Remove custom retry loop ([#948](https://github.com/googleapis/python-firestore/issues/948)) ([04bb206](https://github.com/googleapis/python-firestore/commit/04bb20628a8e68a0ad86433c18c37734b6f282c8)) + ## [2.17.2](https://github.com/googleapis/python-firestore/compare/v2.17.1...v2.17.2) (2024-08-13) diff --git a/google/cloud/firestore/gapic_version.py b/google/cloud/firestore/gapic_version.py index 7f7a51c626..f09943f6bd 100644 --- a/google/cloud/firestore/gapic_version.py +++ b/google/cloud/firestore/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "2.17.2" # {x-release-please-version} +__version__ = "2.18.0" # {x-release-please-version} diff --git a/google/cloud/firestore_admin_v1/gapic_version.py b/google/cloud/firestore_admin_v1/gapic_version.py index 7f7a51c626..f09943f6bd 100644 --- a/google/cloud/firestore_admin_v1/gapic_version.py +++ b/google/cloud/firestore_admin_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "2.17.2" # {x-release-please-version} +__version__ = "2.18.0" # {x-release-please-version} diff --git a/google/cloud/firestore_bundle/gapic_version.py b/google/cloud/firestore_bundle/gapic_version.py index 7f7a51c626..f09943f6bd 100644 --- a/google/cloud/firestore_bundle/gapic_version.py +++ b/google/cloud/firestore_bundle/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "2.17.2" # {x-release-please-version} +__version__ = "2.18.0" # {x-release-please-version} diff --git a/google/cloud/firestore_v1/async_query.py b/google/cloud/firestore_v1/async_query.py index 15f81be247..ca83c26306 100644 --- a/google/cloud/firestore_v1/async_query.py +++ b/google/cloud/firestore_v1/async_query.py @@ -230,17 +230,25 @@ def find_nearest( query_vector: Vector, limit: int, distance_measure: DistanceMeasure, + *, + distance_result_field: Optional[str] = None, + distance_threshold: Optional[float] = None, ) -> AsyncVectorQuery: """ Finds the closest vector embeddings to the given query vector. Args: - vector_field(str): An indexed vector field to search upon. Only documents which contain + vector_field (str): An indexed vector field to search upon. Only documents which contain vectors whose dimensionality match the query_vector can be returned. - query_vector(Vector): The query vector that we are searching on. Must be a vector of no more + query_vector (Vector): The query vector that we are searching on. Must be a vector of no more than 2048 dimensions. limit (int): The number of nearest neighbors to return. Must be a positive integer of no more than 1000. - distance_measure(:class:`DistanceMeasure`): The Distance Measure to use. + distance_measure (:class:`DistanceMeasure`): The Distance Measure to use. + distance_result_field (Optional[str]): + Name of the field to output the result of the vector distance + calculation. If unset then the distance will not be returned. + distance_threshold (Optional[float]): + A threshold for which no less similar documents will be returned. Returns: :class`~firestore_v1.vector_query.VectorQuery`: the vector query. @@ -250,6 +258,8 @@ def find_nearest( query_vector=query_vector, limit=limit, distance_measure=distance_measure, + distance_result_field=distance_result_field, + distance_threshold=distance_threshold, ) def count( diff --git a/google/cloud/firestore_v1/async_transaction.py b/google/cloud/firestore_v1/async_transaction.py index 6b01fffd6c..7281a68e56 100644 --- a/google/cloud/firestore_v1/async_transaction.py +++ b/google/cloud/firestore_v1/async_transaction.py @@ -15,14 +15,12 @@ """Helpers for applying Google Cloud Firestore changes in a transaction.""" -import asyncio -import random from typing import Any, AsyncGenerator, Callable, Coroutine from google.api_core import exceptions, gapic_v1 from google.api_core import retry_async as retries -from google.cloud.firestore_v1 import _helpers, async_batch, types +from google.cloud.firestore_v1 import _helpers, async_batch from google.cloud.firestore_v1.async_document import ( AsyncDocumentReference, DocumentSnapshot, @@ -33,18 +31,12 @@ _CANT_COMMIT, _CANT_ROLLBACK, _EXCEED_ATTEMPTS_TEMPLATE, - _INITIAL_SLEEP, - _MAX_SLEEP, - _MULTIPLIER, _WRITE_READ_ONLY, MAX_ATTEMPTS, BaseTransaction, _BaseTransactional, ) -# Types needed only for Type Hints -from google.cloud.firestore_v1.client import Client - class AsyncTransaction(async_batch.AsyncWriteBatch, BaseTransaction): """Accumulate read-and-write operations to be sent in a transaction. @@ -140,8 +132,13 @@ async def _commit(self) -> list: if not self.in_progress: raise ValueError(_CANT_COMMIT) - commit_response = await _commit_with_retry( - self._client, self._write_pbs, self._id + commit_response = await self._client._firestore_api.commit( + request={ + "database": self._client._database_string, + "writes": self._write_pbs, + "transaction": self._id, + }, + metadata=self._client._rpc_metadata, ) self._clean_up() @@ -313,76 +310,3 @@ def async_transactional( the wrapped callable. """ return _AsyncTransactional(to_wrap) - - -# TODO(crwilcox): this was 'coroutine' from pytype merge-pyi... -async def _commit_with_retry( - client: Client, write_pbs: list, transaction_id: bytes -) -> types.CommitResponse: - """Call ``Commit`` on the GAPIC client with retry / sleep. - - Retries the ``Commit`` RPC on Unavailable. Usually this RPC-level - retry is handled by the underlying GAPICd client, but in this case it - doesn't because ``Commit`` is not always idempotent. But here we know it - is "idempotent"-like because it has a transaction ID. We also need to do - our own retry to special-case the ``INVALID_ARGUMENT`` error. - - Args: - client (:class:`~google.cloud.firestore_v1.client.Client`): - A client with GAPIC client and configuration details. - write_pbs (List[:class:`google.cloud.proto.firestore.v1.write.Write`, ...]): - A ``Write`` protobuf instance to be committed. - transaction_id (bytes): - ID of an existing transaction that this commit will run in. - - Returns: - :class:`google.cloud.firestore_v1.types.CommitResponse`: - The protobuf response from ``Commit``. - - Raises: - ~google.api_core.exceptions.GoogleAPICallError: If a non-retryable - exception is encountered. - """ - current_sleep = _INITIAL_SLEEP - while True: - try: - return await client._firestore_api.commit( - request={ - "database": client._database_string, - "writes": write_pbs, - "transaction": transaction_id, - }, - metadata=client._rpc_metadata, - ) - except exceptions.ServiceUnavailable: - # Retry - pass - - current_sleep = await _sleep(current_sleep) - - -async def _sleep( - current_sleep: float, max_sleep: float = _MAX_SLEEP, multiplier: float = _MULTIPLIER -) -> float: - """Sleep and produce a new sleep time. - - .. _Exponential Backoff And Jitter: https://www.awsarchitectureblog.com/\ - 2015/03/backoff.html - - Select a duration between zero and ``current_sleep``. It might seem - counterintuitive to have so much jitter, but - `Exponential Backoff And Jitter`_ argues that "full jitter" is - the best strategy. - - Args: - current_sleep (float): The current "max" for sleep interval. - max_sleep (Optional[float]): Eventual "max" sleep time - multiplier (Optional[float]): Multiplier for exponential backoff. - - Returns: - float: Newly doubled ``current_sleep`` or ``max_sleep`` (whichever - is smaller) - """ - actual_sleep = random.uniform(0.0, current_sleep) - await asyncio.sleep(actual_sleep) - return min(multiplier * current_sleep, max_sleep) diff --git a/google/cloud/firestore_v1/base_collection.py b/google/cloud/firestore_v1/base_collection.py index e2065dc2f8..18c62aa33b 100644 --- a/google/cloud/firestore_v1/base_collection.py +++ b/google/cloud/firestore_v1/base_collection.py @@ -550,23 +550,35 @@ def find_nearest( query_vector: Vector, limit: int, distance_measure: DistanceMeasure, + *, + distance_result_field: Optional[str] = None, + distance_threshold: Optional[float] = None, ) -> VectorQuery: """ Finds the closest vector embeddings to the given query vector. Args: - vector_field(str): An indexed vector field to search upon. Only documents which contain + vector_field (str): An indexed vector field to search upon. Only documents which contain vectors whose dimensionality match the query_vector can be returned. - query_vector(Vector): The query vector that we are searching on. Must be a vector of no more + query_vector (Vector): The query vector that we are searching on. Must be a vector of no more than 2048 dimensions. limit (int): The number of nearest neighbors to return. Must be a positive integer of no more than 1000. - distance_measure(:class:`DistanceMeasure`): The Distance Measure to use. + distance_measure (:class:`DistanceMeasure`): The Distance Measure to use. + distance_result_field (Optional[str]): + Name of the field to output the result of the vector distance calculation + distance_threshold (Optional[float]): + A threshold for which no less similar documents will be returned. Returns: :class`~firestore_v1.vector_query.VectorQuery`: the vector query. """ return self._vector_query().find_nearest( - vector_field, query_vector, limit, distance_measure + vector_field, + query_vector, + limit, + distance_measure, + distance_result_field=distance_result_field, + distance_threshold=distance_threshold, ) diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index 73ed00206b..cfed454b93 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -982,6 +982,9 @@ def find_nearest( query_vector: Vector, limit: int, distance_measure: DistanceMeasure, + *, + distance_result_field: Optional[str] = None, + distance_threshold: Optional[float] = None, ) -> BaseVectorQuery: raise NotImplementedError diff --git a/google/cloud/firestore_v1/base_transaction.py b/google/cloud/firestore_v1/base_transaction.py index 5b6e76e1b0..09f0c1fb9a 100644 --- a/google/cloud/firestore_v1/base_transaction.py +++ b/google/cloud/firestore_v1/base_transaction.py @@ -39,12 +39,6 @@ _CANT_ROLLBACK: str = _MISSING_ID_TEMPLATE.format("rolled back") _CANT_COMMIT: str = _MISSING_ID_TEMPLATE.format("committed") _WRITE_READ_ONLY: str = "Cannot perform write operation in read-only transaction." -_INITIAL_SLEEP: float = 1.0 -"""float: Initial "max" for sleep interval. To be used in :func:`_sleep`.""" -_MAX_SLEEP: float = 30.0 -"""float: Eventual "max" sleep time. To be used in :func:`_sleep`.""" -_MULTIPLIER: float = 2.0 -"""float: Multiplier for exponential backoff. To be used in :func:`_sleep`.""" _EXCEED_ATTEMPTS_TEMPLATE: str = "Failed to commit transaction in {:d} attempts." _CANT_RETRY_READ_ONLY: str = "Only read-write transactions can be retried." diff --git a/google/cloud/firestore_v1/base_vector_query.py b/google/cloud/firestore_v1/base_vector_query.py index 0c5c61b3e8..26cd5b1997 100644 --- a/google/cloud/firestore_v1/base_vector_query.py +++ b/google/cloud/firestore_v1/base_vector_query.py @@ -45,6 +45,8 @@ def __init__(self, nested_query) -> None: self._query_vector: Optional[Vector] = None self._limit: Optional[int] = None self._distance_measure: Optional[DistanceMeasure] = None + self._distance_result_field: Optional[str] = None + self._distance_threshold: Optional[float] = None @property def _client(self): @@ -69,6 +71,11 @@ def _to_protobuf(self) -> query.StructuredQuery: else: raise ValueError("Invalid distance_measure") + # Coerce ints to floats as required by the protobuf. + distance_threshold_proto = None + if self._distance_threshold is not None: + distance_threshold_proto = float(self._distance_threshold) + pb = self._nested_query._to_protobuf() pb.find_nearest = query.StructuredQuery.FindNearest( vector_field=query.StructuredQuery.FieldReference( @@ -77,6 +84,8 @@ def _to_protobuf(self) -> query.StructuredQuery: query_vector=_helpers.encode_value(self._query_vector), distance_measure=distance_measure_proto, limit=self._limit, + distance_result_field=self._distance_result_field, + distance_threshold=distance_threshold_proto, ) return pb @@ -111,12 +120,17 @@ def find_nearest( query_vector: Vector, limit: int, distance_measure: DistanceMeasure, + *, + distance_result_field: Optional[str] = None, + distance_threshold: Optional[float] = None, ): """Finds the closest vector embeddings to the given query vector.""" self._vector_field = vector_field self._query_vector = query_vector self._limit = limit self._distance_measure = distance_measure + self._distance_result_field = distance_result_field + self._distance_threshold = distance_threshold return self def stream( diff --git a/google/cloud/firestore_v1/gapic_version.py b/google/cloud/firestore_v1/gapic_version.py index 7f7a51c626..f09943f6bd 100644 --- a/google/cloud/firestore_v1/gapic_version.py +++ b/google/cloud/firestore_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "2.17.2" # {x-release-please-version} +__version__ = "2.18.0" # {x-release-please-version} diff --git a/google/cloud/firestore_v1/query.py b/google/cloud/firestore_v1/query.py index b5bd5ec4fd..eb8f51dc8d 100644 --- a/google/cloud/firestore_v1/query.py +++ b/google/cloud/firestore_v1/query.py @@ -251,17 +251,26 @@ def find_nearest( query_vector: Vector, limit: int, distance_measure: DistanceMeasure, + *, + distance_result_field: Optional[str] = None, + distance_threshold: Optional[float] = None, ) -> Type["firestore_v1.vector_query.VectorQuery"]: """ Finds the closest vector embeddings to the given query vector. Args: - vector_field(str): An indexed vector field to search upon. Only documents which contain + vector_field (str): An indexed vector field to search upon. Only documents which contain vectors whose dimensionality match the query_vector can be returned. - query_vector(Vector): The query vector that we are searching on. Must be a vector of no more + query_vector (Vector): The query vector that we are searching on. Must be a vector of no more than 2048 dimensions. limit (int): The number of nearest neighbors to return. Must be a positive integer of no more than 1000. - distance_measure(:class:`DistanceMeasure`): The Distance Measure to use. + distance_measure (:class:`DistanceMeasure`): The Distance Measure to use. + distance_result_field (Optional[str]): + Name of the field to output the result of the vector distance + calculation. If unset then the distance will not be returned. + distance_threshold (Optional[float]): + A threshold for which no less similar documents will be returned. + Returns: :class`~firestore_v1.vector_query.VectorQuery`: the vector query. @@ -271,6 +280,8 @@ def find_nearest( query_vector=query_vector, limit=limit, distance_measure=distance_measure, + distance_result_field=distance_result_field, + distance_threshold=distance_threshold, ) def count( diff --git a/google/cloud/firestore_v1/transaction.py b/google/cloud/firestore_v1/transaction.py index 1691b56792..8f92ddaf0d 100644 --- a/google/cloud/firestore_v1/transaction.py +++ b/google/cloud/firestore_v1/transaction.py @@ -15,8 +15,6 @@ """Helpers for applying Google Cloud Firestore changes in a transaction.""" -import random -import time from typing import Any, Callable, Generator from google.api_core import exceptions, gapic_v1 @@ -31,9 +29,6 @@ _CANT_COMMIT, _CANT_ROLLBACK, _EXCEED_ATTEMPTS_TEMPLATE, - _INITIAL_SLEEP, - _MAX_SLEEP, - _MULTIPLIER, _WRITE_READ_ONLY, MAX_ATTEMPTS, BaseTransaction, @@ -41,7 +36,6 @@ ) from google.cloud.firestore_v1.document import DocumentReference from google.cloud.firestore_v1.query import Query -from google.cloud.firestore_v1.types import CommitResponse class Transaction(batch.WriteBatch, BaseTransaction): @@ -138,7 +132,14 @@ def _commit(self) -> list: if not self.in_progress: raise ValueError(_CANT_COMMIT) - commit_response = _commit_with_retry(self._client, self._write_pbs, self._id) + commit_response = self._client._firestore_api.commit( + request={ + "database": self._client._database_string, + "writes": self._write_pbs, + "transaction": self._id, + }, + metadata=self._client._rpc_metadata, + ) self._clean_up() return list(commit_response.write_results) @@ -301,75 +302,3 @@ def transactional(to_wrap: Callable) -> _Transactional: the wrapped callable. """ return _Transactional(to_wrap) - - -def _commit_with_retry( - client, write_pbs: list, transaction_id: bytes -) -> CommitResponse: - """Call ``Commit`` on the GAPIC client with retry / sleep. - - Retries the ``Commit`` RPC on Unavailable. Usually this RPC-level - retry is handled by the underlying GAPICd client, but in this case it - doesn't because ``Commit`` is not always idempotent. But here we know it - is "idempotent"-like because it has a transaction ID. We also need to do - our own retry to special-case the ``INVALID_ARGUMENT`` error. - - Args: - client (:class:`~google.cloud.firestore_v1.client.Client`): - A client with GAPIC client and configuration details. - write_pbs (List[:class:`google.cloud.proto.firestore.v1.write.Write`, ...]): - A ``Write`` protobuf instance to be committed. - transaction_id (bytes): - ID of an existing transaction that this commit will run in. - - Returns: - :class:`google.cloud.firestore_v1.types.CommitResponse`: - The protobuf response from ``Commit``. - - Raises: - ~google.api_core.exceptions.GoogleAPICallError: If a non-retryable - exception is encountered. - """ - current_sleep = _INITIAL_SLEEP - while True: - try: - return client._firestore_api.commit( - request={ - "database": client._database_string, - "writes": write_pbs, - "transaction": transaction_id, - }, - metadata=client._rpc_metadata, - ) - except exceptions.ServiceUnavailable: - # Retry - pass - - current_sleep = _sleep(current_sleep) - - -def _sleep( - current_sleep: float, max_sleep: float = _MAX_SLEEP, multiplier: float = _MULTIPLIER -) -> float: - """Sleep and produce a new sleep time. - - .. _Exponential Backoff And Jitter: https://www.awsarchitectureblog.com/\ - 2015/03/backoff.html - - Select a duration between zero and ``current_sleep``. It might seem - counterintuitive to have so much jitter, but - `Exponential Backoff And Jitter`_ argues that "full jitter" is - the best strategy. - - Args: - current_sleep (float): The current "max" for sleep interval. - max_sleep (Optional[float]): Eventual "max" sleep time - multiplier (Optional[float]): Multiplier for exponential backoff. - - Returns: - float: Newly doubled ``current_sleep`` or ``max_sleep`` (whichever - is smaller) - """ - actual_sleep = random.uniform(0.0, current_sleep) - time.sleep(actual_sleep) - return min(multiplier * current_sleep, max_sleep) diff --git a/noxfile.py b/noxfile.py index e033449eee..41f545a68f 100644 --- a/noxfile.py +++ b/noxfile.py @@ -365,7 +365,7 @@ def cover(session): session.run("coverage", "erase") -@nox.session(python="3.9") +@nox.session(python="3.10") def docs(session): """Build the docs for this library.""" diff --git a/tests/system/test_system.py b/tests/system/test_system.py index dc9d86a102..b67b8aecca 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -176,15 +176,22 @@ def on_snapshot(docs, changes, read_time): @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_vector_search_collection(client, database): - # Documents and Indexs are a manual step from util/boostrap_vector_index.py +@pytest.mark.parametrize( + "distance_measure", + [ + DistanceMeasure.EUCLIDEAN, + DistanceMeasure.COSINE, + ], +) +def test_vector_search_collection(client, database, distance_measure): + # Documents and Indexes are a manual step from util/bootstrap_vector_index.py collection_id = "vector_search" collection = client.collection(collection_id) vector_query = collection.find_nearest( vector_field="embedding", query_vector=Vector([1.0, 2.0, 3.0]), - distance_measure=DistanceMeasure.EUCLIDEAN, + distance_measure=distance_measure, limit=1, ) returned = vector_query.get() @@ -198,15 +205,22 @@ def test_vector_search_collection(client, database): @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_vector_search_collection_with_filter(client, database): - # Documents and Indexs are a manual step from util/boostrap_vector_index.py +@pytest.mark.parametrize( + "distance_measure", + [ + DistanceMeasure.EUCLIDEAN, + DistanceMeasure.COSINE, + ], +) +def test_vector_search_collection_with_filter(client, database, distance_measure): + # Documents and Indexes are a manual step from util/bootstrap_vector_index.py collection_id = "vector_search" collection = client.collection(collection_id) vector_query = collection.where("color", "==", "red").find_nearest( vector_field="embedding", query_vector=Vector([1.0, 2.0, 3.0]), - distance_measure=DistanceMeasure.EUCLIDEAN, + distance_measure=distance_measure, limit=1, ) returned = vector_query.get() @@ -220,15 +234,82 @@ def test_vector_search_collection_with_filter(client, database): @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_vector_search_collection_group(client, database): - # Documents and Indexs are a manual step from util/boostrap_vector_index.py +def test_vector_search_collection_with_distance_parameters_euclid(client, database): + # Documents and Indexes are a manual step from util/bootstrap_vector_index.py + collection_id = "vector_search" + collection = client.collection(collection_id) + + vector_query = collection.find_nearest( + vector_field="embedding", + query_vector=Vector([1.0, 2.0, 3.0]), + distance_measure=DistanceMeasure.EUCLIDEAN, + limit=3, + distance_result_field="vector_distance", + distance_threshold=1.0, + ) + returned = vector_query.get() + assert isinstance(returned, list) + assert len(returned) == 2 + assert returned[0].to_dict() == { + "embedding": Vector([1.0, 2.0, 3.0]), + "color": "red", + "vector_distance": 0.0, + } + assert returned[1].to_dict() == { + "embedding": Vector([2.0, 2.0, 3.0]), + "color": "red", + "vector_distance": 1.0, + } + + +@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +def test_vector_search_collection_with_distance_parameters_cosine(client, database): + # Documents and Indexes are a manual step from util/bootstrap_vector_index.py + collection_id = "vector_search" + collection = client.collection(collection_id) + + vector_query = collection.find_nearest( + vector_field="embedding", + query_vector=Vector([1.0, 2.0, 3.0]), + distance_measure=DistanceMeasure.COSINE, + limit=3, + distance_result_field="vector_distance", + distance_threshold=0.02, + ) + returned = vector_query.get() + assert isinstance(returned, list) + assert len(returned) == 2 + assert returned[0].to_dict() == { + "embedding": Vector([1.0, 2.0, 3.0]), + "color": "red", + "vector_distance": 0.0, + } + assert returned[1].to_dict() == { + "embedding": Vector([3.0, 4.0, 5.0]), + "color": "yellow", + "vector_distance": 0.017292370176009153, + } + + +@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize( + "distance_measure", + [ + DistanceMeasure.EUCLIDEAN, + DistanceMeasure.COSINE, + ], +) +def test_vector_search_collection_group(client, database, distance_measure): + # Documents and Indexes are a manual step from util/bootstrap_vector_index.py collection_id = "vector_search" collection_group = client.collection_group(collection_id) vector_query = collection_group.find_nearest( vector_field="embedding", query_vector=Vector([1.0, 2.0, 3.0]), - distance_measure=DistanceMeasure.EUCLIDEAN, + distance_measure=distance_measure, limit=1, ) returned = vector_query.get() @@ -241,16 +322,23 @@ def test_vector_search_collection_group(client, database): @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") +@pytest.mark.parametrize( + "distance_measure", + [ + DistanceMeasure.EUCLIDEAN, + DistanceMeasure.COSINE, + ], +) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_vector_search_collection_group_with_filter(client, database): - # Documents and Indexs are a manual step from util/boostrap_vector_index.py +def test_vector_search_collection_group_with_filter(client, database, distance_measure): + # Documents and Indexes are a manual step from util/bootstrap_vector_index.py collection_id = "vector_search" collection_group = client.collection_group(collection_id) vector_query = collection_group.where("color", "==", "red").find_nearest( vector_field="embedding", query_vector=Vector([1.0, 2.0, 3.0]), - distance_measure=DistanceMeasure.EUCLIDEAN, + distance_measure=distance_measure, limit=1, ) returned = vector_query.get() @@ -262,6 +350,70 @@ def test_vector_search_collection_group_with_filter(client, database): } +@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +def test_vector_search_collection_group_with_distance_parameters_euclid( + client, database +): + # Documents and Indexes are a manual step from util/bootstrap_vector_index.py + collection_id = "vector_search" + collection_group = client.collection_group(collection_id) + + vector_query = collection_group.find_nearest( + vector_field="embedding", + query_vector=Vector([1.0, 2.0, 3.0]), + distance_measure=DistanceMeasure.EUCLIDEAN, + limit=3, + distance_result_field="vector_distance", + distance_threshold=1.0, + ) + returned = vector_query.get() + assert isinstance(returned, list) + assert len(returned) == 2 + assert returned[0].to_dict() == { + "embedding": Vector([1.0, 2.0, 3.0]), + "color": "red", + "vector_distance": 0.0, + } + assert returned[1].to_dict() == { + "embedding": Vector([2.0, 2.0, 3.0]), + "color": "red", + "vector_distance": 1.0, + } + + +@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +def test_vector_search_collection_group_with_distance_parameters_cosine( + client, database +): + # Documents and Indexes are a manual step from util/bootstrap_vector_index.py + collection_id = "vector_search" + collection_group = client.collection_group(collection_id) + + vector_query = collection_group.find_nearest( + vector_field="embedding", + query_vector=Vector([1.0, 2.0, 3.0]), + distance_measure=DistanceMeasure.COSINE, + limit=3, + distance_result_field="vector_distance", + distance_threshold=0.02, + ) + returned = vector_query.get() + assert isinstance(returned, list) + assert len(returned) == 2 + assert returned[0].to_dict() == { + "embedding": Vector([1.0, 2.0, 3.0]), + "color": "red", + "vector_distance": 0.0, + } + assert returned[1].to_dict() == { + "embedding": Vector([3.0, 4.0, 5.0]), + "color": "yellow", + "vector_distance": 0.017292370176009153, + } + + @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) def test_create_document_w_subcollection(client, cleanup, database): collection_id = "doc-create-sub" + UNIQUE_RESOURCE_ID diff --git a/tests/system/test_system_async.py b/tests/system/test_system_async.py index df574e0fa7..78bd64c5c5 100644 --- a/tests/system/test_system_async.py +++ b/tests/system/test_system_async.py @@ -341,15 +341,22 @@ async def test_document_update_w_int_field(client, cleanup, database): @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -async def test_vector_search_collection(client, database): - # Documents and Indexs are a manual step from util/boostrap_vector_index.py +@pytest.mark.parametrize( + "distance_measure", + [ + DistanceMeasure.EUCLIDEAN, + DistanceMeasure.COSINE, + ], +) +async def test_vector_search_collection(client, database, distance_measure): + # Documents and Indexes are a manual step from util/bootstrap_vector_index.py collection_id = "vector_search" collection = client.collection(collection_id) vector_query = collection.find_nearest( vector_field="embedding", query_vector=Vector([1.0, 2.0, 3.0]), limit=1, - distance_measure=DistanceMeasure.EUCLIDEAN, + distance_measure=distance_measure, ) returned = await vector_query.get() assert isinstance(returned, list) @@ -362,15 +369,22 @@ async def test_vector_search_collection(client, database): @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -async def test_vector_search_collection_with_filter(client, database): - # Documents and Indexs are a manual step from util/boostrap_vector_index.py +@pytest.mark.parametrize( + "distance_measure", + [ + DistanceMeasure.EUCLIDEAN, + DistanceMeasure.COSINE, + ], +) +async def test_vector_search_collection_with_filter(client, database, distance_measure): + # Documents and Indexes are a manual step from util/bootstrap_vector_index.py collection_id = "vector_search" collection = client.collection(collection_id) vector_query = collection.where("color", "==", "red").find_nearest( vector_field="embedding", query_vector=Vector([1.0, 2.0, 3.0]), limit=1, - distance_measure=DistanceMeasure.EUCLIDEAN, + distance_measure=distance_measure, ) returned = await vector_query.get() assert isinstance(returned, list) @@ -383,15 +397,86 @@ async def test_vector_search_collection_with_filter(client, database): @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -async def test_vector_search_collection_group(client, database): - # Documents and Indexs are a manual step from util/boostrap_vector_index.py +async def test_vector_search_collection_with_distance_parameters_euclid( + client, database +): + # Documents and Indexes are a manual step from util/bootstrap_vector_index.py + collection_id = "vector_search" + collection = client.collection(collection_id) + + vector_query = collection.find_nearest( + vector_field="embedding", + query_vector=Vector([1.0, 2.0, 3.0]), + distance_measure=DistanceMeasure.EUCLIDEAN, + limit=3, + distance_result_field="vector_distance", + distance_threshold=1.0, + ) + returned = await vector_query.get() + assert isinstance(returned, list) + assert len(returned) == 2 + assert returned[0].to_dict() == { + "embedding": Vector([1.0, 2.0, 3.0]), + "color": "red", + "vector_distance": 0.0, + } + assert returned[1].to_dict() == { + "embedding": Vector([2.0, 2.0, 3.0]), + "color": "red", + "vector_distance": 1.0, + } + + +@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_vector_search_collection_with_distance_parameters_cosine( + client, database +): + # Documents and Indexes are a manual step from util/bootstrap_vector_index.py + collection_id = "vector_search" + collection = client.collection(collection_id) + + vector_query = collection.find_nearest( + vector_field="embedding", + query_vector=Vector([1.0, 2.0, 3.0]), + distance_measure=DistanceMeasure.COSINE, + limit=3, + distance_result_field="vector_distance", + distance_threshold=0.02, + ) + returned = await vector_query.get() + assert isinstance(returned, list) + assert len(returned) == 2 + assert returned[0].to_dict() == { + "embedding": Vector([1.0, 2.0, 3.0]), + "color": "red", + "vector_distance": 0.0, + } + assert returned[1].to_dict() == { + "embedding": Vector([3.0, 4.0, 5.0]), + "color": "yellow", + "vector_distance": 0.017292370176009153, + } + + +@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize( + "distance_measure", + [ + DistanceMeasure.EUCLIDEAN, + DistanceMeasure.COSINE, + ], +) +async def test_vector_search_collection_group(client, database, distance_measure): + # Documents and Indexes are a manual step from util/bootstrap_vector_index.py collection_id = "vector_search" collection_group = client.collection_group(collection_id) vector_query = collection_group.find_nearest( vector_field="embedding", query_vector=Vector([1.0, 2.0, 3.0]), - distance_measure=DistanceMeasure.EUCLIDEAN, + distance_measure=distance_measure, limit=1, ) returned = await vector_query.get() @@ -405,15 +490,24 @@ async def test_vector_search_collection_group(client, database): @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -async def test_vector_search_collection_group_with_filter(client, database): - # Documents and Indexs are a manual step from util/boostrap_vector_index.py +@pytest.mark.parametrize( + "distance_measure", + [ + DistanceMeasure.EUCLIDEAN, + DistanceMeasure.COSINE, + ], +) +async def test_vector_search_collection_group_with_filter( + client, database, distance_measure +): + # Documents and Indexes are a manual step from util/bootstrap_vector_index.py collection_id = "vector_search" collection_group = client.collection_group(collection_id) vector_query = collection_group.where("color", "==", "red").find_nearest( vector_field="embedding", query_vector=Vector([1.0, 2.0, 3.0]), - distance_measure=DistanceMeasure.EUCLIDEAN, + distance_measure=distance_measure, limit=1, ) returned = await vector_query.get() @@ -425,6 +519,70 @@ async def test_vector_search_collection_group_with_filter(client, database): } +@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_vector_search_collection_group_with_distance_parameters_euclid( + client, database +): + # Documents and Indexes are a manual step from util/bootstrap_vector_index.py + collection_id = "vector_search" + collection_group = client.collection_group(collection_id) + + vector_query = collection_group.find_nearest( + vector_field="embedding", + query_vector=Vector([1.0, 2.0, 3.0]), + distance_measure=DistanceMeasure.EUCLIDEAN, + limit=3, + distance_result_field="vector_distance", + distance_threshold=1.0, + ) + returned = await vector_query.get() + assert isinstance(returned, list) + assert len(returned) == 2 + assert returned[0].to_dict() == { + "embedding": Vector([1.0, 2.0, 3.0]), + "color": "red", + "vector_distance": 0.0, + } + assert returned[1].to_dict() == { + "embedding": Vector([2.0, 2.0, 3.0]), + "color": "red", + "vector_distance": 1.0, + } + + +@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_vector_search_collection_group_with_distance_parameters_cosine( + client, database +): + # Documents and Indexes are a manual step from util/bootstrap_vector_index.py + collection_id = "vector_search" + collection_group = client.collection_group(collection_id) + + vector_query = collection_group.find_nearest( + vector_field="embedding", + query_vector=Vector([1.0, 2.0, 3.0]), + distance_measure=DistanceMeasure.COSINE, + limit=3, + distance_result_field="vector_distance", + distance_threshold=0.02, + ) + returned = await vector_query.get() + assert isinstance(returned, list) + assert len(returned) == 2 + assert returned[0].to_dict() == { + "embedding": Vector([1.0, 2.0, 3.0]), + "color": "red", + "vector_distance": 0.0, + } + assert returned[1].to_dict() == { + "embedding": Vector([3.0, 4.0, 5.0]), + "color": "yellow", + "vector_distance": 0.017292370176009153, + } + + @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Internal Issue b/137867104") @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) async def test_update_document(client, cleanup, database): diff --git a/tests/unit/v1/_test_helpers.py b/tests/unit/v1/_test_helpers.py index 340ccb30eb..564ec32bc3 100644 --- a/tests/unit/v1/_test_helpers.py +++ b/tests/unit/v1/_test_helpers.py @@ -108,6 +108,12 @@ def make_vector_query(*args, **kw): return VectorQuery(*args, **kw) +def make_async_vector_query(*args, **kw): + from google.cloud.firestore_v1.async_vector_query import AsyncVectorQuery + + return AsyncVectorQuery(*args, **kw) + + def build_test_timestamp( year: int = 2021, month: int = 1, diff --git a/tests/unit/v1/test_async_transaction.py b/tests/unit/v1/test_async_transaction.py index 3c62e83d1b..85d693950e 100644 --- a/tests/unit/v1/test_async_transaction.py +++ b/tests/unit/v1/test_async_transaction.py @@ -799,208 +799,6 @@ def test_async_transactional_factory(): assert wrapped.to_wrap is mock.sentinel.callable_ -@mock.patch("google.cloud.firestore_v1.async_transaction._sleep") -@pytest.mark.asyncio -async def test__commit_with_retry_success_first_attempt(_sleep): - from google.cloud.firestore_v1.async_transaction import _commit_with_retry - - # Create a minimal fake GAPIC with a dummy result. - firestore_api = AsyncMock() - - # Attach the fake GAPIC to a real client. - client = _make_client("summer") - client._firestore_api_internal = firestore_api - - # Call function and check result. - txn_id = b"cheeeeeez" - commit_response = await _commit_with_retry(client, mock.sentinel.write_pbs, txn_id) - assert commit_response is firestore_api.commit.return_value - - # Verify mocks used. - _sleep.assert_not_called() - firestore_api.commit.assert_called_once_with( - request={ - "database": client._database_string, - "writes": mock.sentinel.write_pbs, - "transaction": txn_id, - }, - metadata=client._rpc_metadata, - ) - - -@mock.patch( - "google.cloud.firestore_v1.async_transaction._sleep", side_effect=[2.0, 4.0] -) -@pytest.mark.asyncio -async def test__commit_with_retry_success_third_attempt(_sleep): - from google.api_core import exceptions - - from google.cloud.firestore_v1.async_transaction import _commit_with_retry - - # Create a minimal fake GAPIC with a dummy result. - firestore_api = AsyncMock() - - # Make sure the first two requests fail and the third succeeds. - firestore_api.commit.side_effect = [ - exceptions.ServiceUnavailable("Server sleepy."), - exceptions.ServiceUnavailable("Server groggy."), - mock.sentinel.commit_response, - ] - - # Attach the fake GAPIC to a real client. - client = _make_client("outside") - client._firestore_api_internal = firestore_api - - # Call function and check result. - txn_id = b"the-world\x00" - commit_response = await _commit_with_retry(client, mock.sentinel.write_pbs, txn_id) - assert commit_response is mock.sentinel.commit_response - - # Verify mocks used. - # Ensure _sleep is called after commit failures, with intervals of 1 and 2 seconds - assert _sleep.call_count == 2 - _sleep.assert_any_call(1.0) - _sleep.assert_any_call(2.0) - # commit() called same way 3 times. - commit_call = mock.call( - request={ - "database": client._database_string, - "writes": mock.sentinel.write_pbs, - "transaction": txn_id, - }, - metadata=client._rpc_metadata, - ) - assert firestore_api.commit.mock_calls == [commit_call, commit_call, commit_call] - - -@mock.patch("google.cloud.firestore_v1.async_transaction._sleep") -@pytest.mark.asyncio -async def test__commit_with_retry_failure_first_attempt(_sleep): - from google.api_core import exceptions - - from google.cloud.firestore_v1.async_transaction import _commit_with_retry - - # Create a minimal fake GAPIC with a dummy result. - firestore_api = AsyncMock() - - # Make sure the first request fails with an un-retryable error. - exc = exceptions.ResourceExhausted("We ran out of fries.") - firestore_api.commit.side_effect = exc - - # Attach the fake GAPIC to a real client. - client = _make_client("peanut-butter") - client._firestore_api_internal = firestore_api - - # Call function and check result. - txn_id = b"\x08\x06\x07\x05\x03\x00\x09-jenny" - with pytest.raises(exceptions.ResourceExhausted) as exc_info: - await _commit_with_retry(client, mock.sentinel.write_pbs, txn_id) - - assert exc_info.value is exc - - # Verify mocks used. - _sleep.assert_not_called() - firestore_api.commit.assert_called_once_with( - request={ - "database": client._database_string, - "writes": mock.sentinel.write_pbs, - "transaction": txn_id, - }, - metadata=client._rpc_metadata, - ) - - -@mock.patch("google.cloud.firestore_v1.async_transaction._sleep", return_value=2.0) -@pytest.mark.asyncio -async def test__commit_with_retry_failure_second_attempt(_sleep): - from google.api_core import exceptions - - from google.cloud.firestore_v1.async_transaction import _commit_with_retry - - # Create a minimal fake GAPIC with a dummy result. - firestore_api = AsyncMock() - - # Make sure the first request fails retry-able and second - # fails non-retryable. - exc1 = exceptions.ServiceUnavailable("Come back next time.") - exc2 = exceptions.InternalServerError("Server on fritz.") - firestore_api.commit.side_effect = [exc1, exc2] - - # Attach the fake GAPIC to a real client. - client = _make_client("peanut-butter") - client._firestore_api_internal = firestore_api - - # Call function and check result. - txn_id = b"the-journey-when-and-where-well-go" - with pytest.raises(exceptions.InternalServerError) as exc_info: - await _commit_with_retry(client, mock.sentinel.write_pbs, txn_id) - - assert exc_info.value is exc2 - - # Verify mocks used. - _sleep.assert_called_once_with(1.0) - # commit() called same way 2 times. - commit_call = mock.call( - request={ - "database": client._database_string, - "writes": mock.sentinel.write_pbs, - "transaction": txn_id, - }, - metadata=client._rpc_metadata, - ) - assert firestore_api.commit.mock_calls == [commit_call, commit_call] - - -@mock.patch("random.uniform", return_value=5.5) -@mock.patch("asyncio.sleep", return_value=None) -@pytest.mark.asyncio -async def test_sleep_defaults(sleep, uniform): - from google.cloud.firestore_v1.async_transaction import _sleep - - curr_sleep = 10.0 - assert uniform.return_value <= curr_sleep - - new_sleep = await _sleep(curr_sleep) - assert new_sleep == 2.0 * curr_sleep - - uniform.assert_called_once_with(0.0, curr_sleep) - sleep.assert_called_once_with(uniform.return_value) - - -@mock.patch("random.uniform", return_value=10.5) -@mock.patch("asyncio.sleep", return_value=None) -@pytest.mark.asyncio -async def test_sleep_explicit(sleep, uniform): - from google.cloud.firestore_v1.async_transaction import _sleep - - curr_sleep = 12.25 - assert uniform.return_value <= curr_sleep - - multiplier = 1.5 - new_sleep = await _sleep(curr_sleep, max_sleep=100.0, multiplier=multiplier) - assert new_sleep == multiplier * curr_sleep - - uniform.assert_called_once_with(0.0, curr_sleep) - sleep.assert_called_once_with(uniform.return_value) - - -@mock.patch("random.uniform", return_value=6.75) -@mock.patch("asyncio.sleep", return_value=None) -@pytest.mark.asyncio -async def test_sleep_exceeds_max(sleep, uniform): - from google.cloud.firestore_v1.async_transaction import _sleep - - curr_sleep = 20.0 - assert uniform.return_value <= curr_sleep - - max_sleep = 38.5 - new_sleep = await _sleep(curr_sleep, max_sleep=max_sleep, multiplier=2.0) - assert new_sleep == max_sleep - - uniform.assert_called_once_with(0.0, curr_sleep) - sleep.assert_called_once_with(uniform.return_value) - - def _make_credentials(): import google.auth.credentials diff --git a/tests/unit/v1/test_async_vector_query.py b/tests/unit/v1/test_async_vector_query.py index 8b2a95a26b..390190b534 100644 --- a/tests/unit/v1/test_async_vector_query.py +++ b/tests/unit/v1/test_async_vector_query.py @@ -18,7 +18,12 @@ from google.cloud.firestore_v1.base_vector_query import DistanceMeasure from google.cloud.firestore_v1.types.query import StructuredQuery from google.cloud.firestore_v1.vector import Vector -from tests.unit.v1._test_helpers import make_async_client, make_async_query, make_query +from tests.unit.v1._test_helpers import ( + make_async_client, + make_async_query, + make_async_vector_query, + make_query, +) from tests.unit.v1.test__helpers import AsyncIter, AsyncMock from tests.unit.v1.test_base_query import _make_query_response @@ -33,7 +38,15 @@ def _transaction(client): return transaction -def _expected_pb(parent, vector_field, vector, distance_type, limit): +def _expected_pb( + parent, + vector_field, + vector, + distance_type, + limit, + distance_result_field=None, + distance_threshold=None, +): query = make_query(parent) expected_pb = query._to_protobuf() expected_pb.find_nearest = StructuredQuery.FindNearest( @@ -41,10 +54,40 @@ def _expected_pb(parent, vector_field, vector, distance_type, limit): query_vector=encode_value(vector.to_map_value()), distance_measure=distance_type, limit=limit, + distance_result_field=distance_result_field, + distance_threshold=distance_threshold, ) return expected_pb +def test_async_vector_query_int_threshold_constructor_to_pb(): + client = make_async_client() + parent = client.collection("dee") + query = make_async_query(parent) + vector_query = make_async_vector_query(query) + + assert vector_query._nested_query == query + assert vector_query._client == query._parent._client + + vector_query.find_nearest( + vector_field="embedding", + query_vector=Vector([1.0, 2.0, 3.0]), + distance_measure=DistanceMeasure.EUCLIDEAN, + limit=5, + distance_threshold=5, + ) + + expected_pb = query._to_protobuf() + expected_pb.find_nearest = StructuredQuery.FindNearest( + vector_field=StructuredQuery.FieldReference(field_path="embedding"), + query_vector=encode_value(Vector([1.0, 2.0, 3.0]).to_map_value()), + distance_measure=StructuredQuery.FindNearest.DistanceMeasure.EUCLIDEAN, + limit=5, + distance_threshold=5.0, + ) + assert vector_query._to_protobuf() == expected_pb + + @pytest.mark.parametrize( "distance_measure, expected_distance", [ @@ -188,6 +231,154 @@ async def test_async_vector_query_with_filter(distance_measure, expected_distanc ) +@pytest.mark.parametrize( + "distance_measure, expected_distance", + [ + ( + DistanceMeasure.EUCLIDEAN, + StructuredQuery.FindNearest.DistanceMeasure.EUCLIDEAN, + ), + (DistanceMeasure.COSINE, StructuredQuery.FindNearest.DistanceMeasure.COSINE), + ( + DistanceMeasure.DOT_PRODUCT, + StructuredQuery.FindNearest.DistanceMeasure.DOT_PRODUCT, + ), + ], +) +@pytest.mark.asyncio +async def test_async_vector_query_with_distance_result_field( + distance_measure, expected_distance +): + # Create a minimal fake GAPIC. + firestore_api = AsyncMock(spec=["run_query"]) + client = make_async_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("dee") + query = make_async_query(parent) + parent_path, expected_prefix = parent._parent_info() + + data = {"snooze": 10, "embedding": Vector([1.0, 2.0, 3.5]), "vector_distance": 0.5} + response_pb1 = _make_query_response( + name="{}/test_doc".format(expected_prefix), data=data + ) + response_pb2 = _make_query_response( + name="{}/test_doc".format(expected_prefix), data=data + ) + + kwargs = make_retry_timeout_kwargs(retry=None, timeout=None) + + # Execute the vector query and check the response. + firestore_api.run_query.return_value = AsyncIter([response_pb1, response_pb2]) + + vector_async__query = query.find_nearest( + vector_field="embedding", + query_vector=Vector([1.0, 2.0, 3.0]), + distance_measure=distance_measure, + limit=5, + distance_result_field="vector_distance", + ) + + returned = await vector_async__query.get(transaction=_transaction(client), **kwargs) + assert isinstance(returned, list) + assert len(returned) == 2 + assert returned[0].to_dict() == data + + expected_pb = _expected_pb( + parent=parent, + vector_field="embedding", + vector=Vector([1.0, 2.0, 3.0]), + distance_type=expected_distance, + limit=5, + distance_result_field="vector_distance", + ) + + firestore_api.run_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": expected_pb, + "transaction": _TXN_ID, + }, + metadata=client._rpc_metadata, + **kwargs, + ) + + +@pytest.mark.parametrize( + "distance_measure, expected_distance", + [ + ( + DistanceMeasure.EUCLIDEAN, + StructuredQuery.FindNearest.DistanceMeasure.EUCLIDEAN, + ), + (DistanceMeasure.COSINE, StructuredQuery.FindNearest.DistanceMeasure.COSINE), + ( + DistanceMeasure.DOT_PRODUCT, + StructuredQuery.FindNearest.DistanceMeasure.DOT_PRODUCT, + ), + ], +) +@pytest.mark.asyncio +async def test_async_vector_query_with_distance_threshold( + distance_measure, expected_distance +): + # Create a minimal fake GAPIC. + firestore_api = AsyncMock(spec=["run_query"]) + client = make_async_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("dee") + query = make_async_query(parent) + parent_path, expected_prefix = parent._parent_info() + + data = {"snooze": 10, "embedding": Vector([1.0, 2.0, 3.5])} + response_pb1 = _make_query_response( + name="{}/test_doc".format(expected_prefix), data=data + ) + response_pb2 = _make_query_response( + name="{}/test_doc".format(expected_prefix), data=data + ) + + kwargs = make_retry_timeout_kwargs(retry=None, timeout=None) + + # Execute the vector query and check the response. + firestore_api.run_query.return_value = AsyncIter([response_pb1, response_pb2]) + + vector_async__query = query.find_nearest( + vector_field="embedding", + query_vector=Vector([1.0, 2.0, 3.0]), + distance_measure=distance_measure, + limit=5, + distance_threshold=125.5, + ) + + returned = await vector_async__query.get(transaction=_transaction(client), **kwargs) + assert isinstance(returned, list) + assert len(returned) == 2 + assert returned[0].to_dict() == data + + expected_pb = _expected_pb( + parent=parent, + vector_field="embedding", + vector=Vector([1.0, 2.0, 3.0]), + distance_type=expected_distance, + limit=5, + distance_threshold=125.5, + ) + + firestore_api.run_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": expected_pb, + "transaction": _TXN_ID, + }, + metadata=client._rpc_metadata, + **kwargs, + ) + + @pytest.mark.parametrize( "distance_measure, expected_distance", [ diff --git a/tests/unit/v1/test_transaction.py b/tests/unit/v1/test_transaction.py index fc56d2f9b0..d37be34ea0 100644 --- a/tests/unit/v1/test_transaction.py +++ b/tests/unit/v1/test_transaction.py @@ -810,212 +810,6 @@ def test_transactional_factory(): assert wrapped.to_wrap is mock.sentinel.callable_ -@mock.patch("google.cloud.firestore_v1.transaction._sleep") -@pytest.mark.parametrize("database", [None, "somedb"]) -def test__commit_with_retry_success_first_attempt(_sleep, database): - from google.cloud.firestore_v1.services.firestore import client as firestore_client - from google.cloud.firestore_v1.transaction import _commit_with_retry - - # Create a minimal fake GAPIC with a dummy result. - firestore_api = mock.create_autospec( - firestore_client.FirestoreClient, instance=True - ) - - # Attach the fake GAPIC to a real client. - client = _make_client("summer", database=database) - client._firestore_api_internal = firestore_api - - # Call function and check result. - txn_id = b"cheeeeeez" - commit_response = _commit_with_retry(client, mock.sentinel.write_pbs, txn_id) - assert commit_response is firestore_api.commit.return_value - - # Verify mocks used. - _sleep.assert_not_called() - firestore_api.commit.assert_called_once_with( - request={ - "database": client._database_string, - "writes": mock.sentinel.write_pbs, - "transaction": txn_id, - }, - metadata=client._rpc_metadata, - ) - - -@mock.patch("google.cloud.firestore_v1.transaction._sleep", side_effect=[2.0, 4.0]) -@pytest.mark.parametrize("database", [None, "somedb"]) -def test__commit_with_retry_success_third_attempt(_sleep, database): - from google.api_core import exceptions - - from google.cloud.firestore_v1.services.firestore import client as firestore_client - from google.cloud.firestore_v1.transaction import _commit_with_retry - - # Create a minimal fake GAPIC with a dummy result. - firestore_api = mock.create_autospec( - firestore_client.FirestoreClient, instance=True - ) - # Make sure the first two requests fail and the third succeeds. - firestore_api.commit.side_effect = [ - exceptions.ServiceUnavailable("Server sleepy."), - exceptions.ServiceUnavailable("Server groggy."), - mock.sentinel.commit_response, - ] - - # Attach the fake GAPIC to a real client. - client = _make_client("outside", database=database) - client._firestore_api_internal = firestore_api - - # Call function and check result. - txn_id = b"the-world\x00" - commit_response = _commit_with_retry(client, mock.sentinel.write_pbs, txn_id) - assert commit_response is mock.sentinel.commit_response - - # Verify mocks used. - # Ensure _sleep is called after commit failures, with intervals of 1 and 2 seconds - assert _sleep.call_count == 2 - _sleep.assert_any_call(1.0) - _sleep.assert_any_call(2.0) - # commit() called same way 3 times. - commit_call = mock.call( - request={ - "database": client._database_string, - "writes": mock.sentinel.write_pbs, - "transaction": txn_id, - }, - metadata=client._rpc_metadata, - ) - assert firestore_api.commit.mock_calls == [commit_call, commit_call, commit_call] - - -@mock.patch("google.cloud.firestore_v1.transaction._sleep") -@pytest.mark.parametrize("database", [None, "somedb"]) -def test__commit_with_retry_failure_first_attempt(_sleep, database): - from google.api_core import exceptions - - from google.cloud.firestore_v1.services.firestore import client as firestore_client - from google.cloud.firestore_v1.transaction import _commit_with_retry - - # Create a minimal fake GAPIC with a dummy result. - firestore_api = mock.create_autospec( - firestore_client.FirestoreClient, instance=True - ) - # Make sure the first request fails with an un-retryable error. - exc = exceptions.ResourceExhausted("We ran out of fries.") - firestore_api.commit.side_effect = exc - - # Attach the fake GAPIC to a real client. - client = _make_client("peanut-butter", database=database) - client._firestore_api_internal = firestore_api - - # Call function and check result. - txn_id = b"\x08\x06\x07\x05\x03\x00\x09-jenny" - with pytest.raises(exceptions.ResourceExhausted) as exc_info: - _commit_with_retry(client, mock.sentinel.write_pbs, txn_id) - - assert exc_info.value is exc - - # Verify mocks used. - _sleep.assert_not_called() - firestore_api.commit.assert_called_once_with( - request={ - "database": client._database_string, - "writes": mock.sentinel.write_pbs, - "transaction": txn_id, - }, - metadata=client._rpc_metadata, - ) - - -@mock.patch("google.cloud.firestore_v1.transaction._sleep", return_value=2.0) -@pytest.mark.parametrize("database", [None, "somedb"]) -def test__commit_with_retry_failure_second_attempt(_sleep, database): - from google.api_core import exceptions - - from google.cloud.firestore_v1.services.firestore import client as firestore_client - from google.cloud.firestore_v1.transaction import _commit_with_retry - - # Create a minimal fake GAPIC with a dummy result. - firestore_api = mock.create_autospec( - firestore_client.FirestoreClient, instance=True - ) - # Make sure the first request fails retry-able and second - # fails non-retryable. - exc1 = exceptions.ServiceUnavailable("Come back next time.") - exc2 = exceptions.InternalServerError("Server on fritz.") - firestore_api.commit.side_effect = [exc1, exc2] - - # Attach the fake GAPIC to a real client. - client = _make_client("peanut-butter", database=database) - client._firestore_api_internal = firestore_api - - # Call function and check result. - txn_id = b"the-journey-when-and-where-well-go" - with pytest.raises(exceptions.InternalServerError) as exc_info: - _commit_with_retry(client, mock.sentinel.write_pbs, txn_id) - - assert exc_info.value is exc2 - - # Verify mocks used. - _sleep.assert_called_once_with(1.0) - # commit() called same way 2 times. - commit_call = mock.call( - request={ - "database": client._database_string, - "writes": mock.sentinel.write_pbs, - "transaction": txn_id, - }, - metadata=client._rpc_metadata, - ) - assert firestore_api.commit.mock_calls == [commit_call, commit_call] - - -@mock.patch("random.uniform", return_value=5.5) -@mock.patch("time.sleep", return_value=None) -def test_defaults(sleep, uniform): - from google.cloud.firestore_v1.transaction import _sleep - - curr_sleep = 10.0 - assert uniform.return_value <= curr_sleep - - new_sleep = _sleep(curr_sleep) - assert new_sleep == 2.0 * curr_sleep - - uniform.assert_called_once_with(0.0, curr_sleep) - sleep.assert_called_once_with(uniform.return_value) - - -@mock.patch("random.uniform", return_value=10.5) -@mock.patch("time.sleep", return_value=None) -def test_explicit(sleep, uniform): - from google.cloud.firestore_v1.transaction import _sleep - - curr_sleep = 12.25 - assert uniform.return_value <= curr_sleep - - multiplier = 1.5 - new_sleep = _sleep(curr_sleep, max_sleep=100.0, multiplier=multiplier) - assert new_sleep == multiplier * curr_sleep - - uniform.assert_called_once_with(0.0, curr_sleep) - sleep.assert_called_once_with(uniform.return_value) - - -@mock.patch("random.uniform", return_value=6.75) -@mock.patch("time.sleep", return_value=None) -def test_exceeds_max(sleep, uniform): - from google.cloud.firestore_v1.transaction import _sleep - - curr_sleep = 20.0 - assert uniform.return_value <= curr_sleep - - max_sleep = 38.5 - new_sleep = _sleep(curr_sleep, max_sleep=max_sleep, multiplier=2.0) - assert new_sleep == max_sleep - - uniform.assert_called_once_with(0.0, curr_sleep) - sleep.assert_called_once_with(uniform.return_value) - - def _make_credentials(): import google.auth.credentials diff --git a/tests/unit/v1/test_vector_query.py b/tests/unit/v1/test_vector_query.py index beb0941413..a5b1d342bd 100644 --- a/tests/unit/v1/test_vector_query.py +++ b/tests/unit/v1/test_vector_query.py @@ -54,6 +54,8 @@ def test_vector_query_constructor_to_pb(distance_measure, expected_distance): query_vector=Vector([1.0, 2.0, 3.0]), distance_measure=distance_measure, limit=5, + distance_result_field="vector_distance", + distance_threshold=125.5, ) expected_pb = query._to_protobuf() @@ -62,6 +64,36 @@ def test_vector_query_constructor_to_pb(distance_measure, expected_distance): query_vector=encode_value(Vector([1.0, 2.0, 3.0]).to_map_value()), distance_measure=expected_distance, limit=5, + distance_result_field="vector_distance", + distance_threshold=125.5, + ) + assert vector_query._to_protobuf() == expected_pb + + +def test_vector_query_int_threshold_constructor_to_pb(): + client = make_client() + parent = client.collection("dee") + query = make_query(parent) + vector_query = make_vector_query(query) + + assert vector_query._nested_query == query + assert vector_query._client == query._parent._client + + vector_query.find_nearest( + vector_field="embedding", + query_vector=Vector([1.0, 2.0, 3.0]), + distance_measure=DistanceMeasure.EUCLIDEAN, + limit=5, + distance_threshold=5, + ) + + expected_pb = query._to_protobuf() + expected_pb.find_nearest = StructuredQuery.FindNearest( + vector_field=StructuredQuery.FieldReference(field_path="embedding"), + query_vector=encode_value(Vector([1.0, 2.0, 3.0]).to_map_value()), + distance_measure=StructuredQuery.FindNearest.DistanceMeasure.EUCLIDEAN, + limit=5, + distance_threshold=5.0, ) assert vector_query._to_protobuf() == expected_pb @@ -92,7 +124,15 @@ def _transaction(client): return transaction -def _expected_pb(parent, vector_field, vector, distance_type, limit): +def _expected_pb( + parent, + vector_field, + vector, + distance_type, + limit, + distance_result_field=None, + distance_threshold=None, +): query = make_query(parent) expected_pb = query._to_protobuf() expected_pb.find_nearest = StructuredQuery.FindNearest( @@ -100,6 +140,8 @@ def _expected_pb(parent, vector_field, vector, distance_type, limit): query_vector=encode_value(vector.to_map_value()), distance_measure=distance_type, limit=limit, + distance_result_field=distance_result_field, + distance_threshold=distance_threshold, ) return expected_pb @@ -168,6 +210,138 @@ def test_vector_query(distance_measure, expected_distance): ) +@pytest.mark.parametrize( + "distance_measure, expected_distance", + [ + ( + DistanceMeasure.EUCLIDEAN, + StructuredQuery.FindNearest.DistanceMeasure.EUCLIDEAN, + ), + (DistanceMeasure.COSINE, StructuredQuery.FindNearest.DistanceMeasure.COSINE), + ( + DistanceMeasure.DOT_PRODUCT, + StructuredQuery.FindNearest.DistanceMeasure.DOT_PRODUCT, + ), + ], +) +def test_vector_query_with_distance_result_field(distance_measure, expected_distance): + # Create a minimal fake GAPIC. + firestore_api = mock.Mock(spec=["run_query"]) + client = make_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("dee") + parent_path, expected_prefix = parent._parent_info() + + data = {"snooze": 10, "embedding": Vector([1.0, 2.0, 3.5]), "vector_distance": 0.5} + response_pb = _make_query_response( + name="{}/test_doc".format(expected_prefix), data=data + ) + + kwargs = make_retry_timeout_kwargs(retry=None, timeout=None) + + # Execute the vector query and check the response. + firestore_api.run_query.return_value = iter([response_pb]) + + vector_query = parent.find_nearest( + vector_field="embedding", + query_vector=Vector([1.0, 2.0, 3.0]), + distance_measure=distance_measure, + limit=5, + distance_result_field="vector_distance", + ) + + returned = vector_query.get(transaction=_transaction(client), **kwargs) + assert isinstance(returned, list) + assert len(returned) == 1 + assert returned[0].to_dict() == data + + expected_pb = _expected_pb( + parent=parent, + vector_field="embedding", + vector=Vector([1.0, 2.0, 3.0]), + distance_type=expected_distance, + limit=5, + distance_result_field="vector_distance", + ) + firestore_api.run_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": expected_pb, + "transaction": _TXN_ID, + }, + metadata=client._rpc_metadata, + **kwargs, + ) + + +@pytest.mark.parametrize( + "distance_measure, expected_distance", + [ + ( + DistanceMeasure.EUCLIDEAN, + StructuredQuery.FindNearest.DistanceMeasure.EUCLIDEAN, + ), + (DistanceMeasure.COSINE, StructuredQuery.FindNearest.DistanceMeasure.COSINE), + ( + DistanceMeasure.DOT_PRODUCT, + StructuredQuery.FindNearest.DistanceMeasure.DOT_PRODUCT, + ), + ], +) +def test_vector_query_with_distance_threshold(distance_measure, expected_distance): + # Create a minimal fake GAPIC. + firestore_api = mock.Mock(spec=["run_query"]) + client = make_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("dee") + parent_path, expected_prefix = parent._parent_info() + + data = {"snooze": 10, "embedding": Vector([1.0, 2.0, 3.5])} + response_pb = _make_query_response( + name="{}/test_doc".format(expected_prefix), data=data + ) + + kwargs = make_retry_timeout_kwargs(retry=None, timeout=None) + + # Execute the vector query and check the response. + firestore_api.run_query.return_value = iter([response_pb]) + + vector_query = parent.find_nearest( + vector_field="embedding", + query_vector=Vector([1.0, 2.0, 3.0]), + distance_measure=distance_measure, + limit=5, + distance_threshold=0.75, + ) + + returned = vector_query.get(transaction=_transaction(client), **kwargs) + assert isinstance(returned, list) + assert len(returned) == 1 + assert returned[0].to_dict() == data + + expected_pb = _expected_pb( + parent=parent, + vector_field="embedding", + vector=Vector([1.0, 2.0, 3.0]), + distance_type=expected_distance, + limit=5, + distance_threshold=0.75, + ) + firestore_api.run_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": expected_pb, + "transaction": _TXN_ID, + }, + metadata=client._rpc_metadata, + **kwargs, + ) + + @pytest.mark.parametrize( "distance_measure, expected_distance", [