From 9d55a331c7d77025054e85f23bc23c614fab6856 Mon Sep 17 00:00:00 2001 From: "Jason R. Coombs" Date: Mon, 22 Feb 2021 08:47:11 -0500 Subject: [PATCH] Separate compatibility shim from canonical EntryPoints container. --- importlib_metadata/__init__.py | 41 ++++++++++++++++++++-------------- 1 file changed, 24 insertions(+), 17 deletions(-) diff --git a/importlib_metadata/__init__.py b/importlib_metadata/__init__.py index 313beca2..94a82ffe 100644 --- a/importlib_metadata/__init__.py +++ b/importlib_metadata/__init__.py @@ -165,23 +165,12 @@ class EntryPoints(tuple): __slots__ = () - def __getitem__(self, name) -> Union[EntryPoint, 'EntryPoints']: + def __getitem__(self, name): # -> EntryPoint: try: - match = next(iter(self.select(name=name))) - return match + return next(iter(self.select(name=name))) except StopIteration: - if name in self.groups: - return self._group_getitem(name) raise KeyError(name) - def _group_getitem(self, name): - """ - For backward compatability, supply .__getitem__ for groups. - """ - msg = "GroupedEntryPoints.__getitem__ is deprecated for groups. Use select." - warnings.warn(msg, DeprecationWarning) - return self.select(group=name) - def select(self, **params): return EntryPoints(ep for ep in self if ep.matches(**params)) @@ -193,6 +182,23 @@ def names(self): def groups(self): return set(ep.group for ep in self) + @classmethod + def _from_text_for(cls, text, dist): + return cls(ep._for(dist) for ep in EntryPoint._from_text(text)) + + +class LegacyGroupedEntryPoints(EntryPoints): + def __getitem__(self, name) -> Union[EntryPoint, 'EntryPoints']: + try: + return super().__getitem__(name) + except KeyError: + if name not in self.groups: + raise + + msg = "GroupedEntryPoints.__getitem__ is deprecated for groups. Use select." + warnings.warn(msg, DeprecationWarning) + return self.select(group=name) + def get(self, group, default=None): """ For backward compatibility, supply .get @@ -202,9 +208,10 @@ def get(self, group, default=None): is_flake8 or warnings.warn(msg, DeprecationWarning) return self.select(group=group) or default - @classmethod - def _from_text_for(cls, text, dist): - return cls(ep._for(dist) for ep in EntryPoint._from_text(text)) + def select(self, **params): + if not params: + return self + return super().select(**params) class PackagePath(pathlib.PurePosixPath): @@ -704,7 +711,7 @@ def entry_points(**params): eps = itertools.chain.from_iterable( dist.entry_points for dist in unique(distributions()) ) - return EntryPoints(eps).select(**params) + return LegacyGroupedEntryPoints(eps).select(**params) def files(distribution_name):