mrrahul011 commited on
Commit
6a46f5c
·
verified ·
1 Parent(s): d7ff968

Upload 2 files

Browse files
Files changed (2) hide show
  1. assignment23.py +408 -0
  2. best.pt +3 -0
assignment23.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import gc
4
+ import numpy as np
5
+ import pandas as pd
6
+ import itertools
7
+ from tqdm.autonotebook import tqdm
8
+ import albumentations as A
9
+ import matplotlib.pyplot as plt
10
+
11
+ import torch
12
+ from torch import nn
13
+ import torch.nn.functional as F
14
+ import timm
15
+ from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer
16
+
17
+
18
+ image_path = "./Images"
19
+ captions_path = "."
20
+
21
+
22
+ class CFG:
23
+ debug = False
24
+ image_path = image_path
25
+ captions_path = captions_path
26
+ batch_size = 32
27
+ num_workers = 2
28
+ head_lr = 1e-3
29
+ image_encoder_lr = 1e-4
30
+ text_encoder_lr = 1e-5
31
+ weight_decay = 1e-3
32
+ patience = 1
33
+ factor = 0.8
34
+ epochs = 4
35
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
+
37
+ model_name = 'resnet50'
38
+ image_embedding = 2048
39
+ text_encoder_model = "distilbert-base-uncased"
40
+ text_embedding = 768
41
+ text_tokenizer = "distilbert-base-uncased"
42
+ max_length = 200
43
+
44
+ pretrained = True # for both image encoder and text encoder
45
+ trainable = True # for both image encoder and text encoder
46
+ temperature = 1.0
47
+
48
+ # image size
49
+ size = 224
50
+
51
+ # for projection head; used for both image and text encoders
52
+ num_projection_layers = 1
53
+ projection_dim = 256
54
+ dropout = 0.1
55
+
56
+ class AvgMeter:
57
+ def __init__(self, name="Metric"):
58
+ self.name = name
59
+ self.reset()
60
+
61
+ def reset(self):
62
+ self.avg, self.sum, self.count = [0] * 3
63
+
64
+ def update(self, val, count=1):
65
+ self.count += count
66
+ self.sum += val * count
67
+ self.avg = self.sum / self.count
68
+
69
+ def __repr__(self):
70
+ text = f"{self.name}: {self.avg:.4f}"
71
+ return text
72
+
73
+ def get_lr(optimizer):
74
+ for param_group in optimizer.param_groups:
75
+ return param_group["lr"]
76
+
77
+ class CLIPDataset(torch.utils.data.Dataset):
78
+ def __init__(self, image_filenames, captions, tokenizer, transforms):
79
+ """
80
+ image_filenames and cpations must have the same length; so, if there are
81
+ multiple captions for each image, the image_filenames must have repetitive
82
+ file names
83
+ """
84
+
85
+ self.image_filenames = image_filenames
86
+ self.captions = list(captions)
87
+ self.encoded_captions = tokenizer(
88
+ list(captions), padding=True, truncation=True, max_length=CFG.max_length
89
+ )
90
+ self.transforms = transforms
91
+
92
+ def __getitem__(self, idx):
93
+ item = {
94
+ key: torch.tensor(values[idx])
95
+ for key, values in self.encoded_captions.items()
96
+ }
97
+
98
+ image = cv2.imread(f"{CFG.image_path}/{self.image_filenames[idx]}")
99
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
100
+ image = self.transforms(image=image)['image']
101
+ item['image'] = torch.tensor(image).permute(2, 0, 1).float()
102
+ item['caption'] = self.captions[idx]
103
+
104
+ return item
105
+
106
+
107
+ def __len__(self):
108
+ return len(self.captions)
109
+
110
+
111
+
112
+ def get_transforms(mode="train"):
113
+ if mode == "train":
114
+ return A.Compose(
115
+ [
116
+ A.Resize(CFG.size, CFG.size, always_apply=True),
117
+ A.Normalize(max_pixel_value=255.0, always_apply=True),
118
+ ]
119
+ )
120
+ else:
121
+ return A.Compose(
122
+ [
123
+ A.Resize(CFG.size, CFG.size, always_apply=True),
124
+ A.Normalize(max_pixel_value=255.0, always_apply=True),
125
+ ]
126
+ )
127
+
128
+ class ImageEncoder(nn.Module):
129
+ """
130
+ Encode images to a fixed size vector
131
+ """
132
+
133
+ def __init__(
134
+ self, model_name=CFG.model_name, pretrained=CFG.pretrained, trainable=CFG.trainable
135
+ ):
136
+ super().__init__()
137
+ self.model = timm.create_model(
138
+ model_name, pretrained, num_classes=0, global_pool="avg"
139
+ )
140
+ for p in self.model.parameters():
141
+ p.requires_grad = trainable
142
+
143
+ def forward(self, x):
144
+ return self.model(x)
145
+
146
+ class TextEncoder(nn.Module):
147
+ def __init__(self, model_name=CFG.text_encoder_model, pretrained=CFG.pretrained, trainable=CFG.trainable):
148
+ super().__init__()
149
+ if pretrained:
150
+ self.model = DistilBertModel.from_pretrained(model_name)
151
+ else:
152
+ self.model = DistilBertModel(config=DistilBertConfig())
153
+
154
+ for p in self.model.parameters():
155
+ p.requires_grad = trainable
156
+
157
+ # we are using the CLS token hidden representation as the sentence's embedding
158
+ self.target_token_idx = 0
159
+
160
+ def forward(self, input_ids, attention_mask):
161
+ output = self.model(input_ids=input_ids, attention_mask=attention_mask)
162
+ last_hidden_state = output.last_hidden_state
163
+ return last_hidden_state[:, self.target_token_idx, :]
164
+
165
+ class ProjectionHead(nn.Module):
166
+ def __init__(
167
+ self,
168
+ embedding_dim,
169
+ projection_dim=CFG.projection_dim,
170
+ dropout=CFG.dropout
171
+ ):
172
+ super().__init__()
173
+ self.projection = nn.Linear(embedding_dim, projection_dim)
174
+ self.gelu = nn.GELU()
175
+ self.fc = nn.Linear(projection_dim, projection_dim)
176
+ self.dropout = nn.Dropout(dropout)
177
+ self.layer_norm = nn.LayerNorm(projection_dim)
178
+
179
+ def forward(self, x):
180
+ projected = self.projection(x)
181
+ x = self.gelu(projected)
182
+ x = self.fc(x)
183
+ x = self.dropout(x)
184
+ x = x + projected
185
+ x = self.layer_norm(x)
186
+ return x
187
+
188
+ class CLIPModel(nn.Module):
189
+ def __init__(
190
+ self,
191
+ temperature=CFG.temperature,
192
+ image_embedding=CFG.image_embedding,
193
+ text_embedding=CFG.text_embedding,
194
+ ):
195
+ super().__init__()
196
+ self.image_encoder = ImageEncoder()
197
+ self.text_encoder = TextEncoder()
198
+ self.image_projection = ProjectionHead(embedding_dim=image_embedding)
199
+ self.text_projection = ProjectionHead(embedding_dim=text_embedding)
200
+ self.temperature = temperature
201
+
202
+ def forward(self, batch):
203
+ # Getting Image and Text Features
204
+ image_features = self.image_encoder(batch["image"])
205
+ text_features = self.text_encoder(
206
+ input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
207
+ )
208
+ # Getting Image and Text Embeddings (with same dimension)
209
+ image_embeddings = self.image_projection(image_features)
210
+ text_embeddings = self.text_projection(text_features)
211
+
212
+ # Calculating the Loss
213
+ logits = (text_embeddings @ image_embeddings.T) / self.temperature
214
+ images_similarity = image_embeddings @ image_embeddings.T
215
+ texts_similarity = text_embeddings @ text_embeddings.T
216
+ targets = F.softmax(
217
+ (images_similarity + texts_similarity) / 2 * self.temperature, dim=-1
218
+ )
219
+ texts_loss = cross_entropy(logits, targets, reduction='none')
220
+ images_loss = cross_entropy(logits.T, targets.T, reduction='none')
221
+ loss = (images_loss + texts_loss) / 2.0 # shape: (batch_size)
222
+ return loss.mean()
223
+
224
+
225
+ def cross_entropy(preds, targets, reduction='none'):
226
+ log_softmax = nn.LogSoftmax(dim=-1)
227
+ loss = (-targets * log_softmax(preds)).sum(1)
228
+ if reduction == "none":
229
+ return loss
230
+ elif reduction == "mean":
231
+ return loss.mean()
232
+
233
+ # A simple Example
234
+
235
+ batch_size = 4
236
+ dim = 256
237
+ embeddings = torch.randn(batch_size, dim)
238
+ out = embeddings @ embeddings.T
239
+ print(F.softmax(out, dim=-1))
240
+
241
+ def make_train_valid_dfs():
242
+ dataframe = pd.read_csv(f"{CFG.captions_path}/captions.csv")
243
+ max_id = dataframe["id"].max() + 1 if not CFG.debug else 100
244
+ image_ids = np.arange(0, max_id)
245
+ np.random.seed(42)
246
+ valid_ids = np.random.choice(
247
+ image_ids, size=int(0.2 * len(image_ids)), replace=False
248
+ )
249
+ train_ids = [id_ for id_ in image_ids if id_ not in valid_ids]
250
+ train_dataframe = dataframe[dataframe["id"].isin(train_ids)].reset_index(drop=True)
251
+ valid_dataframe = dataframe[dataframe["id"].isin(valid_ids)].reset_index(drop=True)
252
+ return train_dataframe, valid_dataframe
253
+
254
+
255
+ def build_loaders(dataframe, tokenizer, mode):
256
+ transforms = get_transforms(mode=mode)
257
+ dataset = CLIPDataset(
258
+ dataframe["image"].values,
259
+ dataframe["caption"].values,
260
+ tokenizer=tokenizer,
261
+ transforms=transforms,
262
+ )
263
+ dataloader = torch.utils.data.DataLoader(
264
+ dataset,
265
+ batch_size=CFG.batch_size,
266
+ num_workers=CFG.num_workers,
267
+ shuffle=True if mode == "train" else False,
268
+ )
269
+ return dataloader
270
+
271
+ def train_epoch(model, train_loader, optimizer, lr_scheduler, step):
272
+ loss_meter = AvgMeter()
273
+ tqdm_object = tqdm(train_loader, total=len(train_loader))
274
+ for batch in tqdm_object:
275
+ batch = {k: v.to(CFG.device) for k, v in batch.items() if k != "caption"}
276
+ loss = model(batch)
277
+ optimizer.zero_grad()
278
+ loss.backward()
279
+ optimizer.step()
280
+ if step == "batch":
281
+ lr_scheduler.step()
282
+
283
+ count = batch["image"].size(0)
284
+ loss_meter.update(loss.item(), count)
285
+
286
+ tqdm_object.set_postfix(train_loss=loss_meter.avg, lr=get_lr(optimizer))
287
+ return loss_meter
288
+
289
+
290
+ def valid_epoch(model, valid_loader):
291
+ loss_meter = AvgMeter()
292
+
293
+ tqdm_object = tqdm(valid_loader, total=len(valid_loader))
294
+ for batch in tqdm_object:
295
+ batch = {k: v.to(CFG.device) for k, v in batch.items() if k != "caption"}
296
+ loss = model(batch)
297
+
298
+ count = batch["image"].size(0)
299
+ loss_meter.update(loss.item(), count)
300
+
301
+ tqdm_object.set_postfix(valid_loss=loss_meter.avg)
302
+ return loss_meter
303
+
304
+
305
+ def main():
306
+ train_df, valid_df = make_train_valid_dfs()
307
+ tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
308
+ train_loader = build_loaders(train_df, tokenizer, mode="train")
309
+ valid_loader = build_loaders(valid_df, tokenizer, mode="valid")
310
+
311
+
312
+ model = CLIPModel().to(CFG.device)
313
+ params = [
314
+ {"params": model.image_encoder.parameters(), "lr": CFG.image_encoder_lr},
315
+ {"params": model.text_encoder.parameters(), "lr": CFG.text_encoder_lr},
316
+ {"params": itertools.chain(
317
+ model.image_projection.parameters(), model.text_projection.parameters()
318
+ ), "lr": CFG.head_lr, "weight_decay": CFG.weight_decay}
319
+ ]
320
+ optimizer = torch.optim.AdamW(params, weight_decay=0.)
321
+ lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
322
+ optimizer, mode="min", patience=CFG.patience, factor=CFG.factor
323
+ )
324
+ step = "epoch"
325
+
326
+ best_loss = float('inf')
327
+ for epoch in range(CFG.epochs):
328
+ print(f"Epoch: {epoch + 1}")
329
+ model.train()
330
+ train_loss = train_epoch(model, train_loader, optimizer, lr_scheduler, step)
331
+ model.eval()
332
+ with torch.no_grad():
333
+ valid_loss = valid_epoch(model, valid_loader)
334
+
335
+ if valid_loss.avg < best_loss:
336
+ best_loss = valid_loss.avg
337
+ torch.save(model.state_dict(), "best.pt")
338
+ print("Saved Best Model!")
339
+
340
+ lr_scheduler.step(valid_loss.avg)
341
+
342
+ main()
343
+
344
+ def get_image_embeddings(valid_df, model_path):
345
+ tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
346
+ valid_loader = build_loaders(valid_df, tokenizer, mode="valid")
347
+
348
+ model = CLIPModel().to(CFG.device)
349
+ model.load_state_dict(torch.load(model_path, map_location=CFG.device))
350
+ model.eval()
351
+
352
+ valid_image_embeddings = []
353
+ with torch.no_grad():
354
+ for batch in tqdm(valid_loader):
355
+ image_features = model.image_encoder(batch["image"].to(CFG.device))
356
+ image_embeddings = model.image_projection(image_features)
357
+ valid_image_embeddings.append(image_embeddings)
358
+ return model, torch.cat(valid_image_embeddings)
359
+
360
+ _, valid_df = make_train_valid_dfs()
361
+ model, image_embeddings = get_image_embeddings(valid_df, "best.pt")
362
+
363
+ def find_matches(model, image_embeddings, query, image_filenames, n=9):
364
+ tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
365
+ encoded_query = tokenizer([query])
366
+ batch = {
367
+ key: torch.tensor(values).to(CFG.device)
368
+ for key, values in encoded_query.items()
369
+ }
370
+ with torch.no_grad():
371
+ text_features = model.text_encoder(
372
+ input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
373
+ )
374
+ text_embeddings = model.text_projection(text_features)
375
+
376
+ image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1)
377
+ text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)
378
+ dot_similarity = text_embeddings_n @ image_embeddings_n.T
379
+
380
+ values, indices = torch.topk(dot_similarity.squeeze(0), n * 5)
381
+ matches = [image_filenames[idx] for idx in indices[::5]]
382
+
383
+ _, axes = plt.subplots(3, 3, figsize=(10, 10))
384
+ for match, ax in zip(matches, axes.flatten()):
385
+ image = cv2.imread(f"{CFG.image_path}/{match}")
386
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
387
+ ax.imshow(image)
388
+ ax.axis("off")
389
+
390
+ plt.show()
391
+
392
+ find_matches(model,
393
+ image_embeddings,
394
+ query="man and women on road",
395
+ image_filenames=valid_df['image'].values,
396
+ n=9)
397
+
398
+
399
+ def inference_CLIP(query_text):
400
+ _, valid_df = make_train_valid_dfs()
401
+ model, image_embeddings = get_image_embeddings(valid_df, "best.pt")
402
+ return find_matches(model,
403
+ image_embeddings,
404
+ query=query_text,
405
+ # query="dogs on the grass",
406
+ image_filenames=valid_df['image'].values,
407
+ n=9)
408
+
best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a67c13069b156ab6b439eeb5994c19f72ccd7f5736939ef25d4355b324a1457e
3
+ size 363250624