Christina Theodoris commited on
Commit
277b470
1 Parent(s): c33c308

Add alternative methods comparison examples

Browse files
benchmarking/castle_cell_type_annotation.r ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Usage: Rscript castle_cell_type_annotation.r organ
2
+
3
+ # parse ordered arguments
4
+ args <- commandArgs(trailingOnly=TRUE)
5
+ organ <- args[1]
6
+
7
+ suppressPackageStartupMessages(library(scater))
8
+ suppressPackageStartupMessages(library(xgboost))
9
+ suppressPackageStartupMessages(library(igraph))
10
+ BREAKS=c(-1, 0, 1, 6, Inf)
11
+ nFeatures = 100
12
+
13
+ print(paste("Training ", organ, sep=""))
14
+
15
+ # import training and test data
16
+ rootdir="/path/to/data/"
17
+ train_counts <- t(as.matrix(read.csv(file = paste(rootdir, organ, "_filtered_data_train.csv", sep=""), row.names = 1)))
18
+ test_counts <- t(as.matrix(read.csv(file = paste(rootdir, organ, "_filtered_data_test.csv", sep=""), row.names = 1)))
19
+ train_celltype <- as.matrix(read.csv(file = paste(rootdir, organ, "_filtered_celltype_train.csv", sep="")))
20
+ test_celltype <- as.matrix(read.csv(file = paste(rootdir, organ, "_filtered_celltype_test.csv", sep="")))
21
+
22
+ # select features
23
+ sourceCellTypes = as.factor(train_celltype[,"Cell_type"])
24
+ ds = rbind(train_counts,test_counts)
25
+ ds[is.na(ds)] <- 0
26
+ isSource = c(rep(TRUE,nrow(train_counts)), rep(FALSE,nrow(test_counts)))
27
+ topFeaturesAvg = colnames(ds[isSource,])[order(apply(ds[isSource,], 2, mean), decreasing = T)]
28
+ topFeaturesMi = names(sort(apply(ds[isSource,],2,function(x) { compare(cut(x,breaks=BREAKS),sourceCellTypes,method = "nmi") }), decreasing = T))
29
+ selectedFeatures = union(head(topFeaturesAvg, nFeatures) , head(topFeaturesMi, nFeatures) )
30
+ tmp = cor(ds[isSource,selectedFeatures], method = "pearson")
31
+ tmp[!lower.tri(tmp)] = 0
32
+ selectedFeatures = selectedFeatures[apply(tmp,2,function(x) any(x < 0.9))]
33
+ remove(tmp)
34
+
35
+ # bin expression values and expand features by bins
36
+ dsBins = apply(ds[, selectedFeatures], 2, cut, breaks= BREAKS)
37
+ nUniq = apply(dsBins, 2, function(x) { length(unique(x)) })
38
+ ds = model.matrix(~ . , as.data.frame(dsBins[,nUniq>1]))
39
+ remove(dsBins, nUniq)
40
+
41
+ # train model
42
+ train = runif(nrow(ds[isSource,]))<0.8
43
+ # slightly different setup for multiclass and binary classification
44
+ if (length(unique(sourceCellTypes)) > 2) {
45
+ xg=xgboost(data=ds[isSource,][train, ] ,
46
+ label=as.numeric(sourceCellTypes[train])-1,
47
+ objective="multi:softmax", num_class=length(unique(sourceCellTypes)),
48
+ eta=0.7 , nthread=5, nround=20, verbose=0,
49
+ gamma=0.001, max_depth=5, min_child_weight=10)
50
+ } else {
51
+ xg=xgboost(data=ds[isSource,][train, ] ,
52
+ label=as.numeric(sourceCellTypes[train])-1,
53
+ eta=0.7 , nthread=5, nround=20, verbose=0,
54
+ gamma=0.001, max_depth=5, min_child_weight=10)
55
+ }
56
+
57
+ # validate model
58
+ predictedClasses = predict(xg, ds[!isSource, ])
59
+ testCellTypes = as.factor(test_celltype[,"Cell_type"])
60
+ trueClasses <- as.numeric(testCellTypes)-1
61
+
62
+ cm <- as.matrix(table(Actual = trueClasses, Predicted = predictedClasses))
63
+ n <- sum(cm)
64
+ nc = nrow(cm) # number of classes
65
+ diag = diag(cm) # number of correctly classified instances per class
66
+ rowsums = apply(cm, 1, sum) # number of instances per class
67
+ colsums = apply(cm, 2, sum) # number of predictions per class
68
+ p = rowsums / n # distribution of instances over the actual classes
69
+ q = colsums / n # distribution of instances over the predicted classes
70
+ accuracy = sum(diag) / n
71
+ precision = diag / colsums
72
+ recall = diag / rowsums
73
+ f1 = 2 * precision * recall / (precision + recall)
74
+ macroF1 = mean(f1)
75
+
76
+ print(paste(organ, " accuracy: ", accuracy, sep=""))
77
+ print(paste(organ, " macroF1: ", macroF1, sep=""))
78
+
79
+ results_df = data.frame(Accuracy=c(accuracy),macroF1=c(macroF1))
80
+ write.csv(results_df,paste(rootdir, organ, "_castle_results_test.csv", sep=""), row.names = FALSE)
benchmarking/prepare_datasplits_for_cell_type_annotation.ipynb ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "25107132",
6
+ "metadata": {},
7
+ "source": [
8
+ "### Preparing train and test data splits for cell type annotation application"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": 3,
14
+ "id": "83d8d249-affe-45dd-915e-992b4b35b31a",
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "import os\n",
19
+ "import pandas as pd\n",
20
+ "from sklearn.model_selection import train_test_split\n",
21
+ "from tqdm.notebook import tqdm\n",
22
+ "from collections import Counter\n",
23
+ "import pickle"
24
+ ]
25
+ },
26
+ {
27
+ "cell_type": "code",
28
+ "execution_count": 4,
29
+ "id": "e3e6a2bf-44c8-4164-9ecd-1686230ea8be",
30
+ "metadata": {},
31
+ "outputs": [
32
+ {
33
+ "data": {
34
+ "text/plain": [
35
+ "['pancreas',\n",
36
+ " 'liver',\n",
37
+ " 'blood',\n",
38
+ " 'lung',\n",
39
+ " 'spleen',\n",
40
+ " 'placenta',\n",
41
+ " 'colorectum',\n",
42
+ " 'kidney',\n",
43
+ " 'brain']"
44
+ ]
45
+ },
46
+ "execution_count": 4,
47
+ "metadata": {},
48
+ "output_type": "execute_result"
49
+ }
50
+ ],
51
+ "source": [
52
+ "rootdir = \"/path/to/data/\"\n",
53
+ "\n",
54
+ "# collect panel of tissues to test\n",
55
+ "dir_list = []\n",
56
+ "for dir_i in os.listdir(rootdir):\n",
57
+ " if (\"results\" not in dir_i) & (os.path.isdir(os.path.join(rootdir, dir_i))):\n",
58
+ " dir_list += [dir_i]\n",
59
+ "dir_list"
60
+ ]
61
+ },
62
+ {
63
+ "cell_type": "code",
64
+ "execution_count": 5,
65
+ "id": "0b205eec-a518-472a-ab90-dd63ef9803cd",
66
+ "metadata": {},
67
+ "outputs": [
68
+ {
69
+ "data": {
70
+ "text/html": [
71
+ "<div>\n",
72
+ "<style scoped>\n",
73
+ " .dataframe tbody tr th:only-of-type {\n",
74
+ " vertical-align: middle;\n",
75
+ " }\n",
76
+ "\n",
77
+ " .dataframe tbody tr th {\n",
78
+ " vertical-align: top;\n",
79
+ " }\n",
80
+ "\n",
81
+ " .dataframe thead th {\n",
82
+ " text-align: right;\n",
83
+ " }\n",
84
+ "</style>\n",
85
+ "<table border=\"1\" class=\"dataframe\">\n",
86
+ " <thead>\n",
87
+ " <tr style=\"text-align: right;\">\n",
88
+ " <th></th>\n",
89
+ " <th>filter_pass</th>\n",
90
+ " <th>original_cell_id</th>\n",
91
+ " </tr>\n",
92
+ " </thead>\n",
93
+ " <tbody>\n",
94
+ " <tr>\n",
95
+ " <th>0</th>\n",
96
+ " <td>0</td>\n",
97
+ " <td>C_1</td>\n",
98
+ " </tr>\n",
99
+ " <tr>\n",
100
+ " <th>1</th>\n",
101
+ " <td>1</td>\n",
102
+ " <td>C_2</td>\n",
103
+ " </tr>\n",
104
+ " <tr>\n",
105
+ " <th>2</th>\n",
106
+ " <td>0</td>\n",
107
+ " <td>C_3</td>\n",
108
+ " </tr>\n",
109
+ " <tr>\n",
110
+ " <th>3</th>\n",
111
+ " <td>1</td>\n",
112
+ " <td>C_4</td>\n",
113
+ " </tr>\n",
114
+ " <tr>\n",
115
+ " <th>4</th>\n",
116
+ " <td>0</td>\n",
117
+ " <td>C_5</td>\n",
118
+ " </tr>\n",
119
+ " <tr>\n",
120
+ " <th>...</th>\n",
121
+ " <td>...</td>\n",
122
+ " <td>...</td>\n",
123
+ " </tr>\n",
124
+ " <tr>\n",
125
+ " <th>9590</th>\n",
126
+ " <td>1</td>\n",
127
+ " <td>C_9591</td>\n",
128
+ " </tr>\n",
129
+ " <tr>\n",
130
+ " <th>9591</th>\n",
131
+ " <td>1</td>\n",
132
+ " <td>C_9592</td>\n",
133
+ " </tr>\n",
134
+ " <tr>\n",
135
+ " <th>9592</th>\n",
136
+ " <td>1</td>\n",
137
+ " <td>C_9593</td>\n",
138
+ " </tr>\n",
139
+ " <tr>\n",
140
+ " <th>9593</th>\n",
141
+ " <td>1</td>\n",
142
+ " <td>C_9594</td>\n",
143
+ " </tr>\n",
144
+ " <tr>\n",
145
+ " <th>9594</th>\n",
146
+ " <td>1</td>\n",
147
+ " <td>C_9595</td>\n",
148
+ " </tr>\n",
149
+ " </tbody>\n",
150
+ "</table>\n",
151
+ "<p>9595 rows × 2 columns</p>\n",
152
+ "</div>"
153
+ ],
154
+ "text/plain": [
155
+ " filter_pass original_cell_id\n",
156
+ "0 0 C_1\n",
157
+ "1 1 C_2\n",
158
+ "2 0 C_3\n",
159
+ "3 1 C_4\n",
160
+ "4 0 C_5\n",
161
+ "... ... ...\n",
162
+ "9590 1 C_9591\n",
163
+ "9591 1 C_9592\n",
164
+ "9592 1 C_9593\n",
165
+ "9593 1 C_9594\n",
166
+ "9594 1 C_9595\n",
167
+ "\n",
168
+ "[9595 rows x 2 columns]"
169
+ ]
170
+ },
171
+ "execution_count": 5,
172
+ "metadata": {},
173
+ "output_type": "execute_result"
174
+ }
175
+ ],
176
+ "source": [
177
+ "# dictionary of cell barcodes that passed QC filtering applied by Geneformer \n",
178
+ "# to ensure same cells were used for comparison\n",
179
+ "with open(f\"{rootdir}deepsort_filter_dict.pickle\", \"rb\") as fp:\n",
180
+ " filter_dict = pickle.load(fp)\n",
181
+ "\n",
182
+ "# for example:\n",
183
+ "filter_dict[\"human_Placenta9595_data\"]"
184
+ ]
185
+ },
186
+ {
187
+ "cell_type": "code",
188
+ "execution_count": null,
189
+ "id": "207e3571-0236-4493-83b3-a89b67b16cb2",
190
+ "metadata": {
191
+ "tags": []
192
+ },
193
+ "outputs": [],
194
+ "source": [
195
+ "for dir_name in tqdm(dir_list):\n",
196
+ "\n",
197
+ " df = pd.DataFrame()\n",
198
+ " ct_df = pd.DataFrame(columns=[\"Cell\",\"Cell_type\"])\n",
199
+ " \n",
200
+ " subrootdir = f\"{rootdir}{dir_name}/\"\n",
201
+ " for subdir, dirs, files in os.walk(subrootdir):\n",
202
+ " for i in range(len(files)):\n",
203
+ " file = files[i]\n",
204
+ " if file.endswith(\"_data.csv\"):\n",
205
+ " file_prefix = file.replace(\"_data.csv\",\"\")\n",
206
+ " sample_prefix = file.replace(\".csv\",\"\")\n",
207
+ " filter_df = filter_dict[sample_prefix]\n",
208
+ " sample_to_analyze = list(filter_df[filter_df[\"filter_pass\"]==1][\"original_cell_id\"])\n",
209
+ " \n",
210
+ " # collect data for each tissue\n",
211
+ " df_i = pd.read_csv(f\"{subrootdir}{file}\", index_col=0)\n",
212
+ " df_i = df_i[sample_to_analyze]\n",
213
+ " df_i.columns = [f\"{i}_{cell_id}\" for cell_id in df_i.columns]\n",
214
+ " df = pd.concat([df,df_i],axis=1)\n",
215
+ " \n",
216
+ " # collect cell type metadata\n",
217
+ " ct_df_i = pd.read_csv(f\"{subrootdir}{file_prefix}_celltype.csv\", index_col=0)\n",
218
+ " ct_df_i.columns = [\"Cell\",\"Cell_type\"]\n",
219
+ " ct_df_i[\"Cell\"] = [f\"{i}_{cell_id}\" for cell_id in ct_df_i[\"Cell\"]]\n",
220
+ " ct_df = pd.concat([ct_df,ct_df_i],axis=0)\n",
221
+ " \n",
222
+ " # per published scDeepsort method, filter data for cell types >0.5% of data\n",
223
+ " ct_counts = Counter(ct_df[\"Cell_type\"])\n",
224
+ " total_count = sum(ct_counts.values())\n",
225
+ " nonrare_cell_types = [cell_type for cell_type,count in ct_counts.items() if count>(total_count*0.005)]\n",
226
+ " nonrare_cells = list(ct_df[ct_df[\"Cell_type\"].isin(nonrare_cell_types)][\"Cell\"])\n",
227
+ " df = df[df.columns.intersection(nonrare_cells)]\n",
228
+ "\n",
229
+ " # split into 80/20 train/test data\n",
230
+ " train, test = train_test_split(df.T, test_size=0.2)\n",
231
+ " train = train.T\n",
232
+ " test = test.T \n",
233
+ " \n",
234
+ " # save filtered train/test data\n",
235
+ " train.to_csv(f\"{subrootdir}{dir_name}_filtered_data_train.csv\")\n",
236
+ " test.to_csv(f\"{subrootdir}{dir_name}_filtered_data_test.csv\")\n",
237
+ "\n",
238
+ " # split metadata into train/test data\n",
239
+ " ct_df_train = ct_df[ct_df[\"Cell\"].isin(list(train.columns))]\n",
240
+ " ct_df_test = ct_df[ct_df[\"Cell\"].isin(list(test.columns))]\n",
241
+ " train_order_dict = dict(zip(train.columns,[i for i in range(len(train.columns))]))\n",
242
+ " test_order_dict = dict(zip(test.columns,[i for i in range(len(test.columns))]))\n",
243
+ " ct_df_train[\"order\"] = [train_order_dict[cell_id] for cell_id in ct_df_train[\"Cell\"]]\n",
244
+ " ct_df_test[\"order\"] = [test_order_dict[cell_id] for cell_id in ct_df_test[\"Cell\"]]\n",
245
+ " ct_df_train = ct_df_train.sort_values(\"order\")\n",
246
+ " ct_df_test = ct_df_test.sort_values(\"order\")\n",
247
+ " ct_df_train = ct_df_train.drop(\"order\",axis=1)\n",
248
+ " ct_df_test = ct_df_test.drop(\"order\",axis=1)\n",
249
+ " assert list(ct_df_train[\"Cell\"]) == list(train.columns)\n",
250
+ " assert list(ct_df_test[\"Cell\"]) == list(test.columns)\n",
251
+ " train_labels = list(Counter(ct_df_train[\"Cell_type\"]).keys())\n",
252
+ " test_labels = list(Counter(ct_df_test[\"Cell_type\"]).keys())\n",
253
+ " assert set(train_labels) == set(test_labels)\n",
254
+ " \n",
255
+ " # save train/test cell type annotations\n",
256
+ " ct_df_train.to_csv(f\"{subrootdir}{dir_name}_filtered_celltype_train.csv\")\n",
257
+ " ct_df_test.to_csv(f\"{subrootdir}{dir_name}_filtered_celltype_test.csv\")\n",
258
+ " "
259
+ ]
260
+ }
261
+ ],
262
+ "metadata": {
263
+ "kernelspec": {
264
+ "display_name": "Python 3.8.6 64-bit ('3.8.6')",
265
+ "language": "python",
266
+ "name": "python3"
267
+ },
268
+ "language_info": {
269
+ "codemirror_mode": {
270
+ "name": "ipython",
271
+ "version": 3
272
+ },
273
+ "file_extension": ".py",
274
+ "mimetype": "text/x-python",
275
+ "name": "python",
276
+ "nbconvert_exporter": "python",
277
+ "pygments_lexer": "ipython3",
278
+ "version": "3.8.6"
279
+ },
280
+ "vscode": {
281
+ "interpreter": {
282
+ "hash": "eba1599a1f7e611c14c87ccff6793920aa63510b01fc0e229d6dd014149b8829"
283
+ }
284
+ }
285
+ },
286
+ "nbformat": 4,
287
+ "nbformat_minor": 5
288
+ }
benchmarking/randomForest_token_classifier_dosageTF_10k.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
benchmarking/scDeepsort_train_predict.ipynb ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 3,
6
+ "id": "83d8d249-affe-45dd-915e-992b4b35b31a",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import os\n",
11
+ "import numpy as np\n",
12
+ "import pandas as pd\n",
13
+ "import deepsort\n",
14
+ "from sklearn.metrics import accuracy_score, f1_score\n",
15
+ "from tqdm.notebook import tqdm\n",
16
+ "import pickle"
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": 4,
22
+ "id": "25de46ec-8a41-484d-8e14-d2b19768fc2c",
23
+ "metadata": {},
24
+ "outputs": [],
25
+ "source": [
26
+ "def compute_metrics(labels, preds):\n",
27
+ "\n",
28
+ " # calculate accuracy and macro f1 using sklearn's function\n",
29
+ " acc = accuracy_score(labels, preds)\n",
30
+ " macro_f1 = f1_score(labels, preds, average='macro')\n",
31
+ " return {\n",
32
+ " 'accuracy': acc,\n",
33
+ " 'macro_f1': macro_f1\n",
34
+ " }"
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "execution_count": 5,
40
+ "id": "a4029b2b-afca-4300-82a2-082fec59f191",
41
+ "metadata": {},
42
+ "outputs": [
43
+ {
44
+ "data": {
45
+ "text/plain": [
46
+ "['pancreas',\n",
47
+ " 'liver',\n",
48
+ " 'blood',\n",
49
+ " 'lung',\n",
50
+ " 'spleen',\n",
51
+ " 'placenta',\n",
52
+ " 'colorectum',\n",
53
+ " 'kidney',\n",
54
+ " 'brain']"
55
+ ]
56
+ },
57
+ "execution_count": 5,
58
+ "metadata": {},
59
+ "output_type": "execute_result"
60
+ }
61
+ ],
62
+ "source": [
63
+ "rootdir = \"/path/to/data/\"\n",
64
+ "\n",
65
+ "dir_list = []\n",
66
+ "for dir_i in os.listdir(rootdir):\n",
67
+ " if (\"results\" not in dir_i) & (os.path.isdir(os.path.join(rootdir, dir_i))):\n",
68
+ " dir_list += [dir_i]\n",
69
+ "dir_list"
70
+ ]
71
+ },
72
+ {
73
+ "cell_type": "code",
74
+ "execution_count": null,
75
+ "id": "ddcdc5cd-871e-4fd2-8457-18d3049fa76c",
76
+ "metadata": {
77
+ "tags": []
78
+ },
79
+ "outputs": [],
80
+ "source": [
81
+ "output_dir = \"results_EDefault_filtered\"\n",
82
+ "n_epochs = \"Default\" # scDeepsort default epochs = 300\n",
83
+ "\n",
84
+ "results_dict = dict()\n",
85
+ "for dir_name in tqdm(dir_list):\n",
86
+ " print(f\"TRAINING: {dir_name}\")\n",
87
+ " subrootdir = f\"{rootdir}{dir_name}/\"\n",
88
+ " train_files = [(f\"{subrootdir}{dir_name}_filtered_data_train.csv\",f\"{subrootdir}{dir_name}_filtered_celltype_train.csv\")]\n",
89
+ " test_file = f\"{subrootdir}{dir_name}_filtered_data_test.csv\"\n",
90
+ " label_file = f\"{subrootdir}{dir_name}_filtered_celltype_test.csv\"\n",
91
+ " \n",
92
+ " # define the model\n",
93
+ " model = deepsort.DeepSortClassifier(species='human',\n",
94
+ " tissue=dir_name,\n",
95
+ " gpu_id=0,\n",
96
+ " random_seed=1,\n",
97
+ " validation_fraction=0) # use all training data (already held out 20% in test data file)\n",
98
+ "\n",
99
+ " # fit the model\n",
100
+ " model.fit(train_files, save_path=f\"{subrootdir}{output_dir}\")\n",
101
+ " \n",
102
+ " # use the saved model to predict cell types in test data\n",
103
+ " model.predict(input_file=test_file,\n",
104
+ " model_path=f\"{subrootdir}{output_dir}\",\n",
105
+ " save_path=f\"{subrootdir}{output_dir}\",\n",
106
+ " unsure_rate=0,\n",
107
+ " file_type='csv')\n",
108
+ " labels_df = pd.read_csv(label_file)\n",
109
+ " preds_df = pd.read_csv(f\"{subrootdir}{output_dir}/human_{dir_name}_{dir_name}_filtered_data_test.csv\")\n",
110
+ " label_cell_ids = labels_df[\"Cell\"]\n",
111
+ " pred_cell_ids = preds_df[\"index\"]\n",
112
+ " assert list(label_cell_ids) == list(pred_cell_ids)\n",
113
+ " labels = list(labels_df[\"Cell_type\"])\n",
114
+ " if isinstance(preds_df[\"cell_subtype\"][0],float):\n",
115
+ " if np.isnan(preds_df[\"cell_subtype\"][0]):\n",
116
+ " preds = list(preds_df[\"cell_type\"])\n",
117
+ " results = compute_metrics(labels, preds)\n",
118
+ " else:\n",
119
+ " preds1 = list(preds_df[\"cell_type\"])\n",
120
+ " preds2 = list(preds_df[\"cell_subtype\"])\n",
121
+ " results1 = compute_metrics(labels, preds1)\n",
122
+ " results2 = compute_metrics(labels, preds2)\n",
123
+ " if results2[\"accuracy\"] > results1[\"accuracy\"]:\n",
124
+ " results = results2\n",
125
+ " else:\n",
126
+ " results = results1\n",
127
+ " \n",
128
+ " print(f\"{dir_name}: {results}\")\n",
129
+ " results_dict[dir_name] = results\n",
130
+ " with open(f\"{subrootdir}deepsort_E{n_epochs}_filtered_pred_{dir_name}.pickle\", \"wb\") as output_file:\n",
131
+ " pickle.dump(results, output_file)\n",
132
+ "\n",
133
+ "# save results\n",
134
+ "with open(f\"{rootdir}deepsort_E{n_epochs}_filtered_pred_dict.pickle\", \"wb\") as output_file:\n",
135
+ " pickle.dump(results_dict, output_file)\n",
136
+ " "
137
+ ]
138
+ }
139
+ ],
140
+ "metadata": {
141
+ "kernelspec": {
142
+ "display_name": "Python 3.8.6 64-bit ('3.8.6')",
143
+ "language": "python",
144
+ "name": "python3"
145
+ },
146
+ "language_info": {
147
+ "codemirror_mode": {
148
+ "name": "ipython",
149
+ "version": 3
150
+ },
151
+ "file_extension": ".py",
152
+ "mimetype": "text/x-python",
153
+ "name": "python",
154
+ "nbconvert_exporter": "python",
155
+ "pygments_lexer": "ipython3",
156
+ "version": "3.8.6"
157
+ },
158
+ "vscode": {
159
+ "interpreter": {
160
+ "hash": "eba1599a1f7e611c14c87ccff6793920aa63510b01fc0e229d6dd014149b8829"
161
+ }
162
+ }
163
+ },
164
+ "nbformat": 4,
165
+ "nbformat_minor": 5
166
+ }