VenkateshRoshan commited on
Commit
3138612
·
1 Parent(s): be7ebcc

Training script is now working

Browse files
config/__pycache__/config.cpython-310.pyc CHANGED
Binary files a/config/__pycache__/config.cpython-310.pyc and b/config/__pycache__/config.cpython-310.pyc differ
 
config/config.py CHANGED
@@ -8,4 +8,5 @@ class Config:
8
  EPOCHS = 10
9
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
10
  AWS_S3_BUCKET = 'your-s3-bucket-name'
11
- DATASET_PATH = '../Datasets/Flickr8K/'
 
 
8
  EPOCHS = 10
9
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
10
  AWS_S3_BUCKET = 'your-s3-bucket-name'
11
+ DATASET_PATH = '../Datasets/Flickr8K/'
12
+ BATCH_SIZE = 32
data/__pycache__/dataLoader.cpython-310.pyc CHANGED
Binary files a/data/__pycache__/dataLoader.cpython-310.pyc and b/data/__pycache__/dataLoader.cpython-310.pyc differ
 
data/dataLoader.py CHANGED
@@ -1,52 +1,59 @@
1
- import numpy as np
2
  import os
3
- import cv2
4
  from PIL import Image
5
  from torchvision import transforms
6
- import pandas as pd
 
 
 
 
 
 
 
 
 
 
7
 
8
- class dataLoader:
9
- def __init__(self, path):
10
- self.path = path
11
- self.img_path = path + 'images/'
12
- self.caption_path = path + 'captions.csv'
13
- self.img_list = os.listdir(self.img_path)
14
- self.caption_dict = self.get_caption_dict()
15
- self.transform = transforms.Compose([
16
- transforms.Resize((224, 224)),
17
- transforms.ToTensor()
 
 
 
 
 
18
  ])
19
 
20
- def get_caption_dict(self):
21
- caption_dict = {}
22
- df = pd.read_csv(self.caption_path, delimiter=',')
23
- for i in range(len(df)):
24
- img_name = df.iloc[i, 0]
25
- caption = df.iloc[i, 1]
26
- caption_dict[img_name] = caption
27
- return caption_dict
28
-
29
- def get_image(self, img_name):
30
- img = Image.open(self.img_path + img_name)
31
- img = self.transform(img)
32
- return img
33
-
34
- def get_caption(self, img_name):
35
- return self.caption_dict[img_name]
36
-
37
- def get_batch(self, batch_size):
38
- batch = np.random.choice(self.img_list, batch_size)
39
- images = []
40
- captions = []
41
- for img_name in batch:
42
- images.append(self.get_image(img_name))
43
- captions.append(self.get_caption(img_name))
44
- return images, captions
45
-
46
- def get_all(self):
47
- images = []
48
- captions = []
49
- for img_name in self.img_list:
50
- images.append(self.get_image(img_name))
51
- captions.append(self.get_caption(img_name))
52
- return images, captions
 
 
1
  import os
2
+ import pandas as pd
3
  from PIL import Image
4
  from torchvision import transforms
5
+ from torch.utils.data import Dataset
6
+
7
+ class ImageCaptionDataset(Dataset):
8
+ """
9
+ Custom PyTorch Dataset class to handle loading and transforming image-caption pairs
10
+ where image paths and captions are provided in a CSV file.
11
+
12
+ Attributes:
13
+ caption_file (str): Path to the CSV file containing image paths and captions.
14
+ transform (torchvision.transforms.Compose): Transformations to apply on the images.
15
+ """
16
 
17
+ def __init__(self, caption_file: str, file_path: str, transform=None):
18
+ """
19
+ Initialize dataset with caption CSV file and optional transform.
20
+
21
+ Args:
22
+ caption_file (str): Path to the CSV file where each row has an image path and caption.
23
+ transform (callable, optional): Optional transform to apply on an image.
24
+ """
25
+ self.df = pd.read_csv(caption_file)
26
+ self.image_path = file_path
27
+ self.transform = transform or transforms.Compose([
28
+ transforms.Resize((224, 224)), # Resize to 224x224 for ViT
29
+ transforms.ToTensor(), # Convert to tensor
30
+ # Normalize to have values in the range [0, 1]
31
+ # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
32
  ])
33
 
34
+ def __len__(self):
35
+ """
36
+ Return the total number of samples in the dataset.
37
+ """
38
+ return len(self.df)
39
+
40
+ def __getitem__(self, idx):
41
+ """
42
+ Retrieve an image and its corresponding caption by index.
43
+
44
+ Args:
45
+ idx (int): Index of the data item.
46
+
47
+ Returns:
48
+ tuple: (image, caption) where image is the transformed image tensor and caption is the associated text.
49
+ """
50
+ img_path = self.df.iloc[idx, 0] # The first column contains image paths
51
+ caption = self.df.iloc[idx, 1] # The second column contains captions
52
+ # Load image
53
+ image = Image.open(self.image_path+img_path).convert('RGB')
54
+
55
+ # Apply transformations to the image
56
+ if self.transform:
57
+ image = self.transform(image)
58
+
59
+ return image, caption
 
 
 
 
 
 
 
main.py DELETED
@@ -1,14 +0,0 @@
1
- import numpy as np
2
- import os
3
- import cv2
4
- from PIL import Image
5
- from matplotlib import pyplot as plt
6
-
7
- from config.config import Config
8
- from data.dataLoader import dataLoader
9
-
10
- if __name__ == '__main__':
11
- dl = dataLoader(Config.DATASET_PATH)
12
- images, captions = dl.get_all()
13
- print('Number of images:', len(images))
14
- print('Number of captions:', len(captions))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/__pycache__/model.cpython-310.pyc ADDED
Binary file (2.69 kB). View file
 
models/model.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import ViTModel, ViTFeatureExtractor, GPT2LMHeadModel, GPT2Tokenizer
3
+ from config.config import Config
4
+ from torchsummary import summary
5
+ from torchvision import transforms
6
+
7
+ class ImageCaptioningModel:
8
+ def __init__(self):
9
+ """Initialize the ViT and GPT-2 models for image captioning."""
10
+ self.device = Config.DEVICE
11
+ self.vit_model = ViTModel.from_pretrained(Config.VIT_MODEL).to(self.device)
12
+ self.feature_extractor = ViTFeatureExtractor.from_pretrained(Config.VIT_MODEL)
13
+ self.gpt2_model = GPT2LMHeadModel.from_pretrained(Config.GPT2_MODEL).to(self.device)
14
+ self.tokenizer = GPT2Tokenizer.from_pretrained(Config.GPT2_MODEL)
15
+ self.tokenizer.pad_token = self.tokenizer.eos_token
16
+
17
+ def extract_image_features(self, images):
18
+ """Extract features from images using ViT."""
19
+ pixel_values = self.feature_extractor(images=images, return_tensors="pt", do_rescale=False).pixel_values.to(self.device)
20
+ with torch.no_grad():
21
+ outputs = self.vit_model(pixel_values)
22
+ return outputs.last_hidden_state[:, 0, :] # [batch_size, hidden_size]
23
+
24
+ def prepare_gpt2_inputs(self, image_features, captions):
25
+ """Prepare GPT-2 inputs."""
26
+ # Tokenize the captions
27
+ tokenized_captions = self.tokenizer(captions, padding="longest", truncation=True,
28
+ max_length=Config.MAX_SEQ_LEN, return_tensors="pt").to(self.device)
29
+
30
+ # Get the word embeddings for the tokens
31
+ token_embeddings = self.gpt2_model.transformer.wte(tokenized_captions['input_ids'])
32
+
33
+ # Concatenate image features with token embeddings
34
+ image_features = image_features.unsqueeze(1) # Reshape to [batch_size, 1, hidden_size]
35
+ inputs_embeds = torch.cat((image_features, token_embeddings), dim=1) # Concatenate along the sequence dimension
36
+
37
+ # Adjust input_ids to account for the image feature token
38
+ batch_size = image_features.shape[0]
39
+ image_token_id = torch.full((batch_size, 1), fill_value=self.tokenizer.bos_token_id, device=self.device)
40
+ input_ids = torch.cat((image_token_id, tokenized_captions['input_ids']), dim=1)
41
+
42
+ # Adjust attention_mask to account for the image feature token
43
+ image_attention = torch.ones((batch_size, 1), device=self.device)
44
+ attention_mask = torch.cat((image_attention, tokenized_captions['attention_mask']), dim=1)
45
+
46
+ return inputs_embeds, input_ids, attention_mask
47
+
48
+ def save(self, path):
49
+ """Save model to disk."""
50
+ self.gpt2_model.save_pretrained(path)
51
+ self.tokenizer.save_pretrained(path)
52
+
53
+ def load(self, path):
54
+ """Load model from disk."""
55
+ self.gpt2_model = GPT2LMHeadModel.from_pretrained(path).to(self.device)
56
+ self.tokenizer = GPT2Tokenizer.from_pretrained(path).to(self.device)
train.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from tqdm import tqdm
4
+
5
+ from torch.utils.data import DataLoader
6
+ from data.dataLoader import ImageCaptionDataset
7
+ from config.config import Config
8
+ from models.model import ImageCaptioningModel
9
+
10
+ from torchsummary import summary
11
+
12
+
13
+
14
+ def train_model(model,dataLoader, optimizer, loss_fn):
15
+
16
+ model.gpt2_model.train()
17
+ for epoch in range(Config.EPOCHS):
18
+ epoch_loss = 0
19
+ for batch_idx, (images, captions) in tqdm(enumerate(dataLoader)):
20
+ print(f'\rBatch {batch_idx + 1}/{len(dataLoader)} , Loss : {epoch_loss/(batch_idx+1):.4f}\t', end='')
21
+ images = images.to(Config.DEVICE)
22
+ captions = [caption for caption in captions]
23
+
24
+ # extract image features
25
+ image_features = model.extract_image_features(images)
26
+ # print("Image Features shape:", image_features.shape)
27
+ input_embeds, input_ids, attention_mask = model.prepare_gpt2_inputs(image_features, captions)
28
+
29
+ # print("Input Embeds shape:", input_embeds.shape)
30
+ # print("Input IDs shape:", input_ids.shape)
31
+ # print("Attention Mask shape:", attention_mask.shape)
32
+ # Match Inputs Embeds and Input Ids and Attention Masks
33
+ assert input_embeds.shape[1] == input_ids.shape[1] == attention_mask.shape[1]
34
+
35
+ optimizer.zero_grad()
36
+ outputs = model.gpt2_model(inputs_embeds=input_embeds, labels=input_ids, attention_mask=attention_mask)
37
+
38
+ loss = outputs.loss
39
+ loss.backward()
40
+ optimizer.step()
41
+
42
+ epoch_loss += loss.item()
43
+
44
+ print(f'Epoch {epoch + 1}, Loss: {epoch_loss:.4f}')
45
+
46
+ # Save the model
47
+ model.save('model')
48
+
49
+ # return model
50
+
51
+
52
+ if __name__ == '__main__':
53
+ # Initialize dataset using the CSV file
54
+ dataset = ImageCaptionDataset(
55
+ caption_file=Config.DATASET_PATH + 'captions.csv', # Path to captions CSV file
56
+ file_path = Config.DATASET_PATH+ '/images/', # Path to images folder
57
+ )
58
+
59
+ # Create DataLoader for batch processing
60
+ dataloader = DataLoader(
61
+ dataset,
62
+ batch_size=Config.BATCH_SIZE, # Specify the batch size
63
+ shuffle=True, # Shuffle the data
64
+ num_workers=4 # Number of subprocesses for data loading
65
+ )
66
+
67
+ # # Iterate over the dataloader
68
+ # for batch_idx, (images, captions) in enumerate(dataloader):
69
+ # print(f'Batch {batch_idx + 1}:')
70
+ # print(f'Images shape: {images.shape}')
71
+ # print(f'Captions: {captions}')
72
+ # # Pass 'images' and 'captions' to your model for training/validation
73
+
74
+ # Initialize the ImageCaptioningModel
75
+ model = ImageCaptioningModel()
76
+ optimizer = torch.optim.Adam(model.gpt2_model.parameters(), lr=Config.LEARNING_RATE)
77
+ loss_fn = torch.nn.CrossEntropyLoss()
78
+ train_model(model, dataloader, optimizer, loss_fn)