diff --git a/google/auth/_exponential_backoff.py b/google/auth/_exponential_backoff.py new file mode 100644 index 000000000..b5801bec9 --- /dev/null +++ b/google/auth/_exponential_backoff.py @@ -0,0 +1,111 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import time + +import six + +# The default amount of retry attempts +_DEFAULT_RETRY_TOTAL_ATTEMPTS = 3 + +# The default initial backoff period (1.0 second). +_DEFAULT_INITIAL_INTERVAL_SECONDS = 1.0 + +# The default randomization factor (0.1 which results in a random period ranging +# between 10% below and 10% above the retry interval). +_DEFAULT_RANDOMIZATION_FACTOR = 0.1 + +# The default multiplier value (2 which is 100% increase per back off). +_DEFAULT_MULTIPLIER = 2.0 + +"""Exponential Backoff Utility + +This is a private module that implements the exponential back off algorithm. +It can be used as a utility for code that needs to retry on failure, for example +an HTTP request. +""" + + +class ExponentialBackoff(six.Iterator): + """An exponential backoff iterator. This can be used in a for loop to + perform requests with exponential backoff. + + Args: + total_attempts Optional[int]: + The maximum amount of retries that should happen. + The default value is 3 attempts. + initial_wait_seconds Optional[int]: + The amount of time to sleep in the first backoff. This parameter + should be in seconds. + The default value is 1 second. + randomization_factor Optional[float]: + The amount of jitter that should be in each backoff. For example, + a value of 0.1 will introduce a jitter range of 10% to the + current backoff period. + The default value is 0.1. + multiplier Optional[float]: + The backoff multipler. This adjusts how much each backoff will + increase. For example a value of 2.0 leads to a 200% backoff + on each attempt. If the initial_wait is 1.0 it would look like + this sequence [1.0, 2.0, 4.0, 8.0]. + The default value is 2.0. + """ + + def __init__( + self, + total_attempts=_DEFAULT_RETRY_TOTAL_ATTEMPTS, + initial_wait_seconds=_DEFAULT_INITIAL_INTERVAL_SECONDS, + randomization_factor=_DEFAULT_RANDOMIZATION_FACTOR, + multiplier=_DEFAULT_MULTIPLIER, + ): + self._total_attempts = total_attempts + self._initial_wait_seconds = initial_wait_seconds + + self._current_wait_in_seconds = self._initial_wait_seconds + + self._randomization_factor = randomization_factor + self._multiplier = multiplier + self._backoff_count = 0 + + def __iter__(self): + self._backoff_count = 0 + self._current_wait_in_seconds = self._initial_wait_seconds + return self + + def __next__(self): + if self._backoff_count >= self._total_attempts: + raise StopIteration + self._backoff_count += 1 + + jitter_variance = self._current_wait_in_seconds * self._randomization_factor + jitter = random.uniform( + self._current_wait_in_seconds - jitter_variance, + self._current_wait_in_seconds + jitter_variance, + ) + + time.sleep(jitter) + + self._current_wait_in_seconds *= self._multiplier + return self._backoff_count + + @property + def total_attempts(self): + """The total amount of backoff attempts that will be made.""" + return self._total_attempts + + @property + def backoff_count(self): + """The current amount of backoff attempts that have been made.""" + return self._backoff_count diff --git a/google/auth/exceptions.py b/google/auth/exceptions.py index e9e737780..7760c87b8 100644 --- a/google/auth/exceptions.py +++ b/google/auth/exceptions.py @@ -18,6 +18,15 @@ class GoogleAuthError(Exception): """Base class for all google.auth errors.""" + def __init__(self, *args, **kwargs): + super(GoogleAuthError, self).__init__(*args) + retryable = kwargs.get("retryable", False) + self._retryable = retryable + + @property + def retryable(self): + return self._retryable + class TransportError(GoogleAuthError): """Used to indicate an error occurred during an HTTP request.""" @@ -44,6 +53,10 @@ class MutualTLSChannelError(GoogleAuthError): class ClientCertError(GoogleAuthError): """Used to indicate that client certificate is missing or invalid.""" + @property + def retryable(self): + return False + class OAuthError(GoogleAuthError): """Used to indicate an error occurred during an OAuth related HTTP @@ -53,9 +66,9 @@ class OAuthError(GoogleAuthError): class ReauthFailError(RefreshError): """An exception for when reauth failed.""" - def __init__(self, message=None): + def __init__(self, message=None, **kwargs): super(ReauthFailError, self).__init__( - "Reauthentication failed. {0}".format(message) + "Reauthentication failed. {0}".format(message), **kwargs ) diff --git a/google/auth/transport/__init__.py b/google/auth/transport/__init__.py index 374e7b4d7..8334145a1 100644 --- a/google/auth/transport/__init__.py +++ b/google/auth/transport/__init__.py @@ -29,9 +29,21 @@ import six from six.moves import http_client +TOO_MANY_REQUESTS = 429 # Python 2.7 six is missing this status code. + +DEFAULT_RETRYABLE_STATUS_CODES = ( + http_client.INTERNAL_SERVER_ERROR, + http_client.SERVICE_UNAVAILABLE, + http_client.REQUEST_TIMEOUT, + TOO_MANY_REQUESTS, +) +"""Sequence[int]: HTTP status codes indicating a request can be retried. +""" + + DEFAULT_REFRESH_STATUS_CODES = (http_client.UNAUTHORIZED,) """Sequence[int]: Which HTTP status code indicate that credentials should be -refreshed and a request should be retried. +refreshed. """ DEFAULT_MAX_REFRESH_ATTEMPTS = 2 diff --git a/google/oauth2/_client.py b/google/oauth2/_client.py index 847c5db8a..7f866d446 100644 --- a/google/oauth2/_client.py +++ b/google/oauth2/_client.py @@ -30,9 +30,11 @@ from six.moves import http_client from six.moves import urllib +from google.auth import _exponential_backoff from google.auth import _helpers from google.auth import exceptions from google.auth import jwt +from google.auth import transport _URLENCODED_CONTENT_TYPE = "application/x-www-form-urlencoded" _JSON_CONTENT_TYPE = "application/json" @@ -40,17 +42,22 @@ _REFRESH_GRANT_TYPE = "refresh_token" -def _handle_error_response(response_data): +def _handle_error_response(response_data, retryable_error): """Translates an error response into an exception. Args: response_data (Mapping | str): The decoded response data. + retryable_error Optional[bool]: A boolean indicating if an error is retryable. + Defaults to False. Raises: google.auth.exceptions.RefreshError: The errors contained in response_data. """ + + retryable_error = retryable_error if retryable_error else False + if isinstance(response_data, six.string_types): - raise exceptions.RefreshError(response_data) + raise exceptions.RefreshError(response_data, retryable=retryable_error) try: error_details = "{}: {}".format( response_data["error"], response_data.get("error_description") @@ -59,7 +66,45 @@ def _handle_error_response(response_data): except (KeyError, ValueError): error_details = json.dumps(response_data) - raise exceptions.RefreshError(error_details, response_data) + raise exceptions.RefreshError( + error_details, response_data, retryable=retryable_error + ) + + +def _can_retry(status_code, response_data): + """Checks if a request can be retried by inspecting the status code + and response body of the request. + + Args: + status_code (int): The response status code. + response_data (Mapping | str): The decoded response data. + + Returns: + bool: True if the response is retryable. False otherwise. + """ + if status_code in transport.DEFAULT_RETRYABLE_STATUS_CODES: + return True + + try: + # For a failed response, response_body could be a string + error_desc = response_data.get("error_description") or "" + error_code = response_data.get("error") or "" + + # Per Oauth 2.0 RFC https://www.rfc-editor.org/rfc/rfc6749.html#section-4.1.2.1 + # This is needed because a redirect will not return a 500 status code. + retryable_error_descriptions = { + "internal_failure", + "server_error", + "temporarily_unavailable", + } + + if any(e in retryable_error_descriptions for e in (error_code, error_desc)): + return True + + except AttributeError: + pass + + return False def _parse_expiry(response_data): @@ -81,7 +126,13 @@ def _parse_expiry(response_data): def _token_endpoint_request_no_throw( - request, token_uri, body, access_token=None, use_json=False, **kwargs + request, + token_uri, + body, + access_token=None, + use_json=False, + can_retry=True, + **kwargs ): """Makes a request to the OAuth 2.0 authorization server's token endpoint. This function doesn't throw on response errors. @@ -95,6 +146,7 @@ def _token_endpoint_request_no_throw( access_token (Optional(str)): The access token needed to make the request. use_json (Optional(bool)): Use urlencoded format or json format for the content type. The default value is False. + can_retry (bool): Enable or disable request retry behavior. kwargs: Additional arguments passed on to the request method. The kwargs will be passed to `requests.request` method, see: https://docs.python-requests.org/en/latest/api/#requests.request. @@ -104,8 +156,10 @@ def _token_endpoint_request_no_throw( side SSL certificate verification. Returns: - Tuple(bool, Mapping[str, str]): A boolean indicating if the request is - successful, and a mapping for the JSON-decoded response data. + Tuple(bool, Mapping[str, str], Optional[bool]): A boolean indicating + if the request is successful, a mapping for the JSON-decoded response + data and in the case of an error a boolean indicating if the error + is retryable. """ if use_json: headers = {"Content-Type": _JSON_CONTENT_TYPE} @@ -117,10 +171,7 @@ def _token_endpoint_request_no_throw( if access_token: headers["Authorization"] = "Bearer {}".format(access_token) - retry = 0 - # retry to fetch token for maximum of two times if any internal failure - # occurs. - while True: + def _perform_request(): response = request( method="POST", url=token_uri, headers=headers, body=body, **kwargs ) @@ -129,32 +180,44 @@ def _token_endpoint_request_no_throw( if hasattr(response.data, "decode") else response.data ) - - if response.status == http_client.OK: + response_data = "" + try: # response_body should be a JSON response_data = json.loads(response_body) - break - else: - # For a failed response, response_body could be a string - try: - response_data = json.loads(response_body) - error_desc = response_data.get("error_description") or "" - error_code = response_data.get("error") or "" - if ( - any(e == "internal_failure" for e in (error_code, error_desc)) - and retry < 1 - ): - retry += 1 - continue - except ValueError: - response_data = response_body - return False, response_data - - return True, response_data + except ValueError: + response_data = response_body + + if response.status == http_client.OK: + return True, response_data, None + + retryable_error = _can_retry( + status_code=response.status, response_data=response_data + ) + + return False, response_data, retryable_error + + request_succeeded, response_data, retryable_error = _perform_request() + + if request_succeeded or not retryable_error or not can_retry: + return request_succeeded, response_data, retryable_error + + retries = _exponential_backoff.ExponentialBackoff() + for _ in retries: + request_succeeded, response_data, retryable_error = _perform_request() + if request_succeeded or not retryable_error: + return request_succeeded, response_data, retryable_error + + return False, response_data, retryable_error def _token_endpoint_request( - request, token_uri, body, access_token=None, use_json=False, **kwargs + request, + token_uri, + body, + access_token=None, + use_json=False, + can_retry=True, + **kwargs ): """Makes a request to the OAuth 2.0 authorization server's token endpoint. @@ -167,6 +230,7 @@ def _token_endpoint_request( access_token (Optional(str)): The access token needed to make the request. use_json (Optional(bool)): Use urlencoded format or json format for the content type. The default value is False. + can_retry (bool): Enable or disable request retry behavior. kwargs: Additional arguments passed on to the request method. The kwargs will be passed to `requests.request` method, see: https://docs.python-requests.org/en/latest/api/#requests.request. @@ -182,15 +246,22 @@ def _token_endpoint_request( google.auth.exceptions.RefreshError: If the token endpoint returned an error. """ - response_status_ok, response_data = _token_endpoint_request_no_throw( - request, token_uri, body, access_token=access_token, use_json=use_json, **kwargs + + response_status_ok, response_data, retryable_error = _token_endpoint_request_no_throw( + request, + token_uri, + body, + access_token=access_token, + use_json=use_json, + can_retry=can_retry, + **kwargs ) if not response_status_ok: - _handle_error_response(response_data) + _handle_error_response(response_data, retryable_error) return response_data -def jwt_grant(request, token_uri, assertion): +def jwt_grant(request, token_uri, assertion, can_retry=True): """Implements the JWT Profile for OAuth 2.0 Authorization Grants. For more details, see `rfc7523 section 4`_. @@ -201,6 +272,7 @@ def jwt_grant(request, token_uri, assertion): token_uri (str): The OAuth 2.0 authorizations server's token endpoint URI. assertion (str): The OAuth 2.0 assertion. + can_retry (bool): Enable or disable request retry behavior. Returns: Tuple[str, Optional[datetime], Mapping[str, str]]: The access token, @@ -214,12 +286,16 @@ def jwt_grant(request, token_uri, assertion): """ body = {"assertion": assertion, "grant_type": _JWT_GRANT_TYPE} - response_data = _token_endpoint_request(request, token_uri, body) + response_data = _token_endpoint_request( + request, token_uri, body, can_retry=can_retry + ) try: access_token = response_data["access_token"] except KeyError as caught_exc: - new_exc = exceptions.RefreshError("No access token in response.", response_data) + new_exc = exceptions.RefreshError( + "No access token in response.", response_data, retryable=False + ) six.raise_from(new_exc, caught_exc) expiry = _parse_expiry(response_data) @@ -227,7 +303,7 @@ def jwt_grant(request, token_uri, assertion): return access_token, expiry, response_data -def id_token_jwt_grant(request, token_uri, assertion): +def id_token_jwt_grant(request, token_uri, assertion, can_retry=True): """Implements the JWT Profile for OAuth 2.0 Authorization Grants, but requests an OpenID Connect ID Token instead of an access token. @@ -242,6 +318,7 @@ def id_token_jwt_grant(request, token_uri, assertion): URI. assertion (str): JWT token signed by a service account. The token's payload must include a ``target_audience`` claim. + can_retry (bool): Enable or disable request retry behavior. Returns: Tuple[str, Optional[datetime], Mapping[str, str]]: @@ -254,12 +331,16 @@ def id_token_jwt_grant(request, token_uri, assertion): """ body = {"assertion": assertion, "grant_type": _JWT_GRANT_TYPE} - response_data = _token_endpoint_request(request, token_uri, body) + response_data = _token_endpoint_request( + request, token_uri, body, can_retry=can_retry + ) try: id_token = response_data["id_token"] except KeyError as caught_exc: - new_exc = exceptions.RefreshError("No ID token in response.", response_data) + new_exc = exceptions.RefreshError( + "No ID token in response.", response_data, retryable=False + ) six.raise_from(new_exc, caught_exc) payload = jwt.decode(id_token, verify=False) @@ -288,7 +369,9 @@ def _handle_refresh_grant_response(response_data, refresh_token): try: access_token = response_data["access_token"] except KeyError as caught_exc: - new_exc = exceptions.RefreshError("No access token in response.", response_data) + new_exc = exceptions.RefreshError( + "No access token in response.", response_data, retryable=False + ) six.raise_from(new_exc, caught_exc) refresh_token = response_data.get("refresh_token", refresh_token) @@ -305,6 +388,7 @@ def refresh_grant( client_secret, scopes=None, rapt_token=None, + can_retry=True, ): """Implements the OAuth 2.0 refresh token grant. @@ -324,6 +408,7 @@ def refresh_grant( token has a wild card scope (e.g. 'https://www.googleapis.com/auth/any-api'). rapt_token (Optional(str)): The reauth Proof Token. + can_retry (bool): Enable or disable request retry behavior. Returns: Tuple[str, str, Optional[datetime], Mapping[str, str]]: The access @@ -347,5 +432,7 @@ def refresh_grant( if rapt_token: body["rapt"] = rapt_token - response_data = _token_endpoint_request(request, token_uri, body) + response_data = _token_endpoint_request( + request, token_uri, body, can_retry=can_retry + ) return _handle_refresh_grant_response(response_data, refresh_token) diff --git a/google/oauth2/_client_async.py b/google/oauth2/_client_async.py index cf5121137..428084a70 100644 --- a/google/oauth2/_client_async.py +++ b/google/oauth2/_client_async.py @@ -30,13 +30,14 @@ from six.moves import http_client from six.moves import urllib +from google.auth import _exponential_backoff from google.auth import exceptions from google.auth import jwt from google.oauth2 import _client as client async def _token_endpoint_request_no_throw( - request, token_uri, body, access_token=None, use_json=False + request, token_uri, body, access_token=None, use_json=False, can_retry=True ): """Makes a request to the OAuth 2.0 authorization server's token endpoint. This function doesn't throw on response errors. @@ -50,10 +51,13 @@ async def _token_endpoint_request_no_throw( access_token (Optional(str)): The access token needed to make the request. use_json (Optional(bool)): Use urlencoded format or json format for the content type. The default value is False. + can_retry (bool): Enable or disable request retry behavior. Returns: - Tuple(bool, Mapping[str, str]): A boolean indicating if the request is - successful, and a mapping for the JSON-decoded response data. + Tuple(bool, Mapping[str, str], Optional[bool]): A boolean indicating + if the request is successful, a mapping for the JSON-decoded response + data and in the case of an error a boolean indicating if the error + is retryable. """ if use_json: headers = {"Content-Type": client._JSON_CONTENT_TYPE} @@ -65,11 +69,7 @@ async def _token_endpoint_request_no_throw( if access_token: headers["Authorization"] = "Bearer {}".format(access_token) - retry = 0 - # retry to fetch token for maximum of two times if any internal failure - # occurs. - while True: - + async def _perform_request(): response = await request( method="POST", url=token_uri, headers=headers, body=body ) @@ -83,26 +83,36 @@ async def _token_endpoint_request_no_throw( else response_body1 ) - response_data = json.loads(response_body) + try: + response_data = json.loads(response_body) + except ValueError: + response_data = response_body if response.status == http_client.OK: - break - else: - error_desc = response_data.get("error_description") or "" - error_code = response_data.get("error") or "" - if ( - any(e == "internal_failure" for e in (error_code, error_desc)) - and retry < 1 - ): - retry += 1 - continue - return response.status == http_client.OK, response_data + return True, response_data, None + + retryable_error = client._can_retry( + status_code=response.status, response_data=response_data + ) + + return False, response_data, retryable_error + + request_succeeded, response_data, retryable_error = await _perform_request() + + if request_succeeded or not retryable_error or not can_retry: + return request_succeeded, response_data, retryable_error + + retries = _exponential_backoff.ExponentialBackoff() + for _ in retries: + request_succeeded, response_data, retryable_error = await _perform_request() + if request_succeeded or not retryable_error: + return request_succeeded, response_data, retryable_error - return response.status == http_client.OK, response_data + return False, response_data, retryable_error async def _token_endpoint_request( - request, token_uri, body, access_token=None, use_json=False + request, token_uri, body, access_token=None, use_json=False, can_retry=True ): """Makes a request to the OAuth 2.0 authorization server's token endpoint. @@ -115,6 +125,7 @@ async def _token_endpoint_request( access_token (Optional(str)): The access token needed to make the request. use_json (Optional(bool)): Use urlencoded format or json format for the content type. The default value is False. + can_retry (bool): Enable or disable request retry behavior. Returns: Mapping[str, str]: The JSON-decoded response data. @@ -123,15 +134,21 @@ async def _token_endpoint_request( google.auth.exceptions.RefreshError: If the token endpoint returned an error. """ - response_status_ok, response_data = await _token_endpoint_request_no_throw( - request, token_uri, body, access_token=access_token, use_json=use_json + + response_status_ok, response_data, retryable_error = await _token_endpoint_request_no_throw( + request, + token_uri, + body, + access_token=access_token, + use_json=use_json, + can_retry=can_retry, ) if not response_status_ok: - client._handle_error_response(response_data) + client._handle_error_response(response_data, retryable_error) return response_data -async def jwt_grant(request, token_uri, assertion): +async def jwt_grant(request, token_uri, assertion, can_retry=True): """Implements the JWT Profile for OAuth 2.0 Authorization Grants. For more details, see `rfc7523 section 4`_. @@ -142,6 +159,7 @@ async def jwt_grant(request, token_uri, assertion): token_uri (str): The OAuth 2.0 authorizations server's token endpoint URI. assertion (str): The OAuth 2.0 assertion. + can_retry (bool): Enable or disable request retry behavior. Returns: Tuple[str, Optional[datetime], Mapping[str, str]]: The access token, @@ -155,12 +173,16 @@ async def jwt_grant(request, token_uri, assertion): """ body = {"assertion": assertion, "grant_type": client._JWT_GRANT_TYPE} - response_data = await _token_endpoint_request(request, token_uri, body) + response_data = await _token_endpoint_request( + request, token_uri, body, can_retry=can_retry + ) try: access_token = response_data["access_token"] except KeyError as caught_exc: - new_exc = exceptions.RefreshError("No access token in response.", response_data) + new_exc = exceptions.RefreshError( + "No access token in response.", response_data, retryable=False + ) six.raise_from(new_exc, caught_exc) expiry = client._parse_expiry(response_data) @@ -168,7 +190,7 @@ async def jwt_grant(request, token_uri, assertion): return access_token, expiry, response_data -async def id_token_jwt_grant(request, token_uri, assertion): +async def id_token_jwt_grant(request, token_uri, assertion, can_retry=True): """Implements the JWT Profile for OAuth 2.0 Authorization Grants, but requests an OpenID Connect ID Token instead of an access token. @@ -183,6 +205,7 @@ async def id_token_jwt_grant(request, token_uri, assertion): URI. assertion (str): JWT token signed by a service account. The token's payload must include a ``target_audience`` claim. + can_retry (bool): Enable or disable request retry behavior. Returns: Tuple[str, Optional[datetime], Mapping[str, str]]: @@ -195,12 +218,16 @@ async def id_token_jwt_grant(request, token_uri, assertion): """ body = {"assertion": assertion, "grant_type": client._JWT_GRANT_TYPE} - response_data = await _token_endpoint_request(request, token_uri, body) + response_data = await _token_endpoint_request( + request, token_uri, body, can_retry=can_retry + ) try: id_token = response_data["id_token"] except KeyError as caught_exc: - new_exc = exceptions.RefreshError("No ID token in response.", response_data) + new_exc = exceptions.RefreshError( + "No ID token in response.", response_data, retryable=False + ) six.raise_from(new_exc, caught_exc) payload = jwt.decode(id_token, verify=False) @@ -217,6 +244,7 @@ async def refresh_grant( client_secret, scopes=None, rapt_token=None, + can_retry=True, ): """Implements the OAuth 2.0 refresh token grant. @@ -236,6 +264,7 @@ async def refresh_grant( token has a wild card scope (e.g. 'https://www.googleapis.com/auth/any-api'). rapt_token (Optional(str)): The reauth Proof Token. + can_retry (bool): Enable or disable request retry behavior. Returns: Tuple[str, Optional[str], Optional[datetime], Mapping[str, str]]: The @@ -259,5 +288,7 @@ async def refresh_grant( if rapt_token: body["rapt"] = rapt_token - response_data = await _token_endpoint_request(request, token_uri, body) + response_data = await _token_endpoint_request( + request, token_uri, body, can_retry=can_retry + ) return client._handle_refresh_grant_response(response_data, refresh_token) diff --git a/google/oauth2/_reauth_async.py b/google/oauth2/_reauth_async.py index 30b0b0b1e..6b69c6e67 100644 --- a/google/oauth2/_reauth_async.py +++ b/google/oauth2/_reauth_async.py @@ -292,7 +292,7 @@ async def refresh_grant( if rapt_token: body["rapt"] = rapt_token - response_status_ok, response_data = await _client_async._token_endpoint_request_no_throw( + response_status_ok, response_data, retryable_error = await _client_async._token_endpoint_request_no_throw( request, token_uri, body ) if ( @@ -317,12 +317,13 @@ async def refresh_grant( ( response_status_ok, response_data, + retryable_error, ) = await _client_async._token_endpoint_request_no_throw( request, token_uri, body ) if not response_status_ok: - _client._handle_error_response(response_data) + _client._handle_error_response(response_data, retryable_error) refresh_response = _client._handle_refresh_grant_response( response_data, refresh_token ) diff --git a/google/oauth2/reauth.py b/google/oauth2/reauth.py index 2c32bda2a..ad2ad1b2e 100644 --- a/google/oauth2/reauth.py +++ b/google/oauth2/reauth.py @@ -319,7 +319,7 @@ def refresh_grant( if rapt_token: body["rapt"] = rapt_token - response_status_ok, response_data = _client._token_endpoint_request_no_throw( + response_status_ok, response_data, retryable_error = _client._token_endpoint_request_no_throw( request, token_uri, body ) if ( @@ -339,12 +339,14 @@ def refresh_grant( request, client_id, client_secret, refresh_token, token_uri, scopes=scopes ) body["rapt"] = rapt_token - (response_status_ok, response_data) = _client._token_endpoint_request_no_throw( - request, token_uri, body - ) + ( + response_status_ok, + response_data, + retryable_error, + ) = _client._token_endpoint_request_no_throw(request, token_uri, body) if not response_status_ok: - _client._handle_error_response(response_data) + _client._handle_error_response(response_data, retryable_error) return _client._handle_refresh_grant_response(response_data, refresh_token) + ( rapt_token, ) diff --git a/system_tests/secrets.tar.enc b/system_tests/secrets.tar.enc index f32ace960..735379ef1 100644 Binary files a/system_tests/secrets.tar.enc and b/system_tests/secrets.tar.enc differ diff --git a/tests/compute_engine/test_credentials.py b/tests/compute_engine/test_credentials.py index ff01720c4..ebce176e8 100644 --- a/tests/compute_engine/test_credentials.py +++ b/tests/compute_engine/test_credentials.py @@ -609,7 +609,7 @@ def test_refresh_error(self, sign, get, utcnow): request = mock.create_autospec(transport.Request, instance=True) response = mock.Mock() response.data = b'{"error": "http error"}' - response.status = 500 + response.status = 404 # Throw a 404 so the request is not retried. request.side_effect = [response] self.credentials = credentials.IDTokenCredentials( diff --git a/tests/oauth2/test__client.py b/tests/oauth2/test__client.py index bd4cc5001..13c42dc52 100644 --- a/tests/oauth2/test__client.py +++ b/tests/oauth2/test__client.py @@ -47,12 +47,14 @@ ) -def test__handle_error_response(): +@pytest.mark.parametrize("retryable", [True, False]) +def test__handle_error_response(retryable): response_data = {"error": "help", "error_description": "I'm alive"} with pytest.raises(exceptions.RefreshError) as excinfo: - _client._handle_error_response(response_data) + _client._handle_error_response(response_data, retryable) + assert excinfo.value.retryable == retryable assert excinfo.match(r"help: I\'m alive") @@ -60,8 +62,9 @@ def test__handle_error_response_no_error(): response_data = {"foo": "bar"} with pytest.raises(exceptions.RefreshError) as excinfo: - _client._handle_error_response(response_data) + _client._handle_error_response(response_data, False) + assert not excinfo.value.retryable assert excinfo.match(r"{\"foo\": \"bar\"}") @@ -69,11 +72,33 @@ def test__handle_error_response_not_json(): response_data = "this is an error message" with pytest.raises(exceptions.RefreshError) as excinfo: - _client._handle_error_response(response_data) + _client._handle_error_response(response_data, False) + assert not excinfo.value.retryable assert excinfo.match(response_data) +def test__can_retry_retryable(): + retryable_codes = transport.DEFAULT_RETRYABLE_STATUS_CODES + for status_code in range(100, 600): + if status_code in retryable_codes: + assert _client._can_retry(status_code, {"error": "invalid_scope"}) + else: + assert not _client._can_retry(status_code, {"error": "invalid_scope"}) + + +@pytest.mark.parametrize( + "response_data", [{"error": "internal_failure"}, {"error": "server_error"}] +) +def test__can_retry_message(response_data): + assert _client._can_retry(http_client.OK, response_data) + + +@pytest.mark.parametrize("response_data", [{"error": "invalid_scope"}]) +def test__can_retry_no_retry_message(response_data): + assert not _client._can_retry(http_client.OK, response_data) + + @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) def test__parse_expiry(unused_utcnow): result = _client._parse_expiry({"expires_in": 500}) @@ -154,8 +179,8 @@ def test__token_endpoint_request_internal_failure_error(): _client._token_endpoint_request( request, "http://example.com", {"error_description": "internal_failure"} ) - # request should be called twice due to the retry - assert request.call_count == 2 + # request should be called once and then with 3 retries + assert request.call_count == 4 request = make_request( {"error": "internal_failure"}, status=http_client.BAD_REQUEST @@ -165,7 +190,55 @@ def test__token_endpoint_request_internal_failure_error(): _client._token_endpoint_request( request, "http://example.com", {"error": "internal_failure"} ) - # request should be called twice due to the retry + # request should be called once and then with 3 retries + assert request.call_count == 4 + + +def test__token_endpoint_request_internal_failure_and_retry_failure_error(): + retryable_error = mock.create_autospec(transport.Response, instance=True) + retryable_error.status = http_client.BAD_REQUEST + retryable_error.data = json.dumps({"error_description": "internal_failure"}).encode( + "utf-8" + ) + + unretryable_error = mock.create_autospec(transport.Response, instance=True) + unretryable_error.status = http_client.BAD_REQUEST + unretryable_error.data = json.dumps({"error_description": "invalid_scope"}).encode( + "utf-8" + ) + + request = mock.create_autospec(transport.Request) + + request.side_effect = [retryable_error, retryable_error, unretryable_error] + + with pytest.raises(exceptions.RefreshError): + _client._token_endpoint_request( + request, "http://example.com", {"error_description": "invalid_scope"} + ) + # request should be called three times. Two retryable errors and one + # unretryable error to break the retry loop. + assert request.call_count == 3 + + +def test__token_endpoint_request_internal_failure_and_retry_succeeds(): + retryable_error = mock.create_autospec(transport.Response, instance=True) + retryable_error.status = http_client.BAD_REQUEST + retryable_error.data = json.dumps({"error_description": "internal_failure"}).encode( + "utf-8" + ) + + response = mock.create_autospec(transport.Response, instance=True) + response.status = http_client.OK + response.data = json.dumps({"hello": "world"}).encode("utf-8") + + request = mock.create_autospec(transport.Request) + + request.side_effect = [retryable_error, response] + + _ = _client._token_endpoint_request( + request, "http://example.com", {"test": "params"} + ) + assert request.call_count == 2 @@ -219,8 +292,9 @@ def test_jwt_grant_no_access_token(): } ) - with pytest.raises(exceptions.RefreshError): + with pytest.raises(exceptions.RefreshError) as excinfo: _client.jwt_grant(request, "http://example.com", "assertion_value") + assert not excinfo.value.retryable def test_id_token_jwt_grant(): @@ -255,8 +329,9 @@ def test_id_token_jwt_grant_no_access_token(): } ) - with pytest.raises(exceptions.RefreshError): + with pytest.raises(exceptions.RefreshError) as excinfo: _client.id_token_jwt_grant(request, "http://example.com", "assertion_value") + assert not excinfo.value.retryable @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) @@ -348,7 +423,104 @@ def test_refresh_grant_no_access_token(): } ) - with pytest.raises(exceptions.RefreshError): + with pytest.raises(exceptions.RefreshError) as excinfo: _client.refresh_grant( request, "http://example.com", "refresh_token", "client_id", "client_secret" ) + assert not excinfo.value.retryable + + +@mock.patch("google.oauth2._client._parse_expiry", return_value=None) +@mock.patch.object(_client, "_token_endpoint_request", autospec=True) +def test_jwt_grant_retry_default(mock_token_endpoint_request, mock_expiry): + _client.jwt_grant(mock.Mock(), mock.Mock(), mock.Mock()) + mock_token_endpoint_request.assert_called_with( + mock.ANY, mock.ANY, mock.ANY, can_retry=True + ) + + +@pytest.mark.parametrize("can_retry", [True, False]) +@mock.patch("google.oauth2._client._parse_expiry", return_value=None) +@mock.patch.object(_client, "_token_endpoint_request", autospec=True) +def test_jwt_grant_retry_with_retry( + mock_token_endpoint_request, mock_expiry, can_retry +): + _client.jwt_grant(mock.Mock(), mock.Mock(), mock.Mock(), can_retry=can_retry) + mock_token_endpoint_request.assert_called_with( + mock.ANY, mock.ANY, mock.ANY, can_retry=can_retry + ) + + +@mock.patch("google.auth.jwt.decode", return_value={"exp": 0}) +@mock.patch.object(_client, "_token_endpoint_request", autospec=True) +def test_id_token_jwt_grant_retry_default(mock_token_endpoint_request, mock_jwt_decode): + _client.id_token_jwt_grant(mock.Mock(), mock.Mock(), mock.Mock()) + mock_token_endpoint_request.assert_called_with( + mock.ANY, mock.ANY, mock.ANY, can_retry=True + ) + + +@pytest.mark.parametrize("can_retry", [True, False]) +@mock.patch("google.auth.jwt.decode", return_value={"exp": 0}) +@mock.patch.object(_client, "_token_endpoint_request", autospec=True) +def test_id_token_jwt_grant_retry_with_retry( + mock_token_endpoint_request, mock_jwt_decode, can_retry +): + _client.id_token_jwt_grant( + mock.Mock(), mock.Mock(), mock.Mock(), can_retry=can_retry + ) + mock_token_endpoint_request.assert_called_with( + mock.ANY, mock.ANY, mock.ANY, can_retry=can_retry + ) + + +@mock.patch("google.oauth2._client._parse_expiry", return_value=None) +@mock.patch.object(_client, "_token_endpoint_request", autospec=True) +def test_refresh_grant_retry_default(mock_token_endpoint_request, mock_parse_expiry): + _client.refresh_grant( + mock.Mock(), mock.Mock(), mock.Mock(), mock.Mock(), mock.Mock() + ) + mock_token_endpoint_request.assert_called_with( + mock.ANY, mock.ANY, mock.ANY, can_retry=True + ) + + +@pytest.mark.parametrize("can_retry", [True, False]) +@mock.patch("google.oauth2._client._parse_expiry", return_value=None) +@mock.patch.object(_client, "_token_endpoint_request", autospec=True) +def test_refresh_grant_retry_with_retry( + mock_token_endpoint_request, mock_parse_expiry, can_retry +): + _client.refresh_grant( + mock.Mock(), + mock.Mock(), + mock.Mock(), + mock.Mock(), + mock.Mock(), + can_retry=can_retry, + ) + mock_token_endpoint_request.assert_called_with( + mock.ANY, mock.ANY, mock.ANY, can_retry=can_retry + ) + + +@pytest.mark.parametrize("can_retry", [True, False]) +def test__token_endpoint_request_no_throw_with_retry(can_retry): + response_data = {"error": "help", "error_description": "I'm alive"} + body = "dummy body" + + mock_response = mock.create_autospec(transport.Response, instance=True) + mock_response.status = http_client.INTERNAL_SERVER_ERROR + mock_response.data = json.dumps(response_data).encode("utf-8") + + mock_request = mock.create_autospec(transport.Request) + mock_request.return_value = mock_response + + _client._token_endpoint_request_no_throw( + mock_request, mock.Mock(), body, mock.Mock(), mock.Mock(), can_retry=can_retry + ) + + if can_retry: + assert mock_request.call_count == 4 + else: + assert mock_request.call_count == 1 diff --git a/tests/oauth2/test_reauth.py b/tests/oauth2/test_reauth.py index ae64be009..df0636b18 100644 --- a/tests/oauth2/test_reauth.py +++ b/tests/oauth2/test_reauth.py @@ -260,7 +260,7 @@ def test_refresh_grant_failed(): with mock.patch( "google.oauth2._client._token_endpoint_request_no_throw" ) as mock_token_request: - mock_token_request.return_value = (False, {"error": "Bad request"}) + mock_token_request.return_value = (False, {"error": "Bad request"}, False) with pytest.raises(exceptions.RefreshError) as excinfo: reauth.refresh_grant( MOCK_REQUEST, @@ -273,6 +273,7 @@ def test_refresh_grant_failed(): enable_reauth_refresh=True, ) assert excinfo.match(r"Bad request") + assert not excinfo.value.retryable mock_token_request.assert_called_with( MOCK_REQUEST, "token_uri", @@ -292,8 +293,8 @@ def test_refresh_grant_success(): "google.oauth2._client._token_endpoint_request_no_throw" ) as mock_token_request: mock_token_request.side_effect = [ - (False, {"error": "invalid_grant", "error_subtype": "rapt_required"}), - (True, {"access_token": "access_token"}), + (False, {"error": "invalid_grant", "error_subtype": "rapt_required"}, True), + (True, {"access_token": "access_token"}, None), ] with mock.patch( "google.oauth2.reauth.get_rapt_token", return_value="new_rapt_token" @@ -319,8 +320,8 @@ def test_refresh_grant_reauth_refresh_disabled(): "google.oauth2._client._token_endpoint_request_no_throw" ) as mock_token_request: mock_token_request.side_effect = [ - (False, {"error": "invalid_grant", "error_subtype": "rapt_required"}), - (True, {"access_token": "access_token"}), + (False, {"error": "invalid_grant", "error_subtype": "rapt_required"}, True), + (True, {"access_token": "access_token"}, None), ] with pytest.raises(exceptions.RefreshError) as excinfo: reauth.refresh_grant( diff --git a/tests/test__exponential_backoff.py b/tests/test__exponential_backoff.py new file mode 100644 index 000000000..06a54527e --- /dev/null +++ b/tests/test__exponential_backoff.py @@ -0,0 +1,41 @@ +# Copyright 2022 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import mock + +from google.auth import _exponential_backoff + + +@mock.patch("time.sleep", return_value=None) +def test_exponential_backoff(mock_time): + eb = _exponential_backoff.ExponentialBackoff() + curr_wait = eb._current_wait_in_seconds + iteration_count = 0 + + for attempt in eb: + backoff_interval = mock_time.call_args[0][0] + jitter = curr_wait * eb._randomization_factor + + assert (curr_wait - jitter) <= backoff_interval <= (curr_wait + jitter) + assert attempt == iteration_count + 1 + assert eb.backoff_count == iteration_count + 1 + assert eb._current_wait_in_seconds == eb._multiplier ** (iteration_count + 1) + + curr_wait = eb._current_wait_in_seconds + iteration_count += 1 + + assert eb.total_attempts == _exponential_backoff._DEFAULT_RETRY_TOTAL_ATTEMPTS + assert eb.backoff_count == _exponential_backoff._DEFAULT_RETRY_TOTAL_ATTEMPTS + assert iteration_count == _exponential_backoff._DEFAULT_RETRY_TOTAL_ATTEMPTS + assert mock_time.call_count == _exponential_backoff._DEFAULT_RETRY_TOTAL_ATTEMPTS diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py new file mode 100644 index 000000000..6f542498f --- /dev/null +++ b/tests/test_exceptions.py @@ -0,0 +1,55 @@ +# Copyright 2022 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest # type: ignore + +from google.auth import exceptions # type:ignore + + +@pytest.fixture( + params=[ + exceptions.GoogleAuthError, + exceptions.TransportError, + exceptions.RefreshError, + exceptions.UserAccessTokenError, + exceptions.DefaultCredentialsError, + exceptions.MutualTLSChannelError, + exceptions.OAuthError, + exceptions.ReauthFailError, + exceptions.ReauthSamlChallengeFailError, + ] +) +def retryable_exception(request): + return request.param + + +@pytest.fixture(params=[exceptions.ClientCertError]) +def non_retryable_exception(request): + return request.param + + +def test_default_retryable_exceptions(retryable_exception): + assert not retryable_exception().retryable + + +@pytest.mark.parametrize("retryable", [True, False]) +def test_retryable_exceptions(retryable_exception, retryable): + retryable_exception = retryable_exception(retryable=retryable) + assert retryable_exception.retryable == retryable + + +@pytest.mark.parametrize("retryable", [True, False]) +def test_non_retryable_exceptions(non_retryable_exception, retryable): + non_retryable_exception = non_retryable_exception(retryable=retryable) + assert not non_retryable_exception.retryable diff --git a/tests_async/oauth2/test__client_async.py b/tests_async/oauth2/test__client_async.py index 91874cdd4..402083672 100644 --- a/tests_async/oauth2/test__client_async.py +++ b/tests_async/oauth2/test__client_async.py @@ -29,10 +29,10 @@ from tests.oauth2 import test__client as test_client -def make_request(response_data, status=http_client.OK): +def make_request(response_data, status=http_client.OK, text=False): response = mock.AsyncMock(spec=["transport.Response"]) response.status = status - data = json.dumps(response_data).encode("utf-8") + data = response_data if text else json.dumps(response_data).encode("utf-8") response.data = mock.AsyncMock(spec=["__call__", "read"]) response.data.read = mock.AsyncMock(spec=["__call__"], return_value=data) response.content = mock.AsyncMock(spec=["__call__"], return_value=data) @@ -62,6 +62,27 @@ async def test__token_endpoint_request(): assert result == {"test": "response"} +@pytest.mark.asyncio +async def test__token_endpoint_request_text(): + + request = make_request("response", text=True) + + result = await _client._token_endpoint_request( + request, "http://example.com", {"test": "params"} + ) + + # Check request call + request.assert_called_with( + method="POST", + url="http://example.com", + headers={"Content-Type": "application/x-www-form-urlencoded"}, + body="test=params".encode("utf-8"), + ) + + # Check result + assert result == "response" + + @pytest.mark.asyncio async def test__token_endpoint_request_json(): @@ -95,8 +116,9 @@ async def test__token_endpoint_request_json(): async def test__token_endpoint_request_error(): request = make_request({}, status=http_client.BAD_REQUEST) - with pytest.raises(exceptions.RefreshError): + with pytest.raises(exceptions.RefreshError) as excinfo: await _client._token_endpoint_request(request, "http://example.com", {}) + assert not excinfo.value.retryable @pytest.mark.asyncio @@ -105,10 +127,11 @@ async def test__token_endpoint_request_internal_failure_error(): {"error_description": "internal_failure"}, status=http_client.BAD_REQUEST ) - with pytest.raises(exceptions.RefreshError): + with pytest.raises(exceptions.RefreshError) as excinfo: await _client._token_endpoint_request( request, "http://example.com", {"error_description": "internal_failure"} ) + assert excinfo.value.retryable request = make_request( {"error": "internal_failure"}, status=http_client.BAD_REQUEST @@ -118,6 +141,61 @@ async def test__token_endpoint_request_internal_failure_error(): await _client._token_endpoint_request( request, "http://example.com", {"error": "internal_failure"} ) + assert excinfo.value.retryable + + +@pytest.mark.asyncio +async def test__token_endpoint_request_internal_failure_and_retry_failure_error(): + retryable_error = mock.AsyncMock(spec=["transport.Response"]) + retryable_error.status = http_client.BAD_REQUEST + data = json.dumps({"error_description": "internal_failure"}).encode("utf-8") + retryable_error.data = mock.AsyncMock(spec=["__call__", "read"]) + retryable_error.data.read = mock.AsyncMock(spec=["__call__"], return_value=data) + retryable_error.content = mock.AsyncMock(spec=["__call__"], return_value=data) + + unretryable_error = mock.AsyncMock(spec=["transport.Response"]) + unretryable_error.status = http_client.BAD_REQUEST + data = json.dumps({"error_description": "invalid_scope"}).encode("utf-8") + unretryable_error.data = mock.AsyncMock(spec=["__call__", "read"]) + unretryable_error.data.read = mock.AsyncMock(spec=["__call__"], return_value=data) + unretryable_error.content = mock.AsyncMock(spec=["__call__"], return_value=data) + + request = mock.AsyncMock(spec=["transport.Request"]) + request.side_effect = [retryable_error, retryable_error, unretryable_error] + + with pytest.raises(exceptions.RefreshError): + await _client._token_endpoint_request( + request, "http://example.com", {"error_description": "invalid_scope"} + ) + # request should be called three times. Two retryable errors and one + # unretryable error to break the retry loop. + assert request.call_count == 3 + + +@pytest.mark.asyncio +async def test__token_endpoint_request_internal_failure_and_retry_succeeds(): + retryable_error = mock.AsyncMock(spec=["transport.Response"]) + retryable_error.status = http_client.BAD_REQUEST + data = json.dumps({"error_description": "internal_failure"}).encode("utf-8") + retryable_error.data = mock.AsyncMock(spec=["__call__", "read"]) + retryable_error.data.read = mock.AsyncMock(spec=["__call__"], return_value=data) + retryable_error.content = mock.AsyncMock(spec=["__call__"], return_value=data) + + response = mock.AsyncMock(spec=["transport.Response"]) + response.status = http_client.OK + data = json.dumps({"hello": "world"}).encode("utf-8") + response.data = mock.AsyncMock(spec=["__call__", "read"]) + response.data.read = mock.AsyncMock(spec=["__call__"], return_value=data) + response.content = mock.AsyncMock(spec=["__call__"], return_value=data) + + request = mock.AsyncMock(spec=["transport.Request"]) + request.side_effect = [retryable_error, response] + + _ = await _client._token_endpoint_request( + request, "http://example.com", {"test": "params"} + ) + + assert request.call_count == 2 def verify_request_params(request, params): @@ -128,8 +206,8 @@ def verify_request_params(request, params): assert request_params[key][0] == value -@mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) @pytest.mark.asyncio +@mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) async def test_jwt_grant(utcnow): request = make_request( {"access_token": "token", "expires_in": 500, "extra": "data"} @@ -161,8 +239,9 @@ async def test_jwt_grant_no_access_token(): } ) - with pytest.raises(exceptions.RefreshError): + with pytest.raises(exceptions.RefreshError) as excinfo: await _client.jwt_grant(request, "http://example.com", "assertion_value") + assert not excinfo.value.retryable @pytest.mark.asyncio @@ -200,14 +279,15 @@ async def test_id_token_jwt_grant_no_access_token(): } ) - with pytest.raises(exceptions.RefreshError): + with pytest.raises(exceptions.RefreshError) as excinfo: await _client.id_token_jwt_grant( request, "http://example.com", "assertion_value" ) + assert not excinfo.value.retryable -@mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) @pytest.mark.asyncio +@mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) async def test_refresh_grant(unused_utcnow): request = make_request( { @@ -246,8 +326,8 @@ async def test_refresh_grant(unused_utcnow): assert extra_data["extra"] == "data" -@mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) @pytest.mark.asyncio +@mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) async def test_refresh_grant_with_scopes(unused_utcnow): request = make_request( { @@ -298,7 +378,121 @@ async def test_refresh_grant_no_access_token(): } ) - with pytest.raises(exceptions.RefreshError): + with pytest.raises(exceptions.RefreshError) as excinfo: await _client.refresh_grant( request, "http://example.com", "refresh_token", "client_id", "client_secret" ) + assert not excinfo.value.retryable + + +@pytest.mark.asyncio +@mock.patch("google.oauth2._client._parse_expiry", return_value=None) +@mock.patch.object(_client, "_token_endpoint_request", autospec=True) +async def test_jwt_grant_retry_default(mock_token_endpoint_request, mock_expiry): + _ = await _client.jwt_grant(mock.Mock(), mock.Mock(), mock.Mock()) + mock_token_endpoint_request.assert_called_with( + mock.ANY, mock.ANY, mock.ANY, can_retry=True + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("can_retry", [True, False]) +@mock.patch("google.oauth2._client._parse_expiry", return_value=None) +@mock.patch.object(_client, "_token_endpoint_request", autospec=True) +async def test_jwt_grant_retry_with_retry( + mock_token_endpoint_request, mock_expiry, can_retry +): + _ = await _client.jwt_grant( + mock.AsyncMock(), mock.Mock(), mock.Mock(), can_retry=can_retry + ) + mock_token_endpoint_request.assert_called_with( + mock.ANY, mock.ANY, mock.ANY, can_retry=can_retry + ) + + +@pytest.mark.asyncio +@mock.patch("google.auth.jwt.decode", return_value={"exp": 0}) +@mock.patch.object(_client, "_token_endpoint_request", autospec=True) +async def test_id_token_jwt_grant_retry_default( + mock_token_endpoint_request, mock_jwt_decode +): + _ = await _client.id_token_jwt_grant(mock.Mock(), mock.Mock(), mock.Mock()) + mock_token_endpoint_request.assert_called_with( + mock.ANY, mock.ANY, mock.ANY, can_retry=True + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("can_retry", [True, False]) +@mock.patch("google.auth.jwt.decode", return_value={"exp": 0}) +@mock.patch.object(_client, "_token_endpoint_request", autospec=True) +async def test_id_token_jwt_grant_retry_with_retry( + mock_token_endpoint_request, mock_jwt_decode, can_retry +): + _ = await _client.id_token_jwt_grant( + mock.AsyncMock(), mock.AsyncMock(), mock.AsyncMock(), can_retry=can_retry + ) + mock_token_endpoint_request.assert_called_with( + mock.ANY, mock.ANY, mock.ANY, can_retry=can_retry + ) + + +@pytest.mark.asyncio +@mock.patch("google.oauth2._client._parse_expiry", return_value=None) +@mock.patch.object(_client, "_token_endpoint_request", autospec=True) +async def test_refresh_grant_retry_default( + mock_token_endpoint_request, mock_parse_expiry +): + _ = await _client.refresh_grant( + mock.AsyncMock(), + mock.AsyncMock(), + mock.AsyncMock(), + mock.AsyncMock(), + mock.AsyncMock(), + ) + mock_token_endpoint_request.assert_called_with( + mock.ANY, mock.ANY, mock.ANY, can_retry=True + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("can_retry", [True, False]) +@mock.patch("google.oauth2._client._parse_expiry", return_value=None) +@mock.patch.object(_client, "_token_endpoint_request", autospec=True) +async def test_refresh_grant_retry_with_retry( + mock_token_endpoint_request, mock_parse_expiry, can_retry +): + _ = await _client.refresh_grant( + mock.AsyncMock(), + mock.AsyncMock(), + mock.AsyncMock(), + mock.AsyncMock(), + mock.AsyncMock(), + can_retry=can_retry, + ) + mock_token_endpoint_request.assert_called_with( + mock.ANY, mock.ANY, mock.ANY, can_retry=can_retry + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("can_retry", [True, False]) +async def test__token_endpoint_request_no_throw_with_retry(can_retry): + mock_request = make_request( + {"error": "help", "error_description": "I'm alive"}, + http_client.INTERNAL_SERVER_ERROR, + ) + + _ = await _client._token_endpoint_request_no_throw( + mock_request, + mock.AsyncMock(), + "body", + mock.AsyncMock(), + mock.AsyncMock(), + can_retry=can_retry, + ) + + if can_retry: + assert mock_request.call_count == 4 + else: + assert mock_request.call_count == 1 diff --git a/tests_async/oauth2/test_reauth_async.py b/tests_async/oauth2/test_reauth_async.py index 8f51bd3a7..40ca92717 100644 --- a/tests_async/oauth2/test_reauth_async.py +++ b/tests_async/oauth2/test_reauth_async.py @@ -279,7 +279,7 @@ async def test_refresh_grant_failed(): with mock.patch( "google.oauth2._client_async._token_endpoint_request_no_throw" ) as mock_token_request: - mock_token_request.return_value = (False, {"error": "Bad request"}) + mock_token_request.return_value = (False, {"error": "Bad request"}, True) with pytest.raises(exceptions.RefreshError) as excinfo: await _reauth_async.refresh_grant( MOCK_REQUEST, @@ -291,6 +291,7 @@ async def test_refresh_grant_failed(): rapt_token="rapt_token", ) assert excinfo.match(r"Bad request") + assert excinfo.value.retryable mock_token_request.assert_called_with( MOCK_REQUEST, "token_uri", @@ -311,8 +312,8 @@ async def test_refresh_grant_success(): "google.oauth2._client_async._token_endpoint_request_no_throw" ) as mock_token_request: mock_token_request.side_effect = [ - (False, {"error": "invalid_grant", "error_subtype": "rapt_required"}), - (True, {"access_token": "access_token"}), + (False, {"error": "invalid_grant", "error_subtype": "rapt_required"}, True), + (True, {"access_token": "access_token"}, None), ] with mock.patch( "google.oauth2._reauth_async.get_rapt_token", return_value="new_rapt_token" @@ -339,11 +340,16 @@ async def test_refresh_grant_reauth_refresh_disabled(): "google.oauth2._client_async._token_endpoint_request_no_throw" ) as mock_token_request: mock_token_request.side_effect = [ - (False, {"error": "invalid_grant", "error_subtype": "rapt_required"}), - (True, {"access_token": "access_token"}), + ( + False, + {"error": "invalid_grant", "error_subtype": "rapt_required"}, + False, + ), + (True, {"access_token": "access_token"}, None), ] with pytest.raises(exceptions.RefreshError) as excinfo: assert await _reauth_async.refresh_grant( MOCK_REQUEST, "token_uri", "refresh_token", "client_id", "client_secret" ) assert excinfo.match(r"Reauthentication is needed") + assert not excinfo.value.retryable