Skip to content

Commit

Permalink
fix: add user agent in python-storage when calling resumable media (W…
Browse files Browse the repository at this point in the history
…IP) (#715)

* fix: add user agent in python-storage when calling resumable media (WIP)

* put the things in the right places

* starting to get some tests working

* a bit closer, still some things not quite working

* almost there

* first cleanup

* ensure up to date resumable media

* next round of cleanup

* lint

* update resumable media

* get tests passing

* lint
  • Loading branch information
Aaron Gabriel Neyer committed Mar 9, 2022
1 parent 4fbbf02 commit c7bf615
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 34 deletions.
22 changes: 22 additions & 0 deletions google/cloud/storage/_helpers.py
Expand Up @@ -581,3 +581,25 @@ def _api_core_retry_to_resumable_media_retry(retry, num_retries=None):
return resumable_media.RetryStrategy(max_retries=num_retries)
else:
return resumable_media.RetryStrategy(max_retries=0)


def _get_default_headers(
user_agent,
content_type="application/json; charset=UTF-8",
x_upload_content_type=None,
):
"""Get the headers for a request.
Args:
user_agent (str): The user-agent for requests.
Returns:
Dict: The headers to be used for the request.
"""
return {
"Accept": "application/json",
"Accept-Encoding": "gzip, deflate",
"User-Agent": user_agent,
"x-goog-api-client": user_agent,
"content-type": content_type,
"x-upload-content-type": x_upload_content_type or content_type,
}
16 changes: 9 additions & 7 deletions google/cloud/storage/blob.py
Expand Up @@ -64,6 +64,7 @@
from google.cloud.storage._helpers import _bucket_bound_hostname_url
from google.cloud.storage._helpers import _raise_if_more_than_one_set
from google.cloud.storage._helpers import _api_core_retry_to_resumable_media_retry
from google.cloud.storage._helpers import _get_default_headers
from google.cloud.storage._signing import generate_signed_url_v2
from google.cloud.storage._signing import generate_signed_url_v4
from google.cloud.storage._helpers import _NUM_RETRIES_MESSAGE
Expand Down Expand Up @@ -1720,7 +1721,7 @@ def _get_writable_metadata(self):

return object_metadata

def _get_upload_arguments(self, content_type):
def _get_upload_arguments(self, client, content_type):
"""Get required arguments for performing an upload.
The content type returned will be determined in order of precedence:
Expand All @@ -1739,9 +1740,12 @@ def _get_upload_arguments(self, content_type):
* An object metadata dictionary
* The ``content_type`` as a string (according to precedence)
"""
headers = _get_encryption_headers(self._encryption_key)
object_metadata = self._get_writable_metadata()
content_type = self._get_content_type(content_type)
headers = {
**_get_default_headers(client._connection.user_agent, content_type),
**_get_encryption_headers(self._encryption_key),
}
object_metadata = self._get_writable_metadata()
return headers, object_metadata, content_type

def _do_multipart_upload(
Expand Down Expand Up @@ -1860,7 +1864,7 @@ def _do_multipart_upload(
transport = self._get_transport(client)
if "metadata" in self._properties and "metadata" not in self._changes:
self._changes.add("metadata")
info = self._get_upload_arguments(content_type)
info = self._get_upload_arguments(client, content_type)
headers, object_metadata, content_type = info

hostname = _get_host_name(client._connection)
Expand Down Expand Up @@ -2045,7 +2049,7 @@ def _initiate_resumable_upload(
transport = self._get_transport(client)
if "metadata" in self._properties and "metadata" not in self._changes:
self._changes.add("metadata")
info = self._get_upload_arguments(content_type)
info = self._get_upload_arguments(client, content_type)
headers, object_metadata, content_type = info
if extra_headers is not None:
headers.update(extra_headers)
Expand Down Expand Up @@ -2230,15 +2234,13 @@ def _do_resumable_upload(
checksum=checksum,
retry=retry,
)

while not upload.finished:
try:
response = upload.transmit_next_chunk(transport, timeout=timeout)
except resumable_media.DataCorruption:
# Attempt to delete the corrupted object.
self.delete()
raise

return response

def _do_upload(
Expand Down
2 changes: 2 additions & 0 deletions google/cloud/storage/client.py
Expand Up @@ -31,6 +31,7 @@
from google.cloud._helpers import _LocalStack, _NOW
from google.cloud.client import ClientWithProject
from google.cloud.exceptions import NotFound
from google.cloud.storage._helpers import _get_default_headers
from google.cloud.storage._helpers import _get_environ_project
from google.cloud.storage._helpers import _get_storage_host
from google.cloud.storage._helpers import _BASE_STORAGE_URI
Expand Down Expand Up @@ -1131,6 +1132,7 @@ def download_blob_to_file(
_add_etag_match_headers(
headers, if_etag_match=if_etag_match, if_etag_not_match=if_etag_not_match,
)
headers = {**_get_default_headers(self._connection.user_agent), **headers}

transport = self._http
try:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -31,7 +31,7 @@
"google-auth >= 1.25.0, < 3.0dev",
"google-api-core >= 1.31.5, <3.0.0dev,!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0",
"google-cloud-core >= 1.6.0, < 3.0dev",
"google-resumable-media >= 1.3.0",
"google-resumable-media >= 2.3.2",
"requests >= 2.18.0, < 3.0.0dev",
"protobuf",
]
Expand Down
87 changes: 63 additions & 24 deletions tests/unit/test_blob.py
Expand Up @@ -26,6 +26,7 @@
import mock
import pytest

from google.cloud.storage._helpers import _get_default_headers
from google.cloud.storage.retry import (
DEFAULT_RETRY,
DEFAULT_RETRY_IF_METAGENERATION_SPECIFIED,
Expand Down Expand Up @@ -2212,16 +2213,19 @@ def test__set_metadata_to_none(self):
def test__get_upload_arguments(self):
name = u"blob-name"
key = b"[pXw@,p@@AfBfrR3x-2b2SCHR,.?YwRO"
client = mock.Mock(_connection=_Connection)
client._connection.user_agent = "testing 1.2.3"
blob = self._make_one(name, bucket=None, encryption_key=key)
blob.content_disposition = "inline"

content_type = u"image/jpeg"
info = blob._get_upload_arguments(content_type)
info = blob._get_upload_arguments(client, content_type)

headers, object_metadata, new_content_type = info
header_key_value = "W3BYd0AscEBAQWZCZnJSM3gtMmIyU0NIUiwuP1l3Uk8="
header_key_hash_value = "G0++dxF4q5rG4o9kE8gvEKn15RH6wLm0wXV1MgAlXOg="
expected_headers = {
**_get_default_headers(client._connection.user_agent, content_type),
"X-Goog-Encryption-Algorithm": "AES256",
"X-Goog-Encryption-Key": header_key_value,
"X-Goog-Encryption-Key-Sha256": header_key_hash_value,
Expand Down Expand Up @@ -2368,7 +2372,11 @@ def _do_multipart_success(
+ data_read
+ b"\r\n--==0==--"
)
headers = {"content-type": b'multipart/related; boundary="==0=="'}
headers = _get_default_headers(
client._connection.user_agent,
b'multipart/related; boundary="==0=="',
"application/xml",
)
client._http.request.assert_called_once_with(
"POST", upload_url, data=payload, headers=headers, timeout=expected_timeout
)
Expand Down Expand Up @@ -2614,10 +2622,17 @@ def _initiate_resumable_helper(

self.assertEqual(upload.upload_url, upload_url)
if extra_headers is None:
self.assertEqual(upload._headers, {})
self.assertEqual(
upload._headers,
_get_default_headers(client._connection.user_agent, content_type),
)
else:
self.assertEqual(upload._headers, extra_headers)
self.assertIsNot(upload._headers, extra_headers)
expected_headers = {
**_get_default_headers(client._connection.user_agent, content_type),
**extra_headers,
}
self.assertEqual(upload._headers, expected_headers)
self.assertIsNot(upload._headers, expected_headers)
self.assertFalse(upload.finished)
if chunk_size is None:
if blob_chunk_size is None:
Expand Down Expand Up @@ -2656,10 +2671,9 @@ def _initiate_resumable_helper(
# Check the mocks.
blob._get_writable_metadata.assert_called_once_with()
payload = json.dumps(object_metadata).encode("utf-8")
expected_headers = {
"content-type": "application/json; charset=UTF-8",
"x-upload-content-type": content_type,
}
expected_headers = _get_default_headers(
client._connection.user_agent, x_upload_content_type=content_type
)
if size is not None:
expected_headers["x-upload-content-length"] = str(size)
if extra_headers is not None:
Expand Down Expand Up @@ -2778,6 +2792,7 @@ def _make_resumable_transport(

@staticmethod
def _do_resumable_upload_call0(
client,
blob,
content_type,
size=None,
Expand All @@ -2796,10 +2811,9 @@ def _do_resumable_upload_call0(
)
if predefined_acl is not None:
upload_url += "&predefinedAcl={}".format(predefined_acl)
expected_headers = {
"content-type": "application/json; charset=UTF-8",
"x-upload-content-type": content_type,
}
expected_headers = _get_default_headers(
client._connection.user_agent, x_upload_content_type=content_type
)
if size is not None:
expected_headers["x-upload-content-length"] = str(size)
payload = json.dumps({"name": blob.name}).encode("utf-8")
Expand All @@ -2809,6 +2823,7 @@ def _do_resumable_upload_call0(

@staticmethod
def _do_resumable_upload_call1(
client,
blob,
content_type,
data,
Expand All @@ -2828,6 +2843,9 @@ def _do_resumable_upload_call1(
content_range = "bytes 0-{:d}/{:d}".format(blob.chunk_size - 1, size)

expected_headers = {
**_get_default_headers(
client._connection.user_agent, x_upload_content_type=content_type
),
"content-type": content_type,
"content-range": content_range,
}
Expand All @@ -2842,6 +2860,7 @@ def _do_resumable_upload_call1(

@staticmethod
def _do_resumable_upload_call2(
client,
blob,
content_type,
data,
Expand All @@ -2859,6 +2878,9 @@ def _do_resumable_upload_call2(
blob.chunk_size, total_bytes - 1, total_bytes
)
expected_headers = {
**_get_default_headers(
client._connection.user_agent, x_upload_content_type=content_type
),
"content-type": content_type,
"content-range": content_range,
}
Expand All @@ -2884,13 +2906,11 @@ def _do_resumable_helper(
data_corruption=False,
retry=None,
):
bucket = _Bucket(name="yesterday")
blob = self._make_one(u"blob-name", bucket=bucket)
blob.chunk_size = blob._CHUNK_SIZE_MULTIPLE
self.assertIsNotNone(blob.chunk_size)

CHUNK_SIZE = 256 * 1024
USER_AGENT = "testing 1.2.3"
content_type = u"text/html"
# Data to be uploaded.
data = b"<html>" + (b"A" * blob.chunk_size) + b"</html>"
data = b"<html>" + (b"A" * CHUNK_SIZE) + b"</html>"
total_bytes = len(data)
if use_size:
size = total_bytes
Expand All @@ -2899,17 +2919,29 @@ def _do_resumable_helper(

# Create mocks to be checked for doing transport.
resumable_url = "http://test.invalid?upload_id=and-then-there-was-1"
headers1 = {"location": resumable_url}
headers2 = {"range": "bytes=0-{:d}".format(blob.chunk_size - 1)}
headers1 = {
**_get_default_headers(USER_AGENT, content_type),
"location": resumable_url,
}
headers2 = {
**_get_default_headers(USER_AGENT, content_type),
"range": "bytes=0-{:d}".format(CHUNK_SIZE - 1),
}
headers3 = _get_default_headers(USER_AGENT, content_type)
transport, responses = self._make_resumable_transport(
headers1, headers2, {}, total_bytes, data_corruption=data_corruption
headers1, headers2, headers3, total_bytes, data_corruption=data_corruption
)

# Create some mock arguments and call the method under test.
client = mock.Mock(_http=transport, _connection=_Connection, spec=["_http"])
client._connection.API_BASE_URL = "https://storage.googleapis.com"
client._connection.user_agent = USER_AGENT
stream = io.BytesIO(data)
content_type = u"text/html"

bucket = _Bucket(name="yesterday")
blob = self._make_one(u"blob-name", bucket=bucket)
blob.chunk_size = blob._CHUNK_SIZE_MULTIPLE
self.assertIsNotNone(blob.chunk_size)

if timeout is None:
expected_timeout = self._get_default_timeout()
Expand Down Expand Up @@ -2939,6 +2971,7 @@ def _do_resumable_helper(

# Check the mocks.
call0 = self._do_resumable_upload_call0(
client,
blob,
content_type,
size=size,
Expand All @@ -2950,6 +2983,7 @@ def _do_resumable_helper(
timeout=expected_timeout,
)
call1 = self._do_resumable_upload_call1(
client,
blob,
content_type,
data,
Expand All @@ -2963,6 +2997,7 @@ def _do_resumable_helper(
timeout=expected_timeout,
)
call2 = self._do_resumable_upload_call2(
client,
blob,
content_type,
data,
Expand Down Expand Up @@ -3510,6 +3545,7 @@ def _create_resumable_upload_session_helper(
size = 10000
client = mock.Mock(_http=transport, _connection=_Connection, spec=[u"_http"])
client._connection.API_BASE_URL = "https://storage.googleapis.com"
client._connection.user_agent = "testing 1.2.3"

if timeout is None:
expected_timeout = self._get_default_timeout()
Expand Down Expand Up @@ -3556,7 +3592,9 @@ def _create_resumable_upload_session_helper(
upload_url += "?" + urlencode(qs_params)
payload = b'{"name": "blob-name"}'
expected_headers = {
"content-type": "application/json; charset=UTF-8",
**_get_default_headers(
client._connection.user_agent, x_upload_content_type=content_type
),
"x-upload-content-length": str(size),
"x-upload-content-type": content_type,
}
Expand Down Expand Up @@ -5739,6 +5777,7 @@ class _Connection(object):

API_BASE_URL = "http://example.com"
USER_AGENT = "testing 1.2.3"
user_agent = "testing 1.2.3"
credentials = object()


Expand Down
12 changes: 10 additions & 2 deletions tests/unit/test_client.py
Expand Up @@ -28,6 +28,7 @@
from google.oauth2.service_account import Credentials

from google.cloud.storage._helpers import STORAGE_EMULATOR_ENV_VAR
from google.cloud.storage._helpers import _get_default_headers
from google.cloud.storage.retry import DEFAULT_RETRY
from google.cloud.storage.retry import DEFAULT_RETRY_IF_GENERATION_SPECIFIED

Expand Down Expand Up @@ -1567,7 +1568,10 @@ def test_download_blob_to_file_with_failure(self):

self.assertEqual(file_obj.tell(), 0)

headers = {"accept-encoding": "gzip"}
headers = {
**_get_default_headers(client._connection.user_agent),
"accept-encoding": "gzip",
}
blob._do_download.assert_called_once_with(
client._http,
file_obj,
Expand Down Expand Up @@ -1598,7 +1602,10 @@ def test_download_blob_to_file_with_uri(self):
):
client.download_blob_to_file("gs://bucket_name/path/to/object", file_obj)

headers = {"accept-encoding": "gzip"}
headers = {
**_get_default_headers(client._connection.user_agent),
"accept-encoding": "gzip",
}
blob._do_download.assert_called_once_with(
client._http,
file_obj,
Expand Down Expand Up @@ -1714,6 +1721,7 @@ def _download_blob_to_file_helper(
if_etag_not_match = [if_etag_not_match]
headers["If-None-Match"] = ", ".join(if_etag_not_match)

headers = {**_get_default_headers(client._connection.user_agent), **headers}
blob._do_download.assert_called_once_with(
client._http,
file_obj,
Expand Down

0 comments on commit c7bf615

Please sign in to comment.