Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support "limit" in count query. #384

Merged
merged 9 commits into from Nov 30, 2022
9 changes: 8 additions & 1 deletion google/cloud/datastore/aggregation.py
Expand Up @@ -174,6 +174,7 @@ def add_aggregations(self, aggregations):
def fetch(
self,
client=None,
limit=None,
eventual=False,
retry=None,
timeout=None,
Expand Down Expand Up @@ -204,7 +205,7 @@ def fetch(
>>> client.put_multi([andy, sally, bobby])
>>> query = client.query(kind='Andy')
>>> aggregation_query = client.aggregation_query(query)
>>> result = aggregation_query.count(alias="total").fetch()
>>> result = aggregation_query.count(alias="total").fetch(limit=5)
>>> result
<google.cloud.datastore.aggregation.AggregationResultIterator object at ...>

Expand Down Expand Up @@ -248,6 +249,7 @@ def fetch(
return AggregationResultIterator(
self,
client,
limit=limit,
eventual=eventual,
retry=retry,
timeout=timeout,
Expand Down Expand Up @@ -293,6 +295,7 @@ def __init__(
self,
aggregation_query,
client,
limit=None,
eventual=False,
retry=None,
timeout=None,
Expand All @@ -308,6 +311,7 @@ def __init__(
self._retry = retry
self._timeout = timeout
self._read_time = read_time
self._limit = limit
# The attributes below will change over the life of the iterator.
self._more_results = True

Expand All @@ -322,6 +326,9 @@ def _build_protobuf(self):
state of the iterator.
"""
pb = self._aggregation_query._to_pb()
if self._limit is not None and self._limit > 0:
for aggregation in pb.aggregations:
aggregation.count.up_to = self._limit
return pb

def _process_query_results(self, response_pb):
Expand Down
20 changes: 20 additions & 0 deletions tests/system/test_aggregation_query.py
Expand Up @@ -93,6 +93,26 @@ def test_aggregation_query_with_alias(aggregation_query_client, nested_query):
assert r.value > 0


def test_aggregation_query_with_limit(aggregation_query_client, nested_query):
query = nested_query

aggregation_query = aggregation_query_client.aggregation_query(query)
aggregation_query.count(alias="total")
result = _do_fetch(aggregation_query) # count without limit
assert len(result) == 1
for r in result[0]:
assert r.alias == "total"
assert r.value == 8

aggregation_query = aggregation_query_client.aggregation_query(query)
aggregation_query.count(alias="total_up_to")
result = _do_fetch(aggregation_query, limit=2) # count with limit = 2
assert len(result) == 1
for r in result[0]:
assert r.alias == "total_up_to"
assert r.value == 2


def test_aggregation_query_multiple_aggregations(
aggregation_query_client, nested_query
):
Expand Down
31 changes: 26 additions & 5 deletions tests/unit/test_aggregation.py
Expand Up @@ -127,6 +127,22 @@ def test_query_fetch_w_explicit_client_w_retry_w_timeout(client):
assert iterator._timeout == timeout


def test_query_fetch_w_explicit_client_w_limit(client):
from google.cloud.datastore.aggregation import AggregationResultIterator

other_client = _make_client()
query = _make_query(client)
aggregation_query = _make_aggregation_query(client=client, query=query)
limit = 2

iterator = aggregation_query.fetch(client=other_client, limit=limit)

assert isinstance(iterator, AggregationResultIterator)
assert iterator._aggregation_query is aggregation_query
assert iterator.client is other_client
assert iterator._limit == limit


def test_iterator_constructor_defaults():
query = object()
client = object()
Expand All @@ -149,12 +165,10 @@ def test_iterator_constructor_explicit():
aggregation_query = AggregationQuery(client=client, query=query)
retry = mock.Mock()
timeout = 100000
limit = 2

iterator = _make_aggregation_iterator(
aggregation_query,
client,
retry=retry,
timeout=timeout,
aggregation_query, client, retry=retry, timeout=timeout, limit=limit
)

assert not iterator._started
Expand All @@ -165,6 +179,7 @@ def test_iterator_constructor_explicit():
assert iterator._more_results
assert iterator._retry == retry
assert iterator._timeout == timeout
assert iterator._limit == limit


def test_iterator__build_protobuf_empty():
Expand All @@ -186,14 +201,20 @@ def test_iterator__build_protobuf_all_values():

client = _Client(None)
query = _make_query(client)
alias = "total"
limit = 2
aggregation_query = AggregationQuery(client=client, query=query)
aggregation_query.count(alias)

iterator = _make_aggregation_iterator(aggregation_query, client)
iterator = _make_aggregation_iterator(aggregation_query, client, limit=limit)
iterator.num_results = 4

pb = iterator._build_protobuf()
expected_pb = query_pb2.AggregationQuery()
expected_pb.nested_query = query_pb2.Query()
expected_count_pb = query_pb2.AggregationQuery.Aggregation(alias=alias)
expected_count_pb.count.up_to = limit
expected_pb.aggregations.append(expected_count_pb)
assert pb == expected_pb


Expand Down