import gradio as gr from diffusers import StableDiffusionXLPipeline, DDIMScheduler import torch import mediapy import sa_handler # init models scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False) pipeline = StableDiffusionXLPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True, scheduler=scheduler ).to("cuda") handler = sa_handler.Handler(pipeline) sa_args = sa_handler.StyleAlignedArgs(share_group_norm=False, share_layer_norm=False, share_attention=True, adain_queries=True, adain_keys=True, adain_values=False, ) handler.register(sa_args, ) # run StyleAligned sets_of_prompts = [ "a toy train. macro photo. 3d game asset", "a toy airplane. macro photo. 3d game asset", "a toy bicycle. macro photo. 3d game asset", "a toy car. macro photo. 3d game asset", "a toy boat. macro photo. 3d game asset", ] def style_aligned_sdxl(prompt): images = pipeline([prompts],).images #mediapy.show_images(images) print(images) return images with gr.Blocks() as demo: with gr.Group(): with gr.Row(): prompt = gr.Textbox(label="Prompt", scale=8) btn = gr.Button("Greet", scale=2) output = gr.Image(label="Style-Aligned SDXL") btn.click(fn=style_aligned_sdxl, inputs=prompt, outputs=output, api_name="style_aligned_sdxl") demo.launch()