Skip to content

Commit

Permalink
feat: allow exceptions to be included in batch responses (#1043)
Browse files Browse the repository at this point in the history
* feat: allow exceptions to be included in batch responses

* fix docstring

* address comments and update tests

* more tests
  • Loading branch information
cojenco committed May 31, 2023
1 parent f4d8637 commit 94a35ba
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 8 deletions.
39 changes: 33 additions & 6 deletions google/cloud/storage/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,18 +133,27 @@ class Batch(Connection):
:type client: :class:`google.cloud.storage.client.Client`
:param client: The client to use for making connections.
:type raise_exception: bool
:param raise_exception:
(Optional) Defaults to True. If True, instead of adding exceptions
to the list of return responses, the final exception will be raised.
Note that exceptions are unwrapped after all operations are complete
in success or failure, and only the last exception is raised.
"""

_MAX_BATCH_SIZE = 1000

def __init__(self, client):
def __init__(self, client, raise_exception=True):
api_endpoint = client._connection.API_BASE_URL
client_info = client._connection._client_info
super(Batch, self).__init__(
client, client_info=client_info, api_endpoint=api_endpoint
)
self._requests = []
self._target_objects = []
self._responses = []
self._raise_exception = raise_exception

def _do_request(
self, method, url, headers, data, target_object, timeout=_DEFAULT_TIMEOUT
Expand Down Expand Up @@ -219,24 +228,34 @@ def _prepare_batch_request(self):
_, body = payload.split("\n\n", 1)
return dict(multi._headers), body, timeout

def _finish_futures(self, responses):
def _finish_futures(self, responses, raise_exception=True):
"""Apply all the batch responses to the futures created.
:type responses: list of (headers, payload) tuples.
:param responses: List of headers and payloads from each response in
the batch.
:type raise_exception: bool
:param raise_exception:
(Optional) Defaults to True. If True, instead of adding exceptions
to the list of return responses, the final exception will be raised.
Note that exceptions are unwrapped after all operations are complete
in success or failure, and only the last exception is raised.
:raises: :class:`ValueError` if no requests have been deferred.
"""
# If a bad status occurs, we track it, but don't raise an exception
# until all futures have been populated.
# If raise_exception=False, we add exceptions to the list of responses.
exception_args = None

if len(self._target_objects) != len(responses): # pragma: NO COVER
raise ValueError("Expected a response for every request.")

for target_object, subresponse in zip(self._target_objects, responses):
if not 200 <= subresponse.status_code < 300:
# For backwards compatibility, only the final exception will be raised.
# Set raise_exception=False to include all exceptions to the list of return responses.
if not 200 <= subresponse.status_code < 300 and raise_exception:
exception_args = exception_args or subresponse
elif target_object is not None:
try:
Expand All @@ -247,9 +266,16 @@ def _finish_futures(self, responses):
if exception_args is not None:
raise exceptions.from_http_response(exception_args)

def finish(self):
def finish(self, raise_exception=True):
"""Submit a single `multipart/mixed` request with deferred requests.
:type raise_exception: bool
:param raise_exception:
(Optional) Defaults to True. If True, instead of adding exceptions
to the list of return responses, the final exception will be raised.
Note that exceptions are unwrapped after all operations are complete
in success or failure, and only the last exception is raised.
:rtype: list of tuples
:returns: one ``(headers, payload)`` tuple per deferred request.
"""
Expand All @@ -269,7 +295,8 @@ def finish(self):
raise exceptions.from_http_response(response)

responses = list(_unpack_batch_response(response))
self._finish_futures(responses)
self._finish_futures(responses, raise_exception=raise_exception)
self._responses = responses
return responses

def current(self):
Expand All @@ -283,7 +310,7 @@ def __enter__(self):
def __exit__(self, exc_type, exc_val, exc_tb):
try:
if exc_type is None:
self.finish()
self.finish(raise_exception=self._raise_exception)
finally:
self._client._pop_batch()

Expand Down
11 changes: 9 additions & 2 deletions google/cloud/storage/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,17 +307,24 @@ def bucket(self, bucket_name, user_project=None):
"""
return Bucket(client=self, name=bucket_name, user_project=user_project)

def batch(self):
def batch(self, raise_exception=True):
"""Factory constructor for batch object.
.. note::
This will not make an HTTP request; it simply instantiates
a batch object owned by this client.
:type raise_exception: bool
:param raise_exception:
(Optional) Defaults to True. If True, instead of adding exceptions
to the list of return responses, the final exception will be raised.
Note that exceptions are unwrapped after all operations are complete
in success or failure, and only the last exception is raised.
:rtype: :class:`google.cloud.storage.batch.Batch`
:returns: The batch object created.
"""
return Batch(client=self)
return Batch(client=self, raise_exception=raise_exception)

def _get_resource(
self,
Expand Down
88 changes: 88 additions & 0 deletions tests/unit/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ def test_finish_nonempty(self):
result = batch.finish()

self.assertEqual(len(result), len(batch._requests))
self.assertEqual(len(result), len(batch._responses))

response1, response2, response3 = result

Expand Down Expand Up @@ -438,6 +439,55 @@ def test_finish_nonempty_with_status_failure(self):
self._check_subrequest_payload(chunks[0], "GET", url, {})
self._check_subrequest_payload(chunks[1], "GET", url, {})

def test_finish_no_raise_exception(self):
url = "http://api.example.com/other_api"
expected_response = _make_response(
content=_TWO_PART_MIME_RESPONSE_WITH_FAIL,
headers={"content-type": 'multipart/mixed; boundary="DEADBEEF="'},
)
http = _make_requests_session([expected_response])
connection = _Connection(http=http)
client = _Client(connection)
batch = self._make_one(client)
batch.API_BASE_URL = "http://api.example.com"
target1 = _MockObject()
target2 = _MockObject()

batch._do_request("GET", url, {}, None, target1, timeout=42)
batch._do_request("GET", url, {}, None, target2, timeout=420)

# Make sure futures are not populated.
self.assertEqual(
[future for future in batch._target_objects], [target1, target2]
)

batch.finish(raise_exception=False)

self.assertEqual(len(batch._requests), 2)
self.assertEqual(len(batch._responses), 2)

# Make sure NotFound exception is added to responses and target2
self.assertEqual(target1._properties, {"foo": 1, "bar": 2})
self.assertEqual(target2._properties, {"error": {"message": "Not Found"}})

expected_url = f"{batch.API_BASE_URL}/batch/storage/v1"
http.request.assert_called_once_with(
method="POST",
url=expected_url,
headers=mock.ANY,
data=mock.ANY,
timeout=420, # the last request timeout prevails
)

_, request_body, _, boundary = self._get_mutlipart_request(http)

chunks = self._get_payload_chunks(boundary, request_body)
self.assertEqual(len(chunks), 2)
self._check_subrequest_payload(chunks[0], "GET", url, {})
self._check_subrequest_payload(chunks[1], "GET", url, {})
self.assertEqual(batch._responses[0].status_code, 200)
self.assertEqual(batch._responses[1].status_code, 404)

def test_finish_nonempty_non_multipart_response(self):
url = "http://api.example.com/other_api"
http = _make_requests_session([_make_response()])
Expand Down Expand Up @@ -497,6 +547,7 @@ def test_as_context_mgr_wo_error(self):

self.assertEqual(list(client._batch_stack), [])
self.assertEqual(len(batch._requests), 3)
self.assertEqual(len(batch._responses), 3)
self.assertEqual(batch._requests[0][0], "POST")
self.assertEqual(batch._requests[1][0], "PATCH")
self.assertEqual(batch._requests[2][0], "DELETE")
Expand All @@ -505,6 +556,43 @@ def test_as_context_mgr_wo_error(self):
self.assertEqual(target2._properties, {"foo": 1, "bar": 3})
self.assertEqual(target3._properties, b"")

def test_as_context_mgr_no_raise_exception(self):
from google.cloud.storage.client import Client

url = "http://api.example.com/other_api"
expected_response = _make_response(
content=_TWO_PART_MIME_RESPONSE_WITH_FAIL,
headers={"content-type": 'multipart/mixed; boundary="DEADBEEF="'},
)
http = _make_requests_session([expected_response])
project = "PROJECT"
credentials = _make_credentials()
client = Client(project=project, credentials=credentials)
client._http_internal = http

self.assertEqual(list(client._batch_stack), [])

target1 = _MockObject()
target2 = _MockObject()

with self._make_one(client, raise_exception=False) as batch:
self.assertEqual(list(client._batch_stack), [batch])
batch._make_request("GET", url, {}, target_object=target1)
batch._make_request("GET", url, {}, target_object=target2)

self.assertEqual(list(client._batch_stack), [])
self.assertEqual(len(batch._requests), 2)
self.assertEqual(len(batch._responses), 2)
self.assertEqual(batch._requests[0][0], "GET")
self.assertEqual(batch._requests[1][0], "GET")
self.assertEqual(batch._target_objects, [target1, target2])

# Make sure NotFound exception is added to responses and target2
self.assertEqual(batch._responses[0].status_code, 200)
self.assertEqual(batch._responses[1].status_code, 404)
self.assertEqual(target1._properties, {"foo": 1, "bar": 2})
self.assertEqual(target2._properties, {"error": {"message": "Not Found"}})

def test_as_context_mgr_w_error(self):
from google.cloud.storage.batch import _FutureDict
from google.cloud.storage.client import Client
Expand Down

0 comments on commit 94a35ba

Please sign in to comment.