Skip to content

Commit

Permalink
Merge pull request numpy#24014 from seberg/errstate-contextvar
Browse files Browse the repository at this point in the history
BUG: Make errstate decorator compatible with threading
  • Loading branch information
seberg committed Jun 27, 2023
2 parents cbc4263 + 3a4bbda commit c058561
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 14 deletions.
12 changes: 4 additions & 8 deletions numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import datetime as dt
import enum
from abc import abstractmethod
from types import TracebackType, MappingProxyType, GenericAlias
from contextlib import ContextDecorator
from contextlib import contextmanager

from numpy._pytesttester import PytestTester
Expand Down Expand Up @@ -3319,17 +3318,13 @@ class _CopyMode(enum.Enum):
# Warnings
class RankWarning(UserWarning): ...

_CallType = TypeVar("_CallType", bound=_ErrFunc | _SupportsWrite[str])
_CallType = TypeVar("_CallType", bound=Callable[..., Any])

class errstate(Generic[_CallType], ContextDecorator):
call: _CallType
kwargs: _ErrDictOptional

# Expand `**kwargs` into explicit keyword-only arguments
class errstate:
def __init__(
self,
*,
call: _CallType = ...,
call: _ErrFunc | _SupportsWrite[str] = ...,
all: None | _ErrKind = ...,
divide: None | _ErrKind = ...,
over: None | _ErrKind = ...,
Expand All @@ -3344,6 +3339,7 @@ class errstate(Generic[_CallType], ContextDecorator):
traceback: None | TracebackType,
/,
) -> None: ...
def __call__(self, func: _CallType) -> _CallType: ...

@contextmanager
def _no_nep50_warning() -> Generator[None, None, None]: ...
Expand Down
40 changes: 36 additions & 4 deletions numpy/core/_ufunc_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import collections.abc
import contextlib
import contextvars
import functools

from .._utils import set_module
from .umath import _make_extobj, _get_extobj_dict, _extobj_contextvar
Expand Down Expand Up @@ -330,7 +331,7 @@ class _unspecified:


@set_module('numpy')
class errstate(contextlib.ContextDecorator):
class errstate:
"""
errstate(**kwargs)
Expand All @@ -344,7 +345,11 @@ class errstate(contextlib.ContextDecorator):
.. versionchanged:: 1.17.0
`errstate` is also usable as a function decorator, saving
a level of indentation if an entire function is wrapped.
See :py:class:`contextlib.ContextDecorator` for more information.
.. versionchanged:: 2.0
`errstate` is now fully thread and asyncio safe, but may not be
entered more than once (unless sequentially).
It is not safe to decorate async functions using ``errstate``.
Parameters
----------
Expand Down Expand Up @@ -388,8 +393,8 @@ class errstate(contextlib.ContextDecorator):
>>> olderr = np.seterr(**olderr) # restore original state
"""
__slots__ = [
"_call", "_all", "_divide", "_over", "_under", "_invalid", "_token"]
__slots__ = (
"_call", "_all", "_divide", "_over", "_under", "_invalid", "_token")

def __init__(self, *, call=_Unspecified,
all=None, divide=None, over=None, under=None, invalid=None):
Expand All @@ -402,6 +407,7 @@ def __init__(self, *, call=_Unspecified,
self._invalid = invalid

def __enter__(self):
# Note that __call__ duplicates much of this logic
if self._token is not None:
raise TypeError("Cannot enter `np.errstate` twice.")
if self._call is _Unspecified:
Expand All @@ -421,6 +427,32 @@ def __exit__(self, *exc_info):
# Allow entering twice, so long as it is sequential:
self._token = None

def __call__(self, func):
# We need to customize `__call__` compared to `ContextDecorator`
# because we must store the token per-thread so cannot store it on
# the instance (we could create a new instance for this).
# This duplicates the code from `__enter__`.
@functools.wraps(func)
def inner(*args, **kwargs):
if self._call is _Unspecified:
extobj = _make_extobj(
all=self._all, divide=self._divide, over=self._over,
under=self._under, invalid=self._invalid)
else:
extobj = _make_extobj(
call=self._call,
all=self._all, divide=self._divide, over=self._over,
under=self._under, invalid=self._invalid)

_token = _extobj_contextvar.set(extobj)
try:
# Call the original, decorated, function:
return func(*args, **kwargs)
finally:
_extobj_contextvar.reset(_token)

return inner


NO_NEP50_WARNING = contextvars.ContextVar("_no_nep50_warning", default=False)

Expand Down
45 changes: 45 additions & 0 deletions numpy/core/tests/test_errstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,48 @@ def foo():
a // 0

foo()

@pytest.mark.skipif(IS_WASM, reason="wasm doesn't support asyncio")
def test_asyncio_safe(self):
# asyncio may not always work, lets assume its fine if missing
# Pyodide/wasm doesn't support it. If this test makes problems,
# it should just be skipped liberally (or run differently).
asyncio = pytest.importorskip("asyncio")

@np.errstate(invalid="ignore")
def decorated():
# Decorated non-async function (it is not safe to decorate an
# async one)
assert np.geterr()["invalid"] == "ignore"

async def func1():
decorated()
await asyncio.sleep(0.1)
decorated()

async def func2():
with np.errstate(invalid="raise"):
assert np.geterr()["invalid"] == "raise"
await asyncio.sleep(0.125)
assert np.geterr()["invalid"] == "raise"

# for good sport, a third one with yet another state:
async def func3():
with np.errstate(invalid="print"):
assert np.geterr()["invalid"] == "print"
await asyncio.sleep(0.11)
assert np.geterr()["invalid"] == "print"

async def main():
# simply run all three function multiple times:
await asyncio.gather(
func1(), func2(), func3(), func1(), func2(), func3(),
func1(), func2(), func3(), func1(), func2(), func3())

loop = asyncio.new_event_loop()
with np.errstate(invalid="warn"):
asyncio.run(main())
assert np.geterr()["invalid"] == "warn"

assert np.geterr()["invalid"] == "warn" # the default
loop.close()
4 changes: 2 additions & 2 deletions numpy/typing/tests/data/reveal/ufunc_config.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,5 @@ reveal_type(np.seterrcall(func)) # E: Union[None, def (builtins.str, builtins.i
reveal_type(np.seterrcall(Write())) # E: Union[None, def (builtins.str, builtins.int) -> Any, _SupportsWrite[builtins.str]]
reveal_type(np.geterrcall()) # E: Union[None, def (builtins.str, builtins.int) -> Any, _SupportsWrite[builtins.str]]

reveal_type(np.errstate(call=func, all="call")) # E: errstate[def (a: builtins.str, b: builtins.int)]
reveal_type(np.errstate(call=Write(), divide="log", over="log")) # E: errstate[ufunc_config.Write]
reveal_type(np.errstate(call=func, all="call")) # E: errstate
reveal_type(np.errstate(call=Write(), divide="log", over="log")) # E: errstate

0 comments on commit c058561

Please sign in to comment.