tiny_clip / src /data.py
sachin's picture
Change where images weres stored
a8c8fe0
raw
history blame
3.35 kB
import multiprocessing as mp
import pathlib
from typing import Any
import datasets
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from src import config
from src import tokenizer as tk
class CaptionDatset(Dataset):
def __init__(self, dataset: datasets.Dataset, img_path: pathlib.Path) -> None:
self.dataset = dataset
self.img_path = img_path
def __len__(self) -> int:
return len(self.dataset)
def __getitem__(self, idx: int) -> dict[str, Any]:
item = self.dataset[idx]
image = Image.open(self.img_path / item["url"].rsplit("/", 1)[-1]).convert("RGB")
return {"image": image, "caption": item["short_caption"]}
class CollateFn:
def __init__(self, tokenizer: tk.Tokenizer, transform: transforms.Compose):
self.tokenizer = tokenizer
self.transform = transform
def __call__(self, batch: list[dict[str, Any]]) -> dict[str, torch.Tensor]:
stacked_images = torch.stack([self.transform(item["image"]) for item in batch])
tokenized_text = self.tokenizer([item["caption"] for item in batch])
return {
"image": stacked_images,
**tokenized_text,
}
def _get_dataloaders(
train_ds: Dataset,
valid_ds: Dataset,
training_config: config.TrainerConfig,
collate_fn: CollateFn,
) -> tuple[DataLoader, DataLoader]:
common_params = {
"batch_size": training_config.batch_size,
"pin_memory": True,
"num_workers": mp.cpu_count() // 3,
"collate_fn": collate_fn,
}
train_loader = DataLoader(
train_ds,
shuffle=True,
drop_last=True,
**common_params,
)
valid_loader = DataLoader(
valid_ds,
shuffle=False,
drop_last=False,
**common_params,
)
return train_loader, valid_loader
def get_dataset(
transform: transforms.Compose,
tokenizer: tk.Tokenizer,
hyper_parameters: config.TrainerConfig,
) -> tuple[DataLoader, DataLoader]:
dataset: datasets.Dataset = datasets.load_dataset(
hyper_parameters._data_config.dataset, split="train"
) # type: ignore
train_test_dataset = dataset.train_test_split(seed=42, test_size=0.1)
train_ds = CaptionDatset(train_test_dataset["train"], config.IMAGE_DOWNLOAD_PATH)
valid_ds = CaptionDatset(train_test_dataset["test"], config.IMAGE_DOWNLOAD_PATH)
collate_fn = CollateFn(tokenizer, transform)
return _get_dataloaders(
train_ds=train_ds,
valid_ds=valid_ds,
training_config=hyper_parameters,
collate_fn=collate_fn,
)
if __name__ == "__main__":
# do not want to do these imports in general
import os
from tqdm.auto import tqdm
os.environ["TOKENIZERS_PARALLELISM"] = "false"
hyper_parameters = config.TrainerConfig()
transform = transforms.Compose([transforms.Resize((128, 128)), transforms.ToTensor()])
tokenizer = tk.Tokenizer(
hyper_parameters._model_config.text_model, hyper_parameters._model_config.max_len
)
train_dl, valid_dl = get_dataset(transform, tokenizer, hyper_parameters)
batch = next(iter(train_dl))
print({k: v.shape for k, v in batch.items()}) # torch.Size([1, 3, 128, 128])
for batch in tqdm(train_dl):
continue