diff --git a/CHANGES.rst b/CHANGES.rst index 2bc3a356..f8df681d 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,3 +1,39 @@ +v3.6.0 +====== + +* #284: Introduces new ``EntryPoints`` object, a tuple of + ``EntryPoint`` objects but with convenience properties for + selecting and inspecting the results: + + - ``.select()`` accepts ``group`` or ``name`` keyword + parameters and returns a new ``EntryPoints`` tuple + with only those that match the selection. + - ``.groups`` property presents all of the group names. + - ``.names`` property presents the names of the entry points. + - Item access (e.g. ``eps[name]``) retrieves a single + entry point by name. + + ``entry_points`` now accepts "selection parameters", + same as ``EntryPoint.select()``. + + ``entry_points()`` now provides a future-compatible + ``SelectableGroups`` object that supplies the above interface + but remains a dict for compatibility. + + In the future, ``entry_points()`` will return an + ``EntryPoints`` object, but provide for backward + compatibility with a deprecated ``__getitem__`` + accessor by group and a ``get()`` method. + + If passing selection parameters to ``entry_points``, the + future behavior is invoked and an ``EntryPoints`` is the + result. + + Construction of entry points using + ``dict([EntryPoint, ...])`` is now deprecated and raises + an appropriate DeprecationWarning and will be removed in + a future version. + v3.5.0 ====== diff --git a/docs/using.rst b/docs/using.rst index 00409867..97941452 100644 --- a/docs/using.rst +++ b/docs/using.rst @@ -67,18 +67,20 @@ This package provides the following functionality via its public API. Entry points ------------ -The ``entry_points()`` function returns a dictionary of all entry points, -keyed by group. Entry points are represented by ``EntryPoint`` instances; +The ``entry_points()`` function returns a collection of entry points. +Entry points are represented by ``EntryPoint`` instances; each ``EntryPoint`` has a ``.name``, ``.group``, and ``.value`` attributes and a ``.load()`` method to resolve the value. There are also ``.module``, ``.attr``, and ``.extras`` attributes for getting the components of the ``.value`` attribute:: >>> eps = entry_points() - >>> list(eps) + >>> sorted(eps.groups) ['console_scripts', 'distutils.commands', 'distutils.setup_keywords', 'egg_info.writers', 'setuptools.installation'] - >>> scripts = eps['console_scripts'] - >>> wheel = [ep for ep in scripts if ep.name == 'wheel'][0] + >>> scripts = eps.select(group='console_scripts') + >>> 'wheel' in scripts.names + True + >>> wheel = scripts['wheel'] >>> wheel EntryPoint(name='wheel', value='wheel.cli:main', group='console_scripts') >>> wheel.module diff --git a/importlib_metadata/__init__.py b/importlib_metadata/__init__.py index 4c420a55..b0b1ae0e 100644 --- a/importlib_metadata/__init__.py +++ b/importlib_metadata/__init__.py @@ -5,8 +5,10 @@ import sys import zipp import email +import inspect import pathlib import operator +import warnings import functools import itertools import posixpath @@ -130,10 +132,6 @@ def _from_text(cls, text): config.read_string(text) return cls._from_config(config) - @classmethod - def _from_text_for(cls, text, dist): - return (ep._for(dist) for ep in cls._from_text(text)) - def _for(self, dist): self.dist = dist return self @@ -141,11 +139,12 @@ def _for(self, dist): def __iter__(self): """ Supply iter so one may construct dicts of EntryPoints by name. - - >>> eps = [EntryPoint('a', 'b', 'c'), EntryPoint('d', 'e', 'f')] - >>> dict(eps)['a'] - EntryPoint(name='a', value='b', group='c') """ + msg = ( + "Construction of dict of EntryPoints is deprecated in " + "favor of EntryPoints." + ) + warnings.warn(msg, DeprecationWarning) return iter((self.name, self)) def __reduce__(self): @@ -154,6 +153,118 @@ def __reduce__(self): (self.name, self.value, self.group), ) + def matches(self, **params): + attrs = (getattr(self, param) for param in params) + return all(map(operator.eq, params.values(), attrs)) + + +class EntryPoints(tuple): + """ + An immutable collection of selectable EntryPoint objects. + """ + + __slots__ = () + + def __getitem__(self, name): # -> EntryPoint: + try: + return next(iter(self.select(name=name))) + except StopIteration: + raise KeyError(name) + + def select(self, **params): + return EntryPoints(ep for ep in self if ep.matches(**params)) + + @property + def names(self): + return set(ep.name for ep in self) + + @property + def groups(self): + """ + For coverage while SelectableGroups is present. + >>> EntryPoints().groups + set() + """ + return set(ep.group for ep in self) + + @classmethod + def _from_text_for(cls, text, dist): + return cls(ep._for(dist) for ep in EntryPoint._from_text(text)) + + +class SelectableGroups(dict): + """ + A backward- and forward-compatible result from + entry_points that fully implements the dict interface. + """ + + @classmethod + def load(cls, eps): + by_group = operator.attrgetter('group') + ordered = sorted(eps, key=by_group) + grouped = itertools.groupby(ordered, by_group) + return cls((group, EntryPoints(eps)) for group, eps in grouped) + + @property + def _all(self): + return EntryPoints(itertools.chain.from_iterable(self.values())) + + @property + def groups(self): + return self._all.groups + + @property + def names(self): + """ + for coverage: + >>> SelectableGroups().names + set() + """ + return self._all.names + + def select(self, **params): + if not params: + return self + return self._all.select(**params) + + +class LegacyGroupedEntryPoints(EntryPoints): # pragma: nocover + """ + Compatibility wrapper around EntryPoints to provide + much of the 'dict' interface previously returned by + entry_points. + """ + + def __getitem__(self, name) -> Union[EntryPoint, 'EntryPoints']: + """ + When accessed by name that matches a group, return the group. + """ + group = self.select(group=name) + if group: + msg = "GroupedEntryPoints.__getitem__ is deprecated for groups. Use select." + warnings.warn(msg, DeprecationWarning, stacklevel=2) + return group + + return super().__getitem__(name) + + def get(self, group, default=None): + """ + For backward compatibility, supply .get. + """ + is_flake8 = any('flake8' in str(frame) for frame in inspect.stack()) + msg = "GroupedEntryPoints.get is deprecated. Use select." + is_flake8 or warnings.warn(msg, DeprecationWarning, stacklevel=2) + return self.select(group=group) or default + + def select(self, **params): + """ + Prevent transform to EntryPoints during call to entry_points if + no selection parameters were passed. + """ + if not params: + return self + return super().select(**params) + class PackagePath(pathlib.PurePosixPath): """A reference to a path in a package""" @@ -310,7 +421,7 @@ def version(self): @property def entry_points(self): - return list(EntryPoint._from_text_for(self.read_text('entry_points.txt'), self)) + return EntryPoints._from_text_for(self.read_text('entry_points.txt'), self) @property def files(self): @@ -643,19 +754,29 @@ def version(distribution_name): return distribution(distribution_name).version -def entry_points(): +def entry_points(**params) -> Union[EntryPoints, SelectableGroups]: """Return EntryPoint objects for all installed packages. - :return: EntryPoint objects for all installed packages. + Pass selection parameters (group or name) to filter the + result to entry points matching those properties (see + EntryPoints.select()). + + For compatibility, returns ``SelectableGroups`` object unless + selection parameters are supplied. In the future, this function + will return ``LegacyGroupedEntryPoints`` instead of + ``SelectableGroups`` and eventually will only return + ``EntryPoints``. + + For maximum future compatibility, pass selection parameters + or invoke ``.select`` with parameters on the result. + + :return: EntryPoints or SelectableGroups for all installed packages. """ unique = functools.partial(unique_everseen, key=operator.attrgetter('name')) eps = itertools.chain.from_iterable( dist.entry_points for dist in unique(distributions()) ) - by_group = operator.attrgetter('group') - ordered = sorted(eps, key=by_group) - grouped = itertools.groupby(ordered, by_group) - return {group: tuple(eps) for group, eps in grouped} + return SelectableGroups.load(eps).select(**params) def files(distribution_name): diff --git a/tests/test_api.py b/tests/test_api.py index 04a5b9d3..a2810acc 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,6 +1,7 @@ import re import textwrap import unittest +import warnings from . import fixtures from importlib_metadata import ( @@ -64,13 +65,16 @@ def test_read_text(self): self.assertEqual(top_level.read_text(), 'mod\n') def test_entry_points(self): - entries = dict(entry_points()['entries']) + eps = entry_points() + assert 'entries' in eps.groups + entries = eps.select(group='entries') + assert 'main' in entries.names ep = entries['main'] self.assertEqual(ep.value, 'mod:main') self.assertEqual(ep.extras, []) def test_entry_points_distribution(self): - entries = dict(entry_points()['entries']) + entries = entry_points(group='entries') for entry in ("main", "ns:sub"): ep = entries[entry] self.assertIn(ep.dist.name, ('distinfo-pkg', 'egginfo-pkg')) @@ -96,14 +100,61 @@ def test_entry_points_unique_packages(self): }, } fixtures.build_files(alt_pkg, alt_site_dir) - entries = dict(entry_points()['entries']) + entries = entry_points(group='entries') assert not any( ep.dist.name == 'distinfo-pkg' and ep.dist.version == '1.0.0' - for ep in entries.values() + for ep in entries ) # ns:sub doesn't exist in alt_pkg assert 'ns:sub' not in entries + def test_entry_points_missing_name(self): + with self.assertRaises(KeyError): + entry_points(group='entries')['missing'] + + def test_entry_points_missing_group(self): + assert entry_points(group='missing') == () + + def test_entry_points_dict_construction(self): + """ + Prior versions of entry_points() returned simple lists and + allowed casting those lists into maps by name using ``dict()``. + Capture this now deprecated use-case. + """ + with warnings.catch_warnings(record=True) as caught: + eps = dict(entry_points(group='entries')) + + assert 'main' in eps + assert eps['main'] == entry_points(group='entries')['main'] + + # check warning + expected = next(iter(caught)) + assert expected.category is DeprecationWarning + assert "Construction of dict of EntryPoints is deprecated" in str(expected) + + def test_entry_points_groups_getitem(self): + """ + Prior versions of entry_points() returned a dict. Ensure + that callers using '.__getitem__()' are supported but warned to + migrate. + """ + with warnings.catch_warnings(record=True): + entry_points()['entries'] == entry_points(group='entries') + + with self.assertRaises(KeyError): + entry_points()['missing'] + + def test_entry_points_groups_get(self): + """ + Prior versions of entry_points() returned a dict. Ensure + that callers using '.get()' are supported but warned to + migrate. + """ + with warnings.catch_warnings(record=True): + entry_points().get('missing', 'default') == 'default' + entry_points().get('entries', 'default') == entry_points()['entries'] + entry_points().get('missing', ()) == () + def test_metadata_for_this_package(self): md = metadata('egginfo-pkg') assert md['author'] == 'Steven Ma' diff --git a/tests/test_main.py b/tests/test_main.py index 74979be8..e8a66c0d 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -3,6 +3,7 @@ import pickle import textwrap import unittest +import warnings import importlib import importlib_metadata import pyfakefs.fake_filesystem_unittest as ffs @@ -57,13 +58,11 @@ def test_import_nonexistent_module(self): importlib.import_module('does_not_exist') def test_resolve(self): - entries = dict(entry_points()['entries']) - ep = entries['main'] + ep = entry_points(group='entries')['main'] self.assertEqual(ep.load().__name__, "main") def test_entrypoint_with_colon_in_name(self): - entries = dict(entry_points()['entries']) - ep = entries['ns:sub'] + ep = entry_points(group='entries')['ns:sub'] self.assertEqual(ep.value, 'mod:main') def test_resolve_without_attr(self): @@ -249,7 +248,8 @@ def test_json_dump(self): json should not expect to be able to dump an EntryPoint """ with self.assertRaises(Exception): - json.dumps(self.ep) + with warnings.catch_warnings(record=True): + json.dumps(self.ep) def test_module(self): assert self.ep.module == 'value' diff --git a/tests/test_zip.py b/tests/test_zip.py index 67311da2..4279046d 100644 --- a/tests/test_zip.py +++ b/tests/test_zip.py @@ -45,7 +45,7 @@ def test_zip_version_does_not_match(self): version('definitely-not-installed') def test_zip_entry_points(self): - scripts = dict(entry_points()['console_scripts']) + scripts = entry_points(group='console_scripts') entry_point = scripts['example'] self.assertEqual(entry_point.value, 'example:main') entry_point = scripts['Example']