import os import gradio as gr import spaces import torch from diffusers import AutoPipelineForText2Image from loguru import logger SUPPORTED_MODELS = [ "stabilityai/sdxl-turbo", "stabilityai/stable-diffusion-3-medium-diffusers", "stabilityai/stable-diffusion-xl-base-1.0", ] DEFAULT_MODEL = "stabilityai/stable-diffusion-xl-base-1.0" model = os.environ.get("MODEL_ID", DEFAULT_MODEL) gpu_duration = int(os.environ.get("GPU_DURATION", 60)) def load_pipeline(model): return AutoPipelineForText2Image.from_pretrained( model, torch_dtype=torch.float16, use_safetensors=True, variant="fp16" ) logger.debug(f"Loading pipeline: {dict(model=model)}") pipe = load_pipeline(model).to("cuda") @logger.catch(reraise=True) @spaces.GPU(duration=gpu_duration) def infer( prompt: str, negative_prompt: str | None, num_inference_steps: int, guidance_scale: float, progress=gr.Progress(track_tqdm=True), ): logger.info(f"Starting image generation: {dict(model=model, prompt=prompt)}") additional_args = { k: v for k, v in dict( num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, ).items() if v } logger.debug(f"Generating image: {dict(prompt=prompt, **additional_args)}") images = pipe( prompt=prompt, negative_prompt=negative_prompt, **additional_args, ).images return images[0] with gr.Blocks() as demo: with gr.Column(): gr.Markdown("# Text-to-Image") gr.Markdown(f"## Model: `{model}`") with gr.Row(): prompt = gr.Text( label="Prompt", show_label=False, max_lines=1, placeholder="Enter your prompt", container=False, ) run_button = gr.Button("Run", scale=0, variant="primary") result = gr.Image(label="Result", show_label=False) with gr.Accordion("Advanced Settings", open=False): negative_prompt = gr.Text( label="Negative prompt", max_lines=1, placeholder="Enter a negative prompt", ) with gr.Row(): num_inference_steps = gr.Slider( label="Number of inference steps", minimum=0, maximum=100, step=1, value=0, ) guidance_scale = gr.Slider( label="Guidance scale", minimum=0.0, maximum=100.0, step=0.1, value=0.0, ) gr.on( triggers=[run_button.click, prompt.submit], fn=infer, inputs=[ prompt, negative_prompt, num_inference_steps, guidance_scale, ], outputs=[result], ) if __name__ == "__main__": demo.launch()