Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Testing: Add type annotations to common/logging #6714

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
91 changes: 54 additions & 37 deletions lib/rucio/common/logging.py
Expand Up @@ -20,37 +20,56 @@
import sys
from collections.abc import Callable, Iterator, Mapping, Sequence
from traceback import format_tb
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, get_args

from rucio.common.config import config_get, config_get_bool

if TYPE_CHECKING:
from logging import LogRecord
from logging import LogRecord, _SysExcInfoType

from _typeshed import OptExcInfo
from flask import Flask


# Mapping from ECS field paths
# https://www.elastic.co/guide/en/ecs-logging/overview/current/intro.html#_field_mapping
# https://www.elastic.co/guide/en/ecs/8.5/ecs-field-reference.html
# to python log record attributes:
# https://docs.python.org/3/library/logging.html#logrecord-attributes
BUILTIN_FIELDS = (
('@timestamp', 'asctime'),
('message', 'message'),
('log.level', 'levelname'),
('log.origin.function', 'funcName'),
('log.origin.file.line', 'lineno'),
('log.origin.file.name', 'filename'),
('log.logger', 'name'),
('process.pid', 'process'),
('process.name', 'processName'),
('process.thread.id', 'thread'),
('process.thread.name', 'threadName'),
)
ECS_TO_LOG_RECORD_MAP = dict(BUILTIN_FIELDS)
LOG_RECORD_TO_ECS_MAP = dict((f[1], f[0]) for f in BUILTIN_FIELDS)


def _json_serializable(obj: Any):
ECS_FIELDS = Literal[
'@timestamp',
'message',
'log.level',
'log.origin.function',
'log.origin.file.line',
'log.origin.file.name',
'log.logger',
'process.pid',
'process.name',
'process.thread.id',
'process.thread.name'
]

LOG_RECORDS = Literal[
'asctime',
'message',
'levelname',
'funcName',
'lineno',
'filename',
'name',
'process',
'processName',
'thread',
'threadName'
]

BUILTIN_FIELDS: tuple[tuple[ECS_FIELDS, LOG_RECORDS], ...] = tuple((x, y) for x, y in zip(get_args(ECS_FIELDS), get_args(LOG_RECORDS)))
ECS_TO_LOG_RECORD_MAP: dict[ECS_FIELDS, LOG_RECORDS] = dict(BUILTIN_FIELDS)
LOG_RECORD_TO_ECS_MAP: dict[LOG_RECORDS, ECS_FIELDS] = dict((f[1], f[0]) for f in BUILTIN_FIELDS)


def _json_serializable(obj: Any) -> Union[dict[Any, Any], str]:
try:
return obj.__dict__
except AttributeError:
Expand Down Expand Up @@ -160,11 +179,11 @@ def _timestamp_formatter(record_formatter: "LogDataSource", record: "LogRecord")
yield record_formatter.ecs_fields[0], datetime.datetime.utcfromtimestamp(record.created).isoformat(timespec='milliseconds') + 'Z'


def _ecs_field_to_record_attribute(field_name):
def _ecs_field_to_record_attribute(field_name: Union[ECS_FIELDS, str]) -> Union[LOG_RECORDS, str]:
"""
Sanitize the path-like field name into a symbol which can be the name of an object attribute.
"""
record = ECS_TO_LOG_RECORD_MAP.get(field_name)
record = ECS_TO_LOG_RECORD_MAP.get(field_name) # type: ignore
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type ignore here is to ignore the error that occurs when a str that's not in ECS_FIELDS is passed to _ecs_field_to_record_attribute, in which case the function returns field_name.replace('-', '_').replace('.', '_').

I personally don't like this logic being in this function, as the function name doesn't say anything about replacing those symbols within a string, and from the function name _ecs_field_to_record_attribute you would assume that this only takes in ECS_FIELDS and returns LOG_RECORDS.

I think it would be a good idea to change the way this works in future.

if record:
return record
return field_name.replace('-', '_').replace('.', '_')
Expand Down Expand Up @@ -195,7 +214,7 @@ def __eq__(self, other: Any):
def __str__(self):
return self.__class__.__name__ + '(' + ', '.join(self.ecs_fields) + ')'

def format(self, record: "LogRecord"):
def format(self, record: "LogRecord") -> Optional[Iterator[tuple[str, Any]]]:
if not self._formatter:
return
for field_name, field_value in self._formatter(self, record):
Expand All @@ -212,7 +231,7 @@ def __init__(self):
)

@staticmethod
def _get_exc_info(record):
def _get_exc_info(record: "LogRecord") -> Optional[Union["OptExcInfo", "_SysExcInfoType"]]:
exc_info = record.exc_info
if not exc_info:
return None
Expand All @@ -222,7 +241,7 @@ def _get_exc_info(record):
return exc_info
return None

def format(self, record: "LogRecord"):
def format(self, record: "LogRecord") -> Iterator[tuple[str, Optional[str]]]:
exc_info = self._get_exc_info(record)
message = record.getMessage()
error_type, error_message, stack_trace = None, None, None
Expand Down Expand Up @@ -253,11 +272,11 @@ class ConstantStrDataSource(LogDataSource):
Prints a constant string for the given ECS field.
"""

def __init__(self, ecs_field, _str):
def __init__(self, ecs_field: ECS_FIELDS, _str: str):
log_record = ECS_TO_LOG_RECORD_MAP.get(ecs_field, None)
self._str = _str

def _formatter(data_source: LogDataSource, record: "LogRecord"):
def _formatter(data_source: LogDataSource, record: "LogRecord") -> Iterator[tuple[str, str]]:
yield self.ecs_fields[0], self._str

super().__init__(ecs_fields=(ecs_field,), formatter=_formatter, dst_record_attr=log_record)
Expand All @@ -284,7 +303,7 @@ def __init__(
fmt: Optional[str] = None,
validate: Optional[bool] = None,
output_json: bool = False,
additional_fields: Optional[Mapping[str, str]] = None
additional_fields: Optional[Mapping[ECS_FIELDS, str]] = None
):
_kwargs = {}
if validate is not None:
Expand Down Expand Up @@ -344,15 +363,15 @@ def __init__(
self.output_json = output_json
super().__init__(fmt=fmt, style='%', **_kwargs)

def format(self, record):
json_record = dict(itertools.chain.from_iterable(f.format(record) for f in self._desired_data_sources))
def format(self, record: "LogRecord") -> str:
json_record = dict(itertools.chain.from_iterable(f.format(record) for f in self._desired_data_sources)) # type: ignore
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type: ignore here is due to the fact that LogDataSource.format returns None if there is no self._formatter, and the static type checker is not able to guarantee that f has a _formatter specified in this case.

This logic should likely be changed so that it only runs LogDataSource.format if _formatter exists, and otherwise returns None, but I think it's best to justtype: ignore for now and address this later.

if self.output_json:
return self._to_json(_unflatten_dict(json_record))
else:
return super().format(record)

@staticmethod
def _to_json(record):
def _to_json(record: dict[str, Any]) -> str:
try:
return json.dumps(record, default=_json_serializable)
except (TypeError, ValueError, OverflowError):
Expand All @@ -362,7 +381,7 @@ def _to_json(record):
return '{}'


def rucio_log_formatter(process_name: Optional[str] = None):
def rucio_log_formatter(process_name: Optional[str] = None) -> RucioFormatter:
config_logformat = config_get('common', 'logformat', raise_exception=False, default='%(asctime)s\t%(name)s\t%(process)d\t%(levelname)s\t%(message)s')
output_json = config_get_bool('common', 'logjson', default=False)
additional_fields = {}
Expand All @@ -371,7 +390,7 @@ def rucio_log_formatter(process_name: Optional[str] = None):
return RucioFormatter(fmt=config_logformat, output_json=output_json, additional_fields=additional_fields)


def setup_logging(application=None, process_name=None):
def setup_logging(application: Optional["Flask"] = None, process_name: Optional[str] = None) -> None:
"""
Configures the logging by setting the output stream to stdout and
configures log level and log format.
Expand All @@ -387,17 +406,15 @@ def setup_logging(application=None, process_name=None):
application.logger.addHandler(stdouthandler)


def formatted_logger(innerfunc, formatstr="%s"):
def formatted_logger(innerfunc: Callable, formatstr: str = "%s") -> Callable:
"""
Decorates the passed function, formatting log input by
the passed formatstr. The format string must always include a %s.

:param innerfunc: function to be decorated. Must take (level, msg) arguments.
:type innerfunc: Callable
:param formatstr: format string with %s as placeholder.
:type formatstr: str
"""
@functools.wraps(innerfunc)
def log_format(level, msg, *args, **kwargs):
def log_format(level: int, msg: object, *args, **kwargs) -> Callable:
return innerfunc(level, formatstr % msg, *args, **kwargs)
return log_format