Spaces:
Sleeping
Sleeping
import copy | |
import math | |
import os | |
from glob import glob | |
from typing import Dict, List, Optional, Tuple, Union | |
import cv2 | |
import numpy as np | |
import streamlit as st | |
import torch | |
import torch.nn as nn | |
import torchvision.transforms as TT | |
from einops import rearrange, repeat | |
from imwatermark import WatermarkEncoder | |
from omegaconf import ListConfig, OmegaConf | |
from PIL import Image | |
from safetensors.torch import load_file as load_safetensors | |
from torch import autocast | |
from torchvision import transforms | |
from torchvision.utils import make_grid, save_image | |
from scripts.demo.discretization import (Img2ImgDiscretizationWrapper, | |
Txt2NoisyDiscretizationWrapper) | |
from scripts.util.detection.nsfw_and_watermark_dectection import \ | |
DeepFloydDataFiltering | |
from sgm.inference.helpers import embed_watermark | |
from sgm.modules.diffusionmodules.guiders import (LinearPredictionGuider, | |
VanillaCFG) | |
from sgm.modules.diffusionmodules.sampling import (DPMPP2MSampler, | |
DPMPP2SAncestralSampler, | |
EulerAncestralSampler, | |
EulerEDMSampler, | |
HeunEDMSampler, | |
LinearMultistepSampler) | |
from sgm.util import append_dims, default, instantiate_from_config | |
def init_st(version_dict, load_ckpt=True, load_filter=True): | |
state = dict() | |
if not "model" in state: | |
config = version_dict["config"] | |
ckpt = version_dict["ckpt"] | |
config = OmegaConf.load(config) | |
model, msg = load_model_from_config(config, ckpt if load_ckpt else None) | |
state["msg"] = msg | |
state["model"] = model | |
state["ckpt"] = ckpt if load_ckpt else None | |
state["config"] = config | |
if load_filter: | |
state["filter"] = DeepFloydDataFiltering(verbose=False) | |
return state | |
def load_model(model): | |
model.cuda() | |
lowvram_mode = False | |
def set_lowvram_mode(mode): | |
global lowvram_mode | |
lowvram_mode = mode | |
def initial_model_load(model): | |
global lowvram_mode | |
if lowvram_mode: | |
model.model.half() | |
else: | |
model.cuda() | |
return model | |
def unload_model(model): | |
global lowvram_mode | |
if lowvram_mode: | |
model.cpu() | |
torch.cuda.empty_cache() | |
def load_model_from_config(config, ckpt=None, verbose=True): | |
model = instantiate_from_config(config.model) | |
if ckpt is not None: | |
print(f"Loading model from {ckpt}") | |
if ckpt.endswith("ckpt"): | |
pl_sd = torch.load(ckpt, map_location="cpu") | |
if "global_step" in pl_sd: | |
global_step = pl_sd["global_step"] | |
st.info(f"loaded ckpt from global step {global_step}") | |
print(f"Global Step: {pl_sd['global_step']}") | |
sd = pl_sd["state_dict"] | |
elif ckpt.endswith("safetensors"): | |
sd = load_safetensors(ckpt) | |
else: | |
raise NotImplementedError | |
msg = None | |
m, u = model.load_state_dict(sd, strict=False) | |
if len(m) > 0 and verbose: | |
print("missing keys:") | |
print(m) | |
if len(u) > 0 and verbose: | |
print("unexpected keys:") | |
print(u) | |
else: | |
msg = None | |
model = initial_model_load(model) | |
model.eval() | |
return model, msg | |
def get_unique_embedder_keys_from_conditioner(conditioner): | |
return list(set([x.input_key for x in conditioner.embedders])) | |
def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None): | |
# Hardcoded demo settings; might undergo some changes in the future | |
value_dict = {} | |
for key in keys: | |
if key == "txt": | |
if prompt is None: | |
prompt = "A professional photograph of an astronaut riding a pig" | |
if negative_prompt is None: | |
negative_prompt = "" | |
prompt = st.text_input("Prompt", prompt) | |
negative_prompt = st.text_input("Negative prompt", negative_prompt) | |
value_dict["prompt"] = prompt | |
value_dict["negative_prompt"] = negative_prompt | |
if key == "original_size_as_tuple": | |
orig_width = st.number_input( | |
"orig_width", | |
value=init_dict["orig_width"], | |
min_value=16, | |
) | |
orig_height = st.number_input( | |
"orig_height", | |
value=init_dict["orig_height"], | |
min_value=16, | |
) | |
value_dict["orig_width"] = orig_width | |
value_dict["orig_height"] = orig_height | |
if key == "crop_coords_top_left": | |
crop_coord_top = st.number_input("crop_coords_top", value=0, min_value=0) | |
crop_coord_left = st.number_input("crop_coords_left", value=0, min_value=0) | |
value_dict["crop_coords_top"] = crop_coord_top | |
value_dict["crop_coords_left"] = crop_coord_left | |
if key == "aesthetic_score": | |
value_dict["aesthetic_score"] = 6.0 | |
value_dict["negative_aesthetic_score"] = 2.5 | |
if key == "target_size_as_tuple": | |
value_dict["target_width"] = init_dict["target_width"] | |
value_dict["target_height"] = init_dict["target_height"] | |
if key in ["fps_id", "fps"]: | |
fps = st.number_input("fps", value=6, min_value=1) | |
value_dict["fps"] = fps | |
value_dict["fps_id"] = fps - 1 | |
if key == "motion_bucket_id": | |
mb_id = st.number_input("motion bucket id", 0, 511, value=127) | |
value_dict["motion_bucket_id"] = mb_id | |
if key == "pool_image": | |
st.text("Image for pool conditioning") | |
image = load_img( | |
key="pool_image_input", | |
size=224, | |
center_crop=True, | |
) | |
if image is None: | |
st.info("Need an image here") | |
image = torch.zeros(1, 3, 224, 224) | |
value_dict["pool_image"] = image | |
return value_dict | |
def perform_save_locally(save_path, samples): | |
os.makedirs(os.path.join(save_path), exist_ok=True) | |
base_count = len(os.listdir(os.path.join(save_path))) | |
samples = embed_watermark(samples) | |
for sample in samples: | |
sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c") | |
Image.fromarray(sample.astype(np.uint8)).save( | |
os.path.join(save_path, f"{base_count:09}.png") | |
) | |
base_count += 1 | |
def init_save_locally(_dir, init_value: bool = False): | |
save_locally = st.sidebar.checkbox("Save images locally", value=init_value) | |
if save_locally: | |
save_path = st.text_input("Save path", value=os.path.join(_dir, "samples")) | |
else: | |
save_path = None | |
return save_locally, save_path | |
def get_guider(options, key): | |
guider = st.sidebar.selectbox( | |
f"Discretization #{key}", | |
[ | |
"VanillaCFG", | |
"IdentityGuider", | |
"LinearPredictionGuider", | |
], | |
options.get("guider", 0), | |
) | |
additional_guider_kwargs = options.pop("additional_guider_kwargs", {}) | |
if guider == "IdentityGuider": | |
guider_config = { | |
"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider" | |
} | |
elif guider == "VanillaCFG": | |
scale = st.number_input( | |
f"cfg-scale #{key}", | |
value=options.get("cfg", 5.0), | |
min_value=0.0, | |
) | |
guider_config = { | |
"target": "sgm.modules.diffusionmodules.guiders.VanillaCFG", | |
"params": { | |
"scale": scale, | |
**additional_guider_kwargs, | |
}, | |
} | |
elif guider == "LinearPredictionGuider": | |
max_scale = st.number_input( | |
f"max-cfg-scale #{key}", | |
value=options.get("cfg", 1.5), | |
min_value=1.0, | |
) | |
min_scale = st.number_input( | |
f"min guidance scale", | |
value=options.get("min_cfg", 1.0), | |
min_value=1.0, | |
max_value=10.0, | |
) | |
guider_config = { | |
"target": "sgm.modules.diffusionmodules.guiders.LinearPredictionGuider", | |
"params": { | |
"max_scale": max_scale, | |
"min_scale": min_scale, | |
"num_frames": options["num_frames"], | |
**additional_guider_kwargs, | |
}, | |
} | |
else: | |
raise NotImplementedError | |
return guider_config | |
def init_sampling( | |
key=1, | |
img2img_strength: Optional[float] = None, | |
specify_num_samples: bool = True, | |
stage2strength: Optional[float] = None, | |
options: Optional[Dict[str, int]] = None, | |
): | |
options = {} if options is None else options | |
num_rows, num_cols = 1, 1 | |
if specify_num_samples: | |
num_cols = st.number_input( | |
f"num cols #{key}", value=num_cols, min_value=1, max_value=10 | |
) | |
steps = st.sidebar.number_input( | |
f"steps #{key}", value=options.get("num_steps", 40), min_value=1, max_value=1000 | |
) | |
sampler = st.sidebar.selectbox( | |
f"Sampler #{key}", | |
[ | |
"EulerEDMSampler", | |
"HeunEDMSampler", | |
"EulerAncestralSampler", | |
"DPMPP2SAncestralSampler", | |
"DPMPP2MSampler", | |
"LinearMultistepSampler", | |
], | |
options.get("sampler", 0), | |
) | |
discretization = st.sidebar.selectbox( | |
f"Discretization #{key}", | |
[ | |
"LegacyDDPMDiscretization", | |
"EDMDiscretization", | |
], | |
options.get("discretization", 0), | |
) | |
discretization_config = get_discretization(discretization, options=options, key=key) | |
guider_config = get_guider(options=options, key=key) | |
sampler = get_sampler(sampler, steps, discretization_config, guider_config, key=key) | |
if img2img_strength is not None: | |
st.warning( | |
f"Wrapping {sampler.__class__.__name__} with Img2ImgDiscretizationWrapper" | |
) | |
sampler.discretization = Img2ImgDiscretizationWrapper( | |
sampler.discretization, strength=img2img_strength | |
) | |
if stage2strength is not None: | |
sampler.discretization = Txt2NoisyDiscretizationWrapper( | |
sampler.discretization, strength=stage2strength, original_steps=steps | |
) | |
return sampler, num_rows, num_cols | |
def get_discretization(discretization, options, key=1): | |
if discretization == "LegacyDDPMDiscretization": | |
discretization_config = { | |
"target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization", | |
} | |
elif discretization == "EDMDiscretization": | |
sigma_min = st.number_input( | |
f"sigma_min #{key}", value=options.get("sigma_min", 0.03) | |
) # 0.0292 | |
sigma_max = st.number_input( | |
f"sigma_max #{key}", value=options.get("sigma_max", 14.61) | |
) # 14.6146 | |
rho = st.number_input(f"rho #{key}", value=options.get("rho", 3.0)) | |
discretization_config = { | |
"target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization", | |
"params": { | |
"sigma_min": sigma_min, | |
"sigma_max": sigma_max, | |
"rho": rho, | |
}, | |
} | |
return discretization_config | |
def get_sampler(sampler_name, steps, discretization_config, guider_config, key=1): | |
if sampler_name == "EulerEDMSampler" or sampler_name == "HeunEDMSampler": | |
s_churn = st.sidebar.number_input(f"s_churn #{key}", value=0.0, min_value=0.0) | |
s_tmin = st.sidebar.number_input(f"s_tmin #{key}", value=0.0, min_value=0.0) | |
s_tmax = st.sidebar.number_input(f"s_tmax #{key}", value=999.0, min_value=0.0) | |
s_noise = st.sidebar.number_input(f"s_noise #{key}", value=1.0, min_value=0.0) | |
if sampler_name == "EulerEDMSampler": | |
sampler = EulerEDMSampler( | |
num_steps=steps, | |
discretization_config=discretization_config, | |
guider_config=guider_config, | |
s_churn=s_churn, | |
s_tmin=s_tmin, | |
s_tmax=s_tmax, | |
s_noise=s_noise, | |
verbose=True, | |
) | |
elif sampler_name == "HeunEDMSampler": | |
sampler = HeunEDMSampler( | |
num_steps=steps, | |
discretization_config=discretization_config, | |
guider_config=guider_config, | |
s_churn=s_churn, | |
s_tmin=s_tmin, | |
s_tmax=s_tmax, | |
s_noise=s_noise, | |
verbose=True, | |
) | |
elif ( | |
sampler_name == "EulerAncestralSampler" | |
or sampler_name == "DPMPP2SAncestralSampler" | |
): | |
s_noise = st.sidebar.number_input("s_noise", value=1.0, min_value=0.0) | |
eta = st.sidebar.number_input("eta", value=1.0, min_value=0.0) | |
if sampler_name == "EulerAncestralSampler": | |
sampler = EulerAncestralSampler( | |
num_steps=steps, | |
discretization_config=discretization_config, | |
guider_config=guider_config, | |
eta=eta, | |
s_noise=s_noise, | |
verbose=True, | |
) | |
elif sampler_name == "DPMPP2SAncestralSampler": | |
sampler = DPMPP2SAncestralSampler( | |
num_steps=steps, | |
discretization_config=discretization_config, | |
guider_config=guider_config, | |
eta=eta, | |
s_noise=s_noise, | |
verbose=True, | |
) | |
elif sampler_name == "DPMPP2MSampler": | |
sampler = DPMPP2MSampler( | |
num_steps=steps, | |
discretization_config=discretization_config, | |
guider_config=guider_config, | |
verbose=True, | |
) | |
elif sampler_name == "LinearMultistepSampler": | |
order = st.sidebar.number_input("order", value=4, min_value=1) | |
sampler = LinearMultistepSampler( | |
num_steps=steps, | |
discretization_config=discretization_config, | |
guider_config=guider_config, | |
order=order, | |
verbose=True, | |
) | |
else: | |
raise ValueError(f"unknown sampler {sampler_name}!") | |
return sampler | |
def get_interactive_image() -> Image.Image: | |
image = st.file_uploader("Input", type=["jpg", "JPEG", "png"]) | |
if image is not None: | |
image = Image.open(image) | |
if not image.mode == "RGB": | |
image = image.convert("RGB") | |
return image | |
def load_img( | |
display: bool = True, | |
size: Union[None, int, Tuple[int, int]] = None, | |
center_crop: bool = False, | |
): | |
image = get_interactive_image() | |
if image is None: | |
return None | |
if display: | |
st.image(image) | |
w, h = image.size | |
print(f"loaded input image of size ({w}, {h})") | |
transform = [] | |
if size is not None: | |
transform.append(transforms.Resize(size)) | |
if center_crop: | |
transform.append(transforms.CenterCrop(size)) | |
transform.append(transforms.ToTensor()) | |
transform.append(transforms.Lambda(lambda x: 2.0 * x - 1.0)) | |
transform = transforms.Compose(transform) | |
img = transform(image)[None, ...] | |
st.text(f"input min/max/mean: {img.min():.3f}/{img.max():.3f}/{img.mean():.3f}") | |
return img | |
def get_init_img(batch_size=1, key=None): | |
init_image = load_img(key=key).cuda() | |
init_image = repeat(init_image, "1 ... -> b ...", b=batch_size) | |
return init_image | |
def do_sample( | |
model, | |
sampler, | |
value_dict, | |
num_samples, | |
H, | |
W, | |
C, | |
F, | |
force_uc_zero_embeddings: Optional[List] = None, | |
force_cond_zero_embeddings: Optional[List] = None, | |
batch2model_input: List = None, | |
return_latents=False, | |
filter=None, | |
T=None, | |
additional_batch_uc_fields=None, | |
decoding_t=None, | |
): | |
force_uc_zero_embeddings = default(force_uc_zero_embeddings, []) | |
batch2model_input = default(batch2model_input, []) | |
additional_batch_uc_fields = default(additional_batch_uc_fields, []) | |
st.text("Sampling") | |
outputs = st.empty() | |
precision_scope = autocast | |
with torch.no_grad(): | |
with precision_scope("cuda"): | |
with model.ema_scope(): | |
if T is not None: | |
num_samples = [num_samples, T] | |
else: | |
num_samples = [num_samples] | |
load_model(model.conditioner) | |
batch, batch_uc = get_batch( | |
get_unique_embedder_keys_from_conditioner(model.conditioner), | |
value_dict, | |
num_samples, | |
T=T, | |
additional_batch_uc_fields=additional_batch_uc_fields, | |
) | |
c, uc = model.conditioner.get_unconditional_conditioning( | |
batch, | |
batch_uc=batch_uc, | |
force_uc_zero_embeddings=force_uc_zero_embeddings, | |
force_cond_zero_embeddings=force_cond_zero_embeddings, | |
) | |
unload_model(model.conditioner) | |
for k in c: | |
if not k == "crossattn": | |
c[k], uc[k] = map( | |
lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc) | |
) | |
if k in ["crossattn", "concat"] and T is not None: | |
uc[k] = repeat(uc[k], "b ... -> b t ...", t=T) | |
uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=T) | |
c[k] = repeat(c[k], "b ... -> b t ...", t=T) | |
c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=T) | |
additional_model_inputs = {} | |
for k in batch2model_input: | |
if k == "image_only_indicator": | |
assert T is not None | |
if isinstance( | |
sampler.guider, (VanillaCFG, LinearPredictionGuider) | |
): | |
additional_model_inputs[k] = torch.zeros( | |
num_samples[0] * 2, num_samples[1] | |
).to("cuda") | |
else: | |
additional_model_inputs[k] = torch.zeros(num_samples).to( | |
"cuda" | |
) | |
else: | |
additional_model_inputs[k] = batch[k] | |
shape = (math.prod(num_samples), C, H // F, W // F) | |
randn = torch.randn(shape).to("cuda") | |
def denoiser(input, sigma, c): | |
return model.denoiser( | |
model.model, input, sigma, c, **additional_model_inputs | |
) | |
load_model(model.denoiser) | |
load_model(model.model) | |
samples_z = sampler(denoiser, randn, cond=c, uc=uc) | |
unload_model(model.model) | |
unload_model(model.denoiser) | |
load_model(model.first_stage_model) | |
model.en_and_decode_n_samples_a_time = ( | |
decoding_t # Decode n frames at a time | |
) | |
samples_x = model.decode_first_stage(samples_z) | |
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) | |
unload_model(model.first_stage_model) | |
if filter is not None: | |
samples = filter(samples) | |
if T is None: | |
grid = torch.stack([samples]) | |
grid = rearrange(grid, "n b c h w -> (n h) (b w) c") | |
outputs.image(grid.cpu().numpy()) | |
else: | |
as_vids = rearrange(samples, "(b t) c h w -> b t c h w", t=T) | |
for i, vid in enumerate(as_vids): | |
grid = rearrange(make_grid(vid, nrow=4), "c h w -> h w c") | |
st.image( | |
grid.cpu().numpy(), | |
f"Sample #{i} as image", | |
) | |
if return_latents: | |
return samples, samples_z | |
return samples | |
def get_batch( | |
keys, | |
value_dict: dict, | |
N: Union[List, ListConfig], | |
device: str = "cuda", | |
T: int = None, | |
additional_batch_uc_fields: List[str] = [], | |
): | |
# Hardcoded demo setups; might undergo some changes in the future | |
batch = {} | |
batch_uc = {} | |
for key in keys: | |
if key == "txt": | |
batch["txt"] = [value_dict["prompt"]] * math.prod(N) | |
batch_uc["txt"] = [value_dict["negative_prompt"]] * math.prod(N) | |
elif key == "original_size_as_tuple": | |
batch["original_size_as_tuple"] = ( | |
torch.tensor([value_dict["orig_height"], value_dict["orig_width"]]) | |
.to(device) | |
.repeat(math.prod(N), 1) | |
) | |
elif key == "crop_coords_top_left": | |
batch["crop_coords_top_left"] = ( | |
torch.tensor( | |
[value_dict["crop_coords_top"], value_dict["crop_coords_left"]] | |
) | |
.to(device) | |
.repeat(math.prod(N), 1) | |
) | |
elif key == "aesthetic_score": | |
batch["aesthetic_score"] = ( | |
torch.tensor([value_dict["aesthetic_score"]]) | |
.to(device) | |
.repeat(math.prod(N), 1) | |
) | |
batch_uc["aesthetic_score"] = ( | |
torch.tensor([value_dict["negative_aesthetic_score"]]) | |
.to(device) | |
.repeat(math.prod(N), 1) | |
) | |
elif key == "target_size_as_tuple": | |
batch["target_size_as_tuple"] = ( | |
torch.tensor([value_dict["target_height"], value_dict["target_width"]]) | |
.to(device) | |
.repeat(math.prod(N), 1) | |
) | |
elif key == "fps": | |
batch[key] = ( | |
torch.tensor([value_dict["fps"]]).to(device).repeat(math.prod(N)) | |
) | |
elif key == "fps_id": | |
batch[key] = ( | |
torch.tensor([value_dict["fps_id"]]).to(device).repeat(math.prod(N)) | |
) | |
elif key == "motion_bucket_id": | |
batch[key] = ( | |
torch.tensor([value_dict["motion_bucket_id"]]) | |
.to(device) | |
.repeat(math.prod(N)) | |
) | |
elif key == "pool_image": | |
batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=math.prod(N)).to( | |
device, dtype=torch.half | |
) | |
elif key == "cond_aug": | |
batch[key] = repeat( | |
torch.tensor([value_dict["cond_aug"]]).to("cuda"), | |
"1 -> b", | |
b=math.prod(N), | |
) | |
elif key == "cond_frames": | |
batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0]) | |
elif key == "cond_frames_without_noise": | |
batch[key] = repeat( | |
value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0] | |
) | |
else: | |
batch[key] = value_dict[key] | |
if T is not None: | |
batch["num_video_frames"] = T | |
for key in batch.keys(): | |
if key not in batch_uc and isinstance(batch[key], torch.Tensor): | |
batch_uc[key] = torch.clone(batch[key]) | |
elif key in additional_batch_uc_fields and key not in batch_uc: | |
batch_uc[key] = copy.copy(batch[key]) | |
return batch, batch_uc | |
def do_img2img( | |
img, | |
model, | |
sampler, | |
value_dict, | |
num_samples, | |
force_uc_zero_embeddings: Optional[List] = None, | |
force_cond_zero_embeddings: Optional[List] = None, | |
additional_kwargs={}, | |
offset_noise_level: int = 0.0, | |
return_latents=False, | |
skip_encode=False, | |
filter=None, | |
add_noise=True, | |
): | |
st.text("Sampling") | |
outputs = st.empty() | |
precision_scope = autocast | |
with torch.no_grad(): | |
with precision_scope("cuda"): | |
with model.ema_scope(): | |
load_model(model.conditioner) | |
batch, batch_uc = get_batch( | |
get_unique_embedder_keys_from_conditioner(model.conditioner), | |
value_dict, | |
[num_samples], | |
) | |
c, uc = model.conditioner.get_unconditional_conditioning( | |
batch, | |
batch_uc=batch_uc, | |
force_uc_zero_embeddings=force_uc_zero_embeddings, | |
force_cond_zero_embeddings=force_cond_zero_embeddings, | |
) | |
unload_model(model.conditioner) | |
for k in c: | |
c[k], uc[k] = map(lambda y: y[k][:num_samples].to("cuda"), (c, uc)) | |
for k in additional_kwargs: | |
c[k] = uc[k] = additional_kwargs[k] | |
if skip_encode: | |
z = img | |
else: | |
load_model(model.first_stage_model) | |
z = model.encode_first_stage(img) | |
unload_model(model.first_stage_model) | |
noise = torch.randn_like(z) | |
sigmas = sampler.discretization(sampler.num_steps).cuda() | |
sigma = sigmas[0] | |
st.info(f"all sigmas: {sigmas}") | |
st.info(f"noising sigma: {sigma}") | |
if offset_noise_level > 0.0: | |
noise = noise + offset_noise_level * append_dims( | |
torch.randn(z.shape[0], device=z.device), z.ndim | |
) | |
if add_noise: | |
noised_z = z + noise * append_dims(sigma, z.ndim).cuda() | |
noised_z = noised_z / torch.sqrt( | |
1.0 + sigmas[0] ** 2.0 | |
) # Note: hardcoded to DDPM-like scaling. need to generalize later. | |
else: | |
noised_z = z / torch.sqrt(1.0 + sigmas[0] ** 2.0) | |
def denoiser(x, sigma, c): | |
return model.denoiser(model.model, x, sigma, c) | |
load_model(model.denoiser) | |
load_model(model.model) | |
samples_z = sampler(denoiser, noised_z, cond=c, uc=uc) | |
unload_model(model.model) | |
unload_model(model.denoiser) | |
load_model(model.first_stage_model) | |
samples_x = model.decode_first_stage(samples_z) | |
unload_model(model.first_stage_model) | |
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) | |
if filter is not None: | |
samples = filter(samples) | |
grid = rearrange(grid, "n b c h w -> (n h) (b w) c") | |
outputs.image(grid.cpu().numpy()) | |
if return_latents: | |
return samples, samples_z | |
return samples | |
def get_resizing_factor( | |
desired_shape: Tuple[int, int], current_shape: Tuple[int, int] | |
) -> float: | |
r_bound = desired_shape[1] / desired_shape[0] | |
aspect_r = current_shape[1] / current_shape[0] | |
if r_bound >= 1.0: | |
if aspect_r >= r_bound: | |
factor = min(desired_shape) / min(current_shape) | |
else: | |
if aspect_r < 1.0: | |
factor = max(desired_shape) / min(current_shape) | |
else: | |
factor = max(desired_shape) / max(current_shape) | |
else: | |
if aspect_r <= r_bound: | |
factor = min(desired_shape) / min(current_shape) | |
else: | |
if aspect_r > 1: | |
factor = max(desired_shape) / min(current_shape) | |
else: | |
factor = max(desired_shape) / max(current_shape) | |
return factor | |
def get_interactive_image(key=None) -> Image.Image: | |
image = st.file_uploader("Input", type=["jpg", "JPEG", "png"], key=key) | |
if image is not None: | |
image = Image.open(image) | |
if not image.mode == "RGB": | |
image = image.convert("RGB") | |
return image | |
def load_img_for_prediction( | |
W: int, H: int, display=True, key=None, device="cuda" | |
) -> torch.Tensor: | |
image = get_interactive_image(key=key) | |
if image is None: | |
return None | |
if display: | |
st.image(image) | |
w, h = image.size | |
image = np.array(image).transpose(2, 0, 1) | |
image = torch.from_numpy(image).to(dtype=torch.float32) / 255.0 | |
image = image.unsqueeze(0) | |
rfs = get_resizing_factor((H, W), (h, w)) | |
resize_size = [int(np.ceil(rfs * s)) for s in (h, w)] | |
top = (resize_size[0] - H) // 2 | |
left = (resize_size[1] - W) // 2 | |
image = torch.nn.functional.interpolate( | |
image, resize_size, mode="area", antialias=False | |
) | |
image = TT.functional.crop(image, top=top, left=left, height=H, width=W) | |
if display: | |
numpy_img = np.transpose(image[0].numpy(), (1, 2, 0)) | |
pil_image = Image.fromarray((numpy_img * 255).astype(np.uint8)) | |
st.image(pil_image) | |
return image.to(device) * 2.0 - 1.0 | |
def save_video_as_grid_and_mp4( | |
video_batch: torch.Tensor, save_path: str, T: int, fps: int = 5 | |
): | |
os.makedirs(save_path, exist_ok=True) | |
base_count = len(glob(os.path.join(save_path, "*.mp4"))) | |
video_batch = rearrange(video_batch, "(b t) c h w -> b t c h w", t=T) | |
video_batch = embed_watermark(video_batch) | |
for vid in video_batch: | |
save_image(vid, fp=os.path.join(save_path, f"{base_count:06d}.png"), nrow=4) | |
video_path = os.path.join(save_path, f"{base_count:06d}.mp4") | |
writer = cv2.VideoWriter( | |
video_path, | |
cv2.VideoWriter_fourcc(*"MP4V"), | |
fps, | |
(vid.shape[-1], vid.shape[-2]), | |
) | |
vid = ( | |
(rearrange(vid, "t c h w -> t h w c") * 255).cpu().numpy().astype(np.uint8) | |
) | |
for frame in vid: | |
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) | |
writer.write(frame) | |
writer.release() | |
video_path_h264 = video_path[:-4] + "_h264.mp4" | |
os.system(f"ffmpeg -i {video_path} -c:v libx264 {video_path_h264}") | |
with open(video_path_h264, "rb") as f: | |
video_bytes = f.read() | |
st.video(video_bytes) | |
base_count += 1 | |