Spaces:
Runtime error
Runtime error
import dash | |
import dash_bootstrap_components as dbc | |
from dash import dcc | |
from dash import html | |
from dash.dependencies import Input, Output, State | |
from typing import List, Tuple | |
from scipy.spatial.distance import cdist | |
import pandas as pd | |
import numpy as np | |
import plotly.express as px | |
import plotly.graph_objects as go | |
df = pd.read_pickle('all_embeddings_with_splits.p') | |
app = dash.Dash(external_stylesheets=[dbc.themes.BOOTSTRAP]) | |
app.layout = dbc.Container( | |
[ | |
html.H1("Embedding Plots"), | |
html.Hr(), | |
html.Div( | |
[ | |
dbc.Row( | |
[ | |
dbc.Col( | |
[ | |
html.Label('Algorithm:'), | |
dcc.Dropdown( | |
id="algorithm-dropdown", | |
options=[ | |
{"label": "PCA", "value": "pca"}, | |
{"label": "UMAP", "value": "umap"}, | |
{"label": "tSNE", "value": "tsne"}, | |
{"label": "PaCMAP", "value": "pacmap"}, | |
], | |
value="pacmap", | |
clearable=False, | |
searchable=False, | |
style={"margin-bottom": "10px"} | |
), | |
html.Label('Number of dimensions:'), | |
dcc.Dropdown( | |
id="num-components-dropdown", | |
options=[ | |
{"label": "2", "value": 2}, | |
{"label": "3", "value": 3} | |
], | |
value=3, | |
clearable=False, | |
searchable=False, | |
style={"margin-bottom": "10px"} | |
), | |
html.Label('Color by:'), | |
dcc.Dropdown( | |
id="color-by", | |
options=[ | |
{ | |
"label": "Protein Classification", | |
"value": "classification" | |
}, | |
{ | |
"label": "Split (train/test/val/gpcr)", | |
"value": "split" | |
} | |
], | |
value="classification", | |
clearable=False, | |
searchable=False, | |
style={"margin-bottom": "10px"} | |
), | |
html.Span( | |
[ | |
"Keep the top ", | |
dcc.Input( | |
id="top-n-classes", | |
type="number", | |
value=10, | |
min=1, | |
max=len(df["classification"].unique()), | |
step=1, | |
style={"width": "50px"} | |
), | |
" classes." | |
], | |
style={"margin-bottom": "20px"} | |
), | |
html.Br(), | |
dbc.Button( | |
"Update", | |
id="update-button", | |
color="primary", | |
n_clicks=0, | |
style={"width": "100%", "margin": "10px 0px"} | |
), | |
dbc.Container( | |
id="closest-points", | |
style={"max-height": "65vh", "overflow-y": "auto"} | |
), | |
], | |
width={"size": 2, "order": 1}, | |
), | |
dbc.Col( | |
dcc.Graph( | |
id="embedding-graph", | |
style={"height": "100%", "width": "100%"}, | |
), | |
width={"size": 10, "order": 2}, | |
), | |
], | |
style={"height":"95vh"} | |
) | |
], | |
style={"height":"100hv"} | |
), | |
html.Hr(), | |
], | |
fluid=True, | |
) | |
def load_embedding(algorithm: str, num_components: int) -> np.array: | |
"""Loads the embeddings given an algorithm and number of dimensions. | |
Parameters | |
---------- | |
algorithm : str | |
Algorithm used | |
num_components : int | |
see param name | |
Returns | |
------- | |
np.array | |
A Ax1280 numpy matrix with the embeddings. | |
""" | |
if algorithm == "pca": | |
embedding = np.load("pca.npy") | |
else: | |
embedding = np.load(f"{algorithm}{str(num_components)}d.npy") | |
return embedding | |
def get_top_n_classifications(df: pd.DataFrame, n: int) -> List[str]: | |
return df["classification"].value_counts().nlargest(n).index.tolist() | |
def update_embedding_graph(n_clicks: int, | |
algorithm: str, | |
num_components: int, | |
top_n_classes: int, | |
color_by: str) -> go.Figure: | |
if n_clicks > 0: | |
embedding = load_embedding(algorithm, num_components) | |
if color_by == "split": | |
color_map = { | |
"gpcr": "red", | |
"train": "blue", | |
"val": "green", | |
"test": "orange", | |
"unknown": "grey", | |
} | |
color_series = df["splits"].copy() | |
df["color_series"] = color_series | |
else: | |
top_classes = get_top_n_classifications(df, n=top_n_classes) | |
is_top_n = df["classification"].isin(top_classes) | |
color_series = df["classification"].copy() | |
color_series[~is_top_n] = "other" | |
df["color_series"] = color_series | |
top_n_colors = px.colors.qualitative.Plotly[:top_n_classes] | |
color_map_top = {c: top_n_colors[i] for i, c in enumerate(top_classes)} | |
color_map = {c: color_map_top[c] if c in top_classes else 'grey' for i, c in enumerate(set(df['color_series']))} | |
if num_components == 3: | |
fig = go.Figure() | |
for c in df["color_series"].unique(): | |
class_indices = np.where(df["color_series"] == c)[0] | |
data = embedding[class_indices] | |
fig.add_trace( | |
go.Scatter3d( | |
x=data[:,0], | |
y=data[:,1], | |
z=data[:,2], | |
mode='markers', | |
name=c, | |
marker=dict( | |
size=2.5, | |
color=color_map[c], | |
opacity=1 if color_map[c] != 'grey' else 0.3, | |
), | |
hovertemplate= | |
"<b>PDB ID</b>: %{customdata[0]}<br>" + | |
"<b>Classification</b>: %{customdata[1]}<br>" + | |
"<extra></extra>", | |
customdata=df.iloc[class_indices][['pdb_id', 'classification']] | |
) | |
) | |
fig.update_layout( | |
scene=dict( | |
xaxis=dict(showgrid=False, showticklabels=False, title=""), | |
yaxis=dict(showgrid=False, showticklabels=False, title=""), | |
zaxis=dict(showgrid=False, showticklabels=False, title=""), | |
), | |
) | |
fig.update_scenes(xaxis_visible=False, yaxis_visible=False,zaxis_visible=False ) | |
elif num_components == 2: | |
fig = go.Figure() | |
for c in df["color_series"].unique(): | |
class_indices = np.where(df["color_series"] == c)[0] | |
data = embedding[class_indices] | |
fig.add_trace( | |
go.Scatter( | |
x=data[:,0], | |
y=data[:,1], | |
mode='markers', | |
name=c, | |
marker=dict( | |
size=2.5, | |
color=color_map[c], | |
opacity=1 if color_map[c] != 'grey' else 0.3, | |
), | |
hovertemplate= | |
"<b>PDB ID</b>: %{customdata[0]}<br>" + | |
"<b>Classification</b>: %{customdata[1]}<br>" | |
"<extra></extra>", | |
customdata=df.iloc[class_indices][['pdb_id', 'classification']] | |
) | |
) | |
fig.update_traces(marker=dict(size=7.5), selector=dict(mode='markers')) | |
fig.update_scenes(xaxis_visible=False, yaxis_visible=False) | |
fig.update_layout( | |
legend=dict( | |
x=0, | |
y=1, | |
itemsizing='constant', | |
itemclick='toggle', | |
itemdoubleclick='toggleothers', | |
traceorder='reversed', | |
itemwidth=30, | |
), | |
margin=dict(l=0, r=0, b=0, t=0), | |
plot_bgcolor='rgba(0,0,0,0)', | |
paper_bgcolor='rgba(0,0,0,0)', | |
) | |
return fig | |
else: | |
raise dash.exceptions.PreventUpdate | |
#### GET CLOSEST POINTS | |
def extract_info_from_clickData(clickData: dict) -> Tuple[str, str]: | |
"""Extracts information from a clickData dictionary coming from clicking | |
a point in a scatter plot. | |
Speficially, it retrieves the pdb_id and the classification. | |
Shape of clickData: | |
{ | |
"points": [ | |
{ | |
"x": 11.330583, | |
"y": 15.741333, | |
"z": -5.3435574, | |
"curveNumber": 2, | |
"pointNumber": 982, | |
"bbox": { | |
"x0": 704.3911532022826, | |
"x1": 704.3911532022826, | |
"y0": 393.5066681413661, | |
"y1": 393.5066681413661 | |
}, | |
"customdata": [ | |
"1zfp", | |
"complex (signal transduction/peptide)" | |
] | |
} | |
] | |
} | |
Parameters | |
---------- | |
clickData : dict | |
Contains the information of a point on a go.Figure graph. | |
Returns | |
------- | |
Tuple[] | |
_description_ | |
""" | |
pdb_id = clickData["points"][0]["customdata"][0] | |
classification = clickData["points"][0]["customdata"][1] | |
return pdb_id, classification | |
def find_closest_n_points(df: pd.DataFrame, | |
embedding: np.array, | |
index: int = None, | |
pdb_id: str = None, | |
n: int = 20) -> Tuple[list, list]: | |
""" | |
Given an embedding array and a point index or pdb_id, finds the n closest | |
points to the given point. | |
Parameters: | |
----------- | |
embedding: np.ndarray | |
A 2D numpy array with the embedding coordinates. | |
point_index: int | |
The index of the point to which we want to find the closest points. | |
n: int | |
The number of closest points to retrieve. | |
Returns: | |
-------- | |
closest_indices: list | |
A list with the indices of the n closest points to the given point. | |
""" | |
if pdb_id: | |
index = df.index[df["pdb_id"] == pdb_id].item() | |
distances = cdist(embedding[index, np.newaxis], embedding) | |
closest_indices = np.argsort(distances)[0][:n] | |
closest_ids = df.iloc[closest_indices]["pdb_id"].tolist() | |
closest_ids_classifications = df.iloc[closest_indices]["classification"].tolist() | |
return closest_ids, closest_ids_classifications | |
def update_closest_points_div( | |
clickData: dict, | |
algorithm: str, | |
num_components: int) -> html.Table: | |
embedding = load_embedding(algorithm, num_components) | |
if clickData is not None: | |
pdb_id, _ = extract_info_from_clickData(clickData) | |
index = df.index[df["pdb_id"] == pdb_id].item() | |
closest_ids, closest_ids_classifications = find_closest_n_points( | |
df, embedding, index) | |
cards = [] | |
for i in range(len(closest_ids)): | |
card = dbc.Card( | |
dbc.CardBody( | |
[ | |
html.P(closest_ids[i], className="card-title"), | |
html.P(closest_ids_classifications[i], className="card-text"), | |
] | |
), | |
className="mb-3", | |
) | |
cards.append(card) | |
return cards | |
return html.Div(id="closest-points", children=[html.Div("Click on a data point to see the closest points.")]) | |
if __name__ == "__main__": | |
app.run_server(debug=True) |