import gradio as gr import numpy as np import random import spaces # [uncomment to use ZeroGPU] from diffusers import DiffusionPipeline import torch from PIL import Image model_ids = [ "Prgckwb/trpfrog-sd3.5-large", "Prgckwb/trpfrog-diffusion", ] if torch.cuda.is_available(): torch_dtype = torch.float16 device = "cuda" else: torch_dtype = torch.float32 device = "cpu" pipelines = { model_id: DiffusionPipeline.from_pretrained( model_id, torch_dtype=torch_dtype ) if device == 'cuda' else None for model_id in model_ids } @spaces.GPU() @torch.inference_mode() def inference( model_id: str, prompt: str, width: int, height: int, progress=gr.Progress(track_tqdm=True), ): if device == 'cuda': pipe = pipelines[model_id].to(device) image = pipe( prompt=prompt, width=width, height=height, ).images[0] else: # 真っ黒の画像を生成 image = Image.fromarray(np.random.randn(height, width, 3).astype(np.uint8)) return image def create_interface(): theme = gr.themes.Ocean() with gr.Blocks(theme=theme) as demo: with gr.Column(): gr.Markdown("

TrpFrog Diffusion Demo

") with gr.Row(): with gr.Column(): input_model_id = gr.Dropdown(label="Model", choices=model_ids, value=model_ids[0]) input_prompt = gr.Textbox(label="Prompt", placeholder="an icon of trpfrog", value="an icon of trpfrog") with gr.Row(): input_width = gr.Slider(label="Width", minimum=64, maximum=2056, step=128, value=1024) input_height = gr.Slider(label="Height", minimum=64, maximum=2056, step=128, value=1024) with gr.Row(): clear_btn = gr.ClearButton(components=[input_prompt]) submit_btn = gr.Button('Generate', variant='primary') with gr.Column(): output_image = gr.Image(label="Output") all_inputs = [input_model_id, input_prompt, input_width, input_height] all_outputs = [output_image] examples = gr.Examples( examples=[ ['Prgckwb/trpfrog-sd3.5-large', 'an icon of trpfrog eating ramen', 1024, 1024], ['Prgckwb/trpfrog-sd3.5-large', 'an icon of trpfrog with a gun', 1024, 1024], ], inputs=all_inputs, outputs=all_outputs, fn=inference, cache_mode='eager', cache_examples=True, ) submit_btn.click(inference, inputs=all_inputs, outputs=all_outputs) input_prompt.submit(inference, inputs=all_inputs, outputs=all_outputs) return demo if __name__ == "__main__": try: demo = create_interface() demo.queue().launch() except Exception as e: raise gr.Error(e)