Skip to content

Commit c4c5bfa

Browse files
authored
fix: add mocks to query get tests (#109)
1 parent edf7bd1 commit c4c5bfa

File tree

2 files changed

+51
-77
lines changed

2 files changed

+51
-77
lines changed

tests/unit/v1/test_async_query.py

Lines changed: 32 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,16 @@
2121
from tests.unit.v1.test_base_query import _make_credentials, _make_query_response
2222

2323

24+
class MockAsyncIter:
25+
def __init__(self, count=3):
26+
# count is arbitrary value
27+
self.count = count
28+
29+
async def __aiter__(self, **_):
30+
for i in range(self.count):
31+
yield i
32+
33+
2434
class TestAsyncQuery(aiounittest.AsyncTestCase):
2535
@staticmethod
2636
def _get_target_class():
@@ -45,53 +55,37 @@ def test_constructor(self):
4555
self.assertFalse(query._all_descendants)
4656

4757
@pytest.mark.asyncio
48-
async def test_get_simple(self):
58+
async def test_get(self):
4959
import warnings
5060

51-
# Create a minimal fake GAPIC.
52-
firestore_api = mock.Mock(spec=["run_query"])
61+
with mock.patch.object(self._get_target_class(), "stream") as stream_mock:
62+
stream_mock.return_value = MockAsyncIter(3)
5363

54-
# Attach the fake GAPIC to a real client.
55-
client = _make_client()
56-
client._firestore_api_internal = firestore_api
64+
# Create a minimal fake GAPIC.
65+
firestore_api = mock.Mock(spec=["run_query"])
5766

58-
# Make a **real** collection reference as parent.
59-
parent = client.collection("dee")
67+
# Attach the fake GAPIC to a real client.
68+
client = _make_client()
69+
client._firestore_api_internal = firestore_api
6070

61-
# Add a dummy response to the minimal fake GAPIC.
62-
_, expected_prefix = parent._parent_info()
63-
name = "{}/sleep".format(expected_prefix)
64-
data = {"snooze": 10}
65-
response_pb = _make_query_response(name=name, data=data)
66-
firestore_api.run_query.return_value = iter([response_pb])
71+
# Make a **real** collection reference as parent.
72+
parent = client.collection("dee")
6773

68-
# Execute the query and check the response.
69-
query = self._make_one(parent)
70-
71-
with warnings.catch_warnings(record=True) as warned:
72-
get_response = query.get()
73-
self.assertIsInstance(get_response, types.AsyncGeneratorType)
74-
returned = [x async for x in get_response]
74+
# Execute the query and check the response.
75+
query = self._make_one(parent)
7576

76-
self.assertEqual(len(returned), 1)
77-
snapshot = returned[0]
78-
self.assertEqual(snapshot.reference._path, ("dee", "sleep"))
79-
self.assertEqual(snapshot.to_dict(), data)
77+
with warnings.catch_warnings(record=True) as warned:
78+
get_response = query.get()
79+
returned = [x async for x in get_response]
8080

81-
# Verify the mock call.
82-
parent_path, _ = parent._parent_info()
83-
firestore_api.run_query.assert_called_once_with(
84-
request={
85-
"parent": parent_path,
86-
"structured_query": query._to_protobuf(),
87-
"transaction": None,
88-
},
89-
metadata=client._rpc_metadata,
90-
)
81+
# Verify that `get` merely wraps `stream`.
82+
stream_mock.assert_called_once()
83+
self.assertIsInstance(get_response, types.AsyncGeneratorType)
84+
self.assertEqual(returned, list(range(stream_mock.return_value.count)))
9185

92-
# Verify the deprecation
93-
self.assertEqual(len(warned), 1)
94-
self.assertIs(warned[0].category, DeprecationWarning)
86+
# Verify the deprecation.
87+
self.assertEqual(len(warned), 1)
88+
self.assertIs(warned[0].category, DeprecationWarning)
9589

9690
@pytest.mark.asyncio
9791
async def test_stream_simple(self):

tests/unit/v1/test_query.py

Lines changed: 19 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -43,53 +43,33 @@ def test_constructor(self):
4343
self.assertIsNone(query._end_at)
4444
self.assertFalse(query._all_descendants)
4545

46-
def test_get_simple(self):
46+
def test_get(self):
4747
import warnings
4848

49-
# Create a minimal fake GAPIC.
50-
firestore_api = mock.Mock(spec=["run_query"])
49+
with mock.patch.object(self._get_target_class(), "stream") as stream_mock:
50+
# Create a minimal fake GAPIC.
51+
firestore_api = mock.Mock(spec=["run_query"])
5152

52-
# Attach the fake GAPIC to a real client.
53-
client = _make_client()
54-
client._firestore_api_internal = firestore_api
53+
# Attach the fake GAPIC to a real client.
54+
client = _make_client()
55+
client._firestore_api_internal = firestore_api
5556

56-
# Make a **real** collection reference as parent.
57-
parent = client.collection("dee")
57+
# Make a **real** collection reference as parent.
58+
parent = client.collection("dee")
5859

59-
# Add a dummy response to the minimal fake GAPIC.
60-
_, expected_prefix = parent._parent_info()
61-
name = "{}/sleep".format(expected_prefix)
62-
data = {"snooze": 10}
63-
response_pb = _make_query_response(name=name, data=data)
64-
firestore_api.run_query.return_value = iter([response_pb])
60+
# Execute the query and check the response.
61+
query = self._make_one(parent)
6562

66-
# Execute the query and check the response.
67-
query = self._make_one(parent)
63+
with warnings.catch_warnings(record=True) as warned:
64+
get_response = query.get()
6865

69-
with warnings.catch_warnings(record=True) as warned:
70-
get_response = query.get()
71-
72-
self.assertIsInstance(get_response, types.GeneratorType)
73-
returned = list(get_response)
74-
self.assertEqual(len(returned), 1)
75-
snapshot = returned[0]
76-
self.assertEqual(snapshot.reference._path, ("dee", "sleep"))
77-
self.assertEqual(snapshot.to_dict(), data)
78-
79-
# Verify the mock call.
80-
parent_path, _ = parent._parent_info()
81-
firestore_api.run_query.assert_called_once_with(
82-
request={
83-
"parent": parent_path,
84-
"structured_query": query._to_protobuf(),
85-
"transaction": None,
86-
},
87-
metadata=client._rpc_metadata,
88-
)
66+
# Verify that `get` merely wraps `stream`.
67+
stream_mock.assert_called_once()
68+
self.assertEqual(get_response, stream_mock.return_value)
8969

90-
# Verify the deprecation
91-
self.assertEqual(len(warned), 1)
92-
self.assertIs(warned[0].category, DeprecationWarning)
70+
# Verify the deprecation.
71+
self.assertEqual(len(warned), 1)
72+
self.assertIs(warned[0].category, DeprecationWarning)
9373

9474
def test_stream_simple(self):
9575
# Create a minimal fake GAPIC.

0 commit comments

Comments
 (0)