import itertools import os import random from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union import numpy as np import safetensors.torch import torch import torch.nn.functional as F import torchvision.transforms import torchvision.transforms.functional as TF from PIL import Image from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import default_collate from transformers import (CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizerFast) from diffusion import (default_num_train_timesteps, euler_ode_solver_diffusion_loop, make_sigmas) from sdxl_models import (SDXLAdapter, SDXLControlNet, SDXLControlNetFull, SDXLControlNetPreEncodedControlnetCond, SDXLUNet, SDXLVae) class SDXLTraining: text_encoder_one: CLIPTextModel text_encoder_two: CLIPTextModelWithProjection vae: SDXLVae sigmas: torch.Tensor unet: SDXLUNet adapter: Optional[SDXLAdapter] controlnet: Optional[Union[SDXLControlNet, SDXLControlNetFull]] train_unet: bool train_unet_up_blocks: bool mixed_precision: Optional[torch.dtype] timestep_sampling: Literal["uniform", "cubic"] validation_images_logged: bool log_validation_input_images_every_time: bool get_sdxl_conditioning_images: Callable[[Image.Image], Dict[str, Any]] def __init__( self, device, train_unet, get_sdxl_conditioning_images, train_unet_up_blocks=False, unet_resume_from=None, controlnet_cls=None, controlnet_resume_from=None, adapter_cls=None, adapter_resume_from=None, mixed_precision=None, timestep_sampling="uniform", log_validation_input_images_every_time=True, ): self.text_encoder_one = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder", variant="fp16", torch_dtype=torch.float16) self.text_encoder_one.to(device=device) self.text_encoder_one.requires_grad_(False) self.text_encoder_one.eval() self.text_encoder_two = CLIPTextModelWithProjection.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder_2", variant="fp16", torch_dtype=torch.float16) self.text_encoder_two.to(device=device) self.text_encoder_two.requires_grad_(False) self.text_encoder_two.eval() self.vae = SDXLVae.load_fp16_fix(device=device) self.vae.requires_grad_(False) self.vae.eval() self.sigmas = make_sigmas(device=device) if train_unet: if unet_resume_from is None: self.unet = SDXLUNet.load_fp32(device=device) else: self.unet = SDXLUNet.load(unet_resume_from, device=device) self.unet.requires_grad_(True) self.unet.train() self.unet = DDP(self.unet, device_ids=[device]) elif train_unet_up_blocks: if unet_resume_from is None: self.unet = SDXLUNet.load_fp32(device=device) else: self.unet = SDXLUNet.load_fp32(device=device, overrides=[unet_resume_from]) self.unet.requires_grad_(False) self.unet.eval() self.unet.up_blocks.requires_grad_(True) self.unet.up_blocks.train() self.unet = DDP(self.unet, device_ids=[device], find_unused_parameters=True) else: self.unet = SDXLUNet.load_fp16(device=device) self.unet.requires_grad_(False) self.unet.eval() if controlnet_cls is not None: if controlnet_resume_from is None: self.controlnet = controlnet_cls.from_unet(self.unet) self.controlnet.to(device) else: self.controlnet = controlnet_cls.load(controlnet_resume_from, device=device) self.controlnet.train() self.controlnet.requires_grad_(True) # TODO add back # controlnet.enable_gradient_checkpointing() # TODO - should be able to remove find_unused_parameters. Comes from pre encoded controlnet self.controlnet = DDP(self.controlnet, device_ids=[device], find_unused_parameters=True) else: self.controlnet = None if adapter_cls is not None: if adapter_resume_from is None: self.adapter = adapter_cls() self.adapter.to(device=device) else: self.adapter = adapter_cls.load(adapter_resume_from, device=device) self.adapter.train() self.adapter.requires_grad_(True) self.adapter = DDP(self.adapter, device_ids=[device]) else: self.adapter = None self.mixed_precision = mixed_precision self.timestep_sampling = timestep_sampling self.validation_images_logged = False self.log_validation_input_images_every_time = log_validation_input_images_every_time self.get_sdxl_conditioning_images = get_sdxl_conditioning_images self.train_unet = train_unet self.train_unet_up_blocks = train_unet_up_blocks def train_step(self, batch): with torch.no_grad(): if isinstance(self.unet, DDP): unet_dtype = self.unet.module.dtype unet_device = self.unet.module.device else: unet_dtype = self.unet.dtype unet_device = self.unet.device micro_conditioning = batch["micro_conditioning"].to(device=unet_device) image = batch["image"].to(self.vae.device, dtype=self.vae.dtype) latents = self.vae.encode(image).to(dtype=unet_dtype) text_input_ids_one = batch["text_input_ids_one"].to(self.text_encoder_one.device) text_input_ids_two = batch["text_input_ids_two"].to(self.text_encoder_two.device) encoder_hidden_states, pooled_encoder_hidden_states = sdxl_text_conditioning(self.text_encoder_one, self.text_encoder_two, text_input_ids_one, text_input_ids_two) encoder_hidden_states = encoder_hidden_states.to(dtype=unet_dtype) pooled_encoder_hidden_states = pooled_encoder_hidden_states.to(dtype=unet_dtype) bsz = latents.shape[0] if self.timestep_sampling == "uniform": timesteps = torch.randint(0, default_num_train_timesteps, (bsz,), device=unet_device) elif self.timestep_sampling == "cubic": # Cubic sampling to sample a random timestep for each image timesteps = torch.rand((bsz,), device=unet_device) timesteps = (1 - timesteps**3) * default_num_train_timesteps timesteps = timesteps.long() timesteps = timesteps.clamp(0, default_num_train_timesteps - 1) else: assert False sigmas_ = self.sigmas[timesteps].to(dtype=latents.dtype) noise = torch.randn_like(latents) noisy_latents = latents + noise * sigmas_ scaled_noisy_latents = noisy_latents / ((sigmas_**2 + 1) ** 0.5) if "conditioning_image" in batch: conditioning_image = batch["conditioning_image"].to(unet_device) if self.controlnet is not None and isinstance(self.controlnet, SDXLControlNetPreEncodedControlnetCond): controlnet_device = self.controlnet.module.device controlnet_dtype = self.controlnet.module.dtype conditioning_image = self.vae.encode(conditioning_image.to(self.vae.dtype)).to(device=controlnet_device, dtype=controlnet_dtype) conditioning_image_mask = TF.resize(batch["conditioning_image_mask"], conditioning_image.shape[2:]).to(device=controlnet_device, dtype=controlnet_dtype) conditioning_image = torch.concat((conditioning_image, conditioning_image_mask), dim=1) with torch.autocast( "cuda", self.mixed_precision, enabled=self.mixed_precision is not None, ): down_block_additional_residuals = None mid_block_additional_residual = None add_to_down_block_inputs = None add_to_output = None if self.adapter is not None: down_block_additional_residuals = self.adapter(conditioning_image) if self.controlnet is not None: controlnet_out = self.controlnet( x_t=scaled_noisy_latents, t=timesteps, encoder_hidden_states=encoder_hidden_states, micro_conditioning=micro_conditioning, pooled_encoder_hidden_states=pooled_encoder_hidden_states, controlnet_cond=conditioning_image, ) down_block_additional_residuals = controlnet_out["down_block_res_samples"] mid_block_additional_residual = controlnet_out["mid_block_res_sample"] add_to_down_block_inputs = controlnet_out.get("add_to_down_block_inputs", None) add_to_output = controlnet_out.get("add_to_output", None) model_pred = self.unet( x_t=scaled_noisy_latents, t=timesteps, encoder_hidden_states=encoder_hidden_states, micro_conditioning=micro_conditioning, pooled_encoder_hidden_states=pooled_encoder_hidden_states, down_block_additional_residuals=down_block_additional_residuals, mid_block_additional_residual=mid_block_additional_residual, add_to_down_block_inputs=add_to_down_block_inputs, add_to_output=add_to_output, ).sample loss = F.mse_loss(model_pred.float(), noise.float(), reduction="mean") return loss @torch.no_grad() def log_validation(self, step, num_validation_images: int, validation_prompts: Optional[List[str]] = None, validation_images: Optional[List[str]] = None): import wandb if isinstance(self.unet, DDP): unet = self.unet.module unet.eval() unet_set_to_eval = True else: unet = self.unet unet_set_to_eval = False if self.adapter is not None: adapter = self.adapter.module adapter.eval() else: adapter = None if self.controlnet is not None: controlnet = self.controlnet.module controlnet.eval() else: controlnet = None formatted_validation_images = None if validation_images is not None: formatted_validation_images = [] wandb_validation_images = [] for validation_image_path in validation_images: validation_image = Image.open(validation_image_path) validation_image = validation_image.convert("RGB") validation_image = validation_image.resize((1024, 1024)) conditioning_images = self.get_sdxl_conditioning_images(validation_image) conditioning_image = conditioning_images["conditioning_image"] if self.controlnet is not None and isinstance(self.controlnet, SDXLControlNetPreEncodedControlnetCond): conditioning_image = self.vae.encode(conditioning_image[None, :, :, :].to(self.vae.device, dtype=self.vae.dtype)) conditionin_mask_image = TF.resize(conditioning_images["conditioning_mask_image"], conditioning_image.shape[2:]).to(conditioning_image.dtype, conditioning_image.device) conditioning_image = torch.concat(conditioning_image, conditionin_mask_image, dim=1) formatted_validation_images.append(conditioning_image) wandb_validation_images.append(wandb.Image(conditioning_images["conditioning_image_as_pil"])) if self.log_validation_input_images_every_time or not self.validation_images_logged: wandb.log({"validation_conditioning": wandb_validation_images}, step=step) self.validation_images_logged = True generator = torch.Generator().manual_seed(0) output_validation_images = [] for formatted_validation_image, validation_prompt in zip(formatted_validation_images, validation_prompts): for _ in range(num_validation_images): with torch.autocast("cuda"): x_0 = sdxl_diffusion_loop( prompts=validation_prompt, images=formatted_validation_image, unet=unet, text_encoder_one=self.text_encoder_one, text_encoder_two=self.text_encoder_two, controlnet=controlnet, adapter=adapter, sigmas=self.sigmas, generator=generator, ) x_0 = self.vae.decode(x_0) x_0 = self.vae.output_tensor_to_pil(x_0)[0] output_validation_images.append(wandb.Image(x_0, caption=validation_prompt)) wandb.log({"validation": output_validation_images}, step=step) if unet_set_to_eval: unet.train() if adapter is not None: adapter.train() if controlnet is not None: controlnet.train() def parameters(self): if self.train_unet: return self.unet.parameters() if self.controlnet is not None and self.train_unet_up_blocks: return itertools.chain(self.controlnet.parameters(), self.unet.up_blocks.parameters()) if self.controlnet is not None: return self.controlnet.parameters() if self.adapter is not None: return self.adapter.parameters() assert False def save(self, save_to): if self.train_unet: safetensors.torch.save_file(self.unet.module.state_dict(), os.path.join(save_to, "unet.safetensors")) if self.controlnet is not None and self.train_unet_up_blocks: safetensors.torch.save_file(self.controlnet.module.state_dict(), os.path.join(save_to, "controlnet.safetensors")) safetensors.torch.save_file(self.unet.module.up_blocks.state_dict(), os.path.join(save_to, "unet.safetensors")) if self.controlnet is not None: safetensors.torch.save_file(self.controlnet.module.state_dict(), os.path.join(save_to, "controlnet.safetensors")) if self.adapter is not None: safetensors.torch.save_file(self.adapter.module.state_dict(), os.path.join(save_to, "adapter.safetensors")) def get_sdxl_dataset(train_shards: str, shuffle_buffer_size: int, batch_size: int, proportion_empty_prompts: float, get_sdxl_conditioning_images=None): import webdataset as wds dataset = ( wds.WebDataset( train_shards, resampled=True, handler=wds.ignore_and_continue, ) .shuffle(shuffle_buffer_size) .decode("pil", handler=wds.ignore_and_continue) .rename( image="jpg;png;jpeg;webp", text="text;txt;caption", metadata="json", handler=wds.warn_and_continue, ) .map(lambda d: make_sample(d, proportion_empty_prompts=proportion_empty_prompts, get_sdxl_conditioning_images=get_sdxl_conditioning_images)) .select(lambda sample: "conditioning_image" not in sample or sample["conditioning_image"] is not None) ) dataset = dataset.batched(batch_size, partial=False, collation_fn=default_collate) return dataset @torch.no_grad() def make_sample(d, proportion_empty_prompts, get_sdxl_conditioning_images=None): image = d["image"] metadata = d["metadata"] if random.random() < proportion_empty_prompts: text = "" else: text = d["text"] c_top, c_left, _, _ = get_random_crop_params([image.height, image.width], [1024, 1024]) original_width = int(metadata.get("original_width", 0.0)) original_height = int(metadata.get("original_height", 0.0)) micro_conditioning = torch.tensor([original_width, original_height, c_top, c_left, 1024, 1024]) text_input_ids_one = sdxl_tokenize_one(text)[0] text_input_ids_two = sdxl_tokenize_two(text)[0] image = image.convert("RGB") image = TF.resize( image, 1024, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, ) image = TF.crop( image, c_top, c_left, 1024, 1024, ) sample = { "micro_conditioning": micro_conditioning, "text_input_ids_one": text_input_ids_one, "text_input_ids_two": text_input_ids_two, "image": SDXLVae.input_pil_to_tensor(image), } if get_sdxl_conditioning_images is not None: conditioning_images = get_sdxl_conditioning_images(image) sample["conditioning_image"] = conditioning_images["conditioning_image"] if conditioning_images["conditioning_image_mask"] is not None: sample["conditioning_image_mask"] = conditioning_images["conditioning_image_mask"] return sample def get_random_crop_params(input_size: Tuple[int, int], output_size: Tuple[int, int]) -> Tuple[int, int, int, int]: h, w = input_size th, tw = output_size if h < th or w < tw: raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}") if w == tw and h == th: return 0, 0, h, w i = torch.randint(0, h - th + 1, size=(1,)).item() j = torch.randint(0, w - tw + 1, size=(1,)).item() return i, j, th, tw def get_adapter_openpose_conditioning_image(image, open_pose): resolution = image.width conditioning_image = open_pose(image, detect_resolution=resolution, image_resolution=resolution, return_pil=False) if (conditioning_image == 0).all(): return None, None conditioning_image_as_pil = Image.fromarray(conditioning_image) conditioning_image = TF.to_tensor(conditioning_image) return dict(conditioning_image=conditioning_image, conditioning_image_as_pil=conditioning_image_as_pil) def get_controlnet_canny_conditioning_image(image): import cv2 conditioning_image = np.array(image) conditioning_image = cv2.Canny(conditioning_image, 100, 200) conditioning_image = conditioning_image[:, :, None] conditioning_image = np.concatenate([conditioning_image, conditioning_image, conditioning_image], axis=2) conditioning_image_as_pil = Image.fromarray(conditioning_image) conditioning_image = TF.to_tensor(conditioning_image) return dict(conditioning_image=conditioning_image, conditioning_image_as_pil=conditioning_image_as_pil) def get_controlnet_pre_encoded_controlnet_inpainting_conditioning_image(image, conditioning_image_mask): resolution = image.width if conditioning_image_mask is None: if random.random() <= 0.25: conditioning_image_mask = np.ones((resolution, resolution), np.float32) else: conditioning_image_mask = random.choice([make_random_rectangle_mask, make_random_irregular_mask, make_outpainting_mask])(resolution, resolution) conditioning_image_mask = torch.from_numpy(conditioning_image_mask) conditioning_image_mask = conditioning_image_mask[None, :, :] conditioning_image = TF.to_tensor(image) # where mask is 1, zero out the pixels. Note that this requires mask to be concattenated # with the mask so that the network knows the zeroed out pixels are from the mask and # are not just zero in the original image conditioning_image = conditioning_image * (conditioning_image_mask < 0.5) conditioning_image_as_pil = TF.to_pil_image(conditioning_image) conditioning_image = TF.normalize(conditioning_image, [0.5], [0.5]) return dict(conditioning_image=conditioning_image, conditioning_image_mask=conditioning_image_mask, conditioning_image_as_pil=conditioning_image_as_pil) def get_controlnet_inpainting_conditioning_image(image, conditioning_image_mask): resolution = image.width if conditioning_image_mask is None: if random.random() <= 0.25: conditioning_image_mask = np.ones((resolution, resolution), np.float32) else: conditioning_image_mask = random.choice([make_random_rectangle_mask, make_random_irregular_mask, make_outpainting_mask])(resolution, resolution) conditioning_image_mask = torch.from_numpy(conditioning_image_mask) conditioning_image_mask = conditioning_image_mask[None, :, :] conditioning_image = TF.to_tensor(image) # Just zero out the pixels which will be masked conditioning_image_as_pil = TF.to_pil_image(conditioning_image * (conditioning_image_mask < 0.5)) # where mask is set to 1, set to -1 "special" masked image pixel. # -1 is outside of the 0-1 range that the controlnet normalized # input is in. conditioning_image = conditioning_image * (conditioning_image_mask < 0.5) + -1.0 * (conditioning_image_mask >= 0.5) return dict(conditioning_image=conditioning_image, conditioning_image_mask=conditioning_image_mask, conditioning_image_as_pil=conditioning_image_as_pil) # TODO: would be nice to just call a function from a tokenizers https://github.com/huggingface/tokenizers # i.e. afaik tokenizing shouldn't require holding any state tokenizer_one = CLIPTokenizerFast.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="tokenizer") tokenizer_two = CLIPTokenizerFast.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="tokenizer_2") def sdxl_tokenize_one(prompts): return tokenizer_one( prompts, padding="max_length", max_length=tokenizer_one.model_max_length, truncation=True, return_tensors="pt", ).input_ids def sdxl_tokenize_two(prompts): return tokenizer_two( prompts, padding="max_length", max_length=tokenizer_one.model_max_length, truncation=True, return_tensors="pt", ).input_ids def sdxl_text_conditioning(text_encoder_one, text_encoder_two, text_input_ids_one, text_input_ids_two): prompt_embeds_1 = text_encoder_one( text_input_ids_one, output_hidden_states=True, ).hidden_states[-2] prompt_embeds_1 = prompt_embeds_1.view(prompt_embeds_1.shape[0], prompt_embeds_1.shape[1], -1) prompt_embeds_2 = text_encoder_two( text_input_ids_two, output_hidden_states=True, ) pooled_encoder_hidden_states = prompt_embeds_2[0] prompt_embeds_2 = prompt_embeds_2.hidden_states[-2] prompt_embeds_2 = prompt_embeds_2.view(prompt_embeds_2.shape[0], prompt_embeds_2.shape[1], -1) encoder_hidden_states = torch.cat((prompt_embeds_1, prompt_embeds_2), dim=-1) return encoder_hidden_states, pooled_encoder_hidden_states def make_random_rectangle_mask( height, width, margin=10, bbox_min_size=100, bbox_max_size=512, min_times=1, max_times=2, ): mask = np.zeros((height, width), np.float32) bbox_max_size = min(bbox_max_size, height - margin * 2, width - margin * 2) times = np.random.randint(min_times, max_times + 1) for i in range(times): box_width = np.random.randint(bbox_min_size, bbox_max_size) box_height = np.random.randint(bbox_min_size, bbox_max_size) start_x = np.random.randint(margin, width - margin - box_width + 1) start_y = np.random.randint(margin, height - margin - box_height + 1) mask[start_y : start_y + box_height, start_x : start_x + box_width] = 1 return mask def make_random_irregular_mask(height, width, max_angle=4, max_len=60, max_width=256, min_times=1, max_times=2): import cv2 mask = np.zeros((height, width), np.float32) times = np.random.randint(min_times, max_times + 1) for i in range(times): start_x = np.random.randint(width) start_y = np.random.randint(height) for j in range(1 + np.random.randint(5)): angle = 0.01 + np.random.randint(max_angle) if i % 2 == 0: angle = 2 * 3.1415926 - angle length = 10 + np.random.randint(max_len) brush_w = 5 + np.random.randint(max_width) end_x = np.clip((start_x + length * np.sin(angle)).astype(np.int32), 0, width) end_y = np.clip((start_y + length * np.cos(angle)).astype(np.int32), 0, height) choice = random.randint(0, 2) if choice == 0: cv2.line(mask, (start_x, start_y), (end_x, end_y), 1.0, brush_w) elif choice == 1: cv2.circle(mask, (start_x, start_y), radius=brush_w, color=1.0, thickness=-1) elif choice == 2: radius = brush_w // 2 mask[ start_y - radius : start_y + radius, start_x - radius : start_x + radius, ] = 1 else: assert False start_x, start_y = end_x, end_y return mask def make_outpainting_mask(height, width, probs=[0.5, 0.5, 0.5, 0.5]): mask = np.zeros((height, width), np.float32) at_least_one_mask_applied = False coords = [ [(0, 0), (1, get_padding(height))], [(0, 0), (get_padding(width), 1)], [(0, 1 - get_padding(height)), (1, 1)], [(1 - get_padding(width), 0), (1, 1)], ] for pp, coord in zip(probs, coords): if np.random.random() < pp: at_least_one_mask_applied = True mask = apply_padding(mask=mask, coord=coord) if not at_least_one_mask_applied: idx = np.random.choice(range(len(coords)), p=np.array(probs) / sum(probs)) mask = apply_padding(mask=mask, coord=coords[idx]) return mask def get_padding(size, min_padding_percent=0.04, max_padding_percent=0.5): n1 = int(min_padding_percent * size) n2 = int(max_padding_percent * size) return np.random.randint(n1, n2) / size def apply_padding(mask, coord): height, width = mask.shape mask[ int(coord[0][0] * height) : int(coord[1][0] * height), int(coord[0][1] * width) : int(coord[1][1] * width), ] = 1 return mask @torch.no_grad() def sdxl_diffusion_loop( prompts: Union[str, List[str]], unet, text_encoder_one, text_encoder_two, images=None, controlnet=None, adapter=None, sigmas=None, timesteps=None, x_T=None, micro_conditioning=None, guidance_scale=5.0, generator=None, negative_prompts=None, diffusion_loop=euler_ode_solver_diffusion_loop, ): if isinstance(prompts, str): prompts = [prompts] batch_size = len(prompts) if negative_prompts is not None and guidance_scale > 1.0: prompts += negative_prompts encoder_hidden_states, pooled_encoder_hidden_states = sdxl_text_conditioning( text_encoder_one, text_encoder_two, sdxl_tokenize_one(prompts).to(text_encoder_one.device), sdxl_tokenize_two(prompts).to(text_encoder_two.device), ) encoder_hidden_states = encoder_hidden_states.to(unet.dtype) pooled_encoder_hidden_states = pooled_encoder_hidden_states.to(unet.dtype) if guidance_scale > 1.0: if negative_prompts is None: negative_encoder_hidden_states = torch.zeros_like(encoder_hidden_states) negative_pooled_encoder_hidden_states = torch.zeros_like(pooled_encoder_hidden_states) else: encoder_hidden_states, negative_encoder_hidden_states = torch.chunk(encoder_hidden_states, 2) pooled_encoder_hidden_states, negative_pooled_encoder_hidden_states = torch.chunk(pooled_encoder_hidden_states, 2) else: negative_encoder_hidden_states = None negative_pooled_encoder_hidden_states = None if sigmas is None: sigmas = make_sigmas(device=unet.device) if timesteps is None: timesteps = torch.linspace(0, sigmas.numel() - 1, 50, dtype=torch.long, device=unet.device) if x_T is None: x_T = torch.randn((batch_size, 4, 1024 // 8, 1024 // 8), dtype=unet.dtype, device=unet.device, generator=generator) x_T = x_T * ((sigmas[timesteps[-1]] ** 2 + 1) ** 0.5) if micro_conditioning is None: micro_conditioning = torch.tensor([[1024, 1024, 0, 0, 1024, 1024]], dtype=torch.long, device=unet.device) micro_conditioning = micro_conditioning.expand(batch_size, -1) if adapter is not None: down_block_additional_residuals = adapter(images.to(dtype=adapter.dtype, device=adapter.device)) else: down_block_additional_residuals = None if controlnet is not None: controlnet_cond = images.to(dtype=controlnet.dtype, device=controlnet.device) else: controlnet_cond = None eps_theta = lambda *args, **kwargs: sdxl_eps_theta( *args, **kwargs, unet=unet, encoder_hidden_states=encoder_hidden_states, pooled_encoder_hidden_states=pooled_encoder_hidden_states, negative_encoder_hidden_states=negative_encoder_hidden_states, negative_pooled_encoder_hidden_states=negative_pooled_encoder_hidden_states, micro_conditioning=micro_conditioning, guidance_scale=guidance_scale, controlnet=controlnet, controlnet_cond=controlnet_cond, down_block_additional_residuals=down_block_additional_residuals, ) x_0 = diffusion_loop(eps_theta=eps_theta, timesteps=timesteps, sigmas=sigmas, x_T=x_T) return x_0 @torch.no_grad() def sdxl_eps_theta( x_t, t, sigma, unet, encoder_hidden_states, pooled_encoder_hidden_states, negative_encoder_hidden_states, negative_pooled_encoder_hidden_states, micro_conditioning, guidance_scale, controlnet=None, controlnet_cond=None, down_block_additional_residuals=None, ): # TODO - how does this not effect the ode we are solving scaled_x_t = x_t / ((sigma**2 + 1) ** 0.5) if guidance_scale > 1.0: scaled_x_t = torch.concat([scaled_x_t, scaled_x_t]) encoder_hidden_states = torch.concat((encoder_hidden_states, negative_encoder_hidden_states)) pooled_encoder_hidden_states = torch.concat((pooled_encoder_hidden_states, negative_pooled_encoder_hidden_states)) micro_conditioning = torch.concat([micro_conditioning, micro_conditioning]) if controlnet_cond is not None: controlnet_cond = torch.concat([controlnet_cond, controlnet_cond]) if controlnet is not None: controlnet_out = controlnet( x_t=scaled_x_t.to(controlnet.dtype), t=t, encoder_hidden_states=encoder_hidden_states.to(controlnet.dtype), micro_conditioning=micro_conditioning.to(controlnet.dtype), pooled_encoder_hidden_states=pooled_encoder_hidden_states.to(controlnet.dtype), controlnet_cond=controlnet_cond, ) down_block_additional_residuals = [x.to(unet.dtype) for x in controlnet_out["down_block_res_samples"]] mid_block_additional_residual = controlnet_out["mid_block_res_sample"].to(unet.dtype) add_to_down_block_inputs = controlnet_out.get("add_to_down_block_inputs", None) if add_to_down_block_inputs is not None: add_to_down_block_inputs = [x.to(unet.dtype) for x in add_to_down_block_inputs] add_to_output = controlnet_out.get("add_to_output", None) if add_to_output is not None: add_to_output = add_to_output.to(unet.dtype) else: mid_block_additional_residual = None add_to_down_block_inputs = None add_to_output = None eps_hat = unet( x_t=scaled_x_t, t=t, encoder_hidden_states=encoder_hidden_states, micro_conditioning=micro_conditioning, pooled_encoder_hidden_states=pooled_encoder_hidden_states, down_block_additional_residuals=down_block_additional_residuals, mid_block_additional_residual=mid_block_additional_residual, add_to_down_block_inputs=add_to_down_block_inputs, add_to_output=add_to_output, ) if guidance_scale > 1.0: eps_hat, eps_hat_uncond = eps_hat.chunk(2) eps_hat = eps_hat_uncond + guidance_scale * (eps_hat - eps_hat_uncond) return eps_hat known_negative_prompt = "text, watermark, low-quality, signature, moiré pattern, downsampling, aliasing, distorted, blurry, glossy, blur, jpeg artifacts, compression artifacts, poorly drawn, low-resolution, bad, distortion, twisted, excessive, exaggerated pose, exaggerated limbs, grainy, symmetrical, duplicate, error, pattern, beginner, pixelated, fake, hyper, glitch, overexposed, high-contrast, bad-contrast" if __name__ == "__main__": from argparse import ArgumentParser args = ArgumentParser() args.add_argument("--prompts", required=True, type=str, nargs="+") args.add_argument("--negative_prompts", required=False, type=str, nargs="+") args.add_argument("--use_known_negative_prompt", action="store_true") args.add_argument("--num_images_per_prompt", required=True, type=int, default=1) args.add_argument("--num_inference_steps", required=False, type=int, default=50) args.add_argument("--images", required=False, type=str, default=None, nargs="+") args.add_argument("--masks", required=False, type=str, default=None, nargs="+") args.add_argument("--controlnet_checkpoint", required=False, type=str, default=None) args.add_argument("--controlnet", required=False, choices=["SDXLControlNet", "SDXLControlNetFull", "SDXLControNetPreEncodedControlnetCond"], default=None) args.add_argument("--adapter_checkpoint", required=False, type=str, default=None) args.add_argument("--device", required=False, default=None) args.add_argument("--dtype", required=False, default="fp16", choices=["fp16", "fp32"]) args.add_argument("--guidance_scale", required=False, default=5.0, type=float) args.add_argument("--seed", required=False, type=int) args = args.parse_args() if args.device is None: if torch.cuda.is_available(): device = "cuda" elif torch.backends.mps.is_available(): device = "mps" if args.dtype == "fp16": dtype = torch.float16 text_encoder_one = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder", variant="fp16", torch_dtype=torch.float16) text_encoder_one.to(device=device) text_encoder_two = CLIPTextModelWithProjection.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder_2", variant="fp16", torch_dtype=torch.float16) text_encoder_two.to(device=device) vae = SDXLVae.load_fp16_fix(device=device) vae.to(torch.float16) unet = SDXLUNet.load_fp16(device=device) elif args.dtype == "fp32": dtype = torch.float32 text_encoder_one = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder") text_encoder_one.to(device=device) text_encoder_two = CLIPTextModelWithProjection.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder_2") text_encoder_two.to(device=device) vae = SDXLVae.load_fp16_fix(device=device) unet = SDXLUNet.load_fp32(device=device) else: assert False if args.controlnet == "SDXLControlNet": controlnet = SDXLControlNet.load(args.controlnet_checkpoint, device=device) controlnet.to(dtype) elif args.controlnet == "SDXLControlNetFull": controlnet = SDXLControlNetFull.load(args.controlnet_checkpoint, device=device) controlnet.to(dtype) elif args.controlnet == "SDXLControlNetPreEncodedControlnetCond": controlnet = SDXLControlNetPreEncodedControlnetCond.load(args.controlnet_checkpoint, device=device) controlnet.to(dtype) else: controlnet = None if args.adapter_checkpoint is not None: adapter = SDXLAdapter.load(args.adapter_checkpoint, device=device) adapter.to(dtype) else: adapter = None sigmas = make_sigmas(device=device).to(unet.dtype) timesteps = torch.linspace(0, sigmas.numel() - 1, args.num_inference_steps, dtype=torch.long, device=unet.device) prompts = [] for prompt in args.prompts: prompts += [prompt] * args.num_images_per_prompt if args.use_known_negative_prompt: args.negative_prompts = [known_negative_prompt] if args.negative_prompts is None: negative_prompts = None elif len(args.negative_prompts) == 1: negative_prompts = args.negative_prompts * len(prompts) elif len(args.negative_prompts) == len(args.prompts): negative_prompts = [] for negative_prompt in args.negative_prompts: negative_prompts += [negative_prompt] * args.num_images_per_prompt else: assert False if args.images is not None: images = [] for image_idx, image in enumerate(args.images): image = Image.open(image) image = image.convert("RGB") image = image.resize((1024, 1024)) image = TF.to_tensor(image) if args.masks is not None: mask = args.masks[image_idx] mask = Image.open(mask) mask = mask.convert("L") mask = mask.resize((1024, 1024)) mask = TF.to_tensor(mask) if isinstance(controlnet, SDXLControlNetPreEncodedControlnetCond): image = image * (mask < 0.5) image = TF.normalize(image, [0.5], [0.5]) image = vae.encode(image[None, :, :, :].to(dtype=vae.dtype, device=vae.device)).to(dtype=controlnet.dtype, device=controlnet.device) mask = TF.resize(mask, (1024 // 8, 1024 // 8))[None, :, :, :].to(dtype=image.dtype, device=image.device) image = torch.concat((image, mask), dim=1) else: image = (image * (mask < 0.5) + -1.0 * (mask >= 0.5)).to(dtype=dtype, device=device) image = image[None, :, :, :] images += [image] * args.num_images_per_prompt images = torch.concat(images) else: images = None if args.seed is None: generator = None else: generator = torch.Generator(device).manual_seed(args.seed) images = sdxl_diffusion_loop( prompts=prompts, unet=unet, text_encoder_one=text_encoder_one, text_encoder_two=text_encoder_two, images=images, controlnet=controlnet, adapter=adapter, sigmas=sigmas, timesteps=timesteps, guidance_scale=args.guidance_scale, negative_prompts=negative_prompts, generator=generator, ) images = vae.output_tensor_to_pil(vae.decode(images)) for i, image in enumerate(images): image.save(f"out_{i}.png")