Christina Theodoris
commited on
Commit
•
6caf480
1
Parent(s):
61f15d2
Add memory-efficient method for computing emb summary statistics
Browse files- geneformer/emb_extractor.py +62 -19
geneformer/emb_extractor.py
CHANGED
@@ -14,7 +14,8 @@ Usage:
|
|
14 |
emb_label=["disease","cell_type"],
|
15 |
labels_to_plot=["disease","cell_type"],
|
16 |
forward_batch_size=100,
|
17 |
-
nproc=16
|
|
|
18 |
embs = embex.extract_embs("path/to/model",
|
19 |
"path/to/input_data",
|
20 |
"path/to/output_directory",
|
@@ -33,6 +34,7 @@ import matplotlib.pyplot as plt
|
|
33 |
import numpy as np
|
34 |
import pandas as pd
|
35 |
import pickle
|
|
|
36 |
import scanpy as sc
|
37 |
import seaborn as sns
|
38 |
import torch
|
@@ -54,20 +56,28 @@ from .in_silico_perturber import downsample_and_sort, \
|
|
54 |
|
55 |
logger = logging.getLogger(__name__)
|
56 |
|
57 |
-
#
|
58 |
def get_embs(model,
|
59 |
filtered_input_data,
|
60 |
emb_mode,
|
61 |
layer_to_quant,
|
62 |
pad_token_id,
|
63 |
-
forward_batch_size
|
|
|
64 |
|
65 |
model_input_size = get_model_input_size(model)
|
66 |
total_batch_length = len(filtered_input_data)
|
67 |
-
if ((total_batch_length-1)/forward_batch_size).is_integer():
|
68 |
-
forward_batch_size = forward_batch_size-1
|
69 |
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
for i in trange(0, total_batch_length, forward_batch_size):
|
72 |
max_range = min(i+forward_batch_size, total_batch_length)
|
73 |
|
@@ -81,29 +91,52 @@ def get_embs(model,
|
|
81 |
max_len,
|
82 |
pad_token_id,
|
83 |
model_input_size)
|
84 |
-
|
85 |
with torch.no_grad():
|
86 |
outputs = model(
|
87 |
input_ids = input_data_minibatch.to("cuda"),
|
88 |
attention_mask = gen_attention_mask(minibatch)
|
89 |
)
|
90 |
-
|
91 |
embs_i = outputs.hidden_states[layer_to_quant]
|
92 |
|
93 |
if emb_mode == "cell":
|
94 |
mean_embs = mean_nonpadding_embs(embs_i, original_lens)
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
96 |
|
97 |
del outputs
|
98 |
del minibatch
|
99 |
del input_data_minibatch
|
100 |
del embs_i
|
101 |
del mean_embs
|
102 |
-
torch.cuda.empty_cache()
|
103 |
-
|
104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
return embs_stack
|
106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
def label_embs(embs, downsampled_data, emb_labels):
|
108 |
embs_df = pd.DataFrame(embs.cpu())
|
109 |
if emb_labels is not None:
|
@@ -131,7 +164,6 @@ def plot_umap(embs_df, emb_dims, label, output_file, kwargs_dict):
|
|
131 |
|
132 |
sc.pl.umap(adata, color=label, save=output_file, **default_kwargs_dict)
|
133 |
|
134 |
-
|
135 |
def gen_heatmap_class_colors(labels, df):
|
136 |
pal = sns.cubehelix_palette(len(Counter(labels).keys()), light=0.9, dark=0.1, hue=1, reverse=True, start=1, rot=-2)
|
137 |
lut = dict(zip(map(str, Counter(labels).keys()), pal))
|
@@ -208,6 +240,7 @@ class EmbExtractor:
|
|
208 |
"labels_to_plot": {None, list},
|
209 |
"forward_batch_size": {int},
|
210 |
"nproc": {int},
|
|
|
211 |
}
|
212 |
def __init__(
|
213 |
self,
|
@@ -222,6 +255,7 @@ class EmbExtractor:
|
|
222 |
labels_to_plot=None,
|
223 |
forward_batch_size=100,
|
224 |
nproc=4,
|
|
|
225 |
token_dictionary_file=TOKEN_DICTIONARY_FILE,
|
226 |
):
|
227 |
"""
|
@@ -263,6 +297,10 @@ class EmbExtractor:
|
|
263 |
Batch size for forward pass.
|
264 |
nproc : int
|
265 |
Number of CPU processes to use.
|
|
|
|
|
|
|
|
|
266 |
token_dictionary_file : Path
|
267 |
Path to pickle file containing token dictionary (Ensembl ID:token).
|
268 |
"""
|
@@ -278,6 +316,7 @@ class EmbExtractor:
|
|
278 |
self.labels_to_plot = labels_to_plot
|
279 |
self.forward_batch_size = forward_batch_size
|
280 |
self.nproc = nproc
|
|
|
281 |
|
282 |
self.validate_options()
|
283 |
|
@@ -353,14 +392,19 @@ class EmbExtractor:
|
|
353 |
self.emb_mode,
|
354 |
layer_to_quant,
|
355 |
self.pad_token_id,
|
356 |
-
self.forward_batch_size
|
357 |
-
|
358 |
|
|
|
|
|
|
|
|
|
|
|
359 |
# save embeddings to output_path
|
360 |
output_path = (Path(output_directory) / output_prefix).with_suffix(".csv")
|
361 |
embs_df.to_csv(output_path)
|
362 |
-
|
363 |
-
return embs_df
|
364 |
|
365 |
def plot_embs(self,
|
366 |
embs,
|
@@ -446,5 +490,4 @@ class EmbExtractor:
|
|
446 |
continue
|
447 |
output_prefix_label = output_prefix + f"_heatmap_{label}"
|
448 |
output_file = (Path(output_directory) / output_prefix_label).with_suffix(".pdf")
|
449 |
-
plot_heatmap(embs, emb_dims, label, output_file, kwargs_dict)
|
450 |
-
|
|
|
14 |
emb_label=["disease","cell_type"],
|
15 |
labels_to_plot=["disease","cell_type"],
|
16 |
forward_batch_size=100,
|
17 |
+
nproc=16,
|
18 |
+
summary_stat=None)
|
19 |
embs = embex.extract_embs("path/to/model",
|
20 |
"path/to/input_data",
|
21 |
"path/to/output_directory",
|
|
|
34 |
import numpy as np
|
35 |
import pandas as pd
|
36 |
import pickle
|
37 |
+
from tdigest import TDigest
|
38 |
import scanpy as sc
|
39 |
import seaborn as sns
|
40 |
import torch
|
|
|
56 |
|
57 |
logger = logging.getLogger(__name__)
|
58 |
|
59 |
+
# extract embeddings
|
60 |
def get_embs(model,
|
61 |
filtered_input_data,
|
62 |
emb_mode,
|
63 |
layer_to_quant,
|
64 |
pad_token_id,
|
65 |
+
forward_batch_size,
|
66 |
+
summary_stat):
|
67 |
|
68 |
model_input_size = get_model_input_size(model)
|
69 |
total_batch_length = len(filtered_input_data)
|
|
|
|
|
70 |
|
71 |
+
if summary_stat is None:
|
72 |
+
embs_list = []
|
73 |
+
elif summary_stat is not None:
|
74 |
+
# test embedding extraction for example cell and extract # emb dims
|
75 |
+
example = filtered_input_data.select([i for i in range(1)])
|
76 |
+
example.set_format(type="torch")
|
77 |
+
emb_dims = test_emb(model, example["input_ids"], layer_to_quant)
|
78 |
+
# initiate tdigests for # of emb dims
|
79 |
+
embs_tdigests = [TDigest() for _ in range(emb_dims)]
|
80 |
+
|
81 |
for i in trange(0, total_batch_length, forward_batch_size):
|
82 |
max_range = min(i+forward_batch_size, total_batch_length)
|
83 |
|
|
|
91 |
max_len,
|
92 |
pad_token_id,
|
93 |
model_input_size)
|
94 |
+
|
95 |
with torch.no_grad():
|
96 |
outputs = model(
|
97 |
input_ids = input_data_minibatch.to("cuda"),
|
98 |
attention_mask = gen_attention_mask(minibatch)
|
99 |
)
|
100 |
+
|
101 |
embs_i = outputs.hidden_states[layer_to_quant]
|
102 |
|
103 |
if emb_mode == "cell":
|
104 |
mean_embs = mean_nonpadding_embs(embs_i, original_lens)
|
105 |
+
if summary_stat is None:
|
106 |
+
embs_list += [mean_embs]
|
107 |
+
elif summary_stat is not None:
|
108 |
+
# update tdigests with current batch for each emb dim
|
109 |
+
# note: tdigest batch update known to be slow so updating serially
|
110 |
+
[embs_tdigests[j].update(mean_embs[i,j].item()) for i in range(mean_embs.size(0)) for j in range(emb_dims)]
|
111 |
|
112 |
del outputs
|
113 |
del minibatch
|
114 |
del input_data_minibatch
|
115 |
del embs_i
|
116 |
del mean_embs
|
117 |
+
torch.cuda.empty_cache()
|
118 |
+
|
119 |
+
if summary_stat is None:
|
120 |
+
embs_stack = torch.cat(embs_list)
|
121 |
+
# calculate summary stat embs from approximated tdigests
|
122 |
+
elif summary_stat is not None:
|
123 |
+
if summary_stat == "mean":
|
124 |
+
summary_emb_list = [embs_tdigests[i].trimmed_mean(0,100) for i in range(emb_dims)]
|
125 |
+
elif summary_stat == "median":
|
126 |
+
summary_emb_list = [embs_tdigests[i].percentile(50) for i in range(emb_dims)]
|
127 |
+
embs_stack = torch.tensor(summary_emb_list)
|
128 |
+
|
129 |
return embs_stack
|
130 |
|
131 |
+
def test_emb(model, example, layer_to_quant):
|
132 |
+
with torch.no_grad():
|
133 |
+
outputs = model(
|
134 |
+
input_ids = example.to("cuda")
|
135 |
+
)
|
136 |
+
|
137 |
+
embs_test = outputs.hidden_states[layer_to_quant]
|
138 |
+
return embs_test.size()[2]
|
139 |
+
|
140 |
def label_embs(embs, downsampled_data, emb_labels):
|
141 |
embs_df = pd.DataFrame(embs.cpu())
|
142 |
if emb_labels is not None:
|
|
|
164 |
|
165 |
sc.pl.umap(adata, color=label, save=output_file, **default_kwargs_dict)
|
166 |
|
|
|
167 |
def gen_heatmap_class_colors(labels, df):
|
168 |
pal = sns.cubehelix_palette(len(Counter(labels).keys()), light=0.9, dark=0.1, hue=1, reverse=True, start=1, rot=-2)
|
169 |
lut = dict(zip(map(str, Counter(labels).keys()), pal))
|
|
|
240 |
"labels_to_plot": {None, list},
|
241 |
"forward_batch_size": {int},
|
242 |
"nproc": {int},
|
243 |
+
"summary_stat": {None, "mean", "median"},
|
244 |
}
|
245 |
def __init__(
|
246 |
self,
|
|
|
255 |
labels_to_plot=None,
|
256 |
forward_batch_size=100,
|
257 |
nproc=4,
|
258 |
+
summary_stat=None,
|
259 |
token_dictionary_file=TOKEN_DICTIONARY_FILE,
|
260 |
):
|
261 |
"""
|
|
|
297 |
Batch size for forward pass.
|
298 |
nproc : int
|
299 |
Number of CPU processes to use.
|
300 |
+
summary_stat : {None, "mean", "median"}
|
301 |
+
If not None, outputs only approximated mean or median embedding of input data.
|
302 |
+
Recommended if encountering memory constraints while generating goal embedding positions.
|
303 |
+
Slower but more memory-efficient.
|
304 |
token_dictionary_file : Path
|
305 |
Path to pickle file containing token dictionary (Ensembl ID:token).
|
306 |
"""
|
|
|
316 |
self.labels_to_plot = labels_to_plot
|
317 |
self.forward_batch_size = forward_batch_size
|
318 |
self.nproc = nproc
|
319 |
+
self.summary_stat = summary_stat
|
320 |
|
321 |
self.validate_options()
|
322 |
|
|
|
392 |
self.emb_mode,
|
393 |
layer_to_quant,
|
394 |
self.pad_token_id,
|
395 |
+
self.forward_batch_size,
|
396 |
+
self.summary_stat)
|
397 |
|
398 |
+
if self.summary_stat is None:
|
399 |
+
embs_df = label_embs(embs, downsampled_data, self.emb_label)
|
400 |
+
elif self.summary_stat is not None:
|
401 |
+
embs_df = pd.DataFrame(embs.cpu()).T
|
402 |
+
|
403 |
# save embeddings to output_path
|
404 |
output_path = (Path(output_directory) / output_prefix).with_suffix(".csv")
|
405 |
embs_df.to_csv(output_path)
|
406 |
+
|
407 |
+
return embs_df
|
408 |
|
409 |
def plot_embs(self,
|
410 |
embs,
|
|
|
490 |
continue
|
491 |
output_prefix_label = output_prefix + f"_heatmap_{label}"
|
492 |
output_file = (Path(output_directory) / output_prefix_label).with_suffix(".pdf")
|
493 |
+
plot_heatmap(embs, emb_dims, label, output_file, kwargs_dict)
|
|