Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fix remaining REST transport issues #1428

Merged
merged 4 commits into from
Sep 6, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 5 additions & 1 deletion gapic/schema/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def __str__(self) -> str:
# This module is from a different proto package
# Most commonly happens for a common proto
# https://pypi.org/project/googleapis-common-protos/
if not self.proto_package.startswith(self.api_naming.proto_package):
if self.is_external_type:
module_name = f'{self.module}_pb2'

# Return the dot-separated Python identifier.
Expand All @@ -102,6 +102,10 @@ def __str__(self) -> str:
# Return the Python identifier.
return '.'.join(self.parent + (self.name,))

@property
def is_external_type(self):
return not self.proto_package.startswith(self.api_naming.proto_package)

@cached_property
def __cached_string_repr(self):
return "({})".format(
Expand Down
28 changes: 22 additions & 6 deletions gapic/schema/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,19 @@ def recursive_mock_original_type(field):
# Not worth the hassle, just return an empty map.
return {}

msg_dict = {
f.name: recursive_mock_original_type(f)
for f in field.message.fields.values()
}
adr = field.type.meta.address
if adr.name == "Any" and adr.package == ("google", "protobuf"):
# If it is Any type pack a random but validly encoded type,
# Duration in this specific case.
msg_dict = {
"type_url": "type.googleapis.com/google.protobuf.Duration",
"value": b'\x08\x0c\x10\xdb\x07',
}
else:
msg_dict = {
f.name: recursive_mock_original_type(f)
for f in field.message.fields.values()
}

return [msg_dict] if field.repeated else msg_dict

Expand Down Expand Up @@ -237,9 +246,16 @@ def primitive_mock(self, suffix: int = 0) -> Union[bool, str, bytes, int, float,
if self.type.python_type == bool:
answer = True
elif self.type.python_type == str:
answer = f"{self.name}_value_{suffix}" if suffix else f"{self.name}_value"
if self.name == "type_url":
# It is most likely a mock for Any type. We don't really care
# which mock value to put, so lets put a value which makes
# Any deserializer happy, which will wtill work even if it
# is not Any.
answer = "type.googleapis.com/google.protobuf.Empty"
else:
answer = f"{self.name}_value{suffix}" if suffix else f"{self.name}_value"
elif self.type.python_type == bytes:
answer_str = f"{self.name}_blob_{suffix}" if suffix else f"{self.name}_blob"
answer_str = f"{self.name}_blob{suffix}" if suffix else f"{self.name}_blob"
answer = bytes(answer_str, encoding="utf-8")
elif self.type.python_type == int:
answer = sum([ord(i) for i in self.name]) + suffix
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ from google.api_core import rest_streaming
from google.api_core import path_template
from google.api_core import gapic_v1

from google.protobuf import json_format
{% if service.has_lro %}
from google.api_core import operations_v1
from google.protobuf import json_format
{% endif %}
from requests import __version__ as requests_version
import dataclasses
Expand Down Expand Up @@ -328,20 +328,19 @@ class {{service.name}}RestTransport({{service.name}}Transport):
{% endfor %}{# rule in method.http_options #}
]
request, metadata = self._interceptor.pre_{{ method.name|snake_case }}(request, metadata)
request_kwargs = {{method.input.ident}}.to_dict(request)
transcoded_request = path_template.transcode(
http_options, **request_kwargs)
{% if method.input.ident.is_external_type %}
pb_request = request
{% else %}
pb_request = {{method.input.ident}}.pb(request)
{% endif %}
transcoded_request = path_template.transcode(http_options, pb_request)

{% set body_spec = method.http_options[0].body %}
{%- if body_spec %}
# Jsonify the request body
body = {% if body_spec == '*' -%}
{{method.input.ident}}.to_json(
{{method.input.ident}}(transcoded_request['body']),
{% else -%}
{{method.input.fields[body_spec].type.ident}}.to_json(
{{method.input.fields[body_spec].type.ident}}(transcoded_request['body']),
{% endif %}{# body_spec == "*" #}

body = json_format.MessageToJson(
transcoded_request['body'],
including_default_value_fields=False,
use_integers_for_enums={{ opts.rest_numeric_enums }}
)
Expand All @@ -351,12 +350,11 @@ class {{service.name}}RestTransport({{service.name}}Transport):
method = transcoded_request['method']

# Jsonify the query params
query_params = json.loads({{method.input.ident}}.to_json(
{{method.input.ident}}(transcoded_request['query_params']),
query_params = json.loads(json_format.MessageToJson(
transcoded_request['query_params'],
including_default_value_fields=False,
use_integers_for_enums={{ opts.rest_numeric_enums }}
use_integers_for_enums={{ opts.rest_numeric_enums }},
))

{% if method.input.required_fields %}
query_params.update(self._get_unset_required_fields(query_params))
{% endif %}{# required fields #}
Expand Down Expand Up @@ -391,10 +389,14 @@ class {{service.name}}RestTransport({{service.name}}Transport):
{% elif method.server_streaming %}
resp = rest_streaming.ResponseIterator(response, {{method.output.ident}})
{% else %}
resp = {{method.output.ident}}.from_json(
response.content,
ignore_unknown_fields=True
)
resp = {{method.output.ident}}()
{% if method.output.ident.is_external_type %}
pb_resp = resp
{% else %}
pb_resp = {{method.output.ident}}.pb(resp)
{% endif %}

json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True)
{% endif %}{# method.lro #}
resp = self._interceptor.post_{{ method.name|snake_case }}(resp)
return resp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import grpc
from grpc.experimental import aio
{% if "rest" in opts.transport %}
from collections.abc import Iterable
from google.protobuf import json_format
import json
{% endif %}
import math
Expand Down Expand Up @@ -51,9 +52,6 @@ from google.api_core import future
from google.api_core import operation
from google.api_core import operations_v1
from google.longrunning import operations_pb2
{% if "rest" in opts.transport %}
from google.protobuf import json_format
{% endif %}{# rest transport #}
{% endif %}{# lro #}
{% if api.has_location_mixin %}
from google.cloud.location import locations_pb2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -855,7 +855,7 @@ def test_{{ method_name }}_rest(request_type):
request_init["{{ field.name }}"] = {{ field.merged_mock_value(method.http_options[0].sample_request(method).get(field.name)) }}
{% endif %}
{% endfor %}
request = request_type(request_init)
request = request_type(**request_init)
{% if method.client_streaming %}
requests = [request]
{% endif %}
Expand Down Expand Up @@ -902,11 +902,19 @@ def test_{{ method_name }}_rest(request_type):
json_return_value = ''
{% elif method.lro %}
json_return_value = json_format.MessageToJson(return_value)
{% elif method.server_streaming %}
json_return_value = "[{}]".format({{ method.output.ident }}.to_json(return_value))
{% else %}
json_return_value = {{ method.output.ident }}.to_json(return_value)
{% if method.output.ident.is_external_type %}
pb_return_value = return_value
{% else %}
pb_return_value = {{ method.output.ident }}.pb(return_value)
{% endif %}
json_return_value = json_format.MessageToJson(pb_return_value)
{% endif %}

{% if method.server_streaming %}
json_return_value = "[{}]".format(json_return_value)
{% endif %}

response_value._content = json_return_value.encode('UTF-8')
req.return_value = response_value
{% if method.client_streaming %}
Expand Down Expand Up @@ -965,20 +973,25 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
request_init["{{ req_field.name }}"] = {{ req_field.type.python_type(req_field.field_pb.default_value or 0) }}
{% endif %}{# default is str #}
{% endfor %}
request = request_type(request_init)
jsonified_request = json.loads(request_type.to_json(
request,
request = request_type(**request_init)
{% if method.input.ident.is_external_type %}
pb_request = request
{% else %}
pb_request = request_type.pb(request)
{% endif %}
jsonified_request = json.loads(json_format.MessageToJson(
pb_request,
including_default_value_fields=False,
use_integers_for_enums=False
))
))

# verify fields with default values are dropped
{% for req_field in method.input.required_fields if req_field.is_primitive and req_field.name in method.query_params %}
{% set field_name = req_field.name | camel_case %}
assert "{{ field_name }}" not in jsonified_request
{% endfor %}

unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).{{ method.name | snake_case }}._get_unset_required_fields(jsonified_request)
unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).{{ method.transport_safe_name | snake_case }}._get_unset_required_fields(jsonified_request)
jsonified_request.update(unset_fields)

# verify required fields with default values are now present
Expand All @@ -994,7 +1007,7 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
jsonified_request["{{ field_name }}"] = {{ mock_value }}
{% endfor %}

unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).{{ method.name | snake_case }}._get_unset_required_fields(jsonified_request)
unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).{{ method.transport_safe_name | snake_case }}._get_unset_required_fields(jsonified_request)
{% if method.query_params %}
# Check that path parameters and body parameters are not mixing in.
assert not set(unset_fields) - set(({% for param in method.query_params|sort %}"{{param}}", {% endfor %}))
Expand All @@ -1014,7 +1027,7 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
credentials=ga_credentials.AnonymousCredentials(),
transport='rest',
)
request = request_type(request_init)
request = request_type(**request_init)

# Designate an appropriate value for the returned response.
{% if method.void %}
Expand All @@ -1032,13 +1045,18 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
with mock.patch.object(path_template, 'transcode') as transcode:
# A uri without fields and an empty body will force all the
# request fields to show up in the query_params.
{% if method.input.ident.is_external_type %}
pb_request = request
{% else %}
pb_request = request_type.pb(request)
{% endif %}
transcode_result = {
'uri': 'v1/sample_method',
'method': "{{ method.http_options[0].method }}",
'query_params': request_init,
'query_params': pb_request,
}
{% if method.http_options[0].body %}
transcode_result['body'] = {}
transcode_result['body'] = pb_request
{% endif %}
transcode.return_value = transcode_result

Expand All @@ -1048,11 +1066,19 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
json_return_value = ''
{% elif method.lro %}
json_return_value = json_format.MessageToJson(return_value)
{% elif method.server_streaming %}
json_return_value = "[{}]".format({{ method.output.ident }}.to_json(return_value))
{% else %}
json_return_value = {{ method.output.ident }}.to_json(return_value)

{% if method.output.ident.is_external_type %}
pb_return_value = return_value
{% else %}
pb_return_value = {{ method.output.ident }}.pb(return_value)
{% endif %}
json_return_value = json_format.MessageToJson(pb_return_value)
{% endif %}
{% if method.server_streaming %}
json_return_value = "[{}]".format(json_return_value)
{% endif %}

response_value._content = json_return_value.encode('UTF-8')
req.return_value = response_value

Expand Down Expand Up @@ -1115,8 +1141,17 @@ def test_{{ method_name }}_rest_interceptors(null_interceptor):
{% if not method.void %}
post.assert_not_called()
{% endif %}

transcode.return_value = {"method": "post", "uri": "my_uri", "body": None, "query_params": {},}
{% if method.input.ident.is_external_type %}
pb_message = {{ method.input.ident }}()
{% else %}
pb_message = {{ method.input.ident }}.pb({{ method.input.ident }}())
{% endif %}
transcode.return_value = {
"method": "post",
"uri": "my_uri",
"body": pb_message,
"query_params": pb_message,
}

req.return_value = Response()
req.return_value.status_code = 200
Expand Down Expand Up @@ -1164,7 +1199,7 @@ def test_{{ method_name }}_rest_bad_request(transport: str = 'rest', request_typ
request_init["{{ field.name }}"] = {{ field.merged_mock_value(method.http_options[0].sample_request(method).get(field.name)) }}
{% endif %}
{% endfor %}
request = request_type(request_init)
request = request_type(**request_init)
{% if method.client_streaming %}
requests = [request]
{% endif %}
Expand Down Expand Up @@ -1222,12 +1257,17 @@ def test_{{ method_name }}_rest_flattened():
json_return_value = ''
{% elif method.lro %}
json_return_value = json_format.MessageToJson(return_value)
{% elif method.server_streaming %}
json_return_value = "[{}]".format({{ method.output.ident }}.to_json(return_value))
{% else %}
json_return_value = {{ method.output.ident }}.to_json(return_value)
{% if method.output.ident.is_external_type %}
pb_return_value = return_value
{% else %}
pb_return_value = {{ method.output.ident }}.pb(return_value)
{% endif %}
json_return_value = json_format.MessageToJson(pb_return_value)
{% endif %}
{% if method.server_streaming %}
json_return_value = "[{}]".format(json_return_value)
{% endif %}

response_value._content = json_return_value.encode('UTF-8')
req.return_value = response_value

Expand Down Expand Up @@ -1342,7 +1382,7 @@ def test_{{ method_name }}_rest_pager(transport: str = 'rest'):
return_values = tuple(Response() for i in response)
for return_val, response_val in zip(return_values, response):
{% if method.server_streaming %}
response_val = "[{}]".format({{ method.output.ident }}.to_json(response_val))
response_val = "[{}]".format(response_val)
{% endif %}
return_val._content = response_val.encode('UTF-8')
return_val.status_code = 200
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ async def sample_generate_access_token():
# Initialize request argument(s)
request = credentials_v1.GenerateAccessTokenRequest(
name="name_value",
scope=['scope_value_1', 'scope_value_2'],
scope=['scope_value1', 'scope_value2'],
)

# Make the request
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def sample_generate_access_token():
# Initialize request argument(s)
request = credentials_v1.GenerateAccessTokenRequest(
name="name_value",
scope=['scope_value_1', 'scope_value_2'],
scope=['scope_value1', 'scope_value2'],
)

# Make the request
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ async def sample_generate_access_token():
# Initialize request argument(s)
request = credentials_v1.GenerateAccessTokenRequest(
name="name_value",
scope=['scope_value_1', 'scope_value_2'],
scope=['scope_value1', 'scope_value2'],
)

# Make the request
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def sample_generate_access_token():
# Initialize request argument(s)
request = credentials_v1.GenerateAccessTokenRequest(
name="name_value",
scope=['scope_value_1', 'scope_value_2'],
scope=['scope_value1', 'scope_value2'],
)

# Make the request
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ async def sample_list_log_entries():

# Initialize request argument(s)
request = logging_v2.ListLogEntriesRequest(
resource_names=['resource_names_value_1', 'resource_names_value_2'],
resource_names=['resource_names_value1', 'resource_names_value2'],
)

# Make the request
Expand Down Expand Up @@ -867,7 +867,7 @@ async def sample_tail_log_entries():

# Initialize request argument(s)
request = logging_v2.TailLogEntriesRequest(
resource_names=['resource_names_value_1', 'resource_names_value_2'],
resource_names=['resource_names_value1', 'resource_names_value2'],
)

# This method expects an iterator which contains
Expand Down