import os import pandas as pd import numpy as np from PIL import Image import torch from torch.utils.data import Dataset, DataLoader import json import random import cv2 def canny_processor(image, low_threshold=100, high_threshold=200): image = np.array(image) image = cv2.Canny(image, low_threshold, high_threshold) image = image[:, :, None] image = np.concatenate([image, image, image], axis=2) canny_image = Image.fromarray(image) return canny_image def c_crop(image): width, height = image.size new_size = min(width, height) left = (width - new_size) / 2 top = (height - new_size) / 2 right = (width + new_size) / 2 bottom = (height + new_size) / 2 return image.crop((left, top, right, bottom)) class CustomImageDataset(Dataset): def __init__(self, img_dir, img_size=512): self.images = [os.path.join(img_dir, i) for i in os.listdir(img_dir) if '.jpg' in i or '.png' in i] self.images.sort() self.img_size = img_size def __len__(self): return len(self.images) def __getitem__(self, idx): try: img = Image.open(self.images[idx]) img = c_crop(img) img = img.resize((self.img_size, self.img_size)) hint = canny_processor(img) img = torch.from_numpy((np.array(img) / 127.5) - 1) img = img.permute(2, 0, 1) hint = torch.from_numpy((np.array(hint) / 127.5) - 1) hint = hint.permute(2, 0, 1) json_path = self.images[idx].split('.')[0] + '.json' prompt = json.load(open(json_path))['caption'] return img, hint, prompt except Exception as e: print(e) return self.__getitem__(random.randint(0, len(self.images) - 1)) def loader(train_batch_size, num_workers, **args): dataset = CustomImageDataset(**args) return DataLoader(dataset, batch_size=train_batch_size, num_workers=num_workers)