import copy import math import os from glob import glob from typing import Dict, List, Optional, Tuple, Union import cv2 import numpy as np import streamlit as st import torch import torch.nn as nn import torchvision.transforms as TT from einops import rearrange, repeat from imwatermark import WatermarkEncoder from omegaconf import ListConfig, OmegaConf from PIL import Image from safetensors.torch import load_file as load_safetensors from torch import autocast from torchvision import transforms from torchvision.utils import make_grid, save_image from scripts.demo.discretization import (Img2ImgDiscretizationWrapper, Txt2NoisyDiscretizationWrapper) from scripts.util.detection.nsfw_and_watermark_dectection import \ DeepFloydDataFiltering from sgm.inference.helpers import embed_watermark from sgm.modules.diffusionmodules.guiders import (LinearPredictionGuider, VanillaCFG) from sgm.modules.diffusionmodules.sampling import (DPMPP2MSampler, DPMPP2SAncestralSampler, EulerAncestralSampler, EulerEDMSampler, HeunEDMSampler, LinearMultistepSampler) from sgm.util import append_dims, default, instantiate_from_config @st.cache_resource() def init_st(version_dict, load_ckpt=True, load_filter=True): state = dict() if not "model" in state: config = version_dict["config"] ckpt = version_dict["ckpt"] config = OmegaConf.load(config) model, msg = load_model_from_config(config, ckpt if load_ckpt else None) state["msg"] = msg state["model"] = model state["ckpt"] = ckpt if load_ckpt else None state["config"] = config if load_filter: state["filter"] = DeepFloydDataFiltering(verbose=False) return state def load_model(model): model.cuda() lowvram_mode = False def set_lowvram_mode(mode): global lowvram_mode lowvram_mode = mode def initial_model_load(model): global lowvram_mode if lowvram_mode: model.model.half() else: model.cuda() return model def unload_model(model): global lowvram_mode if lowvram_mode: model.cpu() torch.cuda.empty_cache() def load_model_from_config(config, ckpt=None, verbose=True): model = instantiate_from_config(config.model) if ckpt is not None: print(f"Loading model from {ckpt}") if ckpt.endswith("ckpt"): pl_sd = torch.load(ckpt, map_location="cpu") if "global_step" in pl_sd: global_step = pl_sd["global_step"] st.info(f"loaded ckpt from global step {global_step}") print(f"Global Step: {pl_sd['global_step']}") sd = pl_sd["state_dict"] elif ckpt.endswith("safetensors"): sd = load_safetensors(ckpt) else: raise NotImplementedError msg = None m, u = model.load_state_dict(sd, strict=False) if len(m) > 0 and verbose: print("missing keys:") print(m) if len(u) > 0 and verbose: print("unexpected keys:") print(u) else: msg = None model = initial_model_load(model) model.eval() return model, msg def get_unique_embedder_keys_from_conditioner(conditioner): return list(set([x.input_key for x in conditioner.embedders])) def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None): # Hardcoded demo settings; might undergo some changes in the future value_dict = {} for key in keys: if key == "txt": if prompt is None: prompt = "A professional photograph of an astronaut riding a pig" if negative_prompt is None: negative_prompt = "" prompt = st.text_input("Prompt", prompt) negative_prompt = st.text_input("Negative prompt", negative_prompt) value_dict["prompt"] = prompt value_dict["negative_prompt"] = negative_prompt if key == "original_size_as_tuple": orig_width = st.number_input( "orig_width", value=init_dict["orig_width"], min_value=16, ) orig_height = st.number_input( "orig_height", value=init_dict["orig_height"], min_value=16, ) value_dict["orig_width"] = orig_width value_dict["orig_height"] = orig_height if key == "crop_coords_top_left": crop_coord_top = st.number_input("crop_coords_top", value=0, min_value=0) crop_coord_left = st.number_input("crop_coords_left", value=0, min_value=0) value_dict["crop_coords_top"] = crop_coord_top value_dict["crop_coords_left"] = crop_coord_left if key == "aesthetic_score": value_dict["aesthetic_score"] = 6.0 value_dict["negative_aesthetic_score"] = 2.5 if key == "target_size_as_tuple": value_dict["target_width"] = init_dict["target_width"] value_dict["target_height"] = init_dict["target_height"] if key in ["fps_id", "fps"]: fps = st.number_input("fps", value=6, min_value=1) value_dict["fps"] = fps value_dict["fps_id"] = fps - 1 if key == "motion_bucket_id": mb_id = st.number_input("motion bucket id", 0, 511, value=127) value_dict["motion_bucket_id"] = mb_id if key == "pool_image": st.text("Image for pool conditioning") image = load_img( key="pool_image_input", size=224, center_crop=True, ) if image is None: st.info("Need an image here") image = torch.zeros(1, 3, 224, 224) value_dict["pool_image"] = image return value_dict def perform_save_locally(save_path, samples): os.makedirs(os.path.join(save_path), exist_ok=True) base_count = len(os.listdir(os.path.join(save_path))) samples = embed_watermark(samples) for sample in samples: sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c") Image.fromarray(sample.astype(np.uint8)).save( os.path.join(save_path, f"{base_count:09}.png") ) base_count += 1 def init_save_locally(_dir, init_value: bool = False): save_locally = st.sidebar.checkbox("Save images locally", value=init_value) if save_locally: save_path = st.text_input("Save path", value=os.path.join(_dir, "samples")) else: save_path = None return save_locally, save_path def get_guider(options, key): guider = st.sidebar.selectbox( f"Discretization #{key}", [ "VanillaCFG", "IdentityGuider", "LinearPredictionGuider", ], options.get("guider", 0), ) additional_guider_kwargs = options.pop("additional_guider_kwargs", {}) if guider == "IdentityGuider": guider_config = { "target": "sgm.modules.diffusionmodules.guiders.IdentityGuider" } elif guider == "VanillaCFG": scale = st.number_input( f"cfg-scale #{key}", value=options.get("cfg", 5.0), min_value=0.0, ) guider_config = { "target": "sgm.modules.diffusionmodules.guiders.VanillaCFG", "params": { "scale": scale, **additional_guider_kwargs, }, } elif guider == "LinearPredictionGuider": max_scale = st.number_input( f"max-cfg-scale #{key}", value=options.get("cfg", 1.5), min_value=1.0, ) min_scale = st.number_input( f"min guidance scale", value=options.get("min_cfg", 1.0), min_value=1.0, max_value=10.0, ) guider_config = { "target": "sgm.modules.diffusionmodules.guiders.LinearPredictionGuider", "params": { "max_scale": max_scale, "min_scale": min_scale, "num_frames": options["num_frames"], **additional_guider_kwargs, }, } else: raise NotImplementedError return guider_config def init_sampling( key=1, img2img_strength: Optional[float] = None, specify_num_samples: bool = True, stage2strength: Optional[float] = None, options: Optional[Dict[str, int]] = None, ): options = {} if options is None else options num_rows, num_cols = 1, 1 if specify_num_samples: num_cols = st.number_input( f"num cols #{key}", value=num_cols, min_value=1, max_value=10 ) steps = st.sidebar.number_input( f"steps #{key}", value=options.get("num_steps", 40), min_value=1, max_value=1000 ) sampler = st.sidebar.selectbox( f"Sampler #{key}", [ "EulerEDMSampler", "HeunEDMSampler", "EulerAncestralSampler", "DPMPP2SAncestralSampler", "DPMPP2MSampler", "LinearMultistepSampler", ], options.get("sampler", 0), ) discretization = st.sidebar.selectbox( f"Discretization #{key}", [ "LegacyDDPMDiscretization", "EDMDiscretization", ], options.get("discretization", 0), ) discretization_config = get_discretization(discretization, options=options, key=key) guider_config = get_guider(options=options, key=key) sampler = get_sampler(sampler, steps, discretization_config, guider_config, key=key) if img2img_strength is not None: st.warning( f"Wrapping {sampler.__class__.__name__} with Img2ImgDiscretizationWrapper" ) sampler.discretization = Img2ImgDiscretizationWrapper( sampler.discretization, strength=img2img_strength ) if stage2strength is not None: sampler.discretization = Txt2NoisyDiscretizationWrapper( sampler.discretization, strength=stage2strength, original_steps=steps ) return sampler, num_rows, num_cols def get_discretization(discretization, options, key=1): if discretization == "LegacyDDPMDiscretization": discretization_config = { "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization", } elif discretization == "EDMDiscretization": sigma_min = st.number_input( f"sigma_min #{key}", value=options.get("sigma_min", 0.03) ) # 0.0292 sigma_max = st.number_input( f"sigma_max #{key}", value=options.get("sigma_max", 14.61) ) # 14.6146 rho = st.number_input(f"rho #{key}", value=options.get("rho", 3.0)) discretization_config = { "target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization", "params": { "sigma_min": sigma_min, "sigma_max": sigma_max, "rho": rho, }, } return discretization_config def get_sampler(sampler_name, steps, discretization_config, guider_config, key=1): if sampler_name == "EulerEDMSampler" or sampler_name == "HeunEDMSampler": s_churn = st.sidebar.number_input(f"s_churn #{key}", value=0.0, min_value=0.0) s_tmin = st.sidebar.number_input(f"s_tmin #{key}", value=0.0, min_value=0.0) s_tmax = st.sidebar.number_input(f"s_tmax #{key}", value=999.0, min_value=0.0) s_noise = st.sidebar.number_input(f"s_noise #{key}", value=1.0, min_value=0.0) if sampler_name == "EulerEDMSampler": sampler = EulerEDMSampler( num_steps=steps, discretization_config=discretization_config, guider_config=guider_config, s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise, verbose=True, ) elif sampler_name == "HeunEDMSampler": sampler = HeunEDMSampler( num_steps=steps, discretization_config=discretization_config, guider_config=guider_config, s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise, verbose=True, ) elif ( sampler_name == "EulerAncestralSampler" or sampler_name == "DPMPP2SAncestralSampler" ): s_noise = st.sidebar.number_input("s_noise", value=1.0, min_value=0.0) eta = st.sidebar.number_input("eta", value=1.0, min_value=0.0) if sampler_name == "EulerAncestralSampler": sampler = EulerAncestralSampler( num_steps=steps, discretization_config=discretization_config, guider_config=guider_config, eta=eta, s_noise=s_noise, verbose=True, ) elif sampler_name == "DPMPP2SAncestralSampler": sampler = DPMPP2SAncestralSampler( num_steps=steps, discretization_config=discretization_config, guider_config=guider_config, eta=eta, s_noise=s_noise, verbose=True, ) elif sampler_name == "DPMPP2MSampler": sampler = DPMPP2MSampler( num_steps=steps, discretization_config=discretization_config, guider_config=guider_config, verbose=True, ) elif sampler_name == "LinearMultistepSampler": order = st.sidebar.number_input("order", value=4, min_value=1) sampler = LinearMultistepSampler( num_steps=steps, discretization_config=discretization_config, guider_config=guider_config, order=order, verbose=True, ) else: raise ValueError(f"unknown sampler {sampler_name}!") return sampler def get_interactive_image() -> Image.Image: image = st.file_uploader("Input", type=["jpg", "JPEG", "png"]) if image is not None: image = Image.open(image) if not image.mode == "RGB": image = image.convert("RGB") return image def load_img( display: bool = True, size: Union[None, int, Tuple[int, int]] = None, center_crop: bool = False, ): image = get_interactive_image() if image is None: return None if display: st.image(image) w, h = image.size print(f"loaded input image of size ({w}, {h})") transform = [] if size is not None: transform.append(transforms.Resize(size)) if center_crop: transform.append(transforms.CenterCrop(size)) transform.append(transforms.ToTensor()) transform.append(transforms.Lambda(lambda x: 2.0 * x - 1.0)) transform = transforms.Compose(transform) img = transform(image)[None, ...] st.text(f"input min/max/mean: {img.min():.3f}/{img.max():.3f}/{img.mean():.3f}") return img def get_init_img(batch_size=1, key=None): init_image = load_img(key=key).cuda() init_image = repeat(init_image, "1 ... -> b ...", b=batch_size) return init_image def do_sample( model, sampler, value_dict, num_samples, H, W, C, F, force_uc_zero_embeddings: Optional[List] = None, force_cond_zero_embeddings: Optional[List] = None, batch2model_input: List = None, return_latents=False, filter=None, T=None, additional_batch_uc_fields=None, decoding_t=None, ): force_uc_zero_embeddings = default(force_uc_zero_embeddings, []) batch2model_input = default(batch2model_input, []) additional_batch_uc_fields = default(additional_batch_uc_fields, []) st.text("Sampling") outputs = st.empty() precision_scope = autocast with torch.no_grad(): with precision_scope("cuda"): with model.ema_scope(): if T is not None: num_samples = [num_samples, T] else: num_samples = [num_samples] load_model(model.conditioner) batch, batch_uc = get_batch( get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples, T=T, additional_batch_uc_fields=additional_batch_uc_fields, ) c, uc = model.conditioner.get_unconditional_conditioning( batch, batch_uc=batch_uc, force_uc_zero_embeddings=force_uc_zero_embeddings, force_cond_zero_embeddings=force_cond_zero_embeddings, ) unload_model(model.conditioner) for k in c: if not k == "crossattn": c[k], uc[k] = map( lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc) ) if k in ["crossattn", "concat"] and T is not None: uc[k] = repeat(uc[k], "b ... -> b t ...", t=T) uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=T) c[k] = repeat(c[k], "b ... -> b t ...", t=T) c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=T) additional_model_inputs = {} for k in batch2model_input: if k == "image_only_indicator": assert T is not None if isinstance( sampler.guider, (VanillaCFG, LinearPredictionGuider) ): additional_model_inputs[k] = torch.zeros( num_samples[0] * 2, num_samples[1] ).to("cuda") else: additional_model_inputs[k] = torch.zeros(num_samples).to( "cuda" ) else: additional_model_inputs[k] = batch[k] shape = (math.prod(num_samples), C, H // F, W // F) randn = torch.randn(shape).to("cuda") def denoiser(input, sigma, c): return model.denoiser( model.model, input, sigma, c, **additional_model_inputs ) load_model(model.denoiser) load_model(model.model) samples_z = sampler(denoiser, randn, cond=c, uc=uc) unload_model(model.model) unload_model(model.denoiser) load_model(model.first_stage_model) model.en_and_decode_n_samples_a_time = ( decoding_t # Decode n frames at a time ) samples_x = model.decode_first_stage(samples_z) samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) unload_model(model.first_stage_model) if filter is not None: samples = filter(samples) if T is None: grid = torch.stack([samples]) grid = rearrange(grid, "n b c h w -> (n h) (b w) c") outputs.image(grid.cpu().numpy()) else: as_vids = rearrange(samples, "(b t) c h w -> b t c h w", t=T) for i, vid in enumerate(as_vids): grid = rearrange(make_grid(vid, nrow=4), "c h w -> h w c") st.image( grid.cpu().numpy(), f"Sample #{i} as image", ) if return_latents: return samples, samples_z return samples def get_batch( keys, value_dict: dict, N: Union[List, ListConfig], device: str = "cuda", T: int = None, additional_batch_uc_fields: List[str] = [], ): # Hardcoded demo setups; might undergo some changes in the future batch = {} batch_uc = {} for key in keys: if key == "txt": batch["txt"] = [value_dict["prompt"]] * math.prod(N) batch_uc["txt"] = [value_dict["negative_prompt"]] * math.prod(N) elif key == "original_size_as_tuple": batch["original_size_as_tuple"] = ( torch.tensor([value_dict["orig_height"], value_dict["orig_width"]]) .to(device) .repeat(math.prod(N), 1) ) elif key == "crop_coords_top_left": batch["crop_coords_top_left"] = ( torch.tensor( [value_dict["crop_coords_top"], value_dict["crop_coords_left"]] ) .to(device) .repeat(math.prod(N), 1) ) elif key == "aesthetic_score": batch["aesthetic_score"] = ( torch.tensor([value_dict["aesthetic_score"]]) .to(device) .repeat(math.prod(N), 1) ) batch_uc["aesthetic_score"] = ( torch.tensor([value_dict["negative_aesthetic_score"]]) .to(device) .repeat(math.prod(N), 1) ) elif key == "target_size_as_tuple": batch["target_size_as_tuple"] = ( torch.tensor([value_dict["target_height"], value_dict["target_width"]]) .to(device) .repeat(math.prod(N), 1) ) elif key == "fps": batch[key] = ( torch.tensor([value_dict["fps"]]).to(device).repeat(math.prod(N)) ) elif key == "fps_id": batch[key] = ( torch.tensor([value_dict["fps_id"]]).to(device).repeat(math.prod(N)) ) elif key == "motion_bucket_id": batch[key] = ( torch.tensor([value_dict["motion_bucket_id"]]) .to(device) .repeat(math.prod(N)) ) elif key == "pool_image": batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=math.prod(N)).to( device, dtype=torch.half ) elif key == "cond_aug": batch[key] = repeat( torch.tensor([value_dict["cond_aug"]]).to("cuda"), "1 -> b", b=math.prod(N), ) elif key == "cond_frames": batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0]) elif key == "cond_frames_without_noise": batch[key] = repeat( value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0] ) else: batch[key] = value_dict[key] if T is not None: batch["num_video_frames"] = T for key in batch.keys(): if key not in batch_uc and isinstance(batch[key], torch.Tensor): batch_uc[key] = torch.clone(batch[key]) elif key in additional_batch_uc_fields and key not in batch_uc: batch_uc[key] = copy.copy(batch[key]) return batch, batch_uc @torch.no_grad() def do_img2img( img, model, sampler, value_dict, num_samples, force_uc_zero_embeddings: Optional[List] = None, force_cond_zero_embeddings: Optional[List] = None, additional_kwargs={}, offset_noise_level: int = 0.0, return_latents=False, skip_encode=False, filter=None, add_noise=True, ): st.text("Sampling") outputs = st.empty() precision_scope = autocast with torch.no_grad(): with precision_scope("cuda"): with model.ema_scope(): load_model(model.conditioner) batch, batch_uc = get_batch( get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, [num_samples], ) c, uc = model.conditioner.get_unconditional_conditioning( batch, batch_uc=batch_uc, force_uc_zero_embeddings=force_uc_zero_embeddings, force_cond_zero_embeddings=force_cond_zero_embeddings, ) unload_model(model.conditioner) for k in c: c[k], uc[k] = map(lambda y: y[k][:num_samples].to("cuda"), (c, uc)) for k in additional_kwargs: c[k] = uc[k] = additional_kwargs[k] if skip_encode: z = img else: load_model(model.first_stage_model) z = model.encode_first_stage(img) unload_model(model.first_stage_model) noise = torch.randn_like(z) sigmas = sampler.discretization(sampler.num_steps).cuda() sigma = sigmas[0] st.info(f"all sigmas: {sigmas}") st.info(f"noising sigma: {sigma}") if offset_noise_level > 0.0: noise = noise + offset_noise_level * append_dims( torch.randn(z.shape[0], device=z.device), z.ndim ) if add_noise: noised_z = z + noise * append_dims(sigma, z.ndim).cuda() noised_z = noised_z / torch.sqrt( 1.0 + sigmas[0] ** 2.0 ) # Note: hardcoded to DDPM-like scaling. need to generalize later. else: noised_z = z / torch.sqrt(1.0 + sigmas[0] ** 2.0) def denoiser(x, sigma, c): return model.denoiser(model.model, x, sigma, c) load_model(model.denoiser) load_model(model.model) samples_z = sampler(denoiser, noised_z, cond=c, uc=uc) unload_model(model.model) unload_model(model.denoiser) load_model(model.first_stage_model) samples_x = model.decode_first_stage(samples_z) unload_model(model.first_stage_model) samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) if filter is not None: samples = filter(samples) grid = rearrange(grid, "n b c h w -> (n h) (b w) c") outputs.image(grid.cpu().numpy()) if return_latents: return samples, samples_z return samples def get_resizing_factor( desired_shape: Tuple[int, int], current_shape: Tuple[int, int] ) -> float: r_bound = desired_shape[1] / desired_shape[0] aspect_r = current_shape[1] / current_shape[0] if r_bound >= 1.0: if aspect_r >= r_bound: factor = min(desired_shape) / min(current_shape) else: if aspect_r < 1.0: factor = max(desired_shape) / min(current_shape) else: factor = max(desired_shape) / max(current_shape) else: if aspect_r <= r_bound: factor = min(desired_shape) / min(current_shape) else: if aspect_r > 1: factor = max(desired_shape) / min(current_shape) else: factor = max(desired_shape) / max(current_shape) return factor def get_interactive_image(key=None) -> Image.Image: image = st.file_uploader("Input", type=["jpg", "JPEG", "png"], key=key) if image is not None: image = Image.open(image) if not image.mode == "RGB": image = image.convert("RGB") return image def load_img_for_prediction( W: int, H: int, display=True, key=None, device="cuda" ) -> torch.Tensor: image = get_interactive_image(key=key) if image is None: return None if display: st.image(image) w, h = image.size image = np.array(image).transpose(2, 0, 1) image = torch.from_numpy(image).to(dtype=torch.float32) / 255.0 image = image.unsqueeze(0) rfs = get_resizing_factor((H, W), (h, w)) resize_size = [int(np.ceil(rfs * s)) for s in (h, w)] top = (resize_size[0] - H) // 2 left = (resize_size[1] - W) // 2 image = torch.nn.functional.interpolate( image, resize_size, mode="area", antialias=False ) image = TT.functional.crop(image, top=top, left=left, height=H, width=W) if display: numpy_img = np.transpose(image[0].numpy(), (1, 2, 0)) pil_image = Image.fromarray((numpy_img * 255).astype(np.uint8)) st.image(pil_image) return image.to(device) * 2.0 - 1.0 def save_video_as_grid_and_mp4( video_batch: torch.Tensor, save_path: str, T: int, fps: int = 5 ): os.makedirs(save_path, exist_ok=True) base_count = len(glob(os.path.join(save_path, "*.mp4"))) video_batch = rearrange(video_batch, "(b t) c h w -> b t c h w", t=T) video_batch = embed_watermark(video_batch) for vid in video_batch: save_image(vid, fp=os.path.join(save_path, f"{base_count:06d}.png"), nrow=4) video_path = os.path.join(save_path, f"{base_count:06d}.mp4") writer = cv2.VideoWriter( video_path, cv2.VideoWriter_fourcc(*"MP4V"), fps, (vid.shape[-1], vid.shape[-2]), ) vid = ( (rearrange(vid, "t c h w -> t h w c") * 255).cpu().numpy().astype(np.uint8) ) for frame in vid: frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) writer.write(frame) writer.release() video_path_h264 = video_path[:-4] + "_h264.mp4" os.system(f"ffmpeg -i {video_path} -c:v libx264 {video_path_h264}") with open(video_path_h264, "rb") as f: video_bytes = f.read() st.video(video_bytes) base_count += 1