Spaces:
Runtime error
Runtime error
import gradio as gr | |
from diffusers import StableDiffusionXLPipeline, DDIMScheduler | |
import torch | |
import mediapy | |
import sa_handler | |
# init models | |
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, | |
set_alpha_to_one=False) | |
pipeline = StableDiffusionXLPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True, | |
scheduler=scheduler | |
).to("cuda") | |
handler = sa_handler.Handler(pipeline) | |
sa_args = sa_handler.StyleAlignedArgs(share_group_norm=False, | |
share_layer_norm=False, | |
share_attention=True, | |
adain_queries=True, | |
adain_keys=True, | |
adain_values=False, | |
) | |
handler.register(sa_args, ) | |
# run StyleAligned | |
sets_of_prompts = [ | |
"a toy train. macro photo. 3d game asset", | |
"a toy airplane. macro photo. 3d game asset", | |
"a toy bicycle. macro photo. 3d game asset", | |
"a toy car. macro photo. 3d game asset", | |
"a toy boat. macro photo. 3d game asset", | |
] | |
def style_aligned_sdxl(prompt): | |
images = pipeline([prompts],).images | |
#mediapy.show_images(images) | |
print(images) | |
return images | |
with gr.Blocks() as demo: | |
with gr.Group(): | |
with gr.Row(): | |
prompt = gr.Textbox(label="Prompt", scale=8) | |
btn = gr.Button("Greet", scale=2) | |
output = gr.Image(label="Style-Aligned SDXL") | |
btn.click(fn=style_aligned_sdxl, inputs=prompt, outputs=output, api_name="style_aligned_sdxl") | |
demo.launch() | |