Skip to content

Commit

Permalink
Merge pull request #5209 from svrakitin/fix/4927-slow-import
Browse files Browse the repository at this point in the history
Use importlib to load numba extensions
  • Loading branch information
sklam committed Feb 7, 2022
2 parents d9efc9e + 4456c37 commit fa5c9e1
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 90 deletions.
2 changes: 2 additions & 0 deletions buildscripts/condarecipe.local/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ requirements:
- python
- numpy
- setuptools
- importlib_metadata # [py<39]
# On channel https://anaconda.org/numba/
- llvmlite >=0.39.0dev0,<0.39
# TBB devel version is to match TBB libs.
Expand All @@ -45,6 +46,7 @@ requirements:
- python >=3.7
- numpy >=1.18
- setuptools
- importlib_metadata # [py<39]
# On channel https://anaconda.org/numba/
- llvmlite >=0.39.0dev0,<0.39
run_constrained:
Expand Down
2 changes: 2 additions & 0 deletions buildscripts/incremental/setup_conda_environment.cmd
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ conda create -n %CONDA_ENV% -q -y python=%PYTHON% numpy=%NUMPY% cffi pip scipy j
call activate %CONDA_ENV%
@rem Install latest llvmlite build
%CONDA_INSTALL% -c numba/label/dev llvmlite
@rem Install required backports for older Pythons
if %PYTHON% LSS 3.9 (%CONDA_INSTALL% importlib_metadata)
@rem Install dependencies for building the documentation
if "%BUILD_DOC%" == "yes" (%CONDA_INSTALL% sphinx sphinx_rtd_theme pygments)
@rem Install dependencies for code coverage (codecov.io)
Expand Down
2 changes: 2 additions & 0 deletions buildscripts/incremental/setup_conda_environment.sh
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ fi
# Install latest llvmlite build
$CONDA_INSTALL -c numba/label/dev llvmlite

# Install importlib-metadata for Python < 3.9
if [ $PYTHON \< "3.9" ]; then $CONDA_INSTALL importlib_metadata; fi

# Install dependencies for building the documentation
if [ "$BUILD_DOC" == "yes" ]; then $CONDA_INSTALL sphinx=2.4.4 docutils=0.17 sphinx_rtd_theme pygments numpydoc; fi
Expand Down
1 change: 1 addition & 0 deletions docs/source/user/installing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ vary with target operating system and hardware. The following lists them all
threading backend.
* ``tbb-devel`` - provides TBB headers/libraries for compiling TBB support
into Numba's threading backend (version >= 2021 required).
* ``importlib_metadata`` (for Python versions < 3.9)

* Optional runtime are:

Expand Down
41 changes: 34 additions & 7 deletions numba/core/entrypoints.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,58 @@
import logging
import warnings

from pkg_resources import iter_entry_points
from numba.core.config import PYVERSION

if PYVERSION < (3, 9):
try:
import importlib_metadata
except ImportError as ex:
raise ImportError(
"importlib_metadata backport is required for Python version < 3.9, "
"try:\n"
"$ conda/pip install importlib_metadata"
) from ex
else:
from importlib import metadata as importlib_metadata


_already_initialized = False
logger = logging.getLogger(__name__)


def init_all():
'''Execute all `numba_extensions` entry points with the name `init`
"""Execute all `numba_extensions` entry points with the name `init`
If extensions have already been initialized, this function does nothing.
'''
"""
global _already_initialized
if _already_initialized:
return

# Must put this here to avoid extensions re-triggering initialization
_already_initialized = True

for entry_point in iter_entry_points('numba_extensions', 'init'):
def load_ep(entry_point):
"""Loads a given entry point. Warns and logs on failure.
"""
logger.debug('Loading extension: %s', entry_point)
try:
func = entry_point.load()
func()
except Exception as e:
msg = "Numba extension module '{}' failed to load due to '{}({})'."
warnings.warn(msg.format(entry_point.module_name, type(e).__name__,
str(e)), stacklevel=2)
msg = (f"Numba extension module '{entry_point.module}' "
f"failed to load due to '{type(e).__name__}({str(e)})'.")
warnings.warn(msg, stacklevel=3)
logger.debug('Extension loading failed for: %s', entry_point)

eps = importlib_metadata.entry_points()
# Split, Python 3.10+ and importlib_metadata 3.6+ have the "selectable"
# interface, versions prior to that do not. See "compatibility note" in:
# https://docs.python.org/3.10/library/importlib.metadata.html#entry-points
if hasattr(eps, 'select'):
for entry_point in eps.select(group="numba_extensions", name="init"):
load_ep(entry_point)
else:
for entry_point in eps.get("numba_extensions", ()):
if entry_point.name == "init":
load_ep(entry_point)
150 changes: 67 additions & 83 deletions numba/tests/test_entrypoints.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
import sys
from unittest import mock

import types
import warnings
import unittest
import os
import subprocess
import threading
import pkg_resources

from numba import njit
from numba import config, njit
from numba.tests.support import TestCase
from numba.testing.main import _TIMEOUT as _RUNNER_TIMEOUT

if config.PYVERSION < (3, 9):
import importlib_metadata
else:
from importlib import metadata as importlib_metadata

_TEST_TIMEOUT = _RUNNER_TIMEOUT - 60.


Expand All @@ -31,45 +37,35 @@ def test_init_entrypoint(self):
# loosely based on Pandas test from:
# https://github.com/pandas-dev/pandas/pull/27488

# FIXME: Python 2 workaround because nonlocal doesn't exist
counters = {'init': 0}

def init_function():
counters['init'] += 1

mod = types.ModuleType("_test_numba_extension")
mod.init_func = init_function
mod = mock.Mock(__name__='_test_numba_extension')

try:
# will remove this module at the end of the test
sys.modules[mod.__name__] = mod

# We are registering an entry point using the "numba" package
# ("distribution" in pkg_resources-speak) itself, though these are
# normally registered by other packages.
dist = "numba"
entrypoints = pkg_resources.get_entry_map(dist)
my_entrypoint = pkg_resources.EntryPoint(
"init", # name of entry point
mod.__name__, # module with entry point object
attrs=['init_func'], # name of entry point object
dist=pkg_resources.get_distribution(dist)
my_entrypoint = importlib_metadata.EntryPoint(
'init', '_test_numba_extension:init_func', 'numba_extensions',
)
entrypoints.setdefault('numba_extensions',
{})['init'] = my_entrypoint

from numba.core import entrypoints
# Allow reinitialization
entrypoints._already_initialized = False
with mock.patch.object(
importlib_metadata,
'entry_points',
return_value={'numba_extensions': (my_entrypoint,)},
):

from numba.core import entrypoints

entrypoints.init_all()
# Allow reinitialization
entrypoints._already_initialized = False

entrypoints.init_all()

# was our init function called?
self.assertEqual(counters['init'], 1)
# was our init function called?
mod.init_func.assert_called_once()

# ensure we do not initialize twice
entrypoints.init_all()
self.assertEqual(counters['init'], 1)
# ensure we do not initialize twice
entrypoints.init_all()
mod.init_func.assert_called_once()
finally:
# remove fake module
if mod.__name__ in sys.modules:
Expand All @@ -79,50 +75,41 @@ def test_entrypoint_tolerance(self):
# loosely based on Pandas test from:
# https://github.com/pandas-dev/pandas/pull/27488

# FIXME: Python 2 workaround because nonlocal doesn't exist
counters = {'init': 0}

def init_function():
counters['init'] += 1
raise ValueError("broken")

mod = types.ModuleType("_test_numba_bad_extension")
mod.init_func = init_function
mod = mock.Mock(__name__='_test_numba_bad_extension')
mod.configure_mock(**{'init_func.side_effect': ValueError('broken')})

try:
# will remove this module at the end of the test
sys.modules[mod.__name__] = mod

# We are registering an entry point using the "numba" package
# ("distribution" in pkg_resources-speak) itself, though these are
# normally registered by other packages.
dist = "numba"
entrypoints = pkg_resources.get_entry_map(dist)
my_entrypoint = pkg_resources.EntryPoint(
"init", # name of entry point
mod.__name__, # module with entry point object
attrs=['init_func'], # name of entry point object
dist=pkg_resources.get_distribution(dist)
my_entrypoint = importlib_metadata.EntryPoint(
'init',
'_test_numba_bad_extension:init_func',
'numba_extensions',
)
entrypoints.setdefault('numba_extensions',
{})['init'] = my_entrypoint

from numba.core import entrypoints
# Allow reinitialization
entrypoints._already_initialized = False
with mock.patch.object(
importlib_metadata,
'entry_points',
return_value={'numba_extensions': (my_entrypoint,)},
):

with warnings.catch_warnings(record=True) as w:
entrypoints.init_all()
from numba.core import entrypoints
# Allow reinitialization
entrypoints._already_initialized = False

bad_str = "Numba extension module '_test_numba_bad_extension'"
for x in w:
if bad_str in str(x):
break
else:
raise ValueError("Expected warning message not found")
with warnings.catch_warnings(record=True) as w:
entrypoints.init_all()

# was our init function called?
self.assertEqual(counters['init'], 1)
bad_str = "Numba extension module '_test_numba_bad_extension'"
for x in w:
if bad_str in str(x):
break
else:
raise ValueError("Expected warning message not found")

# was our init function called?
mod.init_func.assert_called_once()

finally:
# remove fake module
Expand Down Expand Up @@ -188,26 +175,23 @@ def box_dummy(typ, val, c):
# will remove this module at the end of the test
sys.modules[mod.__name__] = mod

# We are registering an entry point using the "numba" package
# ("distribution" in pkg_resources-speak) itself, though these are
# normally registered by other packages.
dist = "numba"
entrypoints = pkg_resources.get_entry_map(dist)
my_entrypoint = pkg_resources.EntryPoint(
"init", # name of entry point
mod.__name__, # module with entry point object
attrs=['init_func'], # name of entry point object
dist=pkg_resources.get_distribution(dist)
my_entrypoint = importlib_metadata.EntryPoint(
'init',
'_test_numba_init_sequence:init_func',
'numba_extensions',
)
entrypoints.setdefault('numba_extensions',
{})['init'] = my_entrypoint

@njit
def foo(x):
return x

ival = _DummyClass(10)
foo(ival)
with mock.patch.object(
importlib_metadata,
'entry_points',
return_value={'numba_extensions': (my_entrypoint,)},
):
@njit
def foo(x):
return x

ival = _DummyClass(10)
foo(ival)
finally:
# remove fake module
if mod.__name__ in sys.modules:
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@ def check_file_at_path(path2file):
'llvmlite >={},<{}'.format(min_llvmlite_version, max_llvmlite_version),
'numpy >={}'.format(min_numpy_run_version),
'setuptools',
'importlib_metadata; python_version < "3.9"',
]

metadata = dict(
Expand Down

0 comments on commit fa5c9e1

Please sign in to comment.