dgoot's picture
Update app.py
75ce7d0 verified
import shutil
from pathlib import Path
import gradio as gr
import requests
import spaces
import torch
from diffusers import (
AutoencoderKL,
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLPipeline,
)
from loguru import logger
from PIL import Image
from tqdm import tqdm
def download(file: str, url: str):
file_path = Path(file)
if file_path.exists():
return
r = requests.get(url, stream=True)
r.raise_for_status()
temp_path = f"/tmp/{file_path.name}"
with tqdm(
desc=file, total=int(r.headers["content-length"]), unit="B", unit_scale=True
) as pbar, open(temp_path, "wb") as f:
for chunk in r.iter_content(chunk_size=1024 * 1024):
f.write(chunk)
pbar.update(len(chunk))
shutil.move(temp_path, file_path)
model_path = "pony-diffusion-v6-xl.safetensors"
download(
model_path,
"https://civitai.com/api/download/models/290640?type=Model&format=SafeTensor&size=pruned&fp=fp16",
)
vae_path = "pony-diffusion-v6-xl.vae.safetensors"
download(
vae_path,
"https://civitai.com/api/download/models/290640?type=VAE&format=SafeTensor",
)
vae = AutoencoderKL.from_single_file(
vae_path,
torch_dtype=torch.float16,
)
# pipe = StableDiffusionXLImg2ImgPipeline.from_single_file(
pipe = StableDiffusionXLPipeline.from_single_file(
model_path,
torch_dtype=torch.float16,
vae=vae,
)
pipe = pipe.to("cuda")
@logger.catch(reraise=True)
@spaces.GPU
def generate(
prompt: str,
# init_image: Image.Image,
strength: float,
num_inference_steps: int,
guidance_scale: float,
progress=gr.Progress(track_tqdm=True),
):
logger.info(f"Starting image generation: {dict(prompt=prompt, strength=strength)}")
# # Downscale the image
# init_image.thumbnail((1024, 1024))
additional_args = {
k: v
for k, v in dict(
strength=strength,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
).items()
if v
}
images = pipe(
prompt=prompt,
# image=init_image,
**additional_args,
).images
return images[0]
demo = gr.Interface(
fn=generate,
inputs=[
gr.Text(label="Prompt"),
# gr.Image(label="Init image", type="pil"),
gr.Slider(label="Strength", minimum=0.0, maximum=1.0, value=0.0),
gr.Slider(label="Number of inference steps", minimum=0, maximum=100, value=0),
gr.Slider(label="Guidance scale", minimum=0.0, maximum=100.0, value=0.0),
],
outputs=[gr.Image(label="Output")],
)
demo.launch()