raulminan commited on
Commit
4ddda00
1 Parent(s): 588df03

Upload 3 files

Browse files
Files changed (3) hide show
  1. Dockerfile +12 -0
  2. app.py +389 -0
  3. requirements.txt +8 -0
Dockerfile ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10
2
+
3
+ WORKDIR /app
4
+
5
+ COPY requirements.txt ./
6
+ RUN pip install --no-cache-dir -r requirements.txt
7
+
8
+ COPY . .
9
+ COPY all_embeddings.p all_embeddings_with_splits.p app.py embeddings.p pacmap2d.npy pacmap3d.npy pca.npy tsne2d.npy tsne3d.npy umap2d.npy umap3d.npy ./
10
+
11
+ CMD ["python", "app.py", "--host", "0.0.0.0", "--port", "7680"]
12
+
app.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dash
2
+ import dash_bootstrap_components as dbc
3
+ from dash import dcc
4
+ from dash import html
5
+ from dash.dependencies import Input, Output, State
6
+
7
+ from typing import List, Tuple
8
+ from scipy.spatial.distance import cdist
9
+
10
+ import pandas as pd
11
+ import numpy as np
12
+ import plotly.express as px
13
+ import plotly.graph_objects as go
14
+
15
+
16
+ df = pd.read_pickle('all_embeddings_with_splits.p')
17
+
18
+ app = dash.Dash(external_stylesheets=[dbc.themes.BOOTSTRAP])
19
+ app.layout = dbc.Container(
20
+ [
21
+ html.H1("Embedding Plots"),
22
+ html.Hr(),
23
+ html.Div(
24
+ [
25
+ dbc.Row(
26
+ [
27
+ dbc.Col(
28
+ [
29
+ html.Label('Algorithm:'),
30
+ dcc.Dropdown(
31
+ id="algorithm-dropdown",
32
+ options=[
33
+ {"label": "PCA", "value": "pca"},
34
+ {"label": "UMAP", "value": "umap"},
35
+ {"label": "tSNE", "value": "tsne"},
36
+ {"label": "PaCMAP", "value": "pacmap"},
37
+ ],
38
+ value="pacmap",
39
+ clearable=False,
40
+ searchable=False,
41
+ style={"margin-bottom": "10px"}
42
+ ),
43
+ html.Label('Number of dimensions:'),
44
+ dcc.Dropdown(
45
+ id="num-components-dropdown",
46
+ options=[
47
+ {"label": "2", "value": 2},
48
+ {"label": "3", "value": 3}
49
+ ],
50
+ value=3,
51
+ clearable=False,
52
+ searchable=False,
53
+ style={"margin-bottom": "10px"}
54
+ ),
55
+ html.Label('Color by:'),
56
+ dcc.Dropdown(
57
+ id="color-by",
58
+ options=[
59
+ {
60
+ "label": "Protein Classification",
61
+ "value": "classification"
62
+ },
63
+ {
64
+ "label": "Split (train/test/val/gpcr)",
65
+ "value": "split"
66
+ }
67
+ ],
68
+ value="classification",
69
+ clearable=False,
70
+ searchable=False,
71
+ style={"margin-bottom": "10px"}
72
+ ),
73
+ html.Span(
74
+ [
75
+ "Keep the top ",
76
+ dcc.Input(
77
+ id="top-n-classes",
78
+ type="number",
79
+ value=10,
80
+ min=1,
81
+ max=len(df["classification"].unique()),
82
+ step=1,
83
+ style={"width": "50px"}
84
+ ),
85
+ " classes."
86
+ ],
87
+ style={"margin-bottom": "20px"}
88
+ ),
89
+ html.Br(),
90
+ dbc.Button(
91
+ "Update",
92
+ id="update-button",
93
+ color="primary",
94
+ n_clicks=0,
95
+ style={"width": "100%", "margin": "10px 0px"}
96
+ ),
97
+ dbc.Container(
98
+ id="closest-points",
99
+ style={"max-height": "65vh", "overflow-y": "auto"}
100
+ ),
101
+ ],
102
+ width={"size": 2, "order": 1},
103
+ ),
104
+ dbc.Col(
105
+ dcc.Graph(
106
+ id="embedding-graph",
107
+ style={"height": "100%", "width": "100%"},
108
+ ),
109
+ width={"size": 10, "order": 2},
110
+ ),
111
+ ],
112
+ style={"height":"95vh"}
113
+ )
114
+ ],
115
+ style={"height":"100hv"}
116
+ ),
117
+ html.Hr(),
118
+ ],
119
+ fluid=True,
120
+ )
121
+
122
+ def load_embedding(algorithm: str, num_components: int) -> np.array:
123
+ """Loads the embeddings given an algorithm and number of dimensions.
124
+
125
+ Parameters
126
+ ----------
127
+ algorithm : str
128
+ Algorithm used
129
+ num_components : int
130
+ see param name
131
+
132
+ Returns
133
+ -------
134
+ np.array
135
+ A Ax1280 numpy matrix with the embeddings.
136
+ """
137
+ if algorithm == "pca":
138
+ embedding = np.load("pca.npy")
139
+ else:
140
+ embedding = np.load(f"{algorithm}{str(num_components)}d.npy")
141
+ return embedding
142
+
143
+ def get_top_n_classifications(df: pd.DataFrame, n: int) -> List[str]:
144
+ return df["classification"].value_counts().nlargest(n).index.tolist()
145
+
146
+ @app.callback(
147
+ Output("embedding-graph", "figure"),
148
+ [
149
+ Input("update-button", "n_clicks"),
150
+ ],
151
+ [
152
+ State("algorithm-dropdown", "value"),
153
+ State("num-components-dropdown", "value"),
154
+ State("top-n-classes", "value"),
155
+ State("color-by", "value"),
156
+ ]
157
+ )
158
+ def update_embedding_graph(n_clicks: int,
159
+ algorithm: str,
160
+ num_components: int,
161
+ top_n_classes: int,
162
+ color_by: str) -> go.Figure:
163
+ if n_clicks > 0:
164
+ embedding = load_embedding(algorithm, num_components)
165
+
166
+ if color_by == "split":
167
+ color_map = {
168
+ "gpcr": "red",
169
+ "train": "blue",
170
+ "val": "green",
171
+ "test": "orange",
172
+ "unknown": "grey",
173
+ }
174
+ color_series = df["splits"].copy()
175
+ df["color_series"] = color_series
176
+ else:
177
+ top_classes = get_top_n_classifications(df, n=top_n_classes)
178
+ is_top_n = df["classification"].isin(top_classes)
179
+ color_series = df["classification"].copy()
180
+ color_series[~is_top_n] = "other"
181
+ df["color_series"] = color_series
182
+ top_n_colors = px.colors.qualitative.Plotly[:top_n_classes]
183
+ color_map_top = {c: top_n_colors[i] for i, c in enumerate(top_classes)}
184
+ color_map = {c: color_map_top[c] if c in top_classes else 'grey' for i, c in enumerate(set(df['color_series']))}
185
+
186
+
187
+ if num_components == 3:
188
+ fig = go.Figure()
189
+ for c in df["color_series"].unique():
190
+ class_indices = np.where(df["color_series"] == c)[0]
191
+ data = embedding[class_indices]
192
+ fig.add_trace(
193
+ go.Scatter3d(
194
+ x=data[:,0],
195
+ y=data[:,1],
196
+ z=data[:,2],
197
+ mode='markers',
198
+ name=c,
199
+ marker=dict(
200
+ size=2.5,
201
+ color=color_map[c],
202
+ opacity=1 if color_map[c] != 'grey' else 0.3,
203
+ ),
204
+ hovertemplate=
205
+ "<b>PDB ID</b>: %{customdata[0]}<br>" +
206
+ "<b>Classification</b>: %{customdata[1]}<br>" +
207
+ "<extra></extra>",
208
+ customdata=df.iloc[class_indices][['pdb_id', 'classification']]
209
+ )
210
+ )
211
+
212
+ fig.update_layout(
213
+ scene=dict(
214
+ xaxis=dict(showgrid=False, showticklabels=False, title=""),
215
+ yaxis=dict(showgrid=False, showticklabels=False, title=""),
216
+ zaxis=dict(showgrid=False, showticklabels=False, title=""),
217
+ ),
218
+ )
219
+ fig.update_scenes(xaxis_visible=False, yaxis_visible=False,zaxis_visible=False )
220
+
221
+ elif num_components == 2:
222
+ fig = go.Figure()
223
+ for c in df["color_series"].unique():
224
+ class_indices = np.where(df["color_series"] == c)[0]
225
+ data = embedding[class_indices]
226
+ fig.add_trace(
227
+ go.Scatter(
228
+ x=data[:,0],
229
+ y=data[:,1],
230
+ mode='markers',
231
+ name=c,
232
+ marker=dict(
233
+ size=2.5,
234
+ color=color_map[c],
235
+ opacity=1 if color_map[c] != 'grey' else 0.3,
236
+ ),
237
+ hovertemplate=
238
+ "<b>PDB ID</b>: %{customdata[0]}<br>" +
239
+ "<b>Classification</b>: %{customdata[1]}<br>"
240
+ "<extra></extra>",
241
+ customdata=df.iloc[class_indices][['pdb_id', 'classification']]
242
+ )
243
+ )
244
+ fig.update_traces(marker=dict(size=7.5), selector=dict(mode='markers'))
245
+ fig.update_scenes(xaxis_visible=False, yaxis_visible=False)
246
+
247
+ fig.update_layout(
248
+ legend=dict(
249
+ x=0,
250
+ y=1,
251
+ itemsizing='constant',
252
+ itemclick='toggle',
253
+ itemdoubleclick='toggleothers',
254
+ traceorder='reversed',
255
+ itemwidth=30,
256
+ ),
257
+ margin=dict(l=0, r=0, b=0, t=0),
258
+ plot_bgcolor='rgba(0,0,0,0)',
259
+ paper_bgcolor='rgba(0,0,0,0)',
260
+ )
261
+ return fig
262
+
263
+ else:
264
+ raise dash.exceptions.PreventUpdate
265
+
266
+ #### GET CLOSEST POINTS
267
+
268
+ def extract_info_from_clickData(clickData: dict) -> Tuple[str, str]:
269
+ """Extracts information from a clickData dictionary coming from clicking
270
+ a point in a scatter plot.
271
+
272
+ Speficially, it retrieves the pdb_id and the classification.
273
+
274
+ Shape of clickData:
275
+
276
+ {
277
+ "points": [
278
+ {
279
+ "x": 11.330583,
280
+ "y": 15.741333,
281
+ "z": -5.3435574,
282
+ "curveNumber": 2,
283
+ "pointNumber": 982,
284
+ "bbox": {
285
+ "x0": 704.3911532022826,
286
+ "x1": 704.3911532022826,
287
+ "y0": 393.5066681413661,
288
+ "y1": 393.5066681413661
289
+ },
290
+ "customdata": [
291
+ "1zfp",
292
+ "complex (signal transduction/peptide)"
293
+ ]
294
+ }
295
+ ]
296
+ }
297
+
298
+ Parameters
299
+ ----------
300
+ clickData : dict
301
+ Contains the information of a point on a go.Figure graph.
302
+
303
+ Returns
304
+ -------
305
+ Tuple[]
306
+ _description_
307
+ """
308
+ pdb_id = clickData["points"][0]["customdata"][0]
309
+ classification = clickData["points"][0]["customdata"][1]
310
+
311
+ return pdb_id, classification
312
+
313
+ def find_closest_n_points(df: pd.DataFrame,
314
+ embedding: np.array,
315
+ index: int = None,
316
+ pdb_id: str = None,
317
+ n: int = 20) -> Tuple[list, list]:
318
+ """
319
+ Given an embedding array and a point index or pdb_id, finds the n closest
320
+ points to the given point.
321
+
322
+ Parameters:
323
+ -----------
324
+ embedding: np.ndarray
325
+ A 2D numpy array with the embedding coordinates.
326
+ point_index: int
327
+ The index of the point to which we want to find the closest points.
328
+ n: int
329
+ The number of closest points to retrieve.
330
+
331
+ Returns:
332
+ --------
333
+ closest_indices: list
334
+ A list with the indices of the n closest points to the given point.
335
+ """
336
+ if pdb_id:
337
+ index = df.index[df["pdb_id"] == pdb_id].item()
338
+
339
+ distances = cdist(embedding[index, np.newaxis], embedding)
340
+ closest_indices = np.argsort(distances)[0][:n]
341
+ closest_ids = df.iloc[closest_indices]["pdb_id"].tolist()
342
+ closest_ids_classifications = df.iloc[closest_indices]["classification"].tolist()
343
+
344
+ return closest_ids, closest_ids_classifications
345
+
346
+
347
+ @app.callback(
348
+ Output("closest-points", "children"),
349
+ [
350
+ Input("embedding-graph", "clickData")
351
+ ],
352
+ [
353
+ State("algorithm-dropdown", "value"),
354
+ State("num-components-dropdown", "value"),
355
+ ]
356
+ )
357
+ def update_closest_points_div(
358
+ clickData: dict,
359
+ algorithm: str,
360
+ num_components: int) -> html.Table:
361
+
362
+ embedding = load_embedding(algorithm, num_components)
363
+
364
+ if clickData is not None:
365
+ pdb_id, _ = extract_info_from_clickData(clickData)
366
+ index = df.index[df["pdb_id"] == pdb_id].item()
367
+ closest_ids, closest_ids_classifications = find_closest_n_points(
368
+ df, embedding, index)
369
+
370
+ cards = []
371
+ for i in range(len(closest_ids)):
372
+ card = dbc.Card(
373
+ dbc.CardBody(
374
+ [
375
+ html.P(closest_ids[i], className="card-title"),
376
+ html.P(closest_ids_classifications[i], className="card-text"),
377
+ ]
378
+ ),
379
+ className="mb-3",
380
+ )
381
+ cards.append(card)
382
+
383
+ return cards
384
+
385
+ return html.Div(id="closest-points", children=[html.Div("Click on a data point to see the closest points.")])
386
+
387
+
388
+ if __name__ == "__main__":
389
+ app.run_server(debug=True)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ dash==2.9.3
2
+ dash-bootstrap-components==1.4.1
3
+ dash-core-components==2.0.0
4
+ dash-html-components==2.0.0
5
+ plotly==5.14.1
6
+ numpy==1.23.5
7
+ pandas==1.5.0
8
+ scipy==1.10.0