import os import json import torch from PIL import Image from torch.utils.data import Dataset from transformers import AutoProcessor from torch.utils.data import DataLoader import pickle import requests from datasets import Dataset, load_dataset import pandas as pd import numpy as np class ClipDataset(Dataset): '''ClipDataset class for loading the CLIP dataset''' def __init__(self, coco_data, model_name, tokenizer): self.tokenizer = tokenizer self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) self.caption_dataset = coco_data def __len__(self): #Return the length of the dataset return len(self.caption_dataset) def __getitem__(self, idx): #Get the image url and caption img_url = self.caption_dataset[idx]["image_url"] caption = self.caption_dataset[idx]["caption"] #Get the image and caption embeddings image = Image.open(requests.get(img_url,stream=True).raw) width, height = image.size new_width = 224 new_height = new_width * height // width new_height = 224 new_width = new_height * width // height image = image.resize((new_width, new_height), Image.LANCZOS) image_processed = self.processor(images=image, return_tensors="pt") ['pixel_values'] image_sqeezed = image_processed.squeeze(0) tokenized_caption = self.tokenizer(caption, return_tensors="pt", return_attention_mask=False) tokenized_caption_ids = tokenized_caption['input_ids'].squeeze(0) return(image_sqeezed , tokenized_caption_ids) def collate_fn_phase1(batch): #Unzip the batch image_embeddings, captions = zip(*batch) #Stack the image embeddings image_embeddings_stacked = torch.stack(image_embeddings, dim=0) #Pad the captions, padded value is the token captions_padded = torch.nn.utils.rnn.pad_sequence(captions, batch_first=True, padding_value=50256) #Return the stacked image embeddings and padded captions return (image_embeddings_stacked, captions_padded) def get_data_loaders_phase1(data_dir, clip_model_name, tokenizer, train_batch_size, val_batch_size, num_workers): # Load the data with open(os.path.join(data_dir, 'coco_train.pkl'), 'rb') as fp: train_pkl = pickle.load(fp) with open(os.path.join(data_dir, "coco_val.pkl"), "rb") as fp: val_pkl = pickle.load(fp) # train data loaders train_dataloader = DataLoader(ClipDataset(train_pkl, clip_model_name, tokenizer), collate_fn=collate_fn_phase1, batch_size=train_batch_size, num_workers = num_workers, shuffle=True, pin_memory=True) # val data loaders val_dataloader = DataLoader(ClipDataset(val_pkl, clip_model_name, tokenizer), collate_fn=collate_fn_phase1, batch_size=val_batch_size, num_workers = num_workers, shuffle=False, pin_memory=True) return train_dataloader, val_dataloader ##################################### Phase 2 ######################################### class ClipDatasetPhase2(Dataset): '''ClipDataset class for loading the CLIP dataset''' def __init__(self, data_frame, model_name, tokenizer): self.tokenizer = tokenizer self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) self.df = data_frame def __len__(self): #Return the length of the dataset return len(self.df) def __getitem__(self, idx): #Get the image url and QAs img_url = self.df.ImageUrl[idx[0]] que = self.df.Question[idx[0]] ans = self.df.Answer[idx[0]] print("img_url", img_url) print("que", que) print("ans", ans) #Get the image and caption embeddings if img_url is None: print("img_url is None") image_sqeezed = None else: image = Image.open(requests.get(img_url,stream=True).raw) width, height = image.size new_width = 224 new_height = new_width * height // width new_height = 224 new_width = new_height * width // height image = image.resize((new_width, new_height), Image.LANCZOS) image_processed = self.processor(images=image, return_tensors="pt") ['pixel_values'] image_sqeezed = image_processed.squeeze(0) que_ids = self.tokenizer(que, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0) ans_ids = self.tokenizer(ans, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0) return(image_sqeezed , que_ids, ans_ids) def collate_fn_phase2(batch): #Unzip the batch image_embeddings, ques, ans = zip(*batch) #Stack the image embeddings if image_embeddings[0] is None: image_embeddings_stacked = None else: image_embeddings_stacked = torch.stack(image_embeddings, dim=0) #Pad the QAs, padded value is the token ques_padded = torch.nn.utils.rnn.pad_sequence(ques, batch_first=True, padding_value=50256) ans_padded = torch.nn.utils.rnn.pad_sequence(ans, batch_first=True, padding_value=50256) #Return the stacked image embeddings and padded QAs return (image_embeddings_stacked, ques_padded, ans_padded) def prep_data(df): df_assistant = df[(df.role == "assistant") & (df["rank"] == 0.0)].copy() df_prompter = df[(df.role == "prompter")].copy() df_prompter = df_prompter.set_index("message_id") df_assistant["Answer"] = df_assistant["text"].values inputs = [] for _, row in df_assistant.iterrows(): input = df_prompter.loc[row.parent_id] inputs.append(input.text) df_assistant["Question"] = inputs df_assistant["ImageUrl"] = None df_assistant = df_assistant[df_assistant.lang == "en"] df_assistant = df_assistant[ ["ImageUrl","Question", "Answer", "message_id"] ].rename(columns={"message_id": "Ids"}) return df_assistant def get_i150_df(config): with open(config.get("i150k_json"), "r") as fp: i150k_json_read = json.load(fp) max_tokens = 100 image_urls = [] ques_list = [] ans_list = [] id_list = [] for idx, data in enumerate(i150k_json_read): image = data['image'] image_url = 'http://images.cocodataset.org/train2017/' + image id_ = data["id"] iterator = iter(data['conversations']) for i in iterator: ques = i ans = next(iterator) if (len(ques["value"])>100 or len(ans["value"])>max_tokens): continue if ques["from"] == "human" and ans["from"] == "gpt": image_urls.append(image_url) ques_list.append(ques["value"].replace("\n","").replace("","")) ans_list.append(ans["value"]) id_list.append(id_) df_i150k = pd.DataFrame(list(zip(image_urls, ques_list, ans_list, id_list)), columns =["ImageUrl", "Question", "Answer", "Ids"]) msk = np.random.rand(len(df_i150k)) < 0.96 train_df = df_i150k[msk] test_df = df_i150k[~msk] return train_df, test_df def get_oas_df(config): train_ds, val_ds = load_dataset(config.get("QA_datasetName"), split=["train", "validation"]) train_df = prep_data(train_ds.to_pandas()) test_df = prep_data(val_ds.to_pandas()) return train_df, test_df def get_data_loaders_phase2(tokenizer, config): train_i150k, test_i150k = get_i150_df(config) train_oas, test_oas = get_oas_df(config) train_df = pd.concat([train_i150k, train_oas]).reset_index(drop=True) val_df = pd.concat([test_i150k, test_oas]).reset_index(drop=True) # train data loaders train_dataloader = DataLoader(ClipDatasetPhase2(train_df, config.get("clip_model_name"), tokenizer), collate_fn=collate_fn_phase2, batch_size=config.get("train_batch_size"), num_workers = config.get("num_workers"), shuffle=True, pin_memory=True) # val data loaders val_dataloader = DataLoader(ClipDatasetPhase2(val_df, config.get("clip_model_name"), tokenizer), collate_fn=collate_fn_phase2, batch_size=config.get("val_batch_size"), num_workers = config.get("num_workers"), shuffle=False, pin_memory=True) return train_dataloader, val_dataloader