File size: 9,705 Bytes
1380717
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
from __future__ import annotations

from functools import partial
from importlib.metadata import entry_points
from typing import TYPE_CHECKING, Any, Callable, Generic, cast
from typing_extensions import TypeAliasType, TypeIs, TypeVar

from altair.utils.deprecation import deprecated_warn

if TYPE_CHECKING:
    from types import TracebackType

T = TypeVar("T")
R = TypeVar("R")
Plugin = TypeAliasType("Plugin", Callable[..., R], type_params=(R,))
PluginT = TypeVar("PluginT", bound=Plugin[Any])
IsPlugin = Callable[[object], TypeIs[Plugin[Any]]]


def _is_type(tp: type[T], /) -> Callable[[object], TypeIs[type[T]]]:
    """
    Converts a type to guard function.

    Added for compatibility with original `PluginRegistry` default.
    """

    def func(obj: object, /) -> TypeIs[type[T]]:
        return isinstance(obj, tp)

    return func


class NoSuchEntryPoint(Exception):
    def __init__(self, group, name):
        self.group = group
        self.name = name

    def __str__(self):
        return f"No {self.name!r} entry point found in group {self.group!r}"


class PluginEnabler:
    """
    Context manager for enabling plugins.

    This object lets you use enable() as a context manager to
    temporarily enable a given plugin::

        with plugins.enable("name"):
            do_something()  # 'name' plugin temporarily enabled
        # plugins back to original state
    """

    def __init__(self, registry: PluginRegistry, name: str, **options):
        self.registry: PluginRegistry = registry
        self.name: str = name
        self.options: dict[str, Any] = options
        self.original_state: dict[str, Any] = registry._get_state()
        self.registry._enable(name, **options)

    def __enter__(self) -> PluginEnabler:
        return self

    def __exit__(self, typ: type, value: Exception, traceback: TracebackType) -> None:
        self.registry._set_state(self.original_state)

    def __repr__(self) -> str:
        return f"{self.registry.__class__.__name__}.enable({self.name!r})"


class PluginRegistry(Generic[PluginT, R]):
    """
    A registry for plugins.

    This is a plugin registry that allows plugins to be loaded/registered
    in two ways:

    1. Through an explicit call to ``.register(name, value)``.
    2. By looking for other Python packages that are installed and provide
       a setuptools entry point group.

    When you create an instance of this class, provide the name of the
    entry point group to use::

        reg = PluginRegister("my_entrypoint_group")

    """

    # this is a mapping of name to error message to allow custom error messages
    # in case an entrypoint is not found
    entrypoint_err_messages: dict[str, str] = {}

    # global settings is a key-value mapping of settings that are stored globally
    # in the registry rather than passed to the plugins
    _global_settings: dict[str, Any] = {}

    def __init__(
        self, entry_point_group: str = "", plugin_type: IsPlugin = callable
    ) -> None:
        """
        Create a PluginRegistry for a named entry point group.

        Parameters
        ----------
        entry_point_group: str
            The name of the entry point group.
        plugin_type
            A type narrowing function that will optionally be used for runtime
            type checking loaded plugins.

        References
        ----------
        https://typing.readthedocs.io/en/latest/spec/narrowing.html
        """
        self.entry_point_group: str = entry_point_group
        self.plugin_type: IsPlugin
        if plugin_type is not callable and isinstance(plugin_type, type):
            msg = (
                f"Pass a callable `TypeIs` function to `plugin_type` instead.\n"
                f"{type(self).__name__!r}(plugin_type)\n\n"
                f"See also:\n"
                f"https://typing.readthedocs.io/en/latest/spec/narrowing.html\n"
                f"https://docs.astral.sh/ruff/rules/assert/"
            )
            deprecated_warn(msg, version="5.4.0")
            self.plugin_type = cast(IsPlugin, _is_type(plugin_type))
        else:
            self.plugin_type = plugin_type
        self._active: Plugin[R] | None = None
        self._active_name: str = ""
        self._plugins: dict[str, PluginT] = {}
        self._options: dict[str, Any] = {}
        self._global_settings: dict[str, Any] = self.__class__._global_settings.copy()

    def register(self, name: str, value: PluginT | None) -> PluginT | None:
        """
        Register a plugin by name and value.

        This method is used for explicit registration of a plugin and shouldn't be
        used to manage entry point managed plugins, which are auto-loaded.

        Parameters
        ----------
        name: str
            The name of the plugin.
        value: PluginType or None
            The actual plugin object to register or None to unregister that plugin.

        Returns
        -------
        plugin: PluginType or None
            The plugin that was registered or unregistered.
        """
        if value is None:
            return self._plugins.pop(name, None)
        elif self.plugin_type(value):
            self._plugins[name] = value
            return value
        else:
            msg = f"{type(value).__name__!r} is not compatible with {type(self).__name__!r}"
            raise TypeError(msg)

    def names(self) -> list[str]:
        """List the names of the registered and entry points plugins."""
        exts = list(self._plugins.keys())
        e_points = importlib_metadata_get(self.entry_point_group)
        more_exts = [ep.name for ep in e_points]
        exts.extend(more_exts)
        return sorted(set(exts))

    def _get_state(self) -> dict[str, Any]:
        """Return a dictionary representing the current state of the registry."""
        return {
            "_active": self._active,
            "_active_name": self._active_name,
            "_plugins": self._plugins.copy(),
            "_options": self._options.copy(),
            "_global_settings": self._global_settings.copy(),
        }

    def _set_state(self, state: dict[str, Any]) -> None:
        """Reset the state of the registry."""
        assert set(state.keys()) == {
            "_active",
            "_active_name",
            "_plugins",
            "_options",
            "_global_settings",
        }
        for key, val in state.items():
            setattr(self, key, val)

    def _enable(self, name: str, **options) -> None:
        if name not in self._plugins:
            try:
                (ep,) = (
                    ep
                    for ep in importlib_metadata_get(self.entry_point_group)
                    if ep.name == name
                )
            except ValueError as err:
                if name in self.entrypoint_err_messages:
                    raise ValueError(self.entrypoint_err_messages[name]) from err
                else:
                    raise NoSuchEntryPoint(self.entry_point_group, name) from err
            value = cast(PluginT, ep.load())
            self.register(name, value)
        self._active_name = name
        self._active = self._plugins[name]
        for key in set(options.keys()) & set(self._global_settings.keys()):
            self._global_settings[key] = options.pop(key)
        self._options = options

    def enable(self, name: str | None = None, **options) -> PluginEnabler:
        """
        Enable a plugin by name.

        This can be either called directly, or used as a context manager.

        Parameters
        ----------
        name : string (optional)
            The name of the plugin to enable. If not specified, then use the
            current active name.
        **options :
            Any additional parameters will be passed to the plugin as keyword
            arguments

        Returns
        -------
        PluginEnabler:
            An object that allows enable() to be used as a context manager
        """
        if name is None:
            name = self.active
        return PluginEnabler(self, name, **options)

    @property
    def active(self) -> str:
        """Return the name of the currently active plugin."""
        return self._active_name

    @property
    def options(self) -> dict[str, Any]:
        """Return the current options dictionary."""
        return self._options

    def get(self) -> partial[R] | Plugin[R] | None:
        """Return the currently active plugin."""
        if (func := self._active) and self.plugin_type(func):
            return partial(func, **self._options) if self._options else func
        elif self._active is not None:
            msg = (
                f"{type(self).__name__!r} requires all plugins to be callable objects, "
                f"but {type(self._active).__name__!r} is not callable."
            )
            raise TypeError(msg)
        elif TYPE_CHECKING:
            # NOTE: The `None` return is implicit, but `mypy` isn't satisfied
            # - `ruff` will factor out explicit `None` return
            # - `pyright` has no issue
            raise NotImplementedError

    def __repr__(self) -> str:
        return f"{type(self).__name__}(active={self.active!r}, registered={self.names()!r})"


def importlib_metadata_get(group):
    ep = entry_points()
    # 'select' was introduced in Python 3.10 and 'get' got deprecated
    # We don't check for Python version here as by checking with hasattr we
    # also get compatibility with the importlib_metadata package which had a different
    # deprecation cycle for 'get'
    if hasattr(ep, "select"):
        return ep.select(group=group)  # pyright: ignore
    else:
        return ep.get(group, [])