From a4b666a4a11b04903cf7a48f74e525205d13250e Mon Sep 17 00:00:00 2001 From: Mariatta Wijaya Date: Wed, 30 Nov 2022 08:34:28 -0800 Subject: [PATCH] feat: Support "limit" in count query. (#384) * Move the limit to aggregation_query.fetch * Add test coverage --- google/cloud/datastore/aggregation.py | 9 +++++++- tests/system/test_aggregation_query.py | 20 +++++++++++++++++ tests/unit/test_aggregation.py | 31 +++++++++++++++++++++----- 3 files changed, 54 insertions(+), 6 deletions(-) diff --git a/google/cloud/datastore/aggregation.py b/google/cloud/datastore/aggregation.py index bb75d94e..24d2abcc 100644 --- a/google/cloud/datastore/aggregation.py +++ b/google/cloud/datastore/aggregation.py @@ -174,6 +174,7 @@ def add_aggregations(self, aggregations): def fetch( self, client=None, + limit=None, eventual=False, retry=None, timeout=None, @@ -204,7 +205,7 @@ def fetch( >>> client.put_multi([andy, sally, bobby]) >>> query = client.query(kind='Andy') >>> aggregation_query = client.aggregation_query(query) - >>> result = aggregation_query.count(alias="total").fetch() + >>> result = aggregation_query.count(alias="total").fetch(limit=5) >>> result @@ -248,6 +249,7 @@ def fetch( return AggregationResultIterator( self, client, + limit=limit, eventual=eventual, retry=retry, timeout=timeout, @@ -293,6 +295,7 @@ def __init__( self, aggregation_query, client, + limit=None, eventual=False, retry=None, timeout=None, @@ -308,6 +311,7 @@ def __init__( self._retry = retry self._timeout = timeout self._read_time = read_time + self._limit = limit # The attributes below will change over the life of the iterator. self._more_results = True @@ -322,6 +326,9 @@ def _build_protobuf(self): state of the iterator. """ pb = self._aggregation_query._to_pb() + if self._limit is not None and self._limit > 0: + for aggregation in pb.aggregations: + aggregation.count.up_to = self._limit return pb def _process_query_results(self, response_pb): diff --git a/tests/system/test_aggregation_query.py b/tests/system/test_aggregation_query.py index 3e5120da..b912e96b 100644 --- a/tests/system/test_aggregation_query.py +++ b/tests/system/test_aggregation_query.py @@ -93,6 +93,26 @@ def test_aggregation_query_with_alias(aggregation_query_client, nested_query): assert r.value > 0 +def test_aggregation_query_with_limit(aggregation_query_client, nested_query): + query = nested_query + + aggregation_query = aggregation_query_client.aggregation_query(query) + aggregation_query.count(alias="total") + result = _do_fetch(aggregation_query) # count without limit + assert len(result) == 1 + for r in result[0]: + assert r.alias == "total" + assert r.value == 8 + + aggregation_query = aggregation_query_client.aggregation_query(query) + aggregation_query.count(alias="total_up_to") + result = _do_fetch(aggregation_query, limit=2) # count with limit = 2 + assert len(result) == 1 + for r in result[0]: + assert r.alias == "total_up_to" + assert r.value == 2 + + def test_aggregation_query_multiple_aggregations( aggregation_query_client, nested_query ): diff --git a/tests/unit/test_aggregation.py b/tests/unit/test_aggregation.py index 8b28a908..afa9dc53 100644 --- a/tests/unit/test_aggregation.py +++ b/tests/unit/test_aggregation.py @@ -127,6 +127,22 @@ def test_query_fetch_w_explicit_client_w_retry_w_timeout(client): assert iterator._timeout == timeout +def test_query_fetch_w_explicit_client_w_limit(client): + from google.cloud.datastore.aggregation import AggregationResultIterator + + other_client = _make_client() + query = _make_query(client) + aggregation_query = _make_aggregation_query(client=client, query=query) + limit = 2 + + iterator = aggregation_query.fetch(client=other_client, limit=limit) + + assert isinstance(iterator, AggregationResultIterator) + assert iterator._aggregation_query is aggregation_query + assert iterator.client is other_client + assert iterator._limit == limit + + def test_iterator_constructor_defaults(): query = object() client = object() @@ -149,12 +165,10 @@ def test_iterator_constructor_explicit(): aggregation_query = AggregationQuery(client=client, query=query) retry = mock.Mock() timeout = 100000 + limit = 2 iterator = _make_aggregation_iterator( - aggregation_query, - client, - retry=retry, - timeout=timeout, + aggregation_query, client, retry=retry, timeout=timeout, limit=limit ) assert not iterator._started @@ -165,6 +179,7 @@ def test_iterator_constructor_explicit(): assert iterator._more_results assert iterator._retry == retry assert iterator._timeout == timeout + assert iterator._limit == limit def test_iterator__build_protobuf_empty(): @@ -186,14 +201,20 @@ def test_iterator__build_protobuf_all_values(): client = _Client(None) query = _make_query(client) + alias = "total" + limit = 2 aggregation_query = AggregationQuery(client=client, query=query) + aggregation_query.count(alias) - iterator = _make_aggregation_iterator(aggregation_query, client) + iterator = _make_aggregation_iterator(aggregation_query, client, limit=limit) iterator.num_results = 4 pb = iterator._build_protobuf() expected_pb = query_pb2.AggregationQuery() expected_pb.nested_query = query_pb2.Query() + expected_count_pb = query_pb2.AggregationQuery.Aggregation(alias=alias) + expected_count_pb.count.up_to = limit + expected_pb.aggregations.append(expected_count_pb) assert pb == expected_pb