Christina Theodoris
commited on
Commit
•
67f674c
1
Parent(s):
d20ad0a
Add uniform max len for padding for predictions
Browse files
examples/gene_classification.ipynb
CHANGED
@@ -139,14 +139,15 @@
|
|
139 |
"metadata": {},
|
140 |
"outputs": [],
|
141 |
"source": [
|
142 |
-
"def preprocess_classifier_batch(cell_batch):\n",
|
143 |
-
"
|
|
|
144 |
" def pad_label_example(example):\n",
|
145 |
" example[\"labels\"] = np.pad(example[\"labels\"], \n",
|
146 |
-
" (0,
|
147 |
" mode='constant', constant_values=-100)\n",
|
148 |
" example[\"input_ids\"] = np.pad(example[\"input_ids\"], \n",
|
149 |
-
" (0,
|
150 |
" mode='constant', constant_values=token_dictionary.get(\"<pad>\"))\n",
|
151 |
" example[\"attention_mask\"] = (example[\"input_ids\"] != token_dictionary.get(\"<pad>\")).astype(int)\n",
|
152 |
" return example\n",
|
@@ -158,10 +159,19 @@
|
|
158 |
" predict_logits = []\n",
|
159 |
" predict_labels = []\n",
|
160 |
" model.eval()\n",
|
161 |
-
"
|
162 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
" batch_evalset = evalset.select([i for i in range(i, max_range)])\n",
|
164 |
-
" padded_batch = preprocess_classifier_batch(batch_evalset)\n",
|
165 |
" padded_batch.set_format(type=\"torch\")\n",
|
166 |
" \n",
|
167 |
" input_data_batch = padded_batch[\"input_ids\"]\n",
|
@@ -224,7 +234,16 @@
|
|
224 |
" all_weighted_roc_auc = [a*b for a,b in zip(all_roc_auc, wts)]\n",
|
225 |
" roc_auc = np.sum(all_weighted_roc_auc)\n",
|
226 |
" roc_auc_sd = math.sqrt(np.average((all_roc_auc-roc_auc)**2, weights=wts))\n",
|
227 |
-
" return mean_tpr, roc_auc, roc_auc_sd"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
228 |
]
|
229 |
},
|
230 |
{
|
@@ -327,7 +346,7 @@
|
|
327 |
" \n",
|
328 |
" # load model\n",
|
329 |
" model = BertForTokenClassification.from_pretrained(\n",
|
330 |
-
" \"/
|
331 |
" num_labels=2,\n",
|
332 |
" output_attentions = False,\n",
|
333 |
" output_hidden_states = False\n",
|
|
|
139 |
"metadata": {},
|
140 |
"outputs": [],
|
141 |
"source": [
|
142 |
+
"def preprocess_classifier_batch(cell_batch, max_len):\n",
|
143 |
+
" if max_len == None:\n",
|
144 |
+
" max_len = max([len(i) for i in cell_batch[\"input_ids\"]])\n",
|
145 |
" def pad_label_example(example):\n",
|
146 |
" example[\"labels\"] = np.pad(example[\"labels\"], \n",
|
147 |
+
" (0, max_len-len(example[\"input_ids\"])), \n",
|
148 |
" mode='constant', constant_values=-100)\n",
|
149 |
" example[\"input_ids\"] = np.pad(example[\"input_ids\"], \n",
|
150 |
+
" (0, max_len-len(example[\"input_ids\"])), \n",
|
151 |
" mode='constant', constant_values=token_dictionary.get(\"<pad>\"))\n",
|
152 |
" example[\"attention_mask\"] = (example[\"input_ids\"] != token_dictionary.get(\"<pad>\")).astype(int)\n",
|
153 |
" return example\n",
|
|
|
159 |
" predict_logits = []\n",
|
160 |
" predict_labels = []\n",
|
161 |
" model.eval()\n",
|
162 |
+
" \n",
|
163 |
+
" # ensure there is at least 2 examples in each batch to avoid incorrect tensor dims\n",
|
164 |
+
" evalset_len = len(evalset)\n",
|
165 |
+
" max_divisible = find_largest_div(evalset_len, forward_batch_size)\n",
|
166 |
+
" if len(evalset) - max_divisible == 1:\n",
|
167 |
+
" evalset_len = max_divisible\n",
|
168 |
+
" \n",
|
169 |
+
" max_evalset_len = max(evalset.select([i for i in range(evalset_len)])[\"length\"])\n",
|
170 |
+
" \n",
|
171 |
+
" for i in range(0, evalset_len, forward_batch_size):\n",
|
172 |
+
" max_range = min(i+forward_batch_size, evalset_len)\n",
|
173 |
" batch_evalset = evalset.select([i for i in range(i, max_range)])\n",
|
174 |
+
" padded_batch = preprocess_classifier_batch(batch_evalset, max_evalset_len)\n",
|
175 |
" padded_batch.set_format(type=\"torch\")\n",
|
176 |
" \n",
|
177 |
" input_data_batch = padded_batch[\"input_ids\"]\n",
|
|
|
234 |
" all_weighted_roc_auc = [a*b for a,b in zip(all_roc_auc, wts)]\n",
|
235 |
" roc_auc = np.sum(all_weighted_roc_auc)\n",
|
236 |
" roc_auc_sd = math.sqrt(np.average((all_roc_auc-roc_auc)**2, weights=wts))\n",
|
237 |
+
" return mean_tpr, roc_auc, roc_auc_sd\n",
|
238 |
+
"\n",
|
239 |
+
"# Function to find the largest number smaller\n",
|
240 |
+
"# than or equal to N that is divisible by k\n",
|
241 |
+
"def find_largest_div(N, K):\n",
|
242 |
+
" rem = N % K\n",
|
243 |
+
" if(rem == 0):\n",
|
244 |
+
" return N\n",
|
245 |
+
" else:\n",
|
246 |
+
" return N - rem"
|
247 |
]
|
248 |
},
|
249 |
{
|
|
|
346 |
" \n",
|
347 |
" # load model\n",
|
348 |
" model = BertForTokenClassification.from_pretrained(\n",
|
349 |
+
" \"/gladstone/theodoris/lab/ctheodoris/archive/geneformer_files/geneformer/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/\",\n",
|
350 |
" num_labels=2,\n",
|
351 |
" output_attentions = False,\n",
|
352 |
" output_hidden_states = False\n",
|