Spaces:
Running
Running
from __future__ import annotations | |
from functools import partial | |
from importlib.metadata import entry_points | |
from typing import TYPE_CHECKING, Any, Callable, Generic, cast | |
from typing_extensions import TypeAliasType, TypeIs, TypeVar | |
from altair.utils.deprecation import deprecated_warn | |
if TYPE_CHECKING: | |
from types import TracebackType | |
T = TypeVar("T") | |
R = TypeVar("R") | |
Plugin = TypeAliasType("Plugin", Callable[..., R], type_params=(R,)) | |
PluginT = TypeVar("PluginT", bound=Plugin[Any]) | |
IsPlugin = Callable[[object], TypeIs[Plugin[Any]]] | |
def _is_type(tp: type[T], /) -> Callable[[object], TypeIs[type[T]]]: | |
""" | |
Converts a type to guard function. | |
Added for compatibility with original `PluginRegistry` default. | |
""" | |
def func(obj: object, /) -> TypeIs[type[T]]: | |
return isinstance(obj, tp) | |
return func | |
class NoSuchEntryPoint(Exception): | |
def __init__(self, group, name): | |
self.group = group | |
self.name = name | |
def __str__(self): | |
return f"No {self.name!r} entry point found in group {self.group!r}" | |
class PluginEnabler: | |
""" | |
Context manager for enabling plugins. | |
This object lets you use enable() as a context manager to | |
temporarily enable a given plugin:: | |
with plugins.enable("name"): | |
do_something() # 'name' plugin temporarily enabled | |
# plugins back to original state | |
""" | |
def __init__(self, registry: PluginRegistry, name: str, **options): | |
self.registry: PluginRegistry = registry | |
self.name: str = name | |
self.options: dict[str, Any] = options | |
self.original_state: dict[str, Any] = registry._get_state() | |
self.registry._enable(name, **options) | |
def __enter__(self) -> PluginEnabler: | |
return self | |
def __exit__(self, typ: type, value: Exception, traceback: TracebackType) -> None: | |
self.registry._set_state(self.original_state) | |
def __repr__(self) -> str: | |
return f"{self.registry.__class__.__name__}.enable({self.name!r})" | |
class PluginRegistry(Generic[PluginT, R]): | |
""" | |
A registry for plugins. | |
This is a plugin registry that allows plugins to be loaded/registered | |
in two ways: | |
1. Through an explicit call to ``.register(name, value)``. | |
2. By looking for other Python packages that are installed and provide | |
a setuptools entry point group. | |
When you create an instance of this class, provide the name of the | |
entry point group to use:: | |
reg = PluginRegister("my_entrypoint_group") | |
""" | |
# this is a mapping of name to error message to allow custom error messages | |
# in case an entrypoint is not found | |
entrypoint_err_messages: dict[str, str] = {} | |
# global settings is a key-value mapping of settings that are stored globally | |
# in the registry rather than passed to the plugins | |
_global_settings: dict[str, Any] = {} | |
def __init__( | |
self, entry_point_group: str = "", plugin_type: IsPlugin = callable | |
) -> None: | |
""" | |
Create a PluginRegistry for a named entry point group. | |
Parameters | |
---------- | |
entry_point_group: str | |
The name of the entry point group. | |
plugin_type | |
A type narrowing function that will optionally be used for runtime | |
type checking loaded plugins. | |
References | |
---------- | |
https://typing.readthedocs.io/en/latest/spec/narrowing.html | |
""" | |
self.entry_point_group: str = entry_point_group | |
self.plugin_type: IsPlugin | |
if plugin_type is not callable and isinstance(plugin_type, type): | |
msg = ( | |
f"Pass a callable `TypeIs` function to `plugin_type` instead.\n" | |
f"{type(self).__name__!r}(plugin_type)\n\n" | |
f"See also:\n" | |
f"https://typing.readthedocs.io/en/latest/spec/narrowing.html\n" | |
f"https://docs.astral.sh/ruff/rules/assert/" | |
) | |
deprecated_warn(msg, version="5.4.0") | |
self.plugin_type = cast(IsPlugin, _is_type(plugin_type)) | |
else: | |
self.plugin_type = plugin_type | |
self._active: Plugin[R] | None = None | |
self._active_name: str = "" | |
self._plugins: dict[str, PluginT] = {} | |
self._options: dict[str, Any] = {} | |
self._global_settings: dict[str, Any] = self.__class__._global_settings.copy() | |
def register(self, name: str, value: PluginT | None) -> PluginT | None: | |
""" | |
Register a plugin by name and value. | |
This method is used for explicit registration of a plugin and shouldn't be | |
used to manage entry point managed plugins, which are auto-loaded. | |
Parameters | |
---------- | |
name: str | |
The name of the plugin. | |
value: PluginType or None | |
The actual plugin object to register or None to unregister that plugin. | |
Returns | |
------- | |
plugin: PluginType or None | |
The plugin that was registered or unregistered. | |
""" | |
if value is None: | |
return self._plugins.pop(name, None) | |
elif self.plugin_type(value): | |
self._plugins[name] = value | |
return value | |
else: | |
msg = f"{type(value).__name__!r} is not compatible with {type(self).__name__!r}" | |
raise TypeError(msg) | |
def names(self) -> list[str]: | |
"""List the names of the registered and entry points plugins.""" | |
exts = list(self._plugins.keys()) | |
e_points = importlib_metadata_get(self.entry_point_group) | |
more_exts = [ep.name for ep in e_points] | |
exts.extend(more_exts) | |
return sorted(set(exts)) | |
def _get_state(self) -> dict[str, Any]: | |
"""Return a dictionary representing the current state of the registry.""" | |
return { | |
"_active": self._active, | |
"_active_name": self._active_name, | |
"_plugins": self._plugins.copy(), | |
"_options": self._options.copy(), | |
"_global_settings": self._global_settings.copy(), | |
} | |
def _set_state(self, state: dict[str, Any]) -> None: | |
"""Reset the state of the registry.""" | |
assert set(state.keys()) == { | |
"_active", | |
"_active_name", | |
"_plugins", | |
"_options", | |
"_global_settings", | |
} | |
for key, val in state.items(): | |
setattr(self, key, val) | |
def _enable(self, name: str, **options) -> None: | |
if name not in self._plugins: | |
try: | |
(ep,) = ( | |
ep | |
for ep in importlib_metadata_get(self.entry_point_group) | |
if ep.name == name | |
) | |
except ValueError as err: | |
if name in self.entrypoint_err_messages: | |
raise ValueError(self.entrypoint_err_messages[name]) from err | |
else: | |
raise NoSuchEntryPoint(self.entry_point_group, name) from err | |
value = cast(PluginT, ep.load()) | |
self.register(name, value) | |
self._active_name = name | |
self._active = self._plugins[name] | |
for key in set(options.keys()) & set(self._global_settings.keys()): | |
self._global_settings[key] = options.pop(key) | |
self._options = options | |
def enable(self, name: str | None = None, **options) -> PluginEnabler: | |
""" | |
Enable a plugin by name. | |
This can be either called directly, or used as a context manager. | |
Parameters | |
---------- | |
name : string (optional) | |
The name of the plugin to enable. If not specified, then use the | |
current active name. | |
**options : | |
Any additional parameters will be passed to the plugin as keyword | |
arguments | |
Returns | |
------- | |
PluginEnabler: | |
An object that allows enable() to be used as a context manager | |
""" | |
if name is None: | |
name = self.active | |
return PluginEnabler(self, name, **options) | |
def active(self) -> str: | |
"""Return the name of the currently active plugin.""" | |
return self._active_name | |
def options(self) -> dict[str, Any]: | |
"""Return the current options dictionary.""" | |
return self._options | |
def get(self) -> partial[R] | Plugin[R] | None: | |
"""Return the currently active plugin.""" | |
if (func := self._active) and self.plugin_type(func): | |
return partial(func, **self._options) if self._options else func | |
elif self._active is not None: | |
msg = ( | |
f"{type(self).__name__!r} requires all plugins to be callable objects, " | |
f"but {type(self._active).__name__!r} is not callable." | |
) | |
raise TypeError(msg) | |
elif TYPE_CHECKING: | |
# NOTE: The `None` return is implicit, but `mypy` isn't satisfied | |
# - `ruff` will factor out explicit `None` return | |
# - `pyright` has no issue | |
raise NotImplementedError | |
def __repr__(self) -> str: | |
return f"{type(self).__name__}(active={self.active!r}, registered={self.names()!r})" | |
def importlib_metadata_get(group): | |
ep = entry_points() | |
# 'select' was introduced in Python 3.10 and 'get' got deprecated | |
# We don't check for Python version here as by checking with hasattr we | |
# also get compatibility with the importlib_metadata package which had a different | |
# deprecation cycle for 'get' | |
if hasattr(ep, "select"): | |
return ep.select(group=group) # pyright: ignore | |
else: | |
return ep.get(group, []) | |