Skip to content

Commit

Permalink
Add decorator for suppress optional internal methods in Amazon Provid…
Browse files Browse the repository at this point in the history
…er (#34034)
  • Loading branch information
Taragolis committed Sep 3, 2023
1 parent 4f20b0f commit bf2d411
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 38 deletions.
63 changes: 25 additions & 38 deletions airflow/providers/amazon/aws/hooks/base_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import json
import logging
import os
import uuid
import warnings
from copy import deepcopy
from functools import cached_property, wraps
Expand All @@ -56,6 +55,8 @@
)
from airflow.hooks.base import BaseHook
from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper
from airflow.providers.amazon.aws.utils.identifiers import generate_uuid
from airflow.providers.amazon.aws.utils.suppress import return_on_error
from airflow.providers_manager import ProvidersManager
from airflow.utils.helpers import exactly_one
from airflow.utils.log.logging_mixin import LoggingMixin
Expand Down Expand Up @@ -471,21 +472,17 @@ def __init__(
self._verify = verify

@classmethod
@return_on_error("Unknown")
def _get_provider_version(cls) -> str:
"""Check the Providers Manager for the package version."""
try:
manager = ProvidersManager()
hook = manager.hooks[cls.conn_type]
if not hook:
# This gets caught immediately, but without it MyPy complains
# Item "None" of "Optional[HookInfo]" has no attribute "package_name"
# on the following line and static checks fail.
raise ValueError(f"Hook info for {cls.conn_type} not found in the Provider Manager.")
provider = manager.providers[hook.package_name]
return provider.version
except Exception:
# Under no condition should an error here ever cause an issue for the user.
return "Unknown"
manager = ProvidersManager()
hook = manager.hooks[cls.conn_type]
if not hook:
# This gets caught immediately, but without it MyPy complains
# Item "None" of "Optional[HookInfo]" has no attribute "package_name"
# on the following line and static checks fail.
raise ValueError(f"Hook info for {cls.conn_type} not found in the Provider Manager.")
return manager.providers[hook.package_name].version

@staticmethod
def _find_class_name(target_function_name: str) -> str:
Expand All @@ -505,19 +502,17 @@ def _find_class_name(target_function_name: str) -> str:
# Return the name of the class object.
return frame_class_object.__name__

@return_on_error("Unknown")
def _get_caller(self, target_function_name: str = "execute") -> str:
"""Given a function name, walk the stack and return the name of the class which called it last."""
try:
caller = self._find_class_name(target_function_name)
if caller == "BaseSensorOperator":
# If the result is a BaseSensorOperator, then look for whatever last called "poke".
return self._get_caller("poke")
return caller
except Exception:
# Under no condition should an error here ever cause an issue for the user.
return "Unknown"
caller = self._find_class_name(target_function_name)
if caller == "BaseSensorOperator":
# If the result is a BaseSensorOperator, then look for whatever last called "poke".
return self._get_caller("poke")
return caller

@staticmethod
@return_on_error("00000000-0000-0000-0000-000000000000")
def _generate_dag_key() -> str:
"""Generate a DAG key.
Expand All @@ -526,25 +521,17 @@ def _generate_dag_key() -> str:
can not (reasonably) be reversed. No personal data can be inferred or
extracted from the resulting UUID.
"""
try:
dag_id = os.environ["AIRFLOW_CTX_DAG_ID"]
return str(uuid.uuid5(uuid.NAMESPACE_OID, dag_id))
except Exception:
# Under no condition should an error here ever cause an issue for the user.
return "00000000-0000-0000-0000-000000000000"
return generate_uuid(os.environ.get("AIRFLOW_CTX_DAG_ID"))

@staticmethod
@return_on_error("Unknown")
def _get_airflow_version() -> str:
"""Fetch and return the current Airflow version."""
try:
# This can be a circular import under specific configurations.
# Importing locally to either avoid or catch it if it does happen.
from airflow import __version__ as airflow_version

return airflow_version
except Exception:
# Under no condition should an error here ever cause an issue for the user.
return "Unknown"
# This can be a circular import under specific configurations.
# Importing locally to either avoid or catch it if it does happen.
from airflow import __version__ as airflow_version

return airflow_version

def _generate_user_agent_extra_field(self, existing_user_agent_extra: str) -> str:
user_agent_extra_values = [
Expand Down
74 changes: 74 additions & 0 deletions airflow/providers/amazon/aws/utils/suppress.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""
Module for suppress errors in Amazon Provider.
.. warning::
Only for internal usage, this module might be changed or removed in the future
without any further notice.
:meta: private
"""

from __future__ import annotations

import logging
from functools import wraps
from typing import Callable, TypeVar

from airflow.typing_compat import ParamSpec

PS = ParamSpec("PS")
RT = TypeVar("RT")

log = logging.getLogger(__name__)


def return_on_error(return_value: RT):
"""
Helper decorator which suppress any ``Exception`` raised in decorator function.
Main use-case when functional is optional, however any error on functions/methods might
raise any error which are subclass of ``Exception``.
.. note::
Decorator doesn't intend to catch ``BaseException``,
e.g. ``GeneratorExit``, ``KeyboardInterrupt``, ``SystemExit`` and others.
.. warning::
Only for internal usage, this decorator might be changed or removed in the future
without any further notice.
:param return_value: Return value if decorated function/method raise any ``Exception``.
:meta: private
"""

def decorator(func: Callable[PS, RT]) -> Callable[PS, RT]:
@wraps(func)
def wrapper(*args, **kwargs) -> RT:
try:
return func(*args, **kwargs)
except Exception:
log.debug(
"Encountered error during execution function/method %r", func.__name__, exc_info=True
)
return return_value

return wrapper

return decorator
80 changes: 80 additions & 0 deletions tests/providers/amazon/aws/utils/test_suppress.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import pytest

from airflow.providers.amazon.aws.utils.suppress import return_on_error


def test_suppress_function(caplog):
@return_on_error("error")
def fn(value: str, exc: Exception | None = None) -> str:
if exc:
raise exc
return value

caplog.set_level("DEBUG", "airflow.providers.amazon.aws.utils.suppress")
caplog.clear()

assert fn("no-error") == "no-error"
assert not caplog.messages

assert fn("foo", ValueError("boooo")) == "error"
assert "Encountered error during execution function/method 'fn'" in caplog.messages

caplog.clear()
with pytest.raises(SystemExit, match="42"):
# We do not plan to catch exception which only based on `BaseExceptions`
fn("bar", SystemExit(42))
assert not caplog.messages

# We catch even serious exceptions, e.g. we do not provide mandatory argument here
assert fn() == "error"
assert "Encountered error during execution function/method 'fn'" in caplog.messages


def test_suppress_methods():
class FakeClass:
@return_on_error("Oops!… I Did It Again")
def some_method(self, value, exc: Exception | None = None) -> str:
if exc:
raise exc
return value

@staticmethod
@return_on_error(0)
def some_staticmethod(value, exc: Exception | None = None) -> int:
if exc:
raise exc
return value

@classmethod
@return_on_error("It's fine")
def some_classmethod(cls, value, exc: Exception | None = None) -> str:
if exc:
raise exc
return value

assert FakeClass().some_method("no-error") == "no-error"
assert FakeClass.some_staticmethod(42) == 42
assert FakeClass.some_classmethod("really-no-error-here") == "really-no-error-here"

assert FakeClass().some_method("foo", KeyError("foo")) == "Oops!… I Did It Again"
assert FakeClass.some_staticmethod(42, RuntimeError("bar")) == 0
assert FakeClass.some_classmethod("bar", OSError("Windows detected!")) == "It's fine"

0 comments on commit bf2d411

Please sign in to comment.