retrAIced / app.py
JavierGon12's picture
Generate images code
757b45e
raw
history blame
1.14 kB
import streamlit as st
from diffusers import DDPMScheduler, UNet2DModel
from PIL import Image
import torch
import numpy as np
def generate_image():
scheduler = DDPMScheduler.from_pretrained("google/ddpm-cat-256")
model = UNet2DModel.from_pretrained("google/ddpm-cat-256").to("cuda")
scheduler.set_timesteps(50)
sample_size = model.config.sample_size
noise = torch.randn((1, 3, sample_size, sample_size)).to("cuda")
input = noise
for t in scheduler.timesteps:
with torch.no_grad():
noisy_residual = model(input, t).sample()
prev_noisy_sample = scheduler.step(noisy_residual, t, input).prev_sample
input = prev_noisy_sample
image = (input / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
image = Image.fromarray((image * 255).round().astype("uint8"))
return image
# Streamlit app
st.title("DDPM Image Generation")
st.write("Generating and displaying an image using DDPM.")
# Generate and display the image
generated_image = generate_image()
st.image(generated_image, caption="Generated Image", use_column_width=True)