File size: 3,938 Bytes
90ae92b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80baea1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 1 Prepate dataset\n",
    "# 2 Load pretrained Tokenizer, call it with dataset -> encoding\n",
    "# 3 Build PyTorch Dataset with encodings\n",
    "# 4 Load pretrained model\n",
    "# 5 a) Load Trainer and train it\n",
    "#   b) or use native Pytorch training pipeline\n",
    "from pathlib import Path\n",
    "from sklearn.model_selection import train_test_split\n",
    "import torch\n",
    "from torch.utils.data import Dataset\n",
    "from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification\n",
    "from transformers import Trainer, TrainingArguments\n",
    "\n",
    "model_name = \"distilbert-base-uncased\"\n",
    "\n",
    "def read_imdb_split(split_dir): # helper function to get text and label\n",
    "    split_dir = Path(split_dir)\n",
    "    texts = []\n",
    "    labels = []\n",
    "    for label_dir in [\"pos\", \"neg\"]:\n",
    "        thres = 0\n",
    "        for text_file in (split_dir/label_dir).iterdir():\n",
    "            if thres < 100:\n",
    "                f = open(text_file, encoding='utf8')\n",
    "                texts.append(f.read())\n",
    "                labels.append(0 if label_dir == \"neg\" else 1)\n",
    "                thres += 1\n",
    "\n",
    "    return texts, labels\n",
    "\n",
    "train_texts, train_labels = read_imdb_split(\"aclImdb/train\")\n",
    "test_texts, test_labels = read_imdb_split(\"aclImdb/test\")\n",
    "\n",
    "train_texts, val_texts, train_labels, val_labels = train_test_split(train_texts, train_labels, test_size=.2)\n",
    "\n",
    "\n",
    "class IMDBDataset(Dataset):\n",
    "    def __init__(self, encodings, labels):\n",
    "        self.encodings = encodings\n",
    "        self.labels = labels\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}\n",
    "        item[\"labels\"] = torch.tensor(self.labels[idx])\n",
    "        return item\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.labels)\n",
    "    \n",
    "tokenizer = DistilBertTokenizerFast.from_pretrained(model_name)\n",
    "\n",
    "train_encodings = tokenizer(train_texts, truncation=True, padding=True)\n",
    "val_encodings = tokenizer(val_texts, truncation=True, padding=True)\n",
    "test_encodings = tokenizer(test_texts, truncation=True, padding=True)\n",
    "\n",
    "train_dataset = IMDBDataset(train_encodings, train_labels)\n",
    "val_dataset = IMDBDataset(val_encodings, val_labels)\n",
    "test_dataset = IMDBDataset(test_encodings, test_labels)\n",
    "\n",
    "training_args = TrainingArguments(\n",
    "    output_dir='./results',\n",
    "    num_train_epochs=2,\n",
    "    per_device_train_batch_size=16,\n",
    "    per_device_eval_batch_size=64,\n",
    "    warmup_steps=500,\n",
    "    learning_rate=5e-5,\n",
    "    weight_decay=0.01,\n",
    "    logging_dir='./logs',\n",
    "    logging_steps=10\n",
    ")\n",
    "\n",
    "model = DistilBertForSequenceClassification.from_pretrained(model_name)\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=train_dataset,\n",
    "    eval_dataset=val_dataset\n",
    ")\n",
    "\n",
    "trainer.train() \n",
    "\n",
    "\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}