from __future__ import annotations import base64 import gzip import json from dataclasses import dataclass, fields from io import BytesIO from pathlib import Path from urllib.parse import parse_qsl import altair as alt import ipywidgets as widgets import numpy as np import polars as pl import solara import solara.lab from cmap import Colormap from ipymolstar.widget import PDBeMolstar from pydantic import BaseModel from make_link import decode_data base_v = np.vectorize(np.base_repr) PAD_SIZE = 0.05 # when not autoscale Y size of padding used def norm(x, vmin, vmax): return (x - vmin) / (vmax - vmin) class ColorTransform(BaseModel): name: str = "tol:rainbow_PuRd" norm_type: str = "linear" vmin: float = 0.0 vmax: float = 1.0 missing_data_color: str = "#8c8c8c" highlight_color: str = "#e933f8" def molstar_colors(self, data: pl.DataFrame) -> dict: data = data.drop_nulls() if self.norm_type == "categorical": values = data["value"] else: values = norm(data["value"], vmin=self.vmin, vmax=self.vmax) rgba_array = self.cmap(values, bytes=True) ints = rgba_array.astype(np.uint8).view(dtype=np.uint32).byteswap() padded = np.char.rjust(base_v(ints // 2**8, 16), 6, "0") hex_colors = np.char.add("#", padded).squeeze() color_data = { "data": [ {"residue_number": resi, "color": hcolor.lower()} for resi, hcolor in zip(data["residue_number"], hex_colors) ], "nonSelectedColor": self.missing_data_color, } return color_data @property def cmap(self) -> Colormap: return Colormap(self.name, bad=self.missing_data_color) @property def altair_scale(self) -> alt.Scale: if self.norm_type == "categorical": colors = self.cmap.to_altair(N=self.cmap.num_colors) domain = range(self.cmap.num_colors) else: colors = self.cmap.to_altair() domain = np.linspace(self.vmin, self.vmax, 256, endpoint=True) scale = alt.Scale(domain=list(domain), range=colors, clamp=True) return scale class AxisProperties(BaseModel): label: str = "x" unit: str = "au" autoscale_y: bool = True @property def title(self) -> str: return f"{self.label} ({self.unit})" def make_chart( data: pl.DataFrame, colors: ColorTransform, axis_properties: AxisProperties ) -> alt.LayerChart: xmin, xmax = data["residue_number"].min(), data["residue_number"].max() xpad = (xmax - xmin) * 0.05 xscale = alt.Scale(domain=(xmin - xpad, xmax + xpad)) if axis_properties.autoscale_y: y_scale = alt.Scale() elif colors.norm_type == "categorical": ypad = colors.cmap.num_colors * 0.05 y_scale = alt.Scale(domain=(0 - ypad, colors.cmap.num_colors - 1 + ypad)) else: ypad = (colors.vmax - colors.vmin) * 0.05 y_scale = alt.Scale(domain=(colors.vmin - ypad, colors.vmax + ypad)) zoom_x = alt.selection_interval( bind="scales", encodings=["x"], zoom="wheel![!event.shiftKey]", ) scatter = ( alt.Chart(data) .mark_circle(interpolate="basis", size=200) .encode( x=alt.X("residue_number:Q", title="Residue Number", scale=xscale), y=alt.Y( "value:Q", title=axis_properties.title, scale=y_scale, ), color=alt.Color( f"value:{'O' if colors.norm_type == 'categorical' else 'Q'}", scale=colors.altair_scale, title=axis_properties.title, ), ) .add_params(zoom_x) ) # Create a selection that chooses the nearest point & selects based on x-value nearest = alt.selection_point( name="point", nearest=True, on="pointerover", fields=["residue_number"], empty=False, clear="mouseout", ) select_residue = ( alt.Chart(data) .mark_point() .encode( x="residue_number:Q", opacity=alt.value(0), ) .add_params(nearest) ) # Draw a rule at the location of the selection rule = ( alt.Chart(data) .mark_rule(color=colors.highlight_color, size=2) .encode( x="residue_number:Q", ) .transform_filter(nearest) ) # vline = ( # alt.Chart(pd.DataFrame({"x": [0]})) # .mark_rule(color=colors.highlight_color, size=2) # .encode(x="x:Q") # ) line_position = alt.param(name="line_position", value=0.0) line_opacity = alt.param(name="line_opacity", value=1) df_line = pl.DataFrame({"x": [1.0]}) # Create vertical rule with parameter vline = ( alt.Chart(df_line) .mark_rule(color=colors.highlight_color, opacity=line_opacity, size=2) .encode(x=alt.X("p", type="quantitative")) .transform_calculate(p=alt.datum.x * line_position) .add_params(line_position, line_opacity) ) # Put the five layers into a chart and bind the data chart = ( alt.layer(scatter, vline, select_residue, rule).properties( width="container", height=480, # autosize height? ) # .configure(autosize="fit") ) return chart @solara.component def ScatterChart( data: pl.DataFrame, colors: ColorTransform, axis_properties: AxisProperties, on_selections, line_value, ): def mem_chart(): chart = make_chart(data, colors, axis_properties) return chart chart = solara.use_memo(mem_chart, dependencies=[data, colors, axis_properties]) if line_value is not None: params = {"line_position": line_value, "line_opacity": 1} else: params = {"line_position": 0.0, "line_opacity": 0} dark_effective = solara.lab.use_dark_effective() if dark_effective: options = {"actions": False, "theme": "dark"} else: options = {"actions": False} view = alt.JupyterChart.element( # type: ignore chart=chart, embed_options=options, _params=params, ) def bind(): real = solara.get_widget(view) real.selections.observe(on_selections, "point") # type: ignore solara.use_effect(bind, [data, colors]) def is_numeric(val) -> bool: if val is not None: return not np.isnan(val) return False @solara.component def ProteinView( title: str, molecule_id: str, data: pl.DataFrame, colors: ColorTransform, axis_properties: AxisProperties, dark_effective: bool, description: str = "", ): about_dialog = solara.use_reactive(False) fullscreen = solara.use_reactive(False) # residue number to highlight in altair chart line_number = solara.use_reactive(None) # residue number to highlight in protein view highlight_number = solara.use_reactive(None) if data.is_empty(): color_data = {} else: color_data = colors.molstar_colors(data) tooltips = { "data": [ { "residue_number": resi, "tooltip": f"{axis_properties.label}: {value:.2g} {axis_properties.unit}" if is_numeric(value) else "No data", } for resi, value in zip(data["residue_number"], data["value"]) ] } def on_molstar_mouseover(value): r = value.get("residueNumber", None) line_number.set(r) def on_molstar_mouseout(value): on_molstar_mouseover({}) def on_chart_selection(event): try: r = event["new"].value[0]["residue_number"] highlight_number.set(r) except (IndexError, KeyError): highlight_number.set(None) with solara.AppBar(): solara.AppBarTitle(title) with solara.Tooltip("Fullscreen"): solara.Button( icon_name="mdi-fullscreen", icon=True, on_click=lambda: fullscreen.set(not fullscreen.value), ) if description: with solara.Tooltip("About"): solara.Button( icon_name="mdi-information-outline", icon=True, on_click=lambda: about_dialog.set(True), ) solara.lab.ThemeToggle() with solara.v.Dialog( v_model=about_dialog.value, on_v_model=lambda _ignore: about_dialog.set(False) ): with solara.Card("About", margin=0): solara.Markdown(description) with solara.ColumnsResponsive([4, 8]): with solara.Card(style={"height": "550px"}): PDBeMolstar.element( # type: ignore theme="dark" if dark_effective else "light", molecule_id=molecule_id.lower(), color_data=color_data, hide_water=True, tooltips=tooltips, height="525px", highlight={"data": [{"residue_number": int(highlight_number.value)}]} if highlight_number.value else None, highlight_color=colors.highlight_color, on_mouseover_event=on_molstar_mouseover, on_mouseout_event=on_molstar_mouseout, hide_controls_icon=True, hide_expand_icon=True, hide_settings_icon=True, expanded=fullscreen.value, ).key(f"molstar-{dark_effective}") if not fullscreen.value: with solara.Card(style={"height": "550px"}): if data.is_empty(): solara.Text("No data") else: ScatterChart( data, colors, axis_properties, on_chart_selection, line_number.value, ) @solara.component def RoutedView(): route = solara.use_router() dark_effective = solara.lab.use_dark_effective() try: query_dict = {k: v for k, v in parse_qsl(route.search)} colors = ColorTransform(**query_dict) # type: ignore axis_properties = AxisProperties(**query_dict) # type: ignore data = decode_data(query_dict["data"]) ProteinView( query_dict["title"], molecule_id=query_dict["molecule_id"], data=data, colors=colors, axis_properties=axis_properties, dark_effective=dark_effective, description=query_dict.get("description", ""), ) except KeyError as err: solara.Warning(f"Error: {err}")