from __future__ import annotations import json import pathlib from typing import Any import anywidget import traitlets import altair as alt from altair import TopLevelSpec from altair.utils._vegafusion_data import ( compile_to_vegafusion_chart_state, using_vegafusion, ) from altair.utils.selection import IndexSelection, IntervalSelection, PointSelection _here = pathlib.Path(__file__).parent class Params(traitlets.HasTraits): """Traitlet class storing a JupyterChart's params.""" def __init__(self, trait_values): super().__init__() for key, value in trait_values.items(): if isinstance(value, (int, float)): traitlet_type = traitlets.Float() elif isinstance(value, str): traitlet_type = traitlets.Unicode() elif isinstance(value, list): traitlet_type = traitlets.List() elif isinstance(value, dict): traitlet_type = traitlets.Dict() else: traitlet_type = traitlets.Any() # Add the new trait. self.add_traits(**{key: traitlet_type}) # Set the trait's value. setattr(self, key, value) def __repr__(self): return f"Params({self.trait_values()})" class Selections(traitlets.HasTraits): """Traitlet class storing a JupyterChart's selections.""" def __init__(self, trait_values): super().__init__() for key, value in trait_values.items(): if isinstance(value, IndexSelection): traitlet_type = traitlets.Instance(IndexSelection) elif isinstance(value, PointSelection): traitlet_type = traitlets.Instance(PointSelection) elif isinstance(value, IntervalSelection): traitlet_type = traitlets.Instance(IntervalSelection) else: msg = f"Unexpected selection type: {type(value)}" raise ValueError(msg) # Add the new trait. self.add_traits(**{key: traitlet_type}) # Set the trait's value. setattr(self, key, value) # Make read-only self.observe(self._make_read_only, names=key) def __repr__(self): return f"Selections({self.trait_values()})" def _make_read_only(self, change): """Work around to make traits read-only, but still allow us to change them internally.""" if change["name"] in self.traits() and change["old"] != change["new"]: self._set_value(change["name"], change["old"]) msg = ( "Selections may not be set from Python.\n" f"Attempted to set select: {change['name']}" ) raise ValueError(msg) def _set_value(self, key, value): self.unobserve(self._make_read_only, names=key) setattr(self, key, value) self.observe(self._make_read_only, names=key) def load_js_src() -> str: return (_here / "js" / "index.js").read_text() class JupyterChart(anywidget.AnyWidget): _esm = load_js_src() _css = r""" .vega-embed { /* Make sure action menu isn't cut off */ overflow: visible; } """ # Public traitlets chart = traitlets.Instance(TopLevelSpec, allow_none=True) spec = traitlets.Dict(allow_none=True).tag(sync=True) debounce_wait = traitlets.Float(default_value=10).tag(sync=True) max_wait = traitlets.Bool(default_value=True).tag(sync=True) local_tz = traitlets.Unicode(default_value=None, allow_none=True).tag(sync=True) debug = traitlets.Bool(default_value=False) embed_options = traitlets.Dict(default_value=None, allow_none=True).tag(sync=True) # Internal selection traitlets _selection_types = traitlets.Dict() _vl_selections = traitlets.Dict().tag(sync=True) # Internal param traitlets _params = traitlets.Dict().tag(sync=True) # Internal comm traitlets for VegaFusion support _chart_state = traitlets.Any(allow_none=True) _js_watch_plan = traitlets.Any(allow_none=True).tag(sync=True) _js_to_py_updates = traitlets.Any(allow_none=True).tag(sync=True) _py_to_js_updates = traitlets.Any(allow_none=True).tag(sync=True) # Track whether charts are configured for offline use _is_offline = False @classmethod def enable_offline(cls, offline: bool = True): """ Configure JupyterChart's offline behavior. Parameters ---------- offline: bool If True, configure JupyterChart to operate in offline mode where JavaScript dependencies are loaded from vl-convert. If False, configure it to operate in online mode where JavaScript dependencies are loaded from CDN dynamically. This is the default behavior. """ from altair.utils._importers import import_vl_convert, vl_version_for_vl_convert if offline: if cls._is_offline: # Already offline return vlc = import_vl_convert() src_lines = load_js_src().split("\n") # Remove leading lines with only whitespace, comments, or imports while src_lines and ( len(src_lines[0].strip()) == 0 or src_lines[0].startswith("import") or src_lines[0].startswith("//") ): src_lines.pop(0) src = "\n".join(src_lines) # vl-convert's javascript_bundle function creates a self-contained JavaScript bundle # for JavaScript snippets that import from a small set of dependencies that # vl-convert includes. To see the available imports and their imported names, run # import vl_convert as vlc # help(vlc.javascript_bundle) bundled_src = vlc.javascript_bundle( src, vl_version=vl_version_for_vl_convert() ) cls._esm = bundled_src cls._is_offline = True else: cls._esm = load_js_src() cls._is_offline = False def __init__( self, chart: TopLevelSpec, debounce_wait: int = 10, max_wait: bool = True, debug: bool = False, embed_options: dict | None = None, **kwargs: Any, ): """ Jupyter Widget for displaying and updating Altair Charts, and retrieving selection and parameter values. Parameters ---------- chart: Chart Altair Chart instance debounce_wait: int Debouncing wait time in milliseconds. Updates will be sent from the client to the kernel after debounce_wait milliseconds of no chart interactions. max_wait: bool If True (default), updates will be sent from the client to the kernel every debounce_wait milliseconds even if there are ongoing chart interactions. If False, updates will not be sent until chart interactions have completed. debug: bool If True, debug messages will be printed embed_options: dict Options to pass to vega-embed. See https://github.com/vega/vega-embed?tab=readme-ov-file#options """ self.params = Params({}) self.selections = Selections({}) super().__init__( chart=chart, debounce_wait=debounce_wait, max_wait=max_wait, debug=debug, embed_options=embed_options, **kwargs, ) @traitlets.observe("chart") def _on_change_chart(self, change): # noqa: C901 """Updates the JupyterChart's internal state when the wrapped Chart instance changes.""" new_chart = change.new selection_watches = [] selection_types = {} initial_params = {} initial_vl_selections = {} empty_selections = {} if new_chart is None: with self.hold_sync(): self.spec = None self._selection_types = selection_types self._vl_selections = initial_vl_selections self._params = initial_params return params = getattr(new_chart, "params", []) if params is not alt.Undefined: for param in new_chart.params: if isinstance(param.name, alt.ParameterName): clean_name = param.name.to_json().strip('"') else: clean_name = param.name select = getattr(param, "select", alt.Undefined) if select != alt.Undefined: if not isinstance(select, dict): select = select.to_dict() select_type = select["type"] if select_type == "point": if not ( select.get("fields", None) or select.get("encodings", None) ): # Point selection with no associated fields or encodings specified. # This is an index-based selection selection_types[clean_name] = "index" empty_selections[clean_name] = IndexSelection( name=clean_name, value=[], store=[] ) else: selection_types[clean_name] = "point" empty_selections[clean_name] = PointSelection( name=clean_name, value=[], store=[] ) elif select_type == "interval": selection_types[clean_name] = "interval" empty_selections[clean_name] = IntervalSelection( name=clean_name, value={}, store=[] ) else: msg = f"Unexpected selection type {select.type}" raise ValueError(msg) selection_watches.append(clean_name) initial_vl_selections[clean_name] = {"value": None, "store": []} else: clean_value = param.value if param.value != alt.Undefined else None initial_params[clean_name] = clean_value # Handle the params generated by transforms for param_name in collect_transform_params(new_chart): initial_params[param_name] = None # Setup params self.params = Params(initial_params) def on_param_traitlet_changed(param_change): new_params = dict(self._params) new_params[param_change["name"]] = param_change["new"] self._params = new_params self.params.observe(on_param_traitlet_changed) # Setup selections self.selections = Selections(empty_selections) # Update properties all together with self.hold_sync(): if using_vegafusion(): if self.local_tz is None: self.spec = None def on_local_tz_change(change): self._init_with_vegafusion(change["new"]) self.observe(on_local_tz_change, ["local_tz"]) else: self._init_with_vegafusion(self.local_tz) else: self.spec = new_chart.to_dict() self._selection_types = selection_types self._vl_selections = initial_vl_selections self._params = initial_params def _init_with_vegafusion(self, local_tz: str): if self.chart is not None: vegalite_spec = self.chart.to_dict(context={"pre_transform": False}) with self.hold_sync(): self._chart_state = compile_to_vegafusion_chart_state( vegalite_spec, local_tz ) self._js_watch_plan = self._chart_state.get_watch_plan()[ "client_to_server" ] self.spec = self._chart_state.get_transformed_spec() # Callback to update chart state and send updates back to client def on_js_to_py_updates(change): if self.debug: updates_str = json.dumps(change["new"], indent=2) print( f"JavaScript to Python VegaFusion updates:\n {updates_str}" ) updates = self._chart_state.update(change["new"]) if self.debug: updates_str = json.dumps(updates, indent=2) print( f"Python to JavaScript VegaFusion updates:\n {updates_str}" ) self._py_to_js_updates = updates self.observe(on_js_to_py_updates, ["_js_to_py_updates"]) @traitlets.observe("_params") def _on_change_params(self, change): for param_name, value in change.new.items(): setattr(self.params, param_name, value) @traitlets.observe("_vl_selections") def _on_change_selections(self, change): """Updates the JupyterChart's public selections traitlet in response to changes that the JavaScript logic makes to the internal _selections traitlet.""" for selection_name, selection_dict in change.new.items(): value = selection_dict["value"] store = selection_dict["store"] selection_type = self._selection_types[selection_name] if selection_type == "index": self.selections._set_value( selection_name, IndexSelection.from_vega(selection_name, signal=value, store=store), ) elif selection_type == "point": self.selections._set_value( selection_name, PointSelection.from_vega(selection_name, signal=value, store=store), ) elif selection_type == "interval": self.selections._set_value( selection_name, IntervalSelection.from_vega( selection_name, signal=value, store=store ), ) def collect_transform_params(chart: TopLevelSpec) -> set[str]: """ Collect the names of params that are defined by transforms. Parameters ---------- chart: Chart from which to extract transform params Returns ------- set of param names """ transform_params = set() # Handle recursive case for prop in ("layer", "concat", "hconcat", "vconcat"): for child in getattr(chart, prop, []): transform_params.update(collect_transform_params(child)) # Handle chart's own transforms transforms = getattr(chart, "transform", []) transforms = transforms if transforms != alt.Undefined else [] for tx in transforms: if hasattr(tx, "param"): transform_params.add(tx.param) return transform_params