makiisthebes commited on
Commit
5357119
·
verified ·
1 Parent(s): 4fd64a9

Bert Image Classifier

Browse files
Files changed (4) hide show
  1. CW1_Report.pdf +0 -0
  2. bert.pth +3 -0
  3. bert_image_classifier.py +210 -0
  4. testing_dataset.ipynb +99 -0
CW1_Report.pdf ADDED
Binary file (218 kB). View file
 
bert.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:614ca99aac317874fa033c37e7617881330b5617ec0779e4ec7c51c384d39ee0
3
+ size 38605699
bert_image_classifier.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Michael Peres ~ 09/01/2024
2
+ # Bert Based Transformer Model for Image Classification
3
+ # ----------------------------------------------------------------------------------------------------------------------
4
+ # Import Modules
5
+ # pip install transformers torchvision
6
+ from transformers import BertModel, BertTokenizer, BertConfig
7
+ from transformers import get_linear_schedule_with_warmup
8
+ from transformers import BertForSequenceClassification
9
+ from torchvision.utils import make_grid, save_image
10
+ from torch.utils.data import Dataset, DataLoader
11
+ from torchvision.datasets import MNIST, CIFAR10
12
+ from torchvision import datasets, transforms
13
+ from tqdm.notebook import tqdm, trange
14
+ from torch.optim import AdamW, Adam
15
+ import matplotlib.pyplot as plt
16
+ import torch.nn.functional as F
17
+ import math, os, torch
18
+ import torch.nn as nn
19
+
20
+ # ----------------------------------------------------------------------------------------------------------------------
21
+ # This is a simple implementation, where the first hidden state,
22
+ # which is the encoded class token is used as the input to a MLP Head for classification.
23
+ # The model is trained on CIFAR-10 dataset, which is a dataset of 60,000 32x32 color images in 10 classes,
24
+ # with 6,000 images per class.
25
+ # This model will only contain the encoder part of the BERT model, and the classification head.
26
+
27
+
28
+ # ----------------------------------------------------------------------------------------------------------------------
29
+ # Some understanding of the BERT model is required to understand this code, here are the dimensions and documentation.
30
+ # From documentation, https://huggingface.co/transformers/v3.0.2/model_doc/bert.html
31
+
32
+ # BERT Parameters include:
33
+
34
+ # - hidden size: 256
35
+ # - intermediate size: 1024
36
+ # - number of hidden_layers: 12
37
+ # - num of attention heads: 8
38
+ # - max position embeddings: 256
39
+ # - vocab size: 100
40
+ # - bos_token_id: 101
41
+ # - eod_token_id: 102
42
+ # - cls_token_id: 103
43
+
44
+ # But what do all of these mean in terms of the question.
45
+
46
+ # Hidden size, this represents the dimensionality of the input embeddings D.
47
+
48
+ # Intermediate size is the number of neurons in the hidden layer of the feedforward,
49
+ # the feed forward would have dims, Hidden Size D -> Intermediate Size -> Hidden Size D
50
+
51
+ # Num of hidden layers, means the number of hidden layers in the transformer encoder,
52
+ # layers refer to transformer blocks, so more transformer blocks in the model.
53
+
54
+ # Num of attention heads, refers to the number multihead attention modules within one hidden layer.abs
55
+
56
+ # Max position embeddings refers to the max size of an input the model can handle, this should be larger for models that handle larger inputs etc.abs
57
+
58
+ # vocab size refers to the set of tokens the model is trained on, which has a specific length,
59
+ # in our case it is 100, which is confusing, because we have pixel intensities between 0-255.
60
+
61
+ # bos token is the beginning of a sentence token, which is token id, good for understanding sentence boundaries for text generation tasks.abs
62
+
63
+ # eos token id is end of sentence token, which I dont see in the documentation for bert config.
64
+
65
+ # cls token id is token is inputted at the beginning of each input instances.
66
+
67
+ # output_hidden_states = True, means to output all the hidden states for us to view.
68
+
69
+
70
+ # ----------------------------------------------------------------------------------------------------------------------
71
+
72
+ # Preparing CIFAR10 Image Dataset, and DataLoaders for Training and Testing
73
+ dataset = CIFAR10(root='./data/', train=True, download=True, transform=
74
+ transforms.Compose([
75
+ transforms.RandomHorizontalFlip(),
76
+ transforms.RandomCrop(32, padding=4),
77
+ transforms.ToTensor(),
78
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
79
+ ]))
80
+ # augmentations are super important for CNN trainings, or it will overfit very fast without achieving good generalization accuracy
81
+
82
+ val_dataset = CIFAR10(root='./data/', train=False, download=True, transform=transforms.Compose(
83
+ [transforms.ToTensor(),
84
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]
85
+ ))
86
+
87
+ # Model Configuration and Hyperparameters
88
+ config = BertConfig(hidden_size=256, intermediate_size=1024, num_hidden_layers=12, num_attention_heads=8, max_position_embeddings=256, vocab_size=100, bos_token_id=101, eos_token_id=102, cls_token_id=103, output_hidden_states=False)
89
+
90
+ model = BertModel(config).cuda()
91
+ patch_embed = nn.Conv2d(3, config.hidden_size, kernel_size=4, stride=4).cuda()
92
+ CLS_token = nn.Parameter(torch.randn(1, 1, config.hidden_size, device="cuda") / math.sqrt(config.hidden_size))
93
+ readout = nn.Sequential(nn.Linear(config.hidden_size, config.hidden_size),
94
+ nn.GELU(),
95
+ nn.Linear(config.hidden_size, 10)
96
+ ).cuda()
97
+
98
+ for module in [patch_embed, readout, model, CLS_token]:
99
+ module.cuda()
100
+
101
+ optimizer = AdamW([*model.parameters(),
102
+ *patch_embed.parameters(),
103
+ *readout.parameters(),
104
+ CLS_token], lr=5e-4)
105
+
106
+ # DataLoaders
107
+ batch_size = 192 # 96
108
+ train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
109
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
110
+
111
+ # ----------------------------------------------------------------------------------------------------------------------
112
+ # Understanding ClS Token:
113
+ # print("CLASS TOKEN shape:")
114
+ # print(CLS_token.shape)
115
+ #
116
+ # reshaped_cls = CLS_token.expand(192, 1, -1)
117
+ # print("CLS Reshaped shape", reshaped_cls.shape) # 192, 1, 256
118
+ # # We are telling the CLS to have the same shape as patch embeddings.
119
+ #
120
+ # imgs, labels = next(iter(train_loader))
121
+ # patch_embs = patch_embed(imgs.cuda()).flatten(2).permute(0, 2, 1)
122
+ #
123
+ # input_embs = torch.cat([reshaped_cls, patch_embs], dim=1)
124
+ # print("Patch Embeddings Shape", patch_embs.shape)
125
+ #
126
+ # print("Input Embedding Shape", input_embs.shape)
127
+
128
+ # ----------------------------------------------------------------------------------------------------------------------
129
+ # Understanding Output of Model Transformer:
130
+
131
+ # Hidden State state dimension: 192, 12, 65, 256
132
+ # Last Hidden state dimension: 192, 65 256
133
+ # Pooler Output: 192, 256
134
+
135
+ # in essence pool all the tokens outputs, so we have a one value per complete sample,
136
+ # completely removing the information for each token.
137
+
138
+ #
139
+ # # We should understand output of a model,
140
+ # representations = output.last_hidden_state[:, 0, :]
141
+ # print(output.last_hidden_state.shape) # Out of memory.
142
+ # print(representations.shape)
143
+
144
+ # ----------------------------------------------------------------------------------------------------------------------
145
+
146
+ # Training Loop
147
+ EPOCHS = 30
148
+
149
+ model.train()
150
+ loss_list = []
151
+ acc_list = []
152
+ correct_cnt = 0
153
+ total_loss = 0
154
+ for epoch in trange(EPOCHS, leave=False):
155
+ pbar = tqdm(train_loader, leave=False)
156
+ for i, (imgs, labels) in enumerate(pbar):
157
+ patch_embs = patch_embed(imgs.cuda()) # patch embeddings,
158
+ # print("patch embs shape ", patch_embs.shape) # (192, 256, 8, 8) # 192 per batch,
159
+ patch_embs = patch_embs.flatten(2).permute(0, 2, 1) # (batch_size, HW, hidden=256)
160
+ # print(patch_embs.shape)
161
+ input_embs = torch.cat([CLS_token.expand(imgs.shape[0], 1, -1), patch_embs], dim=1)
162
+ # print(input_embs.shape)
163
+ output = model(inputs_embeds=input_embs)
164
+ # print(dir(output))
165
+ # print("output, hidden state shape", output.hidden_states) # out of memory error.
166
+ # print("output hidden state shape", output.last_hidden_state.shape) # 192, 65, 256
167
+ # print("output pooler output shape", output.pooler_output.shape)
168
+ logit = readout(output.last_hidden_state[:, 0, :])
169
+ loss = F.cross_entropy(logit, labels.cuda())
170
+ # print(loss)
171
+ loss.backward()
172
+ optimizer.step()
173
+ optimizer.zero_grad()
174
+ pbar.set_description(f"loss: {loss.item():.4f}")
175
+ total_loss += loss.item() * imgs.shape[0]
176
+ correct_cnt += (logit.argmax(dim=1) == labels.cuda()).sum().item()
177
+
178
+ loss_list.append(round(total_loss / len(dataset), 4))
179
+ acc_list.append(round(correct_cnt / len(dataset), 4))
180
+ # test on validation set
181
+ model.eval()
182
+ correct_cnt = 0
183
+ total_loss = 0
184
+
185
+ for i, (imgs, labels) in enumerate(val_loader):
186
+ patch_embs = patch_embed(imgs.cuda())
187
+ patch_embs = patch_embs.flatten(2).permute(0, 2, 1) # (batch_size, HW, hidden)
188
+ input_embs = torch.cat([CLS_token.expand(imgs.shape[0], 1, -1), patch_embs], dim=1)
189
+ output = model(inputs_embeds=input_embs)
190
+ logit = readout(output.last_hidden_state[:, 0, :])
191
+ loss = F.cross_entropy(logit, labels.cuda())
192
+ total_loss += loss.item() * imgs.shape[0]
193
+ correct_cnt += (logit.argmax(dim=1) == labels.cuda()).sum().item()
194
+
195
+ print(f"val loss: {total_loss / len(val_dataset):.4f}, val acc: {correct_cnt / len(val_dataset):.4f}")
196
+
197
+ # Plotting Loss and Accuracy
198
+ plt.figure()
199
+ plt.plot(loss_list, label="loss")
200
+ plt.plot(acc_list, label="accuracy")
201
+ plt.legend()
202
+ plt.show()
203
+ # ----------------------------------------------------------------------------------------------------------------------
204
+
205
+ # Saving Model Parameters
206
+ torch.save(model.state_dict(), "bert.pth")
207
+
208
+ # ----------------------------------------------------------------------------------------------------------------------
209
+ # Reference: Tutorial for Harvard Medical School ML from Scratch Series: Transformer from Scratch
210
+ # ----------------------------------------------------------------------------------------------------------------------
testing_dataset.ipynb ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 32,
6
+ "id": "1efa9df0-5f50-415c-b574-fae1236cb2b7",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "Files already downloaded and verified\n"
14
+ ]
15
+ }
16
+ ],
17
+ "source": [
18
+ "from torch.utils.data import Dataset, DataLoader\n",
19
+ "from torchvision.datasets import MNIST, CIFAR10\n",
20
+ "from torchvision import datasets, transforms\n",
21
+ "\n",
22
+ "\n",
23
+ "dataset = CIFAR10(root='./data/', train=True, download=True, transform=\n",
24
+ "transforms.Compose([\n",
25
+ " transforms.RandomHorizontalFlip(),\n",
26
+ " transforms.RandomCrop(32, padding=4),\n",
27
+ " transforms.ToTensor(),\n",
28
+ " transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
29
+ "]))\n",
30
+ "batch_size = 192 # 96\n",
31
+ "train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)\n"
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "code",
36
+ "execution_count": 36,
37
+ "id": "e9254b0b-0b70-4d87-b9eb-62a37212ba5a",
38
+ "metadata": {},
39
+ "outputs": [
40
+ {
41
+ "name": "stderr",
42
+ "output_type": "stream",
43
+ "text": [
44
+ "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n"
45
+ ]
46
+ },
47
+ {
48
+ "name": "stdout",
49
+ "output_type": "stream",
50
+ "text": [
51
+ "torch.Size([3, 32, 32])\n"
52
+ ]
53
+ },
54
+ {
55
+ "data": {
56
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaEAAAGdCAYAAAC7EMwUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy81sbWrAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAj90lEQVR4nO3df3DU9b3v8deKYSWY7C3GZDclZCLyoxpkCmiAogZack0tA0Z7UG694fZcrygwh4uObWRa03ZKONwrg71Ueqq9iFMpnN4COhcE0oEEOYgFhCOiMCixidekOeHobkwwIHzvH47bRhC+77DLJ7t5Pma+M2T3nc++v/lu9sU3u/vegOd5ngAAcOAK1w0AAPouQggA4AwhBABwhhACADhDCAEAnCGEAADOEEIAAGcIIQCAM1e6buCLzp49qw8++EBZWVkKBAKu2wEAGHmep/b2duXn5+uKKy58rtPrQuiDDz5QQUGB6zYAAJeoqalJgwcPvmBN0v4c9/TTT6uoqEhXXXWVxo4dq1deecXX92VlZSWrJQDAZeTn8TwpIbRu3TotWLBAixYt0oEDB3TrrbeqvLxcjY2NF/1e/gQHAOnBz+N5IBkDTEtKSjRmzBitXLkyftnXvvY1zZgxQzU1NRf83lgsplAolOiWAACXWTQaVXZ29gVrEn4mdOrUKe3fv19lZWXdLi8rK9Pu3bvPqe/q6lIsFuu2AQD6hoSHUFtbm86cOaO8vLxul+fl5amlpeWc+pqaGoVCofjGixIAoO9I2gsTvvi3QM/zzvv3waqqKkWj0fjW1NSUrJYAAL1Mwl+inZOTo379+p1z1tPa2nrO2ZEkBYNBBYPBRLcBAEgBCT8T6t+/v8aOHava2tpul9fW1mrixImJvjkAQApLyptVFy5cqPvvv1/jxo3ThAkT9Otf/1qNjY2aM2dOMm4OAJCikhJCM2fO1IkTJ/TTn/5Uzc3NKi4u1ubNm1VYWJiMmwMApKikvE/oUnz+PiE/ry/vW04Z69sMtfnGta13GUsv1veI9TfU/ptp5Tbt8F2boQzT2iGNNNV3Kuq79qj+xbS2dNp35TUKm1Y+YVj7tDJNa1+nIYZq/318Vm25z0pSp+9K631Fhp/LAb1pWtlyvypWse/a9thJfT30sJv3CQEA4BchBABwhhACADhDCAEAnCGEAADOEEIAAGcIIQCAM4QQAMAZQggA4AwhBABwJimz4xLhhI7qlK72VZtjGt8xsGcNOWcd9ZHMj0i/+OfGd+d/pIl1Pxv1mu9a6yiWTMNIk6hxLEyncUSN5T6erRGmtTNMx8f2MzyiFw3VtvvsLfqBoY9dprVbdNxUbzma2coxrZ0hy/gy//dZSYqYOrfcT076ruRMCADgDCEEAHCGEAIAOEMIAQCcIYQAAM4QQgAAZwghAIAzhBAAwBlCCADgDCEEAHCGEAIAONNrZ8ddoxHKNs1MSkUdhlrbzC4b27wp+3w3/3O7rBPymgy9t6nFtHamYVbWaWPnIzTGVJ+hrxlqbb3E9K7v2iOqM61dq22+a20T1aSIaWakZe6Z9K5eN9V3Gu6HIzXSuLb/+W7NxseJkGFe3xHDz+SkYZYiZ0IAAGcIIQCAM4QQAMAZQggA4AwhBABwhhACADhDCAEAnCGEAADOEEIAAGcIIQCAM712bE8qimq/qT5kGq/if3THZ/yPETltGKsjSUfVZKq3DEwp0HWmtTMNP5c24+gWyyiWIfq6ae2hGm+qtxii6031bRrqu/aI3jSt7X94i31k0xEd8V07UhHj6jZNavRdGzXeDy0DjTKNw48ihrE97xpGAn2iT33XciYEAHCGEAIAOEMIAQCcIYQAAM4QQgAAZwghAIAzhBAAwBlCCADgDCEEAHCGEAIAOEMIAQCcYXbcRX3ku/KAXjetHDbMbco0TeGSMg0zvo7qkGntQ+YZbP7ndo0yzmB7TUd91540/gxHGHr5nXGm2ozZA031es1/779/+/+Zlm42zMh701ArSUM0xnftdcb5iDkaaai1zVTLMc6aazP8XKzz3XIMP8ORGmFaO9v0GDTEd22HPpH0iq9azoQAAM4kPISqq6sVCAS6beFwONE3AwBIA0n5c9yNN96oP/7xj/Gv+/Xrl4ybAQCkuKSE0JVXXsnZDwDgopLynNCxY8eUn5+voqIi3XvvvTp+/PiX1nZ1dSkWi3XbAAB9Q8JDqKSkRM8//7y2bt2qZ555Ri0tLZo4caJOnDhx3vqamhqFQqH4VlBQkOiWAAC9VMJDqLy8XHfffbdGjRqlb33rW9q0aZMkafXq1eetr6qqUjQajW9NTbaPjgYApK6kv09o4MCBGjVqlI4dO3be64PBoILBYLLbAAD0Qkl/n1BXV5fefvttRSK2N38BANJfwkPo0UcfVX19vRoaGvTaa6/pnnvuUSwWU2VlZaJvCgCQ4hL+57j3339f9913n9ra2nTttddq/Pjx2rNnjwoLCxN9U5eJ/xE1mcaxIyfU5rv2TX35KwzP30uz79rThnEckvS/tMlU/9YvdviuHXyn/1EskjR16HW+a+/RKNPatdrlu3ZdYKVp7WT6biDXVD/8kVW+a0c8ZHvh0Pyhy33XdhrHXmUYRs50GkdNZWqAqT6kob5ro8bxUUN0i+/ar6vMtLYMj1mZhvFe7eqQtNhXbcJDaO3atYleEgCQppgdBwBwhhACADhDCAEAnCGEAADOEEIAAGcIIQCAM4QQAMAZQggA4AwhBABwhhACADiT9I9y6KkWNapDWb5qc5RjWNk2t6nNMN/tFk02rv2u79pG4+y4d5Xhu7bJOPPurZsaTfWGkVN6/x9+bFr6xVeW+67935PGmta+89ZxpvpUNW/MbN+1R9/dblq7eOitvmvflO1+FVXUd22n8fc+2zhPMSL/Mw+PW34hJLUZerfuZ6YG+q6NaLzv2oHy/wnZnAkBAJwhhAAAzhBCAABnCCEAgDOEEADAGUIIAOAMIQQAcIYQAgA4QwgBAJwhhAAAzvTasT1lekBX+Gzv2yrzve4IXWPqI8MwBmOSSkxrX6OI79qQik1r/0mbfdc2q8C0tjJn2eq1xlBrG33UNOkfbK1Y7Ere0snlf4SMJM33/+ujRVUvmtY+UuZ/tE5UB0xrNxrGXmVqlGnt48ZRVqcNY7JGaKhp7WJD/Wl1mtZu1B7ftUMMY3ssOBMCADhDCAEAnCGEAADOEEIAAGcIIQCAM4QQAMAZQggA4AwhBABwhhACADhDCAEAnCGEAADO9NrZcX+esEfq56/2l7MMQ76GGBsx1E+YZBjCJWm6pvuuzVGOaW3L/LDpss1fW5ezwdjLJt+Vb3vbTStbJnw9uM20dMp67d/eNtX/ackHvmt/9ewvTGtnPOO/9huGWYqSlKM2Q+11prVfV4upXoYZk9/QN0wrDzXMvYsafiaS9KaO+K4dYppJGPNdyZkQAMAZQggA4AwhBABwhhACADhDCAEAnCGEAADOEEIAAGcIIQCAM4QQAMAZQggA4AwhBABwptfOjtNbhto1hlrr7LiQ/9JXN9mGk706ylBfMtS09uBM/3PsbrGN7JI2PWH8hjt9V1qmU1kded3/jLRUdtpwn5Wk7z/5j75rQ5nfNa19jeb4ri0wzkfM1A7ftZ3GtUtk+33boc2+a48ralo7Q42+a5sMtZLUZph59ye97rv2Y3X4ruVMCADgjDmEdu7cqWnTpik/P1+BQEAbN27sdr3neaqurlZ+fr4GDBig0tJSHT58OFH9AgDSiDmEOjo6NHr0aK1YseK81y9dulTLli3TihUrtHfvXoXDYU2dOlXt7e2X3CwAIL2YnxMqLy9XeXn5ea/zPE/Lly/XokWLVFFRIUlavXq18vLytGbNGj344IOX1i0AIK0k9DmhhoYGtbS0qKzsr0+KB4NB3X777dq9e/d5v6erq0uxWKzbBgDoGxIaQi0tn30aYV5eXrfL8/Ly4td9UU1NjUKhUHwrKChIZEsAgF4sKa+OCwQC3b72PO+cyz5XVVWlaDQa35qampLREgCgF0ro+4TC4bCkz86IIpG/vvmktbX1nLOjzwWDQQWDwUS2AQBIEQk9EyoqKlI4HFZtbW38slOnTqm+vl4TJ05M5E0BANKA+Uzo448/1jvvvBP/uqGhQQcPHtSgQYM0ZMgQLViwQIsXL9awYcM0bNgwLV68WJmZmZo1a1ZCGwcApD5zCO3bt0+TJ0+Of71w4UJJUmVlpZ577jk99thjOnnypB5++GF9+OGHKikp0bZt25SVlZW4rr/oiKHWNjFDpmkfbca1mw21q981Lf3+JP/N/OlHpqVlHq4z5B7rDSTFmJH5pvqdSeqjR0b6/0/cNzJsS7/49lP+2zAeessQGeuzwS/qgO9ay8gZSbpH95nqf6dNvmtbjA8U35D/EVyNpgcVKSL/M7uO6/wvLjufLn3iu9YcQqWlpfI870uvDwQCqq6uVnV1tXVpAEAfw+w4AIAzhBAAwBlCCADgDCEEAHCGEAIAOEMIAQCcIYQAAM4QQgAAZwghAIAzhBAAwJmEfpSDM6cNtZZhVtb6Ica1LWOerH2P8d9MsXFpzfqlqbxi8rXWW0iKpTNs9cuT0UQPVdQsT9ra1nlwFpZfiT8Z1/7hxiX+iw/Z1l4/Y5vtG3YYal+zLd1w52r/xdbHIP+j42yPsx/7L+VMCADgDCEEAHCGEAIAOEMIAQCcIYQAAM4QQgAAZwghAIAzhBAAwBlCCADgDCEEAHAmPcb29BbW0TrJNNL/MJ5S49K3ldnG8MyvNN5AkmQY6wfPWO+79v2NFcbVbXIivWP0UTKZx0f92FBrGZElSbuM9W2G2iPGtS1jfixjeCRplKG201B7yn8pZ0IAAGcIIQCAM4QQAMAZQggA4AwhBABwhhACADhDCAEAnCGEAADOEEIAAGcIIQCAM4QQAMCZ3js7boCkgM9ay0yjPmJQjv9a67ipTOMQthHG9XuLv/+vd/mu/cnG5PUhSb/+2XO+a//p/85OWh/JNNL6DZY7rnVe27vG+qih9rRxbctMSssMOyvLPp7xX8qZEADAGUIIAOAMIQQAcIYQAgA4QwgBAJwhhAAAzhBCAABnCCEAgDOEEADAGUIIAOBM7x3bky3/EcnYnnNEO4/7rrX++P5l1zum+to7r/dd+59DxmaS6Pt3+q/9iUqMq79mK++0znrpA8q+6792x+9ta1tG5Uj2UTzJYhmtI9nGGSVpJBBnQgAAZwghAIAz5hDauXOnpk2bpvz8fAUCAW3cuLHb9bNnz1YgEOi2jR8/PlH9AgDSiDmEOjo6NHr0aK1YseJLa+644w41NzfHt82bN19SkwCA9GR+YUJ5ebnKy8svWBMMBhUOh3vcFACgb0jKc0J1dXXKzc3V8OHD9cADD6i1tfVLa7u6uhSLxbptAIC+IeEhVF5erhdeeEHbt2/Xk08+qb1792rKlCnq6uo6b31NTY1CoVB8KygoSHRLAIBeKuHvE5o5c2b838XFxRo3bpwKCwu1adMmVVRUnFNfVVWlhQsXxr+OxWIEEQD0EUl/s2okElFhYaGOHTt23uuDwaCCwWCy2wAA9EJJf5/QiRMn1NTUpEgkkuybAgCkGPOZ0Mcff6x33vnr2JaGhgYdPHhQgwYN0qBBg1RdXa27775bkUhE7733nh5//HHl5OTorrvuSmjjAIDUF/A8z7N8Q11dnSZPnnzO5ZWVlVq5cqVmzJihAwcO6KOPPlIkEtHkyZP1s5/9zPfzPLFYTKFQ6LPZcQGfTVnnJfUFjwz1Xbrxf9pmwc0I/NjWy4wRvku9Df/JtnYvUW0cBfeTW23/KftFxwbftfMzbL2kqsDGF/wX3/U92+LWGYap+hhk2c8e7GM0GlV2dvYFa8xnQqWlpbpQbm3dutW6JACgj2J2HADAGUIIAOAMIQQAcIYQAgA4QwgBAJwhhAAAzhBCAABnCCEAgDOEEADAGUIIAOCMeXZcssVnx+HSlPkv/S/P2O4Cqwr9DvWz62V3x6RpNtYzg/5cNza+7bv2rcIbbIsPMTZjmat22rh2p7E+WSwzCT1Jn/qbHceZEADAGUIIAOAMIQQAcIYQAgA4QwgBAJwhhAAAzhBCAABnCCEAgDOEEADAGUIIAODMla4bQJK867+0dtcryevD6LHXbPVLS5LTR7IxhufSjRoy1HftW9YxPNZROZmGWsv4G0lqNNZbWHrpwdgePzgTAgA4QwgBAJwhhAAAzhBCAABnCCEAgDOEEADAGUIIAOAMIQQAcIYQAgA4QwgBAJwhhAAAzjA7Ll0ZZse1RV9PXh9G/2P8bFP9Uu+5pPSB3u9I2w7/xZbZbpIUNdYnawabZBs0aJ15l6y+z0o66a+UMyEAgDOEEADAGUIIAOAMIQQAcIYQAgA4QwgBAJwhhAAAzhBCAABnCCEAgDOEEADAGcb2QBmZn5rqP0lSH59Zbax/LhlNwIE3jfX/+osn/Be3GRe3jr+xjvmxGJKkWkk6YqgNGWrP+i/lTAgA4IwphGpqanTzzTcrKytLubm5mjFjho4ePdqtxvM8VVdXKz8/XwMGDFBpaakOHz6c0KYBAOnBFEL19fWaO3eu9uzZo9raWn366acqKytTR0dHvGbp0qVatmyZVqxYob179yocDmvq1Klqb29PePMAgNRmek5oy5Yt3b5etWqVcnNztX//ft12223yPE/Lly/XokWLVFFRIUlavXq18vLytGbNGj344IOJ6xwAkPIu6TmhaPSzZ+MGDRokSWpoaFBLS4vKysriNcFgULfffrt279593jW6uroUi8W6bQCAvqHHIeR5nhYuXKhJkyapuLhYktTS0iJJysvL61abl5cXv+6LampqFAqF4ltBQUFPWwIApJgeh9C8efP0xhtv6He/+9051wUCgW5fe553zmWfq6qqUjQajW9NTU09bQkAkGJ69D6h+fPn66WXXtLOnTs1ePDg+OXhcFjSZ2dEkchfP5O2tbX1nLOjzwWDQQWDwZ60AQBIcaYzIc/zNG/ePK1fv17bt29XUVFRt+uLiooUDodVW1sbv+zUqVOqr6/XxIkTE9MxACBtmM6E5s6dqzVr1ujFF19UVlZW/HmeUCikAQMGKBAIaMGCBVq8eLGGDRumYcOGafHixcrMzNSsWbOSsgMAgNRlCqGVK1dKkkpLS7tdvmrVKs2ePVuS9Nhjj+nkyZN6+OGH9eGHH6qkpETbtm1TVlZWQhoGAKSPgOd5nusm/lYsFlMoZBlShEtW80NbfdWS5PTRA4+/4//u+/OhSWykF7GOMWs21GYY17Y49yVOF/ajged/sdN5WWfBpSrr7DjLnSXTUHtW0l8+extPdnb2BUuZHQcAcIYQAgA4QwgBAJwhhAAAzhBCAABnCCEAgDOEEADAGUIIAOAMIQQAcIYQAgA406OPckB6GTz066b695PUR08s/tlW37X//bn/aFo7x9qMQaOx3vJ5w9YJNW2G2tPGtS31/7Rtg23xZI7isYyokZLbi2VWkuVgSlIvmJDGmRAAwBlCCADgDCEEAHCGEAIAOEMIAQCcIYQAAM4QQgAAZwghAIAzhBAAwBlCCADgDCEEAHCG2XFQRiSZg6+SLPqm79I3ZZsdV2DtxaDJWG+ZwdZpHPAWNcwms86OqzXctd7/xfeMqxtYZ8FZdzSZLL0ks+8kPUxwJgQAcIYQAgA4QwgBAJwhhAAAzhBCAABnCCEAgDOEEADAGUIIAOAMIQQAcIYQAgA4w9geKKqjrlvosXHfLfVd22hcu7bNf+01Oba1h9rKTazTVQy7aV573bP/7L94UxLHR6XwZKqk6gU/F86EAADOEEIAAGcIIQCAM4QQAMAZQggA4AwhBABwhhACADhDCAEAnCGEAADOEEIAAGcIIQCAM8yOg/790DbXLfTY388am7S1d+z6yHdtZ2ezae0ZM75mqr/GUHui7ZRp7baM/r5rm0+blpae/T/Gb0Bfw5kQAMAZUwjV1NTo5ptvVlZWlnJzczVjxgwdPdp9AvPs2bMVCAS6bePHj09o0wCA9GAKofr6es2dO1d79uxRbW2tPv30U5WVlamjo6Nb3R133KHm5ub4tnnz5oQ2DQBID6bnhLZs2dLt61WrVik3N1f79+/XbbfdFr88GAwqHA4npkMAQNq6pOeEotGoJGnQoEHdLq+rq1Nubq6GDx+uBx54QK2trV+6RldXl2KxWLcNANA39DiEPM/TwoULNWnSJBUXF8cvLy8v1wsvvKDt27frySef1N69ezVlyhR1dXWdd52amhqFQqH4VlBQ0NOWAAAppscv0Z43b57eeOMN7dq1q9vlM2fOjP+7uLhY48aNU2FhoTZt2qSKiopz1qmqqtLChQvjX8diMYIIAPqIHoXQ/Pnz9dJLL2nnzp0aPHjwBWsjkYgKCwt17Nix814fDAYVDAZ70gYAIMWZQsjzPM2fP18bNmxQXV2dioqKLvo9J06cUFNTkyKRSI+bBACkJ9NzQnPnztVvf/tbrVmzRllZWWppaVFLS4tOnjwpSfr444/16KOP6tVXX9V7772nuro6TZs2TTk5ObrrrruSsgMAgNRlOhNauXKlJKm0tLTb5atWrdLs2bPVr18/HTp0SM8//7w++ugjRSIRTZ48WevWrVNWVlbCmgYApAfzn+MuZMCAAdq6deslNQQHHn3ddQc9FjLUWv8gPGrMf/Bd++u/e9y09r++nmOqHz1mqu/aNw/ZjueZNv9vi6j40Y9Ma6staqtHn8PsOACAM4QQAMAZQggA4AwhBABwhhACADhDCAEAnCGEAADOEEIAAGcIIQCAM4QQAMCZgHexWTyXWSwWUyhkGcaCvmzCni//1N4v+knJtaa1LWN+RuVPMa2t5h22+l7icePDxeLAQEN1p60Z9HrRaFTZ2dkXrOFMCADgDCEEAHCGEAIAOEMIAQCcIYQAAM4QQgAAZwghAIAzhBAAwBlCCADgDCEEAHCGEAIAOHOl6waAS/Hqyv/mu3ZzyQbT2vMsxSOHmNZWs628tzhiHu/GPLi0kWOoPSvp3/2VciYEAHCGEAIAOEMIAQCcIYQAAM4QQgAAZwghAIAzhBAAwBlCCADgDCEEAHCGEAIAOMPYHqS21Rt9ly7P+DvT0tc988/+iy0jTVLY+tUvuG4BrrQlZ1nOhAAAzhBCAABnCCEAgDOEEADAGUIIAOAMIQQAcIYQAgA4QwgBAJwhhAAAzhBCAABnCCEAgDPMjkPf8ezvTeU/n7XVf/HQkLGZFPXw91x3gDTDmRAAwBlTCK1cuVI33XSTsrOzlZ2drQkTJujll1+OX+95nqqrq5Wfn68BAwaotLRUhw8fTnjTAID0YAqhwYMHa8mSJdq3b5/27dunKVOmaPr06fGgWbp0qZYtW6YVK1Zo7969CofDmjp1qtrb25PSPAAgtZlCaNq0afr2t7+t4cOHa/jw4fr5z3+uq6++Wnv27JHneVq+fLkWLVqkiooKFRcXa/Xq1ers7NSaNWuS1T8AIIX1+DmhM2fOaO3atero6NCECRPU0NCglpYWlZWVxWuCwaBuv/127d69+0vX6erqUiwW67YBAPoGcwgdOnRIV199tYLBoObMmaMNGzbohhtuUEtLiyQpLy+vW31eXl78uvOpqalRKBSKbwUFBdaWAAApyhxCI0aM0MGDB7Vnzx499NBDqqys1FtvvRW/PhAIdKv3PO+cy/5WVVWVotFofGtqarK2BABIUeb3CfXv31/XX3+9JGncuHHau3evnnrqKf3gBz+QJLW0tCgSicTrW1tbzzk7+lvBYFDBYNDaBgAgDVzy+4Q8z1NXV5eKiooUDodVW1sbv+7UqVOqr6/XxIkTL/VmAABpyHQm9Pjjj6u8vFwFBQVqb2/X2rVrVVdXpy1btigQCGjBggVavHixhg0bpmHDhmnx4sXKzMzUrFmzktU/ACCFmULoL3/5i+6//341NzcrFArppptu0pYtWzR16lRJ0mOPPaaTJ0/q4Ycf1ocffqiSkhJt27ZNWVlZSWkeSKa/TLnDf/GQ5PUBpLOA53me6yb+ViwWUyjUR+ZwIX1YQ6gxKV0AvUo0GlV2dvYFa5gdBwBwhhACADhDCAEAnCGEAADOEEIAAGcIIQCAM4QQAMAZQggA4AwhBABwxjxFO9l62QAHwJ+zrhsAeh8/j+e9LoTa29tdtwDYve+6AaD3aW9vv+gYtl43O+7s2bP64IMPlJWV1e3D8GKxmAoKCtTU1HTRWUSpjP1MH31hHyX2M90kYj89z1N7e7vy8/N1xRUXftan150JXXHFFRo8ePCXXp+dnZ3Wd4DPsZ/poy/so8R+pptL3U+/g6h5YQIAwBlCCADgTMqEUDAY1BNPPKFgMOi6laRiP9NHX9hHif1MN5d7P3vdCxMAAH1HypwJAQDSDyEEAHCGEAIAOEMIAQCcSZkQevrpp1VUVKSrrrpKY8eO1SuvvOK6pYSqrq5WIBDotoXDYddtXZKdO3dq2rRpys/PVyAQ0MaNG7td73meqqurlZ+frwEDBqi0tFSHDx920+wluNh+zp49+5xjO378eDfN9lBNTY1uvvlmZWVlKTc3VzNmzNDRo0e71aTD8fSzn+lwPFeuXKmbbrop/obUCRMm6OWXX45ffzmPZUqE0Lp167RgwQItWrRIBw4c0K233qry8nI1Nja6bi2hbrzxRjU3N8e3Q4cOuW7pknR0dGj06NFasWLFea9funSpli1bphUrVmjv3r0Kh8OaOnVqys0PvNh+StIdd9zR7dhu3rz5MnZ46err6zV37lzt2bNHtbW1+vTTT1VWVqaOjo54TTocTz/7KaX+8Rw8eLCWLFmiffv2ad++fZoyZYqmT58eD5rLeiy9FHDLLbd4c+bM6XbZyJEjvR/+8IeOOkq8J554whs9erTrNpJGkrdhw4b412fPnvXC4bC3ZMmS+GWffPKJFwqFvF/96lcOOkyML+6n53leZWWlN336dCf9JEtra6snyauvr/c8L32P5xf30/PS83h6nud95Stf8Z599tnLfix7/ZnQqVOntH//fpWVlXW7vKysTLt373bUVXIcO3ZM+fn5Kioq0r333qvjx4+7bilpGhoa1NLS0u24BoNB3X777Wl3XCWprq5Oubm5Gj58uB544AG1tra6bumSRKNRSdKgQYMkpe/x/OJ+fi6djueZM2e0du1adXR0aMKECZf9WPb6EGpra9OZM2eUl5fX7fK8vDy1tLQ46irxSkpK9Pzzz2vr1q165pln1NLSookTJ+rEiROuW0uKz49duh9XSSovL9cLL7yg7du368knn9TevXs1ZcoUdXV1uW6tRzzP08KFCzVp0iQVFxdLSs/jeb79lNLneB46dEhXX321gsGg5syZow0bNuiGG2647Mey103R/jJ/+7EO0md3kC9elsrKy8vj/x41apQmTJigoUOHavXq1Vq4cKHDzpIr3Y+rJM2cOTP+7+LiYo0bN06FhYXatGmTKioqHHbWM/PmzdMbb7yhXbt2nXNdOh3PL9vPdDmeI0aM0MGDB/XRRx/pD3/4gyorK1VfXx+//nIdy15/JpSTk6N+/fqdk8Ctra3nJHU6GThwoEaNGqVjx465biUpPn/lX187rpIUiURUWFiYksd2/vz5eumll7Rjx45uH7mSbsfzy/bzfFL1ePbv31/XX3+9xo0bp5qaGo0ePVpPPfXUZT+WvT6E+vfvr7Fjx6q2trbb5bW1tZo4caKjrpKvq6tLb7/9tiKRiOtWkqKoqEjhcLjbcT116pTq6+vT+rhK0okTJ9TU1JRSx9bzPM2bN0/r16/X9u3bVVRU1O36dDmeF9vP80nF43k+nuepq6vr8h/LhL/UIQnWrl3rZWRkeL/5zW+8t956y1uwYIE3cOBA77333nPdWsI88sgjXl1dnXf8+HFvz5493ne+8x0vKysrpfexvb3dO3DggHfgwAFPkrds2TLvwIED3p///GfP8zxvyZIlXigU8tavX+8dOnTIu++++7xIJOLFYjHHndtcaD/b29u9Rx55xNu9e7fX0NDg7dixw5swYYL31a9+NaX286GHHvJCoZBXV1fnNTc3x7fOzs54TTocz4vtZ7ocz6qqKm/nzp1eQ0OD98Ybb3iPP/64d8UVV3jbtm3zPO/yHsuUCCHP87xf/vKXXmFhode/f39vzJgx3V4ymQ5mzpzpRSIRLyMjw8vPz/cqKiq8w4cPu27rkuzYscOTdM5WWVnped5nL+t94oknvHA47AWDQe+2227zDh065LbpHrjQfnZ2dnplZWXetdde62VkZHhDhgzxKisrvcbGRtdtm5xv/yR5q1atitekw/G82H6my/H8/ve/H388vfbaa71vfvOb8QDyvMt7LPkoBwCAM73+OSEAQPoihAAAzhBCAABnCCEAgDOEEADAGUIIAOAMIQQAcIYQAgA4QwgBAJwhhAAAzhBCAABnCCEAgDP/H/zeMlwAFnOKAAAAAElFTkSuQmCC",
57
+ "text/plain": [
58
+ "<Figure size 640x480 with 1 Axes>"
59
+ ]
60
+ },
61
+ "metadata": {},
62
+ "output_type": "display_data"
63
+ }
64
+ ],
65
+ "source": [
66
+ "img = next(iter(train_loader))[0]\n",
67
+ "print(img[0].shape)\n",
68
+ "\n",
69
+ "\n",
70
+ "from matplotlib import pyplot as plt\n",
71
+ "plt.imshow(img[0].permute(1,2 , 0), interpolation='nearest')\n",
72
+ "plt.show()\n",
73
+ "\n",
74
+ "\n"
75
+ ]
76
+ }
77
+ ],
78
+ "metadata": {
79
+ "kernelspec": {
80
+ "display_name": "Python 3 (ipykernel)",
81
+ "language": "python",
82
+ "name": "python3"
83
+ },
84
+ "language_info": {
85
+ "codemirror_mode": {
86
+ "name": "ipython",
87
+ "version": 3
88
+ },
89
+ "file_extension": ".py",
90
+ "mimetype": "text/x-python",
91
+ "name": "python",
92
+ "nbconvert_exporter": "python",
93
+ "pygments_lexer": "ipython3",
94
+ "version": "3.10.13"
95
+ }
96
+ },
97
+ "nbformat": 4,
98
+ "nbformat_minor": 5
99
+ }