import gradio as gr
import numpy as np
import random
import spaces # [uncomment to use ZeroGPU]
from diffusers import DiffusionPipeline
import torch
from PIL import Image
model_ids = [
"Prgckwb/trpfrog-sd3.5-large",
"Prgckwb/trpfrog-diffusion",
]
if torch.cuda.is_available():
torch_dtype = torch.float16
device = "cuda"
else:
torch_dtype = torch.float32
device = "cpu"
pipelines = {
model_id: DiffusionPipeline.from_pretrained(
model_id, torch_dtype=torch_dtype
) if device == 'cuda' else None
for model_id in model_ids
}
@spaces.GPU()
@torch.inference_mode()
def inference(
model_id: str,
prompt: str,
width: int,
height: int,
progress=gr.Progress(track_tqdm=True),
):
if device == 'cuda':
pipe = pipelines[model_id].to(device)
image = pipe(
prompt=prompt,
width=width,
height=height,
).images[0]
else:
# 真っ黒の画像を生成
image = Image.fromarray(np.random.randn(height, width, 3).astype(np.uint8))
return image
def create_interface():
theme = gr.themes.Ocean()
with gr.Blocks(theme=theme) as demo:
with gr.Column():
gr.Markdown("
TrpFrog Diffusion Demo
")
with gr.Row():
with gr.Column():
input_model_id = gr.Dropdown(label="Model", choices=model_ids, value=model_ids[0])
input_prompt = gr.Textbox(label="Prompt", placeholder="an icon of trpfrog", value="an icon of trpfrog")
with gr.Row():
input_width = gr.Slider(label="Width", minimum=64, maximum=2056, step=128, value=1024)
input_height = gr.Slider(label="Height", minimum=64, maximum=2056, step=128, value=1024)
with gr.Row():
clear_btn = gr.ClearButton(components=[input_prompt])
submit_btn = gr.Button('Generate', variant='primary')
with gr.Column():
output_image = gr.Image(label="Output")
all_inputs = [input_model_id, input_prompt, input_width, input_height]
all_outputs = [output_image]
examples = gr.Examples(
examples=[
['Prgckwb/trpfrog-sd3.5-large', 'an icon of trpfrog eating ramen', 1024, 1024],
['Prgckwb/trpfrog-sd3.5-large', 'an icon of trpfrog with a gun', 1024, 1024],
],
inputs=all_inputs,
outputs=all_outputs,
fn=inference,
cache_mode='eager',
cache_examples=True,
)
submit_btn.click(inference, inputs=all_inputs, outputs=all_outputs)
input_prompt.submit(inference, inputs=all_inputs, outputs=all_outputs)
return demo
if __name__ == "__main__":
try:
demo = create_interface()
demo.queue().launch()
except Exception as e:
raise gr.Error(e)