Skip to content

Commit

Permalink
Bring coverage up 100%
Browse files Browse the repository at this point in the history
  • Loading branch information
Mariatta committed Oct 17, 2022
1 parent 16ebb21 commit 4bf8c04
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 13 deletions.
17 changes: 8 additions & 9 deletions google/cloud/datastore/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def _to_pb(self):
"""
Convert this instance to the protobuf representation
"""
raise NotImplementedError


class CountAggregation(BaseAggregation):
Expand Down Expand Up @@ -302,7 +301,7 @@ def __init__(
):
super(AggregationResultIterator, self).__init__(
client=client,
item_to_value=_item_to_pb,
item_to_value=_item_to_aggregation_result,
)

self._aggregation_query = aggregation_query
Expand All @@ -319,8 +318,8 @@ def _build_protobuf(self):
Relies on the current state of the iterator.
:rtype:
:class:`.query_pb2.Query`
:returns: The query protobuf object for the current
:class:`.query_pb2.AggregationQuery.Aggregation`
:returns: The aggregation_query protobuf object for the current
state of the iterator.
"""
pb = self._aggregation_query._to_pb()
Expand Down Expand Up @@ -417,18 +416,18 @@ def _next_page(self):


# pylint: disable=unused-argument
def _item_to_pb(iterator, pb):
def _item_to_aggregation_result(iterator, pb):
"""Convert a raw protobuf aggregation result to the native object.
:type iterator: :class:`~google.api_core.page_iterator.Iterator`
:param iterator: The iterator that is currently in use.
:type pb:
:class:`.entity_pb2.Entity`
:param entity_pb: An entity protobuf to convert to a native entity.
:class:`proto.marshal.collections.maps.MapComposite`
:param pb: The aggregation properties pb from the aggregation query result
:rtype: :class:`~proto.marshal.collections.maps.MapComposite`
:returns: The next entity in the page.
:rtype: :class:`google.cloud.datastore.aggregation.AggregationResult`
:returns: The list of AggregationResults
"""
results = [AggregationResult(alias=k, value=pb[k].integer_value) for k in pb.keys()]
return results
145 changes: 141 additions & 4 deletions tests/unit/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,9 @@ def test_iterator_constructor_explicit():

def test_iterator__build_protobuf_empty():
from google.cloud.datastore_v1.types import query as query_pb2
from google.cloud.datastore.query import Query

client = _Client(None)
query = Query(client)
query = _make_query(client)
aggregation_query = AggregationQuery(client=client, query=query)
iterator = _make_aggregation_iterator(aggregation_query, client)

Expand All @@ -184,10 +183,9 @@ def test_iterator__build_protobuf_empty():

def test_iterator__build_protobuf_all_values():
from google.cloud.datastore_v1.types import query as query_pb2
from google.cloud.datastore.query import Query

client = _Client(None)
query = Query(client)
query = _make_query(client)
aggregation_query = AggregationQuery(client=client, query=query)

iterator = _make_aggregation_iterator(aggregation_query, client)
Expand Down Expand Up @@ -216,6 +214,145 @@ def test_iterator__process_query_results():
assert iterator._more_results


def test_iterator__process_query_results_finished_result():
from google.cloud.datastore_v1.types import query as query_pb2
from google.cloud.datastore.aggregation import AggregationResult

iterator = _make_aggregation_iterator(None, None)

aggregation_pbs = [AggregationResult(alias="total", value=1)]

more_results_enum = query_pb2.QueryResultBatch.MoreResultsType.NO_MORE_RESULTS
response_pb = _make_aggregation_query_response(aggregation_pbs, more_results_enum)
result = iterator._process_query_results(response_pb)
assert result == [
r.aggregate_properties for r in response_pb.batch.aggregation_results
]
assert iterator._more_results is False


def test_iterator__process_query_results_unexpected_result():
from google.cloud.datastore_v1.types import query as query_pb2
from google.cloud.datastore.aggregation import AggregationResult

iterator = _make_aggregation_iterator(None, None)

aggregation_pbs = [AggregationResult(alias="total", value=1)]

more_results_enum = (
query_pb2.QueryResultBatch.MoreResultsType.MORE_RESULTS_TYPE_UNSPECIFIED
)
response_pb = _make_aggregation_query_response(aggregation_pbs, more_results_enum)
with pytest.raises(ValueError):
iterator._process_query_results(response_pb)


def test_aggregation_iterator__next_page():
_next_page_helper()


def test_iterator__next_page_w_retry():
retry = mock.Mock()
_next_page_helper(retry=retry)


def test_iterator__next_page_w_timeout():
_next_page_helper(timeout=100000)


def test_iterator__next_page_in_transaction():
txn_id = b"1xo1md\xe2\x98\x83"
_next_page_helper(txn_id=txn_id)


def _next_page_helper(txn_id=None, retry=None, timeout=None):
from google.api_core import page_iterator
from google.cloud.datastore.query import Query
from google.cloud.datastore_v1.types import datastore as datastore_pb2
from google.cloud.datastore_v1.types import entity as entity_pb2
from google.cloud.datastore_v1.types import query as query_pb2
from google.cloud.datastore.aggregation import AggregationResult

more_enum = query_pb2.QueryResultBatch.MoreResultsType.NOT_FINISHED
aggregation_pbs = [AggregationResult(alias="total", value=1)]

result_1 = _make_aggregation_query_response([], more_enum)
result_2 = _make_aggregation_query_response(
aggregation_pbs, query_pb2.QueryResultBatch.MoreResultsType.NO_MORE_RESULTS
)

project = "prujekt"
ds_api = _make_datastore_api_for_aggregation(result_1, result_2)
if txn_id is None:
client = _Client(project, datastore_api=ds_api)
else:
transaction = mock.Mock(id=txn_id, spec=["id"])
client = _Client(project, datastore_api=ds_api, transaction=transaction)

query = _make_query(client)
kwargs = {}

if retry is not None:
kwargs["retry"] = retry

if timeout is not None:
kwargs["timeout"] = timeout

it_kwargs = kwargs.copy() # so it doesn't get overwritten later

aggregation_query = AggregationQuery(client=client, query=query)

iterator = _make_aggregation_iterator(aggregation_query, client, **it_kwargs)
page = iterator._next_page()

assert isinstance(page, page_iterator.Page)
assert page._parent is iterator

partition_id = entity_pb2.PartitionId(project_id=project)
if txn_id is not None:
read_options = datastore_pb2.ReadOptions(transaction=txn_id)
else:
read_options = datastore_pb2.ReadOptions()

aggregation_query = AggregationQuery(client=client, query=query)
assert ds_api.run_aggregation_query.call_count == 2
expected_call = mock.call(
request={
"project_id": project,
"partition_id": partition_id,
"read_options": read_options,
"aggregation_query": aggregation_query._to_pb(),
},
**kwargs
)
assert ds_api.run_aggregation_query.call_args_list == (
[expected_call, expected_call]
)


def test__item_to_aggregation_result():
from google.cloud.datastore.aggregation import _item_to_aggregation_result
from google.cloud.datastore_v1.types import query as query_pb2
from google.cloud.datastore.aggregation import AggregationResult

client = _Client(_PROJECT)
query = _make_query(client)

aggregation_query = AggregationQuery(client=client, query=query)

iterator = _make_aggregation_iterator(aggregation_query, client)

aggregation_pbs = [AggregationResult(alias="total", value=1)]

more_results_enum = query_pb2.QueryResultBatch.MoreResultsType.NOT_FINISHED
response_pb = _make_aggregation_query_response(aggregation_pbs, more_results_enum)
aggregation_result = iterator._process_query_results(response_pb)[0]

result = _item_to_aggregation_result(iterator, aggregation_result)

assert result is not None


class _Client(object):
def __init__(self, project, datastore_api=None, namespace=None, transaction=None):
self.project = project
Expand Down
23 changes: 23 additions & 0 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1528,6 +1528,29 @@ def test_client_aggregation_query_w_defaults():
mock_klass.assert_called_once_with(client, query)


def test_client_aggregation_query_w_namespace():
namespace = object()

creds = _make_credentials()
client = _make_client(namespace=namespace, credentials=creds)
query = client.query()

aggregation_query = client.aggregation_query(query=query)
assert aggregation_query.namespace == namespace


def test_client_aggregation_query_w_namespace_collision():
namespace1 = object()
namespace2 = object()

creds = _make_credentials()
client = _make_client(namespace=namespace1, credentials=creds)
query = client.query(namespace=namespace2)

aggregation_query = client.aggregation_query(query=query)
assert aggregation_query.namespace == namespace2


def test_client_reserve_ids_multi_w_partial_key():
incomplete_key = _Key(_Key.kind, None)
creds = _make_credentials()
Expand Down

0 comments on commit 4bf8c04

Please sign in to comment.