diff --git a/.release-please-manifest.json b/.release-please-manifest.json index f122d158c5..882f663e6b 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,3 +1,3 @@ { - ".": "2.16.1" + ".": "2.17.0" } \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index e82a53b506..602974f2c1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,14 @@ [1]: https://pypi.org/project/google-cloud-firestore/#history +## [2.17.0](https://github.com/googleapis/python-firestore/compare/v2.16.1...v2.17.0) (2024-07-12) + + +### Features + +* Support async Vector Search ([#901](https://github.com/googleapis/python-firestore/issues/901)) ([2de1620](https://github.com/googleapis/python-firestore/commit/2de16209409c9d9ba41d3444400e6a39ee1b2936)) +* Use generator for stream results ([#926](https://github.com/googleapis/python-firestore/issues/926)) ([3e5df35](https://github.com/googleapis/python-firestore/commit/3e5df3565c9fc6f73f60207a46ebe1cd70c4df8d)) + ## [2.16.1](https://github.com/googleapis/python-firestore/compare/v2.16.0...v2.16.1) (2024-04-17) diff --git a/google/cloud/firestore/gapic_version.py b/google/cloud/firestore/gapic_version.py index 8edfaef714..8d4f4cfb61 100644 --- a/google/cloud/firestore/gapic_version.py +++ b/google/cloud/firestore/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "2.16.1" # {x-release-please-version} +__version__ = "2.17.0" # {x-release-please-version} diff --git a/google/cloud/firestore_admin_v1/gapic_version.py b/google/cloud/firestore_admin_v1/gapic_version.py index 8edfaef714..8d4f4cfb61 100644 --- a/google/cloud/firestore_admin_v1/gapic_version.py +++ b/google/cloud/firestore_admin_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "2.16.1" # {x-release-please-version} +__version__ = "2.17.0" # {x-release-please-version} diff --git a/google/cloud/firestore_bundle/gapic_version.py b/google/cloud/firestore_bundle/gapic_version.py index 8edfaef714..8d4f4cfb61 100644 --- a/google/cloud/firestore_bundle/gapic_version.py +++ b/google/cloud/firestore_bundle/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "2.16.1" # {x-release-please-version} +__version__ = "2.17.0" # {x-release-please-version} diff --git a/google/cloud/firestore_v1/__init__.py b/google/cloud/firestore_v1/__init__.py index 1d143556fe..1aff5ec740 100644 --- a/google/cloud/firestore_v1/__init__.py +++ b/google/cloud/firestore_v1/__init__.py @@ -23,42 +23,45 @@ __version__ = package_version.__version__ +from typing import List + from google.cloud.firestore_v1 import types -from google.cloud.firestore_v1._helpers import GeoPoint -from google.cloud.firestore_v1._helpers import ExistsOption -from google.cloud.firestore_v1._helpers import LastUpdateOption -from google.cloud.firestore_v1._helpers import ReadAfterWriteError -from google.cloud.firestore_v1._helpers import WriteOption -from google.cloud.firestore_v1.base_aggregation import CountAggregation -from google.cloud.firestore_v1.base_query import And -from google.cloud.firestore_v1.base_query import FieldFilter -from google.cloud.firestore_v1.base_query import Or +from google.cloud.firestore_v1._helpers import ( + ExistsOption, + GeoPoint, + LastUpdateOption, + ReadAfterWriteError, + WriteOption, +) from google.cloud.firestore_v1.async_batch import AsyncWriteBatch from google.cloud.firestore_v1.async_client import AsyncClient from google.cloud.firestore_v1.async_collection import AsyncCollectionReference from google.cloud.firestore_v1.async_document import AsyncDocumentReference from google.cloud.firestore_v1.async_query import AsyncQuery -from google.cloud.firestore_v1.async_transaction import async_transactional -from google.cloud.firestore_v1.async_transaction import AsyncTransaction +from google.cloud.firestore_v1.async_transaction import ( + AsyncTransaction, + async_transactional, +) +from google.cloud.firestore_v1.base_aggregation import CountAggregation from google.cloud.firestore_v1.base_document import DocumentSnapshot +from google.cloud.firestore_v1.base_query import And, FieldFilter, Or from google.cloud.firestore_v1.batch import WriteBatch from google.cloud.firestore_v1.client import Client from google.cloud.firestore_v1.collection import CollectionReference from google.cloud.firestore_v1.document import DocumentReference -from google.cloud.firestore_v1.query import CollectionGroup -from google.cloud.firestore_v1.query import Query -from google.cloud.firestore_v1.transaction import Transaction -from google.cloud.firestore_v1.transaction import transactional -from google.cloud.firestore_v1.transforms import ArrayRemove -from google.cloud.firestore_v1.transforms import ArrayUnion -from google.cloud.firestore_v1.transforms import DELETE_FIELD -from google.cloud.firestore_v1.transforms import Increment -from google.cloud.firestore_v1.transforms import Maximum -from google.cloud.firestore_v1.transforms import Minimum -from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP +from google.cloud.firestore_v1.query import CollectionGroup, Query +from google.cloud.firestore_v1.transaction import Transaction, transactional +from google.cloud.firestore_v1.transforms import ( + DELETE_FIELD, + SERVER_TIMESTAMP, + ArrayRemove, + ArrayUnion, + Increment, + Maximum, + Minimum, +) from google.cloud.firestore_v1.watch import Watch - # TODO(https://github.com/googleapis/python-firestore/issues/93): this is all on the generated surface. We require this to match # firestore.py. So comment out until needed on customer level for certain. # from .services.firestore import FirestoreClient @@ -102,8 +105,6 @@ # from .types.write import DocumentDelete # from .types.write import DocumentRemove from .types.write import DocumentTransform -from typing import List - # from .types.write import ExistenceFilter # from .types.write import Write diff --git a/google/cloud/firestore_v1/_helpers.py b/google/cloud/firestore_v1/_helpers.py index 932b3746b5..c829321df9 100644 --- a/google/cloud/firestore_v1/_helpers.py +++ b/google/cloud/firestore_v1/_helpers.py @@ -16,26 +16,6 @@ import datetime import json - -import google -from google.api_core.datetime_helpers import DatetimeWithNanoseconds -from google.api_core import gapic_v1 -from google.protobuf import struct_pb2 -from google.type import latlng_pb2 # type: ignore -import grpc # type: ignore - -from google.cloud import exceptions # type: ignore -from google.cloud._helpers import _datetime_to_pb_timestamp # type: ignore -from google.cloud.firestore_v1.vector import Vector -from google.cloud.firestore_v1.types.write import DocumentTransform -from google.cloud.firestore_v1 import transforms -from google.cloud.firestore_v1 import types -from google.cloud.firestore_v1.field_path import FieldPath -from google.cloud.firestore_v1.field_path import parse_field_path -from google.cloud.firestore_v1.types import common -from google.cloud.firestore_v1.types import document -from google.cloud.firestore_v1.types import write -from google.protobuf.timestamp_pb2 import Timestamp # type: ignore from typing import ( Any, Dict, @@ -48,6 +28,22 @@ Union, ) +import grpc # type: ignore +from google.api_core import gapic_v1 +from google.api_core.datetime_helpers import DatetimeWithNanoseconds +from google.cloud._helpers import _datetime_to_pb_timestamp # type: ignore +from google.protobuf import struct_pb2 +from google.protobuf.timestamp_pb2 import Timestamp # type: ignore +from google.type import latlng_pb2 # type: ignore + +import google +from google.cloud import exceptions # type: ignore +from google.cloud.firestore_v1 import transforms, types +from google.cloud.firestore_v1.field_path import FieldPath, parse_field_path +from google.cloud.firestore_v1.types import common, document, write +from google.cloud.firestore_v1.types.write import DocumentTransform +from google.cloud.firestore_v1.vector import Vector + _EmptyDict: transforms.Sentinel _GRPC_ERROR_MAPPING: dict diff --git a/google/cloud/firestore_v1/aggregation.py b/google/cloud/firestore_v1/aggregation.py index 609f82f75a..65106122ab 100644 --- a/google/cloud/firestore_v1/aggregation.py +++ b/google/cloud/firestore_v1/aggregation.py @@ -20,18 +20,22 @@ """ from __future__ import annotations -from google.api_core import exceptions -from google.api_core import gapic_v1 -from google.api_core import retry as retries +from typing import TYPE_CHECKING, Any, Generator, List, Optional, Union +from google.api_core import exceptions, gapic_v1 +from google.api_core import retry as retries from google.cloud.firestore_v1.base_aggregation import ( AggregationResult, BaseAggregationQuery, _query_response_to_result, ) +from google.cloud.firestore_v1.base_document import DocumentSnapshot +from google.cloud.firestore_v1.stream_generator import StreamGenerator -from typing import Generator, Union, List, Any +# Types needed only for Type Hints +if TYPE_CHECKING: + from google.cloud.firestore_v1 import transaction # pragma: NO COVER class AggregationQuery(BaseAggregationQuery): @@ -99,36 +103,34 @@ def _retry_query_after_exception(self, exc, retry, transaction): return False - def stream( + def _make_stream( self, - transaction=None, - retry: Union[ - retries.Retry, None, gapic_v1.method._MethodDefault - ] = gapic_v1.method.DEFAULT, - timeout: float | None = None, + transaction: Optional[transaction.Transaction] = None, + retry: Optional[retries.Retry] = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, ) -> Union[Generator[List[AggregationResult], Any, None]]: - """Runs the aggregation query. + """Internal method for stream(). Runs the aggregation query. - This sends a ``RunAggregationQuery`` RPC and then returns an iterator which - consumes each document returned in the stream of ``RunAggregationQueryResponse`` - messages. + This sends a ``RunAggregationQuery`` RPC and then returns a generator + which consumes each document returned in the stream of + ``RunAggregationQueryResponse`` messages. - If a ``transaction`` is used and it already has write operations - added, this method cannot be used (i.e. read-after-write is not - allowed). + If a ``transaction`` is used and it already has write operations added, + this method cannot be used (i.e. read-after-write is not allowed). Args: transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): An existing transaction that this query will run in. - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. Defaults to a system-specified policy. - timeout (float): The timeout for this request. Defaults to a - system-specified value. + retry (Optional[google.api_core.retry.Retry]): Designation of what + errors, if any, should be retried. Defaults to a + system-specified policy. + timeout (Optional[float]): The timeout for this request. Defaults + to a system-specified value. Yields: :class:`~google.cloud.firestore_v1.base_aggregation.AggregationResult`: - The result of aggregations of this query + The result of aggregations of this query. """ response_iterator = self._get_stream_iterator( @@ -154,3 +156,38 @@ def stream( break result = _query_response_to_result(response) yield result + + def stream( + self, + transaction: Optional["transaction.Transaction"] = None, + retry: Optional[retries.Retry] = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + ) -> "StreamGenerator[DocumentSnapshot]": + """Runs the aggregation query. + + This sends a ``RunAggregationQuery`` RPC and then returns a generator + which consumes each document returned in the stream of + ``RunAggregationQueryResponse`` messages. + + If a ``transaction`` is used and it already has write operations added, + this method cannot be used (i.e. read-after-write is not allowed). + + Args: + transaction + (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + An existing transaction that this query will run in. + retry (Optional[google.api_core.retry.Retry]): Designation of what + errors, if any, should be retried. Defaults to a + system-specified policy. + timeout (Optinal[float]): The timeout for this request. Defaults + to a system-specified value. + + Returns: + `StreamGenerator[DocumentSnapshot]`: A generator of the query results. + """ + inner_generator = self._make_stream( + transaction=transaction, + retry=retry, + timeout=timeout, + ) + return StreamGenerator(inner_generator) diff --git a/google/cloud/firestore_v1/async_aggregation.py b/google/cloud/firestore_v1/async_aggregation.py index c39b50c5e4..1c75f0cfd8 100644 --- a/google/cloud/firestore_v1/async_aggregation.py +++ b/google/cloud/firestore_v1/async_aggregation.py @@ -20,18 +20,22 @@ """ from __future__ import annotations +from typing import TYPE_CHECKING, AsyncGenerator, List, Optional, Union + from google.api_core import gapic_v1 from google.api_core import retry_async as retries -from typing import List, Union, AsyncGenerator - - +from google.cloud.firestore_v1 import transaction +from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator from google.cloud.firestore_v1.base_aggregation import ( AggregationResult, - _query_response_to_result, BaseAggregationQuery, + _query_response_to_result, ) +if TYPE_CHECKING: # pragma: NO COVER + from google.cloud.firestore_v1.base_document import DocumentSnapshot + class AsyncAggregationQuery(BaseAggregationQuery): """Represents an aggregation query to the Firestore API.""" @@ -76,17 +80,15 @@ async def get( result = [aggregation async for aggregation in stream_result] return result # type: ignore - async def stream( + async def _make_stream( self, - transaction=None, - retry: Union[ - retries.AsyncRetry, None, gapic_v1.method._MethodDefault - ] = gapic_v1.method.DEFAULT, - timeout: float | None = None, + transaction: Optional[transaction.Transaction] = None, + retry: Optional[retries.AsyncRetry] = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, ) -> Union[AsyncGenerator[List[AggregationResult], None]]: - """Runs the aggregation query. + """Internal method for stream(). Runs the aggregation query. - This sends a ``RunAggregationQuery`` RPC and then returns an iterator which + This sends a ``RunAggregationQuery`` RPC and then returns a generator which consumes each document returned in the stream of ``RunAggregationQueryResponse`` messages. @@ -95,13 +97,14 @@ async def stream( allowed). Args: - transaction - (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): - An existing transaction that this query will run in. - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. Defaults to a system-specified policy. - timeout (float): The timeout for this request. Defaults to a - system-specified value. + transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.\ + Transaction`]): + An existing transaction that the query will run in. + retry (Optional[google.api_core.retry.Retry]): Designation of what + errors, if any, should be retried. Defaults to a + system-specified policy. + timeout (Optional[float]): The timeout for this request. Defaults + to a system-specified value. Yields: :class:`~google.cloud.firestore_v1.base_aggregation.AggregationResult`: @@ -122,3 +125,40 @@ async def stream( async for response in response_iterator: result = _query_response_to_result(response) yield result + + def stream( + self, + transaction: Optional[transaction.Transaction] = None, + retry: Optional[retries.AsyncRetry] = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + ) -> "AsyncStreamGenerator[DocumentSnapshot]": + """Runs the aggregation query. + + This sends a ``RunAggregationQuery`` RPC and then returns a generator + which consumes each document returned in the stream of + ``RunAggregationQueryResponse`` messages. + + If a ``transaction`` is used and it already has write operations added, + this method cannot be used (i.e. read-after-write is not allowed). + + Args: + transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.\ + Transaction`]): + An existing transaction that the query will run in. + retry (Optional[google.api_core.retry.Retry]): Designation of what + errors, if any, should be retried. Defaults to a + system-specified policy. + timeout (Optional[float]): The timeout for this request. Defaults + to a system-specified value. + + Returns: + `AsyncStreamGenerator[DocumentSnapshot]`: + A generator of the query results. + """ + + inner_generator = self._make_stream( + transaction=transaction, + retry=retry, + timeout=timeout, + ) + return AsyncStreamGenerator(inner_generator) diff --git a/google/cloud/firestore_v1/async_client.py b/google/cloud/firestore_v1/async_client.py index 20541c3770..f14ec6573b 100644 --- a/google/cloud/firestore_v1/async_client.py +++ b/google/cloud/firestore_v1/async_client.py @@ -24,24 +24,21 @@ :class:`~google.cloud.firestore_v1.async_document.AsyncDocumentReference` """ +from typing import TYPE_CHECKING, Any, AsyncGenerator, Iterable, List, Optional, Union + from google.api_core import gapic_v1 from google.api_core import retry_async as retries -from google.cloud.firestore_v1.base_client import ( - BaseClient, - _CLIENT_INFO, - _parse_batch_get, # type: ignore - _path_helper, -) - -from google.cloud.firestore_v1.async_query import AsyncCollectionGroup from google.cloud.firestore_v1.async_batch import AsyncWriteBatch from google.cloud.firestore_v1.async_collection import AsyncCollectionReference from google.cloud.firestore_v1.async_document import ( AsyncDocumentReference, DocumentSnapshot, ) +from google.cloud.firestore_v1.async_query import AsyncCollectionGroup from google.cloud.firestore_v1.async_transaction import AsyncTransaction +from google.cloud.firestore_v1.base_client import _parse_batch_get # type: ignore +from google.cloud.firestore_v1.base_client import _CLIENT_INFO, BaseClient, _path_helper from google.cloud.firestore_v1.field_path import FieldPath from google.cloud.firestore_v1.services.firestore import ( async_client as firestore_client, @@ -49,7 +46,6 @@ from google.cloud.firestore_v1.services.firestore.transports import ( grpc_asyncio as firestore_grpc_transport, ) -from typing import Any, AsyncGenerator, Iterable, List, Optional, Union, TYPE_CHECKING if TYPE_CHECKING: from google.cloud.firestore_v1.bulk_writer import BulkWriter # pragma: NO COVER diff --git a/google/cloud/firestore_v1/async_collection.py b/google/cloud/firestore_v1/async_collection.py index 093117d40b..7032b1bdcb 100644 --- a/google/cloud/firestore_v1/async_collection.py +++ b/google/cloud/firestore_v1/async_collection.py @@ -14,22 +14,26 @@ """Classes for representing collections for the Google Cloud Firestore API.""" +from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional, Tuple + from google.api_core import gapic_v1 from google.api_core import retry_async as retries +from google.cloud.firestore_v1 import ( + async_aggregation, + async_document, + async_query, + transaction, +) from google.cloud.firestore_v1.base_collection import ( BaseCollectionReference, _item_to_document_ref, ) -from google.cloud.firestore_v1 import async_query, async_document, async_aggregation - from google.cloud.firestore_v1.document import DocumentReference -from typing import AsyncIterator -from typing import Any, AsyncGenerator, Tuple - -# Types needed only for Type Hints -from google.cloud.firestore_v1.transaction import Transaction +if TYPE_CHECKING: # pragma: NO COVER + from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator + from google.cloud.firestore_v1.base_document import DocumentSnapshot class AsyncCollectionReference(BaseCollectionReference[async_query.AsyncQuery]): @@ -176,9 +180,9 @@ async def list_documents( async def get( self, - transaction: Transaction = None, - retry: retries.AsyncRetry = gapic_v1.method.DEFAULT, - timeout: float = None, + transaction: Optional[transaction.Transaction] = None, + retry: Optional[retries.AsyncRetry] = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, ) -> list: """Read the documents in this collection. @@ -189,14 +193,14 @@ async def get( transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): An existing transaction that this query will run in. - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. Defaults to a system-specified policy. - timeout (float): The timeout for this request. Defaults to a - system-specified value. + retry (Optional[google.api_core.retry.Retry]): Designation of what + errors, if any, should be retried. Defaults to a + system-specified policy. + timeout (Otional[float]): The timeout for this request. Defaults + to a system-specified value. - If a ``transaction`` is used and it already has write operations - added, this method cannot be used (i.e. read-after-write is not - allowed). + If a ``transaction`` is used and it already has write operations added, + this method cannot be used (i.e. read-after-write is not allowed). Returns: list: The documents in this collection that match the query. @@ -205,15 +209,15 @@ async def get( return await query.get(transaction=transaction, **kwargs) - async def stream( + def stream( self, - transaction: Transaction = None, - retry: retries.AsyncRetry = gapic_v1.method.DEFAULT, - timeout: float = None, - ) -> AsyncIterator[async_document.DocumentSnapshot]: + transaction: Optional[transaction.Transaction] = None, + retry: Optional[retries.AsyncRetry] = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + ) -> "AsyncStreamGenerator[DocumentSnapshot]": """Read the documents in this collection. - This sends a ``RunQuery`` RPC and then returns an iterator which + This sends a ``RunQuery`` RPC and then returns a generator which consumes each document returned in the stream of ``RunQueryResponse`` messages. @@ -232,16 +236,16 @@ async def stream( transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.\ Transaction`]): An existing transaction that the query will run in. - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. Defaults to a system-specified policy. - timeout (float): The timeout for this request. Defaults to a - system-specified value. + retry (Optional[google.api_core.retry.Retry]): Designation of what + errors, if any, should be retried. Defaults to a + system-specified policy. + timeout (Optional[float]): The timeout for this request. Defaults + to a system-specified value. - Yields: - :class:`~google.cloud.firestore_v1.document.DocumentSnapshot`: - The next document that fulfills the query. + Returns: + `AsyncStreamGenerator[DocumentSnapshot]`: A generator of the query + results. """ query, kwargs = self._prep_get_or_stream(retry, timeout) - async for d in query.stream(transaction=transaction, **kwargs): - yield d # pytype: disable=name-error + return query.stream(transaction=transaction, **kwargs) diff --git a/google/cloud/firestore_v1/async_document.py b/google/cloud/firestore_v1/async_document.py index 75250d0b4c..a697e86302 100644 --- a/google/cloud/firestore_v1/async_document.py +++ b/google/cloud/firestore_v1/async_document.py @@ -15,21 +15,20 @@ """Classes for representing documents for the Google Cloud Firestore API.""" import datetime import logging +from typing import AsyncGenerator, Iterable from google.api_core import gapic_v1 from google.api_core import retry_async as retries from google.cloud._helpers import _datetime_to_pb_timestamp # type: ignore +from google.protobuf.timestamp_pb2 import Timestamp +from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.base_document import ( BaseDocumentReference, DocumentSnapshot, _first_write_result, ) -from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.types import write -from google.protobuf.timestamp_pb2 import Timestamp -from typing import AsyncGenerator, Iterable - logger = logging.getLogger(__name__) diff --git a/google/cloud/firestore_v1/async_query.py b/google/cloud/firestore_v1/async_query.py index 8ee4012904..15f81be247 100644 --- a/google/cloud/firestore_v1/async_query.py +++ b/google/cloud/firestore_v1/async_query.py @@ -20,28 +20,31 @@ """ from __future__ import annotations +from typing import TYPE_CHECKING, AsyncGenerator, List, Optional, Type + from google.api_core import gapic_v1 from google.api_core import retry_async as retries from google.cloud import firestore_v1 +from google.cloud.firestore_v1 import async_document, transaction +from google.cloud.firestore_v1.async_aggregation import AsyncAggregationQuery +from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator +from google.cloud.firestore_v1.async_vector_query import AsyncVectorQuery from google.cloud.firestore_v1.base_query import ( BaseCollectionGroup, BaseQuery, QueryPartition, - _query_response_to_snapshot, _collection_group_query_response_to_snapshot, _enum_from_direction, + _query_response_to_snapshot, ) -from google.cloud.firestore_v1 import async_document -from google.cloud.firestore_v1.async_aggregation import AsyncAggregationQuery -from google.cloud.firestore_v1.base_document import DocumentSnapshot -from typing import AsyncGenerator, List, Optional, Type, TYPE_CHECKING - if TYPE_CHECKING: # pragma: NO COVER # Types needed only for Type Hints - from google.cloud.firestore_v1.transaction import Transaction + from google.cloud.firestore_v1.base_document import DocumentSnapshot + from google.cloud.firestore_v1.base_vector_query import DistanceMeasure from google.cloud.firestore_v1.field_path import FieldPath + from google.cloud.firestore_v1.vector import Vector class AsyncQuery(BaseQuery): @@ -171,9 +174,9 @@ async def _chunkify( async def get( self, - transaction: Transaction = None, - retry: retries.AsyncRetry = gapic_v1.method.DEFAULT, - timeout: float = None, + transaction: Optional[transaction.Transaction] = None, + retry: Optional[retries.AsyncRetry] = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, ) -> list: """Read the documents in the collection that match this query. @@ -184,10 +187,11 @@ async def get( transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): An existing transaction that this query will run in. - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. Defaults to a system-specified policy. - timeout (float): The timeout for this request. Defaults to a - system-specified value. + retry (Optional[google.api_core.retry.Retry]): Designation of what + errors, if any, should be retried. Defaults to a + system-specified policy. + timeout (Otional[float]): The timeout for this request. Defaults + to a system-specified value. If a ``transaction`` is used and it already has write operations added, this method cannot be used (i.e. read-after-write is not @@ -209,14 +213,45 @@ async def get( else self.ASCENDING ) self._limit_to_last = False - - result = self.stream(transaction=transaction, retry=retry, timeout=timeout) + result = self.stream( + transaction=transaction, + retry=retry, + timeout=timeout, + ) result = [d async for d in result] if is_limited_to_last: result = list(reversed(result)) return result + def find_nearest( + self, + vector_field: str, + query_vector: Vector, + limit: int, + distance_measure: DistanceMeasure, + ) -> AsyncVectorQuery: + """ + Finds the closest vector embeddings to the given query vector. + + Args: + vector_field(str): An indexed vector field to search upon. Only documents which contain + vectors whose dimensionality match the query_vector can be returned. + query_vector(Vector): The query vector that we are searching on. Must be a vector of no more + than 2048 dimensions. + limit (int): The number of nearest neighbors to return. Must be a positive integer of no more than 1000. + distance_measure(:class:`DistanceMeasure`): The Distance Measure to use. + + Returns: + :class`~firestore_v1.vector_query.VectorQuery`: the vector query. + """ + return AsyncVectorQuery(self).find_nearest( + vector_field=vector_field, + query_vector=query_vector, + limit=limit, + distance_measure=distance_measure, + ) + def count( self, alias: str | None = None ) -> Type["firestore_v1.async_aggregation.AsyncAggregationQuery"]: @@ -264,15 +299,16 @@ def avg( """ return AsyncAggregationQuery(self).avg(field_ref, alias=alias) - async def stream( + async def _make_stream( self, - transaction=None, - retry: retries.AsyncRetry = gapic_v1.method.DEFAULT, - timeout: float = None, + transaction: Optional[transaction.Transaction] = None, + retry: Optional[retries.AsyncRetry] = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, ) -> AsyncGenerator[async_document.DocumentSnapshot, None]: - """Read the documents in the collection that match this query. + """Internal method for stream(). Read the documents in the collection + that match this query. - This sends a ``RunQuery`` RPC and then returns an iterator which + This sends a ``RunQuery`` RPC and then returns a generator which consumes each document returned in the stream of ``RunQueryResponse`` messages. @@ -288,13 +324,14 @@ async def stream( allowed). Args: - transaction - (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): - An existing transaction that this query will run in. - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. Defaults to a system-specified policy. - timeout (float): The timeout for this request. Defaults to a - system-specified value. + transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.\ + Transaction`]): + An existing transaction that the query will run in. + retry (Optional[google.api_core.retry.Retry]): Designation of what + errors, if any, should be retried. Defaults to a + system-specified policy. + timeout (Optional[float]): The timeout for this request. Defaults + to a system-specified value. Yields: :class:`~google.cloud.firestore_v1.async_document.DocumentSnapshot`: @@ -324,6 +361,50 @@ async def stream( if snapshot is not None: yield snapshot + def stream( + self, + transaction: Optional[transaction.Transaction] = None, + retry: Optional[retries.AsyncRetry] = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + ) -> "AsyncStreamGenerator[DocumentSnapshot]": + """Read the documents in the collection that match this query. + + This sends a ``RunQuery`` RPC and then returns a generator which + consumes each document returned in the stream of ``RunQueryResponse`` + messages. + + .. note:: + + The underlying stream of responses will time out after + the ``max_rpc_timeout_millis`` value set in the GAPIC + client configuration for the ``RunQuery`` API. Snapshots + not consumed from the iterator before that point will be lost. + + If a ``transaction`` is used and it already has write operations + added, this method cannot be used (i.e. read-after-write is not + allowed). + + Args: + transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.\ + Transaction`]): + An existing transaction that the query will run in. + retry (Optional[google.api_core.retry.Retry]): Designation of what + errors, if any, should be retried. Defaults to a + system-specified policy. + timeout (Optional[float]): The timeout for this request. Defaults + to a system-specified value. + + Returns: + `AsyncStreamGenerator[DocumentSnapshot]`: A generator of the query + results. + """ + inner_generator = self._make_stream( + transaction=transaction, + retry=retry, + timeout=timeout, + ) + return AsyncStreamGenerator(inner_generator) + @staticmethod def _get_collection_reference_class() -> ( Type["firestore_v1.async_collection.AsyncCollectionReference"] diff --git a/google/cloud/firestore_v1/async_stream_generator.py b/google/cloud/firestore_v1/async_stream_generator.py new file mode 100644 index 0000000000..ca0481c0d1 --- /dev/null +++ b/google/cloud/firestore_v1/async_stream_generator.py @@ -0,0 +1,41 @@ +# Copyright 2024 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Classes for iterating over stream results async for the Google Cloud +Firestore API. +""" + +from collections import abc + + +class AsyncStreamGenerator(abc.AsyncGenerator): + """Asynchronous generator for the streamed results.""" + + def __init__(self, response_generator): + self._generator = response_generator + + def __aiter__(self): + return self._generator + + def __anext__(self): + return self._generator.__anext__() + + def asend(self, value=None): + return self._generator.asend(value) + + def athrow(self, exp=None): + return self._generator.athrow(exp) + + def aclose(self): + return self._generator.aclose() diff --git a/google/cloud/firestore_v1/async_transaction.py b/google/cloud/firestore_v1/async_transaction.py index 18a20b8e12..6b01fffd6c 100644 --- a/google/cloud/firestore_v1/async_transaction.py +++ b/google/cloud/firestore_v1/async_transaction.py @@ -17,34 +17,31 @@ import asyncio import random +from typing import Any, AsyncGenerator, Callable, Coroutine -from google.api_core import gapic_v1 +from google.api_core import exceptions, gapic_v1 from google.api_core import retry_async as retries +from google.cloud.firestore_v1 import _helpers, async_batch, types +from google.cloud.firestore_v1.async_document import ( + AsyncDocumentReference, + DocumentSnapshot, +) +from google.cloud.firestore_v1.async_query import AsyncQuery from google.cloud.firestore_v1.base_transaction import ( - _BaseTransactional, - BaseTransaction, - MAX_ATTEMPTS, _CANT_BEGIN, - _CANT_ROLLBACK, _CANT_COMMIT, - _WRITE_READ_ONLY, + _CANT_ROLLBACK, + _EXCEED_ATTEMPTS_TEMPLATE, _INITIAL_SLEEP, _MAX_SLEEP, _MULTIPLIER, - _EXCEED_ATTEMPTS_TEMPLATE, + _WRITE_READ_ONLY, + MAX_ATTEMPTS, + BaseTransaction, + _BaseTransactional, ) -from google.api_core import exceptions -from google.cloud.firestore_v1 import async_batch -from google.cloud.firestore_v1 import _helpers -from google.cloud.firestore_v1 import types - -from google.cloud.firestore_v1.async_document import AsyncDocumentReference -from google.cloud.firestore_v1.async_document import DocumentSnapshot -from google.cloud.firestore_v1.async_query import AsyncQuery -from typing import Any, AsyncGenerator, Callable, Coroutine - # Types needed only for Type Hints from google.cloud.firestore_v1.client import Client diff --git a/google/cloud/firestore_v1/async_vector_query.py b/google/cloud/firestore_v1/async_vector_query.py new file mode 100644 index 0000000000..a77bc4343f --- /dev/null +++ b/google/cloud/firestore_v1/async_vector_query.py @@ -0,0 +1,129 @@ +# Copyright 2024 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import AsyncGenerator, List, Optional, TypeVar, Union + +from google.api_core import gapic_v1 +from google.api_core import retry_async as retries + +from google.cloud.firestore_v1 import async_document +from google.cloud.firestore_v1.base_document import DocumentSnapshot +from google.cloud.firestore_v1.base_query import ( + BaseQuery, + _collection_group_query_response_to_snapshot, + _query_response_to_snapshot, +) +from google.cloud.firestore_v1.base_vector_query import BaseVectorQuery + +TAsyncVectorQuery = TypeVar("TAsyncVectorQuery", bound="AsyncVectorQuery") + + +class AsyncVectorQuery(BaseVectorQuery): + """Represents an async vector query to the Firestore API.""" + + def __init__( + self, + nested_query: Union[BaseQuery, TAsyncVectorQuery], + ) -> None: + """Presents the vector query. + Args: + nested_query (BaseQuery | VectorQuery): the base query to apply as the prefilter. + """ + super(AsyncVectorQuery, self).__init__(nested_query) + + async def get( + self, + transaction=None, + retry: retries.AsyncRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + ) -> List[DocumentSnapshot]: + """Runs the vector query. + + This sends a ``RunQuery`` RPC and returns a list of document messages. + + Args: + transaction + (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + An existing transaction that this query will run in. + If a ``transaction`` is used and it already has write operations + added, this method cannot be used (i.e. read-after-write is not + allowed). + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. Defaults to a system-specified policy. + timeout (float): The timeout for this request. Defaults to a + system-specified value. + + Returns: + list: The vector query results. + """ + stream_result = self.stream( + transaction=transaction, retry=retry, timeout=timeout + ) + result = [snapshot async for snapshot in stream_result] + return result # type: ignore + + async def stream( + self, + transaction=None, + retry: retries.AsyncRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + ) -> AsyncGenerator[async_document.DocumentSnapshot, None]: + """Reads the documents in the collection that match this query. + + This sends a ``RunQuery`` RPC and then returns an iterator which + consumes each document returned in the stream of ``RunQueryResponse`` + messages. + + If a ``transaction`` is used and it already has write operations + added, this method cannot be used (i.e. read-after-write is not + allowed). + + Args: + transaction + (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + An existing transaction that this query will run in. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. Defaults to a system-specified policy. + timeout (float): The timeout for this request. Defaults to a + system-specified value. + + Yields: + :class:`~google.cloud.firestore_v1.document.DocumentSnapshot`: + The next document that fulfills the query. + """ + request, expected_prefix, kwargs = self._prep_stream( + transaction, + retry, + timeout, + ) + + response_iterator = await self._client._firestore_api.run_query( + request=request, + metadata=self._client._rpc_metadata, + **kwargs, + ) + + async for response in response_iterator: + if self._nested_query._all_descendants: + snapshot = _collection_group_query_response_to_snapshot( + response, self._nested_query._parent + ) + else: + snapshot = _query_response_to_snapshot( + response, self._nested_query._parent, expected_prefix + ) + if snapshot is not None: + yield snapshot diff --git a/google/cloud/firestore_v1/base_aggregation.py b/google/cloud/firestore_v1/base_aggregation.py index d6097c136b..f922663791 100644 --- a/google/cloud/firestore_v1/base_aggregation.py +++ b/google/cloud/firestore_v1/base_aggregation.py @@ -23,20 +23,32 @@ from __future__ import annotations import abc - - from abc import ABC - -from typing import List, Coroutine, Union, Tuple, Generator, Any, AsyncGenerator +from typing import ( + TYPE_CHECKING, + Any, + AsyncGenerator, + Coroutine, + Generator, + List, + Optional, + Tuple, + Union, +) from google.api_core import gapic_v1 from google.api_core import retry as retries - -from google.cloud.firestore_v1.field_path import FieldPath -from google.cloud.firestore_v1.types import RunAggregationQueryResponse -from google.cloud.firestore_v1.types import StructuredAggregationQuery from google.cloud.firestore_v1 import _helpers +from google.cloud.firestore_v1.field_path import FieldPath +from google.cloud.firestore_v1.types import ( + RunAggregationQueryResponse, + StructuredAggregationQuery, +) + +# Types needed only for Type Hints +if TYPE_CHECKING: + from google.cloud.firestore_v1 import transaction # pragma: NO COVER class AggregationResult(object): @@ -243,32 +255,27 @@ def get( @abc.abstractmethod def stream( self, - transaction=None, - retry: Union[ - retries.Retry, None, gapic_v1.method._MethodDefault - ] = gapic_v1.method.DEFAULT, - timeout: float | None = None, + transaction: Optional[transaction.Transaction] = None, + retry: Optional[retries.Retry] = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, ) -> ( Generator[List[AggregationResult], Any, None] | AsyncGenerator[List[AggregationResult], None] ): """Runs the aggregation query. - This sends a``RunAggregationQuery`` RPC and returns an iterator in the stream of ``RunAggregationQueryResponse`` messages. + This sends a``RunAggregationQuery`` RPC and returns a generator in the stream of ``RunAggregationQueryResponse`` messages. Args: transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): An existing transaction that this query will run in. - If a ``transaction`` is used and it already has write operations - added, this method cannot be used (i.e. read-after-write is not - allowed). - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. Defaults to a system-specified policy. - timeout (float): The timeout for this request. Defaults to a - system-specified value. + retry (Optional[google.api_core.retry.Retry]): Designation of what + errors, if any, should be retried. Defaults to a + system-specified policy. + timeout (Optinal[float]): The timeout for this request. Defaults + to a system-specified value. Returns: - list: The aggregation query results - + A generator of the query results. """ diff --git a/google/cloud/firestore_v1/base_batch.py b/google/cloud/firestore_v1/base_batch.py index ca3a66c897..4b08c0d304 100644 --- a/google/cloud/firestore_v1/base_batch.py +++ b/google/cloud/firestore_v1/base_batch.py @@ -19,6 +19,7 @@ # Types needed only for Type Hints from google.api_core import retry as retries + from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.base_document import BaseDocumentReference diff --git a/google/cloud/firestore_v1/base_client.py b/google/cloud/firestore_v1/base_client.py index 585fc7e564..1886cd7c8a 100644 --- a/google/cloud/firestore_v1/base_client.py +++ b/google/cloud/firestore_v1/base_client.py @@ -25,22 +25,6 @@ """ import os -import grpc # type: ignore - -from google.auth.credentials import AnonymousCredentials -import google.api_core.client_options -import google.api_core.path_template -from google.api_core import retry as retries -from google.api_core.gapic_v1 import client_info -from google.cloud.client import ClientWithProject # type: ignore - -from google.cloud.firestore_v1 import _helpers -from google.cloud.firestore_v1 import __version__ -from google.cloud.firestore_v1 import types -from google.cloud.firestore_v1.base_document import DocumentSnapshot - -from google.cloud.firestore_v1.field_path import render_field_path -from google.cloud.firestore_v1.bulk_writer import BulkWriter, BulkWriterOptions from typing import ( Any, AsyncGenerator, @@ -52,13 +36,27 @@ Union, ) +import google.api_core.client_options +import google.api_core.path_template +import grpc # type: ignore +from google.api_core import retry as retries +from google.api_core.gapic_v1 import client_info +from google.auth.credentials import AnonymousCredentials +from google.cloud.client import ClientWithProject # type: ignore + +from google.cloud.firestore_v1 import __version__, _helpers, types +from google.cloud.firestore_v1.base_batch import BaseWriteBatch + # Types needed only for Type Hints from google.cloud.firestore_v1.base_collection import BaseCollectionReference -from google.cloud.firestore_v1.base_document import BaseDocumentReference -from google.cloud.firestore_v1.base_transaction import BaseTransaction -from google.cloud.firestore_v1.base_batch import BaseWriteBatch +from google.cloud.firestore_v1.base_document import ( + BaseDocumentReference, + DocumentSnapshot, +) from google.cloud.firestore_v1.base_query import BaseQuery - +from google.cloud.firestore_v1.base_transaction import BaseTransaction +from google.cloud.firestore_v1.bulk_writer import BulkWriter, BulkWriterOptions +from google.cloud.firestore_v1.field_path import render_field_path DEFAULT_DATABASE = "(default)" """str: The default database used in a :class:`~google.cloud.firestore_v1.client.Client`.""" diff --git a/google/cloud/firestore_v1/base_collection.py b/google/cloud/firestore_v1/base_collection.py index 98f690e6d9..e2065dc2f8 100644 --- a/google/cloud/firestore_v1/base_collection.py +++ b/google/cloud/firestore_v1/base_collection.py @@ -14,42 +14,40 @@ """Classes for representing collections for the Google Cloud Firestore API.""" from __future__ import annotations -import random - -from google.api_core import retry as retries - -from google.cloud.firestore_v1 import _helpers -from google.cloud.firestore_v1.base_vector_query import DistanceMeasure -from google.cloud.firestore_v1.document import DocumentReference -from google.cloud.firestore_v1.base_aggregation import BaseAggregationQuery -from google.cloud.firestore_v1.base_vector_query import BaseVectorQuery -from google.cloud.firestore_v1.base_query import QueryType -from google.cloud.firestore_v1.vector import Vector - +import random from typing import ( - Optional, + TYPE_CHECKING, Any, AsyncGenerator, + AsyncIterator, Coroutine, Generator, Generic, - AsyncIterator, - Iterator, Iterable, + Iterator, NoReturn, + Optional, Tuple, Union, - TYPE_CHECKING, ) +from google.api_core import retry as retries + +from google.cloud.firestore_v1 import _helpers +from google.cloud.firestore_v1.base_aggregation import BaseAggregationQuery +from google.cloud.firestore_v1.base_query import QueryType +from google.cloud.firestore_v1.base_vector_query import BaseVectorQuery, DistanceMeasure +from google.cloud.firestore_v1.document import DocumentReference +from google.cloud.firestore_v1.vector import Vector if TYPE_CHECKING: # pragma: NO COVER # Types needed only for Type Hints + from firestore_v1.vector_query import VectorQuery + from google.cloud.firestore_v1.base_document import DocumentSnapshot - from google.cloud.firestore_v1.transaction import Transaction from google.cloud.firestore_v1.field_path import FieldPath - from firestore_v1.vector_query import VectorQuery + from google.cloud.firestore_v1.transaction import Transaction _AUTO_ID_CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" diff --git a/google/cloud/firestore_v1/base_document.py b/google/cloud/firestore_v1/base_document.py index 3997b5b4db..1418ea34d0 100644 --- a/google/cloud/firestore_v1/base_document.py +++ b/google/cloud/firestore_v1/base_document.py @@ -15,18 +15,15 @@ """Classes for representing documents for the Google Cloud Firestore API.""" import copy +from typing import Any, Dict, Iterable, NoReturn, Optional, Tuple, Union from google.api_core import retry as retries -from google.cloud.firestore_v1.types import Document from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1 import field_path as field_path_module -from google.cloud.firestore_v1.types import common # Types needed only for Type Hints -from google.cloud.firestore_v1.types import firestore -from google.cloud.firestore_v1.types import write -from typing import Any, Dict, Iterable, NoReturn, Optional, Union, Tuple +from google.cloud.firestore_v1.types import Document, common, firestore, write class BaseDocumentReference(object): diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index c8c2f3ceb2..73ed00206b 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -24,22 +24,8 @@ import copy import math import warnings - -from google.api_core import retry as retries -from google.protobuf import wrappers_pb2 - -from google.cloud import firestore_v1 -from google.cloud.firestore_v1 import _helpers -from google.cloud.firestore_v1 import document -from google.cloud.firestore_v1 import field_path as field_path_module -from google.cloud.firestore_v1 import transforms -from google.cloud.firestore_v1.base_vector_query import DistanceMeasure -from google.cloud.firestore_v1.types import StructuredQuery -from google.cloud.firestore_v1.types import query -from google.cloud.firestore_v1.types import Cursor -from google.cloud.firestore_v1.types import RunQueryResponse -from google.cloud.firestore_v1.order import Order from typing import ( + TYPE_CHECKING, Any, Dict, Generator, @@ -50,12 +36,27 @@ Type, TypeVar, Union, - TYPE_CHECKING, ) -from google.cloud.firestore_v1.vector import Vector + +from google.api_core import retry as retries +from google.protobuf import wrappers_pb2 + +from google.cloud import firestore_v1 +from google.cloud.firestore_v1 import _helpers, document +from google.cloud.firestore_v1 import field_path as field_path_module +from google.cloud.firestore_v1 import transforms # Types needed only for Type Hints from google.cloud.firestore_v1.base_document import DocumentSnapshot +from google.cloud.firestore_v1.base_vector_query import DistanceMeasure +from google.cloud.firestore_v1.order import Order +from google.cloud.firestore_v1.types import ( + Cursor, + RunQueryResponse, + StructuredQuery, + query, +) +from google.cloud.firestore_v1.vector import Vector if TYPE_CHECKING: # pragma: NO COVER from google.cloud.firestore_v1.base_vector_query import BaseVectorQuery @@ -978,7 +979,7 @@ def _to_protobuf(self) -> StructuredQuery: def find_nearest( self, vector_field: str, - queryVector: Vector, + query_vector: Vector, limit: int, distance_measure: DistanceMeasure, ) -> BaseVectorQuery: diff --git a/google/cloud/firestore_v1/base_transaction.py b/google/cloud/firestore_v1/base_transaction.py index b4e5dd0382..5b6e76e1b0 100644 --- a/google/cloud/firestore_v1/base_transaction.py +++ b/google/cloud/firestore_v1/base_transaction.py @@ -14,10 +14,11 @@ """Helpers for applying Google Cloud Firestore changes in a transaction.""" +from typing import Any, Coroutine, NoReturn, Optional, Union + from google.api_core import retry as retries from google.cloud.firestore_v1 import types -from typing import Any, Coroutine, NoReturn, Optional, Union _CANT_BEGIN: str _CANT_COMMIT: str diff --git a/google/cloud/firestore_v1/base_vector_query.py b/google/cloud/firestore_v1/base_vector_query.py index e41717d2b5..0c5c61b3e8 100644 --- a/google/cloud/firestore_v1/base_vector_query.py +++ b/google/cloud/firestore_v1/base_vector_query.py @@ -16,16 +16,17 @@ """ import abc - from abc import ABC from enum import Enum from typing import Iterable, Optional, Tuple, Union + from google.api_core import gapic_v1 from google.api_core import retry as retries + +from google.cloud.firestore_v1 import _helpers, document from google.cloud.firestore_v1.base_document import DocumentSnapshot from google.cloud.firestore_v1.types import query from google.cloud.firestore_v1.vector import Vector -from google.cloud.firestore_v1 import _helpers class DistanceMeasure(Enum): @@ -117,3 +118,11 @@ def find_nearest( self._limit = limit self._distance_measure = distance_measure return self + + def stream( + self, + transaction=None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + ) -> Iterable[document.DocumentSnapshot]: + """Reads the documents in the collection that match this query.""" diff --git a/google/cloud/firestore_v1/bulk_writer.py b/google/cloud/firestore_v1/bulk_writer.py index 9f7d0f6240..4c1c7bde9e 100644 --- a/google/cloud/firestore_v1/bulk_writer.py +++ b/google/cloud/firestore_v1/bulk_writer.py @@ -23,9 +23,8 @@ import functools import logging import time - from dataclasses import dataclass -from typing import Callable, Dict, List, Optional, Union, TYPE_CHECKING +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union from google.rpc import status_pb2 # type: ignore diff --git a/google/cloud/firestore_v1/client.py b/google/cloud/firestore_v1/client.py index 05c135479b..8bdaf7f815 100644 --- a/google/cloud/firestore_v1/client.py +++ b/google/cloud/firestore_v1/client.py @@ -24,30 +24,30 @@ :class:`~google.cloud.firestore_v1.document.DocumentReference` """ +from typing import TYPE_CHECKING, Any, Generator, Iterable, List, Optional, Union + from google.api_core import gapic_v1 from google.api_core import retry as retries from google.cloud.firestore_v1.base_client import ( - BaseClient, _CLIENT_INFO, + BaseClient, _parse_batch_get, _path_helper, ) -from google.cloud.firestore_v1.query import CollectionGroup +# Types needed only for Type Hints +from google.cloud.firestore_v1.base_document import DocumentSnapshot from google.cloud.firestore_v1.batch import WriteBatch from google.cloud.firestore_v1.collection import CollectionReference from google.cloud.firestore_v1.document import DocumentReference from google.cloud.firestore_v1.field_path import FieldPath -from google.cloud.firestore_v1.transaction import Transaction +from google.cloud.firestore_v1.query import CollectionGroup from google.cloud.firestore_v1.services.firestore import client as firestore_client from google.cloud.firestore_v1.services.firestore.transports import ( grpc as firestore_grpc_transport, ) -from typing import Any, Generator, Iterable, List, Optional, Union, TYPE_CHECKING - -# Types needed only for Type Hints -from google.cloud.firestore_v1.base_document import DocumentSnapshot +from google.cloud.firestore_v1.transaction import Transaction if TYPE_CHECKING: from google.cloud.firestore_v1.bulk_writer import BulkWriter # pragma: NO COVER diff --git a/google/cloud/firestore_v1/collection.py b/google/cloud/firestore_v1/collection.py index 30ddd4bcc0..96dadf2e70 100644 --- a/google/cloud/firestore_v1/collection.py +++ b/google/cloud/firestore_v1/collection.py @@ -14,22 +14,23 @@ """Classes for representing collections for the Google Cloud Firestore API.""" +from typing import TYPE_CHECKING, Any, Callable, Generator, Optional, Tuple, Union + from google.api_core import gapic_v1 from google.api_core import retry as retries +from google.cloud.firestore_v1 import aggregation, document +from google.cloud.firestore_v1 import query as query_mod +from google.cloud.firestore_v1 import transaction, vector_query from google.cloud.firestore_v1.base_collection import ( BaseCollectionReference, _item_to_document_ref, ) -from google.cloud.firestore_v1 import query as query_mod -from google.cloud.firestore_v1 import aggregation -from google.cloud.firestore_v1 import vector_query from google.cloud.firestore_v1.watch import Watch -from google.cloud.firestore_v1 import document -from typing import Any, Callable, Generator, Tuple, Union -# Types needed only for Type Hints -from google.cloud.firestore_v1.transaction import Transaction +if TYPE_CHECKING: # pragma: NO COVER + from google.cloud.firestore_v1.base_document import DocumentSnapshot + from google.cloud.firestore_v1.stream_generator import StreamGenerator class CollectionReference(BaseCollectionReference[query_mod.Query]): @@ -165,7 +166,7 @@ def _chunkify(self, chunk_size: int): def get( self, - transaction: Union[Transaction, None] = None, + transaction: Union[transaction.Transaction, None] = None, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: Union[float, None] = None, ) -> list: @@ -176,7 +177,7 @@ def get( Args: transaction - (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + (Optional[:class:`~google.cloud.firestore_v1.transaction.transaction.Transaction`]): An existing transaction that this query will run in. retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. Defaults to a system-specified policy. @@ -196,10 +197,10 @@ def get( def stream( self, - transaction: Union[Transaction, None] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: Union[float, None] = None, - ) -> Generator[document.DocumentSnapshot, Any, None]: + transaction: Optional[transaction.Transaction] = None, + retry: Optional[retries.Retry] = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + ) -> "StreamGenerator[DocumentSnapshot]": """Read the documents in this collection. This sends a ``RunQuery`` RPC and then returns an iterator which @@ -219,16 +220,16 @@ def stream( Args: transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.\ - Transaction`]): + transaction.Transaction`]): An existing transaction that the query will run in. - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. Defaults to a system-specified policy. - timeout (float): The timeout for this request. Defaults to a - system-specified value. + retry (Optional[google.api_core.retry.Retry]): Designation of what + errors, if any, should be retried. Defaults to a + system-specified policy. + timeout (Optional[float]): The timeout for this request. Defaults + to a system-specified value. - Yields: - :class:`~google.cloud.firestore_v1.document.DocumentSnapshot`: - The next document that fulfills the query. + Returns: + `StreamGenerator[DocumentSnapshot]`: A generator of the query results. """ query, kwargs = self._prep_get_or_stream(retry, timeout) diff --git a/google/cloud/firestore_v1/document.py b/google/cloud/firestore_v1/document.py index 00d682d2bb..305d10df6f 100644 --- a/google/cloud/firestore_v1/document.py +++ b/google/cloud/firestore_v1/document.py @@ -15,22 +15,21 @@ """Classes for representing documents for the Google Cloud Firestore API.""" import datetime import logging +from typing import Any, Callable, Generator, Iterable from google.api_core import gapic_v1 from google.api_core import retry as retries from google.cloud._helpers import _datetime_to_pb_timestamp # type: ignore +from google.protobuf.timestamp_pb2 import Timestamp +from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.base_document import ( BaseDocumentReference, DocumentSnapshot, _first_write_result, ) -from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.types import write from google.cloud.firestore_v1.watch import Watch -from google.protobuf.timestamp_pb2 import Timestamp -from typing import Any, Callable, Generator, Iterable - logger = logging.getLogger(__name__) diff --git a/google/cloud/firestore_v1/field_path.py b/google/cloud/firestore_v1/field_path.py index 24683fb843..df7d10a789 100644 --- a/google/cloud/firestore_v1/field_path.py +++ b/google/cloud/firestore_v1/field_path.py @@ -14,12 +14,10 @@ """Utilities for managing / converting field paths to / from strings.""" -from collections import abc - import re +from collections import abc from typing import Iterable - _FIELD_PATH_MISSING_TOP = "{!r} is not contained in the data" _FIELD_PATH_MISSING_KEY = "{!r} is not contained in the data for the key {!r}" _FIELD_PATH_WRONG_TYPE = ( diff --git a/google/cloud/firestore_v1/gapic_version.py b/google/cloud/firestore_v1/gapic_version.py index 8edfaef714..8d4f4cfb61 100644 --- a/google/cloud/firestore_v1/gapic_version.py +++ b/google/cloud/firestore_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "2.16.1" # {x-release-please-version} +__version__ = "2.17.0" # {x-release-please-version} diff --git a/google/cloud/firestore_v1/order.py b/google/cloud/firestore_v1/order.py index 0803a60e3f..9395d05b96 100644 --- a/google/cloud/firestore_v1/order.py +++ b/google/cloud/firestore_v1/order.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from enum import Enum -from google.cloud.firestore_v1._helpers import decode_value import math +from enum import Enum from typing import Any +from google.cloud.firestore_v1._helpers import decode_value + class TypeOrder(Enum): """The supported Data Type. diff --git a/google/cloud/firestore_v1/query.py b/google/cloud/firestore_v1/query.py index c46a06918a..b5bd5ec4fd 100644 --- a/google/cloud/firestore_v1/query.py +++ b/google/cloud/firestore_v1/query.py @@ -20,28 +20,27 @@ """ from __future__ import annotations -from google.cloud import firestore_v1 -from google.cloud.firestore_v1.base_document import DocumentSnapshot -from google.api_core import exceptions -from google.api_core import gapic_v1 +from typing import TYPE_CHECKING, Any, Callable, Generator, List, Optional, Type + +from google.api_core import exceptions, gapic_v1 from google.api_core import retry as retries +from google.cloud import firestore_v1 +from google.cloud.firestore_v1 import aggregation, transaction +from google.cloud.firestore_v1.base_document import DocumentSnapshot from google.cloud.firestore_v1.base_query import ( BaseCollectionGroup, BaseQuery, QueryPartition, - _query_response_to_snapshot, _collection_group_query_response_to_snapshot, _enum_from_direction, + _query_response_to_snapshot, ) from google.cloud.firestore_v1.base_vector_query import DistanceMeasure -from google.cloud.firestore_v1.vector_query import VectorQuery +from google.cloud.firestore_v1.stream_generator import StreamGenerator from google.cloud.firestore_v1.vector import Vector -from google.cloud.firestore_v1 import aggregation - -from google.cloud.firestore_v1 import document +from google.cloud.firestore_v1.vector_query import VectorQuery from google.cloud.firestore_v1.watch import Watch -from typing import Any, Callable, Generator, List, Optional, Type, TYPE_CHECKING if TYPE_CHECKING: # pragma: NO COVER from google.cloud.firestore_v1.field_path import FieldPath @@ -171,7 +170,11 @@ def get( ) self._limit_to_last = False - result = self.stream(transaction=transaction, retry=retry, timeout=timeout) + result = self.stream( + transaction=transaction, + retry=retry, + timeout=timeout, + ) if is_limited_to_last: result = reversed(list(result)) @@ -312,15 +315,17 @@ def avg( """ return aggregation.AggregationQuery(self).avg(field_ref, alias=alias) - def stream( + def _make_stream( self, - transaction=None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - ) -> Generator[document.DocumentSnapshot, Any, None]: - """Read the documents in the collection that match this query. - - This sends a ``RunQuery`` RPC and then returns an iterator which + transaction: Optional[transaction.Transaction] = None, + retry: Optional[retries.Retry] = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + ) -> Generator[DocumentSnapshot, Any, None]: + """Internal method for stream(). Read the documents in the collection + that match this query. + + Internal method for stream(). + This sends a ``RunQuery`` RPC and then returns a generator which consumes each document returned in the stream of ``RunQueryResponse`` messages. @@ -336,13 +341,14 @@ def stream( allowed). Args: - transaction - (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): - An existing transaction that this query will run in. - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. Defaults to a system-specified policy. - timeout (float): The timeout for this request. Defaults to a - system-specified value. + transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.\ + Transaction`]): + An existing transaction that the query will run in. + retry (Optional[google.api_core.retry.Retry]): Designation of what + errors, if any, should be retried. Defaults to a + system-specified policy. + timeout (Optional[float]): The timeout for this request. Defaults + to a system-specified value. Yields: :class:`~google.cloud.firestore_v1.document.DocumentSnapshot`: @@ -386,6 +392,49 @@ def stream( last_snapshot = snapshot yield snapshot + def stream( + self, + transaction: Optional[transaction.Transaction] = None, + retry: Optional[retries.Retry] = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + ) -> "StreamGenerator[DocumentSnapshot]": + """Read the documents in the collection that match this query. + + This sends a ``RunQuery`` RPC and then returns a generator which + consumes each document returned in the stream of ``RunQueryResponse`` + messages. + + .. note:: + + The underlying stream of responses will time out after + the ``max_rpc_timeout_millis`` value set in the GAPIC + client configuration for the ``RunQuery`` API. Snapshots + not consumed from the iterator before that point will be lost. + + If a ``transaction`` is used and it already has write operations + added, this method cannot be used (i.e. read-after-write is not + allowed). + + Args: + transaction + (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + An existing transaction that this query will run in. + retry (Optional[google.api_core.retry.Retry]): Designation of what + errors, if any, should be retried. Defaults to a + system-specified policy. + timeout (Optinal[float]): The timeout for this request. Defaults + to a system-specified value. + + Returns: + `StreamGenerator[DocumentSnapshot]`: A generator of the query results. + """ + inner_generator = self._make_stream( + transaction=transaction, + retry=retry, + timeout=timeout, + ) + return StreamGenerator(inner_generator) + def on_snapshot(self, callback: Callable) -> Watch: """Monitor the documents in this collection that match this query. @@ -415,7 +464,7 @@ def on_snapshot(docs, changes, read_time): # Terminate this watch query_watch.unsubscribe() """ - return Watch.for_query(self, callback, document.DocumentSnapshot) + return Watch.for_query(self, callback, DocumentSnapshot) @staticmethod def _get_collection_reference_class() -> ( diff --git a/google/cloud/firestore_v1/rate_limiter.py b/google/cloud/firestore_v1/rate_limiter.py index 8ca98dbe88..4cd06d8666 100644 --- a/google/cloud/firestore_v1/rate_limiter.py +++ b/google/cloud/firestore_v1/rate_limiter.py @@ -13,8 +13,8 @@ # limitations under the License. import datetime -from typing import NoReturn, Optional import warnings +from typing import NoReturn, Optional def utcnow(): diff --git a/google/cloud/firestore_v1/stream_generator.py b/google/cloud/firestore_v1/stream_generator.py new file mode 100644 index 0000000000..0a95af8d1f --- /dev/null +++ b/google/cloud/firestore_v1/stream_generator.py @@ -0,0 +1,40 @@ +# Copyright 2024 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Classes for iterating over stream results for the Google Cloud Firestore API. +""" + +from collections import abc + + +class StreamGenerator(abc.Generator): + """Generator for the streamed results.""" + + def __init__(self, response_generator): + self._generator = response_generator + + def __iter__(self): + return self._generator + + def __next__(self): + return self._generator.__next__() + + def send(self, value=None): + return self._generator.send(value) + + def throw(self, exp=None): + return self._generator.throw(exp) + + def close(self): + return self._generator.close() diff --git a/google/cloud/firestore_v1/transaction.py b/google/cloud/firestore_v1/transaction.py index 3c175a4ced..1691b56792 100644 --- a/google/cloud/firestore_v1/transaction.py +++ b/google/cloud/firestore_v1/transaction.py @@ -17,34 +17,31 @@ import random import time +from typing import Any, Callable, Generator -from google.api_core import gapic_v1 +from google.api_core import exceptions, gapic_v1 from google.api_core import retry as retries +from google.cloud.firestore_v1 import _helpers, batch + +# Types needed only for Type Hints +from google.cloud.firestore_v1.base_document import DocumentSnapshot from google.cloud.firestore_v1.base_transaction import ( - _BaseTransactional, - BaseTransaction, - MAX_ATTEMPTS, _CANT_BEGIN, - _CANT_ROLLBACK, _CANT_COMMIT, - _WRITE_READ_ONLY, + _CANT_ROLLBACK, + _EXCEED_ATTEMPTS_TEMPLATE, _INITIAL_SLEEP, _MAX_SLEEP, _MULTIPLIER, - _EXCEED_ATTEMPTS_TEMPLATE, + _WRITE_READ_ONLY, + MAX_ATTEMPTS, + BaseTransaction, + _BaseTransactional, ) - -from google.api_core import exceptions -from google.cloud.firestore_v1 import batch from google.cloud.firestore_v1.document import DocumentReference -from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.query import Query - -# Types needed only for Type Hints -from google.cloud.firestore_v1.base_document import DocumentSnapshot from google.cloud.firestore_v1.types import CommitResponse -from typing import Any, Callable, Generator class Transaction(batch.WriteBatch, BaseTransaction): diff --git a/google/cloud/firestore_v1/vector.py b/google/cloud/firestore_v1/vector.py index 3aa5cdc75d..3349b57e1f 100644 --- a/google/cloud/firestore_v1/vector.py +++ b/google/cloud/firestore_v1/vector.py @@ -14,8 +14,7 @@ # limitations under the License. import collections - -from typing import Tuple, Sequence +from typing import Sequence, Tuple class Vector(collections.abc.Sequence): diff --git a/google/cloud/firestore_v1/vector_query.py b/google/cloud/firestore_v1/vector_query.py index 1e8e990839..a419dba63a 100644 --- a/google/cloud/firestore_v1/vector_query.py +++ b/google/cloud/firestore_v1/vector_query.py @@ -15,17 +15,24 @@ """Classes for representing vector queries for the Google Cloud Firestore API. """ -from google.cloud.firestore_v1.base_vector_query import BaseVectorQuery -from typing import Iterable, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, Generator, Iterable, Optional, TypeVar, Union + from google.api_core import gapic_v1 from google.api_core import retry as retries -from google.cloud.firestore_v1.base_document import DocumentSnapshot -from google.cloud.firestore_v1 import document + from google.cloud.firestore_v1.base_query import ( BaseQuery, - _query_response_to_snapshot, _collection_group_query_response_to_snapshot, + _query_response_to_snapshot, ) +from google.cloud.firestore_v1.base_vector_query import BaseVectorQuery +from google.cloud.firestore_v1.stream_generator import StreamGenerator + +# Types needed only for Type Hints +if TYPE_CHECKING: # pragma: NO COVER + from google.cloud.firestore_v1 import transaction + from google.cloud.firestore_v1.base_document import DocumentSnapshot + TVectorQuery = TypeVar("TVectorQuery", bound="VectorQuery") @@ -48,7 +55,7 @@ def get( transaction=None, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: Optional[float] = None, - ) -> Iterable[DocumentSnapshot]: + ) -> Iterable["DocumentSnapshot"]: """Runs the vector query. This sends a ``RunQuery`` RPC and returns a list of document messages. @@ -88,15 +95,15 @@ def _get_stream_iterator(self, transaction, retry, timeout): return response_iterator, expected_prefix - def stream( + def _make_stream( self, - transaction=None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - ) -> Iterable[document.DocumentSnapshot]: + transaction: Optional["transaction.Transaction"] = None, + retry: Optional[retries.Retry] = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + ) -> Generator["DocumentSnapshot", Any, None]: """Reads the documents in the collection that match this query. - This sends a ``RunQuery`` RPC and then returns an iterator which + This sends a ``RunQuery`` RPC and then returns a generator which consumes each document returned in the stream of ``RunQueryResponse`` messages. @@ -108,10 +115,11 @@ def stream( transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): An existing transaction that this query will run in. - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. Defaults to a system-specified policy. - timeout (float): The timeout for this request. Defaults to a - system-specified value. + retry (Optional[google.api_core.retry.Retry]): Designation of what + errors, if any, should be retried. Defaults to a + system-specified policy. + timeout (Optional[float]): The timeout for this request. Defaults + to a system-specified value. Yields: :class:`~google.cloud.firestore_v1.document.DocumentSnapshot`: @@ -139,3 +147,39 @@ def stream( ) if snapshot is not None: yield snapshot + + def stream( + self, + transaction: Optional["transaction.Transaction"] = None, + retry: Optional[retries.Retry] = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + ) -> "StreamGenerator[DocumentSnapshot]": + """Reads the documents in the collection that match this query. + + This sends a ``RunQuery`` RPC and then returns a generator which + consumes each document returned in the stream of ``RunQueryResponse`` + messages. + + If a ``transaction`` is used and it already has write operations + added, this method cannot be used (i.e. read-after-write is not + allowed). + + Args: + transaction + (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + An existing transaction that this query will run in. + retry (Optional[google.api_core.retry.Retry]): Designation of what + errors, if any, should be retried. Defaults to a + system-specified policy. + timeout (Optinal[float]): The timeout for this request. Defaults + to a system-specified value. + + Returns: + `StreamGenerator[DocumentSnapshot]`: A generator of the query results. + """ + inner_generator = self._make_stream( + transaction=transaction, + retry=retry, + timeout=timeout, + ) + return StreamGenerator(inner_generator) diff --git a/google/cloud/firestore_v1/watch.py b/google/cloud/firestore_v1/watch.py index 555b895019..63bb522b92 100644 --- a/google/cloud/firestore_v1/watch.py +++ b/google/cloud/firestore_v1/watch.py @@ -13,21 +13,21 @@ # limitations under the License. import collections -from enum import Enum import functools import logging import threading +from enum import Enum -from google.api_core.bidi import ResumableBidiRpc -from google.api_core.bidi import BackgroundConsumer -from google.api_core import exceptions import grpc # type: ignore +from google.api_core import exceptions +from google.api_core.bidi import BackgroundConsumer, ResumableBidiRpc -from google.cloud.firestore_v1.types.firestore import ListenRequest -from google.cloud.firestore_v1.types.firestore import Target -from google.cloud.firestore_v1.types.firestore import TargetChange from google.cloud.firestore_v1 import _helpers - +from google.cloud.firestore_v1.types.firestore import ( + ListenRequest, + Target, + TargetChange, +) TargetChangeType = TargetChange.TargetChangeType diff --git a/setup.py b/setup.py index 46ca556b4b..38f6d0e6ee 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,6 @@ import setuptools - # Package metadata. name = "google-cloud-firestore" diff --git a/tests/system/test__helpers.py b/tests/system/test__helpers.py index 5a683a44f6..d6ee9b9449 100644 --- a/tests/system/test__helpers.py +++ b/tests/system/test__helpers.py @@ -1,8 +1,9 @@ import os import re + +from test_utils.system import EmulatorCreds, unique_resource_id + from google.cloud.firestore_v1.base_client import _FIRESTORE_EMULATOR_HOST -from test_utils.system import unique_resource_id -from test_utils.system import EmulatorCreds FIRESTORE_CREDS = os.environ.get("FIRESTORE_APPLICATION_CREDENTIALS") FIRESTORE_PROJECT = os.environ.get("GCLOUD_PROJECT") diff --git a/tests/system/test_system.py b/tests/system/test_system.py index 17ca974a60..87cd89d3e1 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -16,34 +16,33 @@ import itertools import math import operator +from time import sleep +from typing import Callable, Dict, List, Optional import google.auth -from google.oauth2 import service_account import pytest - -from google.api_core.exceptions import AlreadyExists -from google.api_core.exceptions import FailedPrecondition -from google.api_core.exceptions import InvalidArgument -from google.api_core.exceptions import NotFound +from google.api_core.exceptions import ( + AlreadyExists, + FailedPrecondition, + InvalidArgument, + NotFound, +) from google.cloud._helpers import _datetime_to_pb_timestamp +from google.oauth2 import service_account + from google.cloud import firestore_v1 as firestore -from google.cloud.firestore_v1.base_query import FieldFilter, And, Or +from google.cloud.firestore_v1.base_query import And, FieldFilter, Or from google.cloud.firestore_v1.base_vector_query import DistanceMeasure from google.cloud.firestore_v1.vector import Vector - - -from time import sleep -from typing import Callable, Dict, List, Optional - from tests.system.test__helpers import ( + EMULATOR_CREDS, FIRESTORE_CREDS, + FIRESTORE_EMULATOR, + FIRESTORE_OTHER_DB, FIRESTORE_PROJECT, - RANDOM_ID_REGEX, MISSING_DOCUMENT, + RANDOM_ID_REGEX, UNIQUE_RESOURCE_ID, - EMULATOR_CREDS, - FIRESTORE_EMULATOR, - FIRESTORE_OTHER_DB, ) @@ -1239,8 +1238,8 @@ def test_batch(client, cleanup, database): @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) def test_live_bulk_writer(client, cleanup, database): - from google.cloud.firestore_v1.client import Client from google.cloud.firestore_v1.bulk_writer import BulkWriter + from google.cloud.firestore_v1.client import Client db: Client = client bw: BulkWriter = db.bulk_writer() diff --git a/tests/system/test_system_async.py b/tests/system/test_system_async.py index 5b681e7b33..696f5a6f7a 100644 --- a/tests/system/test_system_async.py +++ b/tests/system/test_system_async.py @@ -16,38 +16,38 @@ import datetime import itertools import math -import pytest -import pytest_asyncio import operator -import google.auth - from typing import Callable, Dict, List, Optional -from google.oauth2 import service_account - -from google.api_core import retry_async as retries +import google.auth +import pytest +import pytest_asyncio from google.api_core import exceptions as core_exceptions - -from google.api_core.exceptions import AlreadyExists -from google.api_core.exceptions import FailedPrecondition -from google.api_core.exceptions import InvalidArgument -from google.api_core.exceptions import NotFound +from google.api_core import retry_async as retries +from google.api_core.exceptions import ( + AlreadyExists, + FailedPrecondition, + InvalidArgument, + NotFound, +) from google.cloud._helpers import _datetime_to_pb_timestamp -from google.cloud import firestore_v1 as firestore -from google.cloud.firestore_v1.base_query import FieldFilter, And, Or +from google.oauth2 import service_account +from google.cloud import firestore_v1 as firestore +from google.cloud.firestore_v1.base_query import And, FieldFilter, Or +from google.cloud.firestore_v1.base_vector_query import DistanceMeasure +from google.cloud.firestore_v1.vector import Vector from tests.system.test__helpers import ( + EMULATOR_CREDS, FIRESTORE_CREDS, + FIRESTORE_EMULATOR, + FIRESTORE_OTHER_DB, FIRESTORE_PROJECT, - RANDOM_ID_REGEX, MISSING_DOCUMENT, + RANDOM_ID_REGEX, UNIQUE_RESOURCE_ID, - EMULATOR_CREDS, - FIRESTORE_EMULATOR, - FIRESTORE_OTHER_DB, ) - RETRIES = retries.AsyncRetry( initial=0.1, maximum=60.0, @@ -339,6 +339,47 @@ async def test_document_update_w_int_field(client, cleanup, database): assert snapshot1.to_dict() == expected +@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_vector_search_collection(client, database): + collection_id = "vector_search" + collection = client.collection(collection_id) + vector_query = collection.where("color", "==", "red").find_nearest( + vector_field="embedding", + query_vector=Vector([1.0, 2.0, 3.0]), + limit=1, + distance_measure=DistanceMeasure.EUCLIDEAN, + ) + returned = await vector_query.get() + assert isinstance(returned, list) + assert len(returned) == 1 + assert returned[0].to_dict() == { + "embedding": Vector([1.0, 2.0, 3.0]), + "color": "red", + } + + +@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_vector_search_collection_group(client, database): + collection_id = "vector_search" + collection_group = client.collection_group(collection_id) + + vector_query = collection_group.where("color", "==", "red").find_nearest( + vector_field="embedding", + query_vector=Vector([1.0, 2.0, 3.0]), + distance_measure=DistanceMeasure.EUCLIDEAN, + limit=1, + ) + returned = await vector_query.get() + assert isinstance(returned, list) + assert len(returned) == 1 + assert returned[0].to_dict() == { + "embedding": Vector([1.0, 2.0, 3.0]), + "color": "red", + } + + @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Internal Issue b/137867104") @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) async def test_update_document(client, cleanup, database): diff --git a/tests/unit/gapic/firestore_admin_v1/test_firestore_admin.py b/tests/unit/gapic/firestore_admin_v1/test_firestore_admin.py index cd3009184e..9564476625 100644 --- a/tests/unit/gapic/firestore_admin_v1/test_firestore_admin.py +++ b/tests/unit/gapic/firestore_admin_v1/test_firestore_admin.py @@ -22,41 +22,48 @@ except ImportError: # pragma: NO COVER import mock -import grpc -from grpc.experimental import aio -from collections.abc import Iterable -from google.protobuf import json_format import json import math -import pytest -from google.api_core import api_core_version -from proto.marshal.rules.dates import DurationRule, TimestampRule -from proto.marshal.rules import wrappers -from requests import Response -from requests import Request, PreparedRequest -from requests.sessions import Session -from google.protobuf import json_format +from collections.abc import Iterable -from google.api_core import client_options -from google.api_core import exceptions as core_exceptions -from google.api_core import future -from google.api_core import gapic_v1 -from google.api_core import grpc_helpers -from google.api_core import grpc_helpers_async -from google.api_core import operation +import google.auth +import grpc +import pytest from google.api_core import operation_async # type: ignore -from google.api_core import operations_v1 -from google.api_core import path_template +from google.api_core import api_core_version, client_options +from google.api_core import exceptions as core_exceptions +from google.api_core import ( + future, + gapic_v1, + grpc_helpers, + grpc_helpers_async, + operation, + operations_v1, + path_template, +) from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError +from google.cloud.location import locations_pb2 +from google.longrunning import operations_pb2 # type: ignore +from google.oauth2 import service_account +from google.protobuf import duration_pb2 # type: ignore +from google.protobuf import empty_pb2 # type: ignore +from google.protobuf import field_mask_pb2 # type: ignore +from google.protobuf import timestamp_pb2 # type: ignore +from google.protobuf import json_format +from google.type import dayofweek_pb2 # type: ignore +from grpc.experimental import aio +from proto.marshal.rules import wrappers +from proto.marshal.rules.dates import DurationRule, TimestampRule +from requests import PreparedRequest, Request, Response +from requests.sessions import Session + from google.cloud.firestore_admin_v1.services.firestore_admin import ( FirestoreAdminAsyncClient, -) -from google.cloud.firestore_admin_v1.services.firestore_admin import ( FirestoreAdminClient, + pagers, + transports, ) -from google.cloud.firestore_admin_v1.services.firestore_admin import pagers -from google.cloud.firestore_admin_v1.services.firestore_admin import transports from google.cloud.firestore_admin_v1.types import backup from google.cloud.firestore_admin_v1.types import database from google.cloud.firestore_admin_v1.types import database as gfa_database @@ -67,15 +74,6 @@ from google.cloud.firestore_admin_v1.types import index as gfa_index from google.cloud.firestore_admin_v1.types import operation as gfa_operation from google.cloud.firestore_admin_v1.types import schedule -from google.cloud.location import locations_pb2 -from google.longrunning import operations_pb2 # type: ignore -from google.oauth2 import service_account -from google.protobuf import duration_pb2 # type: ignore -from google.protobuf import empty_pb2 # type: ignore -from google.protobuf import field_mask_pb2 # type: ignore -from google.protobuf import timestamp_pb2 # type: ignore -from google.type import dayofweek_pb2 # type: ignore -import google.auth def client_cert_source_callback(): diff --git a/tests/unit/gapic/firestore_v1/test_firestore.py b/tests/unit/gapic/firestore_v1/test_firestore.py index bec710de7c..2cfa0bfda1 100644 --- a/tests/unit/gapic/firestore_v1/test_firestore.py +++ b/tests/unit/gapic/firestore_v1/test_firestore.py @@ -22,50 +22,44 @@ except ImportError: # pragma: NO COVER import mock -import grpc -from grpc.experimental import aio -from collections.abc import Iterable -from google.protobuf import json_format import json import math -import pytest -from google.api_core import api_core_version -from proto.marshal.rules.dates import DurationRule, TimestampRule -from proto.marshal.rules import wrappers -from requests import Response -from requests import Request, PreparedRequest -from requests.sessions import Session -from google.protobuf import json_format +from collections.abc import Iterable -from google.api_core import client_options +import google.auth +import grpc +import pytest +from google.api_core import api_core_version, client_options from google.api_core import exceptions as core_exceptions -from google.api_core import gapic_v1 -from google.api_core import grpc_helpers -from google.api_core import grpc_helpers_async -from google.api_core import path_template +from google.api_core import gapic_v1, grpc_helpers, grpc_helpers_async, path_template from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError -from google.cloud.firestore_v1.services.firestore import FirestoreAsyncClient -from google.cloud.firestore_v1.services.firestore import FirestoreClient -from google.cloud.firestore_v1.services.firestore import pagers -from google.cloud.firestore_v1.services.firestore import transports -from google.cloud.firestore_v1.types import aggregation_result -from google.cloud.firestore_v1.types import common -from google.cloud.firestore_v1.types import document -from google.cloud.firestore_v1.types import document as gf_document -from google.cloud.firestore_v1.types import firestore -from google.cloud.firestore_v1.types import query -from google.cloud.firestore_v1.types import query_profile -from google.cloud.firestore_v1.types import write as gf_write from google.cloud.location import locations_pb2 from google.longrunning import operations_pb2 # type: ignore from google.oauth2 import service_account from google.protobuf import struct_pb2 # type: ignore from google.protobuf import timestamp_pb2 # type: ignore from google.protobuf import wrappers_pb2 # type: ignore +from google.protobuf import json_format from google.rpc import status_pb2 # type: ignore from google.type import latlng_pb2 # type: ignore -import google.auth +from grpc.experimental import aio +from proto.marshal.rules import wrappers +from proto.marshal.rules.dates import DurationRule, TimestampRule +from requests import PreparedRequest, Request, Response +from requests.sessions import Session + +from google.cloud.firestore_v1.services.firestore import ( + FirestoreAsyncClient, + FirestoreClient, + pagers, + transports, +) +from google.cloud.firestore_v1.types import aggregation_result, common +from google.cloud.firestore_v1.types import document +from google.cloud.firestore_v1.types import document as gf_document +from google.cloud.firestore_v1.types import firestore, query, query_profile +from google.cloud.firestore_v1.types import write as gf_write def client_cert_source_callback(): diff --git a/tests/unit/test_firestore_shim.py b/tests/unit/test_firestore_shim.py index df7d951ad0..5353d28435 100644 --- a/tests/unit/test_firestore_shim.py +++ b/tests/unit/test_firestore_shim.py @@ -24,8 +24,7 @@ def test_version_from_gapic_version_meatches_firestore_v1(self): self.assertEqual(gapic_version.__version__, gapic_version_v1.__version__) def test_shim_matches_firestore_v1(self): - from google.cloud import firestore - from google.cloud import firestore_v1 + from google.cloud import firestore, firestore_v1 self.assertEqual(firestore.__all__, firestore_v1.__all__) diff --git a/tests/unit/v1/_test_helpers.py b/tests/unit/v1/_test_helpers.py index 2734d78751..340ccb30eb 100644 --- a/tests/unit/v1/_test_helpers.py +++ b/tests/unit/v1/_test_helpers.py @@ -14,19 +14,18 @@ import concurrent.futures import datetime -import mock import typing -import google +import mock +from google.cloud._helpers import UTC, _datetime_to_pb_timestamp # type: ignore +from google.protobuf.timestamp_pb2 import Timestamp # type: ignore +import google +from google.cloud.firestore_v1._helpers import build_timestamp from google.cloud.firestore_v1.async_client import AsyncClient from google.cloud.firestore_v1.base_client import BaseClient -from google.cloud.firestore_v1.document import DocumentReference, DocumentSnapshot -from google.cloud._helpers import _datetime_to_pb_timestamp, UTC # type: ignore -from google.cloud.firestore_v1._helpers import build_timestamp from google.cloud.firestore_v1.client import Client -from google.protobuf.timestamp_pb2 import Timestamp # type: ignore - +from google.cloud.firestore_v1.document import DocumentReference, DocumentSnapshot DEFAULT_TEST_PROJECT = "project-project" @@ -78,10 +77,10 @@ def make_async_aggregation_query(*args, **kw): def make_aggregation_query_response(aggregations, read_time=None, transaction=None): - from google.cloud.firestore_v1.types import firestore from google.cloud._helpers import _datetime_to_pb_timestamp + from google.cloud.firestore_v1 import _helpers - from google.cloud.firestore_v1.types import aggregation_result + from google.cloud.firestore_v1.types import aggregation_result, firestore if read_time is None: now = datetime.datetime.now(tz=datetime.timezone.utc) diff --git a/tests/unit/v1/conformance_tests.py b/tests/unit/v1/conformance_tests.py index 779c83b0e3..5eb378d2ee 100644 --- a/tests/unit/v1/conformance_tests.py +++ b/tests/unit/v1/conformance_tests.py @@ -16,14 +16,10 @@ # import proto # type: ignore - - -from google.cloud.firestore_v1.types import common -from google.cloud.firestore_v1.types import document -from google.cloud.firestore_v1.types import firestore -from google.cloud.firestore_v1.types import query as gcf_query from google.protobuf import timestamp_pb2 as timestamp # type: ignore +from google.cloud.firestore_v1.types import common, document, firestore +from google.cloud.firestore_v1.types import query as gcf_query __protobuf__ = proto.module( package="tests.unit.v1", diff --git a/tests/unit/v1/test__helpers.py b/tests/unit/v1/test__helpers.py index 5d9c9e490e..db891741a6 100644 --- a/tests/unit/v1/test__helpers.py +++ b/tests/unit/v1/test__helpers.py @@ -18,7 +18,6 @@ import mock import pytest - from tests.unit.v1._test_helpers import make_test_credentials @@ -149,6 +148,7 @@ def test_verify_path_w_success_document(): def test_encode_value_w_none(): from google.protobuf import struct_pb2 + from google.cloud.firestore_v1._helpers import encode_value result = encode_value(None) @@ -184,9 +184,10 @@ def test_encode_value_w_float(): def test_encode_value_w_datetime_with_nanos(): from google.api_core.datetime_helpers import DatetimeWithNanoseconds - from google.cloud.firestore_v1._helpers import encode_value from google.protobuf import timestamp_pb2 + from google.cloud.firestore_v1._helpers import encode_value + dt_seconds = 1488768504 dt_nanos = 458816991 timestamp_pb = timestamp_pb2.Timestamp(seconds=dt_seconds, nanos=dt_nanos) @@ -199,6 +200,7 @@ def test_encode_value_w_datetime_with_nanos(): def test_encode_value_w_datetime_wo_nanos(): from google.protobuf import timestamp_pb2 + from google.cloud.firestore_v1._helpers import encode_value dt_seconds = 1488768504 @@ -245,8 +247,7 @@ def test_encode_value_w_reference_value(): def test_encode_value_w_geo_point(): - from google.cloud.firestore_v1._helpers import encode_value - from google.cloud.firestore_v1._helpers import GeoPoint + from google.cloud.firestore_v1._helpers import GeoPoint, encode_value value = GeoPoint(50.5, 88.75) result = encode_value(value) @@ -296,11 +297,10 @@ def test_encode_value_w_bad_type(): def test_encode_dict_w_many_types(): - from google.protobuf import struct_pb2 - from google.protobuf import timestamp_pb2 + from google.protobuf import struct_pb2, timestamp_pb2 + from google.cloud.firestore_v1._helpers import encode_dict - from google.cloud.firestore_v1.types.document import ArrayValue - from google.cloud.firestore_v1.types.document import MapValue + from google.cloud.firestore_v1.types.document import ArrayValue, MapValue dt_seconds = 1497397225 dt_nanos = 465964000 @@ -355,8 +355,10 @@ def test_encode_dict_w_many_types(): def test_reference_value_to_document_w_bad_format(): - from google.cloud.firestore_v1._helpers import BAD_REFERENCE_ERROR - from google.cloud.firestore_v1._helpers import reference_value_to_document + from google.cloud.firestore_v1._helpers import ( + BAD_REFERENCE_ERROR, + reference_value_to_document, + ) reference_value = "not/the/right/format" with pytest.raises(ValueError) as exc_info: @@ -367,8 +369,8 @@ def test_reference_value_to_document_w_bad_format(): def test_reference_value_to_document_w_same_client(): - from google.cloud.firestore_v1.document import DocumentReference from google.cloud.firestore_v1._helpers import reference_value_to_document + from google.cloud.firestore_v1.document import DocumentReference client = _make_client() document = client.document("that", "this") @@ -383,8 +385,10 @@ def test_reference_value_to_document_w_same_client(): def test_reference_value_to_document_w_different_client(): - from google.cloud.firestore_v1._helpers import WRONG_APP_REFERENCE - from google.cloud.firestore_v1._helpers import reference_value_to_document + from google.cloud.firestore_v1._helpers import ( + WRONG_APP_REFERENCE, + reference_value_to_document, + ) client1 = _make_client(project="kirk") document = client1.document("tin", "foil") @@ -431,11 +435,12 @@ def test_documentreferencevalue_w_broken(): def test_document_snapshot_to_protobuf_w_real_snapshot(): + from google.protobuf import timestamp_pb2 # type: ignore + from google.cloud.firestore_v1._helpers import document_snapshot_to_protobuf - from google.cloud.firestore_v1.types import Document from google.cloud.firestore_v1.base_document import DocumentSnapshot from google.cloud.firestore_v1.document import DocumentReference - from google.protobuf import timestamp_pb2 # type: ignore + from google.cloud.firestore_v1.types import Document client = _make_client() snapshot = DocumentSnapshot( @@ -468,6 +473,7 @@ def test_document_snapshot_to_protobuf_w_non_existant_snapshot(): def test_decode_value_w_none(): from google.protobuf import struct_pb2 + from google.cloud.firestore_v1._helpers import decode_value value = _value_pb(null_value=struct_pb2.NULL_VALUE) @@ -500,10 +506,11 @@ def test_decode_value_w_float(): def test_decode_value_w_datetime(): - from google.cloud.firestore_v1._helpers import decode_value from google.api_core.datetime_helpers import DatetimeWithNanoseconds from google.protobuf import timestamp_pb2 + from google.cloud.firestore_v1._helpers import decode_value + dt_seconds = 552855006 dt_nanos = 766961828 @@ -531,8 +538,8 @@ def test_decode_value_w_bytes(): def test_decode_value_w_reference(): - from google.cloud.firestore_v1.document import DocumentReference from google.cloud.firestore_v1._helpers import decode_value + from google.cloud.firestore_v1.document import DocumentReference client = _make_client() path = ("then", "there-was-one") @@ -547,8 +554,7 @@ def test_decode_value_w_reference(): def test_decode_value_w_geo_point(): - from google.cloud.firestore_v1._helpers import GeoPoint - from google.cloud.firestore_v1._helpers import decode_value + from google.cloud.firestore_v1._helpers import GeoPoint, decode_value geo_pt = GeoPoint(latitude=42.5, longitude=99.0625) value = _value_pb(geo_point_value=geo_pt.to_protobuf()) @@ -556,8 +562,8 @@ def test_decode_value_w_geo_point(): def test_decode_value_w_array(): - from google.cloud.firestore_v1.types import document from google.cloud.firestore_v1._helpers import decode_value + from google.cloud.firestore_v1.types import document sub_value1 = _value_pb(boolean_value=True) sub_value2 = _value_pb(double_value=14.1396484375) @@ -574,8 +580,8 @@ def test_decode_value_w_array(): def test_decode_value_w_map(): - from google.cloud.firestore_v1.types import document from google.cloud.firestore_v1._helpers import decode_value + from google.cloud.firestore_v1.types import document sub_value1 = _value_pb(integer_value=187680) sub_value2 = _value_pb(string_value="how low can you go?") @@ -590,8 +596,8 @@ def test_decode_value_w_map(): def test_decode_value_w_nested_map(): - from google.cloud.firestore_v1.types import document from google.cloud.firestore_v1._helpers import decode_value + from google.cloud.firestore_v1.types import document actual_value1 = 1009876 actual_value2 = "hey you guys" @@ -646,12 +652,11 @@ def test_decode_value_w_unknown_value_type(): def test_decode_dict_w_many_types(): - from google.protobuf import struct_pb2 - from google.protobuf import timestamp_pb2 - from google.cloud.firestore_v1.types.document import ArrayValue - from google.cloud.firestore_v1.types.document import MapValue - from google.cloud.firestore_v1.field_path import FieldPath + from google.protobuf import struct_pb2, timestamp_pb2 + from google.cloud.firestore_v1._helpers import decode_dict + from google.cloud.firestore_v1.field_path import FieldPath + from google.cloud.firestore_v1.types.document import ArrayValue, MapValue dt_seconds = 1394037350 dt_nanos = 667285000 @@ -711,8 +716,8 @@ def _dummy_ref_string(collection_id): def test_get_doc_id_w_success(): - from google.cloud.firestore_v1.types import document from google.cloud.firestore_v1._helpers import get_doc_id + from google.cloud.firestore_v1.types import document prefix = _dummy_ref_string("sub-collection") actual_id = "this-is-the-one" @@ -724,8 +729,8 @@ def test_get_doc_id_w_success(): def test_get_doc_id_w_failure(): - from google.cloud.firestore_v1.types import document from google.cloud.firestore_v1._helpers import get_doc_id + from google.cloud.firestore_v1.types import document actual_prefix = _dummy_ref_string("the-right-one") wrong_prefix = _dummy_ref_string("the-wrong-one") @@ -742,8 +747,7 @@ def test_get_doc_id_w_failure(): def test_extract_fields_w_empty_document(): - from google.cloud.firestore_v1._helpers import extract_fields - from google.cloud.firestore_v1._helpers import _EmptyDict + from google.cloud.firestore_v1._helpers import _EmptyDict, extract_fields document_data = {} prefix_path = _make_field_path() @@ -779,8 +783,7 @@ def test_extract_fields_w_shallow_keys(): def test_extract_fields_w_nested(): - from google.cloud.firestore_v1._helpers import _EmptyDict - from google.cloud.firestore_v1._helpers import extract_fields + from google.cloud.firestore_v1._helpers import _EmptyDict, extract_fields document_data = {"b": {"a": {"d": 4, "c": 3, "g": {}}, "e": 7}, "f": 5} prefix_path = _make_field_path() @@ -797,8 +800,7 @@ def test_extract_fields_w_nested(): def test_extract_fields_w_expand_dotted(): - from google.cloud.firestore_v1._helpers import _EmptyDict - from google.cloud.firestore_v1._helpers import extract_fields + from google.cloud.firestore_v1._helpers import _EmptyDict, extract_fields document_data = { "b": {"a": {"d": 4, "c": 3, "g": {}, "k.l.m": 17}, "e": 7}, @@ -845,8 +847,7 @@ def test_set_field_value_normal_value_w_nested(): def test_set_field_value_empty_dict_w_shallow(): - from google.cloud.firestore_v1._helpers import _EmptyDict - from google.cloud.firestore_v1._helpers import set_field_value + from google.cloud.firestore_v1._helpers import _EmptyDict, set_field_value document = {} field_path = _make_field_path("a") @@ -858,8 +859,7 @@ def test_set_field_value_empty_dict_w_shallow(): def test_set_field_value_empty_dict_w_nested(): - from google.cloud.firestore_v1._helpers import _EmptyDict - from google.cloud.firestore_v1._helpers import set_field_value + from google.cloud.firestore_v1._helpers import _EmptyDict, set_field_value document = {} field_path = _make_field_path("a", "b", "c") @@ -1369,8 +1369,8 @@ def test_documentextractor_get_update_pb_w_exists_precondition(): def test_documentextractor_get_update_pb_wo_exists_precondition(): - from google.cloud.firestore_v1.types import write from google.cloud.firestore_v1._helpers import encode_dict + from google.cloud.firestore_v1.types import write document_data = {"a": 1} inst = _make_document_extractor(document_data) @@ -1395,9 +1395,9 @@ def test_documentextractor_get_field_transform_pbs_miss(): def test_documentextractor_get_field_transform_pbs_w_server_timestamp(): - from google.cloud.firestore_v1.types import write - from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP from google.cloud.firestore_v1._helpers import REQUEST_TIME_ENUM + from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP + from google.cloud.firestore_v1.types import write document_data = {"a": SERVER_TIMESTAMP} inst = _make_document_extractor(document_data) @@ -1413,9 +1413,9 @@ def test_documentextractor_get_field_transform_pbs_w_server_timestamp(): def test_documentextractor_get_transform_pb_w_server_timestamp_w_exists_precondition(): - from google.cloud.firestore_v1.types import write - from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP from google.cloud.firestore_v1._helpers import REQUEST_TIME_ENUM + from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP + from google.cloud.firestore_v1.types import write document_data = {"a": SERVER_TIMESTAMP} inst = _make_document_extractor(document_data) @@ -1435,9 +1435,9 @@ def test_documentextractor_get_transform_pb_w_server_timestamp_w_exists_precondi def test_documentextractor_get_transform_pb_w_server_timestamp_wo_exists_precondition(): - from google.cloud.firestore_v1.types import write - from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP from google.cloud.firestore_v1._helpers import REQUEST_TIME_ENUM + from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP + from google.cloud.firestore_v1.types import write document_data = {"a": {"b": {"c": SERVER_TIMESTAMP}}} inst = _make_document_extractor(document_data) @@ -1462,8 +1462,8 @@ def _array_value_to_list(array_value): def test_documentextractor_get_transform_pb_w_array_remove(): - from google.cloud.firestore_v1.types import write from google.cloud.firestore_v1.transforms import ArrayRemove + from google.cloud.firestore_v1.types import write values = [2, 4, 8] document_data = {"a": {"b": {"c": ArrayRemove(values)}}} @@ -1484,8 +1484,8 @@ def test_documentextractor_get_transform_pb_w_array_remove(): def test_documentextractor_get_transform_pb_w_array_union(): - from google.cloud.firestore_v1.types import write from google.cloud.firestore_v1.transforms import ArrayUnion + from google.cloud.firestore_v1.types import write values = [1, 3, 5] document_data = {"a": {"b": {"c": ArrayUnion(values)}}} @@ -1506,8 +1506,8 @@ def test_documentextractor_get_transform_pb_w_array_union(): def test_documentextractor_get_transform_pb_w_increment_int(): - from google.cloud.firestore_v1.types import write from google.cloud.firestore_v1.transforms import Increment + from google.cloud.firestore_v1.types import write value = 1 document_data = {"a": {"b": {"c": Increment(value)}}} @@ -1528,8 +1528,8 @@ def test_documentextractor_get_transform_pb_w_increment_int(): def test_documentextractor_get_transform_pb_w_increment_float(): - from google.cloud.firestore_v1.types import write from google.cloud.firestore_v1.transforms import Increment + from google.cloud.firestore_v1.types import write value = 3.1415926 document_data = {"a": {"b": {"c": Increment(value)}}} @@ -1550,8 +1550,8 @@ def test_documentextractor_get_transform_pb_w_increment_float(): def test_documentextractor_get_transform_pb_w_maximum_int(): - from google.cloud.firestore_v1.types import write from google.cloud.firestore_v1.transforms import Maximum + from google.cloud.firestore_v1.types import write value = 1 document_data = {"a": {"b": {"c": Maximum(value)}}} @@ -1572,8 +1572,8 @@ def test_documentextractor_get_transform_pb_w_maximum_int(): def test_documentextractor_get_transform_pb_w_maximum_float(): - from google.cloud.firestore_v1.types import write from google.cloud.firestore_v1.transforms import Maximum + from google.cloud.firestore_v1.types import write value = 3.1415926 document_data = {"a": {"b": {"c": Maximum(value)}}} @@ -1594,8 +1594,8 @@ def test_documentextractor_get_transform_pb_w_maximum_float(): def test_documentextractor_get_transform_pb_w_minimum_int(): - from google.cloud.firestore_v1.types import write from google.cloud.firestore_v1.transforms import Minimum + from google.cloud.firestore_v1.types import write value = 1 document_data = {"a": {"b": {"c": Minimum(value)}}} @@ -1616,8 +1616,8 @@ def test_documentextractor_get_transform_pb_w_minimum_int(): def test_documentextractor_get_transform_pb_w_minimum_float(): - from google.cloud.firestore_v1.types import write from google.cloud.firestore_v1.transforms import Minimum + from google.cloud.firestore_v1.types import write value = 3.1415926 document_data = {"a": {"b": {"c": Minimum(value)}}} @@ -1638,10 +1638,8 @@ def test_documentextractor_get_transform_pb_w_minimum_float(): def _make_write_w_document_for_create(document_path, **data): - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import write from google.cloud.firestore_v1._helpers import encode_dict - from google.cloud.firestore_v1.types import common + from google.cloud.firestore_v1.types import common, document, write return write.Write( update=document.Document(name=document_path, fields=encode_dict(data)), @@ -1662,8 +1660,8 @@ def _add_field_transforms_for_create(update_pb, fields): def __pbs_for_create_helper(do_transform=False, empty_val=False): - from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP from google.cloud.firestore_v1._helpers import pbs_for_create + from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP document_path = _make_ref_string("little", "town", "of", "ham") document_data = {"cheese": 1.5, "crackers": True} @@ -1705,9 +1703,8 @@ def test__pbs_for_create_w_transform_and_empty_value(): def _make_write_w_document_for_set_no_merge(document_path, **data): - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import write from google.cloud.firestore_v1._helpers import encode_dict + from google.cloud.firestore_v1.types import document, write return write.Write( update=document.Document(name=document_path, fields=encode_dict(data)) @@ -1740,8 +1737,8 @@ def test__pbs_for_set_w_empty_document(): def test__pbs_for_set_w_only_server_timestamp(): - from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP from google.cloud.firestore_v1._helpers import pbs_for_set_no_merge + from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP document_path = _make_ref_string("little", "town", "of", "ham") document_data = {"butter": SERVER_TIMESTAMP} @@ -1755,8 +1752,8 @@ def test__pbs_for_set_w_only_server_timestamp(): def _pbs_for_set_no_merge_helper(do_transform=False, empty_val=False): - from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP from google.cloud.firestore_v1._helpers import pbs_for_set_no_merge + from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP document_path = _make_ref_string("little", "town", "of", "ham") document_data = {"cheese": 1.5, "crackers": True} @@ -1994,9 +1991,8 @@ def test_documentextractorformerge_apply_merge_list_fields_w_array_union(): def _make_write_w_document_for_set_w_merge(document_path, **data): - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import write from google.cloud.firestore_v1._helpers import encode_dict + from google.cloud.firestore_v1.types import document, write return write.Write( update=document.Document(name=document_path, fields=encode_dict(data)) @@ -2054,8 +2050,8 @@ def test__pbs_for_set_with_merge_w_merge_field_wo_transform(): def test__pbs_for_set_with_merge_w_merge_true_w_only_transform(): - from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP from google.cloud.firestore_v1._helpers import pbs_for_set_with_merge + from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP document_path = _make_ref_string("little", "town", "of", "ham") document_data = {"butter": SERVER_TIMESTAMP} @@ -2070,8 +2066,8 @@ def test__pbs_for_set_with_merge_w_merge_true_w_only_transform(): def test__pbs_for_set_with_merge_w_merge_true_w_transform(): - from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP from google.cloud.firestore_v1._helpers import pbs_for_set_with_merge + from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP document_path = _make_ref_string("little", "town", "of", "ham") update_data = {"cheese": 1.5, "crackers": True} @@ -2088,8 +2084,8 @@ def test__pbs_for_set_with_merge_w_merge_true_w_transform(): def test__pbs_for_set_with_merge_w_merge_field_w_transform(): - from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP from google.cloud.firestore_v1._helpers import pbs_for_set_with_merge + from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP document_path = _make_ref_string("little", "town", "of", "ham") update_data = {"cheese": 1.5, "crackers": True} @@ -2110,8 +2106,8 @@ def test__pbs_for_set_with_merge_w_merge_field_w_transform(): def test__pbs_for_set_with_merge_w_merge_field_w_transform_masking_simple(): - from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP from google.cloud.firestore_v1._helpers import pbs_for_set_with_merge + from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP document_path = _make_ref_string("little", "town", "of", "ham") update_data = {"cheese": 1.5, "crackers": True} @@ -2130,8 +2126,8 @@ def test__pbs_for_set_with_merge_w_merge_field_w_transform_masking_simple(): def test__pbs_for_set_with_merge_w_merge_field_w_transform_parent(): - from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP from google.cloud.firestore_v1._helpers import pbs_for_set_with_merge + from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP document_path = _make_ref_string("little", "town", "of", "ham") update_data = {"cheese": 1.5, "crackers": True} @@ -2215,14 +2211,11 @@ def test_documentextractorforupdate_ctor_w_nested_dotted_keys(): def _pbs_for_update_helper(option=None, do_transform=False, **write_kwargs): - from google.cloud.firestore_v1 import _helpers + from google.cloud.firestore_v1 import DocumentTransform, _helpers + from google.cloud.firestore_v1._helpers import pbs_for_update from google.cloud.firestore_v1.field_path import FieldPath from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP - from google.cloud.firestore_v1 import DocumentTransform - from google.cloud.firestore_v1.types import common - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import write - from google.cloud.firestore_v1._helpers import pbs_for_update + from google.cloud.firestore_v1.types import common, document, write document_path = _make_ref_string("toy", "car", "onion", "garlic") field_path1 = "bitez.yum" @@ -2286,8 +2279,8 @@ def test__pbs_for_update_w_update_and_transform(): def _pb_for_delete_helper(option=None, **write_kwargs): - from google.cloud.firestore_v1.types import write from google.cloud.firestore_v1._helpers import pb_for_delete + from google.cloud.firestore_v1.types import write document_path = _make_ref_string("chicken", "philly", "one", "two") write_pb = pb_for_delete(document_path, option) @@ -2302,8 +2295,9 @@ def test__pb_for_delete_wo_option(): def test__pb_for_delete_w_option(): from google.protobuf import timestamp_pb2 - from google.cloud.firestore_v1.types import common + from google.cloud.firestore_v1 import _helpers + from google.cloud.firestore_v1.types import common update_time = timestamp_pb2.Timestamp(seconds=1309700594, nanos=822211297) option = _helpers.LastUpdateOption(update_time) @@ -2319,8 +2313,8 @@ def test_get_transaction_id_w_no_transaction(): def test_get_transaction_id_w_invalid_transaction(): - from google.cloud.firestore_v1.transaction import Transaction from google.cloud.firestore_v1._helpers import get_transaction_id + from google.cloud.firestore_v1.transaction import Transaction transaction = Transaction(mock.sentinel.client) assert not transaction.in_progress @@ -2329,9 +2323,11 @@ def test_get_transaction_id_w_invalid_transaction(): def test_get_transaction_id_w_after_writes_not_allowed(): - from google.cloud.firestore_v1._helpers import ReadAfterWriteError + from google.cloud.firestore_v1._helpers import ( + ReadAfterWriteError, + get_transaction_id, + ) from google.cloud.firestore_v1.transaction import Transaction - from google.cloud.firestore_v1._helpers import get_transaction_id transaction = Transaction(mock.sentinel.client) transaction._id = b"under-hook" @@ -2342,8 +2338,8 @@ def test_get_transaction_id_w_after_writes_not_allowed(): def test_get_transaction_id_w_after_writes_allowed(): - from google.cloud.firestore_v1.transaction import Transaction from google.cloud.firestore_v1._helpers import get_transaction_id + from google.cloud.firestore_v1.transaction import Transaction transaction = Transaction(mock.sentinel.client) txn_id = b"we-are-0fine" @@ -2355,8 +2351,8 @@ def test_get_transaction_id_w_after_writes_allowed(): def test_get_transaction_id_w_good_transaction(): - from google.cloud.firestore_v1.transaction import Transaction from google.cloud.firestore_v1._helpers import get_transaction_id + from google.cloud.firestore_v1.transaction import Transaction transaction = Transaction(mock.sentinel.client) txn_id = b"doubt-it" @@ -2416,9 +2412,9 @@ def test_lastupdateoption___eq___same_timestamp(): def test_lastupdateoption_modify_write_update_time(): from google.protobuf import timestamp_pb2 - from google.cloud.firestore_v1.types import common - from google.cloud.firestore_v1.types import write + from google.cloud.firestore_v1._helpers import LastUpdateOption + from google.cloud.firestore_v1.types import common, write timestamp_pb = timestamp_pb2.Timestamp(seconds=683893592, nanos=229362000) option = LastUpdateOption(timestamp_pb) @@ -2462,9 +2458,8 @@ def test_existsoption___eq___same_exists(): def test_existsoption_modify_write(): - from google.cloud.firestore_v1.types import common - from google.cloud.firestore_v1.types import write from google.cloud.firestore_v1._helpers import ExistsOption + from google.cloud.firestore_v1.types import common, write for exists in (True, False): option = ExistsOption(exists) @@ -2478,6 +2473,7 @@ def test_existsoption_modify_write(): def test_make_retry_timeout_kwargs_default(): from google.api_core.gapic_v1.method import DEFAULT + from google.cloud.firestore_v1._helpers import make_retry_timeout_kwargs kwargs = make_retry_timeout_kwargs(DEFAULT, None) @@ -2495,6 +2491,7 @@ def test_make_retry_timeout_kwargs_retry_None(): def test_make_retry_timeout_kwargs_retry_only(): from google.api_core.retry import Retry + from google.cloud.firestore_v1._helpers import make_retry_timeout_kwargs retry = Retry(predicate=object()) @@ -2505,6 +2502,7 @@ def test_make_retry_timeout_kwargs_retry_only(): def test_make_retry_timeout_kwargs_timeout_only(): from google.api_core.gapic_v1.method import DEFAULT + from google.cloud.firestore_v1._helpers import make_retry_timeout_kwargs timeout = 123.0 @@ -2515,6 +2513,7 @@ def test_make_retry_timeout_kwargs_timeout_only(): def test_make_retry_timeout_kwargs_retry_and_timeout(): from google.api_core.retry import Retry + from google.cloud.firestore_v1._helpers import make_retry_timeout_kwargs retry = Retry(predicate=object()) diff --git a/tests/unit/v1/test_aggregation.py b/tests/unit/v1/test_aggregation.py index d19cf69e81..59fe5378c8 100644 --- a/tests/unit/v1/test_aggregation.py +++ b/tests/unit/v1/test_aggregation.py @@ -12,18 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import types +from datetime import datetime, timedelta, timezone + import mock import pytest - -from datetime import datetime, timezone, timedelta - from google.cloud.firestore_v1.base_aggregation import ( + AggregationResult, + AvgAggregation, CountAggregation, SumAggregation, - AvgAggregation, - AggregationResult, ) from tests.unit.v1._test_helpers import ( make_aggregation_query, @@ -358,9 +356,10 @@ def test_aggregation_query_prep_stream_with_transaction(): def _aggregation_query_get_helper(retry=None, timeout=None, read_time=None): - from google.cloud.firestore_v1 import _helpers from google.cloud._helpers import _datetime_to_pb_timestamp + from google.cloud.firestore_v1 import _helpers + # Create a minimal fake GAPIC. firestore_api = mock.Mock(spec=["run_aggregation_query"]) @@ -492,9 +491,9 @@ def _aggregation_query_stream_w_retriable_exc_helper( transaction=None, expect_retry=True, ): - from google.api_core import exceptions - from google.api_core import gapic_v1 - from google.cloud.firestore_v1 import _helpers + from google.api_core import exceptions, gapic_v1 + + from google.cloud.firestore_v1 import _helpers, stream_generator if retry is _not_passed: retry = gapic_v1.method.DEFAULT @@ -536,7 +535,7 @@ def _stream_w_exception(*_args, **_kw): get_response = aggregation_query.stream(transaction=transaction, **kwargs) - assert isinstance(get_response, types.GeneratorType) + assert isinstance(get_response, stream_generator.StreamGenerator) if expect_retry: returned = list(get_response) else: diff --git a/tests/unit/v1/test_async_aggregation.py b/tests/unit/v1/test_async_aggregation.py index 4ed97ddb98..e51592ae3a 100644 --- a/tests/unit/v1/test_async_aggregation.py +++ b/tests/unit/v1/test_async_aggregation.py @@ -12,27 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest - +from datetime import datetime, timedelta, timezone -from datetime import datetime, timezone, timedelta +import pytest from google.cloud.firestore_v1.base_aggregation import ( + AggregationResult, + AvgAggregation, CountAggregation, SumAggregation, - AvgAggregation, - AggregationResult, ) - -from tests.unit.v1.test__helpers import AsyncIter -from tests.unit.v1.test__helpers import AsyncMock from tests.unit.v1._test_helpers import ( + make_aggregation_query_response, + make_async_aggregation_query, make_async_client, make_async_query, - make_async_aggregation_query, - make_aggregation_query_response, ) - +from tests.unit.v1.test__helpers import AsyncIter, AsyncMock _PROJECT = "PROJECT" @@ -298,9 +294,10 @@ def test_async_aggregation_query_prep_stream_with_transaction(): @pytest.mark.asyncio async def _async_aggregation_query_get_helper(retry=None, timeout=None, read_time=None): - from google.cloud.firestore_v1 import _helpers from google.cloud._helpers import _datetime_to_pb_timestamp + from google.cloud.firestore_v1 import _helpers + # Create a minimal fake GAPIC. firestore_api = AsyncMock(spec=["run_aggregation_query"]) diff --git a/tests/unit/v1/test_async_batch.py b/tests/unit/v1/test_async_batch.py index f44d0caa75..43fa809819 100644 --- a/tests/unit/v1/test_async_batch.py +++ b/tests/unit/v1/test_async_batch.py @@ -34,9 +34,9 @@ def test_constructor(): async def _commit_helper(retry=None, timeout=None): from google.protobuf import timestamp_pb2 + from google.cloud.firestore_v1 import _helpers - from google.cloud.firestore_v1.types import firestore - from google.cloud.firestore_v1.types import write + from google.cloud.firestore_v1.types import firestore, write # Create a minimal fake GAPIC with a dummy result. firestore_api = AsyncMock(spec=["commit"]) @@ -98,8 +98,8 @@ async def test_commit_w_retry_timeout(): @pytest.mark.asyncio async def test_as_context_mgr_wo_error(): from google.protobuf import timestamp_pb2 - from google.cloud.firestore_v1.types import firestore - from google.cloud.firestore_v1.types import write + + from google.cloud.firestore_v1.types import firestore, write firestore_api = AsyncMock(spec=["commit"]) timestamp = timestamp_pb2.Timestamp(seconds=1234567, nanos=123456798) diff --git a/tests/unit/v1/test_async_client.py b/tests/unit/v1/test_async_client.py index e2a2624c26..ee624d382b 100644 --- a/tests/unit/v1/test_async_client.py +++ b/tests/unit/v1/test_async_client.py @@ -18,9 +18,7 @@ import mock import pytest -from tests.unit.v1.test__helpers import AsyncIter -from tests.unit.v1.test__helpers import AsyncMock - +from tests.unit.v1.test__helpers import AsyncIter, AsyncMock PROJECT = "my-prahjekt" @@ -190,8 +188,8 @@ def test_asyncclient_document_factory_w_nested_path(): async def _collections_helper(retry=None, timeout=None): - from google.cloud.firestore_v1.async_collection import AsyncCollectionReference from google.cloud.firestore_v1 import _helpers + from google.cloud.firestore_v1.async_collection import AsyncCollectionReference collection_ids = ["users", "projects"] @@ -256,8 +254,8 @@ async def _invoke_get_all(client, references, document_pbs, **kwargs): async def _get_all_helper(num_snapshots=2, txn_id=None, retry=None, timeout=None): from google.cloud.firestore_v1 import _helpers - from google.cloud.firestore_v1.types import common from google.cloud.firestore_v1.async_document import DocumentSnapshot + from google.cloud.firestore_v1.types import common client = _make_default_async_client() @@ -400,8 +398,7 @@ def test_asyncclient_sync_copy(): @pytest.mark.asyncio async def test_asyncclient_recursive_delete(): - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import firestore + from google.cloud.firestore_v1.types import document, firestore client = _make_default_async_client() client._firestore_api_internal = AsyncMock(spec=["run_query"]) @@ -438,8 +435,7 @@ def _get_chunk(*args, **kwargs): @pytest.mark.asyncio async def test_asyncclient_recursive_delete_from_document(): - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import firestore + from google.cloud.firestore_v1.types import document, firestore client = _make_default_async_client() client._firestore_api_internal = mock.Mock( @@ -550,9 +546,10 @@ def _make_batch_response(**kwargs): def _doc_get_info(ref_string, values): - from google.cloud.firestore_v1.types import document from google.cloud._helpers import _datetime_to_pb_timestamp + from google.cloud.firestore_v1 import _helpers + from google.cloud.firestore_v1.types import document now = datetime.datetime.now(tz=datetime.timezone.utc) read_time = _datetime_to_pb_timestamp(now) diff --git a/tests/unit/v1/test_async_collection.py b/tests/unit/v1/test_async_collection.py index c5bce0ae8d..43884911b4 100644 --- a/tests/unit/v1/test_async_collection.py +++ b/tests/unit/v1/test_async_collection.py @@ -17,9 +17,8 @@ import mock import pytest -from tests.unit.v1.test__helpers import AsyncIter -from tests.unit.v1.test__helpers import AsyncMock from tests.unit.v1._test_helpers import DEFAULT_TEST_PROJECT, make_async_client +from tests.unit.v1.test__helpers import AsyncIter, AsyncMock def _make_async_collection_reference(*args, **kwargs): @@ -56,8 +55,8 @@ def test_asynccollectionreference_constructor(): def test_asynccollectionreference_query_method_matching(): - from google.cloud.firestore_v1.async_query import AsyncQuery from google.cloud.firestore_v1.async_collection import AsyncCollectionReference + from google.cloud.firestore_v1.async_query import AsyncQuery query_methods = _get_public_methods(AsyncQuery) collection_methods = _get_public_methods(AsyncCollectionReference) @@ -129,10 +128,10 @@ def test_async_collection_avg(): @pytest.mark.asyncio async def test_asynccollectionreference_add_auto_assigned(): - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.async_document import AsyncDocumentReference from google.cloud.firestore_v1 import SERVER_TIMESTAMP from google.cloud.firestore_v1._helpers import pbs_for_create + from google.cloud.firestore_v1.async_document import AsyncDocumentReference + from google.cloud.firestore_v1.types import document # Create a minimal fake GAPIC add attach it to a real client. firestore_api = AsyncMock(spec=["create_document", "commit"]) @@ -186,10 +185,8 @@ async def test_asynccollectionreference_add_auto_assigned(): def _write_pb_for_create(document_path, document_data): - from google.cloud.firestore_v1.types import common - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import write from google.cloud.firestore_v1 import _helpers + from google.cloud.firestore_v1.types import common, document, write return write.Write( update=document.Document( @@ -200,8 +197,8 @@ def _write_pb_for_create(document_path, document_data): async def _add_helper(retry=None, timeout=None): - from google.cloud.firestore_v1.async_document import AsyncDocumentReference from google.cloud.firestore_v1 import _helpers + from google.cloud.firestore_v1.async_document import AsyncDocumentReference # Create a minimal fake GAPIC with a dummy response. firestore_api = AsyncMock(spec=["commit"]) @@ -265,8 +262,7 @@ async def test_asynccollectionreference_add_w_retry_timeout(): @pytest.mark.asyncio async def test_asynccollectionreference_chunkify(): - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import firestore + from google.cloud.firestore_v1.types import document, firestore client = make_async_client() col = client.collection("my-collection") @@ -307,9 +303,10 @@ async def _get_chunk(*args, **kwargs): @pytest.mark.asyncio async def _list_documents_helper(page_size=None, retry=None, timeout=None): - from google.cloud.firestore_v1 import _helpers - from google.api_core.page_iterator_async import AsyncIterator from google.api_core.page_iterator import Page + from google.api_core.page_iterator_async import AsyncIterator + + from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.async_document import AsyncDocumentReference from google.cloud.firestore_v1.types.document import Document diff --git a/tests/unit/v1/test_async_document.py b/tests/unit/v1/test_async_document.py index 41a5abff56..8d67e78f08 100644 --- a/tests/unit/v1/test_async_document.py +++ b/tests/unit/v1/test_async_document.py @@ -17,8 +17,8 @@ import mock import pytest -from tests.unit.v1.test__helpers import AsyncIter, AsyncMock from tests.unit.v1._test_helpers import make_async_client +from tests.unit.v1.test__helpers import AsyncIter, AsyncMock def _make_async_document_reference(*args, **kwargs): @@ -55,10 +55,8 @@ def _make_commit_repsonse(write_results=None): def _write_pb_for_create(document_path, document_data): - from google.cloud.firestore_v1.types import common - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import write from google.cloud.firestore_v1 import _helpers + from google.cloud.firestore_v1.types import common, document, write return write.Write( update=document.Document( @@ -118,8 +116,10 @@ async def test_asyncdocumentreference_create_w_retry_timeout(): @pytest.mark.asyncio async def test_asyncdocumentreference_create_empty(): # Create a minimal fake GAPIC with a dummy response. - from google.cloud.firestore_v1.async_document import AsyncDocumentReference - from google.cloud.firestore_v1.async_document import DocumentSnapshot + from google.cloud.firestore_v1.async_document import ( + AsyncDocumentReference, + DocumentSnapshot, + ) firestore_api = AsyncMock(spec=["commit"]) document_reference = mock.create_autospec(AsyncDocumentReference) @@ -144,10 +144,8 @@ async def test_asyncdocumentreference_create_empty(): def _write_pb_for_set(document_path, document_data, merge): - from google.cloud.firestore_v1.types import common - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import write from google.cloud.firestore_v1 import _helpers + from google.cloud.firestore_v1.types import common, document, write write_pbs = write.Write( update=document.Document( @@ -221,10 +219,8 @@ async def test_asyncdocumentreference_set_merge(): def _write_pb_for_update(document_path, update_values, field_paths): - from google.cloud.firestore_v1.types import common - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import write from google.cloud.firestore_v1 import _helpers + from google.cloud.firestore_v1.types import common, document, write return write.Write( update=document.Document( @@ -405,10 +401,8 @@ async def _get_helper( timeout=None, ): from google.cloud.firestore_v1 import _helpers - from google.cloud.firestore_v1.types import common - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import firestore from google.cloud.firestore_v1.transaction import Transaction + from google.cloud.firestore_v1.types import common, document, firestore # Create a minimal fake GAPIC with a dummy response. create_time = 123 diff --git a/tests/unit/v1/test_async_query.py b/tests/unit/v1/test_async_query.py index c0f3d0d9ed..cacf0220b1 100644 --- a/tests/unit/v1/test_async_query.py +++ b/tests/unit/v1/test_async_query.py @@ -17,15 +17,13 @@ import mock import pytest -from tests.unit.v1.test__helpers import AsyncIter -from tests.unit.v1.test__helpers import AsyncMock -from tests.unit.v1.test_base_query import _make_query_response -from tests.unit.v1.test_base_query import _make_cursor_pb from tests.unit.v1._test_helpers import ( DEFAULT_TEST_PROJECT, make_async_client, make_async_query, ) +from tests.unit.v1.test__helpers import AsyncIter, AsyncMock +from tests.unit.v1.test_base_query import _make_cursor_pb, _make_query_response def test_asyncquery_constructor(): @@ -161,8 +159,8 @@ async def test_asyncquery_get_limit_to_last(): def test_asyncquery_sum(): - from google.cloud.firestore_v1.field_path import FieldPath from google.cloud.firestore_v1.base_aggregation import SumAggregation + from google.cloud.firestore_v1.field_path import FieldPath client = make_async_client() parent = client.collection("dee") @@ -190,8 +188,8 @@ def test_asyncquery_sum(): def test_asyncquery_avg(): - from google.cloud.firestore_v1.field_path import FieldPath from google.cloud.firestore_v1.base_aggregation import AvgAggregation + from google.cloud.firestore_v1.field_path import FieldPath client = make_async_client() parent = client.collection("dee") @@ -235,8 +233,7 @@ async def test_asyncquery_chunkify_w_empty(): @pytest.mark.asyncio async def test_asyncquery_chunkify_w_chunksize_lt_limit(): - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import firestore + from google.cloud.firestore_v1.types import document, firestore client = make_async_client() firestore_api = AsyncMock(spec=["run_query"]) @@ -283,8 +280,7 @@ async def test_asyncquery_chunkify_w_chunksize_lt_limit(): @pytest.mark.asyncio async def test_asyncquery_chunkify_w_chunksize_gt_limit(): - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import firestore + from google.cloud.firestore_v1.types import document, firestore client = make_async_client() @@ -316,6 +312,7 @@ async def test_asyncquery_chunkify_w_chunksize_gt_limit(): async def _stream_helper(retry=None, timeout=None): from google.cloud.firestore_v1 import _helpers + from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator # Create a minimal fake GAPIC. firestore_api = AsyncMock(spec=["run_query"]) @@ -340,7 +337,7 @@ async def _stream_helper(retry=None, timeout=None): get_response = query.stream(**kwargs) - assert isinstance(get_response, types.AsyncGeneratorType) + assert isinstance(get_response, AsyncStreamGenerator) returned = [x async for x in get_response] assert len(returned) == 1 snapshot = returned[0] @@ -392,6 +389,8 @@ async def test_asyncquery_stream_with_limit_to_last(): @pytest.mark.asyncio async def test_asyncquery_stream_with_transaction(): + from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator + # Create a minimal fake GAPIC. firestore_api = AsyncMock(spec=["run_query"]) @@ -417,7 +416,7 @@ async def test_asyncquery_stream_with_transaction(): # Execute the query and check the response. query = make_async_query(parent) get_response = query.stream(transaction=transaction) - assert isinstance(get_response, types.AsyncGeneratorType) + assert isinstance(get_response, AsyncStreamGenerator) returned = [x async for x in get_response] assert len(returned) == 1 snapshot = returned[0] @@ -437,6 +436,8 @@ async def test_asyncquery_stream_with_transaction(): @pytest.mark.asyncio async def test_asyncquery_stream_no_results(): + from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator + # Create a minimal fake GAPIC with a dummy response. firestore_api = AsyncMock(spec=["run_query"]) empty_response = _make_query_response() @@ -452,7 +453,7 @@ async def test_asyncquery_stream_no_results(): query = make_async_query(parent) get_response = query.stream() - assert isinstance(get_response, types.AsyncGeneratorType) + assert isinstance(get_response, AsyncStreamGenerator) assert [x async for x in get_response] == [] # Verify the mock call. @@ -469,6 +470,8 @@ async def test_asyncquery_stream_no_results(): @pytest.mark.asyncio async def test_asyncquery_stream_second_response_in_empty_stream(): + from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator + # Create a minimal fake GAPIC with a dummy response. firestore_api = AsyncMock(spec=["run_query"]) empty_response1 = _make_query_response() @@ -485,7 +488,7 @@ async def test_asyncquery_stream_second_response_in_empty_stream(): query = make_async_query(parent) get_response = query.stream() - assert isinstance(get_response, types.AsyncGeneratorType) + assert isinstance(get_response, AsyncStreamGenerator) assert [x async for x in get_response] == [] # Verify the mock call. @@ -502,6 +505,8 @@ async def test_asyncquery_stream_second_response_in_empty_stream(): @pytest.mark.asyncio async def test_asyncquery_stream_with_skipped_results(): + from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator + # Create a minimal fake GAPIC. firestore_api = AsyncMock(spec=["run_query"]) @@ -523,7 +528,7 @@ async def test_asyncquery_stream_with_skipped_results(): # Execute the query and check the response. query = make_async_query(parent) get_response = query.stream() - assert isinstance(get_response, types.AsyncGeneratorType) + assert isinstance(get_response, AsyncStreamGenerator) returned = [x async for x in get_response] assert len(returned) == 1 snapshot = returned[0] @@ -544,6 +549,8 @@ async def test_asyncquery_stream_with_skipped_results(): @pytest.mark.asyncio async def test_asyncquery_stream_empty_after_first_response(): + from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator + # Create a minimal fake GAPIC. firestore_api = AsyncMock(spec=["run_query"]) @@ -565,7 +572,7 @@ async def test_asyncquery_stream_empty_after_first_response(): # Execute the query and check the response. query = make_async_query(parent) get_response = query.stream() - assert isinstance(get_response, types.AsyncGeneratorType) + assert isinstance(get_response, AsyncStreamGenerator) returned = [x async for x in get_response] assert len(returned) == 1 snapshot = returned[0] @@ -586,6 +593,8 @@ async def test_asyncquery_stream_empty_after_first_response(): @pytest.mark.asyncio async def test_asyncquery_stream_w_collection_group(): + from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator + # Create a minimal fake GAPIC. firestore_api = AsyncMock(spec=["run_query"]) @@ -609,7 +618,7 @@ async def test_asyncquery_stream_w_collection_group(): query = make_async_query(parent) query._all_descendants = True get_response = query.stream() - assert isinstance(get_response, types.AsyncGeneratorType) + assert isinstance(get_response, AsyncStreamGenerator) returned = [x async for x in get_response] assert len(returned) == 1 snapshot = returned[0] diff --git a/tests/unit/v1/test_async_stream_generator.py b/tests/unit/v1/test_async_stream_generator.py new file mode 100644 index 0000000000..c2e7507b5d --- /dev/null +++ b/tests/unit/v1/test_async_stream_generator.py @@ -0,0 +1,95 @@ +# Copyright 2024 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + + +def _make_async_stream_generator(iterable): + from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator + + async def _inner_generator(): + for i in iterable: + X = yield i + if X: + yield X + + return AsyncStreamGenerator(_inner_generator()) + + +@pytest.mark.asyncio +async def test_async_stream_generator_aiter(): + expected_results = [0, 1, 2] + inst = _make_async_stream_generator(expected_results) + + actual_results = [] + async for result in inst: + actual_results.append(result) + + assert expected_results == actual_results + + +@pytest.mark.asyncio +async def test_async_stream_generator_anext(): + expected_results = [0, 1] + inst = _make_async_stream_generator(expected_results) + + actual_results = [] + + # Use inst.__anext__() instead of anext(inst), because built-in anext() + # was introduced in Python 3.10. + actual_results.append(await inst.__anext__()) + actual_results.append(await inst.__anext__()) + + with pytest.raises(StopAsyncIteration): + await inst.__anext__() + + assert expected_results == actual_results + + +@pytest.mark.asyncio +async def test_async_stream_generator_asend(): + expected_results = [0, 1] + inst = _make_async_stream_generator(expected_results) + + actual_results = [] + + # Use inst.__anext__() instead of anext(inst), because built-in anext() + # was introduced in Python 3.10. + actual_results.append(await inst.__anext__()) + assert await inst.asend(2) == 2 + actual_results.append(await inst.__anext__()) + + with pytest.raises(StopAsyncIteration): + await inst.__anext__() + + assert expected_results == actual_results + + +@pytest.mark.asyncio +async def test_async_stream_generator_athrow(): + inst = _make_async_stream_generator([]) + with pytest.raises(ValueError): + await inst.athrow(ValueError) + + +@pytest.mark.asyncio +async def test_stream_generator_aclose(): + expected_results = [0, 1] + inst = _make_async_stream_generator(expected_results) + + await inst.aclose() + + # Verifies that generator is closed. + with pytest.raises(StopAsyncIteration): + await inst.__anext__() diff --git a/tests/unit/v1/test_async_transaction.py b/tests/unit/v1/test_async_transaction.py index 7c1ab0650d..3c62e83d1b 100644 --- a/tests/unit/v1/test_async_transaction.py +++ b/tests/unit/v1/test_async_transaction.py @@ -195,8 +195,7 @@ async def test_asynctransaction__rollback_failure(): @pytest.mark.asyncio async def test_asynctransaction__commit(): - from google.cloud.firestore_v1.types import firestore - from google.cloud.firestore_v1.types import write + from google.cloud.firestore_v1.types import firestore, write # Create a minimal fake GAPIC with a dummy result. firestore_api = AsyncMock() @@ -316,8 +315,8 @@ async def test_asynctransaction_get_all_w_retry_timeout(): async def _get_w_document_ref_helper(retry=None, timeout=None): - from google.cloud.firestore_v1.async_document import AsyncDocumentReference from google.cloud.firestore_v1 import _helpers + from google.cloud.firestore_v1.async_document import AsyncDocumentReference client = AsyncMock(spec=["get_all"]) transaction = _make_async_transaction(client) @@ -345,8 +344,8 @@ async def test_asynctransaction_get_w_document_ref_w_retry_timeout(): async def _get_w_query_helper(retry=None, timeout=None): - from google.cloud.firestore_v1.async_query import AsyncQuery from google.cloud.firestore_v1 import _helpers + from google.cloud.firestore_v1.async_query import AsyncQuery client = AsyncMock(spec=[]) transaction = _make_async_transaction(client) @@ -493,9 +492,8 @@ async def test_asynctransactional___call__success_first_attempt(): @pytest.mark.asyncio async def test_asynctransactional___call__success_second_attempt(): from google.api_core import exceptions - from google.cloud.firestore_v1.types import common - from google.cloud.firestore_v1.types import firestore - from google.cloud.firestore_v1.types import write + + from google.cloud.firestore_v1.types import common, firestore, write to_wrap = AsyncMock(return_value=mock.sentinel.result, spec=[]) wrapped = _make_async_transactional(to_wrap) @@ -553,8 +551,9 @@ async def test_asynctransactional___call__failure_max_attempts(max_attempts): rasie retryable error and exhause max_attempts """ from google.api_core import exceptions - from google.cloud.firestore_v1.types import common + from google.cloud.firestore_v1.async_transaction import _EXCEED_ATTEMPTS_TEMPLATE + from google.cloud.firestore_v1.types import common to_wrap = AsyncMock(return_value=mock.sentinel.result, spec=[]) wrapped = _make_async_transactional(to_wrap) @@ -623,6 +622,7 @@ async def test_asynctransactional___call__failure_readonly(max_attempts): readonly transaction should never retry """ from google.api_core import exceptions + from google.cloud.firestore_v1.types import common to_wrap = AsyncMock(return_value=mock.sentinel.result, spec=[]) @@ -789,8 +789,10 @@ async def test_asynctransactional___call__failure_with_rollback_failure(): def test_async_transactional_factory(): - from google.cloud.firestore_v1.async_transaction import _AsyncTransactional - from google.cloud.firestore_v1.async_transaction import async_transactional + from google.cloud.firestore_v1.async_transaction import ( + _AsyncTransactional, + async_transactional, + ) wrapped = async_transactional(mock.sentinel.callable_) assert isinstance(wrapped, _AsyncTransactional) @@ -832,6 +834,7 @@ async def test__commit_with_retry_success_first_attempt(_sleep): @pytest.mark.asyncio async def test__commit_with_retry_success_third_attempt(_sleep): from google.api_core import exceptions + from google.cloud.firestore_v1.async_transaction import _commit_with_retry # Create a minimal fake GAPIC with a dummy result. @@ -874,6 +877,7 @@ async def test__commit_with_retry_success_third_attempt(_sleep): @pytest.mark.asyncio async def test__commit_with_retry_failure_first_attempt(_sleep): from google.api_core import exceptions + from google.cloud.firestore_v1.async_transaction import _commit_with_retry # Create a minimal fake GAPIC with a dummy result. @@ -910,6 +914,7 @@ async def test__commit_with_retry_failure_first_attempt(_sleep): @pytest.mark.asyncio async def test__commit_with_retry_failure_second_attempt(_sleep): from google.api_core import exceptions + from google.cloud.firestore_v1.async_transaction import _commit_with_retry # Create a minimal fake GAPIC with a dummy result. @@ -1011,9 +1016,9 @@ def _make_client(project="feral-tom-cat"): def _make_transaction(txn_id, **txn_kwargs): from google.protobuf import empty_pb2 - from google.cloud.firestore_v1.types import firestore - from google.cloud.firestore_v1.types import write + from google.cloud.firestore_v1.async_transaction import AsyncTransaction + from google.cloud.firestore_v1.types import firestore, write # Create a fake GAPIC ... firestore_api = AsyncMock() diff --git a/tests/unit/v1/test_async_vector_query.py b/tests/unit/v1/test_async_vector_query.py new file mode 100644 index 0000000000..69e855b530 --- /dev/null +++ b/tests/unit/v1/test_async_vector_query.py @@ -0,0 +1,236 @@ +# Copyright 2024 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from google.cloud.firestore_v1._helpers import encode_value, make_retry_timeout_kwargs +from google.cloud.firestore_v1.base_vector_query import DistanceMeasure +from google.cloud.firestore_v1.types.query import StructuredQuery +from google.cloud.firestore_v1.vector import Vector +from tests.unit.v1._test_helpers import make_async_client, make_async_query, make_query +from tests.unit.v1.test__helpers import AsyncIter, AsyncMock +from tests.unit.v1.test_base_query import _make_query_response + +_PROJECT = "PROJECT" +_TXN_ID = b"\x00\x00\x01-work-\xf2" + + +def _transaction(client): + transaction = client.transaction() + txn_id = _TXN_ID + transaction._id = txn_id + return transaction + + +def _expected_pb(parent, vector_field, vector, distance_type, limit): + query = make_query(parent) + expected_pb = query._to_protobuf() + expected_pb.find_nearest = StructuredQuery.FindNearest( + vector_field=StructuredQuery.FieldReference(field_path=vector_field), + query_vector=encode_value(vector.to_map_value()), + distance_measure=distance_type, + limit=limit, + ) + return expected_pb + + +@pytest.mark.parametrize( + "distance_measure, expected_distance", + [ + ( + DistanceMeasure.EUCLIDEAN, + StructuredQuery.FindNearest.DistanceMeasure.EUCLIDEAN, + ), + (DistanceMeasure.COSINE, StructuredQuery.FindNearest.DistanceMeasure.COSINE), + ( + DistanceMeasure.DOT_PRODUCT, + StructuredQuery.FindNearest.DistanceMeasure.DOT_PRODUCT, + ), + ], +) +@pytest.mark.asyncio +async def test_async_vector_query_with_filter(distance_measure, expected_distance): + # Create a minimal fake GAPIC. + firestore_api = AsyncMock(spec=["run_query"]) + client = make_async_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("dee") + query = make_async_query(parent) + parent_path, expected_prefix = parent._parent_info() + + data = {"snooze": 10, "embedding": Vector([1.0, 2.0, 3.0])} + response_pb1 = _make_query_response( + name="{}/test_doc".format(expected_prefix), data=data + ) + response_pb2 = _make_query_response( + name="{}/test_doc".format(expected_prefix), data=data + ) + + kwargs = make_retry_timeout_kwargs(retry=None, timeout=None) + + # Execute the vector query and check the response. + firestore_api.run_query.return_value = AsyncIter([response_pb1, response_pb2]) + + vector_async__query = query.where("snooze", "==", 10).find_nearest( + vector_field="embedding", + query_vector=Vector([1.0, 2.0, 3.0]), + distance_measure=distance_measure, + limit=5, + ) + + returned = await vector_async__query.get(transaction=_transaction(client), **kwargs) + assert isinstance(returned, list) + assert len(returned) == 2 + assert returned[0].to_dict() == data + + expected_pb = _expected_pb( + parent=parent, + vector_field="embedding", + vector=Vector([1.0, 2.0, 3.0]), + distance_type=expected_distance, + limit=5, + ) + expected_pb.where = StructuredQuery.Filter( + field_filter=StructuredQuery.FieldFilter( + field=StructuredQuery.FieldReference(field_path="snooze"), + op=StructuredQuery.FieldFilter.Operator.EQUAL, + value=encode_value(10), + ) + ) + + firestore_api.run_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": expected_pb, + "transaction": _TXN_ID, + }, + metadata=client._rpc_metadata, + **kwargs, + ) + + +@pytest.mark.parametrize( + "distance_measure, expected_distance", + [ + ( + DistanceMeasure.EUCLIDEAN, + StructuredQuery.FindNearest.DistanceMeasure.EUCLIDEAN, + ), + (DistanceMeasure.COSINE, StructuredQuery.FindNearest.DistanceMeasure.COSINE), + ( + DistanceMeasure.DOT_PRODUCT, + StructuredQuery.FindNearest.DistanceMeasure.DOT_PRODUCT, + ), + ], +) +@pytest.mark.asyncio +async def test_vector_query_collection_group(distance_measure, expected_distance): + # Create a minimal fake GAPIC. + firestore_api = AsyncMock(spec=["run_query"]) + client = make_async_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection group reference as parent. + collection_group_ref = client.collection_group("dee") + + data = {"snooze": 10, "embedding": Vector([1.0, 2.0, 3.0])} + response_pb = _make_query_response(name="xxx/test_doc", data=data) + + kwargs = make_retry_timeout_kwargs(retry=None, timeout=None) + + # Execute the vector query and check the response. + firestore_api.run_query.return_value = AsyncIter([response_pb]) + + vector_query = collection_group_ref.where("snooze", "==", 10).find_nearest( + vector_field="embedding", + query_vector=Vector([1.0, 2.0, 3.0]), + distance_measure=distance_measure, + limit=5, + ) + + returned = await vector_query.get(transaction=_transaction(client), **kwargs) + assert isinstance(returned, list) + assert len(returned) == 1 + assert returned[0].to_dict() == data + + parent = client.collection("dee") + parent_path, expected_prefix = parent._parent_info() + + expected_pb = _expected_pb( + parent=parent, + vector_field="embedding", + vector=Vector([1.0, 2.0, 3.0]), + distance_type=expected_distance, + limit=5, + ) + expected_pb.where = StructuredQuery.Filter( + field_filter=StructuredQuery.FieldFilter( + field=StructuredQuery.FieldReference(field_path="snooze"), + op=StructuredQuery.FieldFilter.Operator.EQUAL, + value=encode_value(10), + ) + ) + expected_pb.from_ = [ + StructuredQuery.CollectionSelector(collection_id="dee", all_descendants=True) + ] + + firestore_api.run_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": expected_pb, + "transaction": _TXN_ID, + }, + metadata=client._rpc_metadata, + **kwargs, + ) + + +@pytest.mark.asyncio +async def test_async_query_stream_multiple_empty_response_in_stream(): + # Create a minimal fake GAPIC with a dummy response. + firestore_api = AsyncMock(spec=["run_query"]) + empty_response1 = _make_query_response() + empty_response2 = _make_query_response() + run_query_response = AsyncIter([empty_response1, empty_response2]) + firestore_api.run_query.return_value = run_query_response + + # Attach the fake GAPIC to a real client. + client = make_async_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("dah", "dah", "dum") + async_vector_query = parent.where("snooze", "==", 10).find_nearest( + vector_field="embedding", + query_vector=Vector([1.0, 2.0, 3.0]), + distance_measure=DistanceMeasure.EUCLIDEAN, + limit=5, + ) + + result = [snapshot async for snapshot in async_vector_query.stream()] + + assert list(result) == [] + + # Verify the mock call. + parent_path, _ = parent._parent_info() + firestore_api.run_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": async_vector_query._to_protobuf(), + "transaction": None, + }, + metadata=client._rpc_metadata, + ) diff --git a/tests/unit/v1/test_base_batch.py b/tests/unit/v1/test_base_batch.py index eedb6625a3..3bd7c7e806 100644 --- a/tests/unit/v1/test_base_batch.py +++ b/tests/unit/v1/test_base_batch.py @@ -47,9 +47,7 @@ def test_basewritebatch__add_write_pbs(): def test_basewritebatch_create(): - from google.cloud.firestore_v1.types import common - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import write + from google.cloud.firestore_v1.types import common, document, write client = _make_client() batch = _make_derived_write_batch(client) @@ -73,8 +71,7 @@ def test_basewritebatch_create(): def test_basewritebatch_set(): - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import write + from google.cloud.firestore_v1.types import document, write client = _make_client() batch = _make_derived_write_batch(client) @@ -96,8 +93,7 @@ def test_basewritebatch_set(): def test_basewritebatch_set_merge(): - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import write + from google.cloud.firestore_v1.types import document, write client = _make_client() batch = _make_derived_write_batch(client) @@ -120,9 +116,7 @@ def test_basewritebatch_set_merge(): def test_basewritebatch_update(): - from google.cloud.firestore_v1.types import common - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import write + from google.cloud.firestore_v1.types import common, document, write client = _make_client() batch = _make_derived_write_batch(client) diff --git a/tests/unit/v1/test_base_client.py b/tests/unit/v1/test_base_client.py index 57d278daa2..e7eddcdeaa 100644 --- a/tests/unit/v1/test_base_client.py +++ b/tests/unit/v1/test_base_client.py @@ -13,8 +13,8 @@ # limitations under the License. import datetime -import grpc +import grpc import mock import pytest @@ -34,8 +34,11 @@ def _make_default_base_client(): def test_baseclient_constructor_with_emulator_host_defaults(): from google.auth.credentials import AnonymousCredentials - from google.cloud.firestore_v1.base_client import _DEFAULT_EMULATOR_PROJECT - from google.cloud.firestore_v1.base_client import _FIRESTORE_EMULATOR_HOST + + from google.cloud.firestore_v1.base_client import ( + _DEFAULT_EMULATOR_PROJECT, + _FIRESTORE_EMULATOR_HOST, + ) emulator_host = "localhost:8081" @@ -49,6 +52,7 @@ def test_baseclient_constructor_with_emulator_host_defaults(): def test_baseclient_constructor_with_emulator_host_w_project(): from google.auth.credentials import AnonymousCredentials + from google.cloud.firestore_v1.base_client import _FIRESTORE_EMULATOR_HOST emulator_host = "localhost:8081" @@ -61,8 +65,10 @@ def test_baseclient_constructor_with_emulator_host_w_project(): def test_baseclient_constructor_with_emulator_host_w_creds(): - from google.cloud.firestore_v1.base_client import _DEFAULT_EMULATOR_PROJECT - from google.cloud.firestore_v1.base_client import _FIRESTORE_EMULATOR_HOST + from google.cloud.firestore_v1.base_client import ( + _DEFAULT_EMULATOR_PROJECT, + _FIRESTORE_EMULATOR_HOST, + ) credentials = _make_credentials() emulator_host = "localhost:8081" @@ -296,6 +302,7 @@ def test_baseclient_field_path(): def test_baseclient_write_option_last_update(): from google.protobuf import timestamp_pb2 + from google.cloud.firestore_v1._helpers import LastUpdateOption from google.cloud.firestore_v1.base_client import BaseClient @@ -320,8 +327,7 @@ def test_baseclient_write_option_exists(): def test_baseclient_write_open_neither_arg(): - from google.cloud.firestore_v1.base_client import _BAD_OPTION_ERR - from google.cloud.firestore_v1.base_client import BaseClient + from google.cloud.firestore_v1.base_client import _BAD_OPTION_ERR, BaseClient with pytest.raises(TypeError) as exc_info: BaseClient.write_option() @@ -330,8 +336,7 @@ def test_baseclient_write_open_neither_arg(): def test_baseclient_write_multiple_args(): - from google.cloud.firestore_v1.base_client import _BAD_OPTION_ERR - from google.cloud.firestore_v1.base_client import BaseClient + from google.cloud.firestore_v1.base_client import _BAD_OPTION_ERR, BaseClient with pytest.raises(TypeError) as exc_info: BaseClient.write_option(exists=False, last_update_time=mock.sentinel.timestamp) @@ -340,8 +345,7 @@ def test_baseclient_write_multiple_args(): def test_baseclient_write_bad_arg(): - from google.cloud.firestore_v1.base_client import _BAD_OPTION_ERR - from google.cloud.firestore_v1.base_client import BaseClient + from google.cloud.firestore_v1.base_client import _BAD_OPTION_ERR, BaseClient with pytest.raises(TypeError) as exc_info: BaseClient.write_option(spinach="popeye") @@ -376,8 +380,7 @@ def test__get_reference_success(): def test__get_reference_failure(): - from google.cloud.firestore_v1.base_client import _BAD_DOC_TEMPLATE - from google.cloud.firestore_v1.base_client import _get_reference + from google.cloud.firestore_v1.base_client import _BAD_DOC_TEMPLATE, _get_reference doc_path = "1/888/call-now" with pytest.raises(ValueError) as exc_info: @@ -399,10 +402,11 @@ def _dummy_ref_string(): def test__parse_batch_get_found(): - from google.cloud.firestore_v1.types import document from google.cloud._helpers import _datetime_to_pb_timestamp - from google.cloud.firestore_v1.document import DocumentSnapshot + from google.cloud.firestore_v1.base_client import _parse_batch_get + from google.cloud.firestore_v1.document import DocumentSnapshot + from google.cloud.firestore_v1.types import document now = datetime.datetime.now(tz=datetime.timezone.utc) read_time = _datetime_to_pb_timestamp(now) @@ -434,8 +438,8 @@ def test__parse_batch_get_found(): def test__parse_batch_get_missing(): - from google.cloud.firestore_v1.document import DocumentReference from google.cloud.firestore_v1.base_client import _parse_batch_get + from google.cloud.firestore_v1.document import DocumentReference ref_string = _dummy_ref_string() response_pb = _make_batch_response(missing=ref_string) @@ -475,8 +479,8 @@ def test__get_doc_mask_w_none(): def test__get_doc_mask_w_paths(): - from google.cloud.firestore_v1.types import common from google.cloud.firestore_v1.base_client import _get_doc_mask + from google.cloud.firestore_v1.types import common field_paths = ["a.b", "c"] result = _get_doc_mask(field_paths) diff --git a/tests/unit/v1/test_base_collection.py b/tests/unit/v1/test_base_collection.py index e867a30981..22baa0c5f3 100644 --- a/tests/unit/v1/test_base_collection.py +++ b/tests/unit/v1/test_base_collection.py @@ -302,8 +302,8 @@ def test_basecollectionreference_where_w___name___w_value_as_list_of_docref(mock @mock.patch("google.cloud.firestore_v1.base_query.BaseQuery", autospec=True) def test_basecollectionreference_order_by(mock_query): - from google.cloud.firestore_v1.base_query import BaseQuery from google.cloud.firestore_v1.base_collection import BaseCollectionReference + from google.cloud.firestore_v1.base_query import BaseQuery with mock.patch.object(BaseCollectionReference, "_query") as _query: _query.return_value = mock_query @@ -424,8 +424,7 @@ def test_basecollectionreference_end_at(mock_query): @mock.patch("random.choice") def test__auto_id(mock_rand_choice): - from google.cloud.firestore_v1.base_collection import _AUTO_ID_CHARS - from google.cloud.firestore_v1.base_collection import _auto_id + from google.cloud.firestore_v1.base_collection import _AUTO_ID_CHARS, _auto_id mock_result = "0123456789abcdefghij" mock_rand_choice.side_effect = list(mock_result) diff --git a/tests/unit/v1/test_base_document.py b/tests/unit/v1/test_base_document.py index 28fcc5b2a4..8098afd76a 100644 --- a/tests/unit/v1/test_base_document.py +++ b/tests/unit/v1/test_base_document.py @@ -274,6 +274,7 @@ def test_documentsnapshot___eq___same_reference_same_data(): @pytest.mark.xfail(strict=False) def test_documentsnapshot___hash__(): import datetime + from proto.datetime_helpers import DatetimeWithNanoseconds client = mock.MagicMock() @@ -401,8 +402,9 @@ def test__consume_single_get_failure_too_many(): def test__first_write_result_success(): from google.protobuf import timestamp_pb2 - from google.cloud.firestore_v1.types import write + from google.cloud.firestore_v1.base_document import _first_write_result + from google.cloud.firestore_v1.types import write single_result = write.WriteResult( update_time=timestamp_pb2.Timestamp(seconds=1368767504, nanos=458000123) @@ -421,8 +423,8 @@ def test__first_write_result_failure_not_enough(): def test__first_write_result_more_than_one(): - from google.cloud.firestore_v1.types import write from google.cloud.firestore_v1.base_document import _first_write_result + from google.cloud.firestore_v1.types import write result1 = write.WriteResult() result2 = write.WriteResult() diff --git a/tests/unit/v1/test_base_query.py b/tests/unit/v1/test_base_query.py index a3369954bb..227b46933f 100644 --- a/tests/unit/v1/test_base_query.py +++ b/tests/unit/v1/test_base_query.py @@ -225,9 +225,7 @@ def test_basequery_where_invalid_path(): def test_basequery_where(): from google.cloud.firestore_v1.base_query import BaseQuery - from google.cloud.firestore_v1.types import StructuredQuery - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import query + from google.cloud.firestore_v1.types import StructuredQuery, document, query query_inst = _make_base_query_all_fields( skip_fields=("field_filters",), all_descendants=True @@ -433,8 +431,8 @@ def test_basequery_order_by_invalid_path(): def test_basequery_order_by(): - from google.cloud.firestore_v1.types import StructuredQuery from google.cloud.firestore_v1.base_query import BaseQuery + from google.cloud.firestore_v1.types import StructuredQuery query1 = _make_base_query_all_fields(skip_fields=("orders",), all_descendants=True) @@ -760,10 +758,8 @@ def test_basequery_end_at(): def test_basequery_where_filter_keyword_arg(): - from google.cloud.firestore_v1.types import StructuredQuery - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import query - from google.cloud.firestore_v1.base_query import FieldFilter, And, Or + from google.cloud.firestore_v1.base_query import And, FieldFilter, Or + from google.cloud.firestore_v1.types import StructuredQuery, document, query op_class = StructuredQuery.FieldFilter.Operator @@ -877,7 +873,7 @@ def test_basequery_where_cannot_pass_both_positional_and_keyword_filter_arg(): def test_basequery_where_cannot_pass_filter_without_keyword_arg(): - from google.cloud.firestore_v1.base_query import FieldFilter, And + from google.cloud.firestore_v1.base_query import And, FieldFilter field_path_1 = "x.y" op_str_1 = ">" @@ -900,10 +896,9 @@ def test_basequery_where_cannot_pass_filter_without_keyword_arg(): def test_basequery_where_mix_of_field_and_composite(): - from google.cloud.firestore_v1.base_query import FieldFilter, And, Or - from google.cloud.firestore_v1.types import query + from google.cloud.firestore_v1.base_query import And, FieldFilter, Or + from google.cloud.firestore_v1.types import document, query from google.cloud.firestore_v1.types.query import StructuredQuery - from google.cloud.firestore_v1.types import document op_class = StructuredQuery.FieldFilter.Operator @@ -1046,10 +1041,7 @@ def test_basequery__filters_pb_empty(): def test_basequery__filters_pb_single(): - from google.cloud.firestore_v1.types import StructuredQuery - - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import query + from google.cloud.firestore_v1.types import StructuredQuery, document, query query1 = _make_base_query(mock.sentinel.parent) query2 = query1.where("x.y", ">", 50.5) @@ -1065,10 +1057,7 @@ def test_basequery__filters_pb_single(): def test_basequery__filters_pb_multi(): - from google.cloud.firestore_v1.types import StructuredQuery - - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import query + from google.cloud.firestore_v1.types import StructuredQuery, document, query query1 = _make_base_query(mock.sentinel.parent) query2 = query1.where("x.y", ">", 50.5) @@ -1418,10 +1407,8 @@ def test_basequery__normalize_cursor_w___name___wo_slash(): def test_basequery__to_protobuf_all_fields(): from google.protobuf import wrappers_pb2 - from google.cloud.firestore_v1.types import StructuredQuery - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import query + from google.cloud.firestore_v1.types import StructuredQuery, document, query parent = mock.Mock(id="cat", spec=["id"]) query1 = _make_base_query(parent) @@ -1484,10 +1471,7 @@ def test_basequery__to_protobuf_select_only(): def test_basequery__to_protobuf_where_only(): - from google.cloud.firestore_v1.types import StructuredQuery - - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import query + from google.cloud.firestore_v1.types import StructuredQuery, document, query parent = mock.Mock(id="dog", spec=["id"]) query1 = _make_base_query(parent) @@ -1509,9 +1493,7 @@ def test_basequery__to_protobuf_where_only(): def test_basequery__to_protobuf_order_by_only(): - from google.cloud.firestore_v1.types import StructuredQuery - - from google.cloud.firestore_v1.types import query + from google.cloud.firestore_v1.types import StructuredQuery, query parent = mock.Mock(id="fish", spec=["id"]) query1 = _make_base_query(parent) @@ -1528,10 +1510,7 @@ def test_basequery__to_protobuf_order_by_only(): def test_basequery__to_protobuf_start_at_only(): # NOTE: "only" is wrong since we must have ``order_by`` as well. - from google.cloud.firestore_v1.types import StructuredQuery - - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import query + from google.cloud.firestore_v1.types import StructuredQuery, document, query parent = mock.Mock(id="phish", spec=["id"]) query_inst = _make_base_query(parent).order_by("X.Y").start_after({"X": {"Y": "Z"}}) @@ -1548,10 +1527,7 @@ def test_basequery__to_protobuf_start_at_only(): def test_basequery__to_protobuf_end_at_only(): # NOTE: "only" is wrong since we must have ``order_by`` as well. - from google.cloud.firestore_v1.types import StructuredQuery - - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import query + from google.cloud.firestore_v1.types import StructuredQuery, document, query parent = mock.Mock(id="ghoti", spec=["id"]) query_inst = _make_base_query(parent).order_by("a").end_at({"a": 88}) @@ -1585,6 +1561,7 @@ def test_basequery__to_protobuf_offset_only(): def test_basequery__to_protobuf_limit_only(): from google.protobuf import wrappers_pb2 + from google.cloud.firestore_v1.types import query parent = mock.Mock(id="donut", spec=["id"]) @@ -1701,8 +1678,8 @@ def test_basequery_comparator_missing_order_by_field_in_data_raises(): def test_basequery_recursive_multiple(): - from google.cloud.firestore_v1.collection import CollectionReference from google.cloud.firestore_v1.base_query import BaseQuery + from google.cloud.firestore_v1.collection import CollectionReference class DerivedQuery(BaseQuery): @staticmethod @@ -1813,9 +1790,9 @@ def test__isnan_invalid(): def test__enum_from_direction_success(): - from google.cloud.firestore_v1.types import StructuredQuery from google.cloud.firestore_v1.base_query import _enum_from_direction from google.cloud.firestore_v1.query import Query + from google.cloud.firestore_v1.types import StructuredQuery dir_class = StructuredQuery.Direction assert _enum_from_direction(Query.ASCENDING) == dir_class.ASCENDING @@ -1834,9 +1811,8 @@ def test__enum_from_direction_failure(): def test__filter_pb_unary(): - from google.cloud.firestore_v1.types import StructuredQuery from google.cloud.firestore_v1.base_query import _filter_pb - from google.cloud.firestore_v1.types import query + from google.cloud.firestore_v1.types import StructuredQuery, query unary_pb = query.StructuredQuery.UnaryFilter( field=query.StructuredQuery.FieldReference(field_path="a.b.c"), @@ -1848,10 +1824,8 @@ def test__filter_pb_unary(): def test__filter_pb_field(): - from google.cloud.firestore_v1.types import StructuredQuery - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import query from google.cloud.firestore_v1.base_query import _filter_pb + from google.cloud.firestore_v1.types import StructuredQuery, document, query field_filter_pb = query.StructuredQuery.FieldFilter( field=query.StructuredQuery.FieldReference(field_path="XYZ"), @@ -1877,9 +1851,9 @@ def test__cursor_pb_no_pair(): def test__cursor_pb_success(): - from google.cloud.firestore_v1.types import query from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.base_query import _cursor_pb + from google.cloud.firestore_v1.types import query data = [1.5, 10, True] cursor_pair = data, True @@ -1956,10 +1930,10 @@ def test__collection_group_query_response_to_snapshot_after_offset(): def test__collection_group_query_response_to_snapshot_response(): - from google.cloud.firestore_v1.document import DocumentSnapshot from google.cloud.firestore_v1.base_query import ( _collection_group_query_response_to_snapshot, ) + from google.cloud.firestore_v1.document import DocumentSnapshot client = make_client() collection = client.collection("a", "b", "c") @@ -1989,10 +1963,10 @@ def _make_order_pb(field_path, direction): def _make_query_response(**kwargs): # kwargs supported are ``skipped_results``, ``name`` and ``data`` - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import firestore from google.cloud._helpers import _datetime_to_pb_timestamp + from google.cloud.firestore_v1 import _helpers + from google.cloud.firestore_v1.types import document, firestore now = datetime.datetime.now(tz=datetime.timezone.utc) read_time = _datetime_to_pb_timestamp(now) diff --git a/tests/unit/v1/test_batch.py b/tests/unit/v1/test_batch.py index ba641751c4..5e51222981 100644 --- a/tests/unit/v1/test_batch.py +++ b/tests/unit/v1/test_batch.py @@ -32,9 +32,9 @@ def test_writebatch_ctor(): def _commit_helper(retry=None, timeout=None): from google.protobuf import timestamp_pb2 + from google.cloud.firestore_v1 import _helpers - from google.cloud.firestore_v1.types import firestore - from google.cloud.firestore_v1.types import write + from google.cloud.firestore_v1.types import firestore, write # Create a minimal fake GAPIC with a dummy result. firestore_api = mock.Mock(spec=["commit"]) @@ -93,8 +93,8 @@ def test_writebatch_commit_w_retry_timeout(): def test_writebatch_as_context_mgr_wo_error(): from google.protobuf import timestamp_pb2 - from google.cloud.firestore_v1.types import firestore - from google.cloud.firestore_v1.types import write + + from google.cloud.firestore_v1.types import firestore, write firestore_api = mock.Mock(spec=["commit"]) timestamp = timestamp_pb2.Timestamp(seconds=1234567, nanos=123456798) diff --git a/tests/unit/v1/test_bulk_batch.py b/tests/unit/v1/test_bulk_batch.py index 97cd66a417..bd23c61dca 100644 --- a/tests/unit/v1/test_bulk_batch.py +++ b/tests/unit/v1/test_bulk_batch.py @@ -30,8 +30,7 @@ def test_bulkwritebatch_ctor(): def _write_helper(retry=None, timeout=None): from google.cloud.firestore_v1 import _helpers - from google.cloud.firestore_v1.types import firestore - from google.cloud.firestore_v1.types import write + from google.cloud.firestore_v1.types import firestore, write # Create a minimal fake GAPIC with a dummy result. firestore_api = mock.Mock(spec=["batch_write"]) diff --git a/tests/unit/v1/test_bulk_writer.py b/tests/unit/v1/test_bulk_writer.py index ce62250e88..ac7d2e1da0 100644 --- a/tests/unit/v1/test_bulk_writer.py +++ b/tests/unit/v1/test_bulk_writer.py @@ -19,17 +19,15 @@ import mock import pytest -from google.cloud.firestore_v1 import async_client -from google.cloud.firestore_v1 import client -from google.cloud.firestore_v1 import base_client +from google.cloud.firestore_v1 import async_client, base_client, client def _make_no_send_bulk_writer(*args, **kwargs): from google.rpc import status_pb2 + from google.cloud.firestore_v1._helpers import build_timestamp from google.cloud.firestore_v1.bulk_batch import BulkWriteBatch - from google.cloud.firestore_v1.bulk_writer import BulkWriter - from google.cloud.firestore_v1.bulk_writer import BulkWriterOperation + from google.cloud.firestore_v1.bulk_writer import BulkWriter, BulkWriterOperation from google.cloud.firestore_v1.types.firestore import BatchWriteResponse from google.cloud.firestore_v1.types.write import WriteResult from tests.unit.v1._test_helpers import FakeThreadPoolExecutor @@ -133,8 +131,7 @@ def test_basebulkwriter_ctor_defaults(self): self._basebulkwriter_ctor_helper() def test_basebulkwriter_ctor_explicit(self): - from google.cloud.firestore_v1.bulk_writer import BulkRetry - from google.cloud.firestore_v1.bulk_writer import BulkWriterOptions + from google.cloud.firestore_v1.bulk_writer import BulkRetry, BulkWriterOptions options = BulkWriterOptions(retry=BulkRetry.immediate) self._basebulkwriter_ctor_helper(options=options) @@ -372,9 +369,11 @@ def _on_error(error, bw) -> bool: assert len(bw._operations) == 0 def test_basebulkwriter_invokes_error_callbacks_successfully_multiple_retries(self): - from google.cloud.firestore_v1.bulk_writer import BulkRetry - from google.cloud.firestore_v1.bulk_writer import BulkWriteFailure - from google.cloud.firestore_v1.bulk_writer import BulkWriterOptions + from google.cloud.firestore_v1.bulk_writer import ( + BulkRetry, + BulkWriteFailure, + BulkWriterOptions, + ) client = self._make_client() bw = _make_no_send_bulk_writer( @@ -416,8 +415,7 @@ def _on_error(error, bw) -> bool: assert len(bw._operations) == 0 def test_basebulkwriter_default_error_handler(self): - from google.cloud.firestore_v1.bulk_writer import BulkRetry - from google.cloud.firestore_v1.bulk_writer import BulkWriterOptions + from google.cloud.firestore_v1.bulk_writer import BulkRetry, BulkWriterOptions client = self._make_client() bw = _make_no_send_bulk_writer( @@ -440,9 +438,11 @@ def _on_error(error, bw): assert bw._attempts == 15 def test_basebulkwriter_handles_errors_and_successes_correctly(self): - from google.cloud.firestore_v1.bulk_writer import BulkRetry - from google.cloud.firestore_v1.bulk_writer import BulkWriteFailure - from google.cloud.firestore_v1.bulk_writer import BulkWriterOptions + from google.cloud.firestore_v1.bulk_writer import ( + BulkRetry, + BulkWriteFailure, + BulkWriterOptions, + ) client = self._make_client() bw = _make_no_send_bulk_writer( @@ -485,9 +485,11 @@ def _on_error(error, bw) -> bool: assert len(bw._operations) == 0 def test_basebulkwriter_create_retriable(self): - from google.cloud.firestore_v1.bulk_writer import BulkRetry - from google.cloud.firestore_v1.bulk_writer import BulkWriteFailure - from google.cloud.firestore_v1.bulk_writer import BulkWriterOptions + from google.cloud.firestore_v1.bulk_writer import ( + BulkRetry, + BulkWriteFailure, + BulkWriterOptions, + ) client = self._make_client() bw = _make_no_send_bulk_writer( @@ -516,9 +518,11 @@ def _on_error(error, bw) -> bool: assert len(bw._operations) == 0 def test_basebulkwriter_delete_retriable(self): - from google.cloud.firestore_v1.bulk_writer import BulkRetry - from google.cloud.firestore_v1.bulk_writer import BulkWriteFailure - from google.cloud.firestore_v1.bulk_writer import BulkWriterOptions + from google.cloud.firestore_v1.bulk_writer import ( + BulkRetry, + BulkWriteFailure, + BulkWriterOptions, + ) client = self._make_client() bw = _make_no_send_bulk_writer( @@ -547,9 +551,11 @@ def _on_error(error, bw) -> bool: assert len(bw._operations) == 0 def test_basebulkwriter_set_retriable(self): - from google.cloud.firestore_v1.bulk_writer import BulkRetry - from google.cloud.firestore_v1.bulk_writer import BulkWriteFailure - from google.cloud.firestore_v1.bulk_writer import BulkWriterOptions + from google.cloud.firestore_v1.bulk_writer import ( + BulkRetry, + BulkWriteFailure, + BulkWriterOptions, + ) client = self._make_client() bw = _make_no_send_bulk_writer( @@ -578,9 +584,11 @@ def _on_error(error, bw) -> bool: assert len(bw._operations) == 0 def test_basebulkwriter_update_retriable(self): - from google.cloud.firestore_v1.bulk_writer import BulkRetry - from google.cloud.firestore_v1.bulk_writer import BulkWriteFailure - from google.cloud.firestore_v1.bulk_writer import BulkWriterOptions + from google.cloud.firestore_v1.bulk_writer import ( + BulkRetry, + BulkWriteFailure, + BulkWriterOptions, + ) client = self._make_client() bw = _make_no_send_bulk_writer( @@ -609,8 +617,7 @@ def _on_error(error, bw) -> bool: assert len(bw._operations) == 0 def test_basebulkwriter_serial_calls_send_correctly(self): - from google.cloud.firestore_v1.bulk_writer import BulkWriterOptions - from google.cloud.firestore_v1.bulk_writer import SendMode + from google.cloud.firestore_v1.bulk_writer import BulkWriterOptions, SendMode client = self._make_client() bw = _make_no_send_bulk_writer( @@ -779,8 +786,10 @@ def test_scheduling_max_in_flight_honored(): def test_scheduling_operation_retry_scheduling(): - from google.cloud.firestore_v1.bulk_writer import BulkWriterCreateOperation - from google.cloud.firestore_v1.bulk_writer import OperationRetry + from google.cloud.firestore_v1.bulk_writer import ( + BulkWriterCreateOperation, + OperationRetry, + ) now = datetime.datetime.now() one_second_from_now = now + datetime.timedelta(seconds=1) diff --git a/tests/unit/v1/test_bundle.py b/tests/unit/v1/test_bundle.py index 15ee737581..d4b9a894b9 100644 --- a/tests/unit/v1/test_bundle.py +++ b/tests/unit/v1/test_bundle.py @@ -19,11 +19,9 @@ import mock import pytest -from google.cloud.firestore_v1 import base_query -from google.cloud.firestore_v1 import collection +from google.cloud.firestore_v1 import base_query, collection from google.cloud.firestore_v1 import query as query_mod from tests.unit.v1 import _test_helpers - from tests.unit.v1._test_helpers import DEFAULT_TEST_PROJECT @@ -61,10 +59,11 @@ def _bundled_collection_helper( and this method arranges all of the necessary mocks so that unit tests can think they are evaluating a live query. """ + from google.protobuf.timestamp_pb2 import Timestamp # type: ignore + from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.types.document import Document from google.cloud.firestore_v1.types.firestore import RunQueryResponse - from google.protobuf.timestamp_pb2 import Timestamp # type: ignore client = self.get_client() template = client._database_string + "/documents/col/{}" @@ -136,6 +135,7 @@ def test_add_document(self): def test_add_newer_document(self): from google.protobuf.timestamp_pb2 import Timestamp # type: ignore + from google.cloud.firestore_bundle import FirestoreBundle bundle = FirestoreBundle("test") @@ -158,6 +158,7 @@ def test_add_newer_document(self): def test_add_older_document(self): from google.protobuf.timestamp_pb2 import Timestamp # type: ignore + from google.cloud.firestore_bundle import FirestoreBundle bundle = FirestoreBundle("test") @@ -245,8 +246,8 @@ def test_bundle_build(self): assert isinstance(bundle.build(), str) def test_get_documents(self): - from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_bundle import FirestoreBundle + from google.cloud.firestore_v1 import _helpers bundle = FirestoreBundle("test") query: query_mod.Query = self._bundled_query_helper() # type: ignore @@ -454,6 +455,7 @@ def test_build_round_trip_more_unicode(self): def test_roundtrip_binary_data(self): import sys + from google.cloud.firestore_bundle import FirestoreBundle from google.cloud.firestore_v1 import _helpers @@ -475,6 +477,7 @@ def test_deserialize_from_seconds_nanos(self): '{"seconds": 123, "nanos": 456}', instead of an ISO-formatted string. This tests deserialization from that format.""" from google.protobuf.json_format import ParseError + from google.cloud.firestore_v1 import _helpers client = _test_helpers.make_client(project_name="fir-bundles-test") @@ -613,8 +616,7 @@ def test_not_actually_a_bundle_at_all(self): _helpers.deserialize_bundle("{}", client) def test_add_invalid_bundle_element_type(self): - from google.cloud.firestore_bundle import FirestoreBundle - from google.cloud.firestore_bundle import BundleElement + from google.cloud.firestore_bundle import BundleElement, FirestoreBundle client = _test_helpers.make_client() bundle = FirestoreBundle("asdf") diff --git a/tests/unit/v1/test_client.py b/tests/unit/v1/test_client.py index 3442358d5c..edb411c9ff 100644 --- a/tests/unit/v1/test_client.py +++ b/tests/unit/v1/test_client.py @@ -19,8 +19,8 @@ import pytest from google.cloud.firestore_v1.base_client import ( - DEFAULT_DATABASE, _DEFAULT_EMULATOR_PROJECT, + DEFAULT_DATABASE, ) PROJECT = "my-prahjekt" @@ -348,8 +348,8 @@ def _get_all_helper( num_snapshots=2, txn_id=None, retry=None, timeout=None, database=None ): from google.cloud.firestore_v1 import _helpers - from google.cloud.firestore_v1.types import common from google.cloud.firestore_v1.async_document import DocumentSnapshot + from google.cloud.firestore_v1.types import common client = _make_default_client(database=database) @@ -475,8 +475,7 @@ def test_client_get_all_unknown_result(database): @pytest.mark.parametrize("database", [None, DEFAULT_DATABASE, "somedb"]) def test_client_recursive_delete(database): - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import firestore + from google.cloud.firestore_v1.types import document, firestore client = _make_default_client(database=database) client._firestore_api_internal = mock.Mock(spec=["run_query"]) @@ -513,8 +512,7 @@ def _get_chunk(*args, **kwargs): @pytest.mark.parametrize("database", [None, DEFAULT_DATABASE, "somedb"]) def test_client_recursive_delete_from_document(database): - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import firestore + from google.cloud.firestore_v1.types import document, firestore client = _make_default_client(database=database) client._firestore_api_internal = mock.Mock( @@ -631,9 +629,10 @@ def _make_batch_response(**kwargs): def _doc_get_info(ref_string, values): - from google.cloud.firestore_v1.types import document from google.cloud._helpers import _datetime_to_pb_timestamp + from google.cloud.firestore_v1 import _helpers + from google.cloud.firestore_v1.types import document now = datetime.datetime.now(tz=datetime.timezone.utc) read_time = _datetime_to_pb_timestamp(now) diff --git a/tests/unit/v1/test_collection.py b/tests/unit/v1/test_collection.py index f3bc099b97..98c83664e1 100644 --- a/tests/unit/v1/test_collection.py +++ b/tests/unit/v1/test_collection.py @@ -39,8 +39,8 @@ def _get_public_methods(klass): def test_query_method_matching(): - from google.cloud.firestore_v1.query import Query from google.cloud.firestore_v1.collection import CollectionReference + from google.cloud.firestore_v1.query import Query query_methods = _get_public_methods(Query) collection_methods = _get_public_methods(CollectionReference) @@ -134,10 +134,10 @@ def test_constructor(): def test_add_auto_assigned(): - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.document import DocumentReference from google.cloud.firestore_v1 import SERVER_TIMESTAMP from google.cloud.firestore_v1._helpers import pbs_for_create + from google.cloud.firestore_v1.document import DocumentReference + from google.cloud.firestore_v1.types import document from tests.unit.v1 import _test_helpers # Create a minimal fake GAPIC add attach it to a real client. @@ -194,10 +194,8 @@ def test_add_auto_assigned(): def _write_pb_for_create(document_path, document_data): - from google.cloud.firestore_v1.types import common - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import write from google.cloud.firestore_v1 import _helpers + from google.cloud.firestore_v1.types import common, document, write return write.Write( update=document.Document( @@ -208,8 +206,8 @@ def _write_pb_for_create(document_path, document_data): def _add_helper(retry=None, timeout=None): - from google.cloud.firestore_v1.document import DocumentReference from google.cloud.firestore_v1 import _helpers as _fs_v1_helpers + from google.cloud.firestore_v1.document import DocumentReference from tests.unit.v1 import _test_helpers # Create a minimal fake GAPIC with a dummy response. @@ -269,9 +267,9 @@ def test_add_w_retry_timeout(): def _list_documents_helper(page_size=None, retry=None, timeout=None): + from google.api_core.page_iterator import Iterator, Page + from google.cloud.firestore_v1 import _helpers as _fs_v1_helpers - from google.api_core.page_iterator import Iterator - from google.api_core.page_iterator import Page from google.cloud.firestore_v1.document import DocumentReference from google.cloud.firestore_v1.services.firestore.client import FirestoreClient from google.cloud.firestore_v1.types.document import Document diff --git a/tests/unit/v1/test_cross_language.py b/tests/unit/v1/test_cross_language.py index 44f7985f1c..d2adeb2ba6 100644 --- a/tests/unit/v1/test_cross_language.py +++ b/tests/unit/v1/test_cross_language.py @@ -18,15 +18,11 @@ import os import mock -import pytest - import proto as proto_plus - -from google.cloud.firestore_v1.types import document -from google.cloud.firestore_v1.types import firestore -from google.cloud.firestore_v1.types import write +import pytest from google.protobuf.timestamp_pb2 import Timestamp +from google.cloud.firestore_v1.types import document, firestore, write from tests.unit.v1 import conformance_tests @@ -87,9 +83,10 @@ def _mock_firestore_api(): def _make_client_document(firestore_api, testcase): + import google.auth.credentials + from google.cloud.firestore_v1 import Client from google.cloud.firestore_v1.base_client import DEFAULT_DATABASE - import google.auth.credentials _, project, _, database, _, doc_path = testcase.doc_ref_path.split("/", 5) assert database == DEFAULT_DATABASE @@ -219,11 +216,10 @@ def test_listen_testprotos(test_proto): # pragma: NO COVER # and then an expected list of 'snapshots' (local 'Snapshot'), containing # 'docs' (list of 'google.firestore_v1.Document'), # 'changes' (list lof local 'DocChange', and 'read_time' timestamp. - from google.cloud.firestore_v1 import Client - from google.cloud.firestore_v1 import DocumentSnapshot - from google.cloud.firestore_v1 import Watch import google.auth.credentials + from google.cloud.firestore_v1 import Client, DocumentSnapshot, Watch + testcase = test_proto.listen testname = test_proto.description @@ -303,10 +299,12 @@ def test_query_testprotos(test_proto): # pragma: NO COVER def convert_data(v): # Replace the strings 'ServerTimestamp' and 'Delete' with the corresponding # sentinels. - from google.cloud.firestore_v1 import ArrayRemove - from google.cloud.firestore_v1 import ArrayUnion - from google.cloud.firestore_v1 import DELETE_FIELD - from google.cloud.firestore_v1 import SERVER_TIMESTAMP + from google.cloud.firestore_v1 import ( + DELETE_FIELD, + SERVER_TIMESTAMP, + ArrayRemove, + ArrayUnion, + ) if v == "ServerTimestamp": return SERVER_TIMESTAMP @@ -453,8 +451,8 @@ def parse_query(testcase): # 'path': str # 'json_data': str from google.auth.credentials import Credentials - from google.cloud.firestore_v1 import Client - from google.cloud.firestore_v1 import Query + + from google.cloud.firestore_v1 import Client, Query _directions = {"asc": Query.ASCENDING, "desc": Query.DESCENDING} @@ -507,8 +505,7 @@ def parse_path(path): def parse_cursor(cursor, client): - from google.cloud.firestore_v1 import DocumentReference - from google.cloud.firestore_v1 import DocumentSnapshot + from google.cloud.firestore_v1 import DocumentReference, DocumentSnapshot if "doc_snapshot" in cursor: path = parse_path(cursor.doc_snapshot.path) diff --git a/tests/unit/v1/test_document.py b/tests/unit/v1/test_document.py index d7ab541a22..b9116ae61d 100644 --- a/tests/unit/v1/test_document.py +++ b/tests/unit/v1/test_document.py @@ -53,10 +53,8 @@ def _make_commit_repsonse(write_results=None): def _write_pb_for_create(document_path, document_data): - from google.cloud.firestore_v1.types import common - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import write from google.cloud.firestore_v1 import _helpers + from google.cloud.firestore_v1.types import common, document, write return write.Write( update=document.Document( @@ -116,8 +114,7 @@ def test_documentreference_create_w_retry_timeout(database): @pytest.mark.parametrize("database", [None, "somedb"]) def test_documentreference_create_empty(database): # Create a minimal fake GAPIC with a dummy response. - from google.cloud.firestore_v1.document import DocumentReference - from google.cloud.firestore_v1.document import DocumentSnapshot + from google.cloud.firestore_v1.document import DocumentReference, DocumentSnapshot firestore_api = mock.Mock(spec=["commit"]) document_reference = mock.create_autospec(DocumentReference) @@ -142,10 +139,8 @@ def test_documentreference_create_empty(database): def _write_pb_for_set(document_path, document_data, merge): - from google.cloud.firestore_v1.types import common - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import write from google.cloud.firestore_v1 import _helpers + from google.cloud.firestore_v1.types import common, document, write write_pbs = write.Write( update=document.Document( @@ -218,10 +213,8 @@ def test_documentreference_set_merge(database): def _write_pb_for_update(document_path, update_values, field_paths): - from google.cloud.firestore_v1.types import common - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import write from google.cloud.firestore_v1 import _helpers + from google.cloud.firestore_v1.types import common, document, write return write.Write( update=document.Document( @@ -234,6 +227,7 @@ def _write_pb_for_update(document_path, update_values, field_paths): def _update_helper(retry=None, timeout=None, database=None, **option_kwargs): from collections import OrderedDict + from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.transforms import DELETE_FIELD @@ -401,10 +395,8 @@ def _get_helper( database=None, ): from google.cloud.firestore_v1 import _helpers - from google.cloud.firestore_v1.types import common - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import firestore from google.cloud.firestore_v1.transaction import Transaction + from google.cloud.firestore_v1.types import common, document, firestore # Create a minimal fake GAPIC with a dummy response. create_time = 123 @@ -529,8 +521,8 @@ def test_documentreference_get_with_transaction(database): def _collections_helper(page_size=None, retry=None, timeout=None, database=None): - from google.cloud.firestore_v1.collection import CollectionReference from google.cloud.firestore_v1 import _helpers + from google.cloud.firestore_v1.collection import CollectionReference from google.cloud.firestore_v1.services.firestore.client import FirestoreClient collection_ids = ["coll-1", "coll-2"] diff --git a/tests/unit/v1/test_order.py b/tests/unit/v1/test_order.py index f1100a098b..8b723b14f7 100644 --- a/tests/unit/v1/test_order.py +++ b/tests/unit/v1/test_order.py @@ -179,7 +179,7 @@ def test_order_compare_w_failure_to_find_type(): def test_order_all_value_present(): - from google.cloud.firestore_v1.order import TypeOrder, _TYPE_ORDER_MAP + from google.cloud.firestore_v1.order import _TYPE_ORDER_MAP, TypeOrder for type_order in TypeOrder: assert type_order in _TYPE_ORDER_MAP @@ -236,17 +236,17 @@ def nullValue(): def _timestamp_value(seconds, nanos): - from google.cloud.firestore_v1.types import document from google.protobuf import timestamp_pb2 + from google.cloud.firestore_v1.types import document + return document.Value( timestamp_value=timestamp_pb2.Timestamp(seconds=seconds, nanos=nanos) ) def _geoPoint_value(latitude, longitude): - from google.cloud.firestore_v1._helpers import encode_value - from google.cloud.firestore_v1._helpers import GeoPoint + from google.cloud.firestore_v1._helpers import GeoPoint, encode_value return encode_value(GeoPoint(latitude, longitude)) diff --git a/tests/unit/v1/test_query.py b/tests/unit/v1/test_query.py index a7f2e60162..b7add63f36 100644 --- a/tests/unit/v1/test_query.py +++ b/tests/unit/v1/test_query.py @@ -18,11 +18,8 @@ import pytest from google.cloud.firestore_v1.base_client import DEFAULT_DATABASE - -from tests.unit.v1.test_base_query import _make_cursor_pb -from tests.unit.v1.test_base_query import _make_query_response - from tests.unit.v1._test_helpers import DEFAULT_TEST_PROJECT, make_client, make_query +from tests.unit.v1.test_base_query import _make_cursor_pb, _make_query_response def test_query_constructor(): @@ -154,8 +151,8 @@ def test_query_get_limit_to_last(database): @pytest.mark.parametrize("database", [None, "somedb"]) def test_query_sum(database): - from google.cloud.firestore_v1.field_path import FieldPath from google.cloud.firestore_v1.base_aggregation import SumAggregation + from google.cloud.firestore_v1.field_path import FieldPath client = make_client(database=database) parent = client.collection("dee") @@ -184,8 +181,8 @@ def test_query_sum(database): @pytest.mark.parametrize("database", [None, "somedb"]) def test_query_avg(database): - from google.cloud.firestore_v1.field_path import FieldPath from google.cloud.firestore_v1.base_aggregation import AvgAggregation + from google.cloud.firestore_v1.field_path import FieldPath client = make_client(database=database) parent = client.collection("dee") @@ -306,6 +303,7 @@ def test_query_chunkify_w_chunksize_gt_limit(database, expected): def _query_stream_helper(retry=None, timeout=None, database=None): from google.cloud.firestore_v1 import _helpers + from google.cloud.firestore_v1.stream_generator import StreamGenerator # Create a minimal fake GAPIC. firestore_api = mock.Mock(spec=["run_query"]) @@ -330,7 +328,7 @@ def _query_stream_helper(retry=None, timeout=None, database=None): get_response = query.stream(**kwargs) - assert isinstance(get_response, types.GeneratorType) + assert isinstance(get_response, StreamGenerator) returned = list(get_response) assert len(returned) == 1 snapshot = returned[0] @@ -380,6 +378,8 @@ def test_query_stream_with_limit_to_last(database): @pytest.mark.parametrize("database", [None, "somedb"]) def test_query_stream_with_transaction(database): + from google.cloud.firestore_v1.stream_generator import StreamGenerator + # Create a minimal fake GAPIC. firestore_api = mock.Mock(spec=["run_query"]) @@ -405,7 +405,7 @@ def test_query_stream_with_transaction(database): # Execute the query and check the response. query = make_query(parent) get_response = query.stream(transaction=transaction) - assert isinstance(get_response, types.GeneratorType) + assert isinstance(get_response, StreamGenerator) returned = list(get_response) assert len(returned) == 1 snapshot = returned[0] @@ -425,6 +425,8 @@ def test_query_stream_with_transaction(database): @pytest.mark.parametrize("database", [None, "somedb"]) def test_query_stream_no_results(database): + from google.cloud.firestore_v1.stream_generator import StreamGenerator + # Create a minimal fake GAPIC with a dummy response. firestore_api = mock.Mock(spec=["run_query"]) empty_response = _make_query_response() @@ -440,7 +442,7 @@ def test_query_stream_no_results(database): query = make_query(parent) get_response = query.stream() - assert isinstance(get_response, types.GeneratorType) + assert isinstance(get_response, StreamGenerator) assert list(get_response) == [] # Verify the mock call. @@ -458,6 +460,8 @@ def test_query_stream_no_results(database): @pytest.mark.parametrize("database", [None, "somedb"]) def test_query_stream_second_response_in_empty_stream(database): + from google.cloud.firestore_v1.stream_generator import StreamGenerator + # Create a minimal fake GAPIC with a dummy response. firestore_api = mock.Mock(spec=["run_query"]) empty_response1 = _make_query_response() @@ -474,7 +478,7 @@ def test_query_stream_second_response_in_empty_stream(database): query = make_query(parent) get_response = query.stream() - assert isinstance(get_response, types.GeneratorType) + assert isinstance(get_response, StreamGenerator) assert list(get_response) == [] # Verify the mock call. @@ -491,6 +495,8 @@ def test_query_stream_second_response_in_empty_stream(database): @pytest.mark.parametrize("database", [None, "somedb"]) def test_query_stream_with_skipped_results(database): + from google.cloud.firestore_v1.stream_generator import StreamGenerator + # Create a minimal fake GAPIC. firestore_api = mock.Mock(spec=["run_query"]) @@ -512,7 +518,7 @@ def test_query_stream_with_skipped_results(database): # Execute the query and check the response. query = make_query(parent) get_response = query.stream() - assert isinstance(get_response, types.GeneratorType) + assert isinstance(get_response, StreamGenerator) returned = list(get_response) assert len(returned) == 1 snapshot = returned[0] @@ -533,6 +539,8 @@ def test_query_stream_with_skipped_results(database): @pytest.mark.parametrize("database", [None, "somedb"]) def test_query_stream_empty_after_first_response(database): + from google.cloud.firestore_v1.stream_generator import StreamGenerator + # Create a minimal fake GAPIC. firestore_api = mock.Mock(spec=["run_query"]) @@ -554,7 +562,7 @@ def test_query_stream_empty_after_first_response(database): # Execute the query and check the response. query = make_query(parent) get_response = query.stream() - assert isinstance(get_response, types.GeneratorType) + assert isinstance(get_response, StreamGenerator) returned = list(get_response) assert len(returned) == 1 snapshot = returned[0] @@ -575,6 +583,8 @@ def test_query_stream_empty_after_first_response(database): @pytest.mark.parametrize("database", [None, "somedb"]) def test_query_stream_w_collection_group(database): + from google.cloud.firestore_v1.stream_generator import StreamGenerator + # Create a minimal fake GAPIC. firestore_api = mock.Mock(spec=["run_query"]) @@ -598,7 +608,7 @@ def test_query_stream_w_collection_group(database): query = make_query(parent) query._all_descendants = True get_response = query.stream() - assert isinstance(get_response, types.GeneratorType) + assert isinstance(get_response, StreamGenerator) returned = list(get_response) assert len(returned) == 1 snapshot = returned[0] @@ -625,9 +635,10 @@ def test_query_stream_w_collection_group(database): def _query_stream_w_retriable_exc_helper( retry=_not_passed, timeout=None, transaction=None, expect_retry=True, database=None ): - from google.api_core import exceptions - from google.api_core import gapic_v1 + from google.api_core import exceptions, gapic_v1 + from google.cloud.firestore_v1 import _helpers + from google.cloud.firestore_v1.stream_generator import StreamGenerator if retry is _not_passed: retry = gapic_v1.method.DEFAULT @@ -668,7 +679,7 @@ def _stream_w_exception(*_args, **_kw): get_response = query.stream(transaction=transaction, **kwargs) - assert isinstance(get_response, types.GeneratorType) + assert isinstance(get_response, StreamGenerator) if expect_retry: returned = list(get_response) else: diff --git a/tests/unit/v1/test_rate_limiter.py b/tests/unit/v1/test_rate_limiter.py index c23b85ae03..3767108ae4 100644 --- a/tests/unit/v1/test_rate_limiter.py +++ b/tests/unit/v1/test_rate_limiter.py @@ -13,9 +13,9 @@ # limitations under the License. import datetime -import pytest import freezegun +import pytest from google.cloud.firestore_v1 import rate_limiter diff --git a/tests/unit/v1/test_stream_generator.py b/tests/unit/v1/test_stream_generator.py new file mode 100644 index 0000000000..bfc11cf6f6 --- /dev/null +++ b/tests/unit/v1/test_stream_generator.py @@ -0,0 +1,84 @@ +# Copyright 2024 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + + +def _make_stream_generator(iterable): + from google.cloud.firestore_v1.stream_generator import StreamGenerator + + def _inner_generator(): + for i in iterable: + X = yield i + if X: + yield X + + return StreamGenerator(_inner_generator()) + + +def test_stream_generator_iter(): + expected_results = [0, 1, 2] + inst = _make_stream_generator(expected_results) + + actual_results = [] + for result in inst: + actual_results.append(result) + + assert expected_results == actual_results + + +def test_stream_generator_next(): + expected_results = [0, 1] + inst = _make_stream_generator(expected_results) + + actual_results = [] + actual_results.append(next(inst)) + actual_results.append(next(inst)) + + with pytest.raises(StopIteration): + next(inst) + + assert expected_results == actual_results + + +def test_stream_generator_send(): + expected_results = [0, 1] + inst = _make_stream_generator(expected_results) + + actual_results = [] + actual_results.append(next(inst)) + assert inst.send(2) == 2 + actual_results.append(next(inst)) + + with pytest.raises(StopIteration): + next(inst) + + assert expected_results == actual_results + + +def test_stream_generator_throw(): + inst = _make_stream_generator([]) + with pytest.raises(ValueError): + inst.throw(ValueError) + + +def test_stream_generator_close(): + expected_results = [0, 1] + inst = _make_stream_generator(expected_results) + + inst.close() + + # Verifies that generator is closed. + with pytest.raises(StopIteration): + next(inst) diff --git a/tests/unit/v1/test_transaction.py b/tests/unit/v1/test_transaction.py index 26bb5cc9ca..fc56d2f9b0 100644 --- a/tests/unit/v1/test_transaction.py +++ b/tests/unit/v1/test_transaction.py @@ -125,6 +125,7 @@ def test_transaction__begin_failure(database): @pytest.mark.parametrize("database", [None, "somedb"]) def test_transaction__rollback(database): from google.protobuf import empty_pb2 + from google.cloud.firestore_v1.services.firestore import client as firestore_client # Create a minimal fake GAPIC with a dummy result. @@ -169,6 +170,7 @@ def test_transaction__rollback_not_allowed(database): @pytest.mark.parametrize("database", [None, "somedb"]) def test_transaction__rollback_failure(database): from google.api_core import exceptions + from google.cloud.firestore_v1.services.firestore import client as firestore_client # Create a minimal fake GAPIC with a dummy failure. @@ -204,8 +206,7 @@ def test_transaction__rollback_failure(database): @pytest.mark.parametrize("database", [None, "somedb"]) def test_transaction__commit(database): from google.cloud.firestore_v1.services.firestore import client as firestore_client - from google.cloud.firestore_v1.types import firestore - from google.cloud.firestore_v1.types import write + from google.cloud.firestore_v1.types import firestore, write # Create a minimal fake GAPIC with a dummy result. firestore_api = mock.create_autospec( @@ -257,6 +258,7 @@ def test_transaction__commit_not_allowed(): @pytest.mark.parametrize("database", [None, "somedb"]) def test_transaction__commit_failure(database): from google.api_core import exceptions + from google.cloud.firestore_v1.services.firestore import client as firestore_client # Create a minimal fake GAPIC with a dummy failure. @@ -327,8 +329,8 @@ def test_transaction_get_all_w_retry_timeout(): def _transaction_get_w_document_ref_helper(retry=None, timeout=None): - from google.cloud.firestore_v1.document import DocumentReference from google.cloud.firestore_v1 import _helpers + from google.cloud.firestore_v1.document import DocumentReference client = mock.Mock(spec=["get_all"]) transaction = _make_transaction(client) @@ -498,9 +500,8 @@ def test__transactional___call__success_first_attempt(database): @pytest.mark.parametrize("database", [None, "somedb"]) def test__transactional___call__success_second_attempt(database): from google.api_core import exceptions - from google.cloud.firestore_v1.types import common - from google.cloud.firestore_v1.types import firestore - from google.cloud.firestore_v1.types import write + + from google.cloud.firestore_v1.types import common, firestore, write to_wrap = mock.Mock(return_value=mock.sentinel.result, spec=[]) wrapped = _make__transactional(to_wrap) @@ -558,8 +559,9 @@ def test_transactional___call__failure_max_attempts(database, max_attempts): rasie retryable error and exhause max_attempts """ from google.api_core import exceptions - from google.cloud.firestore_v1.types import common + from google.cloud.firestore_v1.transaction import _EXCEED_ATTEMPTS_TEMPLATE + from google.cloud.firestore_v1.types import common to_wrap = mock.Mock(return_value=mock.sentinel.result, spec=[]) wrapped = _make__transactional(to_wrap) @@ -630,6 +632,7 @@ def test_transactional___call__failure_readonly(database, max_attempts): readonly transaction should never retry """ from google.api_core import exceptions + from google.cloud.firestore_v1.types import common to_wrap = mock.Mock(return_value=mock.sentinel.result, spec=[]) @@ -800,8 +803,7 @@ def test_transactional___call__failure_with_rollback_failure(database): def test_transactional_factory(): - from google.cloud.firestore_v1.transaction import _Transactional - from google.cloud.firestore_v1.transaction import transactional + from google.cloud.firestore_v1.transaction import _Transactional, transactional wrapped = transactional(mock.sentinel.callable_) assert isinstance(wrapped, _Transactional) @@ -844,6 +846,7 @@ def test__commit_with_retry_success_first_attempt(_sleep, database): @pytest.mark.parametrize("database", [None, "somedb"]) def test__commit_with_retry_success_third_attempt(_sleep, database): from google.api_core import exceptions + from google.cloud.firestore_v1.services.firestore import client as firestore_client from google.cloud.firestore_v1.transaction import _commit_with_retry @@ -888,6 +891,7 @@ def test__commit_with_retry_success_third_attempt(_sleep, database): @pytest.mark.parametrize("database", [None, "somedb"]) def test__commit_with_retry_failure_first_attempt(_sleep, database): from google.api_core import exceptions + from google.cloud.firestore_v1.services.firestore import client as firestore_client from google.cloud.firestore_v1.transaction import _commit_with_retry @@ -926,6 +930,7 @@ def test__commit_with_retry_failure_first_attempt(_sleep, database): @pytest.mark.parametrize("database", [None, "somedb"]) def test__commit_with_retry_failure_second_attempt(_sleep, database): from google.api_core import exceptions + from google.cloud.firestore_v1.services.firestore import client as firestore_client from google.cloud.firestore_v1.transaction import _commit_with_retry @@ -1026,10 +1031,10 @@ def _make_client(project="feral-tom-cat", database=None): def _make_transaction_pb(txn_id, database=None, **txn_kwargs): from google.protobuf import empty_pb2 + from google.cloud.firestore_v1.services.firestore import client as firestore_client - from google.cloud.firestore_v1.types import firestore - from google.cloud.firestore_v1.types import write from google.cloud.firestore_v1.transaction import Transaction + from google.cloud.firestore_v1.types import firestore, write # Create a fake GAPIC ... firestore_api = mock.create_autospec( diff --git a/tests/unit/v1/test_transforms.py b/tests/unit/v1/test_transforms.py index 1a46f27216..67cf5a6eb4 100644 --- a/tests/unit/v1/test_transforms.py +++ b/tests/unit/v1/test_transforms.py @@ -117,9 +117,10 @@ def test__numericvalue___eq___same_value(): def test__server_timestamp_is_same_after_copy(): - from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP import copy + from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP + value = SERVER_TIMESTAMP value_copy = copy.copy(value) @@ -127,9 +128,10 @@ def test__server_timestamp_is_same_after_copy(): def test__server_timestamp_is_same_after_deepcopy(): - from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP import copy + from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP + value = SERVER_TIMESTAMP value_copy = copy.deepcopy(value) diff --git a/tests/unit/v1/test_vector.py b/tests/unit/v1/test_vector.py index 6ca1ce4134..e411eac47b 100644 --- a/tests/unit/v1/test_vector.py +++ b/tests/unit/v1/test_vector.py @@ -13,15 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import google.auth.credentials +from unittest import mock +import google.auth.credentials from google.api_core import gapic_v1 + +from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.client import Client from google.cloud.firestore_v1.document import DocumentReference -from google.cloud.firestore_v1.vector import Vector from google.cloud.firestore_v1.types import common, document, firestore, write -from google.cloud.firestore_v1 import _helpers -from unittest import mock +from google.cloud.firestore_v1.vector import Vector def _make_commit_repsonse(): diff --git a/tests/unit/v1/test_vector_query.py b/tests/unit/v1/test_vector_query.py index 92dca45c4d..beb0941413 100644 --- a/tests/unit/v1/test_vector_query.py +++ b/tests/unit/v1/test_vector_query.py @@ -14,19 +14,13 @@ import mock import pytest -import types +from google.cloud.firestore_v1._helpers import encode_value, make_retry_timeout_kwargs +from google.cloud.firestore_v1.base_vector_query import DistanceMeasure from google.cloud.firestore_v1.types.query import StructuredQuery from google.cloud.firestore_v1.vector import Vector -from google.cloud.firestore_v1.base_vector_query import DistanceMeasure - -from tests.unit.v1._test_helpers import ( - make_vector_query, - make_client, - make_query, -) +from tests.unit.v1._test_helpers import make_client, make_query, make_vector_query from tests.unit.v1.test_base_query import _make_query_response -from google.cloud.firestore_v1._helpers import encode_value, make_retry_timeout_kwargs _PROJECT = "PROJECT" _TXN_ID = b"\x00\x00\x01-work-\xf2" @@ -326,6 +320,8 @@ def test_vector_query_collection_group(distance_measure, expected_distance): def test_query_stream_multiple_empty_response_in_stream(): + from google.cloud.firestore_v1 import stream_generator + # Create a minimal fake GAPIC with a dummy response. firestore_api = mock.Mock(spec=["run_query"]) empty_response1 = _make_query_response() @@ -347,7 +343,7 @@ def test_query_stream_multiple_empty_response_in_stream(): ) get_response = vector_query.stream() - assert isinstance(get_response, types.GeneratorType) + assert isinstance(get_response, stream_generator.StreamGenerator) assert list(get_response) == [] # Verify the mock call. diff --git a/tests/unit/v1/test_watch.py b/tests/unit/v1/test_watch.py index 2d7927a1de..094248e933 100644 --- a/tests/unit/v1/test_watch.py +++ b/tests/unit/v1/test_watch.py @@ -90,6 +90,7 @@ def test_watchresult_ctor(): def test__maybe_wrap_exception_w_grpc_error(): import grpc from google.api_core.exceptions import GoogleAPICallError + from google.cloud.firestore_v1.watch import _maybe_wrap_exception exc = grpc.RpcError() @@ -121,6 +122,7 @@ def test_document_watch_comparator_wdiff_doc(): def test__should_recover_w_unavailable(): from google.api_core.exceptions import ServiceUnavailable + from google.cloud.firestore_v1.watch import _should_recover exception = ServiceUnavailable("testing") @@ -138,6 +140,7 @@ def test__should_recover_w_non_recoverable(): def test__should_terminate_w_unavailable(): from google.api_core.exceptions import Cancelled + from google.cloud.firestore_v1.watch import _should_terminate exception = Cancelled("testing") @@ -194,8 +197,7 @@ def _make_watch(snapshots=None, comparator=_document_watch_comparator): def test_watch_ctor(): - from google.cloud.firestore_v1.watch import _should_recover - from google.cloud.firestore_v1.watch import _should_terminate + from google.cloud.firestore_v1.watch import _should_recover, _should_terminate with mock.patch("google.cloud.firestore_v1.watch.ResumableBidiRpc") as rpc: with mock.patch("google.cloud.firestore_v1.watch.BackgroundConsumer") as bc: @@ -406,6 +408,7 @@ def test_watch_on_snapshot_target_no_change_no_target_ids_not_current(): def test_watch_on_snapshot_target_no_change_no_target_ids_current(): import datetime + from proto.datetime_helpers import DatetimeWithNanoseconds inst = _make_watch() @@ -512,8 +515,7 @@ def test_watch_on_snapshot_target_unknown(): def test_watch_on_snapshot_document_change_removed(): from google.cloud.firestore_v1.types.document import Document - from google.cloud.firestore_v1.watch import WATCH_TARGET_ID - from google.cloud.firestore_v1.watch import ChangeType + from google.cloud.firestore_v1.watch import WATCH_TARGET_ID, ChangeType inst = _make_watch() proto = _make_listen_response() @@ -982,8 +984,7 @@ def Thread(self, name, target, kwargs): def _make_listen_response(): - from google.cloud.firestore_v1.types.firestore import ListenResponse - from google.cloud.firestore_v1.types.firestore import TargetChange + from google.cloud.firestore_v1.types.firestore import ListenResponse, TargetChange response = ListenResponse() tc = response.target_change