import gradio as gr from diffusers import StableDiffusionXLPipeline, DDIMScheduler import torch import mediapy import sa_handler # init models scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False) pipeline = StableDiffusionXLPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True, scheduler=scheduler ).to("cuda") #pipeline.enable_sequential_cpu_offload() pipeline.enable_model_cpu_offload() pipeline.enable_vae_slicing() handler = sa_handler.Handler(pipeline) sa_args = sa_handler.StyleAlignedArgs(share_group_norm=False, share_layer_norm=False, share_attention=True, adain_queries=True, adain_keys=True, adain_values=False, ) handler.register(sa_args, ) # example of set of prompts sets_of_prompts = [ "a toy train. macro photo. 3d game asset", "a toy airplane. macro photo. 3d game asset", "a toy bicycle. macro photo. 3d game asset", "a toy car. macro photo. 3d game asset", "a toy boat. macro photo. 3d game asset", ] # run StyleAligned def style_aligned_sdxl(initial_prompt1, initial_prompt2, initial_prompt3, initial_prompt4, initial_prompt5, style_prompt): sets_of_prompts = [ prompt + ". " + style_prompt for prompt in [initial_prompt1, initial_prompt2, initial_prompt3, initial_prompt4, initial_prompt5,]] images = pipeline(sets_of_prompts,).images #mediapy.show_images(images) print(images) return images with gr.Blocks() as demo: with gr.Group(): with gr.Column(): with gr.Accordion(label='Enter upto 5 different initial prompts', open=True): with gr.Row(variant='panel'): initial_prompt1 = gr.Textbox(label='Initial prompt 1', value='', show_label=False, container=False, placeholder='a toy train') initial_prompt2 = gr.Textbox(label='Initial prompt 2', value='', show_label=False, container=False, placeholder='a toy airplane') initial_prompt3 = gr.Textbox(label='Initial prompt 3', value='', show_label=False, container=False, placeholder='a toy bicycle') initial_prompt4 = gr.Textbox(label='Initial prompt 4', value='', show_label=False, container=False, placeholder='a toy car') initial_prompt5 = gr.Textbox(label='Initial prompt 5', value='', show_label=False, container=False, placeholder='a toy boat') with gr.Row(): style_prompt = gr.Textbox(label="Enter a style prompt", placeholder='macro photo, 3d game asset') btn = gr.Button("Generate a set of Style-aligned SDXL images",) output = gr.Gallery(label="Style-Aligned SDXL Images", elem_id="gallery",columns=5, rows=1, object_fit="contain", height="auto",) btn.click(fn=style_aligned_sdxl, inputs=[initial_prompt1, initial_prompt2, initial_prompt3, initial_prompt4, initial_prompt5, style_prompt], outputs=output, api_name="style_aligned_sdxl") gr.Examples(examples=[ ["a toy train", "a toy airplane", "a toy bicycle", "a toy car", "a toy boat", "macro photo. 3d game asset."], ["a toy train", "a toy airplane", "a toy bicycle", "a toy car", "a toy boat", "BW logo. high contrast."], ["a cat", "a dog", "a bear", "a man on a bicycle", "a girl working on laptop", "minimal origami."], ["a firewoman", "a Gardner", "a scientist", "a policewoman", "a saxophone player", "made of claymation, stop motion animation."], ["a firewoman", "a Gardner", "a scientist", "a policewoman", "a saxophone player", "sketch, character sheet."], ], inputs=[initial_prompt1, initial_prompt2, initial_prompt3, initial_prompt4, initial_prompt5, style_prompt], outputs=[output], fn=style_aligned_sdxl) demo.launch()