diff --git a/importlib_metadata/__init__.py b/importlib_metadata/__init__.py index 313beca2..94a82ffe 100644 --- a/importlib_metadata/__init__.py +++ b/importlib_metadata/__init__.py @@ -165,23 +165,12 @@ class EntryPoints(tuple): __slots__ = () - def __getitem__(self, name) -> Union[EntryPoint, 'EntryPoints']: + def __getitem__(self, name): # -> EntryPoint: try: - match = next(iter(self.select(name=name))) - return match + return next(iter(self.select(name=name))) except StopIteration: - if name in self.groups: - return self._group_getitem(name) raise KeyError(name) - def _group_getitem(self, name): - """ - For backward compatability, supply .__getitem__ for groups. - """ - msg = "GroupedEntryPoints.__getitem__ is deprecated for groups. Use select." - warnings.warn(msg, DeprecationWarning) - return self.select(group=name) - def select(self, **params): return EntryPoints(ep for ep in self if ep.matches(**params)) @@ -193,6 +182,23 @@ def names(self): def groups(self): return set(ep.group for ep in self) + @classmethod + def _from_text_for(cls, text, dist): + return cls(ep._for(dist) for ep in EntryPoint._from_text(text)) + + +class LegacyGroupedEntryPoints(EntryPoints): + def __getitem__(self, name) -> Union[EntryPoint, 'EntryPoints']: + try: + return super().__getitem__(name) + except KeyError: + if name not in self.groups: + raise + + msg = "GroupedEntryPoints.__getitem__ is deprecated for groups. Use select." + warnings.warn(msg, DeprecationWarning) + return self.select(group=name) + def get(self, group, default=None): """ For backward compatibility, supply .get @@ -202,9 +208,10 @@ def get(self, group, default=None): is_flake8 or warnings.warn(msg, DeprecationWarning) return self.select(group=group) or default - @classmethod - def _from_text_for(cls, text, dist): - return cls(ep._for(dist) for ep in EntryPoint._from_text(text)) + def select(self, **params): + if not params: + return self + return super().select(**params) class PackagePath(pathlib.PurePosixPath): @@ -704,7 +711,7 @@ def entry_points(**params): eps = itertools.chain.from_iterable( dist.entry_points for dist in unique(distributions()) ) - return EntryPoints(eps).select(**params) + return LegacyGroupedEntryPoints(eps).select(**params) def files(distribution_name):