-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use importlib to load numba extensions #5209
Changes from 20 commits
2b672f2
bedf8da
4999274
22f4f04
22e91be
8f3c285
760ae11
c3d6990
44bcc9c
22e230a
86af993
0d03446
bbe8594
1fbb79f
6184d29
7992f28
169d106
e5fdc82
61ae870
c842eda
724990f
d94e66d
fbd66c3
b5f75f6
59fa60d
4456c37
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -83,6 +83,8 @@ fi | |||||
# Install latest llvmlite build | ||||||
$CONDA_INSTALL -c numba/label/dev llvmlite | ||||||
|
||||||
# Install importlib-metadata for Python < 3.8 | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
if [ $PYTHON \< "3.9" ]; then $CONDA_INSTALL importlib_metadata; fi | ||||||
# Install dependencies for building the documentation | ||||||
if [ "$BUILD_DOC" == "yes" ]; then $CONDA_INSTALL sphinx=2.4.4 sphinx_rtd_theme pygments numpydoc; fi | ||||||
if [ "$BUILD_DOC" == "yes" ]; then $PIP_INSTALL rstcheck; fi | ||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -1,31 +1,47 @@ | ||||||||||||
import logging | ||||||||||||
import warnings | ||||||||||||
|
||||||||||||
from pkg_resources import iter_entry_points | ||||||||||||
from numba.core.config import PYVERSION | ||||||||||||
|
||||||||||||
if PYVERSION < (3, 9): | ||||||||||||
try: | ||||||||||||
import importlib_metadata | ||||||||||||
except ImportError as ex: | ||||||||||||
raise ImportError( | ||||||||||||
"importlib_metadata backport is required for Python version < 3.9, " | ||||||||||||
"try:\n" | ||||||||||||
"$ conda/pip install importlib_metadata" | ||||||||||||
) from ex | ||||||||||||
else: | ||||||||||||
from importlib import metadata as importlib_metadata | ||||||||||||
|
||||||||||||
|
||||||||||||
_already_initialized = False | ||||||||||||
logger = logging.getLogger(__name__) | ||||||||||||
|
||||||||||||
|
||||||||||||
def init_all(): | ||||||||||||
'''Execute all `numba_extensions` entry points with the name `init` | ||||||||||||
"""Execute all `numba_extensions` entry points with the name `init` | ||||||||||||
|
||||||||||||
If extensions have already been initialized, this function does nothing. | ||||||||||||
''' | ||||||||||||
""" | ||||||||||||
global _already_initialized | ||||||||||||
if _already_initialized: | ||||||||||||
return | ||||||||||||
|
||||||||||||
# Must put this here to avoid extensions re-triggering initialization | ||||||||||||
_already_initialized = True | ||||||||||||
|
||||||||||||
for entry_point in iter_entry_points('numba_extensions', 'init'): | ||||||||||||
logger.debug('Loading extension: %s', entry_point) | ||||||||||||
try: | ||||||||||||
func = entry_point.load() | ||||||||||||
func() | ||||||||||||
except Exception as e: | ||||||||||||
msg = "Numba extension module '{}' failed to load due to '{}({})'." | ||||||||||||
warnings.warn(msg.format(entry_point.module_name, type(e).__name__, | ||||||||||||
str(e)), stacklevel=2) | ||||||||||||
logger.debug('Extension loading failed for: %s', entry_point) | ||||||||||||
for entry_point in importlib_metadata.entry_points().get( | ||||||||||||
"numba_extensions", () | ||||||||||||
): | ||||||||||||
if entry_point.name == "init": | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this should be:
Suggested change
which then means we need There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks @eric-wieser, seems like that could work. Were this change made then I think there would also need to be a minimum version of >>> import importlib_metadata as imd
>>> imd.entry_points(group='foo')
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
TypeError: entry_points() got an unexpected keyword argument 'group'
>>> imd.__version__
'1.6.1 Assuming v3.6.0 is correct, its release date is February 2021, I wonder if this is too new for a non-optional dependency with respect to some of the places with longer release/dependency cycles, e.g. HPC systems with internal curated package mirrors and linux distros. Also, as a cross-check from the 3.10 docs: https://docs.python.org/3.10/library/importlib.metadata.html
and here: https://docs.python.org/3.10/library/importlib.metadata.html#entry-points Am going to raise this at the next public meeting to try and get a resolution: https://numba.discourse.group/t/weekly-public-meeting-every-tuesday-for-2021/658/2 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps if compatibility is a concern, we should just do feature detection on There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think something a bit like this might work: def init_all():
"""Execute all `numba_extensions` entry points with the name `init`
If extensions have already been initialized, this function does nothing.
"""
global _already_initialized
if _already_initialized:
return
# Must put this here to avoid extensions re-triggering initialization
_already_initialized = True
def load_ep(entry_point):
"""Loads a given entry point. Warns and logs on failure.
"""
logger.debug('Loading extension: %s', entry_point)
try:
func = entry_point.load()
func()
except Exception as e:
msg = (f"Numba extension module '{entry_point.module}' "
f"failed to load due to '{type(e).__name__}({str(e)})'.")
warnings.warn(msg, stacklevel=3)
logger.debug('Extension loading failed for: %s', entry_point)
eps = importlib_metadata.entry_points()
# Split, Python 3.10+ and importlib_metadata 3.6+ have the "selectable"
# interface, versions prior to that do not. See "compatibility note" in:
# https://docs.python.org/3.10/library/importlib.metadata.html#entry-points
if hasattr(eps, 'select'):
for entry_point in eps.select(group="numba_extensions", name="init"):
load_ep(entry_point)
else:
for entry_point in eps.get("numba_extensions", ()):
if entry_point.name == "init":
load_ep(entry_point) CC @gmarkall who was interested in carrying this PR to completion. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for this - I've applied and tested it locally with Python 3.7 - 3.10 and importlib_metadata 4.10.1 - it all seemed fine locally so I've pushed it along with a merge / conflict resolution from master. |
||||||||||||
logger.debug('Loading extension: %s', entry_point) | ||||||||||||
try: | ||||||||||||
func = entry_point.load() | ||||||||||||
func() | ||||||||||||
except Exception as e: | ||||||||||||
msg = (f"Numba extension module '{entry_point.module}' " | ||||||||||||
f"failed to load due to '{type(e).__name__}({str(e)})'.") | ||||||||||||
warnings.warn(msg, stacklevel=2) | ||||||||||||
logger.debug('Extension loading failed for: %s', entry_point) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,22 @@ | ||
import sys | ||
from unittest import mock | ||
|
||
import types | ||
import warnings | ||
import unittest | ||
import os | ||
import subprocess | ||
import threading | ||
import pkg_resources | ||
|
||
from numba import njit | ||
from numba import config, njit | ||
from numba.tests.support import TestCase | ||
from numba.testing.main import _TIMEOUT as _RUNNER_TIMEOUT | ||
|
||
if config.PYVERSION < (3, 9): | ||
import importlib_metadata | ||
else: | ||
from importlib import metadata as importlib_metadata | ||
|
||
_TEST_TIMEOUT = _RUNNER_TIMEOUT - 60. | ||
|
||
|
||
|
@@ -31,45 +37,35 @@ def test_init_entrypoint(self): | |
# loosely based on Pandas test from: | ||
# https://github.com/pandas-dev/pandas/pull/27488 | ||
|
||
# FIXME: Python 2 workaround because nonlocal doesn't exist | ||
counters = {'init': 0} | ||
|
||
def init_function(): | ||
counters['init'] += 1 | ||
|
||
mod = types.ModuleType("_test_numba_extension") | ||
mod.init_func = init_function | ||
mod = mock.Mock(__name__='_test_numba_extension') | ||
|
||
try: | ||
# will remove this module at the end of the test | ||
sys.modules[mod.__name__] = mod | ||
|
||
# We are registering an entry point using the "numba" package | ||
# ("distribution" in pkg_resources-speak) itself, though these are | ||
# normally registered by other packages. | ||
dist = "numba" | ||
entrypoints = pkg_resources.get_entry_map(dist) | ||
my_entrypoint = pkg_resources.EntryPoint( | ||
"init", # name of entry point | ||
mod.__name__, # module with entry point object | ||
attrs=['init_func'], # name of entry point object | ||
dist=pkg_resources.get_distribution(dist) | ||
my_entrypoint = importlib_metadata.EntryPoint( | ||
'init', '_test_numba_extension:init_func', 'numba_extensions', | ||
) | ||
entrypoints.setdefault('numba_extensions', | ||
{})['init'] = my_entrypoint | ||
|
||
from numba.core import entrypoints | ||
# Allow reinitialization | ||
entrypoints._already_initialized = False | ||
with mock.patch.object( | ||
importlib_metadata, | ||
'entry_points', | ||
return_value={'numba_extensions': (my_entrypoint,)}, | ||
): | ||
|
||
from numba.core import entrypoints | ||
|
||
entrypoints.init_all() | ||
# Allow reinitialization | ||
entrypoints._already_initialized = False | ||
stuartarchibald marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
entrypoints.init_all() | ||
|
||
# was our init function called? | ||
self.assertEqual(counters['init'], 1) | ||
# was our init function called? | ||
mod.init_func.assert_called_once() | ||
Comment on lines
-68
to
+64
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems like a similar change can also be made to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar change made - note this also involves configuring the mock with a side effect for |
||
|
||
# ensure we do not initialize twice | ||
entrypoints.init_all() | ||
self.assertEqual(counters['init'], 1) | ||
# ensure we do not initialize twice | ||
entrypoints.init_all() | ||
mod.init_func.assert_called_once() | ||
finally: | ||
# remove fake module | ||
if mod.__name__ in sys.modules: | ||
|
@@ -93,36 +89,34 @@ def init_function(): | |
# will remove this module at the end of the test | ||
sys.modules[mod.__name__] = mod | ||
|
||
# We are registering an entry point using the "numba" package | ||
# ("distribution" in pkg_resources-speak) itself, though these are | ||
# normally registered by other packages. | ||
dist = "numba" | ||
entrypoints = pkg_resources.get_entry_map(dist) | ||
my_entrypoint = pkg_resources.EntryPoint( | ||
"init", # name of entry point | ||
mod.__name__, # module with entry point object | ||
attrs=['init_func'], # name of entry point object | ||
dist=pkg_resources.get_distribution(dist) | ||
my_entrypoint = importlib_metadata.EntryPoint( | ||
'init', | ||
'_test_numba_bad_extension:init_func', | ||
'numba_extensions', | ||
) | ||
entrypoints.setdefault('numba_extensions', | ||
{})['init'] = my_entrypoint | ||
|
||
from numba.core import entrypoints | ||
# Allow reinitialization | ||
entrypoints._already_initialized = False | ||
with mock.patch.object( | ||
importlib_metadata, | ||
'entry_points', | ||
return_value={'numba_extensions': (my_entrypoint,)}, | ||
): | ||
|
||
with warnings.catch_warnings(record=True) as w: | ||
entrypoints.init_all() | ||
from numba.core import entrypoints | ||
# Allow reinitialization | ||
entrypoints._already_initialized = False | ||
stuartarchibald marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
bad_str = "Numba extension module '_test_numba_bad_extension'" | ||
for x in w: | ||
if bad_str in str(x): | ||
break | ||
else: | ||
raise ValueError("Expected warning message not found") | ||
with warnings.catch_warnings(record=True) as w: | ||
entrypoints.init_all() | ||
|
||
# was our init function called? | ||
self.assertEqual(counters['init'], 1) | ||
bad_str = "Numba extension module '_test_numba_bad_extension'" | ||
for x in w: | ||
if bad_str in str(x): | ||
break | ||
else: | ||
raise ValueError("Expected warning message not found") | ||
|
||
# was our init function called? | ||
self.assertEqual(counters['init'], 1) | ||
|
||
finally: | ||
# remove fake module | ||
|
@@ -188,26 +182,23 @@ def box_dummy(typ, val, c): | |
# will remove this module at the end of the test | ||
sys.modules[mod.__name__] = mod | ||
|
||
# We are registering an entry point using the "numba" package | ||
# ("distribution" in pkg_resources-speak) itself, though these are | ||
# normally registered by other packages. | ||
dist = "numba" | ||
entrypoints = pkg_resources.get_entry_map(dist) | ||
my_entrypoint = pkg_resources.EntryPoint( | ||
"init", # name of entry point | ||
mod.__name__, # module with entry point object | ||
attrs=['init_func'], # name of entry point object | ||
dist=pkg_resources.get_distribution(dist) | ||
my_entrypoint = importlib_metadata.EntryPoint( | ||
'init', | ||
'_test_numba_init_sequence:init_func', | ||
'numba_extensions', | ||
) | ||
entrypoints.setdefault('numba_extensions', | ||
{})['init'] = my_entrypoint | ||
|
||
@njit | ||
def foo(x): | ||
return x | ||
|
||
ival = _DummyClass(10) | ||
foo(ival) | ||
with mock.patch.object( | ||
importlib_metadata, | ||
'entry_points', | ||
return_value={'numba_extensions': (my_entrypoint,)}, | ||
): | ||
@njit | ||
def foo(x): | ||
return x | ||
|
||
ival = _DummyClass(10) | ||
foo(ival) | ||
finally: | ||
# remove fake module | ||
if mod.__name__ in sys.modules: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like this comment is now on the incorrect line.