import json from dataclasses import dataclass from io import StringIO from typing import Literal, Optional, TypedDict, cast from zipfile import ZipFile import polars as pl import solara from Bio.PDB import MMCIFParser, Structure from ipymolstar import PDBeMolstar from ipymolstar.widget import QueryParam from polarify import polarify from solara.components.file_drop import FileInfo class ColorData(TypedDict): data: list[QueryParam] NonSelectedColor: None class TooltipData(TypedDict): data: list[QueryParam] class CustomData(TypedDict): data: str format: Literal["cif"] binary: Literal[False] @dataclass class AlphaFoldData: name: str structure: Structure atom_data: pl.DataFrame residue_data: pl.DataFrame custom_data: CustomData color_data: ColorData tooltip_data: TooltipData def write_atoms(self): return self.atom_data.write_csv() COLOR_LUT = { "very-high": {"r": 16, "g": 109, "b": 255}, "confident": {"r": 16, "g": 207, "b": 241}, "low": {"r": 246, "g": 237, "b": 18}, "very-low": {"r": 239, "g": 130, "b": 30}, } NO_COLOR_DATA = {"data": [], "nonSelectedColor": None} NO_TOOLTIP_DATA = {"data": []} PARSER = MMCIFParser() result_index = solara.reactive(0) file_info = solara.reactive(cast(Optional[FileInfo], None)) @polarify def assign_confidence(x: pl.Expr) -> pl.Expr: s = pl.lit("very-high") if x < 50: s = pl.lit("very-low") elif x < 70: s = pl.lit("low") elif x < 90: s = pl.lit("confident") return s @solara.lab.task def load_result() -> Optional[AlphaFoldData]: f_idx = result_index.value with ZipFile(file_info.value["file_obj"]) as zf: files = zf.namelist() names = sorted(f for f in files if f.endswith(".cif")) structure_file = sorted(f for f in files if f.endswith(".cif"))[f_idx] json_data_file = sorted(f for f in files if "full_data" in f)[f_idx] with zf.open(json_data_file) as json_f: json_load = json.load(json_f) cif_str = zf.read(structure_file).decode("utf-8") alphafold_name = structure_file.rstrip(".cif") sio = StringIO(cif_str) sio.seek(0) structure = PARSER.get_structure(structure_file.removesuffix(".cif"), sio) names = pl.Series( (atom.get_parent().resname for atom in structure.get_atoms()), dtype=pl.Categorical, ) resn = pl.Series(atom.get_parent().id[1] for atom in structure.get_atoms()) chain = pl.Series(json_load["atom_chain_ids"], dtype=pl.Categorical) atoms_df = pl.DataFrame( { "name": names, "resn": resn, "chain": chain, "plddt": json_load["atom_plddts"], } ) residue_df = ( atoms_df.group_by(["chain", "resn", "name"]) .agg(pl.col("plddt").mean().alias("mean_plddt")) .sort(["chain", "resn"]) .with_columns( assign_confidence(pl.col("mean_plddt")) .alias("confidence") .cast(pl.Categorical) ) ) custom_data = { "data": cif_str, "format": "cif", "binary": False, } color_query = [] tooltip_query = [] for chain, resn, name, mean_plddt, confidence in residue_df.iter_rows(): res_color = { "struct_asym_id": chain, "residue_number": resn, "color": COLOR_LUT[confidence], } res_tt = { "struct_asym_id": chain, "residue_number": resn, "tooltip": f"Confidence: {confidence}; plddt: {mean_plddt:.2f}", } color_query.append(res_color) tooltip_query.append(res_tt) plddt_color_data = {"data": color_query, "nonSelectedColor": None} plddt_tooltip_data = {"data": tooltip_query} data = AlphaFoldData( name=alphafold_name, structure=structure, atom_data=atoms_df, residue_data=residue_df, custom_data=custom_data, color_data=plddt_color_data, tooltip_data=plddt_tooltip_data, ) return data @solara.component def Page(): color_mode = solara.use_reactive("chain") spin = solara.use_reactive(True) dark_effective = solara.lab.use_dark_effective() def on_color_mode(value: str): color_mode.set(value) def set_result_index(value: int): result_index.set(value) load_result() solara.Title("Solarafold result viewer") with solara.AppBar(): solara.lab.ThemeToggle() with solara.Sidebar(): solara.FileDrop(label="Upload zip file", on_file=file_info.set, lazy=True) solara.Button( label="Load result", on_click=load_result, block=True, disabled=file_info.value is None, ) if not load_result.not_called: disabled = load_result.pending solara.Select( label="Result index", value=result_index.value, on_value=set_result_index, values=list(range(5)), disabled=disabled, ) solara.Select( label="Color mode", values=["chain", "plddt"], value=color_mode.value, on_value=on_color_mode, disabled=disabled, ) solara.Checkbox(label="Spin", value=spin) def write_atoms(): return load_result.value.atom_data.write_csv() solara.FileDownload( write_atoms, filename="NA" if disabled else f"{load_result.value.name}_atoms.csv", children=[ solara.Button( "Download atom plddt", block=True, disabled=disabled, ) ], ) solara.Div(style={"height": "20px"}) def write_residues(): return load_result.value.residue_data.write_csv() solara.FileDownload( write_residues, filename="NA" if disabled else f"{load_result.value.name}_residues.csv", children=[ solara.Button( "Download residue plddt", block=True, disabled=disabled, ) ], ) if load_result.not_called: solara.HTML( tag="p", unsafe_innerHTML='Drag and drop an alphafold3 result .zip file to get started. You can download an example file here.', ) elif load_result.pending: solara.ProgressLinear(load_result.pending) elif load_result.finished: fold_data: AlphaFoldData = load_result.value color_data = ( NO_COLOR_DATA if color_mode.value == "chain" else fold_data.color_data ) with solara.Card(): theme = "dark" if dark_effective else "light" PDBeMolstar.element( height="calc(100vh - 150px)", custom_data=fold_data.custom_data, color_data=color_data, tooltips=fold_data.tooltip_data, show_water=False, spin=spin.value, theme=theme, ).key(f"pdbemolstar-{dark_effective}") @solara.component def Layout(children): dark_effective = solara.lab.use_dark_effective() return solara.AppLayout( children=children, toolbar_dark=dark_effective, color=None ) # if dark_effective else "primary")