Update for gene classification (#330)
Browse files- Update for gene classification (f49922d12ee1fe946a511e6d96b9cda14ce7c22b)
Co-authored-by: Han Chen <hchen725@users.noreply.huggingface.co>
- geneformer/classifier_utils.py +72 -33
geneformer/classifier_utils.py
CHANGED
@@ -1,4 +1,6 @@
|
|
|
|
1 |
import logging
|
|
|
2 |
import random
|
3 |
from collections import Counter, defaultdict
|
4 |
|
@@ -6,6 +8,7 @@ import numpy as np
|
|
6 |
import pandas as pd
|
7 |
from scipy.stats import chisquare, ranksums
|
8 |
from sklearn.metrics import accuracy_score, f1_score
|
|
|
9 |
|
10 |
from . import perturber_utils as pu
|
11 |
|
@@ -133,61 +136,55 @@ def label_gene_classes(example, class_id_dict, gene_class_dict):
|
|
133 |
]
|
134 |
|
135 |
|
136 |
-
def
|
137 |
data, targets, labels, train_index, eval_index, max_ncells, iteration_num, num_proc
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
):
|
139 |
# generate cross-validation splits
|
140 |
targets = np.array(targets)
|
141 |
labels = np.array(labels)
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
label_dict_eval = dict(zip(targets_eval, labels_eval))
|
146 |
|
147 |
# function to filter by whether contains train or eval labels
|
148 |
-
def
|
149 |
-
a =
|
150 |
-
b = example["input_ids"]
|
151 |
-
return not set(a).isdisjoint(b)
|
152 |
-
|
153 |
-
def if_contains_eval_label(example):
|
154 |
-
a = targets_eval
|
155 |
b = example["input_ids"]
|
156 |
return not set(a).isdisjoint(b)
|
157 |
|
158 |
# filter dataset for examples containing classes for this split
|
159 |
-
logger.info(f"Filtering
|
160 |
-
|
161 |
logger.info(
|
162 |
-
f"Filtered {round((1-len(
|
163 |
-
)
|
164 |
-
logger.info(f"Filtering evalation data for genes in split {iteration_num}")
|
165 |
-
eval_data = data.filter(if_contains_eval_label, num_proc=num_proc)
|
166 |
-
logger.info(
|
167 |
-
f"Filtered {round((1-len(eval_data)/len(data))*100)}%; {len(eval_data)} remain\n"
|
168 |
)
|
169 |
|
170 |
# subsample to max_ncells
|
171 |
-
|
172 |
-
eval_data = downsample_and_shuffle(eval_data, max_ncells, None, None)
|
173 |
|
174 |
# relabel genes for this split
|
175 |
-
def
|
176 |
example["labels"] = [
|
177 |
-
|
178 |
]
|
179 |
return example
|
180 |
|
181 |
-
|
182 |
-
example["labels"] = [
|
183 |
-
label_dict_eval.get(token_id, -100) for token_id in example["input_ids"]
|
184 |
-
]
|
185 |
-
return example
|
186 |
|
187 |
-
|
188 |
-
eval_data = eval_data.map(eval_classes_to_ids, num_proc=num_proc)
|
189 |
-
|
190 |
-
return train_data, eval_data
|
191 |
|
192 |
|
193 |
def prep_gene_classifier_all_data(data, targets, labels, max_ncells, num_proc):
|
@@ -423,3 +420,45 @@ def get_default_train_args(model, classifier, data, output_dir):
|
|
423 |
training_args.update(default_training_args)
|
424 |
|
425 |
return training_args, freeze_layers
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
import logging
|
3 |
+
import os
|
4 |
import random
|
5 |
from collections import Counter, defaultdict
|
6 |
|
|
|
8 |
import pandas as pd
|
9 |
from scipy.stats import chisquare, ranksums
|
10 |
from sklearn.metrics import accuracy_score, f1_score
|
11 |
+
from sklearn.model_selection import StratifiedKFold, train_test_split
|
12 |
|
13 |
from . import perturber_utils as pu
|
14 |
|
|
|
136 |
]
|
137 |
|
138 |
|
139 |
+
def prep_gene_classifier_train_eval_split(
|
140 |
data, targets, labels, train_index, eval_index, max_ncells, iteration_num, num_proc
|
141 |
+
):
|
142 |
+
# generate cross-validation splits
|
143 |
+
train_data = prep_gene_classifier_split(
|
144 |
+
data, targets, labels, train_index, "train", max_ncells, iteration_num, num_proc
|
145 |
+
)
|
146 |
+
eval_data = prep_gene_classifier_split(
|
147 |
+
data, targets, labels, eval_index, "eval", max_ncells, iteration_num, num_proc
|
148 |
+
)
|
149 |
+
return train_data, eval_data
|
150 |
+
|
151 |
+
|
152 |
+
def prep_gene_classifier_split(
|
153 |
+
data, targets, labels, index, subset_name, max_ncells, iteration_num, num_proc
|
154 |
):
|
155 |
# generate cross-validation splits
|
156 |
targets = np.array(targets)
|
157 |
labels = np.array(labels)
|
158 |
+
targets_subset = targets[index]
|
159 |
+
labels_subset = labels[index]
|
160 |
+
label_dict_subset = dict(zip(targets_subset, labels_subset))
|
|
|
161 |
|
162 |
# function to filter by whether contains train or eval labels
|
163 |
+
def if_contains_subset_label(example):
|
164 |
+
a = targets_subset
|
|
|
|
|
|
|
|
|
|
|
165 |
b = example["input_ids"]
|
166 |
return not set(a).isdisjoint(b)
|
167 |
|
168 |
# filter dataset for examples containing classes for this split
|
169 |
+
logger.info(f"Filtering data for {subset_name} genes in split {iteration_num}")
|
170 |
+
subset_data = data.filter(if_contains_subset_label, num_proc=num_proc)
|
171 |
logger.info(
|
172 |
+
f"Filtered {round((1-len(subset_data)/len(data))*100)}%; {len(subset_data)} remain\n"
|
|
|
|
|
|
|
|
|
|
|
173 |
)
|
174 |
|
175 |
# subsample to max_ncells
|
176 |
+
subset_data = downsample_and_shuffle(subset_data, max_ncells, None, None)
|
|
|
177 |
|
178 |
# relabel genes for this split
|
179 |
+
def subset_classes_to_ids(example):
|
180 |
example["labels"] = [
|
181 |
+
label_dict_subset.get(token_id, -100) for token_id in example["input_ids"]
|
182 |
]
|
183 |
return example
|
184 |
|
185 |
+
subset_data = subset_data.map(subset_classes_to_ids, num_proc=num_proc)
|
|
|
|
|
|
|
|
|
186 |
|
187 |
+
return subset_data
|
|
|
|
|
|
|
188 |
|
189 |
|
190 |
def prep_gene_classifier_all_data(data, targets, labels, max_ncells, num_proc):
|
|
|
420 |
training_args.update(default_training_args)
|
421 |
|
422 |
return training_args, freeze_layers
|
423 |
+
|
424 |
+
|
425 |
+
def load_best_model(directory, model_type, num_classes, mode="eval"):
|
426 |
+
file_dict = dict()
|
427 |
+
for subdir, dirs, files in os.walk(directory):
|
428 |
+
for file in files:
|
429 |
+
if file.endswith("result.json"):
|
430 |
+
with open(f"{subdir}/{file}", "rb") as fp:
|
431 |
+
result_json = json.load(fp)
|
432 |
+
file_dict[f"{subdir}"] = result_json["eval_macro_f1"]
|
433 |
+
file_df = pd.DataFrame(
|
434 |
+
{"dir": file_dict.keys(), "eval_macro_f1": file_dict.values()}
|
435 |
+
)
|
436 |
+
model_superdir = (
|
437 |
+
"run-"
|
438 |
+
+ file_df.iloc[file_df["eval_macro_f1"].idxmax()]["dir"]
|
439 |
+
.split("_objective_")[2]
|
440 |
+
.split("_")[0]
|
441 |
+
)
|
442 |
+
|
443 |
+
for subdir, dirs, files in os.walk(f"{directory}/{model_superdir}"):
|
444 |
+
for file in files:
|
445 |
+
if file.endswith("model.safetensors"):
|
446 |
+
model = pu.load_model(model_type, num_classes, f"{subdir}", mode)
|
447 |
+
return model
|
448 |
+
|
449 |
+
|
450 |
+
class StratifiedKFold3(StratifiedKFold):
|
451 |
+
def split(self, targets, labels, test_ratio=0.5, groups=None):
|
452 |
+
s = super().split(targets, labels, groups)
|
453 |
+
for train_indxs, test_indxs in s:
|
454 |
+
if test_ratio == 0:
|
455 |
+
yield train_indxs, test_indxs, None
|
456 |
+
else:
|
457 |
+
labels_test = np.array(labels)[test_indxs]
|
458 |
+
valid_indxs, test_indxs = train_test_split(
|
459 |
+
test_indxs,
|
460 |
+
stratify=labels_test,
|
461 |
+
test_size=test_ratio,
|
462 |
+
random_state=0,
|
463 |
+
)
|
464 |
+
yield train_indxs, valid_indxs, test_indxs
|