CoAdapter / ldm /data /dataset_wikiart.py
MC-E
first push
c05d22e
raw
history blame
2.15 kB
import json
import os.path
from PIL import Image
from torch.utils.data import DataLoader
from transformers import CLIPProcessor
from torchvision.transforms import transforms
import pytorch_lightning as pl
class WikiArtDataset():
def __init__(self, meta_file):
super(WikiArtDataset, self).__init__()
self.files = []
with open(meta_file, 'r') as f:
js = json.load(f)
for img_path in js:
img_name = os.path.splitext(os.path.basename(img_path))[0]
caption = img_name.split('_')[-1]
caption = caption.split('-')
j = len(caption) - 1
while j >= 0:
if not caption[j].isdigit():
break
j -= 1
if j < 0:
continue
sentence = ' '.join(caption[:j + 1])
self.files.append({'img_path': os.path.join('datasets/wikiart', img_path), 'sentence': sentence})
version = 'openai/clip-vit-large-patch14'
self.processor = CLIPProcessor.from_pretrained(version)
self.jpg_transform = transforms.Compose([
transforms.Resize(512),
transforms.RandomCrop(512),
transforms.ToTensor(),
])
def __getitem__(self, idx):
file = self.files[idx]
im = Image.open(file['img_path'])
im_tensor = self.jpg_transform(im)
clip_im = self.processor(images=im, return_tensors="pt")['pixel_values'][0]
return {'jpg': im_tensor, 'style': clip_im, 'txt': file['sentence']}
def __len__(self):
return len(self.files)
class WikiArtDataModule(pl.LightningDataModule):
def __init__(self, meta_file, batch_size, num_workers):
super(WikiArtDataModule, self).__init__()
self.train_dataset = WikiArtDataset(meta_file)
self.batch_size = batch_size
self.num_workers = num_workers
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers,
pin_memory=True)