Christina Theodoris commited on
Commit
ebe5ee8
1 Parent(s): 402ba9b

Update gene classification example to create directory after training arguments are defined

Browse files
Files changed (1) hide show
  1. examples/gene_classification.ipynb +22 -21
examples/gene_classification.ipynb CHANGED
@@ -36,6 +36,7 @@
36
  "from sklearn import preprocessing\n",
37
  "from sklearn.metrics import accuracy_score, auc, confusion_matrix, ConfusionMatrixDisplay, roc_curve\n",
38
  "from sklearn.model_selection import StratifiedKFold\n",
 
39
  "from transformers import BertForTokenClassification\n",
40
  "from transformers import Trainer\n",
41
  "from transformers.training_args import TrainingArguments\n",
@@ -424,26 +425,6 @@
424
  "## Fine-Tune With Gene Classification Learning Objective and Quantify Predictive Performance"
425
  ]
426
  },
427
- {
428
- "cell_type": "code",
429
- "execution_count": null,
430
- "metadata": {},
431
- "outputs": [],
432
- "source": [
433
- "# define output directory path\n",
434
- "current_date = datetime.datetime.now()\n",
435
- "datestamp = f\"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}\"\n",
436
- "training_output_dir = f\"/path/to/models/{datestamp}_geneformer_GeneClassifier_dosageTF_L{max_sequence_length}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_n{subsample_size}_F{freeze_layers}/\"\n",
437
- "\n",
438
- "# ensure not overwriting previously saved model\n",
439
- "ksplit_model_test = os.path.join(training_output_dir, \"ksplit0/models/pytorch_model.bin\")\n",
440
- "if os.path.isfile(ksplit_model_test) == True:\n",
441
- " raise Exception(\"Model already saved to this directory.\")\n",
442
- "\n",
443
- "# make output directory\n",
444
- "subprocess.call(f'mkdir {training_output_dir}', shell=True)"
445
- ]
446
- },
447
  {
448
  "cell_type": "code",
449
  "execution_count": null,
@@ -489,6 +470,7 @@
489
  " \"learning_rate\": max_lr,\n",
490
  " \"do_train\": True,\n",
491
  " \"evaluation_strategy\": \"no\",\n",
 
492
  " \"logging_steps\": 100,\n",
493
  " \"group_by_length\": True,\n",
494
  " \"length_column_name\": \"length\",\n",
@@ -499,10 +481,29 @@
499
  " \"per_device_train_batch_size\": geneformer_batch_size,\n",
500
  " \"per_device_eval_batch_size\": geneformer_batch_size,\n",
501
  " \"num_train_epochs\": epochs,\n",
502
- " \"load_best_model_at_end\": True,\n",
503
  "}"
504
  ]
505
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
506
  {
507
  "cell_type": "code",
508
  "execution_count": 23,
 
36
  "from sklearn import preprocessing\n",
37
  "from sklearn.metrics import accuracy_score, auc, confusion_matrix, ConfusionMatrixDisplay, roc_curve\n",
38
  "from sklearn.model_selection import StratifiedKFold\n",
39
+ "import torch\n",
40
  "from transformers import BertForTokenClassification\n",
41
  "from transformers import Trainer\n",
42
  "from transformers.training_args import TrainingArguments\n",
 
425
  "## Fine-Tune With Gene Classification Learning Objective and Quantify Predictive Performance"
426
  ]
427
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
428
  {
429
  "cell_type": "code",
430
  "execution_count": null,
 
470
  " \"learning_rate\": max_lr,\n",
471
  " \"do_train\": True,\n",
472
  " \"evaluation_strategy\": \"no\",\n",
473
+ " \"save_strategy\": \"epoch\",\n",
474
  " \"logging_steps\": 100,\n",
475
  " \"group_by_length\": True,\n",
476
  " \"length_column_name\": \"length\",\n",
 
481
  " \"per_device_train_batch_size\": geneformer_batch_size,\n",
482
  " \"per_device_eval_batch_size\": geneformer_batch_size,\n",
483
  " \"num_train_epochs\": epochs,\n",
 
484
  "}"
485
  ]
486
  },
487
+ {
488
+ "cell_type": "code",
489
+ "execution_count": null,
490
+ "metadata": {},
491
+ "outputs": [],
492
+ "source": [
493
+ "# define output directory path\n",
494
+ "current_date = datetime.datetime.now()\n",
495
+ "datestamp = f\"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}\"\n",
496
+ "training_output_dir = f\"/path/to/models/{datestamp}_geneformer_GeneClassifier_dosageTF_L{max_input_size}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_n{subsample_size}_F{freeze_layers}/\"\n",
497
+ "\n",
498
+ "# ensure not overwriting previously saved model\n",
499
+ "ksplit_model_test = os.path.join(training_output_dir, \"ksplit0/models/pytorch_model.bin\")\n",
500
+ "if os.path.isfile(ksplit_model_test) == True:\n",
501
+ " raise Exception(\"Model already saved to this directory.\")\n",
502
+ "\n",
503
+ "# make output directory\n",
504
+ "subprocess.call(f'mkdir {training_output_dir}', shell=True)"
505
+ ]
506
+ },
507
  {
508
  "cell_type": "code",
509
  "execution_count": 23,