|
import random |
|
import math |
|
import os |
|
import torch |
|
import torch.utils.checkpoint |
|
from torch.utils.data import Dataset |
|
from torchvision import transforms |
|
from tqdm.auto import tqdm |
|
import numpy as np |
|
from PIL import Image |
|
from trainer_util import * |
|
from clip_segmentation import ClipSeg |
|
|
|
class bcolors: |
|
HEADER = '\033[95m' |
|
OKBLUE = '\033[94m' |
|
OKCYAN = '\033[96m' |
|
OKGREEN = '\033[92m' |
|
WARNING = '\033[93m' |
|
FAIL = '\033[91m' |
|
ENDC = '\033[0m' |
|
BOLD = '\033[1m' |
|
UNDERLINE = '\033[4m' |
|
ASPECT_2048 = [[2048, 2048], |
|
[2112, 1984],[1984, 2112], |
|
[2176, 1920],[1920, 2176], |
|
[2240, 1856],[1856, 2240], |
|
[2304, 1792],[1792, 2304], |
|
[2368, 1728],[1728, 2368], |
|
[2432, 1664],[1664, 2432], |
|
[2496, 1600],[1600, 2496], |
|
[2560, 1536],[1536, 2560], |
|
[2624, 1472],[1472, 2624]] |
|
ASPECT_1984 = [[1984, 1984], |
|
[2048, 1920],[1920, 2048], |
|
[2112, 1856],[1856, 2112], |
|
[2176, 1792],[1792, 2176], |
|
[2240, 1728],[1728, 2240], |
|
[2304, 1664],[1664, 2304], |
|
[2368, 1600],[1600, 2368], |
|
[2432, 1536],[1536, 2432], |
|
[2496, 1472],[1472, 2496], |
|
[2560, 1408],[1408, 2560]] |
|
ASPECT_1920 = [[1920, 1920], |
|
[1984, 1856],[1856, 1984], |
|
[2048, 1792],[1792, 2048], |
|
[2112, 1728],[1728, 2112], |
|
[2176, 1664],[1664, 2176], |
|
[2240, 1600],[1600, 2240], |
|
[2304, 1536],[1536, 2304], |
|
[2368, 1472],[1472, 2368], |
|
[2432, 1408],[1408, 2432], |
|
[2496, 1344],[1344, 2496]] |
|
ASPECT_1856 = [[1856, 1856], |
|
[1920, 1792],[1792, 1920], |
|
[1984, 1728],[1728, 1984], |
|
[2048, 1664],[1664, 2048], |
|
[2112, 1600],[1600, 2112], |
|
[2176, 1536],[1536, 2176], |
|
[2240, 1472],[1472, 2240], |
|
[2304, 1408],[1408, 2304], |
|
[2368, 1344],[1344, 2368], |
|
[2432, 1280],[1280, 2432]] |
|
ASPECT_1792 = [[1792, 1792], |
|
[1856, 1728],[1728, 1856], |
|
[1920, 1664],[1664, 1920], |
|
[1984, 1600],[1600, 1984], |
|
[2048, 1536],[1536, 2048], |
|
[2112, 1472],[1472, 2112], |
|
[2176, 1408],[1408, 2176], |
|
[2240, 1344],[1344, 2240], |
|
[2304, 1280],[1280, 2304], |
|
[2368, 1216],[1216, 2368]] |
|
ASPECT_1728 = [[1728, 1728], |
|
[1792, 1664],[1664, 1792], |
|
[1856, 1600],[1600, 1856], |
|
[1920, 1536],[1536, 1920], |
|
[1984, 1472],[1472, 1984], |
|
[2048, 1408],[1408, 2048], |
|
[2112, 1344],[1344, 2112], |
|
[2176, 1280],[1280, 2176], |
|
[2240, 1216],[1216, 2240], |
|
[2304, 1152],[1152, 2304]] |
|
ASPECT_1664 = [[1664, 1664], |
|
[1728, 1600],[1600, 1728], |
|
[1792, 1536],[1536, 1792], |
|
[1856, 1472],[1472, 1856], |
|
[1920, 1408],[1408, 1920], |
|
[1984, 1344],[1344, 1984], |
|
[2048, 1280],[1280, 2048], |
|
[2112, 1216],[1216, 2112], |
|
[2176, 1152],[1152, 2176], |
|
[2240, 1088],[1088, 2240]] |
|
ASPECT_1600 = [[1600, 1600], |
|
[1664, 1536],[1536, 1664], |
|
[1728, 1472],[1472, 1728], |
|
[1792, 1408],[1408, 1792], |
|
[1856, 1344],[1344, 1856], |
|
[1920, 1280],[1280, 1920], |
|
[1984, 1216],[1216, 1984], |
|
[2048, 1152],[1152, 2048], |
|
[2112, 1088],[1088, 2112], |
|
[2176, 1024],[1024, 2176]] |
|
ASPECT_1536 = [[1536, 1536], |
|
[1600, 1472],[1472, 1600], |
|
[1664, 1408],[1408, 1664], |
|
[1728, 1344],[1344, 1728], |
|
[1792, 1280],[1280, 1792], |
|
[1856, 1216],[1216, 1856], |
|
[1920, 1152],[1152, 1920], |
|
[1984, 1088],[1088, 1984], |
|
[2048, 1024],[1024, 2048], |
|
[2112, 960],[960, 2112]] |
|
ASPECT_1472 = [[1472, 1472], |
|
[1536, 1408],[1408, 1536], |
|
[1600, 1344],[1344, 1600], |
|
[1664, 1280],[1280, 1664], |
|
[1728, 1216],[1216, 1728], |
|
[1792, 1152],[1152, 1792], |
|
[1856, 1088],[1088, 1856], |
|
[1920, 1024],[1024, 1920], |
|
[1984, 960],[960, 1984], |
|
[2048, 896],[896, 2048]] |
|
ASPECT_1408 = [[1408, 1408], |
|
[1472, 1344],[1344, 1472], |
|
[1536, 1280],[1280, 1536], |
|
[1600, 1216],[1216, 1600], |
|
[1664, 1152],[1152, 1664], |
|
[1728, 1088],[1088, 1728], |
|
[1792, 1024],[1024, 1792], |
|
[1856, 960],[960, 1856], |
|
[1920, 896],[896, 1920], |
|
[1984, 832],[832, 1984]] |
|
ASPECT_1344 = [[1344, 1344], |
|
[1408, 1280],[1280, 1408], |
|
[1472, 1216],[1216, 1472], |
|
[1536, 1152],[1152, 1536], |
|
[1600, 1088],[1088, 1600], |
|
[1664, 1024],[1024, 1664], |
|
[1728, 960],[960, 1728], |
|
[1792, 896],[896, 1792], |
|
[1856, 832],[832, 1856], |
|
[1920, 768],[768, 1920]] |
|
ASPECT_1280 = [[1280, 1280], |
|
[1344, 1216],[1216, 1344], |
|
[1408, 1152],[1152, 1408], |
|
[1472, 1088],[1088, 1472], |
|
[1536, 1024],[1024, 1536], |
|
[1600, 960],[960, 1600], |
|
[1664, 896],[896, 1664], |
|
[1728, 832],[832, 1728], |
|
[1792, 768],[768, 1792], |
|
[1856, 704],[704, 1856]] |
|
ASPECT_1216 = [[1216, 1216], |
|
[1280, 1152],[1152, 1280], |
|
[1344, 1088],[1088, 1344], |
|
[1408, 1024],[1024, 1408], |
|
[1472, 960],[960, 1472], |
|
[1536, 896],[896, 1536], |
|
[1600, 832],[832, 1600], |
|
[1664, 768],[768, 1664], |
|
[1728, 704],[704, 1728], |
|
[1792, 640],[640, 1792]] |
|
ASPECT_1152 = [[1152, 1152], |
|
[1216, 1088],[1088, 1216], |
|
[1280, 1024],[1024, 1280], |
|
[1344, 960],[960, 1344], |
|
[1408, 896],[896, 1408], |
|
[1472, 832],[832, 1472], |
|
[1536, 768],[768, 1536], |
|
[1600, 704],[704, 1600], |
|
[1664, 640],[640, 1664], |
|
[1728, 576],[576, 1728]] |
|
ASPECT_1088 = [[1088, 1088], |
|
[1152, 1024],[1024, 1152], |
|
[1216, 960],[960, 1216], |
|
[1280, 896],[896, 1280], |
|
[1344, 832],[832, 1344], |
|
[1408, 768],[768, 1408], |
|
[1472, 704],[704, 1472], |
|
[1536, 640],[640, 1536], |
|
[1600, 576],[576, 1600], |
|
[1664, 512],[512, 1664]] |
|
ASPECT_832 = [[832, 832], |
|
[896, 768], [768, 896], |
|
[960, 704], [704, 960], |
|
[1024, 640], [640, 1024], |
|
[1152, 576], [576, 1152], |
|
[1280, 512], [512, 1280], |
|
[1344, 512], [512, 1344], |
|
[1408, 448], [448, 1408], |
|
[1472, 448], [448, 1472], |
|
[1536, 384], [384, 1536], |
|
[1600, 384], [384, 1600]] |
|
|
|
ASPECT_896 = [[896, 896], |
|
[960, 832], [832, 960], |
|
[1024, 768], [768, 1024], |
|
[1088, 704], [704, 1088], |
|
[1152, 704], [704, 1152], |
|
[1216, 640], [640, 1216], |
|
[1280, 640], [640, 1280], |
|
[1344, 576], [576, 1344], |
|
[1408, 576], [576, 1408], |
|
[1472, 512], [512, 1472], |
|
[1536, 512], [512, 1536], |
|
[1600, 448], [448, 1600], |
|
[1664, 448], [448, 1664]] |
|
ASPECT_960 = [[960, 960], |
|
[1024, 896],[896, 1024], |
|
[1088, 832],[832, 1088], |
|
[1152, 768],[768, 1152], |
|
[1216, 704],[704, 1216], |
|
[1280, 640],[640, 1280], |
|
[1344, 576],[576, 1344], |
|
[1408, 512],[512, 1408], |
|
[1472, 448],[448, 1472], |
|
[1536, 384],[384, 1536]] |
|
ASPECT_1024 = [[1024, 1024], |
|
[1088, 960], [960, 1088], |
|
[1152, 896], [896, 1152], |
|
[1216, 832], [832, 1216], |
|
[1344, 768], [768, 1344], |
|
[1472, 704], [704, 1472], |
|
[1600, 640], [640, 1600], |
|
[1728, 576], [576, 1728], |
|
[1792, 576], [576, 1792]] |
|
ASPECT_768 = [[768,768], |
|
[896,640],[640,896], |
|
[832,704],[704,832], |
|
[960,576],[576,960], |
|
[1024,576],[576,1024], |
|
[1088,512],[512,1088], |
|
[1152,512],[512,1152], |
|
[1216,448],[448,1216], |
|
[1280,448],[448,1280], |
|
[1344,384],[384,1344], |
|
[1408,384],[384,1408], |
|
[1472,320],[320,1472], |
|
[1536,320],[320,1536], |
|
] |
|
|
|
ASPECT_704 = [[704,704], |
|
[768,640],[640,768], |
|
[832,576],[576,832], |
|
[896,512],[512,896], |
|
[960,512],[512,960], |
|
[1024,448],[448,1024], |
|
[1088,448],[448,1088], |
|
[1152,384],[384,1152], |
|
[1216,384],[384,1216], |
|
[1280,384],[384,1280], |
|
[1280,320],[320,1280], |
|
[1408,320],[320,1408], |
|
[1536,320],[320,1536], |
|
] |
|
|
|
ASPECT_640 = [[640,640], |
|
[704,576],[576,704], |
|
[768,512],[512,768], |
|
[896,448],[448,896], |
|
[1024,384],[384,1024], |
|
[1280,320],[320,1280], |
|
[1408,256],[256,1408], |
|
[1472,256],[256,1472], |
|
[1536,256],[256,1536], |
|
[1600,256],[256,1600], |
|
] |
|
|
|
ASPECT_576 = [[576,576], |
|
[640,512],[512,640], |
|
[640,448],[448,640], |
|
[704,448],[448,704], |
|
[832,384],[384,832], |
|
[1024,320],[320,1024], |
|
[1280,256],[256,1280], |
|
] |
|
|
|
ASPECT_512 = [[512,512], |
|
[576,448],[448,576], |
|
[640,384],[384,640], |
|
[768,320],[320,768], |
|
[832,256],[256,832], |
|
[896,256],[256,896], |
|
[960,256],[256,960], |
|
[1024,256],[256,1024], |
|
] |
|
|
|
ASPECT_448 = [[448,448], |
|
[512,384],[384,512], |
|
[576,320],[320,576], |
|
[768,256],[256,768], |
|
] |
|
|
|
ASPECT_384 = [[384,384], |
|
[448,320],[320,448], |
|
[576,256],[256,576], |
|
[768,192],[192,768], |
|
] |
|
|
|
ASPECT_320 = [[320,320], |
|
[384,256],[256,384], |
|
[512,192],[192,512], |
|
] |
|
|
|
ASPECT_256 = [[256,256], |
|
[320,192],[192,320], |
|
[512,128],[128,512], |
|
] |
|
|
|
|
|
ASPECTS = ASPECT_512 |
|
def get_aspect_buckets(resolution,mode=''): |
|
if resolution < 256: |
|
raise ValueError("Resolution must be at least 512") |
|
try: |
|
rounded_resolution = int(resolution / 64) * 64 |
|
print(f" {bcolors.WARNING} Rounded resolution to: {rounded_resolution}{bcolors.ENDC}") |
|
all_image_sizes = __get_all_aspects() |
|
if mode == 'MJ': |
|
|
|
all_image_sizes = [x[0:3] for x in all_image_sizes] |
|
aspects = next(filter(lambda sizes: sizes[0][0]==rounded_resolution, all_image_sizes), None) |
|
ASPECTS = aspects |
|
|
|
return aspects |
|
except Exception as e: |
|
print(f" {bcolors.FAIL} *** Could not find selected resolution: {rounded_resolution}{bcolors.ENDC}") |
|
|
|
raise e |
|
|
|
def __get_all_aspects(): |
|
return [ASPECT_256, ASPECT_320, ASPECT_384, ASPECT_448, ASPECT_512, ASPECT_576, ASPECT_640, ASPECT_704, ASPECT_768,ASPECT_832,ASPECT_896,ASPECT_960,ASPECT_1024,ASPECT_1088,ASPECT_1152,ASPECT_1216,ASPECT_1280,ASPECT_1344,ASPECT_1408,ASPECT_1472,ASPECT_1536,ASPECT_1600,ASPECT_1664,ASPECT_1728,ASPECT_1792,ASPECT_1856,ASPECT_1920,ASPECT_1984,ASPECT_2048] |
|
class AutoBucketing(Dataset): |
|
def __init__(self, |
|
concepts_list, |
|
tokenizer=None, |
|
flip_p=0.0, |
|
repeats=1, |
|
debug_level=0, |
|
batch_size=1, |
|
set='val', |
|
resolution=512, |
|
center_crop=False, |
|
use_image_names_as_captions=True, |
|
shuffle_captions=False, |
|
add_class_images_to_dataset=None, |
|
balance_datasets=False, |
|
crop_jitter=20, |
|
with_prior_loss=False, |
|
use_text_files_as_captions=False, |
|
aspect_mode='dynamic', |
|
action_preference='dynamic', |
|
seed=555, |
|
model_variant='base', |
|
extra_module=None, |
|
mask_prompts=None, |
|
load_mask=False, |
|
): |
|
|
|
self.debug_level = debug_level |
|
self.resolution = resolution |
|
self.center_crop = center_crop |
|
self.tokenizer = tokenizer |
|
self.batch_size = batch_size |
|
self.concepts_list = concepts_list |
|
self.use_image_names_as_captions = use_image_names_as_captions |
|
self.shuffle_captions = shuffle_captions |
|
self.num_train_images = 0 |
|
self.num_reg_images = 0 |
|
self.image_train_items = [] |
|
self.image_reg_items = [] |
|
self.add_class_images_to_dataset = add_class_images_to_dataset |
|
self.balance_datasets = balance_datasets |
|
self.crop_jitter = crop_jitter |
|
self.with_prior_loss = with_prior_loss |
|
self.use_text_files_as_captions = use_text_files_as_captions |
|
self.aspect_mode = aspect_mode |
|
self.action_preference = action_preference |
|
self.model_variant = model_variant |
|
self.extra_module = extra_module |
|
self.image_transforms = transforms.Compose( |
|
[ |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.5], [0.5]), |
|
] |
|
) |
|
self.mask_transforms = transforms.Compose( |
|
[ |
|
transforms.ToTensor(), |
|
] |
|
) |
|
self.depth_image_transforms = transforms.Compose( |
|
[ |
|
transforms.ToTensor(), |
|
] |
|
) |
|
self.seed = seed |
|
|
|
print(f" {bcolors.WARNING}Creating Auto Bucketing Dataloader{bcolors.ENDC}") |
|
|
|
shared_dataloader = DataLoaderMultiAspect(concepts_list, |
|
debug_level=debug_level, |
|
resolution=self.resolution, |
|
seed=self.seed, |
|
batch_size=self.batch_size, |
|
flip_p=flip_p, |
|
use_image_names_as_captions=self.use_image_names_as_captions, |
|
add_class_images_to_dataset=self.add_class_images_to_dataset, |
|
balance_datasets=self.balance_datasets, |
|
with_prior_loss=self.with_prior_loss, |
|
use_text_files_as_captions=self.use_text_files_as_captions, |
|
aspect_mode=self.aspect_mode, |
|
action_preference=self.action_preference, |
|
model_variant=self.model_variant, |
|
extra_module=self.extra_module, |
|
mask_prompts=mask_prompts, |
|
load_mask=load_mask, |
|
) |
|
|
|
|
|
if self.with_prior_loss and self.add_class_images_to_dataset == False: |
|
self.image_train_items, self.class_train_items = shared_dataloader.get_all_images() |
|
self.num_train_images = self.num_train_images + len(self.image_train_items) |
|
self.num_reg_images = self.num_reg_images + len(self.class_train_items) |
|
self._length = max(max(math.trunc(self.num_train_images * repeats), batch_size),math.trunc(self.num_reg_images * repeats), batch_size) - self.num_train_images % self.batch_size |
|
self.num_train_images = self.num_train_images + self.num_reg_images |
|
|
|
else: |
|
self.image_train_items = shared_dataloader.get_all_images() |
|
self.num_train_images = self.num_train_images + len(self.image_train_items) |
|
self._length = max(math.trunc(self.num_train_images * repeats), batch_size) - self.num_train_images % self.batch_size |
|
|
|
print() |
|
print(f" {bcolors.WARNING} ** Validation Set: {set}, steps: {self._length / batch_size:.0f}, repeats: {repeats} {bcolors.ENDC}") |
|
print() |
|
|
|
|
|
def __len__(self): |
|
return self._length |
|
|
|
def __getitem__(self, i): |
|
idx = i % self.num_train_images |
|
|
|
image_train_item = self.image_train_items[idx] |
|
|
|
example = self.__get_image_for_trainer(image_train_item,debug_level=self.debug_level) |
|
if self.with_prior_loss and self.add_class_images_to_dataset == False: |
|
idx = i % self.num_reg_images |
|
class_train_item = self.class_train_items[idx] |
|
example_class = self.__get_image_for_trainer(class_train_item,debug_level=self.debug_level,class_img=True) |
|
example= {**example, **example_class} |
|
|
|
|
|
|
|
|
|
return example |
|
def normalize8(self,I): |
|
mn = I.min() |
|
mx = I.max() |
|
|
|
mx -= mn |
|
|
|
I = ((I - mn)/mx) * 255 |
|
return I.astype(np.uint8) |
|
def __get_image_for_trainer(self,image_train_item,debug_level=0,class_img=False): |
|
example = {} |
|
save = debug_level > 2 |
|
|
|
if class_img==False: |
|
image_train_tmp = image_train_item.hydrate(crop=False, save=0, crop_jitter=self.crop_jitter) |
|
image_train_tmp_image = Image.fromarray(self.normalize8(image_train_tmp.image)).convert("RGB") |
|
|
|
instance_prompt = image_train_tmp.caption |
|
if self.shuffle_captions: |
|
caption_parts = instance_prompt.split(",") |
|
random.shuffle(caption_parts) |
|
instance_prompt = ",".join(caption_parts) |
|
|
|
example["instance_images"] = self.image_transforms(image_train_tmp_image) |
|
if image_train_tmp.mask is not None: |
|
image_train_tmp_mask = Image.fromarray(self.normalize8(image_train_tmp.mask)).convert("L") |
|
example["mask"] = self.mask_transforms(image_train_tmp_mask) |
|
if self.model_variant == 'depth2img': |
|
image_train_tmp_depth = Image.fromarray(self.normalize8(image_train_tmp.extra)).convert("L") |
|
example["instance_depth_images"] = self.depth_image_transforms(image_train_tmp_depth) |
|
|
|
example["instance_prompt_ids"] = self.tokenizer( |
|
instance_prompt, |
|
padding="do_not_pad", |
|
truncation=True, |
|
max_length=self.tokenizer.model_max_length, |
|
).input_ids |
|
image_train_item.self_destruct() |
|
return example |
|
|
|
if class_img==True: |
|
image_train_tmp = image_train_item.hydrate(crop=False, save=4, crop_jitter=self.crop_jitter) |
|
image_train_tmp_image = Image.fromarray(self.normalize8(image_train_tmp.image)).convert("RGB") |
|
if self.model_variant == 'depth2img': |
|
image_train_tmp_depth = Image.fromarray(self.normalize8(image_train_tmp.extra)).convert("L") |
|
example["class_depth_images"] = self.depth_image_transforms(image_train_tmp_depth) |
|
example["class_images"] = self.image_transforms(image_train_tmp_image) |
|
example["class_prompt_ids"] = self.tokenizer( |
|
image_train_tmp.caption, |
|
padding="do_not_pad", |
|
truncation=True, |
|
max_length=self.tokenizer.model_max_length, |
|
).input_ids |
|
image_train_item.self_destruct() |
|
return example |
|
|
|
_RANDOM_TRIM = 0.04 |
|
class ImageTrainItem(): |
|
""" |
|
image: Image |
|
mask: Image |
|
extra: Image |
|
identifier: caption, |
|
target_aspect: (width, height), |
|
pathname: path to image file |
|
flip_p: probability of flipping image (0.0 to 1.0) |
|
""" |
|
def __init__(self, image: Image, mask: Image, extra: Image, caption: str, target_wh: list, pathname: str, flip_p=0.0, model_variant='base', load_mask=False): |
|
self.caption = caption |
|
self.target_wh = target_wh |
|
self.pathname = pathname |
|
self.mask_pathname = os.path.splitext(pathname)[0] + "-masklabel.png" |
|
self.depth_pathname = os.path.splitext(pathname)[0] + "-depth.png" |
|
self.flip_p = flip_p |
|
self.flip = transforms.RandomHorizontalFlip(p=flip_p) |
|
self.cropped_img = None |
|
self.model_variant = model_variant |
|
self.load_mask=load_mask |
|
self.is_dupe = [] |
|
self.variant_warning = False |
|
|
|
self.image = image |
|
self.mask = mask |
|
self.extra = extra |
|
|
|
def self_destruct(self): |
|
self.image = None |
|
self.mask = None |
|
self.extra = None |
|
self.cropped_img = None |
|
self.is_dupe.append(1) |
|
|
|
def load_image(self, pathname, crop, jitter_amount, flip): |
|
if len(self.is_dupe) > 0: |
|
self.flip = transforms.RandomHorizontalFlip(p=1.0 if flip else 0.0) |
|
image = Image.open(pathname).convert('RGB') |
|
|
|
width, height = image.size |
|
if crop: |
|
cropped_img = self.__autocrop(image) |
|
image = cropped_img.resize((512, 512), resample=Image.Resampling.LANCZOS) |
|
else: |
|
width, height = image.size |
|
|
|
if self.target_wh[0] == self.target_wh[1]: |
|
if width > height: |
|
left = random.randint(0, width - height) |
|
image = image.crop((left, 0, height + left, height)) |
|
width = height |
|
elif height > width: |
|
top = random.randint(0, height - width) |
|
image = image.crop((0, top, width, width + top)) |
|
height = width |
|
elif width > self.target_wh[0]: |
|
slice = min(int(self.target_wh[0] * _RANDOM_TRIM), width - self.target_wh[0]) |
|
slicew_ratio = random.random() |
|
left = int(slice * slicew_ratio) |
|
right = width - int(slice * (1 - slicew_ratio)) |
|
sliceh_ratio = random.random() |
|
top = int(slice * sliceh_ratio) |
|
bottom = height - int(slice * (1 - sliceh_ratio)) |
|
|
|
image = image.crop((left, top, right, bottom)) |
|
else: |
|
image_aspect = width / height |
|
target_aspect = self.target_wh[0] / self.target_wh[1] |
|
if image_aspect > target_aspect: |
|
new_width = int(height * target_aspect) |
|
jitter_amount = max(min(jitter_amount, int(abs(width - new_width) / 2)), 0) |
|
left = jitter_amount |
|
right = left + new_width |
|
image = image.crop((left, 0, right, height)) |
|
else: |
|
new_height = int(width / target_aspect) |
|
jitter_amount = max(min(jitter_amount, int(abs(height - new_height) / 2)), 0) |
|
top = jitter_amount |
|
bottom = top + new_height |
|
image = image.crop((0, top, width, bottom)) |
|
|
|
image = image.resize(self.target_wh, resample=Image.Resampling.LANCZOS) |
|
|
|
|
|
|
|
|
|
image = self.flip(image) |
|
return image |
|
|
|
def hydrate(self, crop=False, save=False, crop_jitter=20): |
|
""" |
|
crop: hard center crop to 512x512 |
|
save: save the cropped image to disk, for manual inspection of resize/crop |
|
crop_jitter: randomly shift cropp by N pixels when using multiple aspect ratios to improve training quality |
|
""" |
|
|
|
if self.image is None: |
|
chance = float(len(self.is_dupe)) / 10.0 |
|
|
|
flip_p = self.flip_p + chance if chance < 1.0 else 1.0 |
|
flip = random.uniform(0, 1) < flip_p |
|
|
|
if len(self.is_dupe) > 0: |
|
crop_jitter = crop_jitter + (len(self.is_dupe) * 10) if crop_jitter < 50 else 50 |
|
|
|
jitter_amount = random.randint(0, crop_jitter) |
|
|
|
self.image = self.load_image(self.pathname, crop, jitter_amount, flip) |
|
|
|
if self.model_variant == "inpainting" or self.load_mask: |
|
if os.path.exists(self.mask_pathname) and self.load_mask: |
|
self.mask = self.load_image(self.mask_pathname, crop, jitter_amount, flip) |
|
else: |
|
if self.variant_warning == False: |
|
print(f" {bcolors.FAIL} ** Warning: No mask found for an image, using an empty mask but make sure you're training the right model variant.{bcolors.ENDC}") |
|
self.variant_warning = True |
|
self.mask = Image.new('RGB', self.image.size, color="white").convert("L") |
|
|
|
if self.model_variant == "depth2img": |
|
if os.path.exists(self.depth_pathname): |
|
self.extra = self.load_image(self.depth_pathname, crop, jitter_amount, flip) |
|
else: |
|
if self.variant_warning == False: |
|
print(f" {bcolors.FAIL} ** Warning: No depth found for an image, using an empty depth but make sure you're training the right model variant.{bcolors.ENDC}") |
|
self.variant_warning = True |
|
self.extra = Image.new('RGB', self.image.size, color="white").convert("L") |
|
if type(self.image) is not np.ndarray: |
|
if save: |
|
base_name = os.path.basename(self.pathname) |
|
if not os.path.exists("test/output"): |
|
os.makedirs("test/output") |
|
self.image.save(f"test/output/{base_name}") |
|
|
|
self.image = np.array(self.image).astype(np.uint8) |
|
|
|
self.image = (self.image / 127.5 - 1.0).astype(np.float32) |
|
if self.mask is not None and type(self.mask) is not np.ndarray: |
|
self.mask = np.array(self.mask).astype(np.uint8) |
|
|
|
self.mask = (self.mask / 255.0).astype(np.float32) |
|
if self.extra is not None and type(self.extra) is not np.ndarray: |
|
self.extra = np.array(self.extra).astype(np.uint8) |
|
|
|
self.extra = (self.extra / 255.0).astype(np.float32) |
|
|
|
|
|
|
|
return self |
|
class CachedLatentsDataset(Dataset): |
|
|
|
def __init__(self, cache_paths=(),batch_size=None,tokenizer=None,text_encoder=None,dtype=None,model_variant='base',shuffle_per_epoch=False,args=None): |
|
self.cache_paths = cache_paths |
|
self.tokenizer = tokenizer |
|
self.args = args |
|
self.text_encoder = text_encoder |
|
|
|
text_encoder_device = next(self.text_encoder.parameters()).device |
|
self.empty_batch = [self.tokenizer('',padding="do_not_pad",truncation=True,max_length=self.tokenizer.model_max_length,).input_ids for i in range(batch_size)] |
|
|
|
if self.args.train_text_encoder != True: |
|
self.empty_tokens = tokenizer.pad({"input_ids": self.empty_batch},padding="max_length",max_length=tokenizer.model_max_length,return_tensors="pt",).to(text_encoder_device).input_ids |
|
self.empty_tokens.to(text_encoder_device, dtype=dtype) |
|
self.empty_tokens = self.text_encoder(self.empty_tokens)[0] |
|
else: |
|
self.empty_tokens = tokenizer.pad({"input_ids": self.empty_batch},padding="max_length",max_length=tokenizer.model_max_length,return_tensors="pt",).input_ids |
|
self.empty_tokens.to(text_encoder_device, dtype=dtype) |
|
|
|
self.conditional_dropout = args.conditional_dropout |
|
self.conditional_indexes = [] |
|
self.model_variant = model_variant |
|
self.shuffle_per_epoch = shuffle_per_epoch |
|
def __len__(self): |
|
return len(self.cache_paths) |
|
def __getitem__(self, index): |
|
if index == 0: |
|
if self.shuffle_per_epoch == True: |
|
self.cache_paths = tuple(random.sample(self.cache_paths, len(self.cache_paths))) |
|
if len(self.cache_paths) > 1: |
|
possible_indexes_extension = None |
|
possible_indexes = list(range(0,len(self.cache_paths))) |
|
|
|
if self.conditional_dropout != None: |
|
if len(self.conditional_indexes) == 0: |
|
self.conditional_indexes = random.sample(possible_indexes, k=int(math.ceil(len(possible_indexes)*self.conditional_dropout))) |
|
else: |
|
|
|
possible_indexes_extension = [i for i in possible_indexes if i not in self.conditional_indexes] |
|
|
|
possible_indexes_extension = possible_indexes_extension + possible_indexes_extension |
|
possible_indexes_extension = possible_indexes_extension + self.conditional_indexes |
|
self.conditional_indexes = random.sample(possible_indexes_extension, k=int(math.ceil(len(possible_indexes)*self.conditional_dropout))) |
|
|
|
if len(self.conditional_indexes) != len(set(self.conditional_indexes)): |
|
|
|
self.conditional_indexes_non_dupe = list(set(self.conditional_indexes)) |
|
|
|
for i in range(len(self.conditional_indexes) - len(self.conditional_indexes_non_dupe)): |
|
while True: |
|
random_value = random.choice(possible_indexes_extension) |
|
if random_value not in self.conditional_indexes_non_dupe: |
|
self.conditional_indexes_non_dupe.append(random_value) |
|
break |
|
self.conditional_indexes = self.conditional_indexes_non_dupe |
|
self.cache = torch.load(self.cache_paths[index]) |
|
self.latents = self.cache.latents_cache[0] |
|
self.tokens = self.cache.tokens_cache[0] |
|
self.extra_cache = None |
|
self.mask_cache = None |
|
if self.cache.mask_cache is not None: |
|
self.mask_cache = self.cache.mask_cache[0] |
|
self.mask_mean_cache = None |
|
if self.cache.mask_mean_cache is not None: |
|
self.mask_mean_cache = self.cache.mask_mean_cache[0] |
|
if index in self.conditional_indexes: |
|
self.text_encoder = self.empty_tokens |
|
else: |
|
self.text_encoder = self.cache.text_encoder_cache[0] |
|
if self.model_variant != 'base': |
|
self.extra_cache = self.cache.extra_cache[0] |
|
del self.cache |
|
return self.latents, self.text_encoder, self.mask_cache, self.mask_mean_cache, self.extra_cache, self.tokens |
|
|
|
def add_pt_cache(self, cache_path): |
|
if len(self.cache_paths) == 0: |
|
self.cache_paths = (cache_path,) |
|
else: |
|
self.cache_paths += (cache_path,) |
|
|
|
class LatentsDataset(Dataset): |
|
def __init__(self, latents_cache=None, text_encoder_cache=None, mask_cache=None, mask_mean_cache=None, extra_cache=None,tokens_cache=None): |
|
self.latents_cache = latents_cache |
|
self.text_encoder_cache = text_encoder_cache |
|
self.mask_cache = mask_cache |
|
self.mask_mean_cache = mask_mean_cache |
|
self.extra_cache = extra_cache |
|
self.tokens_cache = tokens_cache |
|
def add_latent(self, latent, text_encoder, cached_mask, cached_extra, tokens_cache): |
|
self.latents_cache.append(latent) |
|
self.text_encoder_cache.append(text_encoder) |
|
self.mask_cache.append(cached_mask) |
|
self.mask_mean_cache.append(None if cached_mask is None else cached_mask.mean()) |
|
self.extra_cache.append(cached_extra) |
|
self.tokens_cache.append(tokens_cache) |
|
def __len__(self): |
|
return len(self.latents_cache) |
|
def __getitem__(self, index): |
|
return self.latents_cache[index], self.text_encoder_cache[index], self.mask_cache[index], self.mask_mean_cache[index], self.extra_cache[index], self.tokens_cache[index] |
|
|
|
class DataLoaderMultiAspect(): |
|
""" |
|
Data loader for multi-aspect-ratio training and bucketing |
|
data_root: root folder of training data |
|
batch_size: number of images per batch |
|
flip_p: probability of flipping image horizontally (i.e. 0-0.5) |
|
""" |
|
def __init__( |
|
self, |
|
concept_list, |
|
seed=555, |
|
debug_level=0, |
|
resolution=512, |
|
batch_size=1, |
|
flip_p=0.0, |
|
use_image_names_as_captions=True, |
|
add_class_images_to_dataset=False, |
|
balance_datasets=False, |
|
with_prior_loss=False, |
|
use_text_files_as_captions=False, |
|
aspect_mode='dynamic', |
|
action_preference='add', |
|
model_variant='base', |
|
extra_module=None, |
|
mask_prompts=None, |
|
load_mask=False, |
|
): |
|
self.resolution = resolution |
|
self.debug_level = debug_level |
|
self.flip_p = flip_p |
|
self.use_image_names_as_captions = use_image_names_as_captions |
|
self.balance_datasets = balance_datasets |
|
self.with_prior_loss = with_prior_loss |
|
self.add_class_images_to_dataset = add_class_images_to_dataset |
|
self.use_text_files_as_captions = use_text_files_as_captions |
|
self.aspect_mode = aspect_mode |
|
self.action_preference = action_preference |
|
self.seed = seed |
|
self.model_variant = model_variant |
|
self.extra_module = extra_module |
|
self.load_mask = load_mask |
|
prepared_train_data = [] |
|
|
|
self.aspects = get_aspect_buckets(resolution) |
|
|
|
|
|
|
|
print(f" {bcolors.WARNING} Preloading images...{bcolors.ENDC}") |
|
|
|
if balance_datasets: |
|
print(f" {bcolors.WARNING} Balancing datasets...{bcolors.ENDC}") |
|
|
|
for concept in concept_list: |
|
count = 0 |
|
if 'use_sub_dirs' in concept: |
|
if concept['use_sub_dirs'] == 1: |
|
tot = 0 |
|
for root, dirs, files in os.walk(concept['instance_data_dir']): |
|
tot += len(files) |
|
count = tot |
|
else: |
|
count = len(os.listdir(concept['instance_data_dir'])) |
|
else: |
|
count = len(os.listdir(concept['instance_data_dir'])) |
|
print(f"{concept['instance_data_dir']} has count of {count}") |
|
concept['count'] = count |
|
|
|
min_concept = min(concept_list, key=lambda x: x['count']) |
|
|
|
min_concept_num_images = min_concept['count'] |
|
print(" Min concept: ",min_concept['instance_data_dir']," with ",min_concept_num_images," images") |
|
|
|
balance_cocnept_list = [] |
|
for concept in concept_list: |
|
|
|
if 'do_not_balance' in concept: |
|
if concept['do_not_balance'] == True: |
|
balance_cocnept_list.append(-1) |
|
else: |
|
balance_cocnept_list.append(min_concept_num_images) |
|
else: |
|
balance_cocnept_list.append(min_concept_num_images) |
|
for concept in concept_list: |
|
if 'use_sub_dirs' in concept: |
|
if concept['use_sub_dirs'] == True: |
|
use_sub_dirs = True |
|
else: |
|
use_sub_dirs = False |
|
else: |
|
use_sub_dirs = False |
|
self.image_paths = [] |
|
|
|
min_concept_num_images = None |
|
if balance_datasets: |
|
min_concept_num_images = balance_cocnept_list[concept_list.index(concept)] |
|
data_root = concept['instance_data_dir'] |
|
data_root_class = concept['class_data_dir'] |
|
concept_prompt = concept['instance_prompt'] |
|
concept_class_prompt = concept['class_prompt'] |
|
if 'flip_p' in concept.keys(): |
|
flip_p = concept['flip_p'] |
|
if flip_p == '': |
|
flip_p = 0.0 |
|
else: |
|
flip_p = float(flip_p) |
|
self.__recurse_data_root(self=self, recurse_root=data_root,use_sub_dirs=use_sub_dirs) |
|
random.Random(self.seed).shuffle(self.image_paths) |
|
if self.model_variant == 'depth2img': |
|
print(f" {bcolors.WARNING} ** Loading Depth2Img Pipeline To Process Dataset{bcolors.ENDC}") |
|
self.vae_scale_factor = self.extra_module.depth_images(self.image_paths) |
|
prepared_train_data.extend(self.__prescan_images(debug_level, self.image_paths, flip_p,use_image_names_as_captions,concept_prompt,use_text_files_as_captions=self.use_text_files_as_captions)[0:min_concept_num_images]) |
|
if add_class_images_to_dataset: |
|
self.image_paths = [] |
|
self.__recurse_data_root(self=self, recurse_root=data_root_class,use_sub_dirs=use_sub_dirs) |
|
random.Random(self.seed).shuffle(self.image_paths) |
|
use_image_names_as_captions = False |
|
prepared_train_data.extend(self.__prescan_images(debug_level, self.image_paths, flip_p,use_image_names_as_captions,concept_class_prompt,use_text_files_as_captions=self.use_text_files_as_captions)) |
|
|
|
self.image_caption_pairs = self.__bucketize_images(prepared_train_data, batch_size=batch_size, debug_level=debug_level,aspect_mode=self.aspect_mode,action_preference=self.action_preference) |
|
if self.with_prior_loss and add_class_images_to_dataset == False: |
|
self.class_image_caption_pairs = [] |
|
for concept in concept_list: |
|
self.class_images_path = [] |
|
data_root_class = concept['class_data_dir'] |
|
concept_class_prompt = concept['class_prompt'] |
|
self.__recurse_data_root(self=self, recurse_root=data_root_class,use_sub_dirs=use_sub_dirs,class_images=True) |
|
random.Random(seed).shuffle(self.image_paths) |
|
if self.model_variant == 'depth2img': |
|
print(f" {bcolors.WARNING} ** Depth2Img To Process Class Dataset{bcolors.ENDC}") |
|
self.vae_scale_factor = self.extra_module.depth_images(self.image_paths) |
|
use_image_names_as_captions = False |
|
self.class_image_caption_pairs.extend(self.__prescan_images(debug_level, self.class_images_path, flip_p,use_image_names_as_captions,concept_class_prompt,use_text_files_as_captions=self.use_text_files_as_captions)) |
|
self.class_image_caption_pairs = self.__bucketize_images(self.class_image_caption_pairs, batch_size=batch_size, debug_level=debug_level,aspect_mode=self.aspect_mode,action_preference=self.action_preference) |
|
if mask_prompts is not None: |
|
print(f" {bcolors.WARNING} Checking and generating missing masks...{bcolors.ENDC}") |
|
clip_seg = ClipSeg() |
|
clip_seg.mask_images(self.image_paths, mask_prompts) |
|
del clip_seg |
|
if debug_level > 0: print(f" * DLMA Example: {self.image_caption_pairs[0]} images") |
|
|
|
print(f" {bcolors.WARNING} Number of image-caption pairs: {len(self.image_caption_pairs)}{bcolors.ENDC}") |
|
if len(self.image_caption_pairs) == 0: |
|
raise Exception("All the buckets are empty. Please check your data or reduce the batch size.") |
|
def get_all_images(self): |
|
if self.with_prior_loss == False: |
|
return self.image_caption_pairs |
|
else: |
|
return self.image_caption_pairs, self.class_image_caption_pairs |
|
def __prescan_images(self,debug_level: int, image_paths: list, flip_p=0.0,use_image_names_as_captions=True,concept=None,use_text_files_as_captions=False): |
|
""" |
|
Create ImageTrainItem objects with metadata for hydration later |
|
""" |
|
decorated_image_train_items = [] |
|
|
|
for pathname in image_paths: |
|
identifier = concept |
|
if use_image_names_as_captions: |
|
caption_from_filename = os.path.splitext(os.path.basename(pathname))[0].split("_")[0] |
|
identifier = caption_from_filename |
|
if use_text_files_as_captions: |
|
txt_file_path = os.path.splitext(pathname)[0] + ".txt" |
|
|
|
if os.path.exists(txt_file_path): |
|
try: |
|
with open(txt_file_path, 'r',encoding='utf-8',errors='ignore') as f: |
|
identifier = f.readline().rstrip() |
|
f.close() |
|
if len(identifier) < 1: |
|
raise ValueError(f" *** Could not find valid text in: {txt_file_path}") |
|
|
|
except Exception as e: |
|
print(f" {bcolors.FAIL} *** Error reading {txt_file_path} to get caption, falling back to filename{bcolors.ENDC}") |
|
print(e) |
|
identifier = caption_from_filename |
|
pass |
|
|
|
image = Image.open(pathname) |
|
width, height = image.size |
|
image_aspect = width / height |
|
|
|
target_wh = min(self.aspects, key=lambda aspects:abs(aspects[0]/aspects[1] - image_aspect)) |
|
|
|
image_train_item = ImageTrainItem(image=None, mask=None, extra=None, caption=identifier, target_wh=target_wh, pathname=pathname, flip_p=flip_p,model_variant=self.model_variant, load_mask=self.load_mask) |
|
|
|
decorated_image_train_items.append(image_train_item) |
|
return decorated_image_train_items |
|
|
|
@staticmethod |
|
def __bucketize_images(prepared_train_data: list, batch_size=1, debug_level=0,aspect_mode='dynamic',action_preference='add'): |
|
""" |
|
Put images into buckets based on aspect ratio with batch_size*n images per bucket, discards remainder |
|
""" |
|
|
|
|
|
buckets = {} |
|
for image_caption_pair in prepared_train_data: |
|
target_wh = image_caption_pair.target_wh |
|
|
|
if (target_wh[0],target_wh[1]) not in buckets: |
|
buckets[(target_wh[0],target_wh[1])] = [] |
|
buckets[(target_wh[0],target_wh[1])].append(image_caption_pair) |
|
print(f" ** Number of buckets: {len(buckets)}") |
|
for bucket in buckets: |
|
bucket_len = len(buckets[bucket]) |
|
|
|
|
|
truncate_amount = bucket_len % batch_size |
|
add_amount = batch_size - bucket_len % batch_size |
|
action = None |
|
|
|
if aspect_mode == 'dynamic': |
|
if batch_size == bucket_len: |
|
action = None |
|
elif add_amount < truncate_amount and add_amount != 0 and add_amount != batch_size or truncate_amount == 0: |
|
action = 'add' |
|
|
|
elif truncate_amount < add_amount and truncate_amount != 0 and truncate_amount != batch_size and batch_size < bucket_len: |
|
|
|
action = 'truncate' |
|
|
|
elif truncate_amount == add_amount: |
|
if action_preference == 'add': |
|
action = 'add' |
|
elif action_preference == 'truncate': |
|
action = 'truncate' |
|
elif batch_size > bucket_len: |
|
action = 'add' |
|
|
|
elif aspect_mode == 'add': |
|
action = 'add' |
|
elif aspect_mode == 'truncate': |
|
action = 'truncate' |
|
if action == None: |
|
action = None |
|
|
|
if action == None: |
|
|
|
current_bucket_size = bucket_len |
|
print(f" ** Bucket {bucket} found {bucket_len}, nice!") |
|
elif action == 'add': |
|
|
|
shuffleBucket = random.sample(buckets[bucket], bucket_len) |
|
|
|
current_bucket_size = bucket_len |
|
truncate_count = (bucket_len) % batch_size |
|
|
|
addAmount = batch_size - truncate_count |
|
if addAmount != batch_size: |
|
added=0 |
|
while added != addAmount: |
|
randomIndex = random.randint(0,len(shuffleBucket)-1) |
|
|
|
buckets[bucket].append(shuffleBucket[randomIndex]) |
|
added+=1 |
|
print(f" ** Bucket {bucket} found {bucket_len} images, will {bcolors.OKCYAN}duplicate {added} images{bcolors.ENDC} due to batch size {bcolors.WARNING}{batch_size}{bcolors.ENDC}") |
|
else: |
|
print(f" ** Bucket {bucket} found {bucket_len}, {bcolors.OKGREEN}nice!{bcolors.ENDC}") |
|
elif action == 'truncate': |
|
truncate_count = (bucket_len) % batch_size |
|
current_bucket_size = bucket_len |
|
buckets[bucket] = buckets[bucket][:current_bucket_size - truncate_count] |
|
print(f" ** Bucket {bucket} found {bucket_len} images, will {bcolors.FAIL}drop {truncate_count} images{bcolors.ENDC} due to batch size {bcolors.WARNING}{batch_size}{bcolors.ENDC}") |
|
|
|
|
|
|
|
image_caption_pairs = [] |
|
for bucket in buckets: |
|
image_caption_pairs.extend(buckets[bucket]) |
|
|
|
return image_caption_pairs |
|
|
|
@staticmethod |
|
def __recurse_data_root(self, recurse_root,use_sub_dirs=True,class_images=False): |
|
progress_bar = tqdm(os.listdir(recurse_root), desc=f" {bcolors.WARNING} ** Processing {recurse_root}{bcolors.ENDC}") |
|
for f in os.listdir(recurse_root): |
|
current = os.path.join(recurse_root, f) |
|
if os.path.isfile(current): |
|
ext = os.path.splitext(f)[1].lower() |
|
if '-depth' in f or '-masklabel' in f: |
|
progress_bar.update(1) |
|
continue |
|
if ext in ['.jpg', '.jpeg', '.png', '.bmp', '.webp']: |
|
|
|
try: |
|
img = Image.open(current) |
|
except: |
|
print(f" ** Skipping {current} because it failed to open, please check the file") |
|
progress_bar.update(1) |
|
continue |
|
del img |
|
if class_images == False: |
|
self.image_paths.append(current) |
|
else: |
|
self.class_images_path.append(current) |
|
progress_bar.update(1) |
|
if use_sub_dirs: |
|
sub_dirs = [] |
|
|
|
for d in os.listdir(recurse_root): |
|
current = os.path.join(recurse_root, d) |
|
if os.path.isdir(current): |
|
sub_dirs.append(current) |
|
|
|
for dir in sub_dirs: |
|
self.__recurse_data_root(self=self, recurse_root=dir) |
|
|
|
class NormalDataset(Dataset): |
|
""" |
|
A dataset to prepare the instance and class images with the prompts for fine-tuning the model. |
|
It pre-processes the images and the tokenizes prompts. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
concepts_list, |
|
tokenizer, |
|
with_prior_preservation=True, |
|
size=512, |
|
center_crop=False, |
|
num_class_images=None, |
|
use_image_names_as_captions=False, |
|
shuffle_captions=False, |
|
repeats=1, |
|
use_text_files_as_captions=False, |
|
seed=555, |
|
model_variant='base', |
|
extra_module=None, |
|
mask_prompts=None, |
|
load_mask=None, |
|
): |
|
self.use_image_names_as_captions = use_image_names_as_captions |
|
self.shuffle_captions = shuffle_captions |
|
self.size = size |
|
self.center_crop = center_crop |
|
self.tokenizer = tokenizer |
|
self.with_prior_preservation = with_prior_preservation |
|
self.use_text_files_as_captions = use_text_files_as_captions |
|
self.image_paths = [] |
|
self.class_images_path = [] |
|
self.seed = seed |
|
self.model_variant = model_variant |
|
self.variant_warning = False |
|
self.vae_scale_factor = None |
|
self.load_mask = load_mask |
|
for concept in concepts_list: |
|
if 'use_sub_dirs' in concept: |
|
if concept['use_sub_dirs'] == True: |
|
use_sub_dirs = True |
|
else: |
|
use_sub_dirs = False |
|
else: |
|
use_sub_dirs = False |
|
|
|
for i in range(repeats): |
|
self.__recurse_data_root(self, concept,use_sub_dirs=use_sub_dirs) |
|
|
|
if with_prior_preservation: |
|
for i in range(repeats): |
|
self.__recurse_data_root(self, concept,use_sub_dirs=False,class_images=True) |
|
if mask_prompts is not None: |
|
print(f" {bcolors.WARNING} Checking and generating missing masks{bcolors.ENDC}") |
|
clip_seg = ClipSeg() |
|
clip_seg.mask_images(self.image_paths, mask_prompts) |
|
del clip_seg |
|
|
|
random.Random(seed).shuffle(self.image_paths) |
|
self.num_instance_images = len(self.image_paths) |
|
self._length = self.num_instance_images |
|
self.num_class_images = len(self.class_images_path) |
|
self._length = max(self.num_class_images, self.num_instance_images) |
|
if self.model_variant == 'depth2img': |
|
print(f" {bcolors.WARNING} ** Loading Depth2Img Pipeline To Process Dataset{bcolors.ENDC}") |
|
self.vae_scale_factor = extra_module.depth_images(self.image_paths) |
|
if self.with_prior_preservation: |
|
print(f" {bcolors.WARNING} ** Loading Depth2Img Class Processing{bcolors.ENDC}") |
|
extra_module.depth_images(self.class_images_path) |
|
print(f" {bcolors.WARNING} ** Dataset length: {self._length}, {int(self.num_instance_images / repeats)} images using {repeats} repeats{bcolors.ENDC}") |
|
|
|
self.image_transforms = transforms.Compose( |
|
[ |
|
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), |
|
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.5], [0.5]), |
|
] |
|
|
|
) |
|
self.mask_transforms = transforms.Compose( |
|
[ |
|
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), |
|
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), |
|
transforms.ToTensor(), |
|
]) |
|
|
|
self.depth_image_transforms = transforms.Compose( |
|
[ |
|
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), |
|
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), |
|
transforms.ToTensor(), |
|
] |
|
) |
|
|
|
@staticmethod |
|
def __recurse_data_root(self, recurse_root,use_sub_dirs=True,class_images=False): |
|
|
|
if isinstance(recurse_root, dict): |
|
if class_images == True: |
|
|
|
concept_token = recurse_root['class_prompt'] |
|
data = recurse_root['class_data_dir'] |
|
else: |
|
|
|
concept_token = recurse_root['instance_prompt'] |
|
data = recurse_root['instance_data_dir'] |
|
|
|
|
|
else: |
|
concept_token = None |
|
|
|
progress_bar = tqdm(os.listdir(data), desc=f" {bcolors.WARNING} ** Processing {data}{bcolors.ENDC}") |
|
for f in os.listdir(data): |
|
current = os.path.join(data, f) |
|
if os.path.isfile(current): |
|
if '-depth' in f or '-masklabel' in f: |
|
continue |
|
ext = os.path.splitext(f)[1].lower() |
|
if ext in ['.jpg', '.jpeg', '.png', '.bmp', '.webp']: |
|
try: |
|
img = Image.open(current) |
|
except: |
|
print(f" ** Skipping {current} because it failed to open, please check the file") |
|
progress_bar.update(1) |
|
continue |
|
del img |
|
if class_images == False: |
|
self.image_paths.append([current,concept_token]) |
|
else: |
|
self.class_images_path.append([current,concept_token]) |
|
progress_bar.update(1) |
|
if use_sub_dirs: |
|
sub_dirs = [] |
|
|
|
for d in os.listdir(data): |
|
current = os.path.join(data, d) |
|
if os.path.isdir(current): |
|
sub_dirs.append(current) |
|
|
|
for dir in sub_dirs: |
|
if class_images == False: |
|
self.__recurse_data_root(self=self, recurse_root={'instance_data_dir' : dir, 'instance_prompt' : concept_token}) |
|
else: |
|
self.__recurse_data_root(self=self, recurse_root={'class_data_dir' : dir, 'class_prompt' : concept_token}) |
|
|
|
def __len__(self): |
|
return self._length |
|
|
|
def __getitem__(self, index): |
|
example = {} |
|
instance_path, instance_prompt = self.image_paths[index % self.num_instance_images] |
|
og_prompt = instance_prompt |
|
instance_image = Image.open(instance_path) |
|
if self.model_variant == "inpainting" or self.load_mask: |
|
|
|
mask_pathname = os.path.splitext(instance_path)[0] + "-masklabel.png" |
|
if os.path.exists(mask_pathname) and self.load_mask: |
|
mask = Image.open(mask_pathname).convert("L") |
|
else: |
|
if self.variant_warning == False: |
|
print(f" {bcolors.FAIL} ** Warning: No mask found for an image, using an empty mask but make sure you're training the right model variant.{bcolors.ENDC}") |
|
self.variant_warning = True |
|
size = instance_image.size |
|
mask = Image.new('RGB', size, color="white").convert("L") |
|
example["mask"] = self.mask_transforms(mask) |
|
if self.model_variant == "depth2img": |
|
depth_pathname = os.path.splitext(instance_path)[0] + "-depth.png" |
|
if os.path.exists(depth_pathname): |
|
depth_image = Image.open(depth_pathname).convert("L") |
|
else: |
|
if self.variant_warning == False: |
|
print(f" {bcolors.FAIL} ** Warning: No depth image found for an image, using an empty depth image but make sure you're training the right model variant.{bcolors.ENDC}") |
|
self.variant_warning = True |
|
size = instance_image.size |
|
depth_image = Image.new('RGB', size, color="white").convert("L") |
|
example["instance_depth_images"] = self.depth_image_transforms(depth_image) |
|
|
|
if self.use_image_names_as_captions == True: |
|
instance_prompt = str(instance_path).split(os.sep)[-1].split('.')[0].split('_')[0] |
|
|
|
if self.use_text_files_as_captions == True: |
|
|
|
|
|
last_dot = str(instance_path).rfind('.') |
|
|
|
txt_path = str(instance_path)[:last_dot] + '.txt' |
|
|
|
|
|
if os.path.exists(txt_path): |
|
with open(txt_path, encoding='utf-8') as f: |
|
instance_prompt = f.readline().rstrip() |
|
f.close() |
|
|
|
if self.shuffle_captions: |
|
caption_parts = instance_prompt.split(",") |
|
random.shuffle(caption_parts) |
|
instance_prompt = ",".join(caption_parts) |
|
|
|
|
|
instance_image = instance_image.convert("RGB") |
|
example["instance_images"] = self.image_transforms(instance_image) |
|
example["instance_prompt_ids"] = self.tokenizer( |
|
instance_prompt, |
|
padding="do_not_pad", |
|
truncation=True, |
|
max_length=self.tokenizer.model_max_length, |
|
).input_ids |
|
if self.with_prior_preservation: |
|
class_path, class_prompt = self.class_images_path[index % self.num_class_images] |
|
class_image = Image.open(class_path) |
|
if not class_image.mode == "RGB": |
|
class_image = class_image.convert("RGB") |
|
|
|
if self.model_variant == "inpainting": |
|
mask_pathname = os.path.splitext(class_path)[0] + "-masklabel.png" |
|
if os.path.exists(mask_pathname): |
|
mask = Image.open(mask_pathname).convert("L") |
|
else: |
|
if self.variant_warning == False: |
|
print(f" {bcolors.FAIL} ** Warning: No mask found for an image, using an empty mask but make sure you're training the right model variant.{bcolors.ENDC}") |
|
self.variant_warning = True |
|
size = instance_image.size |
|
mask = Image.new('RGB', size, color="white").convert("L") |
|
example["class_mask"] = self.mask_transforms(mask) |
|
if self.model_variant == "depth2img": |
|
depth_pathname = os.path.splitext(class_path)[0] + "-depth.png" |
|
if os.path.exists(depth_pathname): |
|
depth_image = Image.open(depth_pathname) |
|
else: |
|
if self.variant_warning == False: |
|
print(f" {bcolors.FAIL} ** Warning: No depth image found for an image, using an empty depth image but make sure you're training the right model variant.{bcolors.ENDC}") |
|
self.variant_warning = True |
|
size = instance_image.size |
|
depth_image = Image.new('RGB', size, color="white").convert("L") |
|
example["class_depth_images"] = self.depth_image_transforms(depth_image) |
|
example["class_images"] = self.image_transforms(class_image) |
|
example["class_prompt_ids"] = self.tokenizer( |
|
class_prompt, |
|
padding="do_not_pad", |
|
truncation=True, |
|
max_length=self.tokenizer.model_max_length, |
|
).input_ids |
|
|
|
return example |
|
|
|
|
|
class PromptDataset(Dataset): |
|
"A simple dataset to prepare the prompts to generate class images on multiple GPUs." |
|
|
|
def __init__(self, prompt, num_samples): |
|
self.prompt = prompt |
|
self.num_samples = num_samples |
|
|
|
def __len__(self): |
|
return self.num_samples |
|
|
|
def __getitem__(self, index): |
|
example = {} |
|
example["prompt"] = self.prompt |
|
example["index"] = index |
|
return example |
|
|