makiisthebes commited on
Commit
5357119
·
verified ·
1 Parent(s): 4fd64a9

Bert Image Classifier

Browse files
Files changed (4) hide show
  1. CW1_Report.pdf +0 -0
  2. bert.pth +3 -0
  3. bert_image_classifier.py +210 -0
  4. testing_dataset.ipynb +99 -0
CW1_Report.pdf ADDED
Binary file (218 kB). View file
 
bert.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:614ca99aac317874fa033c37e7617881330b5617ec0779e4ec7c51c384d39ee0
3
+ size 38605699
bert_image_classifier.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Michael Peres ~ 09/01/2024
2
+ # Bert Based Transformer Model for Image Classification
3
+ # ----------------------------------------------------------------------------------------------------------------------
4
+ # Import Modules
5
+ # pip install transformers torchvision
6
+ from transformers import BertModel, BertTokenizer, BertConfig
7
+ from transformers import get_linear_schedule_with_warmup
8
+ from transformers import BertForSequenceClassification
9
+ from torchvision.utils import make_grid, save_image
10
+ from torch.utils.data import Dataset, DataLoader
11
+ from torchvision.datasets import MNIST, CIFAR10
12
+ from torchvision import datasets, transforms
13
+ from tqdm.notebook import tqdm, trange
14
+ from torch.optim import AdamW, Adam
15
+ import matplotlib.pyplot as plt
16
+ import torch.nn.functional as F
17
+ import math, os, torch
18
+ import torch.nn as nn
19
+
20
+ # ----------------------------------------------------------------------------------------------------------------------
21
+ # This is a simple implementation, where the first hidden state,
22
+ # which is the encoded class token is used as the input to a MLP Head for classification.
23
+ # The model is trained on CIFAR-10 dataset, which is a dataset of 60,000 32x32 color images in 10 classes,
24
+ # with 6,000 images per class.
25
+ # This model will only contain the encoder part of the BERT model, and the classification head.
26
+
27
+
28
+ # ----------------------------------------------------------------------------------------------------------------------
29
+ # Some understanding of the BERT model is required to understand this code, here are the dimensions and documentation.
30
+ # From documentation, https://huggingface.co/transformers/v3.0.2/model_doc/bert.html
31
+
32
+ # BERT Parameters include:
33
+
34
+ # - hidden size: 256
35
+ # - intermediate size: 1024
36
+ # - number of hidden_layers: 12
37
+ # - num of attention heads: 8
38
+ # - max position embeddings: 256
39
+ # - vocab size: 100
40
+ # - bos_token_id: 101
41
+ # - eod_token_id: 102
42
+ # - cls_token_id: 103
43
+
44
+ # But what do all of these mean in terms of the question.
45
+
46
+ # Hidden size, this represents the dimensionality of the input embeddings D.
47
+
48
+ # Intermediate size is the number of neurons in the hidden layer of the feedforward,
49
+ # the feed forward would have dims, Hidden Size D -> Intermediate Size -> Hidden Size D
50
+
51
+ # Num of hidden layers, means the number of hidden layers in the transformer encoder,
52
+ # layers refer to transformer blocks, so more transformer blocks in the model.
53
+
54
+ # Num of attention heads, refers to the number multihead attention modules within one hidden layer.abs
55
+
56
+ # Max position embeddings refers to the max size of an input the model can handle, this should be larger for models that handle larger inputs etc.abs
57
+
58
+ # vocab size refers to the set of tokens the model is trained on, which has a specific length,
59
+ # in our case it is 100, which is confusing, because we have pixel intensities between 0-255.
60
+
61
+ # bos token is the beginning of a sentence token, which is token id, good for understanding sentence boundaries for text generation tasks.abs
62
+
63
+ # eos token id is end of sentence token, which I dont see in the documentation for bert config.
64
+
65
+ # cls token id is token is inputted at the beginning of each input instances.
66
+
67
+ # output_hidden_states = True, means to output all the hidden states for us to view.
68
+
69
+
70
+ # ----------------------------------------------------------------------------------------------------------------------
71
+
72
+ # Preparing CIFAR10 Image Dataset, and DataLoaders for Training and Testing
73
+ dataset = CIFAR10(root='./data/', train=True, download=True, transform=
74
+ transforms.Compose([
75
+ transforms.RandomHorizontalFlip(),
76
+ transforms.RandomCrop(32, padding=4),
77
+ transforms.ToTensor(),
78
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
79
+ ]))
80
+ # augmentations are super important for CNN trainings, or it will overfit very fast without achieving good generalization accuracy
81
+
82
+ val_dataset = CIFAR10(root='./data/', train=False, download=True, transform=transforms.Compose(
83
+ [transforms.ToTensor(),
84
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]
85
+ ))
86
+
87
+ # Model Configuration and Hyperparameters
88
+ config = BertConfig(hidden_size=256, intermediate_size=1024, num_hidden_layers=12, num_attention_heads=8, max_position_embeddings=256, vocab_size=100, bos_token_id=101, eos_token_id=102, cls_token_id=103, output_hidden_states=False)
89
+
90
+ model = BertModel(config).cuda()
91
+ patch_embed = nn.Conv2d(3, config.hidden_size, kernel_size=4, stride=4).cuda()
92
+ CLS_token = nn.Parameter(torch.randn(1, 1, config.hidden_size, device="cuda") / math.sqrt(config.hidden_size))
93
+ readout = nn.Sequential(nn.Linear(config.hidden_size, config.hidden_size),
94
+ nn.GELU(),
95
+ nn.Linear(config.hidden_size, 10)
96
+ ).cuda()
97
+
98
+ for module in [patch_embed, readout, model, CLS_token]:
99
+ module.cuda()
100
+
101
+ optimizer = AdamW([*model.parameters(),
102
+ *patch_embed.parameters(),
103
+ *readout.parameters(),
104
+ CLS_token], lr=5e-4)
105
+
106
+ # DataLoaders
107
+ batch_size = 192 # 96
108
+ train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
109
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
110
+
111
+ # ----------------------------------------------------------------------------------------------------------------------
112
+ # Understanding ClS Token:
113
+ # print("CLASS TOKEN shape:")
114
+ # print(CLS_token.shape)
115
+ #
116
+ # reshaped_cls = CLS_token.expand(192, 1, -1)
117
+ # print("CLS Reshaped shape", reshaped_cls.shape) # 192, 1, 256
118
+ # # We are telling the CLS to have the same shape as patch embeddings.
119
+ #
120
+ # imgs, labels = next(iter(train_loader))
121
+ # patch_embs = patch_embed(imgs.cuda()).flatten(2).permute(0, 2, 1)
122
+ #
123
+ # input_embs = torch.cat([reshaped_cls, patch_embs], dim=1)
124
+ # print("Patch Embeddings Shape", patch_embs.shape)
125
+ #
126
+ # print("Input Embedding Shape", input_embs.shape)
127
+
128
+ # ----------------------------------------------------------------------------------------------------------------------
129
+ # Understanding Output of Model Transformer:
130
+
131
+ # Hidden State state dimension: 192, 12, 65, 256
132
+ # Last Hidden state dimension: 192, 65 256
133
+ # Pooler Output: 192, 256
134
+
135
+ # in essence pool all the tokens outputs, so we have a one value per complete sample,
136
+ # completely removing the information for each token.
137
+
138
+ #
139
+ # # We should understand output of a model,
140
+ # representations = output.last_hidden_state[:, 0, :]
141
+ # print(output.last_hidden_state.shape) # Out of memory.
142
+ # print(representations.shape)
143
+
144
+ # ----------------------------------------------------------------------------------------------------------------------
145
+
146
+ # Training Loop
147
+ EPOCHS = 30
148
+
149
+ model.train()
150
+ loss_list = []
151
+ acc_list = []
152
+ correct_cnt = 0
153
+ total_loss = 0
154
+ for epoch in trange(EPOCHS, leave=False):
155
+ pbar = tqdm(train_loader, leave=False)
156
+ for i, (imgs, labels) in enumerate(pbar):
157
+ patch_embs = patch_embed(imgs.cuda()) # patch embeddings,
158
+ # print("patch embs shape ", patch_embs.shape) # (192, 256, 8, 8) # 192 per batch,
159
+ patch_embs = patch_embs.flatten(2).permute(0, 2, 1) # (batch_size, HW, hidden=256)
160
+ # print(patch_embs.shape)
161
+ input_embs = torch.cat([CLS_token.expand(imgs.shape[0], 1, -1), patch_embs], dim=1)
162
+ # print(input_embs.shape)
163
+ output = model(inputs_embeds=input_embs)
164
+ # print(dir(output))
165
+ # print("output, hidden state shape", output.hidden_states) # out of memory error.
166
+ # print("output hidden state shape", output.last_hidden_state.shape) # 192, 65, 256
167
+ # print("output pooler output shape", output.pooler_output.shape)
168
+ logit = readout(output.last_hidden_state[:, 0, :])
169
+ loss = F.cross_entropy(logit, labels.cuda())
170
+ # print(loss)
171
+ loss.backward()
172
+ optimizer.step()
173
+ optimizer.zero_grad()
174
+ pbar.set_description(f"loss: {loss.item():.4f}")
175
+ total_loss += loss.item() * imgs.shape[0]
176
+ correct_cnt += (logit.argmax(dim=1) == labels.cuda()).sum().item()
177
+
178
+ loss_list.append(round(total_loss / len(dataset), 4))
179
+ acc_list.append(round(correct_cnt / len(dataset), 4))
180
+ # test on validation set
181
+ model.eval()
182
+ correct_cnt = 0
183
+ total_loss = 0
184
+
185
+ for i, (imgs, labels) in enumerate(val_loader):
186
+ patch_embs = patch_embed(imgs.cuda())
187
+ patch_embs = patch_embs.flatten(2).permute(0, 2, 1) # (batch_size, HW, hidden)
188
+ input_embs = torch.cat([CLS_token.expand(imgs.shape[0], 1, -1), patch_embs], dim=1)
189
+ output = model(inputs_embeds=input_embs)
190
+ logit = readout(output.last_hidden_state[:, 0, :])
191
+ loss = F.cross_entropy(logit, labels.cuda())
192
+ total_loss += loss.item() * imgs.shape[0]
193
+ correct_cnt += (logit.argmax(dim=1) == labels.cuda()).sum().item()
194
+
195
+ print(f"val loss: {total_loss / len(val_dataset):.4f}, val acc: {correct_cnt / len(val_dataset):.4f}")
196
+
197
+ # Plotting Loss and Accuracy
198
+ plt.figure()
199
+ plt.plot(loss_list, label="loss")
200
+ plt.plot(acc_list, label="accuracy")
201
+ plt.legend()
202
+ plt.show()
203
+ # ----------------------------------------------------------------------------------------------------------------------
204
+
205
+ # Saving Model Parameters
206
+ torch.save(model.state_dict(), "bert.pth")
207
+
208
+ # ----------------------------------------------------------------------------------------------------------------------
209
+ # Reference: Tutorial for Harvard Medical School ML from Scratch Series: Transformer from Scratch
210
+ # ----------------------------------------------------------------------------------------------------------------------
testing_dataset.ipynb ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 32,
6
+ "id": "1efa9df0-5f50-415c-b574-fae1236cb2b7",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "Files already downloaded and verified\n"
14
+ ]
15
+ }
16
+ ],
17
+ "source": [
18
+ "from torch.utils.data import Dataset, DataLoader\n",
19
+ "from torchvision.datasets import MNIST, CIFAR10\n",
20
+ "from torchvision import datasets, transforms\n",
21
+ "\n",
22
+ "\n",
23
+ "dataset = CIFAR10(root='./data/', train=True, download=True, transform=\n",
24
+ "transforms.Compose([\n",
25
+ " transforms.RandomHorizontalFlip(),\n",
26
+ " transforms.RandomCrop(32, padding=4),\n",
27
+ " transforms.ToTensor(),\n",
28
+ " transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
29
+ "]))\n",
30
+ "batch_size = 192 # 96\n",
31
+ "train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)\n"
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "code",
36
+ "execution_count": 36,
37
+ "id": "e9254b0b-0b70-4d87-b9eb-62a37212ba5a",
38
+ "metadata": {},
39
+ "outputs": [
40
+ {
41
+ "name": "stderr",
42
+ "output_type": "stream",
43
+ "text": [
44
+ "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n"
45
+ ]
46
+ },
47
+ {
48
+ "name": "stdout",
49
+ "output_type": "stream",
50
+ "text": [
51
+ "torch.Size([3, 32, 32])\n"
52
+ ]
53
+ },
54
+ {
55
+ "data": {
56
+ "image/png": "",
57
+ "text/plain": [
58
+ "<Figure size 640x480 with 1 Axes>"
59
+ ]
60
+ },
61
+ "metadata": {},
62
+ "output_type": "display_data"
63
+ }
64
+ ],
65
+ "source": [
66
+ "img = next(iter(train_loader))[0]\n",
67
+ "print(img[0].shape)\n",
68
+ "\n",
69
+ "\n",
70
+ "from matplotlib import pyplot as plt\n",
71
+ "plt.imshow(img[0].permute(1,2 , 0), interpolation='nearest')\n",
72
+ "plt.show()\n",
73
+ "\n",
74
+ "\n"
75
+ ]
76
+ }
77
+ ],
78
+ "metadata": {
79
+ "kernelspec": {
80
+ "display_name": "Python 3 (ipykernel)",
81
+ "language": "python",
82
+ "name": "python3"
83
+ },
84
+ "language_info": {
85
+ "codemirror_mode": {
86
+ "name": "ipython",
87
+ "version": 3
88
+ },
89
+ "file_extension": ".py",
90
+ "mimetype": "text/x-python",
91
+ "name": "python",
92
+ "nbconvert_exporter": "python",
93
+ "pygments_lexer": "ipython3",
94
+ "version": "3.10.13"
95
+ }
96
+ },
97
+ "nbformat": 4,
98
+ "nbformat_minor": 5
99
+ }