Vageesh1 commited on
Commit
4e527a6
1 Parent(s): 2c14e54

Upload 4 files

Browse files
Files changed (4) hide show
  1. neuralnet/dataset.py +139 -0
  2. neuralnet/model.py +71 -0
  3. neuralnet/train.py +130 -0
  4. neuralnet/utils.py +42 -0
neuralnet/dataset.py CHANGED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os # when loading file paths
2
+ import pandas as pd # for lookup in annotation file
3
+ import spacy # for tokenizer
4
+ import torch
5
+ from torch.nn.utils.rnn import pad_sequence # pad batch
6
+ from torch.utils.data import DataLoader, Dataset
7
+ from PIL import Image # Load img
8
+ import torchvision.transforms as transforms
9
+ import json
10
+
11
+ # Download with: python -m spacy download en
12
+ spacy_eng = spacy.load("en_core_web_sm")
13
+
14
+
15
+ class Vocabulary:
16
+ def __init__(self, freq_threshold):
17
+ self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
18
+ self.stoi = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
19
+ self.freq_threshold = freq_threshold
20
+
21
+ def __len__(self):
22
+ return len(self.stoi)
23
+
24
+ @staticmethod
25
+ def tokenizer_eng(text):
26
+ return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]
27
+
28
+ def build_vocabulary(self, sentence_list):
29
+ frequencies = {}
30
+ idx = 4
31
+
32
+ for sentence in sentence_list:
33
+ for word in self.tokenizer_eng(sentence):
34
+ if word not in frequencies:
35
+ frequencies[word] = 1
36
+
37
+ else:
38
+ frequencies[word] += 1
39
+
40
+ if frequencies[word] == self.freq_threshold:
41
+ self.stoi[word] = idx
42
+ self.itos[idx] = word
43
+ idx += 1
44
+
45
+ def numericalize(self, text):
46
+ tokenized_text = self.tokenizer_eng(text)
47
+
48
+ return [
49
+ self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
50
+ for token in tokenized_text
51
+ ]
52
+
53
+
54
+ class FlickrDataset(Dataset):
55
+ def __init__(self, root_dir, captions_file, transform=None, freq_threshold=5):
56
+ self.root_dir = root_dir
57
+ self.df = pd.read_csv(captions_file)
58
+ self.transform = transform
59
+
60
+ # Get img, caption columns
61
+ self.imgs = self.df["image_name"]
62
+ self.captions = self.df["comment"]
63
+
64
+ # Initialize vocabulary and build vocab
65
+ self.vocab = Vocabulary(freq_threshold)
66
+ self.vocab.build_vocabulary(self.captions.tolist())
67
+
68
+ def __len__(self):
69
+ return len(self.df)
70
+
71
+ def __getitem__(self, index):
72
+ caption = self.captions[index]
73
+ img_id = self.imgs[index]
74
+ img = Image.open(os.path.join(self.root_dir, img_id)).convert("RGB")
75
+
76
+ if self.transform is not None:
77
+ img = self.transform(img)
78
+
79
+ numericalized_caption = [self.vocab.stoi["<SOS>"]]
80
+ numericalized_caption += self.vocab.numericalize(caption)
81
+ numericalized_caption.append(self.vocab.stoi["<EOS>"])
82
+
83
+ return img, torch.tensor(numericalized_caption)
84
+
85
+
86
+ class MyCollate:
87
+ def __init__(self, pad_idx):
88
+ self.pad_idx = pad_idx
89
+
90
+ def __call__(self, batch):
91
+ imgs = [item[0].unsqueeze(0) for item in batch]
92
+ imgs = torch.cat(imgs, dim=0)
93
+ targets = [item[1] for item in batch]
94
+ targets = pad_sequence(targets, batch_first=False, padding_value=self.pad_idx)
95
+
96
+ return imgs, targets
97
+
98
+
99
+ def get_loader(
100
+ root_folder,
101
+ annotation_file,
102
+ transform,
103
+ batch_size=64,
104
+ num_workers=2,
105
+ shuffle=True,
106
+ pin_memory=True,
107
+ ):
108
+ dataset = FlickrDataset(root_folder, annotation_file, transform=transform)
109
+
110
+ pad_idx = dataset.vocab.stoi["<PAD>"]
111
+
112
+ loader = DataLoader(
113
+ dataset=dataset,
114
+ batch_size=batch_size,
115
+ num_workers=num_workers,
116
+ shuffle=shuffle,
117
+ pin_memory=pin_memory,
118
+ collate_fn=MyCollate(pad_idx=pad_idx),
119
+ )
120
+
121
+ return loader, dataset
122
+
123
+
124
+ if __name__ == "__main__":
125
+ transform = transforms.Compose(
126
+ [transforms.Resize((224, 224)), transforms.ToTensor(),]
127
+ )
128
+
129
+ loader, dataset = get_loader(
130
+ "/home/koushik/vscode/Projects/pytorch/img2text_v1/flickr30k/flickr30k_images/", "/home/koushik/vscode/Projects/pytorch/img2text_v1/flickr30k/results.csv", transform=transform
131
+ )
132
+
133
+ for idx, (imgs, captions) in enumerate(loader):
134
+ print(imgs.shape)
135
+ print(captions.shape)
136
+ print(len(dataset.vocab))
137
+ test = {"itos":dataset.vocab.itos, "stoi": dataset.vocab.stoi}
138
+ json.dump(test, open('test.json', 'w'))
139
+ break
neuralnet/model.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision.models as models
4
+
5
+
6
+ class InceptionEncoder(nn.Module):
7
+ def __init__(self, embed_size, train_CNN=False):
8
+ super(InceptionEncoder, self).__init__()
9
+ self.train_CNN = train_CNN
10
+ self.inception = models.inception_v3(pretrained=True, aux_logits=False)
11
+ self.inception.fc = nn.Linear(self.inception.fc.in_features, embed_size)
12
+ self.relu = nn.ReLU()
13
+ self.bn = nn.BatchNorm1d(embed_size, momentum = 0.01)
14
+ self.dropout = nn.Dropout(0.5)
15
+
16
+ def forward(self, images):
17
+ features = self.inception(images)
18
+ norm_features = self.bn(features)
19
+ return self.dropout(self.relu(norm_features))
20
+
21
+
22
+ class LstmDecoder(nn.Module):
23
+ def __init__(self, embed_size, hidden_size, vocab_size, num_layers, device = 'cpu'):
24
+ super(LstmDecoder, self).__init__()
25
+ self.num_layers = num_layers
26
+ self.hidden_size = hidden_size
27
+ self.device = device
28
+ self.embed = nn.Embedding(vocab_size, embed_size)
29
+ self.lstm = nn.LSTM(embed_size, hidden_size, num_layers = self.num_layers)
30
+ self.linear = nn.Linear(hidden_size, vocab_size)
31
+ self.dropout = nn.Dropout(0.5)
32
+
33
+ def forward(self, encoder_out, captions):
34
+ h0 = torch.zeros(self.num_layers, encoder_out.shape[0], self.hidden_size).to(self.device).requires_grad_()
35
+ c0 = torch.zeros(self.num_layers, encoder_out.shape[0], self.hidden_size).to(self.device).requires_grad_()
36
+ embeddings = self.dropout(self.embed(captions))
37
+ embeddings = torch.cat((encoder_out.unsqueeze(0), embeddings), dim=0)
38
+ hiddens, (hn, cn) = self.lstm(embeddings, (h0.detach(), c0.detach()))
39
+ outputs = self.linear(hiddens)
40
+ return outputs
41
+
42
+
43
+ class SeqToSeq(nn.Module):
44
+ def __init__(self, embed_size, hidden_size, vocab_size, num_layers, device = 'cpu'):
45
+ super(SeqToSeq, self).__init__()
46
+ self.encoder = InceptionEncoder(embed_size)
47
+ self.decoder = LstmDecoder(embed_size, hidden_size, vocab_size, num_layers, device)
48
+
49
+ def forward(self, images, captions):
50
+ features = self.encoder(images)
51
+ outputs = self.decoder(features, captions)
52
+ return outputs
53
+
54
+ def caption_image(self, image, vocabulary, max_length = 50):
55
+ result_caption = []
56
+
57
+ with torch.no_grad():
58
+ x = self.encoder(image).unsqueeze(0)
59
+ states = None
60
+
61
+ for _ in range(max_length):
62
+ hiddens, states = self.decoder.lstm(x, states)
63
+ output = self.decoder.linear(hiddens.squeeze(0))
64
+ predicted = output.argmax(1)
65
+ result_caption.append(predicted.item())
66
+ x = self.decoder.embed(predicted).unsqueeze(0)
67
+
68
+ if vocabulary[str(predicted.item())] == "<EOS>":
69
+ break
70
+
71
+ return [vocabulary[str(idx)] for idx in result_caption]
neuralnet/train.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from tqdm import tqdm
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ import torchvision.transforms as transforms
6
+ from torch.utils.tensorboard import SummaryWriter # For TensorBoard
7
+ from utils import save_checkpoint, load_checkpoint, print_examples
8
+ from dataset import get_loader
9
+ from model import SeqToSeq
10
+ from tabulate import tabulate # To tabulate loss and epoch
11
+ import argparse
12
+ import json
13
+
14
+ def main(args):
15
+ transform = transforms.Compose(
16
+ [
17
+ transforms.Resize((356, 356)),
18
+ transforms.RandomCrop((299, 299)),
19
+ transforms.ToTensor(),
20
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
21
+ ]
22
+ )
23
+
24
+ train_loader, _ = get_loader(
25
+ root_folder = args.root_dir,
26
+ annotation_file = args.csv_file,
27
+ transform=transform,
28
+ batch_size = 64,
29
+ num_workers=2,
30
+ )
31
+ vocab = json.load(open('vocab.json'))
32
+
33
+ torch.backends.cudnn.benchmark = True
34
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
+ load_model = False
36
+ save_model = True
37
+ train_CNN = False
38
+
39
+ # Hyperparameters
40
+ embed_size = args.embed_size
41
+ hidden_size = args.hidden_size
42
+ vocab_size = len(vocab['stoi'])
43
+ num_layers = args.num_layers
44
+ learning_rate = args.lr
45
+ num_epochs = args.num_epochs
46
+ # for tensorboard
47
+
48
+
49
+ writer = SummaryWriter(args.log_dir)
50
+ step = 0
51
+ model_params = {'embed_size': embed_size, 'hidden_size': hidden_size, 'vocab_size':vocab_size, 'num_layers':num_layers}
52
+ # initialize model, loss etc
53
+ model = SeqToSeq(**model_params, device = device).to(device)
54
+ criterion = nn.CrossEntropyLoss(ignore_index = vocab['stoi']["<PAD>"])
55
+ optimizer = optim.Adam(model.parameters(), lr=learning_rate)
56
+
57
+ # Only finetune the CNN
58
+ for name, param in model.encoder.inception.named_parameters():
59
+ if "fc.weight" in name or "fc.bias" in name:
60
+ param.requires_grad = True
61
+ else:
62
+ param.requires_grad = train_CNN
63
+
64
+ #load from a save checkpoint
65
+ if load_model:
66
+ step = load_checkpoint(torch.load(args.save_path), model, optimizer)
67
+
68
+ model.train()
69
+ best_loss, best_epoch = 10, 0
70
+ for epoch in range(num_epochs):
71
+ print_examples(model, device, vocab['itos'])
72
+
73
+ for idx, (imgs, captions) in tqdm(
74
+ enumerate(train_loader), total=len(train_loader), leave=False):
75
+ imgs = imgs.to(device)
76
+ captions = captions.to(device)
77
+
78
+ outputs = model(imgs, captions[:-1])
79
+ loss = criterion(
80
+ outputs.reshape(-1, outputs.shape[2]), captions.reshape(-1)
81
+ )
82
+
83
+ writer.add_scalar("Training loss", loss.item(), global_step=step)
84
+ step += 1
85
+
86
+ optimizer.zero_grad()
87
+ loss.backward(loss)
88
+ optimizer.step()
89
+
90
+ train_loss = loss.item()
91
+ if train_loss < best_loss:
92
+ best_loss = train_loss
93
+ best_epoch = epoch + 1
94
+ if save_model:
95
+ checkpoint = {
96
+ "model_params": model_params,
97
+ "state_dict": model.state_dict(),
98
+ "optimizer": optimizer.state_dict(),
99
+ "step": step
100
+ }
101
+ save_checkpoint(checkpoint, args.save_path)
102
+
103
+
104
+ table = [["Loss:", train_loss],
105
+ ["Step:", step],
106
+ ["Epoch:", epoch + 1],
107
+ ["Best Loss:", best_loss],
108
+ ["Best Epoch:", best_epoch]]
109
+ print(tabulate(table))
110
+
111
+
112
+ if __name__ == "__main__":
113
+
114
+ parser = argparse.ArgumentParser()
115
+
116
+ parser.add_argument('--root_dir', type = str, default = './flickr30k/flickr30k_images', help = 'path to images folder')
117
+ parser.add_argument('--csv_file', type = str, default = './flickr30k/results.csv', help = 'path to captions csv file')
118
+ parser.add_argument('--log_dir', type = str, default = './drive/MyDrive/TensorBoard/', help = 'path to save tensorboard logs')
119
+ parser.add_argument('--save_path', type = str, default = './drive/MyDrive/checkpoints/Seq2Seq.pt', help = 'path to save checkpoint')
120
+ # Model Params
121
+ parser.add_argument('--batch_size', type = int, default = 64)
122
+ parser.add_argument('--num_epochs', type = int, default = 100)
123
+ parser.add_argument('--embed_size', type = int, default=256)
124
+ parser.add_argument('--hidden_size', type = int, default=512)
125
+ parser.add_argument('--lr', type = float, default= 0.001)
126
+ parser.add_argument('--num_layers', type = int, default = 3, help = 'number of lstm layers')
127
+
128
+ args = parser.parse_args()
129
+
130
+ main(args)
neuralnet/utils.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms as transforms
3
+ from PIL import Image
4
+
5
+
6
+ def print_examples(model, device, vocab):
7
+ transform = transforms.Compose(
8
+ [transforms.Resize((299, 299)),
9
+ transforms.ToTensor(),
10
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
11
+ )
12
+
13
+ model.eval()
14
+
15
+ test_img1 = transform(Image.open("./test_examples/dog.png").convert("RGB")).unsqueeze(0)
16
+ print("dog.png PREDICTION: " + " ".join(model.caption_image(test_img1.to(device), vocab)))
17
+
18
+ test_img2 = transform(Image.open("./test_examples/dirt_bike.png").convert("RGB")).unsqueeze(0)
19
+ print("dirt_bike.png PREDICTION: " + " ".join(model.caption_image(test_img2.to(device), vocab)))
20
+
21
+ test_img3 = transform(Image.open("./test_examples/surfing.png").convert("RGB")).unsqueeze(0)
22
+ print("wave.png PREDICTION: " + " ".join(model.caption_image(test_img3.to(device), vocab)))
23
+
24
+ test_img4 = transform(Image.open("./test_examples/horse.png").convert("RGB")).unsqueeze(0)
25
+ print("horse.png PREDICTION: " + " ".join(model.caption_image(test_img4.to(device), vocab)))
26
+
27
+ test_img5 = transform(Image.open("./test_examples/camera.png").convert("RGB")).unsqueeze(0)
28
+ print("camera.png PREDICTION: " + " ".join(model.caption_image(test_img5.to(device), vocab)))
29
+ model.train()
30
+
31
+
32
+ def save_checkpoint(state, filename="/content/drive/MyDrive/checkpoints/Seq2Seq.pt"):
33
+ print("=> Saving checkpoint")
34
+ torch.save(state, filename)
35
+
36
+
37
+ def load_checkpoint(checkpoint, model, optimizer):
38
+ print("=> Loading checkpoint")
39
+ model.load_state_dict(checkpoint["state_dict"])
40
+ optimizer.load_state_dict(checkpoint["optimizer"])
41
+ step = checkpoint["step"]
42
+ return step