Spaces:
Running
Running
File size: 9,705 Bytes
1380717 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 |
from __future__ import annotations
from functools import partial
from importlib.metadata import entry_points
from typing import TYPE_CHECKING, Any, Callable, Generic, cast
from typing_extensions import TypeAliasType, TypeIs, TypeVar
from altair.utils.deprecation import deprecated_warn
if TYPE_CHECKING:
from types import TracebackType
T = TypeVar("T")
R = TypeVar("R")
Plugin = TypeAliasType("Plugin", Callable[..., R], type_params=(R,))
PluginT = TypeVar("PluginT", bound=Plugin[Any])
IsPlugin = Callable[[object], TypeIs[Plugin[Any]]]
def _is_type(tp: type[T], /) -> Callable[[object], TypeIs[type[T]]]:
"""
Converts a type to guard function.
Added for compatibility with original `PluginRegistry` default.
"""
def func(obj: object, /) -> TypeIs[type[T]]:
return isinstance(obj, tp)
return func
class NoSuchEntryPoint(Exception):
def __init__(self, group, name):
self.group = group
self.name = name
def __str__(self):
return f"No {self.name!r} entry point found in group {self.group!r}"
class PluginEnabler:
"""
Context manager for enabling plugins.
This object lets you use enable() as a context manager to
temporarily enable a given plugin::
with plugins.enable("name"):
do_something() # 'name' plugin temporarily enabled
# plugins back to original state
"""
def __init__(self, registry: PluginRegistry, name: str, **options):
self.registry: PluginRegistry = registry
self.name: str = name
self.options: dict[str, Any] = options
self.original_state: dict[str, Any] = registry._get_state()
self.registry._enable(name, **options)
def __enter__(self) -> PluginEnabler:
return self
def __exit__(self, typ: type, value: Exception, traceback: TracebackType) -> None:
self.registry._set_state(self.original_state)
def __repr__(self) -> str:
return f"{self.registry.__class__.__name__}.enable({self.name!r})"
class PluginRegistry(Generic[PluginT, R]):
"""
A registry for plugins.
This is a plugin registry that allows plugins to be loaded/registered
in two ways:
1. Through an explicit call to ``.register(name, value)``.
2. By looking for other Python packages that are installed and provide
a setuptools entry point group.
When you create an instance of this class, provide the name of the
entry point group to use::
reg = PluginRegister("my_entrypoint_group")
"""
# this is a mapping of name to error message to allow custom error messages
# in case an entrypoint is not found
entrypoint_err_messages: dict[str, str] = {}
# global settings is a key-value mapping of settings that are stored globally
# in the registry rather than passed to the plugins
_global_settings: dict[str, Any] = {}
def __init__(
self, entry_point_group: str = "", plugin_type: IsPlugin = callable
) -> None:
"""
Create a PluginRegistry for a named entry point group.
Parameters
----------
entry_point_group: str
The name of the entry point group.
plugin_type
A type narrowing function that will optionally be used for runtime
type checking loaded plugins.
References
----------
https://typing.readthedocs.io/en/latest/spec/narrowing.html
"""
self.entry_point_group: str = entry_point_group
self.plugin_type: IsPlugin
if plugin_type is not callable and isinstance(plugin_type, type):
msg = (
f"Pass a callable `TypeIs` function to `plugin_type` instead.\n"
f"{type(self).__name__!r}(plugin_type)\n\n"
f"See also:\n"
f"https://typing.readthedocs.io/en/latest/spec/narrowing.html\n"
f"https://docs.astral.sh/ruff/rules/assert/"
)
deprecated_warn(msg, version="5.4.0")
self.plugin_type = cast(IsPlugin, _is_type(plugin_type))
else:
self.plugin_type = plugin_type
self._active: Plugin[R] | None = None
self._active_name: str = ""
self._plugins: dict[str, PluginT] = {}
self._options: dict[str, Any] = {}
self._global_settings: dict[str, Any] = self.__class__._global_settings.copy()
def register(self, name: str, value: PluginT | None) -> PluginT | None:
"""
Register a plugin by name and value.
This method is used for explicit registration of a plugin and shouldn't be
used to manage entry point managed plugins, which are auto-loaded.
Parameters
----------
name: str
The name of the plugin.
value: PluginType or None
The actual plugin object to register or None to unregister that plugin.
Returns
-------
plugin: PluginType or None
The plugin that was registered or unregistered.
"""
if value is None:
return self._plugins.pop(name, None)
elif self.plugin_type(value):
self._plugins[name] = value
return value
else:
msg = f"{type(value).__name__!r} is not compatible with {type(self).__name__!r}"
raise TypeError(msg)
def names(self) -> list[str]:
"""List the names of the registered and entry points plugins."""
exts = list(self._plugins.keys())
e_points = importlib_metadata_get(self.entry_point_group)
more_exts = [ep.name for ep in e_points]
exts.extend(more_exts)
return sorted(set(exts))
def _get_state(self) -> dict[str, Any]:
"""Return a dictionary representing the current state of the registry."""
return {
"_active": self._active,
"_active_name": self._active_name,
"_plugins": self._plugins.copy(),
"_options": self._options.copy(),
"_global_settings": self._global_settings.copy(),
}
def _set_state(self, state: dict[str, Any]) -> None:
"""Reset the state of the registry."""
assert set(state.keys()) == {
"_active",
"_active_name",
"_plugins",
"_options",
"_global_settings",
}
for key, val in state.items():
setattr(self, key, val)
def _enable(self, name: str, **options) -> None:
if name not in self._plugins:
try:
(ep,) = (
ep
for ep in importlib_metadata_get(self.entry_point_group)
if ep.name == name
)
except ValueError as err:
if name in self.entrypoint_err_messages:
raise ValueError(self.entrypoint_err_messages[name]) from err
else:
raise NoSuchEntryPoint(self.entry_point_group, name) from err
value = cast(PluginT, ep.load())
self.register(name, value)
self._active_name = name
self._active = self._plugins[name]
for key in set(options.keys()) & set(self._global_settings.keys()):
self._global_settings[key] = options.pop(key)
self._options = options
def enable(self, name: str | None = None, **options) -> PluginEnabler:
"""
Enable a plugin by name.
This can be either called directly, or used as a context manager.
Parameters
----------
name : string (optional)
The name of the plugin to enable. If not specified, then use the
current active name.
**options :
Any additional parameters will be passed to the plugin as keyword
arguments
Returns
-------
PluginEnabler:
An object that allows enable() to be used as a context manager
"""
if name is None:
name = self.active
return PluginEnabler(self, name, **options)
@property
def active(self) -> str:
"""Return the name of the currently active plugin."""
return self._active_name
@property
def options(self) -> dict[str, Any]:
"""Return the current options dictionary."""
return self._options
def get(self) -> partial[R] | Plugin[R] | None:
"""Return the currently active plugin."""
if (func := self._active) and self.plugin_type(func):
return partial(func, **self._options) if self._options else func
elif self._active is not None:
msg = (
f"{type(self).__name__!r} requires all plugins to be callable objects, "
f"but {type(self._active).__name__!r} is not callable."
)
raise TypeError(msg)
elif TYPE_CHECKING:
# NOTE: The `None` return is implicit, but `mypy` isn't satisfied
# - `ruff` will factor out explicit `None` return
# - `pyright` has no issue
raise NotImplementedError
def __repr__(self) -> str:
return f"{type(self).__name__}(active={self.active!r}, registered={self.names()!r})"
def importlib_metadata_get(group):
ep = entry_points()
# 'select' was introduced in Python 3.10 and 'get' got deprecated
# We don't check for Python version here as by checking with hasattr we
# also get compatibility with the importlib_metadata package which had a different
# deprecation cycle for 'get'
if hasattr(ep, "select"):
return ep.select(group=group) # pyright: ignore
else:
return ep.get(group, [])
|