Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support "limit" in count query. #384

Merged
merged 9 commits into from Nov 30, 2022
9 changes: 8 additions & 1 deletion google/cloud/datastore/aggregation.py
Expand Up @@ -174,6 +174,7 @@ def add_aggregations(self, aggregations):
def fetch(
self,
client=None,
limit=None,
eventual=False,
retry=None,
timeout=None,
Expand Down Expand Up @@ -204,7 +205,7 @@ def fetch(
>>> client.put_multi([andy, sally, bobby])
>>> query = client.query(kind='Andy')
>>> aggregation_query = client.aggregation_query(query)
>>> result = aggregation_query.count(alias="total").fetch()
>>> result = aggregation_query.count(alias="total").fetch(fetch=5)
Mariatta marked this conversation as resolved.
Show resolved Hide resolved
>>> result
<google.cloud.datastore.aggregation.AggregationResultIterator object at ...>

Expand Down Expand Up @@ -248,6 +249,7 @@ def fetch(
return AggregationResultIterator(
self,
client,
limit=limit,
eventual=eventual,
retry=retry,
timeout=timeout,
Expand Down Expand Up @@ -293,6 +295,7 @@ def __init__(
self,
aggregation_query,
client,
limit=None,
eventual=False,
retry=None,
timeout=None,
Expand All @@ -308,6 +311,7 @@ def __init__(
self._retry = retry
self._timeout = timeout
self._read_time = read_time
self._limit = limit
# The attributes below will change over the life of the iterator.
self._more_results = True

Expand All @@ -322,6 +326,9 @@ def _build_protobuf(self):
state of the iterator.
"""
pb = self._aggregation_query._to_pb()
if self._limit is not None and self._limit > 0:
for aggregation in pb.aggregations:
aggregation.count.up_to = self._limit
return pb

def _process_query_results(self, response_pb):
Expand Down
20 changes: 20 additions & 0 deletions tests/system/test_aggregation_query.py
Expand Up @@ -93,6 +93,26 @@ def test_aggregation_query_with_alias(aggregation_query_client, nested_query):
assert r.value > 0


def test_aggregation_query_with_limit(aggregation_query_client, nested_query):
query = nested_query

aggregation_query = aggregation_query_client.aggregation_query(query)
aggregation_query.count(alias="total")
result = _do_fetch(aggregation_query) # count without limit
assert len(result) == 1
for r in result[0]:
assert r.alias == "total"
assert r.value == 8

aggregation_query = aggregation_query_client.aggregation_query(query)
aggregation_query.count(alias="total_up_to")
result = _do_fetch(aggregation_query, limit=2) # count with limit = 2
assert len(result) == 1
for r in result[0]:
assert r.alias == "total_up_to"
assert r.value == 2


def test_aggregation_query_multiple_aggregations(
aggregation_query_client, nested_query
):
Expand Down