|
import multiprocessing as mp |
|
import pathlib |
|
from typing import Any |
|
|
|
import datasets |
|
from PIL import Image |
|
import torch |
|
from torch.utils.data import Dataset, DataLoader |
|
from torchvision import transforms |
|
|
|
from src import config |
|
from src import tokenizer as tk |
|
|
|
|
|
class CaptionDatset(Dataset): |
|
def __init__(self, dataset: datasets.Dataset, img_path: pathlib.Path) -> None: |
|
self.dataset = dataset |
|
self.img_path = img_path |
|
|
|
def __len__(self) -> int: |
|
return len(self.dataset) |
|
|
|
def __getitem__(self, idx: int) -> dict[str, Any]: |
|
item = self.dataset[idx] |
|
image = Image.open(self.img_path / item["url"].rsplit("/", 1)[-1]).convert("RGB") |
|
return {"image": image, "caption": item["short_caption"]} |
|
|
|
|
|
class CollateFn: |
|
def __init__(self, tokenizer: tk.Tokenizer, transform: transforms.Compose): |
|
self.tokenizer = tokenizer |
|
self.transform = transform |
|
|
|
def __call__(self, batch: list[dict[str, Any]]) -> dict[str, torch.Tensor]: |
|
stacked_images = torch.stack([self.transform(item["image"]) for item in batch]) |
|
tokenized_text = self.tokenizer([item["caption"] for item in batch]) |
|
|
|
return { |
|
"image": stacked_images, |
|
**tokenized_text, |
|
} |
|
|
|
|
|
def _get_dataloaders( |
|
train_ds: Dataset, |
|
valid_ds: Dataset, |
|
training_config: config.TrainerConfig, |
|
collate_fn: CollateFn, |
|
) -> tuple[DataLoader, DataLoader]: |
|
common_params = { |
|
"batch_size": training_config.batch_size, |
|
"pin_memory": True, |
|
"num_workers": mp.cpu_count() // 3, |
|
"collate_fn": collate_fn, |
|
} |
|
train_loader = DataLoader( |
|
train_ds, |
|
shuffle=True, |
|
drop_last=True, |
|
**common_params, |
|
) |
|
valid_loader = DataLoader( |
|
valid_ds, |
|
shuffle=False, |
|
drop_last=False, |
|
**common_params, |
|
) |
|
return train_loader, valid_loader |
|
|
|
|
|
def get_dataset( |
|
transform: transforms.Compose, |
|
tokenizer: tk.Tokenizer, |
|
hyper_parameters: config.TrainerConfig, |
|
) -> tuple[DataLoader, DataLoader]: |
|
dataset: datasets.Dataset = datasets.load_dataset( |
|
hyper_parameters._data_config.dataset, split="train" |
|
) |
|
train_test_dataset = dataset.train_test_split(seed=42, test_size=0.1) |
|
train_ds = CaptionDatset(train_test_dataset["train"], config.IMAGE_DOWNLOAD_PATH) |
|
valid_ds = CaptionDatset(train_test_dataset["test"], config.IMAGE_DOWNLOAD_PATH) |
|
collate_fn = CollateFn(tokenizer, transform) |
|
|
|
return _get_dataloaders( |
|
train_ds=train_ds, |
|
valid_ds=valid_ds, |
|
training_config=hyper_parameters, |
|
collate_fn=collate_fn, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
import os |
|
from tqdm.auto import tqdm |
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
hyper_parameters = config.TrainerConfig() |
|
transform = transforms.Compose([transforms.Resize((128, 128)), transforms.ToTensor()]) |
|
tokenizer = tk.Tokenizer( |
|
hyper_parameters._model_config.text_model, hyper_parameters._model_config.max_len |
|
) |
|
train_dl, valid_dl = get_dataset(transform, tokenizer, hyper_parameters) |
|
|
|
batch = next(iter(train_dl)) |
|
print({k: v.shape for k, v in batch.items()}) |
|
|
|
for batch in tqdm(train_dl): |
|
continue |
|
|