ysharma HF staff commited on
Commit
fd1c028
1 Parent(s): 1b9e540

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -0
app.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from diffusers import StableDiffusionXLPipeline, DDIMScheduler
3
+ import torch
4
+ import mediapy
5
+ import sa_handler
6
+
7
+ # init models
8
+ scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False,
9
+ set_alpha_to_one=False)
10
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
11
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True,
12
+ scheduler=scheduler
13
+ ).to("cuda")
14
+
15
+ handler = sa_handler.Handler(pipeline)
16
+ sa_args = sa_handler.StyleAlignedArgs(share_group_norm=False,
17
+ share_layer_norm=False,
18
+ share_attention=True,
19
+ adain_queries=True,
20
+ adain_keys=True,
21
+ adain_values=False,
22
+ )
23
+
24
+ handler.register(sa_args, )
25
+
26
+
27
+ # run StyleAligned
28
+ sets_of_prompts = [
29
+ "a toy train. macro photo. 3d game asset",
30
+ "a toy airplane. macro photo. 3d game asset",
31
+ "a toy bicycle. macro photo. 3d game asset",
32
+ "a toy car. macro photo. 3d game asset",
33
+ "a toy boat. macro photo. 3d game asset",
34
+ ]
35
+
36
+ def style_aligned_sdxl(prompt):
37
+ images = pipeline([prompts],).images
38
+ #mediapy.show_images(images)
39
+ print(images)
40
+ return images
41
+
42
+ with gr.Blocks() as demo:
43
+ with gr.Group():
44
+ with gr.Row():
45
+ prompt = gr.Textbox(label="Prompt", scale=8)
46
+ btn = gr.Button("Greet", scale=2)
47
+ output = gr.Image(label="Style-Aligned SDXL")
48
+
49
+ btn.click(fn=style_aligned_sdxl, inputs=prompt, outputs=output, api_name="style_aligned_sdxl")
50
+
51
+ demo.launch()
52
+