21
21
from tests .unit .v1 .test_base_query import _make_credentials , _make_query_response
22
22
23
23
24
+ class MockAsyncIter :
25
+ def __init__ (self , count = 3 ):
26
+ # count is arbitrary value
27
+ self .count = count
28
+
29
+ async def __aiter__ (self , ** _ ):
30
+ for i in range (self .count ):
31
+ yield i
32
+
33
+
24
34
class TestAsyncQuery (aiounittest .AsyncTestCase ):
25
35
@staticmethod
26
36
def _get_target_class ():
@@ -45,53 +55,37 @@ def test_constructor(self):
45
55
self .assertFalse (query ._all_descendants )
46
56
47
57
@pytest .mark .asyncio
48
- async def test_get_simple (self ):
58
+ async def test_get (self ):
49
59
import warnings
50
60
51
- # Create a minimal fake GAPIC.
52
- firestore_api = mock . Mock ( spec = [ "run_query" ] )
61
+ with mock . patch . object ( self . _get_target_class (), "stream" ) as stream_mock :
62
+ stream_mock . return_value = MockAsyncIter ( 3 )
53
63
54
- # Attach the fake GAPIC to a real client.
55
- client = _make_client ()
56
- client ._firestore_api_internal = firestore_api
64
+ # Create a minimal fake GAPIC.
65
+ firestore_api = mock .Mock (spec = ["run_query" ])
57
66
58
- # Make a **real** collection reference as parent.
59
- parent = client .collection ("dee" )
67
+ # Attach the fake GAPIC to a real client.
68
+ client = _make_client ()
69
+ client ._firestore_api_internal = firestore_api
60
70
61
- # Add a dummy response to the minimal fake GAPIC.
62
- _ , expected_prefix = parent ._parent_info ()
63
- name = "{}/sleep" .format (expected_prefix )
64
- data = {"snooze" : 10 }
65
- response_pb = _make_query_response (name = name , data = data )
66
- firestore_api .run_query .return_value = iter ([response_pb ])
71
+ # Make a **real** collection reference as parent.
72
+ parent = client .collection ("dee" )
67
73
68
- # Execute the query and check the response.
69
- query = self ._make_one (parent )
70
-
71
- with warnings .catch_warnings (record = True ) as warned :
72
- get_response = query .get ()
73
- self .assertIsInstance (get_response , types .AsyncGeneratorType )
74
- returned = [x async for x in get_response ]
74
+ # Execute the query and check the response.
75
+ query = self ._make_one (parent )
75
76
76
- self .assertEqual (len (returned ), 1 )
77
- snapshot = returned [0 ]
78
- self .assertEqual (snapshot .reference ._path , ("dee" , "sleep" ))
79
- self .assertEqual (snapshot .to_dict (), data )
77
+ with warnings .catch_warnings (record = True ) as warned :
78
+ get_response = query .get ()
79
+ returned = [x async for x in get_response ]
80
80
81
- # Verify the mock call.
82
- parent_path , _ = parent ._parent_info ()
83
- firestore_api .run_query .assert_called_once_with (
84
- request = {
85
- "parent" : parent_path ,
86
- "structured_query" : query ._to_protobuf (),
87
- "transaction" : None ,
88
- },
89
- metadata = client ._rpc_metadata ,
90
- )
81
+ # Verify that `get` merely wraps `stream`.
82
+ stream_mock .assert_called_once ()
83
+ self .assertIsInstance (get_response , types .AsyncGeneratorType )
84
+ self .assertEqual (returned , list (range (stream_mock .return_value .count )))
91
85
92
- # Verify the deprecation
93
- self .assertEqual (len (warned ), 1 )
94
- self .assertIs (warned [0 ].category , DeprecationWarning )
86
+ # Verify the deprecation.
87
+ self .assertEqual (len (warned ), 1 )
88
+ self .assertIs (warned [0 ].category , DeprecationWarning )
95
89
96
90
@pytest .mark .asyncio
97
91
async def test_stream_simple (self ):
0 commit comments