Spaces:
Sleeping
Sleeping
VenkateshRoshan
commited on
Commit
·
3138612
1
Parent(s):
be7ebcc
Training script is now working
Browse files- config/__pycache__/config.cpython-310.pyc +0 -0
- config/config.py +2 -1
- data/__pycache__/dataLoader.cpython-310.pyc +0 -0
- data/dataLoader.py +53 -46
- main.py +0 -14
- models/__pycache__/model.cpython-310.pyc +0 -0
- models/model.py +56 -0
- train.py +78 -0
config/__pycache__/config.cpython-310.pyc
CHANGED
Binary files a/config/__pycache__/config.cpython-310.pyc and b/config/__pycache__/config.cpython-310.pyc differ
|
|
config/config.py
CHANGED
@@ -8,4 +8,5 @@ class Config:
|
|
8 |
EPOCHS = 10
|
9 |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
10 |
AWS_S3_BUCKET = 'your-s3-bucket-name'
|
11 |
-
DATASET_PATH = '../Datasets/Flickr8K/'
|
|
|
|
8 |
EPOCHS = 10
|
9 |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
10 |
AWS_S3_BUCKET = 'your-s3-bucket-name'
|
11 |
+
DATASET_PATH = '../Datasets/Flickr8K/'
|
12 |
+
BATCH_SIZE = 32
|
data/__pycache__/dataLoader.cpython-310.pyc
CHANGED
Binary files a/data/__pycache__/dataLoader.cpython-310.pyc and b/data/__pycache__/dataLoader.cpython-310.pyc differ
|
|
data/dataLoader.py
CHANGED
@@ -1,52 +1,59 @@
|
|
1 |
-
import numpy as np
|
2 |
import os
|
3 |
-
import
|
4 |
from PIL import Image
|
5 |
from torchvision import transforms
|
6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
18 |
])
|
19 |
|
20 |
-
def
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
def get_all(self):
|
47 |
-
images = []
|
48 |
-
captions = []
|
49 |
-
for img_name in self.img_list:
|
50 |
-
images.append(self.get_image(img_name))
|
51 |
-
captions.append(self.get_caption(img_name))
|
52 |
-
return images, captions
|
|
|
|
|
1 |
import os
|
2 |
+
import pandas as pd
|
3 |
from PIL import Image
|
4 |
from torchvision import transforms
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
|
7 |
+
class ImageCaptionDataset(Dataset):
|
8 |
+
"""
|
9 |
+
Custom PyTorch Dataset class to handle loading and transforming image-caption pairs
|
10 |
+
where image paths and captions are provided in a CSV file.
|
11 |
+
|
12 |
+
Attributes:
|
13 |
+
caption_file (str): Path to the CSV file containing image paths and captions.
|
14 |
+
transform (torchvision.transforms.Compose): Transformations to apply on the images.
|
15 |
+
"""
|
16 |
|
17 |
+
def __init__(self, caption_file: str, file_path: str, transform=None):
|
18 |
+
"""
|
19 |
+
Initialize dataset with caption CSV file and optional transform.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
caption_file (str): Path to the CSV file where each row has an image path and caption.
|
23 |
+
transform (callable, optional): Optional transform to apply on an image.
|
24 |
+
"""
|
25 |
+
self.df = pd.read_csv(caption_file)
|
26 |
+
self.image_path = file_path
|
27 |
+
self.transform = transform or transforms.Compose([
|
28 |
+
transforms.Resize((224, 224)), # Resize to 224x224 for ViT
|
29 |
+
transforms.ToTensor(), # Convert to tensor
|
30 |
+
# Normalize to have values in the range [0, 1]
|
31 |
+
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
32 |
])
|
33 |
|
34 |
+
def __len__(self):
|
35 |
+
"""
|
36 |
+
Return the total number of samples in the dataset.
|
37 |
+
"""
|
38 |
+
return len(self.df)
|
39 |
+
|
40 |
+
def __getitem__(self, idx):
|
41 |
+
"""
|
42 |
+
Retrieve an image and its corresponding caption by index.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
idx (int): Index of the data item.
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
tuple: (image, caption) where image is the transformed image tensor and caption is the associated text.
|
49 |
+
"""
|
50 |
+
img_path = self.df.iloc[idx, 0] # The first column contains image paths
|
51 |
+
caption = self.df.iloc[idx, 1] # The second column contains captions
|
52 |
+
# Load image
|
53 |
+
image = Image.open(self.image_path+img_path).convert('RGB')
|
54 |
+
|
55 |
+
# Apply transformations to the image
|
56 |
+
if self.transform:
|
57 |
+
image = self.transform(image)
|
58 |
+
|
59 |
+
return image, caption
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main.py
DELETED
@@ -1,14 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
import os
|
3 |
-
import cv2
|
4 |
-
from PIL import Image
|
5 |
-
from matplotlib import pyplot as plt
|
6 |
-
|
7 |
-
from config.config import Config
|
8 |
-
from data.dataLoader import dataLoader
|
9 |
-
|
10 |
-
if __name__ == '__main__':
|
11 |
-
dl = dataLoader(Config.DATASET_PATH)
|
12 |
-
images, captions = dl.get_all()
|
13 |
-
print('Number of images:', len(images))
|
14 |
-
print('Number of captions:', len(captions))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/__pycache__/model.cpython-310.pyc
ADDED
Binary file (2.69 kB). View file
|
|
models/model.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import ViTModel, ViTFeatureExtractor, GPT2LMHeadModel, GPT2Tokenizer
|
3 |
+
from config.config import Config
|
4 |
+
from torchsummary import summary
|
5 |
+
from torchvision import transforms
|
6 |
+
|
7 |
+
class ImageCaptioningModel:
|
8 |
+
def __init__(self):
|
9 |
+
"""Initialize the ViT and GPT-2 models for image captioning."""
|
10 |
+
self.device = Config.DEVICE
|
11 |
+
self.vit_model = ViTModel.from_pretrained(Config.VIT_MODEL).to(self.device)
|
12 |
+
self.feature_extractor = ViTFeatureExtractor.from_pretrained(Config.VIT_MODEL)
|
13 |
+
self.gpt2_model = GPT2LMHeadModel.from_pretrained(Config.GPT2_MODEL).to(self.device)
|
14 |
+
self.tokenizer = GPT2Tokenizer.from_pretrained(Config.GPT2_MODEL)
|
15 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
16 |
+
|
17 |
+
def extract_image_features(self, images):
|
18 |
+
"""Extract features from images using ViT."""
|
19 |
+
pixel_values = self.feature_extractor(images=images, return_tensors="pt", do_rescale=False).pixel_values.to(self.device)
|
20 |
+
with torch.no_grad():
|
21 |
+
outputs = self.vit_model(pixel_values)
|
22 |
+
return outputs.last_hidden_state[:, 0, :] # [batch_size, hidden_size]
|
23 |
+
|
24 |
+
def prepare_gpt2_inputs(self, image_features, captions):
|
25 |
+
"""Prepare GPT-2 inputs."""
|
26 |
+
# Tokenize the captions
|
27 |
+
tokenized_captions = self.tokenizer(captions, padding="longest", truncation=True,
|
28 |
+
max_length=Config.MAX_SEQ_LEN, return_tensors="pt").to(self.device)
|
29 |
+
|
30 |
+
# Get the word embeddings for the tokens
|
31 |
+
token_embeddings = self.gpt2_model.transformer.wte(tokenized_captions['input_ids'])
|
32 |
+
|
33 |
+
# Concatenate image features with token embeddings
|
34 |
+
image_features = image_features.unsqueeze(1) # Reshape to [batch_size, 1, hidden_size]
|
35 |
+
inputs_embeds = torch.cat((image_features, token_embeddings), dim=1) # Concatenate along the sequence dimension
|
36 |
+
|
37 |
+
# Adjust input_ids to account for the image feature token
|
38 |
+
batch_size = image_features.shape[0]
|
39 |
+
image_token_id = torch.full((batch_size, 1), fill_value=self.tokenizer.bos_token_id, device=self.device)
|
40 |
+
input_ids = torch.cat((image_token_id, tokenized_captions['input_ids']), dim=1)
|
41 |
+
|
42 |
+
# Adjust attention_mask to account for the image feature token
|
43 |
+
image_attention = torch.ones((batch_size, 1), device=self.device)
|
44 |
+
attention_mask = torch.cat((image_attention, tokenized_captions['attention_mask']), dim=1)
|
45 |
+
|
46 |
+
return inputs_embeds, input_ids, attention_mask
|
47 |
+
|
48 |
+
def save(self, path):
|
49 |
+
"""Save model to disk."""
|
50 |
+
self.gpt2_model.save_pretrained(path)
|
51 |
+
self.tokenizer.save_pretrained(path)
|
52 |
+
|
53 |
+
def load(self, path):
|
54 |
+
"""Load model from disk."""
|
55 |
+
self.gpt2_model = GPT2LMHeadModel.from_pretrained(path).to(self.device)
|
56 |
+
self.tokenizer = GPT2Tokenizer.from_pretrained(path).to(self.device)
|
train.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from tqdm import tqdm
|
4 |
+
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
from data.dataLoader import ImageCaptionDataset
|
7 |
+
from config.config import Config
|
8 |
+
from models.model import ImageCaptioningModel
|
9 |
+
|
10 |
+
from torchsummary import summary
|
11 |
+
|
12 |
+
|
13 |
+
|
14 |
+
def train_model(model,dataLoader, optimizer, loss_fn):
|
15 |
+
|
16 |
+
model.gpt2_model.train()
|
17 |
+
for epoch in range(Config.EPOCHS):
|
18 |
+
epoch_loss = 0
|
19 |
+
for batch_idx, (images, captions) in tqdm(enumerate(dataLoader)):
|
20 |
+
print(f'\rBatch {batch_idx + 1}/{len(dataLoader)} , Loss : {epoch_loss/(batch_idx+1):.4f}\t', end='')
|
21 |
+
images = images.to(Config.DEVICE)
|
22 |
+
captions = [caption for caption in captions]
|
23 |
+
|
24 |
+
# extract image features
|
25 |
+
image_features = model.extract_image_features(images)
|
26 |
+
# print("Image Features shape:", image_features.shape)
|
27 |
+
input_embeds, input_ids, attention_mask = model.prepare_gpt2_inputs(image_features, captions)
|
28 |
+
|
29 |
+
# print("Input Embeds shape:", input_embeds.shape)
|
30 |
+
# print("Input IDs shape:", input_ids.shape)
|
31 |
+
# print("Attention Mask shape:", attention_mask.shape)
|
32 |
+
# Match Inputs Embeds and Input Ids and Attention Masks
|
33 |
+
assert input_embeds.shape[1] == input_ids.shape[1] == attention_mask.shape[1]
|
34 |
+
|
35 |
+
optimizer.zero_grad()
|
36 |
+
outputs = model.gpt2_model(inputs_embeds=input_embeds, labels=input_ids, attention_mask=attention_mask)
|
37 |
+
|
38 |
+
loss = outputs.loss
|
39 |
+
loss.backward()
|
40 |
+
optimizer.step()
|
41 |
+
|
42 |
+
epoch_loss += loss.item()
|
43 |
+
|
44 |
+
print(f'Epoch {epoch + 1}, Loss: {epoch_loss:.4f}')
|
45 |
+
|
46 |
+
# Save the model
|
47 |
+
model.save('model')
|
48 |
+
|
49 |
+
# return model
|
50 |
+
|
51 |
+
|
52 |
+
if __name__ == '__main__':
|
53 |
+
# Initialize dataset using the CSV file
|
54 |
+
dataset = ImageCaptionDataset(
|
55 |
+
caption_file=Config.DATASET_PATH + 'captions.csv', # Path to captions CSV file
|
56 |
+
file_path = Config.DATASET_PATH+ '/images/', # Path to images folder
|
57 |
+
)
|
58 |
+
|
59 |
+
# Create DataLoader for batch processing
|
60 |
+
dataloader = DataLoader(
|
61 |
+
dataset,
|
62 |
+
batch_size=Config.BATCH_SIZE, # Specify the batch size
|
63 |
+
shuffle=True, # Shuffle the data
|
64 |
+
num_workers=4 # Number of subprocesses for data loading
|
65 |
+
)
|
66 |
+
|
67 |
+
# # Iterate over the dataloader
|
68 |
+
# for batch_idx, (images, captions) in enumerate(dataloader):
|
69 |
+
# print(f'Batch {batch_idx + 1}:')
|
70 |
+
# print(f'Images shape: {images.shape}')
|
71 |
+
# print(f'Captions: {captions}')
|
72 |
+
# # Pass 'images' and 'captions' to your model for training/validation
|
73 |
+
|
74 |
+
# Initialize the ImageCaptioningModel
|
75 |
+
model = ImageCaptioningModel()
|
76 |
+
optimizer = torch.optim.Adam(model.gpt2_model.parameters(), lr=Config.LEARNING_RATE)
|
77 |
+
loss_fn = torch.nn.CrossEntropyLoss()
|
78 |
+
train_model(model, dataloader, optimizer, loss_fn)
|