josuelmet's picture
Update app.py
c483e0a
import gradio as gr
from gradio.components import *
import torch
from torch import autocast
from torchvision import transforms as T
from boomerang import *
COLAB = False
HALF_PRECISION = COLAB
def main(image, prompt, percent_noise): # percent_noise = 0.5, 0.02, 0.999
# Convert image to float and preprocess it.
# From huggingface/diffusers/blob/main/examples/unconditional_image_generation/train_unconditional.py
transform = T.Compose([T.PILToTensor(),
T.ConvertImageDtype(torch.float),
T.Normalize([0.5], [0.5])])
if HALF_PRECISION:
tensor = transform(image).half().to(pipe.device)
else:
tensor = transform(image).to(pipe.device)
tensor = torch.unsqueeze(tensor, 0)
# Project image into the latent space
clean_z = pipe.vae.encode(tensor).latent_dist.mode() # From huggingface/diffusers/blob/main/src/diffusers/models/vae.py
clean_z = 0.18215 * clean_z
# Add noise to the latent variable
# (this is the forward diffusion process)
noise = torch.randn(clean_z.shape).to(pipe.device)
timestep = torch.Tensor([int(pipe.scheduler.config.num_train_timesteps * percent_noise)]).to(pipe.device).long()
z = pipe.scheduler.add_noise(clean_z, noise, timestep)
if HALF_PRECISION:
z = z.half()
# Run the diffusion model
#with autocast('cuda'):
# The 'num_inference_steps=100' arguments means that, if you were running
# the full diffusion model (i.e., percent_noise = 0.999), it would be sampling
# at 2x frequency compared to the normal stable diffusion model (which uses 50 steps).
# This way, running percent_noise = 0.5 yields 50 inference steps,
# and running percent_noise = 0.2 yields 20 inference steps, etc.
return pipe(prompt=prompt, latents=z, num_inference_steps=100, percent_noise=percent_noise).images[0]
gr.Interface(fn=main,
inputs=[gr.Image(type="pil", shape=(512, 512)), "text", gr.Slider(0.02, 0.999, value=0.7, label='percent noise')],
outputs=gr.Image(type="pil", shape=(512,512)),
examples=[["original.png", "person", 0.7],
['cat.png', 'cat', 0.9],
['bedroom.jpg', 'bathroom', 0.8],
['einstein.jpg', 'dog', 0.8],
['oprah.jpeg', 'pirate', 0.8]]).launch()