47
47
Optional ,
48
48
Tuple ,
49
49
Type ,
50
+ TypeVar ,
50
51
Union ,
51
52
)
52
53
102
103
103
104
_not_passed = object ()
104
105
106
+ QueryType = TypeVar ("QueryType" , bound = "BaseQuery" )
107
+
105
108
106
109
class BaseFilter (abc .ABC ):
107
110
"""Base class for Filters"""
@@ -319,7 +322,7 @@ def _client(self):
319
322
"""
320
323
return self ._parent ._client
321
324
322
- def select (self , field_paths : Iterable [str ]) -> "BaseQuery" :
325
+ def select (self : QueryType , field_paths : Iterable [str ]) -> QueryType :
323
326
"""Project documents matching query to a limited set of fields.
324
327
325
328
See :meth:`~google.cloud.firestore_v1.client.Client.field_path` for
@@ -354,7 +357,7 @@ def select(self, field_paths: Iterable[str]) -> "BaseQuery":
354
357
return self ._copy (projection = new_projection )
355
358
356
359
def _copy (
357
- self ,
360
+ self : QueryType ,
358
361
* ,
359
362
projection : Optional [query .StructuredQuery .Projection ] = _not_passed ,
360
363
field_filters : Optional [Tuple [query .StructuredQuery .FieldFilter ]] = _not_passed ,
@@ -366,7 +369,7 @@ def _copy(
366
369
end_at : Optional [Tuple [dict , bool ]] = _not_passed ,
367
370
all_descendants : Optional [bool ] = _not_passed ,
368
371
recursive : Optional [bool ] = _not_passed ,
369
- ) -> "BaseQuery" :
372
+ ) -> QueryType :
370
373
return self .__class__ (
371
374
self ._parent ,
372
375
projection = self ._evaluate_param (projection , self ._projection ),
@@ -389,13 +392,13 @@ def _evaluate_param(self, value, fallback_value):
389
392
return value if value is not _not_passed else fallback_value
390
393
391
394
def where (
392
- self ,
395
+ self : QueryType ,
393
396
field_path : Optional [str ] = None ,
394
397
op_string : Optional [str ] = None ,
395
398
value = None ,
396
399
* ,
397
400
filter = None ,
398
- ) -> "BaseQuery" :
401
+ ) -> QueryType :
399
402
"""Filter the query on a field.
400
403
401
404
See :meth:`~google.cloud.firestore_v1.client.Client.field_path` for
@@ -492,7 +495,9 @@ def _make_order(field_path, direction) -> StructuredQuery.Order:
492
495
direction = _enum_from_direction (direction ),
493
496
)
494
497
495
- def order_by (self , field_path : str , direction : str = ASCENDING ) -> "BaseQuery" :
498
+ def order_by (
499
+ self : QueryType , field_path : str , direction : str = ASCENDING
500
+ ) -> QueryType :
496
501
"""Modify the query to add an order clause on a specific field.
497
502
498
503
See :meth:`~google.cloud.firestore_v1.client.Client.field_path` for
@@ -526,7 +531,7 @@ def order_by(self, field_path: str, direction: str = ASCENDING) -> "BaseQuery":
526
531
new_orders = self ._orders + (order_pb ,)
527
532
return self ._copy (orders = new_orders )
528
533
529
- def limit (self , count : int ) -> "BaseQuery" :
534
+ def limit (self : QueryType , count : int ) -> QueryType :
530
535
"""Limit a query to return at most `count` matching results.
531
536
532
537
If the current query already has a `limit` set, this will override it.
@@ -545,7 +550,7 @@ def limit(self, count: int) -> "BaseQuery":
545
550
"""
546
551
return self ._copy (limit = count , limit_to_last = False )
547
552
548
- def limit_to_last (self , count : int ) -> "BaseQuery" :
553
+ def limit_to_last (self : QueryType , count : int ) -> QueryType :
549
554
"""Limit a query to return the last `count` matching results.
550
555
If the current query already has a `limit_to_last`
551
556
set, this will override it.
@@ -570,7 +575,7 @@ def _resolve_chunk_size(self, num_loaded: int, chunk_size: int) -> int:
570
575
return max (self ._limit - num_loaded , 0 )
571
576
return chunk_size
572
577
573
- def offset (self , num_to_skip : int ) -> "BaseQuery" :
578
+ def offset (self : QueryType , num_to_skip : int ) -> QueryType :
574
579
"""Skip to an offset in a query.
575
580
576
581
If the current query already has specified an offset, this will
@@ -601,11 +606,11 @@ def _check_snapshot(self, document_snapshot) -> None:
601
606
raise ValueError ("Cannot use snapshot from another collection as a cursor." )
602
607
603
608
def _cursor_helper (
604
- self ,
609
+ self : QueryType ,
605
610
document_fields_or_snapshot : Union [DocumentSnapshot , dict , list , tuple ],
606
611
before : bool ,
607
612
start : bool ,
608
- ) -> "BaseQuery" :
613
+ ) -> QueryType :
609
614
"""Set values to be used for a ``start_at`` or ``end_at`` cursor.
610
615
611
616
The values will later be used in a query protobuf.
@@ -658,8 +663,9 @@ def _cursor_helper(
658
663
return self ._copy (** query_kwargs )
659
664
660
665
def start_at (
661
- self , document_fields_or_snapshot : Union [DocumentSnapshot , dict , list , tuple ]
662
- ) -> "BaseQuery" :
666
+ self : QueryType ,
667
+ document_fields_or_snapshot : Union [DocumentSnapshot , dict , list , tuple ],
668
+ ) -> QueryType :
663
669
"""Start query results at a particular document value.
664
670
665
671
The result set will **include** the document specified by
@@ -690,8 +696,9 @@ def start_at(
690
696
return self ._cursor_helper (document_fields_or_snapshot , before = True , start = True )
691
697
692
698
def start_after (
693
- self , document_fields_or_snapshot : Union [DocumentSnapshot , dict , list , tuple ]
694
- ) -> "BaseQuery" :
699
+ self : QueryType ,
700
+ document_fields_or_snapshot : Union [DocumentSnapshot , dict , list , tuple ],
701
+ ) -> QueryType :
695
702
"""Start query results after a particular document value.
696
703
697
704
The result set will **exclude** the document specified by
@@ -723,8 +730,9 @@ def start_after(
723
730
)
724
731
725
732
def end_before (
726
- self , document_fields_or_snapshot : Union [DocumentSnapshot , dict , list , tuple ]
727
- ) -> "BaseQuery" :
733
+ self : QueryType ,
734
+ document_fields_or_snapshot : Union [DocumentSnapshot , dict , list , tuple ],
735
+ ) -> QueryType :
728
736
"""End query results before a particular document value.
729
737
730
738
The result set will **exclude** the document specified by
@@ -756,8 +764,9 @@ def end_before(
756
764
)
757
765
758
766
def end_at (
759
- self , document_fields_or_snapshot : Union [DocumentSnapshot , dict , list , tuple ]
760
- ) -> "BaseQuery" :
767
+ self : QueryType ,
768
+ document_fields_or_snapshot : Union [DocumentSnapshot , dict , list , tuple ],
769
+ ) -> QueryType :
761
770
"""End query results at a particular document value.
762
771
763
772
The result set will **include** the document specified by
@@ -1003,7 +1012,7 @@ def stream(
1003
1012
def on_snapshot (self , callback ) -> NoReturn :
1004
1013
raise NotImplementedError
1005
1014
1006
- def recursive (self ) -> "BaseQuery" :
1015
+ def recursive (self : QueryType ) -> QueryType :
1007
1016
"""Returns a copy of this query whose iterator will yield all matching
1008
1017
documents as well as each of their descendent subcollections and documents.
1009
1018
0 commit comments