Skip to content

Commit ae1247b

Browse files
authored
fix: improve AsyncQuery typing (#782)
1 parent d07eebf commit ae1247b

File tree

5 files changed

+52
-40
lines changed

5 files changed

+52
-40
lines changed

google/cloud/firestore_v1/async_collection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from google.cloud.firestore_v1.transaction import Transaction
3333

3434

35-
class AsyncCollectionReference(BaseCollectionReference):
35+
class AsyncCollectionReference(BaseCollectionReference[async_query.AsyncQuery]):
3636
"""A reference to a collection in a Firestore database.
3737
3838
The collection may already exist or this class can facilitate creation

google/cloud/firestore_v1/base_client.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -262,13 +262,15 @@ def _rpc_metadata(self):
262262

263263
return self._rpc_metadata_internal
264264

265-
def collection(self, *collection_path) -> BaseCollectionReference:
265+
def collection(self, *collection_path) -> BaseCollectionReference[BaseQuery]:
266266
raise NotImplementedError
267267

268268
def collection_group(self, collection_id: str) -> BaseQuery:
269269
raise NotImplementedError
270270

271-
def _get_collection_reference(self, collection_id: str) -> BaseCollectionReference:
271+
def _get_collection_reference(
272+
self, collection_id: str
273+
) -> BaseCollectionReference[BaseQuery]:
272274
"""Checks validity of collection_id and then uses subclasses collection implementation.
273275
274276
Args:
@@ -325,7 +327,7 @@ def _document_path_helper(self, *document_path) -> List[str]:
325327

326328
def recursive_delete(
327329
self,
328-
reference: Union[BaseCollectionReference, BaseDocumentReference],
330+
reference: Union[BaseCollectionReference[BaseQuery], BaseDocumentReference],
329331
bulk_writer: Optional["BulkWriter"] = None, # type: ignore
330332
) -> int:
331333
raise NotImplementedError
@@ -459,8 +461,8 @@ def collections(
459461
retry: retries.Retry = None,
460462
timeout: float = None,
461463
) -> Union[
462-
AsyncGenerator[BaseCollectionReference, Any],
463-
Generator[BaseCollectionReference, Any, Any],
464+
AsyncGenerator[BaseCollectionReference[BaseQuery], Any],
465+
Generator[BaseCollectionReference[BaseQuery], Any, Any],
464466
]:
465467
raise NotImplementedError
466468

google/cloud/firestore_v1/base_collection.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
AsyncGenerator,
2929
Coroutine,
3030
Generator,
31+
Generic,
3132
AsyncIterator,
3233
Iterator,
3334
Iterable,
@@ -38,13 +39,13 @@
3839

3940
# Types needed only for Type Hints
4041
from google.cloud.firestore_v1.base_document import DocumentSnapshot
41-
from google.cloud.firestore_v1.base_query import BaseQuery
42+
from google.cloud.firestore_v1.base_query import QueryType
4243
from google.cloud.firestore_v1.transaction import Transaction
4344

4445
_AUTO_ID_CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
4546

4647

47-
class BaseCollectionReference(object):
48+
class BaseCollectionReference(Generic[QueryType]):
4849
"""A reference to a collection in a Firestore database.
4950
5051
The collection may already exist or this class can facilitate creation
@@ -108,7 +109,7 @@ def parent(self):
108109
parent_path = self._path[:-1]
109110
return self._client.document(*parent_path)
110111

111-
def _query(self) -> BaseQuery:
112+
def _query(self) -> QueryType:
112113
raise NotImplementedError
113114

114115
def _aggregation_query(self) -> BaseAggregationQuery:
@@ -215,10 +216,10 @@ def list_documents(
215216
]:
216217
raise NotImplementedError
217218

218-
def recursive(self) -> "BaseQuery":
219+
def recursive(self) -> QueryType:
219220
return self._query().recursive()
220221

221-
def select(self, field_paths: Iterable[str]) -> BaseQuery:
222+
def select(self, field_paths: Iterable[str]) -> QueryType:
222223
"""Create a "select" query with this collection as parent.
223224
224225
See
@@ -244,7 +245,7 @@ def where(
244245
value=None,
245246
*,
246247
filter=None
247-
) -> BaseQuery:
248+
) -> QueryType:
248249
"""Create a "where" query with this collection as parent.
249250
250251
See
@@ -290,7 +291,7 @@ def where(
290291
else:
291292
return query.where(filter=filter)
292293

293-
def order_by(self, field_path: str, **kwargs) -> BaseQuery:
294+
def order_by(self, field_path: str, **kwargs) -> QueryType:
294295
"""Create an "order by" query with this collection as parent.
295296
296297
See
@@ -312,7 +313,7 @@ def order_by(self, field_path: str, **kwargs) -> BaseQuery:
312313
query = self._query()
313314
return query.order_by(field_path, **kwargs)
314315

315-
def limit(self, count: int) -> BaseQuery:
316+
def limit(self, count: int) -> QueryType:
316317
"""Create a limited query with this collection as parent.
317318
318319
.. note::
@@ -355,7 +356,7 @@ def limit_to_last(self, count: int):
355356
query = self._query()
356357
return query.limit_to_last(count)
357358

358-
def offset(self, num_to_skip: int) -> BaseQuery:
359+
def offset(self, num_to_skip: int) -> QueryType:
359360
"""Skip to an offset in a query with this collection as parent.
360361
361362
See
@@ -375,7 +376,7 @@ def offset(self, num_to_skip: int) -> BaseQuery:
375376

376377
def start_at(
377378
self, document_fields: Union[DocumentSnapshot, dict, list, tuple]
378-
) -> BaseQuery:
379+
) -> QueryType:
379380
"""Start query at a cursor with this collection as parent.
380381
381382
See
@@ -398,7 +399,7 @@ def start_at(
398399

399400
def start_after(
400401
self, document_fields: Union[DocumentSnapshot, dict, list, tuple]
401-
) -> BaseQuery:
402+
) -> QueryType:
402403
"""Start query after a cursor with this collection as parent.
403404
404405
See
@@ -421,7 +422,7 @@ def start_after(
421422

422423
def end_before(
423424
self, document_fields: Union[DocumentSnapshot, dict, list, tuple]
424-
) -> BaseQuery:
425+
) -> QueryType:
425426
"""End query before a cursor with this collection as parent.
426427
427428
See
@@ -444,7 +445,7 @@ def end_before(
444445

445446
def end_at(
446447
self, document_fields: Union[DocumentSnapshot, dict, list, tuple]
447-
) -> BaseQuery:
448+
) -> QueryType:
448449
"""End query at a cursor with this collection as parent.
449450
450451
See

google/cloud/firestore_v1/base_query.py

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
Optional,
4848
Tuple,
4949
Type,
50+
TypeVar,
5051
Union,
5152
)
5253

@@ -102,6 +103,8 @@
102103

103104
_not_passed = object()
104105

106+
QueryType = TypeVar("QueryType", bound="BaseQuery")
107+
105108

106109
class BaseFilter(abc.ABC):
107110
"""Base class for Filters"""
@@ -319,7 +322,7 @@ def _client(self):
319322
"""
320323
return self._parent._client
321324

322-
def select(self, field_paths: Iterable[str]) -> "BaseQuery":
325+
def select(self: QueryType, field_paths: Iterable[str]) -> QueryType:
323326
"""Project documents matching query to a limited set of fields.
324327
325328
See :meth:`~google.cloud.firestore_v1.client.Client.field_path` for
@@ -354,7 +357,7 @@ def select(self, field_paths: Iterable[str]) -> "BaseQuery":
354357
return self._copy(projection=new_projection)
355358

356359
def _copy(
357-
self,
360+
self: QueryType,
358361
*,
359362
projection: Optional[query.StructuredQuery.Projection] = _not_passed,
360363
field_filters: Optional[Tuple[query.StructuredQuery.FieldFilter]] = _not_passed,
@@ -366,7 +369,7 @@ def _copy(
366369
end_at: Optional[Tuple[dict, bool]] = _not_passed,
367370
all_descendants: Optional[bool] = _not_passed,
368371
recursive: Optional[bool] = _not_passed,
369-
) -> "BaseQuery":
372+
) -> QueryType:
370373
return self.__class__(
371374
self._parent,
372375
projection=self._evaluate_param(projection, self._projection),
@@ -389,13 +392,13 @@ def _evaluate_param(self, value, fallback_value):
389392
return value if value is not _not_passed else fallback_value
390393

391394
def where(
392-
self,
395+
self: QueryType,
393396
field_path: Optional[str] = None,
394397
op_string: Optional[str] = None,
395398
value=None,
396399
*,
397400
filter=None,
398-
) -> "BaseQuery":
401+
) -> QueryType:
399402
"""Filter the query on a field.
400403
401404
See :meth:`~google.cloud.firestore_v1.client.Client.field_path` for
@@ -492,7 +495,9 @@ def _make_order(field_path, direction) -> StructuredQuery.Order:
492495
direction=_enum_from_direction(direction),
493496
)
494497

495-
def order_by(self, field_path: str, direction: str = ASCENDING) -> "BaseQuery":
498+
def order_by(
499+
self: QueryType, field_path: str, direction: str = ASCENDING
500+
) -> QueryType:
496501
"""Modify the query to add an order clause on a specific field.
497502
498503
See :meth:`~google.cloud.firestore_v1.client.Client.field_path` for
@@ -526,7 +531,7 @@ def order_by(self, field_path: str, direction: str = ASCENDING) -> "BaseQuery":
526531
new_orders = self._orders + (order_pb,)
527532
return self._copy(orders=new_orders)
528533

529-
def limit(self, count: int) -> "BaseQuery":
534+
def limit(self: QueryType, count: int) -> QueryType:
530535
"""Limit a query to return at most `count` matching results.
531536
532537
If the current query already has a `limit` set, this will override it.
@@ -545,7 +550,7 @@ def limit(self, count: int) -> "BaseQuery":
545550
"""
546551
return self._copy(limit=count, limit_to_last=False)
547552

548-
def limit_to_last(self, count: int) -> "BaseQuery":
553+
def limit_to_last(self: QueryType, count: int) -> QueryType:
549554
"""Limit a query to return the last `count` matching results.
550555
If the current query already has a `limit_to_last`
551556
set, this will override it.
@@ -570,7 +575,7 @@ def _resolve_chunk_size(self, num_loaded: int, chunk_size: int) -> int:
570575
return max(self._limit - num_loaded, 0)
571576
return chunk_size
572577

573-
def offset(self, num_to_skip: int) -> "BaseQuery":
578+
def offset(self: QueryType, num_to_skip: int) -> QueryType:
574579
"""Skip to an offset in a query.
575580
576581
If the current query already has specified an offset, this will
@@ -601,11 +606,11 @@ def _check_snapshot(self, document_snapshot) -> None:
601606
raise ValueError("Cannot use snapshot from another collection as a cursor.")
602607

603608
def _cursor_helper(
604-
self,
609+
self: QueryType,
605610
document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple],
606611
before: bool,
607612
start: bool,
608-
) -> "BaseQuery":
613+
) -> QueryType:
609614
"""Set values to be used for a ``start_at`` or ``end_at`` cursor.
610615
611616
The values will later be used in a query protobuf.
@@ -658,8 +663,9 @@ def _cursor_helper(
658663
return self._copy(**query_kwargs)
659664

660665
def start_at(
661-
self, document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple]
662-
) -> "BaseQuery":
666+
self: QueryType,
667+
document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple],
668+
) -> QueryType:
663669
"""Start query results at a particular document value.
664670
665671
The result set will **include** the document specified by
@@ -690,8 +696,9 @@ def start_at(
690696
return self._cursor_helper(document_fields_or_snapshot, before=True, start=True)
691697

692698
def start_after(
693-
self, document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple]
694-
) -> "BaseQuery":
699+
self: QueryType,
700+
document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple],
701+
) -> QueryType:
695702
"""Start query results after a particular document value.
696703
697704
The result set will **exclude** the document specified by
@@ -723,8 +730,9 @@ def start_after(
723730
)
724731

725732
def end_before(
726-
self, document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple]
727-
) -> "BaseQuery":
733+
self: QueryType,
734+
document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple],
735+
) -> QueryType:
728736
"""End query results before a particular document value.
729737
730738
The result set will **exclude** the document specified by
@@ -756,8 +764,9 @@ def end_before(
756764
)
757765

758766
def end_at(
759-
self, document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple]
760-
) -> "BaseQuery":
767+
self: QueryType,
768+
document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple],
769+
) -> QueryType:
761770
"""End query results at a particular document value.
762771
763772
The result set will **include** the document specified by
@@ -1003,7 +1012,7 @@ def stream(
10031012
def on_snapshot(self, callback) -> NoReturn:
10041013
raise NotImplementedError
10051014

1006-
def recursive(self) -> "BaseQuery":
1015+
def recursive(self: QueryType) -> QueryType:
10071016
"""Returns a copy of this query whose iterator will yield all matching
10081017
documents as well as each of their descendent subcollections and documents.
10091018

google/cloud/firestore_v1/collection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from google.cloud.firestore_v1.transaction import Transaction
3232

3333

34-
class CollectionReference(BaseCollectionReference):
34+
class CollectionReference(BaseCollectionReference[query_mod.Query]):
3535
"""A reference to a collection in a Firestore database.
3636
3737
The collection may already exist or this class can facilitate creation

0 commit comments

Comments
 (0)