import gradio as gr from gradio.components import * import torch from torch import autocast from torchvision import transforms as T from boomerang import * COLAB = False HALF_PRECISION = COLAB def main(image, prompt, percent_noise): # percent_noise = 0.5, 0.02, 0.999 # Convert image to float and preprocess it. # From huggingface/diffusers/blob/main/examples/unconditional_image_generation/train_unconditional.py transform = T.Compose([T.PILToTensor(), T.ConvertImageDtype(torch.float), T.Normalize([0.5], [0.5])]) if HALF_PRECISION: tensor = transform(image).half().to(pipe.device) else: tensor = transform(image).to(pipe.device) tensor = torch.unsqueeze(tensor, 0) # Project image into the latent space clean_z = pipe.vae.encode(tensor).latent_dist.mode() # From huggingface/diffusers/blob/main/src/diffusers/models/vae.py clean_z = 0.18215 * clean_z # Add noise to the latent variable # (this is the forward diffusion process) noise = torch.randn(clean_z.shape).to(pipe.device) timestep = torch.Tensor([int(pipe.scheduler.config.num_train_timesteps * percent_noise)]).to(pipe.device).long() z = pipe.scheduler.add_noise(clean_z, noise, timestep) if HALF_PRECISION: z = z.half() # Run the diffusion model #with autocast('cuda'): # The 'num_inference_steps=100' arguments means that, if you were running # the full diffusion model (i.e., percent_noise = 0.999), it would be sampling # at 2x frequency compared to the normal stable diffusion model (which uses 50 steps). # This way, running percent_noise = 0.5 yields 50 inference steps, # and running percent_noise = 0.2 yields 20 inference steps, etc. return pipe(prompt=prompt, latents=z, num_inference_steps=100, percent_noise=percent_noise).images[0] gr.Interface(fn=main, inputs=[gr.Image(type="pil", shape=(512, 512)), "text", gr.Slider(0.02, 0.999, value=0.7, label='percent noise')], outputs=gr.Image(type="pil", shape=(512,512)), examples=[["original.png", "person", 0.7], ['cat.png', 'cat', 0.9], ['bedroom.jpg', 'bathroom', 0.8], ['einstein.jpg', 'dog', 0.8], ['oprah.jpeg', 'pirate', 0.8]]).launch()