diff --git a/docs/aggregations.rst b/docs/aggregations.rst new file mode 100644 index 00000000..d287fbc5 --- /dev/null +++ b/docs/aggregations.rst @@ -0,0 +1,6 @@ +Aggregations +~~~~~~~~~~~~ + +.. automodule:: google.cloud.datastore.aggregation + :members: + :show-inheritance: diff --git a/docs/index.rst b/docs/index.rst index 4866c891..890ec56a 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -18,6 +18,7 @@ API Reference entities keys queries + aggregations transactions batches helpers diff --git a/google/cloud/datastore/_http.py b/google/cloud/datastore/_http.py index 60b8af89..61209e98 100644 --- a/google/cloud/datastore/_http.py +++ b/google/cloud/datastore/_http.py @@ -280,6 +280,39 @@ def run_query(self, request, retry=None, timeout=None): timeout=timeout, ) + def run_aggregation_query(self, request, retry=None, timeout=None): + """Perform a ``runAggregationQuery`` request. + + :type request: :class:`_datastore_pb2.BeginTransactionRequest` or dict + :param request: + Parameter bundle for API request. + + :type retry: :class:`google.api_core.retry.Retry` + :param retry: (Optional) retry policy for the request + + :type timeout: float or tuple(float, float) + :param timeout: (Optional) timeout for the request + + :rtype: :class:`.datastore_pb2.RunAggregationQueryResponse` + :returns: The returned protobuf response object. + """ + request_pb = _make_request_pb( + request, _datastore_pb2.RunAggregationQueryRequest + ) + project_id = request_pb.project_id + + return _rpc( + self.client._http, + project_id, + "runAggregationQuery", + self.client._base_url, + self.client._client_info, + request_pb, + _datastore_pb2.RunAggregationQueryResponse, + retry=retry, + timeout=timeout, + ) + def begin_transaction(self, request, retry=None, timeout=None): """Perform a ``beginTransaction`` request. diff --git a/google/cloud/datastore/aggregation.py b/google/cloud/datastore/aggregation.py new file mode 100644 index 00000000..bb75d94e --- /dev/null +++ b/google/cloud/datastore/aggregation.py @@ -0,0 +1,431 @@ +# # Copyright 2022 Google LLC +# # +# # Licensed under the Apache License, Version 2.0 (the "License"); +# # you may not use this file except in compliance with the License. +# # You may obtain a copy of the License at +# # +# # http://www.apache.org/licenses/LICENSE-2.0 +# # +# # Unless required by applicable law or agreed to in writing, software +# # distributed under the License is distributed on an "AS IS" BASIS, +# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# # See the License for the specific language governing permissions and +# # limitations under the License. +# +# """Create / interact with Google Cloud Datastore aggregation queries.""" +import abc +from abc import ABC + +from google.api_core import page_iterator + +from google.cloud.datastore_v1.types import entity as entity_pb2 +from google.cloud.datastore_v1.types import query as query_pb2 +from google.cloud.datastore import helpers +from google.cloud.datastore.query import _pb_from_query + + +_NOT_FINISHED = query_pb2.QueryResultBatch.MoreResultsType.NOT_FINISHED +_NO_MORE_RESULTS = query_pb2.QueryResultBatch.MoreResultsType.NO_MORE_RESULTS + +_FINISHED = ( + _NO_MORE_RESULTS, + query_pb2.QueryResultBatch.MoreResultsType.MORE_RESULTS_AFTER_LIMIT, + query_pb2.QueryResultBatch.MoreResultsType.MORE_RESULTS_AFTER_CURSOR, +) + + +class BaseAggregation(ABC): + """ + Base class representing an Aggregation operation in Datastore + """ + + @abc.abstractmethod + def _to_pb(self): + """ + Convert this instance to the protobuf representation + """ + + +class CountAggregation(BaseAggregation): + """ + Representation of a "Count" aggregation query. + + :type alias: str + :param alias: The alias for the aggregation. + + :type value: int + :param value: The resulting value from the aggregation. + + """ + + def __init__(self, alias=None): + self.alias = alias + + def _to_pb(self): + """ + Convert this instance to the protobuf representation + """ + aggregation_pb = query_pb2.AggregationQuery.Aggregation() + aggregation_pb.count = query_pb2.AggregationQuery.Aggregation.Count() + aggregation_pb.alias = self.alias + return aggregation_pb + + +class AggregationResult(object): + """ + A class representing result from Aggregation Query + + :type alias: str + :param alias: The alias for the aggregation. + + :type value: int + :param value: The resulting value from the aggregation. + + """ + + def __init__(self, alias, value): + self.alias = alias + self.value = value + + def __repr__(self): + return "" % (self.alias, self.value) + + +class AggregationQuery(object): + """An Aggregation query against the Cloud Datastore. + + This class serves as an abstraction for creating aggregations over query + in the Cloud Datastore. + + :type client: :class:`google.cloud.datastore.client.Client` + :param client: The client used to connect to Datastore. + + :type query: :class:`google.cloud.datastore.query.Query` + :param query: The query used for aggregations. + """ + + def __init__( + self, + client, + query, + ): + + self._client = client + self._nested_query = query + self._aggregations = [] + + @property + def project(self): + """Get the project for this AggregationQuery. + + :rtype: str + :returns: The project for the query. + """ + return self._nested_query._project or self._client.project + + @property + def namespace(self): + """The nested query's namespace + + :rtype: str or None + :returns: the namespace assigned to this query + """ + return self._nested_query._namespace or self._client.namespace + + def _to_pb(self): + """ + Returns the protobuf representation for this Aggregation Query + """ + pb = query_pb2.AggregationQuery() + pb.nested_query = _pb_from_query(self._nested_query) + for aggregation in self._aggregations: + aggregation_pb = aggregation._to_pb() + pb.aggregations.append(aggregation_pb) + return pb + + def count(self, alias=None): + """ + Adds a count over the nested query + + :type alias: str + :param alias: (Optional) The alias for the count + """ + count_aggregation = CountAggregation(alias=alias) + self._aggregations.append(count_aggregation) + return self + + def add_aggregation(self, aggregation): + """ + Adds an aggregation operation to the nested query + + :type aggregation: :class:`google.cloud.datastore.aggregation.BaseAggregation` + :param aggregation: An aggregation operation, e.g. a CountAggregation + """ + self._aggregations.append(aggregation) + + def add_aggregations(self, aggregations): + """ + Adds a list of aggregations to the nested query + :type aggregations: list + :param aggregations: a list of aggregation operations + """ + self._aggregations.extend(aggregations) + + def fetch( + self, + client=None, + eventual=False, + retry=None, + timeout=None, + read_time=None, + ): + """Execute the Aggregation Query; return an iterator for the aggregation results. + + For example: + + .. testsetup:: aggregation-query-fetch + + import uuid + + from google.cloud import datastore + + unique = str(uuid.uuid4())[0:8] + client = datastore.Client(namespace='ns{}'.format(unique)) + + + .. doctest:: aggregation-query-fetch + + >>> andy = datastore.Entity(client.key('Person', 1234)) + >>> andy['name'] = 'Andy' + >>> sally = datastore.Entity(client.key('Person', 2345)) + >>> sally['name'] = 'Sally' + >>> bobby = datastore.Entity(client.key('Person', 3456)) + >>> bobby['name'] = 'Bobby' + >>> client.put_multi([andy, sally, bobby]) + >>> query = client.query(kind='Andy') + >>> aggregation_query = client.aggregation_query(query) + >>> result = aggregation_query.count(alias="total").fetch() + >>> result + + + .. testcleanup:: aggregation-query-fetch + + client.delete(andy.key) + + :type client: :class:`google.cloud.datastore.client.Client` + :param client: (Optional) client used to connect to datastore. + If not supplied, uses the query's value. + + :type eventual: bool + :param eventual: (Optional) Defaults to strongly consistent (False). + Setting True will use eventual consistency, + but cannot be used inside a transaction or + with read_time, otherwise will raise + ValueError. + + :type retry: :class:`google.api_core.retry.Retry` + :param retry: + A retry object used to retry requests. If ``None`` is specified, + requests will be retried using a default configuration. + + :type timeout: float + :param timeout: + Time, in seconds, to wait for the request to complete. + Note that if ``retry`` is specified, the timeout applies + to each individual attempt. + + :type read_time: datetime + :param read_time: + (Optional) use read_time read consistency, cannot be used inside a + transaction or with eventual consistency, or will raise ValueError. + + :rtype: :class:`AggregationIterator` + :returns: The iterator for the aggregation query. + """ + if client is None: + client = self._client + + return AggregationResultIterator( + self, + client, + eventual=eventual, + retry=retry, + timeout=timeout, + read_time=read_time, + ) + + +class AggregationResultIterator(page_iterator.Iterator): + """Represent the state of a given execution of a Query. + + :type aggregation_query: :class:`~google.cloud.datastore.aggregation.AggregationQuery` + :param aggregation_query: AggregationQuery object holding permanent configuration (i.e. + things that don't change on with each page in + a results set). + + :type client: :class:`~google.cloud.datastore.client.Client` + :param client: The client used to make a request. + + :type eventual: bool + :param eventual: (Optional) Defaults to strongly consistent (False). + Setting True will use eventual consistency, + but cannot be used inside a transaction or + with read_time, otherwise will raise ValueError. + + :type retry: :class:`google.api_core.retry.Retry` + :param retry: + A retry object used to retry requests. If ``None`` is specified, + requests will be retried using a default configuration. + + :type timeout: float + :param timeout: + Time, in seconds, to wait for the request to complete. + Note that if ``retry`` is specified, the timeout applies + to each individual attempt. + + :type read_time: datetime + :param read_time: (Optional) Runs the query with read time consistency. + Cannot be used with eventual consistency or inside a + transaction, otherwise will raise ValueError. This feature is in private preview. + """ + + def __init__( + self, + aggregation_query, + client, + eventual=False, + retry=None, + timeout=None, + read_time=None, + ): + super(AggregationResultIterator, self).__init__( + client=client, + item_to_value=_item_to_aggregation_result, + ) + + self._aggregation_query = aggregation_query + self._eventual = eventual + self._retry = retry + self._timeout = timeout + self._read_time = read_time + # The attributes below will change over the life of the iterator. + self._more_results = True + + def _build_protobuf(self): + """Build a query protobuf. + + Relies on the current state of the iterator. + + :rtype: + :class:`.query_pb2.AggregationQuery.Aggregation` + :returns: The aggregation_query protobuf object for the current + state of the iterator. + """ + pb = self._aggregation_query._to_pb() + return pb + + def _process_query_results(self, response_pb): + """Process the response from a datastore query. + + :type response_pb: :class:`.datastore_pb2.RunQueryResponse` + :param response_pb: The protobuf response from a ``runQuery`` request. + + :rtype: iterable + :returns: The next page of entity results. + :raises ValueError: If ``more_results`` is an unexpected value. + """ + + if response_pb.batch.more_results == _NOT_FINISHED: + self._more_results = True + elif response_pb.batch.more_results in _FINISHED: + self._more_results = False + else: + raise ValueError("Unexpected value returned for `more_results`.") + + return [ + result.aggregate_properties + for result in response_pb.batch.aggregation_results + ] + + def _next_page(self): + """Get the next page in the iterator. + + :rtype: :class:`~google.cloud.iterator.Page` + :returns: The next page in the iterator (or :data:`None` if + there are no pages left). + """ + if not self._more_results: + return None + + query_pb = self._build_protobuf() + transaction = self.client.current_transaction + if transaction is None: + transaction_id = None + else: + transaction_id = transaction.id + read_options = helpers.get_read_options( + self._eventual, transaction_id, self._read_time + ) + + partition_id = entity_pb2.PartitionId( + project_id=self._aggregation_query.project, + namespace_id=self._aggregation_query.namespace, + ) + + kwargs = {} + + if self._retry is not None: + kwargs["retry"] = self._retry + + if self._timeout is not None: + kwargs["timeout"] = self._timeout + + response_pb = self.client._datastore_api.run_aggregation_query( + request={ + "project_id": self._aggregation_query.project, + "partition_id": partition_id, + "read_options": read_options, + "aggregation_query": query_pb, + }, + **kwargs, + ) + + while response_pb.batch.more_results == _NOT_FINISHED: + # We haven't finished processing. A likely reason is we haven't + # skipped all of the results yet. Don't return any results. + # Instead, rerun query, adjusting offsets. Datastore doesn't process + # more than 1000 skipped results in a query. + old_query_pb = query_pb + query_pb = query_pb2.AggregationQuery() + query_pb._pb.CopyFrom(old_query_pb._pb) # copy for testability + + response_pb = self.client._datastore_api.run_aggregation_query( + request={ + "project_id": self._aggregation_query.project, + "partition_id": partition_id, + "read_options": read_options, + "aggregation_query": query_pb, + }, + **kwargs, + ) + + item_pbs = self._process_query_results(response_pb) + return page_iterator.Page(self, item_pbs, self.item_to_value) + + +# pylint: disable=unused-argument +def _item_to_aggregation_result(iterator, pb): + """Convert a raw protobuf aggregation result to the native object. + + :type iterator: :class:`~google.api_core.page_iterator.Iterator` + :param iterator: The iterator that is currently in use. + + :type pb: + :class:`proto.marshal.collections.maps.MapComposite` + :param pb: The aggregation properties pb from the aggregation query result + + :rtype: :class:`google.cloud.datastore.aggregation.AggregationResult` + :returns: The list of AggregationResults + """ + results = [AggregationResult(alias=k, value=pb[k].integer_value) for k in pb.keys()] + return results diff --git a/google/cloud/datastore/client.py b/google/cloud/datastore/client.py index 212ba1d4..e90a3415 100644 --- a/google/cloud/datastore/client.py +++ b/google/cloud/datastore/client.py @@ -28,6 +28,8 @@ from google.cloud.datastore.entity import Entity from google.cloud.datastore.key import Key from google.cloud.datastore.query import Query +from google.cloud.datastore.aggregation import AggregationQuery + from google.cloud.datastore.transaction import Transaction try: @@ -837,6 +839,86 @@ def do_something_with(entity): kwargs["namespace"] = self.namespace return Query(self, **kwargs) + def aggregation_query(self, query): + """Proxy to :class:`google.cloud.datastore.aggregation.AggregationQuery`. + + Using aggregation_query to count over a query: + + .. testsetup:: aggregation_query + + import uuid + + from google.cloud import datastore + from google.cloud.datastore.aggregation import CountAggregation + + unique = str(uuid.uuid4())[0:8] + client = datastore.Client(namespace='ns{}'.format(unique)) + + def do_something_with(entity): + pass + + .. doctest:: aggregation_query + + >>> query = client.query(kind='MyKind') + >>> aggregation_query = client.aggregation_query(query) + >>> aggregation_query.count(alias='total') + + >>> aggregation_query.fetch() + + + Adding an aggregation to the aggregation_query + + .. doctest:: aggregation_query + + >>> query = client.query(kind='MyKind') + >>> aggregation_query.add_aggregation(CountAggregation(alias='total')) + >>> aggregation_query.fetch() + + + Adding multiple aggregations to the aggregation_query + + .. doctest:: aggregation_query + + >>> query = client.query(kind='MyKind') + >>> total_count = CountAggregation(alias='total') + >>> all_count = CountAggregation(alias='all') + >>> aggregation_query.add_aggregations([total_count, all_count]) + >>> aggregation_query.fetch() + + + + Using the aggregation_query iterator + + .. doctest:: aggregation_query + + >>> query = client.query(kind='MyKind') + >>> aggregation_query = client.aggregation_query(query) + >>> aggregation_query.count(alias='total') + + >>> aggregation_query_iter = aggregation_query.fetch() + >>> for aggregation_result in aggregation_query_iter: + ... do_something_with(aggregation_result) + + or manually page through results + + .. doctest:: aggregation_query + + >>> aggregation_query_iter = aggregation_query.fetch() + >>> pages = aggregation_query_iter.pages + >>> + >>> first_page = next(pages) + >>> first_page_entities = list(first_page) + >>> aggregation_query_iter.next_page_token is None + True + + :param kwargs: Parameters for initializing and instance of + :class:`~google.cloud.datastore.aggregation.AggregationQuery`. + + :rtype: :class:`~google.cloud.datastore.aggregation.AggregationQuery` + :returns: An AggregationQuery object. + """ + return AggregationQuery(self, query) + def reserve_ids_sequential(self, complete_key, num_ids, retry=None, timeout=None): """Reserve a list of IDs sequentially from a complete key. diff --git a/tests/system/test_aggregation_query.py b/tests/system/test_aggregation_query.py new file mode 100644 index 00000000..3e5120da --- /dev/null +++ b/tests/system/test_aggregation_query.py @@ -0,0 +1,214 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from google.api_core import exceptions +from test_utils.retry import RetryErrors + +from .utils import clear_datastore +from .utils import populate_datastore +from . import _helpers + + +retry_503 = RetryErrors(exceptions.ServiceUnavailable) + + +def _make_iterator(aggregation_query, **kw): + # Do retry for errors raised during initial API call + return retry_503(aggregation_query.fetch)(**kw) + + +def _pull_iterator(aggregation_query, **kw): + return list(_make_iterator(aggregation_query, **kw)) + + +def _do_fetch(aggregation_query, **kw): + # Do retry for errors raised during iteration + return retry_503(_pull_iterator)(aggregation_query, **kw) + + +@pytest.fixture(scope="session") +def aggregation_query_client(datastore_client): + return _helpers.clone_client(datastore_client, namespace=None) + + +@pytest.fixture(scope="session") +def ancestor_key(aggregation_query_client, in_emulator): + + # In the emulator, re-populating the datastore is cheap. + if in_emulator: + populate_datastore.add_characters(client=aggregation_query_client) + + ancestor_key = aggregation_query_client.key(*populate_datastore.ANCESTOR) + + yield ancestor_key + + # In the emulator, destroy the query entities. + if in_emulator: + clear_datastore.remove_all_entities(client=aggregation_query_client) + + +def _make_query(aggregation_query_client, ancestor_key): + return aggregation_query_client.query(kind="Character", ancestor=ancestor_key) + + +@pytest.fixture(scope="function") +def nested_query(aggregation_query_client, ancestor_key): + return _make_query(aggregation_query_client, ancestor_key) + + +def test_aggregation_query_default(aggregation_query_client, nested_query): + query = nested_query + + aggregation_query = aggregation_query_client.aggregation_query(query) + aggregation_query.count() + result = _do_fetch(aggregation_query) + assert len(result) == 1 + for r in result[0]: + assert r.alias == "property_1" + assert r.value == 8 + + +def test_aggregation_query_with_alias(aggregation_query_client, nested_query): + query = nested_query + + aggregation_query = aggregation_query_client.aggregation_query(query) + aggregation_query.count(alias="total") + result = _do_fetch(aggregation_query) + assert len(result) == 1 + for r in result[0]: + assert r.alias == "total" + assert r.value > 0 + + +def test_aggregation_query_multiple_aggregations( + aggregation_query_client, nested_query +): + query = nested_query + + aggregation_query = aggregation_query_client.aggregation_query(query) + aggregation_query.count(alias="total") + aggregation_query.count(alias="all") + result = _do_fetch(aggregation_query) + assert len(result) == 1 + for r in result[0]: + assert r.alias in ["all", "total"] + assert r.value > 0 + + +def test_aggregation_query_add_aggregation(aggregation_query_client, nested_query): + from google.cloud.datastore.aggregation import CountAggregation + + query = nested_query + + aggregation_query = aggregation_query_client.aggregation_query(query) + count_aggregation = CountAggregation(alias="total") + aggregation_query.add_aggregation(count_aggregation) + result = _do_fetch(aggregation_query) + assert len(result) == 1 + for r in result[0]: + assert r.alias == "total" + assert r.value > 0 + + +def test_aggregation_query_add_aggregations(aggregation_query_client, nested_query): + from google.cloud.datastore.aggregation import CountAggregation + + query = nested_query + + aggregation_query = aggregation_query_client.aggregation_query(query) + count_aggregation_1 = CountAggregation(alias="total") + count_aggregation_2 = CountAggregation(alias="all") + aggregation_query.add_aggregations([count_aggregation_1, count_aggregation_2]) + result = _do_fetch(aggregation_query) + assert len(result) == 1 + for r in result[0]: + assert r.alias in ["total", "all"] + assert r.value > 0 + + +def test_aggregation_query_add_aggregations_duplicated_alias( + aggregation_query_client, nested_query +): + from google.cloud.datastore.aggregation import CountAggregation + from google.api_core.exceptions import BadRequest + + query = nested_query + + aggregation_query = aggregation_query_client.aggregation_query(query) + count_aggregation_1 = CountAggregation(alias="total") + count_aggregation_2 = CountAggregation(alias="total") + aggregation_query.add_aggregations([count_aggregation_1, count_aggregation_2]) + with pytest.raises(BadRequest): + _do_fetch(aggregation_query) + + aggregation_query = aggregation_query_client.aggregation_query(query) + aggregation_query.add_aggregation(count_aggregation_1) + aggregation_query.add_aggregation(count_aggregation_2) + with pytest.raises(BadRequest): + _do_fetch(aggregation_query) + + aggregation_query = aggregation_query_client.aggregation_query(query) + aggregation_query.count(alias="total") + aggregation_query.count(alias="total") + with pytest.raises(BadRequest): + _do_fetch(aggregation_query) + + +def test_aggregation_query_with_nested_query_filtered( + aggregation_query_client, nested_query +): + query = nested_query + + query.add_filter("appearances", ">=", 20) + expected_matches = 6 + + # We expect 6, but allow the query to get 1 extra. + entities = _do_fetch(query, limit=expected_matches + 1) + + assert len(entities) == expected_matches + + aggregation_query = aggregation_query_client.aggregation_query(query) + aggregation_query.count(alias="total") + result = _do_fetch(aggregation_query) + assert len(result) == 1 + + for r in result[0]: + assert r.alias == "total" + assert r.value == 6 + + +def test_aggregation_query_with_nested_query_multiple_filters( + aggregation_query_client, nested_query +): + query = nested_query + + query.add_filter("appearances", ">=", 26) + query = query.add_filter("family", "=", "Stark") + expected_matches = 4 + + # We expect 4, but allow the query to get 1 extra. + entities = _do_fetch(query, limit=expected_matches + 1) + + assert len(entities) == expected_matches + + aggregation_query = aggregation_query_client.aggregation_query(query) + aggregation_query.count(alias="total") + result = _do_fetch(aggregation_query) + assert len(result) == 1 + + for r in result[0]: + assert r.alias == "total" + assert r.value == 4 diff --git a/tests/unit/test__http.py b/tests/unit/test__http.py index a03397d5..f9e0a29f 100644 --- a/tests/unit/test__http.py +++ b/tests/unit/test__http.py @@ -557,6 +557,98 @@ def test_api_run_query_w_namespace_nonempty_result(): _run_query_helper(namespace=namespace, found=1) +def _run_aggregation_query_helper( + transaction=None, + retry=None, + timeout=None, +): + from google.cloud.datastore_v1.types import datastore as datastore_pb2 + from google.cloud.datastore_v1.types import entity as entity_pb2 + from google.cloud.datastore_v1.types import query as query_pb2 + from google.cloud.datastore_v1.types import aggregation_result + + project = "PROJECT" + kind = "Nonesuch" + query_pb = query_pb2.Query(kind=[query_pb2.KindExpression(name=kind)]) + + aggregation_query_pb = query_pb2.AggregationQuery() + aggregation_query_pb.nested_query = query_pb + count_aggregation = query_pb2.AggregationQuery.Aggregation() + count_aggregation.alias = "total" + aggregation_query_pb.aggregations.append(count_aggregation) + partition_kw = {"project_id": project} + + partition_id = entity_pb2.PartitionId(**partition_kw) + + options_kw = {} + + if transaction is not None: + options_kw["transaction"] = transaction + read_options = datastore_pb2.ReadOptions(**options_kw) + + batch_kw = { + "more_results": query_pb2.QueryResultBatch.MoreResultsType.NO_MORE_RESULTS, + } + rsp_pb = datastore_pb2.RunAggregationQueryResponse( + batch=aggregation_result.AggregationResultBatch(**batch_kw) + ) + + http = _make_requests_session( + [_make_response(content=rsp_pb._pb.SerializeToString())] + ) + client_info = _make_client_info() + client = mock.Mock( + _http=http, + _base_url="test.invalid", + _client_info=client_info, + spec=["_http", "_base_url", "_client_info"], + ) + ds_api = _make_http_datastore_api(client) + request = { + "project_id": project, + "partition_id": partition_id, + "read_options": read_options, + "aggregation_query": aggregation_query_pb, + } + kwargs = _retry_timeout_kw(retry, timeout, http) + + response = ds_api.run_aggregation_query(request=request, **kwargs) + + assert response == rsp_pb._pb + + uri = _build_expected_url(client._base_url, project, "runAggregationQuery") + request = _verify_protobuf_call( + http, + uri, + datastore_pb2.RunAggregationQueryRequest(), + retry=retry, + timeout=timeout, + ) + + assert request.partition_id == partition_id._pb + assert request.aggregation_query == aggregation_query_pb._pb + assert request.read_options == read_options._pb + + +def test_api_run_aggregation_query_simple(): + _run_aggregation_query_helper() + + +def test_api_run_aggregation_query_w_retry(): + retry = mock.MagicMock() + _run_aggregation_query_helper(retry=retry) + + +def test_api_run_aggregation_query_w_timeout(): + timeout = 5.0 + _run_aggregation_query_helper(timeout=timeout) + + +def test_api_run_aggregation_query_w_transaction(): + transaction = b"TRANSACTION" + _run_aggregation_query_helper(transaction=transaction) + + def _begin_transaction_helper(options=None, retry=None, timeout=None): from google.cloud.datastore_v1.types import datastore as datastore_pb2 diff --git a/tests/unit/test_aggregation.py b/tests/unit/test_aggregation.py new file mode 100644 index 00000000..8b28a908 --- /dev/null +++ b/tests/unit/test_aggregation.py @@ -0,0 +1,415 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import mock +import pytest + +from google.cloud.datastore.aggregation import CountAggregation, AggregationQuery + +from tests.unit.test_query import _make_query, _make_client + +_PROJECT = "PROJECT" + + +def test_count_aggregation_to_pb(): + from google.cloud.datastore_v1.types import query as query_pb2 + + count_aggregation = CountAggregation(alias="total") + + expected_aggregation_query_pb = query_pb2.AggregationQuery.Aggregation() + expected_aggregation_query_pb.count = query_pb2.AggregationQuery.Aggregation.Count() + expected_aggregation_query_pb.alias = count_aggregation.alias + assert count_aggregation._to_pb() == expected_aggregation_query_pb + + +@pytest.fixture +def client(): + return _make_client() + + +def test_pb_over_query(client): + from google.cloud.datastore.query import _pb_from_query + + query = _make_query(client) + aggregation_query = _make_aggregation_query(client=client, query=query) + pb = aggregation_query._to_pb() + assert pb.nested_query == _pb_from_query(query) + assert pb.aggregations == [] + + +def test_pb_over_query_with_count(client): + from google.cloud.datastore.query import _pb_from_query + + query = _make_query(client) + aggregation_query = _make_aggregation_query(client=client, query=query) + + aggregation_query.count(alias="total") + pb = aggregation_query._to_pb() + assert pb.nested_query == _pb_from_query(query) + assert len(pb.aggregations) == 1 + assert pb.aggregations[0] == CountAggregation(alias="total")._to_pb() + + +def test_pb_over_query_with_add_aggregation(client): + from google.cloud.datastore.query import _pb_from_query + + query = _make_query(client) + aggregation_query = _make_aggregation_query(client=client, query=query) + + aggregation_query.add_aggregation(CountAggregation(alias="total")) + pb = aggregation_query._to_pb() + assert pb.nested_query == _pb_from_query(query) + assert len(pb.aggregations) == 1 + assert pb.aggregations[0] == CountAggregation(alias="total")._to_pb() + + +def test_pb_over_query_with_add_aggregations(client): + from google.cloud.datastore.query import _pb_from_query + + aggregations = [ + CountAggregation(alias="total"), + CountAggregation(alias="all"), + ] + + query = _make_query(client) + aggregation_query = _make_aggregation_query(client=client, query=query) + + aggregation_query.add_aggregations(aggregations) + pb = aggregation_query._to_pb() + assert pb.nested_query == _pb_from_query(query) + assert len(pb.aggregations) == 2 + assert pb.aggregations[0] == CountAggregation(alias="total")._to_pb() + assert pb.aggregations[1] == CountAggregation(alias="all")._to_pb() + + +def test_query_fetch_defaults_w_client_attr(client): + from google.cloud.datastore.aggregation import AggregationResultIterator + + query = _make_query(client) + aggregation_query = _make_aggregation_query(client=client, query=query) + iterator = aggregation_query.fetch() + + assert isinstance(iterator, AggregationResultIterator) + assert iterator._aggregation_query is aggregation_query + assert iterator.client is client + assert iterator._retry is None + assert iterator._timeout is None + + +def test_query_fetch_w_explicit_client_w_retry_w_timeout(client): + from google.cloud.datastore.aggregation import AggregationResultIterator + + other_client = _make_client() + query = _make_query(client) + aggregation_query = _make_aggregation_query(client=client, query=query) + retry = mock.Mock() + timeout = 100000 + + iterator = aggregation_query.fetch( + client=other_client, retry=retry, timeout=timeout + ) + + assert isinstance(iterator, AggregationResultIterator) + assert iterator._aggregation_query is aggregation_query + assert iterator.client is other_client + assert iterator._retry == retry + assert iterator._timeout == timeout + + +def test_iterator_constructor_defaults(): + query = object() + client = object() + aggregation_query = AggregationQuery(client=client, query=query) + iterator = _make_aggregation_iterator(aggregation_query, client) + + assert not iterator._started + assert iterator.client is client + assert iterator.page_number == 0 + assert iterator.num_results == 0 + assert iterator._aggregation_query is aggregation_query + assert iterator._more_results + assert iterator._retry is None + assert iterator._timeout is None + + +def test_iterator_constructor_explicit(): + query = object() + client = object() + aggregation_query = AggregationQuery(client=client, query=query) + retry = mock.Mock() + timeout = 100000 + + iterator = _make_aggregation_iterator( + aggregation_query, + client, + retry=retry, + timeout=timeout, + ) + + assert not iterator._started + assert iterator.client is client + assert iterator.page_number == 0 + assert iterator.num_results == 0 + assert iterator._aggregation_query is aggregation_query + assert iterator._more_results + assert iterator._retry == retry + assert iterator._timeout == timeout + + +def test_iterator__build_protobuf_empty(): + from google.cloud.datastore_v1.types import query as query_pb2 + + client = _Client(None) + query = _make_query(client) + aggregation_query = AggregationQuery(client=client, query=query) + iterator = _make_aggregation_iterator(aggregation_query, client) + + pb = iterator._build_protobuf() + expected_pb = query_pb2.AggregationQuery() + expected_pb.nested_query = query_pb2.Query() + assert pb == expected_pb + + +def test_iterator__build_protobuf_all_values(): + from google.cloud.datastore_v1.types import query as query_pb2 + + client = _Client(None) + query = _make_query(client) + aggregation_query = AggregationQuery(client=client, query=query) + + iterator = _make_aggregation_iterator(aggregation_query, client) + iterator.num_results = 4 + + pb = iterator._build_protobuf() + expected_pb = query_pb2.AggregationQuery() + expected_pb.nested_query = query_pb2.Query() + assert pb == expected_pb + + +def test_iterator__process_query_results(): + from google.cloud.datastore_v1.types import query as query_pb2 + from google.cloud.datastore.aggregation import AggregationResult + + iterator = _make_aggregation_iterator(None, None) + + aggregation_pbs = [AggregationResult(alias="total", value=1)] + + more_results_enum = query_pb2.QueryResultBatch.MoreResultsType.NOT_FINISHED + response_pb = _make_aggregation_query_response(aggregation_pbs, more_results_enum) + result = iterator._process_query_results(response_pb) + assert result == [ + r.aggregate_properties for r in response_pb.batch.aggregation_results + ] + assert iterator._more_results + + +def test_iterator__process_query_results_finished_result(): + from google.cloud.datastore_v1.types import query as query_pb2 + from google.cloud.datastore.aggregation import AggregationResult + + iterator = _make_aggregation_iterator(None, None) + + aggregation_pbs = [AggregationResult(alias="total", value=1)] + + more_results_enum = query_pb2.QueryResultBatch.MoreResultsType.NO_MORE_RESULTS + response_pb = _make_aggregation_query_response(aggregation_pbs, more_results_enum) + result = iterator._process_query_results(response_pb) + assert result == [ + r.aggregate_properties for r in response_pb.batch.aggregation_results + ] + assert iterator._more_results is False + + +def test_iterator__process_query_results_unexpected_result(): + from google.cloud.datastore_v1.types import query as query_pb2 + from google.cloud.datastore.aggregation import AggregationResult + + iterator = _make_aggregation_iterator(None, None) + + aggregation_pbs = [AggregationResult(alias="total", value=1)] + + more_results_enum = ( + query_pb2.QueryResultBatch.MoreResultsType.MORE_RESULTS_TYPE_UNSPECIFIED + ) + response_pb = _make_aggregation_query_response(aggregation_pbs, more_results_enum) + with pytest.raises(ValueError): + iterator._process_query_results(response_pb) + + +def test_aggregation_iterator__next_page(): + _next_page_helper() + + +def test_iterator__next_page_w_retry(): + retry = mock.Mock() + _next_page_helper(retry=retry) + + +def test_iterator__next_page_w_timeout(): + _next_page_helper(timeout=100000) + + +def test_iterator__next_page_in_transaction(): + txn_id = b"1xo1md\xe2\x98\x83" + _next_page_helper(txn_id=txn_id) + + +def test_iterator__next_page_no_more(): + from google.cloud.datastore.query import Query + + ds_api = _make_datastore_api_for_aggregation() + client = _Client(None, datastore_api=ds_api) + query = Query(client) + + iterator = _make_aggregation_iterator(query, client) + iterator._more_results = False + page = iterator._next_page() + assert page is None + ds_api.run_aggregation_query.assert_not_called() + + +def _next_page_helper(txn_id=None, retry=None, timeout=None): + from google.api_core import page_iterator + from google.cloud.datastore_v1.types import datastore as datastore_pb2 + from google.cloud.datastore_v1.types import entity as entity_pb2 + from google.cloud.datastore_v1.types import query as query_pb2 + from google.cloud.datastore.aggregation import AggregationResult + + more_enum = query_pb2.QueryResultBatch.MoreResultsType.NOT_FINISHED + aggregation_pbs = [AggregationResult(alias="total", value=1)] + + result_1 = _make_aggregation_query_response([], more_enum) + result_2 = _make_aggregation_query_response( + aggregation_pbs, query_pb2.QueryResultBatch.MoreResultsType.NO_MORE_RESULTS + ) + + project = "prujekt" + ds_api = _make_datastore_api_for_aggregation(result_1, result_2) + if txn_id is None: + client = _Client(project, datastore_api=ds_api) + else: + transaction = mock.Mock(id=txn_id, spec=["id"]) + client = _Client(project, datastore_api=ds_api, transaction=transaction) + + query = _make_query(client) + kwargs = {} + + if retry is not None: + kwargs["retry"] = retry + + if timeout is not None: + kwargs["timeout"] = timeout + + it_kwargs = kwargs.copy() # so it doesn't get overwritten later + + aggregation_query = AggregationQuery(client=client, query=query) + + iterator = _make_aggregation_iterator(aggregation_query, client, **it_kwargs) + page = iterator._next_page() + + assert isinstance(page, page_iterator.Page) + assert page._parent is iterator + + partition_id = entity_pb2.PartitionId(project_id=project) + if txn_id is not None: + read_options = datastore_pb2.ReadOptions(transaction=txn_id) + else: + read_options = datastore_pb2.ReadOptions() + + aggregation_query = AggregationQuery(client=client, query=query) + assert ds_api.run_aggregation_query.call_count == 2 + expected_call = mock.call( + request={ + "project_id": project, + "partition_id": partition_id, + "read_options": read_options, + "aggregation_query": aggregation_query._to_pb(), + }, + **kwargs + ) + assert ds_api.run_aggregation_query.call_args_list == ( + [expected_call, expected_call] + ) + + +def test__item_to_aggregation_result(): + from google.cloud.datastore.aggregation import _item_to_aggregation_result + from google.cloud.datastore.aggregation import AggregationResult + + with mock.patch( + "proto.marshal.collections.maps.MapComposite" + ) as map_composite_mock: + map_composite_mock.keys.return_value = {"total": {"integer_value": 1}} + + result = _item_to_aggregation_result(None, map_composite_mock) + + assert len(result) == 1 + assert type(result[0]) == AggregationResult + + assert result[0].alias == "total" + assert result[0].value == map_composite_mock.__getitem__().integer_value + + +class _Client(object): + def __init__(self, project, datastore_api=None, namespace=None, transaction=None): + self.project = project + self._datastore_api = datastore_api + self.namespace = namespace + self._transaction = transaction + + @property + def current_transaction(self): + return self._transaction + + +def _make_aggregation_query(*args, **kw): + from google.cloud.datastore.aggregation import AggregationQuery + + return AggregationQuery(*args, **kw) + + +def _make_aggregation_iterator(*args, **kw): + from google.cloud.datastore.aggregation import AggregationResultIterator + + return AggregationResultIterator(*args, **kw) + + +def _make_aggregation_query_response(aggregation_pbs, more_results_enum): + from google.cloud.datastore_v1.types import datastore as datastore_pb2 + from google.cloud.datastore_v1.types import aggregation_result + + aggregation_results = [] + for aggr in aggregation_pbs: + result = aggregation_result.AggregationResult() + result.aggregate_properties.alias = aggr.alias + result.aggregate_properties.value = aggr.value + aggregation_results.append(result) + + return datastore_pb2.RunAggregationQueryResponse( + batch=aggregation_result.AggregationResultBatch( + aggregation_results=aggregation_results, + more_results=more_results_enum, + ) + ) + + +def _make_datastore_api_for_aggregation(*results): + if len(results) == 0: + run_aggregation_query = mock.Mock(return_value=None, spec=[]) + else: + run_aggregation_query = mock.Mock(side_effect=results, spec=[]) + + return mock.Mock( + run_aggregation_query=run_aggregation_query, spec=["run_aggregation_query"] + ) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 2a15677a..3e35f74e 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -1515,6 +1515,42 @@ def test_client_query_w_namespace_collision(): ) +def test_client_aggregation_query_w_defaults(): + creds = _make_credentials() + client = _make_client(credentials=creds) + query = client.query() + patch = mock.patch( + "google.cloud.datastore.client.AggregationQuery", spec=["__call__"] + ) + with patch as mock_klass: + aggregation_query = client.aggregation_query(query=query) + assert aggregation_query is mock_klass.return_value + mock_klass.assert_called_once_with(client, query) + + +def test_client_aggregation_query_w_namespace(): + namespace = object() + + creds = _make_credentials() + client = _make_client(namespace=namespace, credentials=creds) + query = client.query() + + aggregation_query = client.aggregation_query(query=query) + assert aggregation_query.namespace == namespace + + +def test_client_aggregation_query_w_namespace_collision(): + namespace1 = object() + namespace2 = object() + + creds = _make_credentials() + client = _make_client(namespace=namespace1, credentials=creds) + query = client.query(namespace=namespace2) + + aggregation_query = client.aggregation_query(query=query) + assert aggregation_query.namespace == namespace2 + + def test_client_reserve_ids_multi_w_partial_key(): incomplete_key = _Key(_Key.kind, None) creds = _make_credentials()