New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Testing: Add type annotations to common/logging #6714
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,37 +20,56 @@ | |
import sys | ||
from collections.abc import Callable, Iterator, Mapping, Sequence | ||
from traceback import format_tb | ||
from typing import TYPE_CHECKING, Any, Optional | ||
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, get_args | ||
|
||
from rucio.common.config import config_get, config_get_bool | ||
|
||
if TYPE_CHECKING: | ||
from logging import LogRecord | ||
from logging import LogRecord, _SysExcInfoType | ||
|
||
from _typeshed import OptExcInfo | ||
from flask import Flask | ||
|
||
|
||
# Mapping from ECS field paths | ||
# https://www.elastic.co/guide/en/ecs-logging/overview/current/intro.html#_field_mapping | ||
# https://www.elastic.co/guide/en/ecs/8.5/ecs-field-reference.html | ||
# to python log record attributes: | ||
# https://docs.python.org/3/library/logging.html#logrecord-attributes | ||
BUILTIN_FIELDS = ( | ||
('@timestamp', 'asctime'), | ||
('message', 'message'), | ||
('log.level', 'levelname'), | ||
('log.origin.function', 'funcName'), | ||
('log.origin.file.line', 'lineno'), | ||
('log.origin.file.name', 'filename'), | ||
('log.logger', 'name'), | ||
('process.pid', 'process'), | ||
('process.name', 'processName'), | ||
('process.thread.id', 'thread'), | ||
('process.thread.name', 'threadName'), | ||
) | ||
ECS_TO_LOG_RECORD_MAP = dict(BUILTIN_FIELDS) | ||
LOG_RECORD_TO_ECS_MAP = dict((f[1], f[0]) for f in BUILTIN_FIELDS) | ||
|
||
|
||
def _json_serializable(obj: Any): | ||
ECS_FIELDS = Literal[ | ||
'@timestamp', | ||
'message', | ||
'log.level', | ||
'log.origin.function', | ||
'log.origin.file.line', | ||
'log.origin.file.name', | ||
'log.logger', | ||
'process.pid', | ||
'process.name', | ||
'process.thread.id', | ||
'process.thread.name' | ||
] | ||
|
||
LOG_RECORDS = Literal[ | ||
'asctime', | ||
'message', | ||
'levelname', | ||
'funcName', | ||
'lineno', | ||
'filename', | ||
'name', | ||
'process', | ||
'processName', | ||
'thread', | ||
'threadName' | ||
] | ||
|
||
BUILTIN_FIELDS: tuple[tuple[ECS_FIELDS, LOG_RECORDS], ...] = tuple((x, y) for x, y in zip(get_args(ECS_FIELDS), get_args(LOG_RECORDS))) | ||
ECS_TO_LOG_RECORD_MAP: dict[ECS_FIELDS, LOG_RECORDS] = dict(BUILTIN_FIELDS) | ||
LOG_RECORD_TO_ECS_MAP: dict[LOG_RECORDS, ECS_FIELDS] = dict((f[1], f[0]) for f in BUILTIN_FIELDS) | ||
|
||
|
||
def _json_serializable(obj: Any) -> Union[dict[Any, Any], str]: | ||
try: | ||
return obj.__dict__ | ||
except AttributeError: | ||
|
@@ -160,11 +179,11 @@ def _timestamp_formatter(record_formatter: "LogDataSource", record: "LogRecord") | |
yield record_formatter.ecs_fields[0], datetime.datetime.utcfromtimestamp(record.created).isoformat(timespec='milliseconds') + 'Z' | ||
|
||
|
||
def _ecs_field_to_record_attribute(field_name): | ||
def _ecs_field_to_record_attribute(field_name: Union[ECS_FIELDS, str]) -> Union[LOG_RECORDS, str]: | ||
""" | ||
Sanitize the path-like field name into a symbol which can be the name of an object attribute. | ||
""" | ||
record = ECS_TO_LOG_RECORD_MAP.get(field_name) | ||
record = ECS_TO_LOG_RECORD_MAP.get(field_name) # type: ignore | ||
if record: | ||
return record | ||
return field_name.replace('-', '_').replace('.', '_') | ||
|
@@ -195,7 +214,7 @@ def __eq__(self, other: Any): | |
def __str__(self): | ||
return self.__class__.__name__ + '(' + ', '.join(self.ecs_fields) + ')' | ||
|
||
def format(self, record: "LogRecord"): | ||
def format(self, record: "LogRecord") -> Optional[Iterator[tuple[str, Any]]]: | ||
if not self._formatter: | ||
return | ||
for field_name, field_value in self._formatter(self, record): | ||
|
@@ -212,7 +231,7 @@ def __init__(self): | |
) | ||
|
||
@staticmethod | ||
def _get_exc_info(record): | ||
def _get_exc_info(record: "LogRecord") -> Optional[Union["OptExcInfo", "_SysExcInfoType"]]: | ||
exc_info = record.exc_info | ||
if not exc_info: | ||
return None | ||
|
@@ -222,7 +241,7 @@ def _get_exc_info(record): | |
return exc_info | ||
return None | ||
|
||
def format(self, record: "LogRecord"): | ||
def format(self, record: "LogRecord") -> Iterator[tuple[str, Optional[str]]]: | ||
exc_info = self._get_exc_info(record) | ||
message = record.getMessage() | ||
error_type, error_message, stack_trace = None, None, None | ||
|
@@ -253,11 +272,11 @@ class ConstantStrDataSource(LogDataSource): | |
Prints a constant string for the given ECS field. | ||
""" | ||
|
||
def __init__(self, ecs_field, _str): | ||
def __init__(self, ecs_field: ECS_FIELDS, _str: str): | ||
log_record = ECS_TO_LOG_RECORD_MAP.get(ecs_field, None) | ||
self._str = _str | ||
|
||
def _formatter(data_source: LogDataSource, record: "LogRecord"): | ||
def _formatter(data_source: LogDataSource, record: "LogRecord") -> Iterator[tuple[str, str]]: | ||
yield self.ecs_fields[0], self._str | ||
|
||
super().__init__(ecs_fields=(ecs_field,), formatter=_formatter, dst_record_attr=log_record) | ||
|
@@ -284,7 +303,7 @@ def __init__( | |
fmt: Optional[str] = None, | ||
validate: Optional[bool] = None, | ||
output_json: bool = False, | ||
additional_fields: Optional[Mapping[str, str]] = None | ||
additional_fields: Optional[Mapping[ECS_FIELDS, str]] = None | ||
): | ||
_kwargs = {} | ||
if validate is not None: | ||
|
@@ -344,15 +363,15 @@ def __init__( | |
self.output_json = output_json | ||
super().__init__(fmt=fmt, style='%', **_kwargs) | ||
|
||
def format(self, record): | ||
json_record = dict(itertools.chain.from_iterable(f.format(record) for f in self._desired_data_sources)) | ||
def format(self, record: "LogRecord") -> str: | ||
json_record = dict(itertools.chain.from_iterable(f.format(record) for f in self._desired_data_sources)) # type: ignore | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The This logic should likely be changed so that it only runs |
||
if self.output_json: | ||
return self._to_json(_unflatten_dict(json_record)) | ||
else: | ||
return super().format(record) | ||
|
||
@staticmethod | ||
def _to_json(record): | ||
def _to_json(record: dict[str, Any]) -> str: | ||
try: | ||
return json.dumps(record, default=_json_serializable) | ||
except (TypeError, ValueError, OverflowError): | ||
|
@@ -362,7 +381,7 @@ def _to_json(record): | |
return '{}' | ||
|
||
|
||
def rucio_log_formatter(process_name: Optional[str] = None): | ||
def rucio_log_formatter(process_name: Optional[str] = None) -> RucioFormatter: | ||
config_logformat = config_get('common', 'logformat', raise_exception=False, default='%(asctime)s\t%(name)s\t%(process)d\t%(levelname)s\t%(message)s') | ||
output_json = config_get_bool('common', 'logjson', default=False) | ||
additional_fields = {} | ||
|
@@ -371,7 +390,7 @@ def rucio_log_formatter(process_name: Optional[str] = None): | |
return RucioFormatter(fmt=config_logformat, output_json=output_json, additional_fields=additional_fields) | ||
|
||
|
||
def setup_logging(application=None, process_name=None): | ||
def setup_logging(application: Optional["Flask"] = None, process_name: Optional[str] = None) -> None: | ||
""" | ||
Configures the logging by setting the output stream to stdout and | ||
configures log level and log format. | ||
|
@@ -387,17 +406,15 @@ def setup_logging(application=None, process_name=None): | |
application.logger.addHandler(stdouthandler) | ||
|
||
|
||
def formatted_logger(innerfunc, formatstr="%s"): | ||
def formatted_logger(innerfunc: Callable, formatstr: str = "%s") -> Callable: | ||
""" | ||
Decorates the passed function, formatting log input by | ||
the passed formatstr. The format string must always include a %s. | ||
|
||
:param innerfunc: function to be decorated. Must take (level, msg) arguments. | ||
:type innerfunc: Callable | ||
:param formatstr: format string with %s as placeholder. | ||
:type formatstr: str | ||
""" | ||
@functools.wraps(innerfunc) | ||
def log_format(level, msg, *args, **kwargs): | ||
def log_format(level: int, msg: object, *args, **kwargs) -> Callable: | ||
return innerfunc(level, formatstr % msg, *args, **kwargs) | ||
return log_format |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type ignore here is to ignore the error that occurs when a
str
that's not inECS_FIELDS
is passed to_ecs_field_to_record_attribute
, in which case the function returnsfield_name.replace('-', '_').replace('.', '_')
.I personally don't like this logic being in this function, as the function name doesn't say anything about replacing those symbols within a string, and from the function name
_ecs_field_to_record_attribute
you would assume that this only takes inECS_FIELDS
and returnsLOG_RECORDS
.I think it would be a good idea to change the way this works in future.