-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use importlib to load numba extensions #5209
Changes from 22 commits
2b672f2
bedf8da
4999274
22f4f04
22e91be
8f3c285
760ae11
c3d6990
44bcc9c
22e230a
86af993
0d03446
bbe8594
1fbb79f
6184d29
7992f28
169d106
e5fdc82
61ae870
c842eda
724990f
d94e66d
fbd66c3
b5f75f6
59fa60d
4456c37
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -36,6 +36,7 @@ requirements: | |||||||||
- setuptools | ||||||||||
# On channel https://anaconda.org/numba/ | ||||||||||
- llvmlite >=0.39.0dev0,<0.39 | ||||||||||
- importlib_metadata # [py<39] | ||||||||||
# TBB devel version is to match TBB libs. | ||||||||||
# 2020.3 is the last version with the "old" ABI | ||||||||||
# NOTE: ppc64le exclusion is temporary until packages are more generally | ||||||||||
|
@@ -46,6 +47,7 @@ requirements: | |||||||||
- numpy >=1.18 | ||||||||||
- setuptools | ||||||||||
# On channel https://anaconda.org/numba/ | ||||||||||
- importlib_metadata # [py<39] | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
looks like this comment is now on the incorrect line. |
||||||||||
- llvmlite >=0.39.0dev0,<0.39 | ||||||||||
run_constrained: | ||||||||||
# If TBB is present it must be at least version 2021 | ||||||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -98,6 +98,8 @@ fi | |||||
# Install latest llvmlite build | ||||||
$CONDA_INSTALL -c numba/label/dev llvmlite | ||||||
|
||||||
# Install importlib-metadata for Python < 3.8 | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
if [ $PYTHON \< "3.9" ]; then $CONDA_INSTALL importlib_metadata; fi | ||||||
|
||||||
# Install dependencies for building the documentation | ||||||
if [ "$BUILD_DOC" == "yes" ]; then $CONDA_INSTALL sphinx=2.4.4 docutils=0.17 sphinx_rtd_theme pygments numpydoc; fi | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,31 +1,58 @@ | ||
import logging | ||
import warnings | ||
|
||
from pkg_resources import iter_entry_points | ||
from numba.core.config import PYVERSION | ||
|
||
if PYVERSION < (3, 9): | ||
try: | ||
import importlib_metadata | ||
except ImportError as ex: | ||
raise ImportError( | ||
"importlib_metadata backport is required for Python version < 3.9, " | ||
"try:\n" | ||
"$ conda/pip install importlib_metadata" | ||
) from ex | ||
else: | ||
from importlib import metadata as importlib_metadata | ||
|
||
|
||
_already_initialized = False | ||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def init_all(): | ||
'''Execute all `numba_extensions` entry points with the name `init` | ||
"""Execute all `numba_extensions` entry points with the name `init` | ||
|
||
If extensions have already been initialized, this function does nothing. | ||
''' | ||
""" | ||
global _already_initialized | ||
if _already_initialized: | ||
return | ||
|
||
# Must put this here to avoid extensions re-triggering initialization | ||
_already_initialized = True | ||
|
||
for entry_point in iter_entry_points('numba_extensions', 'init'): | ||
def load_ep(entry_point): | ||
"""Loads a given entry point. Warns and logs on failure. | ||
""" | ||
logger.debug('Loading extension: %s', entry_point) | ||
try: | ||
func = entry_point.load() | ||
func() | ||
except Exception as e: | ||
msg = "Numba extension module '{}' failed to load due to '{}({})'." | ||
warnings.warn(msg.format(entry_point.module_name, type(e).__name__, | ||
str(e)), stacklevel=2) | ||
msg = (f"Numba extension module '{entry_point.module}' " | ||
f"failed to load due to '{type(e).__name__}({str(e)})'.") | ||
warnings.warn(msg, stacklevel=3) | ||
logger.debug('Extension loading failed for: %s', entry_point) | ||
|
||
eps = importlib_metadata.entry_points() | ||
# Split, Python 3.10+ and importlib_metadata 3.6+ have the "selectable" | ||
# interface, versions prior to that do not. See "compatibility note" in: | ||
# https://docs.python.org/3.10/library/importlib.metadata.html#entry-points | ||
if hasattr(eps, 'select'): | ||
for entry_point in eps.select(group="numba_extensions", name="init"): | ||
load_ep(entry_point) | ||
else: | ||
for entry_point in eps.get("numba_extensions", ()): | ||
if entry_point.name == "init": | ||
load_ep(entry_point) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,22 @@ | ||
import sys | ||
from unittest import mock | ||
|
||
import types | ||
import warnings | ||
import unittest | ||
import os | ||
import subprocess | ||
import threading | ||
import pkg_resources | ||
|
||
from numba import njit | ||
from numba import config, njit | ||
from numba.tests.support import TestCase | ||
from numba.testing.main import _TIMEOUT as _RUNNER_TIMEOUT | ||
|
||
if config.PYVERSION < (3, 9): | ||
import importlib_metadata | ||
else: | ||
from importlib import metadata as importlib_metadata | ||
|
||
_TEST_TIMEOUT = _RUNNER_TIMEOUT - 60. | ||
|
||
|
||
|
@@ -31,45 +37,35 @@ def test_init_entrypoint(self): | |
# loosely based on Pandas test from: | ||
# https://github.com/pandas-dev/pandas/pull/27488 | ||
|
||
# FIXME: Python 2 workaround because nonlocal doesn't exist | ||
counters = {'init': 0} | ||
|
||
def init_function(): | ||
counters['init'] += 1 | ||
|
||
mod = types.ModuleType("_test_numba_extension") | ||
mod.init_func = init_function | ||
mod = mock.Mock(__name__='_test_numba_extension') | ||
|
||
try: | ||
# will remove this module at the end of the test | ||
sys.modules[mod.__name__] = mod | ||
|
||
# We are registering an entry point using the "numba" package | ||
# ("distribution" in pkg_resources-speak) itself, though these are | ||
# normally registered by other packages. | ||
dist = "numba" | ||
entrypoints = pkg_resources.get_entry_map(dist) | ||
my_entrypoint = pkg_resources.EntryPoint( | ||
"init", # name of entry point | ||
mod.__name__, # module with entry point object | ||
attrs=['init_func'], # name of entry point object | ||
dist=pkg_resources.get_distribution(dist) | ||
my_entrypoint = importlib_metadata.EntryPoint( | ||
'init', '_test_numba_extension:init_func', 'numba_extensions', | ||
) | ||
entrypoints.setdefault('numba_extensions', | ||
{})['init'] = my_entrypoint | ||
|
||
from numba.core import entrypoints | ||
# Allow reinitialization | ||
entrypoints._already_initialized = False | ||
with mock.patch.object( | ||
importlib_metadata, | ||
'entry_points', | ||
return_value={'numba_extensions': (my_entrypoint,)}, | ||
): | ||
|
||
from numba.core import entrypoints | ||
|
||
entrypoints.init_all() | ||
# Allow reinitialization | ||
entrypoints._already_initialized = False | ||
stuartarchibald marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
entrypoints.init_all() | ||
|
||
# was our init function called? | ||
self.assertEqual(counters['init'], 1) | ||
# was our init function called? | ||
mod.init_func.assert_called_once() | ||
Comment on lines
-68
to
+64
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems like a similar change can also be made to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar change made - note this also involves configuring the mock with a side effect for |
||
|
||
# ensure we do not initialize twice | ||
entrypoints.init_all() | ||
self.assertEqual(counters['init'], 1) | ||
# ensure we do not initialize twice | ||
entrypoints.init_all() | ||
mod.init_func.assert_called_once() | ||
finally: | ||
# remove fake module | ||
if mod.__name__ in sys.modules: | ||
|
@@ -93,36 +89,34 @@ def init_function(): | |
# will remove this module at the end of the test | ||
sys.modules[mod.__name__] = mod | ||
|
||
# We are registering an entry point using the "numba" package | ||
# ("distribution" in pkg_resources-speak) itself, though these are | ||
# normally registered by other packages. | ||
dist = "numba" | ||
entrypoints = pkg_resources.get_entry_map(dist) | ||
my_entrypoint = pkg_resources.EntryPoint( | ||
"init", # name of entry point | ||
mod.__name__, # module with entry point object | ||
attrs=['init_func'], # name of entry point object | ||
dist=pkg_resources.get_distribution(dist) | ||
my_entrypoint = importlib_metadata.EntryPoint( | ||
'init', | ||
'_test_numba_bad_extension:init_func', | ||
'numba_extensions', | ||
) | ||
entrypoints.setdefault('numba_extensions', | ||
{})['init'] = my_entrypoint | ||
|
||
from numba.core import entrypoints | ||
# Allow reinitialization | ||
entrypoints._already_initialized = False | ||
with mock.patch.object( | ||
importlib_metadata, | ||
'entry_points', | ||
return_value={'numba_extensions': (my_entrypoint,)}, | ||
): | ||
|
||
with warnings.catch_warnings(record=True) as w: | ||
entrypoints.init_all() | ||
from numba.core import entrypoints | ||
# Allow reinitialization | ||
entrypoints._already_initialized = False | ||
stuartarchibald marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
bad_str = "Numba extension module '_test_numba_bad_extension'" | ||
for x in w: | ||
if bad_str in str(x): | ||
break | ||
else: | ||
raise ValueError("Expected warning message not found") | ||
with warnings.catch_warnings(record=True) as w: | ||
entrypoints.init_all() | ||
|
||
# was our init function called? | ||
self.assertEqual(counters['init'], 1) | ||
bad_str = "Numba extension module '_test_numba_bad_extension'" | ||
for x in w: | ||
if bad_str in str(x): | ||
break | ||
else: | ||
raise ValueError("Expected warning message not found") | ||
|
||
# was our init function called? | ||
self.assertEqual(counters['init'], 1) | ||
|
||
finally: | ||
# remove fake module | ||
|
@@ -188,26 +182,23 @@ def box_dummy(typ, val, c): | |
# will remove this module at the end of the test | ||
sys.modules[mod.__name__] = mod | ||
|
||
# We are registering an entry point using the "numba" package | ||
# ("distribution" in pkg_resources-speak) itself, though these are | ||
# normally registered by other packages. | ||
dist = "numba" | ||
entrypoints = pkg_resources.get_entry_map(dist) | ||
my_entrypoint = pkg_resources.EntryPoint( | ||
"init", # name of entry point | ||
mod.__name__, # module with entry point object | ||
attrs=['init_func'], # name of entry point object | ||
dist=pkg_resources.get_distribution(dist) | ||
my_entrypoint = importlib_metadata.EntryPoint( | ||
'init', | ||
'_test_numba_init_sequence:init_func', | ||
'numba_extensions', | ||
) | ||
entrypoints.setdefault('numba_extensions', | ||
{})['init'] = my_entrypoint | ||
|
||
@njit | ||
def foo(x): | ||
return x | ||
|
||
ival = _DummyClass(10) | ||
foo(ival) | ||
with mock.patch.object( | ||
importlib_metadata, | ||
'entry_points', | ||
return_value={'numba_extensions': (my_entrypoint,)}, | ||
): | ||
@njit | ||
def foo(x): | ||
return x | ||
|
||
ival = _DummyClass(10) | ||
foo(ival) | ||
finally: | ||
# remove fake module | ||
if mod.__name__ in sys.modules: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.