Skip to content

Commit a4b666a

Browse files
authored
feat: Support "limit" in count query. (#384)
* Move the limit to aggregation_query.fetch * Add test coverage
1 parent 953fd52 commit a4b666a

File tree

3 files changed

+54
-6
lines changed

3 files changed

+54
-6
lines changed

google/cloud/datastore/aggregation.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ def add_aggregations(self, aggregations):
174174
def fetch(
175175
self,
176176
client=None,
177+
limit=None,
177178
eventual=False,
178179
retry=None,
179180
timeout=None,
@@ -204,7 +205,7 @@ def fetch(
204205
>>> client.put_multi([andy, sally, bobby])
205206
>>> query = client.query(kind='Andy')
206207
>>> aggregation_query = client.aggregation_query(query)
207-
>>> result = aggregation_query.count(alias="total").fetch()
208+
>>> result = aggregation_query.count(alias="total").fetch(limit=5)
208209
>>> result
209210
<google.cloud.datastore.aggregation.AggregationResultIterator object at ...>
210211
@@ -248,6 +249,7 @@ def fetch(
248249
return AggregationResultIterator(
249250
self,
250251
client,
252+
limit=limit,
251253
eventual=eventual,
252254
retry=retry,
253255
timeout=timeout,
@@ -293,6 +295,7 @@ def __init__(
293295
self,
294296
aggregation_query,
295297
client,
298+
limit=None,
296299
eventual=False,
297300
retry=None,
298301
timeout=None,
@@ -308,6 +311,7 @@ def __init__(
308311
self._retry = retry
309312
self._timeout = timeout
310313
self._read_time = read_time
314+
self._limit = limit
311315
# The attributes below will change over the life of the iterator.
312316
self._more_results = True
313317

@@ -322,6 +326,9 @@ def _build_protobuf(self):
322326
state of the iterator.
323327
"""
324328
pb = self._aggregation_query._to_pb()
329+
if self._limit is not None and self._limit > 0:
330+
for aggregation in pb.aggregations:
331+
aggregation.count.up_to = self._limit
325332
return pb
326333

327334
def _process_query_results(self, response_pb):

tests/system/test_aggregation_query.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,26 @@ def test_aggregation_query_with_alias(aggregation_query_client, nested_query):
9393
assert r.value > 0
9494

9595

96+
def test_aggregation_query_with_limit(aggregation_query_client, nested_query):
97+
query = nested_query
98+
99+
aggregation_query = aggregation_query_client.aggregation_query(query)
100+
aggregation_query.count(alias="total")
101+
result = _do_fetch(aggregation_query) # count without limit
102+
assert len(result) == 1
103+
for r in result[0]:
104+
assert r.alias == "total"
105+
assert r.value == 8
106+
107+
aggregation_query = aggregation_query_client.aggregation_query(query)
108+
aggregation_query.count(alias="total_up_to")
109+
result = _do_fetch(aggregation_query, limit=2) # count with limit = 2
110+
assert len(result) == 1
111+
for r in result[0]:
112+
assert r.alias == "total_up_to"
113+
assert r.value == 2
114+
115+
96116
def test_aggregation_query_multiple_aggregations(
97117
aggregation_query_client, nested_query
98118
):

tests/unit/test_aggregation.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,22 @@ def test_query_fetch_w_explicit_client_w_retry_w_timeout(client):
127127
assert iterator._timeout == timeout
128128

129129

130+
def test_query_fetch_w_explicit_client_w_limit(client):
131+
from google.cloud.datastore.aggregation import AggregationResultIterator
132+
133+
other_client = _make_client()
134+
query = _make_query(client)
135+
aggregation_query = _make_aggregation_query(client=client, query=query)
136+
limit = 2
137+
138+
iterator = aggregation_query.fetch(client=other_client, limit=limit)
139+
140+
assert isinstance(iterator, AggregationResultIterator)
141+
assert iterator._aggregation_query is aggregation_query
142+
assert iterator.client is other_client
143+
assert iterator._limit == limit
144+
145+
130146
def test_iterator_constructor_defaults():
131147
query = object()
132148
client = object()
@@ -149,12 +165,10 @@ def test_iterator_constructor_explicit():
149165
aggregation_query = AggregationQuery(client=client, query=query)
150166
retry = mock.Mock()
151167
timeout = 100000
168+
limit = 2
152169

153170
iterator = _make_aggregation_iterator(
154-
aggregation_query,
155-
client,
156-
retry=retry,
157-
timeout=timeout,
171+
aggregation_query, client, retry=retry, timeout=timeout, limit=limit
158172
)
159173

160174
assert not iterator._started
@@ -165,6 +179,7 @@ def test_iterator_constructor_explicit():
165179
assert iterator._more_results
166180
assert iterator._retry == retry
167181
assert iterator._timeout == timeout
182+
assert iterator._limit == limit
168183

169184

170185
def test_iterator__build_protobuf_empty():
@@ -186,14 +201,20 @@ def test_iterator__build_protobuf_all_values():
186201

187202
client = _Client(None)
188203
query = _make_query(client)
204+
alias = "total"
205+
limit = 2
189206
aggregation_query = AggregationQuery(client=client, query=query)
207+
aggregation_query.count(alias)
190208

191-
iterator = _make_aggregation_iterator(aggregation_query, client)
209+
iterator = _make_aggregation_iterator(aggregation_query, client, limit=limit)
192210
iterator.num_results = 4
193211

194212
pb = iterator._build_protobuf()
195213
expected_pb = query_pb2.AggregationQuery()
196214
expected_pb.nested_query = query_pb2.Query()
215+
expected_count_pb = query_pb2.AggregationQuery.Aggregation(alias=alias)
216+
expected_count_pb.count.up_to = limit
217+
expected_pb.aggregations.append(expected_count_pb)
197218
assert pb == expected_pb
198219

199220

0 commit comments

Comments
 (0)