Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,072 Bytes
7b31aca 857a4f1 7b31aca a8b4b26 7b31aca 857a4f1 7b31aca 857a4f1 7a474e6 a8b4b26 857a4f1 7b31aca a8b4b26 7b31aca 857a4f1 7b31aca a8b4b26 7b31aca a8b4b26 7b31aca a8b4b26 7b31aca a8b4b26 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import gradio as gr
import numpy as np
import random
import spaces # [uncomment to use ZeroGPU]
from diffusers import DiffusionPipeline
import torch
from PIL import Image
model_ids = [
"Prgckwb/trpfrog-sd3.5-large",
"Prgckwb/trpfrog-diffusion",
]
if torch.cuda.is_available():
torch_dtype = torch.float16
device = "cuda"
else:
torch_dtype = torch.float32
device = "cpu"
pipelines = {
model_id: DiffusionPipeline.from_pretrained(
model_id, torch_dtype=torch_dtype
) if device == 'cuda' else None
for model_id in model_ids
}
@spaces.GPU()
@torch.inference_mode()
def inference(
model_id: str,
prompt: str,
width: int,
height: int,
progress=gr.Progress(track_tqdm=True),
):
if device == 'cuda':
pipe = pipelines[model_id].to(device)
image = pipe(
prompt=prompt,
width=width,
height=height,
).images[0]
else:
# ็ใฃ้ปใฎ็ปๅใ็ๆ
image = Image.fromarray(np.random.randn(height, width, 3).astype(np.uint8))
return image
def create_interface():
theme = gr.themes.Ocean()
with gr.Blocks(theme=theme) as demo:
with gr.Column():
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>TrpFrog Diffusion Demo</h1>")
with gr.Row():
with gr.Column():
input_model_id = gr.Dropdown(label="Model", choices=model_ids, value=model_ids[0])
input_prompt = gr.Textbox(label="Prompt", placeholder="an icon of trpfrog", value="an icon of trpfrog")
with gr.Row():
input_width = gr.Slider(label="Width", minimum=64, maximum=2056, step=128, value=1024)
input_height = gr.Slider(label="Height", minimum=64, maximum=2056, step=128, value=1024)
with gr.Row():
clear_btn = gr.ClearButton(components=[input_prompt])
submit_btn = gr.Button('Generate', variant='primary')
with gr.Column():
output_image = gr.Image(label="Output")
all_inputs = [input_model_id, input_prompt, input_width, input_height]
all_outputs = [output_image]
examples = gr.Examples(
examples=[
['Prgckwb/trpfrog-sd3.5-large', 'an icon of trpfrog eating ramen', 1024, 1024],
['Prgckwb/trpfrog-sd3.5-large', 'an icon of trpfrog with a gun', 1024, 1024],
],
inputs=all_inputs,
outputs=all_outputs,
fn=inference,
cache_mode='eager',
cache_examples=True,
)
submit_btn.click(inference, inputs=all_inputs, outputs=all_outputs)
input_prompt.submit(inference, inputs=all_inputs, outputs=all_outputs)
return demo
if __name__ == "__main__":
try:
demo = create_interface()
demo.queue().launch()
except Exception as e:
raise gr.Error(e)
|