Skip to content

Commit

Permalink
feat: Support "limit" in count query. (#384)
Browse files Browse the repository at this point in the history
* Move the limit to aggregation_query.fetch
* Add test coverage
  • Loading branch information
Mariatta committed Nov 30, 2022
1 parent 953fd52 commit a4b666a
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 6 deletions.
9 changes: 8 additions & 1 deletion google/cloud/datastore/aggregation.py
Expand Up @@ -174,6 +174,7 @@ def add_aggregations(self, aggregations):
def fetch(
self,
client=None,
limit=None,
eventual=False,
retry=None,
timeout=None,
Expand Down Expand Up @@ -204,7 +205,7 @@ def fetch(
>>> client.put_multi([andy, sally, bobby])
>>> query = client.query(kind='Andy')
>>> aggregation_query = client.aggregation_query(query)
>>> result = aggregation_query.count(alias="total").fetch()
>>> result = aggregation_query.count(alias="total").fetch(limit=5)
>>> result
<google.cloud.datastore.aggregation.AggregationResultIterator object at ...>
Expand Down Expand Up @@ -248,6 +249,7 @@ def fetch(
return AggregationResultIterator(
self,
client,
limit=limit,
eventual=eventual,
retry=retry,
timeout=timeout,
Expand Down Expand Up @@ -293,6 +295,7 @@ def __init__(
self,
aggregation_query,
client,
limit=None,
eventual=False,
retry=None,
timeout=None,
Expand All @@ -308,6 +311,7 @@ def __init__(
self._retry = retry
self._timeout = timeout
self._read_time = read_time
self._limit = limit
# The attributes below will change over the life of the iterator.
self._more_results = True

Expand All @@ -322,6 +326,9 @@ def _build_protobuf(self):
state of the iterator.
"""
pb = self._aggregation_query._to_pb()
if self._limit is not None and self._limit > 0:
for aggregation in pb.aggregations:
aggregation.count.up_to = self._limit
return pb

def _process_query_results(self, response_pb):
Expand Down
20 changes: 20 additions & 0 deletions tests/system/test_aggregation_query.py
Expand Up @@ -93,6 +93,26 @@ def test_aggregation_query_with_alias(aggregation_query_client, nested_query):
assert r.value > 0


def test_aggregation_query_with_limit(aggregation_query_client, nested_query):
query = nested_query

aggregation_query = aggregation_query_client.aggregation_query(query)
aggregation_query.count(alias="total")
result = _do_fetch(aggregation_query) # count without limit
assert len(result) == 1
for r in result[0]:
assert r.alias == "total"
assert r.value == 8

aggregation_query = aggregation_query_client.aggregation_query(query)
aggregation_query.count(alias="total_up_to")
result = _do_fetch(aggregation_query, limit=2) # count with limit = 2
assert len(result) == 1
for r in result[0]:
assert r.alias == "total_up_to"
assert r.value == 2


def test_aggregation_query_multiple_aggregations(
aggregation_query_client, nested_query
):
Expand Down
31 changes: 26 additions & 5 deletions tests/unit/test_aggregation.py
Expand Up @@ -127,6 +127,22 @@ def test_query_fetch_w_explicit_client_w_retry_w_timeout(client):
assert iterator._timeout == timeout


def test_query_fetch_w_explicit_client_w_limit(client):
from google.cloud.datastore.aggregation import AggregationResultIterator

other_client = _make_client()
query = _make_query(client)
aggregation_query = _make_aggregation_query(client=client, query=query)
limit = 2

iterator = aggregation_query.fetch(client=other_client, limit=limit)

assert isinstance(iterator, AggregationResultIterator)
assert iterator._aggregation_query is aggregation_query
assert iterator.client is other_client
assert iterator._limit == limit


def test_iterator_constructor_defaults():
query = object()
client = object()
Expand All @@ -149,12 +165,10 @@ def test_iterator_constructor_explicit():
aggregation_query = AggregationQuery(client=client, query=query)
retry = mock.Mock()
timeout = 100000
limit = 2

iterator = _make_aggregation_iterator(
aggregation_query,
client,
retry=retry,
timeout=timeout,
aggregation_query, client, retry=retry, timeout=timeout, limit=limit
)

assert not iterator._started
Expand All @@ -165,6 +179,7 @@ def test_iterator_constructor_explicit():
assert iterator._more_results
assert iterator._retry == retry
assert iterator._timeout == timeout
assert iterator._limit == limit


def test_iterator__build_protobuf_empty():
Expand All @@ -186,14 +201,20 @@ def test_iterator__build_protobuf_all_values():

client = _Client(None)
query = _make_query(client)
alias = "total"
limit = 2
aggregation_query = AggregationQuery(client=client, query=query)
aggregation_query.count(alias)

iterator = _make_aggregation_iterator(aggregation_query, client)
iterator = _make_aggregation_iterator(aggregation_query, client, limit=limit)
iterator.num_results = 4

pb = iterator._build_protobuf()
expected_pb = query_pb2.AggregationQuery()
expected_pb.nested_query = query_pb2.Query()
expected_count_pb = query_pb2.AggregationQuery.Aggregation(alias=alias)
expected_count_pb.count.up_to = limit
expected_pb.aggregations.append(expected_count_pb)
assert pb == expected_pb


Expand Down

0 comments on commit a4b666a

Please sign in to comment.