czl's picture
simplify UI
38d4034 verified
raw
history blame
No virus
35.6 kB
"""
Helper scripts for generating synthetic images using diffusion model.
Functions:
- get_top_misclassified
- get_class_list
- generateClassPairs
- outputDirectory
- pipe_img
- createPrompts
- interpolatePrompts
- slerp
- get_middle_elements
- remove_middle
- genClassImg
- getMetadata
- groupbyInterpolation
- ungroupInterpolation
- groupAllbyInterpolation
- getPairIndices
- generateImagesFromDataset
- generateTrace
"""
import json
import os
import numpy as np
import pandas as pd
import torch
from DeepCache import DeepCacheSDHelper
from diffusers import (
LMSDiscreteScheduler,
StableDiffusionImg2ImgPipeline,
)
from torch import nn
from torchmetrics.functional.image import structural_similarity_index_measure as ssim
from torchvision import transforms
def get_top_misclassified(val_classifier_json):
"""
Retrieves the top misclassified classes from a validation classifier JSON file.
Args:
val_classifier_json (str): The path to the validation classifier JSON file.
Returns:
dict: A dictionary containing the top misclassified classes, where the keys are the class names
and the values are the number of misclassifications.
"""
with open(val_classifier_json) as f:
val_output = json.load(f)
val_metrics_df = pd.DataFrame.from_dict(
val_output["val_metrics_details"], orient="index"
)
class_dict = dict()
for k, v in val_metrics_df["top_n_classes"].items():
class_dict[k] = v
return class_dict
def get_class_list(val_classifier_json):
"""
Retrieves the list of classes from the given validation classifier JSON file.
Args:
val_classifier_json (str): The path to the validation classifier JSON file.
Returns:
list: A sorted list of class names extracted from the JSON file.
"""
with open(val_classifier_json, "r") as f:
data = json.load(f)
return sorted(list(data["val_metrics_details"].keys()))
def generateClassPairs(val_classifier_json):
"""
Generate pairs of misclassified classes from the given validation classifier JSON.
Args:
val_classifier_json (str): The path to the validation classifier JSON file.
Returns:
list: A sorted list of pairs of misclassified classes.
"""
pairs = set()
misclassified_classes = get_top_misclassified(val_classifier_json)
for key, value in misclassified_classes.items():
for v in value:
pairs.add(tuple(sorted([key, v])))
return sorted(list(pairs))
def outputDirectory(class_pairs, synth_path, metadata_path):
"""
Creates the output directory structure for the synthesized data.
Args:
class_pairs (list): A list of class pairs.
synth_path (str): The path to the directory where the synthesized data will be stored.
metadata_path (str): The path to the directory where the metadata will be stored.
Returns:
None
"""
for id in class_pairs:
class_folder = f"{synth_path}/{id}"
if not (os.path.exists(class_folder)):
os.makedirs(class_folder)
if not (os.path.exists(metadata_path)):
os.makedirs(metadata_path)
print("Info: Output directory ready.")
def pipe_img(
model_path,
device="cuda",
apply_optimization=True,
use_torchcompile=False,
ci_cb=(5, 1),
use_safetensors=None,
cpu_offload=False,
scheduler=None,
):
"""
Creates and returns an image-to-image pipeline for stable diffusion.
Args:
model_path (str): The path to the pretrained model.
device (str, optional): The device to use for computation. Defaults to "cuda".
apply_optimization (bool, optional): Whether to apply optimization techniques. Defaults to True.
use_torchcompile (bool, optional): Whether to use torchcompile for model compilation. Defaults to False.
ci_cb (tuple, optional): A tuple containing the cache interval and cache branch ID. Defaults to (5, 1).
use_safetensors (bool, optional): Whether to use safetensors. Defaults to None.
cpu_offload (bool, optional): Whether to enable CPU offloading. Defaults to False.
scheduler (LMSDiscreteScheduler, optional): The scheduler for the pipeline. Defaults to None.
Returns:
StableDiffusionImg2ImgPipeline: The image-to-image pipeline for stable diffusion.
"""
###############################
# Reference:
# Akimov, R. (2024) Images Interpolation with Stable Diffusion - Hugging Face Open-Source AI Cookbook. Available at: https://huggingface.co/learn/cookbook/en/stable_diffusion_interpolation (Accessed: 4 June 2024).
###############################
if scheduler is None:
scheduler = LMSDiscreteScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000,
steps_offset=1,
)
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
model_path,
scheduler=scheduler,
torch_dtype=torch.float32,
use_safetensors=use_safetensors,
).to(device)
if cpu_offload:
pipe.enable_model_cpu_offload()
if apply_optimization:
# tomesd.apply_patch(pipe, ratio=0.5)
helper = DeepCacheSDHelper(pipe=pipe)
cache_interval, cache_branch_id = ci_cb
helper.set_params(
cache_interval=cache_interval, cache_branch_id=cache_branch_id
) # lower is faster but lower quality
helper.enable()
if torch.cuda.is_available():
pipe.enable_xformers_memory_efficient_attention()
if use_torchcompile:
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
return pipe
def createPrompts(
class_name_pairs,
prompt_structure=None,
use_default_negative_prompt=False,
negative_prompt=None,
):
"""
Create prompts for image generation.
Args:
class_name_pairs (list): A list of two class names.
prompt_structure (str, optional): The structure of the prompt. Defaults to "a photo of a <class_name>".
use_default_negative_prompt (bool, optional): Whether to use the default negative prompt. Defaults to False.
negative_prompt (str, optional): The negative prompt to steer the generation away from certain features.
Returns:
tuple: A tuple containing two lists - prompts and negative_prompts.
prompts (list): Text prompts that describe the desired output image.
negative_prompts (list): Negative prompts that can be used to steer the generation away from certain features.
"""
if prompt_structure is None:
prompt_structure = "a photo of a <class_name>"
elif "<class_name>" not in prompt_structure:
raise ValueError(
"The prompt structure must contain the <class_name> placeholder."
)
if use_default_negative_prompt:
default_negative_prompt = (
"blurry image, disfigured, deformed, distorted, cartoon, drawings"
)
negative_prompt = default_negative_prompt
class1 = class_name_pairs[0]
class2 = class_name_pairs[1]
prompt1 = prompt_structure.replace("<class_name>", class1)
prompt2 = prompt_structure.replace("<class_name>", class2)
prompts = [prompt1, prompt2]
if negative_prompt is None:
print("Info: Negative prompt not provided, returning as None.")
return prompts, None
else:
# Negative prompts that can be used to steer the generation away from certain features.
negative_prompts = [negative_prompt] * len(prompts)
return prompts, negative_prompts
def interpolatePrompts(
prompts,
pipeline,
num_interpolation_steps,
sample_mid_interpolation,
remove_n_middle=0,
device="cuda",
):
"""
Interpolates prompts by generating intermediate embeddings between pairs of prompts.
Args:
prompts (List[str]): A list of prompts to be interpolated.
pipeline: The pipeline object containing the tokenizer and text encoder.
num_interpolation_steps (int): The number of interpolation steps between each pair of prompts.
sample_mid_interpolation (int): The number of intermediate embeddings to sample from the middle of the interpolated prompts.
remove_n_middle (int, optional): The number of middle embeddings to remove from the interpolated prompts. Defaults to 0.
device (str, optional): The device to run the interpolation on. Defaults to "cuda".
Returns:
interpolated_prompt_embeds (torch.Tensor): The interpolated prompt embeddings.
prompt_metadata (dict): Metadata about the interpolation process, including similarity scores and nearest class information.
e.g. if num_interpolation_steps = 10, sample_mid_interpolation = 6, remove_n_middle = 2
Interpolated: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
Sampled: [2, 3, 4, 5, 6, 7]
Removed: x x
Returns: [2, 3, 6, 7]
"""
###############################
# Reference:
# Akimov, R. (2024) Images Interpolation with Stable Diffusion - Hugging Face Open-Source AI Cookbook. Available at: https://huggingface.co/learn/cookbook/en/stable_diffusion_interpolation (Accessed: 4 June 2024).
###############################
def slerp(v0, v1, num, t0=0, t1=1):
"""
Performs spherical linear interpolation between two vectors.
Args:
v0 (torch.Tensor): The starting vector.
v1 (torch.Tensor): The ending vector.
num (int): The number of interpolation points.
t0 (float, optional): The starting time. Defaults to 0.
t1 (float, optional): The ending time. Defaults to 1.
Returns:
torch.Tensor: The interpolated vectors.
"""
###############################
# Reference:
# Karpathy, A. (2022) hacky stablediffusion code for generating videos, Gist. Available at: https://gist.github.com/karpathy/00103b0037c5aaea32fe1da1af553355 (Accessed: 4 June 2024).
###############################
v0 = v0.detach().cpu().numpy()
v1 = v1.detach().cpu().numpy()
def interpolation(t, v0, v1, DOT_THRESHOLD=0.9995):
"""helper function to spherically interpolate two arrays v1 v2"""
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
if np.abs(dot) > DOT_THRESHOLD:
v2 = (1 - t) * v0 + t * v1
else:
theta_0 = np.arccos(dot)
sin_theta_0 = np.sin(theta_0)
theta_t = theta_0 * t
sin_theta_t = np.sin(theta_t)
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
s1 = sin_theta_t / sin_theta_0
v2 = s0 * v0 + s1 * v1
return v2
t = np.linspace(t0, t1, num)
v3 = torch.tensor(np.array([interpolation(t[i], v0, v1) for i in range(num)]))
return v3
def get_middle_elements(lst, n):
"""
Returns a tuple containing a sublist of the middle elements of the given list `lst` and a range of indices of those elements.
Args:
lst (list): The list from which to extract the middle elements.
n (int): The number of middle elements to extract.
Returns:
tuple: A tuple containing the sublist of middle elements and a range of indices.
Raises:
None
Examples:
lst = [1, 2, 3, 4, 5]
get_middle_elements(lst, 3)
([2, 3, 4], range(2, 5))
"""
if n % 2 == 0: # Even number of elements
middle_index = len(lst) // 2 - 1
start = middle_index - n // 2 + 1
end = middle_index + n // 2 + 1
return lst[start:end], range(start, end)
else: # Odd number of elements
middle_index = len(lst) // 2
start = middle_index - n // 2
end = middle_index + n // 2 + 1
return lst[start:end], range(start, end)
def remove_middle(data, n):
"""
Remove the middle n elements from a list.
Args:
data (list): The input list.
n (int): The number of elements to remove from the middle of the list.
Returns:
list: The modified list with the middle n elements removed.
Raises:
ValueError: If n is negative or greater than the length of the list.
"""
if n < 0 or n > len(data):
raise ValueError(
"Invalid value for n. It should be non-negative and less than half the list length"
)
# Find the middle index
middle = len(data) // 2
# Create slices to exclude the middle n elements
if n == 1:
return data[:middle] + data[middle + 1 :]
elif n % 2 == 0:
return data[: middle - n // 2] + data[middle + n // 2 :]
else:
return data[: middle - n // 2] + data[middle + n // 2 + 1 :]
batch_size = len(prompts)
# Tokenizing and encoding prompts into embeddings.
prompts_tokens = pipeline.tokenizer(
prompts,
padding="max_length",
max_length=pipeline.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
prompts_embeds = pipeline.text_encoder(prompts_tokens.input_ids.to(device))[0]
# Interpolating between embeddings pairs for the given number of interpolation steps.
interpolated_prompt_embeds = []
for i in range(batch_size - 1):
interpolated_prompt_embeds.append(
slerp(prompts_embeds[i], prompts_embeds[i + 1], num_interpolation_steps)
)
full_interpolated_prompt_embeds = interpolated_prompt_embeds[:]
interpolated_prompt_embeds[0], sample_range = get_middle_elements(
interpolated_prompt_embeds[0], sample_mid_interpolation
)
if remove_n_middle > 0:
interpolated_prompt_embeds[0] = remove_middle(
interpolated_prompt_embeds[0], remove_n_middle
)
prompt_metadata = dict()
similarity = nn.CosineSimilarity(dim=-1, eps=1e-6)
for i in range(num_interpolation_steps):
class1_sim = (
similarity(
full_interpolated_prompt_embeds[0][0],
full_interpolated_prompt_embeds[0][i],
)
.mean()
.item()
)
class2_sim = (
similarity(
full_interpolated_prompt_embeds[0][num_interpolation_steps - 1],
full_interpolated_prompt_embeds[0][i],
)
.mean()
.item()
)
relative_distance = class1_sim / (class1_sim + class2_sim)
prompt_metadata[i] = {
"selected": i in sample_range,
"similarity": {
"class1": class1_sim,
"class2": class2_sim,
"class1_relative_distance": relative_distance,
"class2_relative_distance": 1 - relative_distance,
},
"nearest_class": int(relative_distance < 0.5),
}
interpolated_prompt_embeds = torch.cat(interpolated_prompt_embeds, dim=0).to(device)
return interpolated_prompt_embeds, prompt_metadata
def genClassImg(
pipeline,
pos_embed,
neg_embed,
input_image,
generator,
latents,
num_imgs=1,
height=512,
width=512,
num_inference_steps=25,
guidance_scale=7.5,
):
"""
Generate class image using the given inputs.
Args:
pipeline: The pipeline object used for image generation.
pos_embed: The positive embedding for the class.
neg_embed: The negative embedding for the class (optional).
input_image: The input image for guidance (optional).
generator: The generator model used for image generation.
latents: The latent vectors used for image generation.
num_imgs: The number of images to generate (default is 1).
height: The height of the generated images (default is 512).
width: The width of the generated images (default is 512).
num_inference_steps: The number of inference steps for image generation (default is 25).
guidance_scale: The scale factor for guidance (default is 7.5).
Returns:
The generated class image.
"""
if neg_embed is not None:
npe = neg_embed[None, ...]
else:
npe = None
return pipeline(
height=height,
width=width,
num_images_per_prompt=num_imgs,
prompt_embeds=pos_embed[None, ...],
negative_prompt_embeds=npe,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=generator,
latents=latents,
image=input_image,
).images[0]
def getMetadata(
class_pairs,
path,
seed,
guidance_scale,
num_inference_steps,
num_interpolation_steps,
sample_mid_interpolation,
height,
width,
prompts,
negative_prompts,
pipeline,
prompt_metadata,
negative_prompt_metadata,
ssim_metadata=None,
save_json=True,
save_path=".",
):
"""
Generate metadata for the given parameters.
Args:
class_pairs (list): List of class pairs.
path (str): Path to the data.
seed (int): Seed value for randomization.
guidance_scale (float): Scale factor for guidance.
num_inference_steps (int): Number of inference steps.
num_interpolation_steps (int): Number of interpolation steps.
sample_mid_interpolation (bool): Flag to sample mid-interpolation.
height (int): Height of the image.
width (int): Width of the image.
prompts (list): List of prompts.
negative_prompts (list): List of negative prompts.
pipeline (object): Pipeline object.
prompt_metadata (dict): Metadata for prompts.
negative_prompt_metadata (dict): Metadata for negative prompts.
ssim_metadata (dict, optional): SSIM scores metadata. Defaults to None.
save_json (bool, optional): Flag to save metadata as JSON. Defaults to True.
save_path (str, optional): Path to save the JSON file. Defaults to ".".
Returns:
dict: Generated metadata.
"""
metadata = dict()
metadata["class_pairs"] = class_pairs
metadata["path"] = path
metadata["seed"] = seed
metadata["params"] = {
"CFG": guidance_scale,
"inferenceSteps": num_inference_steps,
"interpolationSteps": num_interpolation_steps,
"sampleMidInterpolation": sample_mid_interpolation,
"height": height,
"width": width,
}
for i in range(len(prompts)):
metadata[f"prompt_text_{i}"] = prompts[i]
if negative_prompts is not None:
metadata[f"negative_prompt_text_{i}"] = negative_prompts[i]
metadata["pipe_config"] = dict(pipeline.config)
metadata["prompt_embed_similarity"] = prompt_metadata
metadata["negative_prompt_embed_similarity"] = negative_prompt_metadata
if ssim_metadata is not None:
print("Info: SSIM scores are available.")
metadata["ssim_scores"] = ssim_metadata
if save_json:
with open(
os.path.join(save_path, f"{'_'.join(i for i in class_pairs)}_{seed}.json"),
"w",
) as f:
json.dump(metadata, f, indent=4)
return metadata
def groupbyInterpolation(dir_to_classfolder):
"""
Group files in a directory by interpolation step.
Args:
dir_to_classfolder (str): The path to the directory containing the files.
Returns:
None
"""
files = [
(f.split(sep="_")[1].split(sep=".")[0], os.path.join(dir_to_classfolder, f))
for f in os.listdir(dir_to_classfolder)
]
# create a subfolder for each step of the interpolation
for interpolation_step, file_path in files:
new_dir = os.path.join(dir_to_classfolder, interpolation_step)
if not os.path.exists(new_dir):
os.makedirs(new_dir)
os.rename(file_path, os.path.join(new_dir, os.path.basename(file_path)))
def ungroupInterpolation(dir_to_classfolder):
"""
Moves all files from subdirectories within `dir_to_classfolder` to `dir_to_classfolder` itself,
and then removes the subdirectories.
Args:
dir_to_classfolder (str): The path to the directory containing the subdirectories.
Returns:
None
"""
for interpolation_step in os.listdir(dir_to_classfolder):
if os.path.isdir(os.path.join(dir_to_classfolder, interpolation_step)):
for f in os.listdir(os.path.join(dir_to_classfolder, interpolation_step)):
os.rename(
os.path.join(dir_to_classfolder, interpolation_step, f),
os.path.join(dir_to_classfolder, f),
)
os.rmdir(os.path.join(dir_to_classfolder, interpolation_step))
def groupAllbyInterpolation(
data_path,
group=True,
fn_group=groupbyInterpolation,
fn_ungroup=ungroupInterpolation,
):
"""
Group or ungroup all data classes by interpolation.
Args:
data_path (str): The path to the data.
group (bool, optional): Whether to group the data. Defaults to True.
fn_group (function, optional): The function to use for grouping. Defaults to groupbyInterpolation.
fn_ungroup (function, optional): The function to use for ungrouping. Defaults to ungroupInterpolation.
"""
data_classes = sorted(os.listdir(data_path))
if group:
fn = fn_group
else:
fn = fn_ungroup
for c in data_classes:
c_path = os.path.join(data_path, c)
if os.path.isdir(c_path):
fn(c_path)
print(f"Processed {c}")
def getPairIndices(subset_len, total_pair_count=1, seed=None):
"""
Generate pairs of indices for a given subset length.
Args:
subset_len (int): The length of the subset.
total_pair_count (int, optional): The total number of pairs to generate. Defaults to 1.
seed (int, optional): The seed value for the random number generator. Defaults to None.
Returns:
list: A list of pairs of indices.
"""
rng = np.random.default_rng(seed)
group_size = (subset_len + total_pair_count - 1) // total_pair_count
numbers = list(range(subset_len))
numbers_selection = list(range(subset_len))
rng.shuffle(numbers)
for i in range(group_size - subset_len % group_size):
numbers.append(numbers_selection[i])
numbers = np.array(numbers)
groups = numbers[: group_size * total_pair_count].reshape(-1, group_size)
return groups.tolist()
def generateImagesFromDataset(
img_subsets,
class_iterables,
pipeline,
interpolated_prompt_embeds,
interpolated_negative_prompts_embeds,
num_inference_steps,
guidance_scale,
height=512,
width=512,
seed=None,
save_path=".",
class_pairs=("0", "1"),
save_image=True,
image_type="jpg",
interpolate_range="full",
device="cuda",
return_images=False,
):
"""
Generates images from a dataset using the given parameters.
Args:
img_subsets (dict): A dictionary containing image subsets for each class.
class_iterables (dict): A dictionary containing iterable objects for each class.
pipeline (object): The pipeline object used for image generation.
interpolated_prompt_embeds (list): A list of interpolated prompt embeddings.
interpolated_negative_prompts_embeds (list): A list of interpolated negative prompt embeddings.
num_inference_steps (int): The number of inference steps for image generation.
guidance_scale (float): The scale factor for guidance loss during image generation.
height (int, optional): The height of the generated images. Defaults to 512.
width (int, optional): The width of the generated images. Defaults to 512.
seed (int, optional): The seed value for random number generation. Defaults to None.
save_path (str, optional): The path to save the generated images. Defaults to ".".
class_pairs (tuple, optional): A tuple containing pairs of class identifiers. Defaults to ("0", "1").
save_image (bool, optional): Whether to save the generated images. Defaults to True.
image_type (str, optional): The file format of the saved images. Defaults to "jpg".
interpolate_range (str, optional): The range of interpolation for prompt embeddings.
Possible values are "full", "nearest", or "furthest". Defaults to "full".
device (str, optional): The device to use for image generation. Defaults to "cuda".
return_images (bool, optional): Whether to return the generated images. Defaults to False.
Returns:
dict or tuple: If return_images is True, returns a dictionary containing the generated images for each class and a dictionary containing the SSIM scores for each class and interpolation step.
If return_images is False, returns a dictionary containing the SSIM scores for each class and interpolation step.
"""
if interpolate_range == "nearest":
nearest_half = True
furthest_half = False
elif interpolate_range == "furthest":
nearest_half = False
furthest_half = True
else:
nearest_half = False
furthest_half = False
if seed is None:
seed = torch.Generator().seed()
generator = torch.manual_seed(seed)
rng = np.random.default_rng(seed)
# Generating initial U-Net latent vectors from a random normal distribution.
latents = torch.randn(
(1, pipeline.unet.config.in_channels, height // 8, width // 8),
generator=generator,
).to(device)
embed_len = len(interpolated_prompt_embeds)
embed_pairs = zip(interpolated_prompt_embeds, interpolated_negative_prompts_embeds)
embed_pairs_list = list(embed_pairs)
if return_images:
class_images = dict()
class_ssim = dict()
if nearest_half or furthest_half:
if nearest_half:
steps_range = (range(0, embed_len // 2), range(embed_len // 2, embed_len))
mutiplier = 2
elif furthest_half:
# uses opposite class of images of the text interpolation
steps_range = (range(embed_len // 2, embed_len), range(0, embed_len // 2))
mutiplier = 2
else:
steps_range = (range(embed_len), range(embed_len))
mutiplier = 1
for class_iter, class_id in enumerate(class_pairs):
if return_images:
class_images[class_id] = list()
class_ssim[class_id] = {
i: {"ssim_sum": 0, "ssim_count": 0, "ssim_avg": 0} for i in range(embed_len)
}
subset_len = len(img_subsets[class_id])
# to efficiently randomize the steps to interpolate for each image in the class, group_map is used
# group_map: index is the image id, element is the group id
# steps_range[class_iter] determines the range of steps to interpolate for the class,
# so the first half of the steps are for the first class and so on. range(0,7) and range(8,15) for 16 steps
# then the rest is to multiply the steps to cover the whole subset + remainder
group_map = (
list(steps_range[class_iter]) * mutiplier * (subset_len // embed_len + 1)
)
rng.shuffle(
group_map
) # shuffle the steps to interpolate for each image, position in the group_map is mapped to the image id
iter_indices = class_iterables[class_id].pop()
# generate images for each image in the class, randomly selecting an interpolated step
for image_id in iter_indices:
img, trg = img_subsets[class_id][image_id]
input_image = img.unsqueeze(0)
interpolate_step = group_map[image_id]
prompt_embeds, negative_prompt_embeds = embed_pairs_list[interpolate_step]
generated_image = genClassImg(
pipeline,
prompt_embeds,
negative_prompt_embeds,
input_image,
generator,
latents,
num_imgs=1,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
)
pred_image = transforms.ToTensor()(generated_image).unsqueeze(0)
ssim_score = ssim(pred_image, input_image).item()
class_ssim[class_id][interpolate_step]["ssim_sum"] += ssim_score
class_ssim[class_id][interpolate_step]["ssim_count"] += 1
if return_images:
class_images[class_id].append(generated_image)
if save_image:
if image_type == "jpg":
generated_image.save(
f"{save_path}/{class_id}/{seed}-{image_id}_{interpolate_step}.{image_type}",
format="JPEG",
quality=95,
)
elif image_type == "png":
generated_image.save(
f"{save_path}/{class_id}/{seed}-{image_id}_{interpolate_step}.{image_type}",
format="PNG",
)
else:
generated_image.save(
f"{save_path}/{class_id}/{seed}-{image_id}_{interpolate_step}.{image_type}"
)
# calculate ssim avg for the class
for i_step in range(embed_len):
if class_ssim[class_id][i_step]["ssim_count"] > 0:
class_ssim[class_id][i_step]["ssim_avg"] = (
class_ssim[class_id][i_step]["ssim_sum"]
/ class_ssim[class_id][i_step]["ssim_count"]
)
if return_images:
return class_images, class_ssim
else:
return class_ssim
def generateTrace(
prompts,
img_subsets,
class_iterables,
interpolated_prompt_embeds,
interpolated_negative_prompts_embeds,
subset_indices,
seed=None,
save_path=".",
class_pairs=("0", "1"),
image_type="jpg",
interpolate_range="full",
save_prompt_embeds=False,
):
"""
Generate a trace dictionary containing information about the generated images.
Args:
prompts (list): List of prompt texts.
img_subsets (dict): Dictionary containing image subsets for each class.
class_iterables (dict): Dictionary containing iterable objects for each class.
interpolated_prompt_embeds (torch.Tensor): Tensor containing interpolated prompt embeddings.
interpolated_negative_prompts_embeds (torch.Tensor): Tensor containing interpolated negative prompt embeddings.
subset_indices (dict): Dictionary containing indices of subsets for each class.
seed (int, optional): Seed value for random number generation. Defaults to None.
save_path (str, optional): Path to save the generated images. Defaults to ".".
class_pairs (tuple, optional): Tuple containing class pairs. Defaults to ("0", "1").
image_type (str, optional): Type of the generated images. Defaults to "jpg".
interpolate_range (str, optional): Range of interpolation. Defaults to "full".
save_prompt_embeds (bool, optional): Flag to save prompt embeddings. Defaults to False.
Returns:
dict: Trace dictionary containing information about the generated images.
"""
trace_dict = {
"class_pairs": list(),
"class_id": list(),
"image_id": list(),
"interpolation_step": list(),
"embed_len": list(),
"pos_prompt_text": list(),
"neg_prompt_text": list(),
"input_file_path": list(),
"output_file_path": list(),
"input_prompts_embed": list(),
}
if interpolate_range == "nearest":
nearest_half = True
furthest_half = False
elif interpolate_range == "furthest":
nearest_half = False
furthest_half = True
else:
nearest_half = False
furthest_half = False
if seed is None:
seed = torch.Generator().seed()
rng = np.random.default_rng(seed)
embed_len = len(interpolated_prompt_embeds)
embed_pairs = zip(
interpolated_prompt_embeds.cpu().numpy(),
interpolated_negative_prompts_embeds.cpu().numpy(),
)
embed_pairs_list = list(embed_pairs)
if nearest_half or furthest_half:
if nearest_half:
steps_range = (range(0, embed_len // 2), range(embed_len // 2, embed_len))
mutiplier = 2
elif furthest_half:
# uses opposite class of images of the text interpolation
steps_range = (range(embed_len // 2, embed_len), range(0, embed_len // 2))
mutiplier = 2
else:
steps_range = (range(embed_len), range(embed_len))
mutiplier = 1
for class_iter, class_id in enumerate(class_pairs):
subset_len = len(img_subsets[class_id])
# to efficiently randomize the steps to interpolate for each image in the class, group_map is used
# group_map: index is the image id, element is the group id
# steps_range[class_iter] determines the range of steps to interpolate for the class,
# so the first half of the steps are for the first class and so on. range(0,7) and range(8,15) for 16 steps
# then the rest is to multiply the steps to cover the whole subset + remainder
group_map = (
list(steps_range[class_iter]) * mutiplier * (subset_len // embed_len + 1)
)
rng.shuffle(
group_map
) # shuffle the steps to interpolate for each image, position in the group_map is mapped to the image id
iter_indices = class_iterables[class_id].pop()
# generate images for each image in the class, randomly selecting an interpolated step
for image_id in iter_indices:
class_ds = img_subsets[class_id]
interpolate_step = group_map[image_id]
sample_count = subset_indices[class_id][0] + image_id
input_file = os.path.normpath(class_ds.dataset.samples[sample_count][0])
pos_prompt = prompts[0]
neg_prompt = prompts[1]
output_file = f"{save_path}/{class_id}/{seed}-{image_id}_{interpolate_step}.{image_type}"
if save_prompt_embeds:
input_prompts_embed = embed_pairs_list[interpolate_step]
else:
input_prompts_embed = None
trace_dict["class_pairs"].append(class_pairs)
trace_dict["class_id"].append(class_id)
trace_dict["image_id"].append(image_id)
trace_dict["interpolation_step"].append(interpolate_step)
trace_dict["embed_len"].append(embed_len)
trace_dict["pos_prompt_text"].append(pos_prompt)
trace_dict["neg_prompt_text"].append(neg_prompt)
trace_dict["input_file_path"].append(input_file)
trace_dict["output_file_path"].append(output_file)
trace_dict["input_prompts_embed"].append(input_prompts_embed)
return trace_dict