patrickvonplaten's picture
update config
9379e34
#!/usr/bin/env python3
from diffusers import UNet2DModel, DDIMScheduler, VQModel
import torch
import PIL.Image
import numpy as np
import tqdm
seed = 3
# 1. Unroll the full loop
# ==================================================================
# load all models
unet = UNet2DModel.from_pretrained("./", subfolder="unet")
vqvae = VQModel.from_pretrained("./", subfolder="vqvae")
scheduler = DDIMScheduler.from_config("./", subfolder="scheduler")
# set to cuda
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
unet.to(torch_device)
vqvae.to(torch_device)
# generate gaussian noise to be decoded
generator = torch.manual_seed(seed)
noise = torch.randn(
(1, unet.in_channels, unet.image_size, unet.image_size),
generator=generator,
).to(torch_device)
# set inference steps for DDIM
scheduler.set_timesteps(num_inference_steps=200)
image = noise
for t in tqdm.tqdm(scheduler.timesteps):
# predict noise residual of previous image
with torch.no_grad():
residual = unet(image, t)["sample"]
# compute previous image x_t according to DDIM formula
prev_image = scheduler.step(residual, t, image, eta=0.0)["prev_sample"]
# x_t-1 -> x_t
image = prev_image
# decode image with vae
with torch.no_grad():
image = vqvae.decode(image)
# process image
image_processed = image.cpu().permute(0, 2, 3, 1)
image_processed = (image_processed + 1.0) * 127.5
image_processed = image_processed.clamp(0, 255).numpy().astype(np.uint8)
image_pil = PIL.Image.fromarray(image_processed[0])
# 2. Use pipeline
# ==================================================================
from diffusers import LatentDiffusionUncondPipeline
import torch
import PIL.Image
import numpy as np
import tqdm
pipeline = LatentDiffusionUncondPipeline.from_pretrained("./")
# generatae image by calling the pipeline
generator = torch.manual_seed(seed)
image = pipeline(generator=generator, num_inference_steps=200)["sample"]
# process image
image_processed = image.cpu().permute(0, 2, 3, 1)
image_processed = (image_processed + 1.0) * 127.5
image_processed = image_processed.clamp(0, 255).numpy().astype(np.uint8)
image_pil = PIL.Image.fromarray(image_processed[0])
image_pil.save(f"generated_image_{seed}.png")