radames commited on
Commit
4771ff8
1 Parent(s): 50a87b5

add randomize seed and sfast

Browse files
Files changed (2) hide show
  1. app.py +44 -27
  2. requirements.txt +3 -0
app.py CHANGED
@@ -12,10 +12,13 @@ from PIL import Image
12
  import numpy as np
13
  import gradio as gr
14
  import psutil
 
 
 
 
15
 
16
 
17
  SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", None)
18
- TORCH_COMPILE = os.environ.get("TORCH_COMPILE", None)
19
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
20
  # check if MPS is available OSX only M1/M2/M3 chips
21
  mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
@@ -27,7 +30,6 @@ torch_device = device
27
  torch_dtype = torch.float16
28
 
29
  print(f"SAFETY_CHECKER: {SAFETY_CHECKER}")
30
- print(f"TORCH_COMPILE: {TORCH_COMPILE}")
31
  print(f"device: {device}")
32
 
33
  if mps_available:
@@ -43,24 +45,21 @@ else:
43
  pipe = DiffusionPipeline.from_pretrained(model_id, safety_checker=None)
44
 
45
  pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
46
- pipe.to(device=torch_device, dtype=torch_dtype).to(device)
47
- pipe.unet.to(memory_format=torch.channels_last)
48
-
49
- # check if computer has less than 64GB of RAM using sys or os
50
- if psutil.virtual_memory().total < 64 * 1024**3:
51
- pipe.enable_attention_slicing()
52
-
53
- if TORCH_COMPILE:
54
- pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
55
- pipe.vae = torch.compile(pipe.vae, mode="reduce-overhead", fullgraph=True)
56
-
57
- pipe(prompt="warmup", num_inference_steps=1, guidance_scale=8.0)
58
-
59
- # Load LCM LoRA
60
  pipe.load_lora_weights(
61
  "latent-consistency/lcm-lora-sdxl",
62
  use_auth_token=HF_TOKEN,
63
  )
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  compel_proc = Compel(
66
  tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
@@ -71,8 +70,15 @@ compel_proc = Compel(
71
 
72
 
73
  def predict(
74
- prompt, guidance, steps, seed=1231231, progress=gr.Progress(track_tqdm=True)
 
 
 
 
 
75
  ):
 
 
76
  generator = torch.manual_seed(seed)
77
  prompt_embeds, pooled_prompt_embeds = compel_proc(prompt)
78
 
@@ -94,7 +100,7 @@ def predict(
94
  )
95
  if nsfw_content_detected:
96
  raise gr.Error("NSFW content detected.")
97
- return results.images[0]
98
 
99
 
100
  css = """
@@ -122,18 +128,28 @@ with gr.Blocks(css=css) as demo:
122
  placeholder="Insert your prompt here:", scale=5, container=False
123
  )
124
  generate_bt = gr.Button("Generate", scale=1)
125
-
126
  image = gr.Image(type="filepath")
127
  with gr.Accordion("Advanced options", open=False):
128
  guidance = gr.Slider(
129
  label="Guidance", minimum=0.0, maximum=5, value=0.3, step=0.001
130
  )
131
  steps = gr.Slider(label="Steps", value=4, minimum=2, maximum=10, step=1)
132
- seed = gr.Slider(
133
- randomize=True, minimum=0, maximum=12013012031030, label="Seed", step=1
134
- )
 
 
 
 
 
 
 
 
 
135
  with gr.Accordion("Run with diffusers"):
136
- gr.Markdown('''## Running LCM-LoRAs it with `diffusers`
 
137
  ```bash
138
  pip install diffusers==0.23.0
139
  ```
@@ -151,10 +167,11 @@ with gr.Blocks(css=css) as demo:
151
  )
152
  results.images[0]
153
  ```
154
- ''')
155
-
156
- inputs = [prompt, guidance, steps, seed]
157
- generate_bt.click(fn=predict, inputs=inputs, outputs=image)
 
158
 
159
  demo.queue()
160
  demo.launch()
 
12
  import numpy as np
13
  import gradio as gr
14
  import psutil
15
+ from sfast.compilers.stable_diffusion_pipeline_compiler import (
16
+ compile,
17
+ CompilationConfig,
18
+ )
19
 
20
 
21
  SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", None)
 
22
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
23
  # check if MPS is available OSX only M1/M2/M3 chips
24
  mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
 
30
  torch_dtype = torch.float16
31
 
32
  print(f"SAFETY_CHECKER: {SAFETY_CHECKER}")
 
33
  print(f"device: {device}")
34
 
35
  if mps_available:
 
45
  pipe = DiffusionPipeline.from_pretrained(model_id, safety_checker=None)
46
 
47
  pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  pipe.load_lora_weights(
49
  "latent-consistency/lcm-lora-sdxl",
50
  use_auth_token=HF_TOKEN,
51
  )
52
+ if device.type != "mps":
53
+ pipe.unet.to(memory_format=torch.channels_last)
54
+ pipe.to(device=torch_device, dtype=torch_dtype).to(device)
55
+
56
+ # Load LCM LoRA
57
+
58
+ config = CompilationConfig.Default()
59
+ config.enable_xformers = True
60
+ config.enable_triton = True
61
+ config.enable_cuda_graph = True
62
+ pipe = compile(pipe, config=config)
63
 
64
  compel_proc = Compel(
65
  tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
 
70
 
71
 
72
  def predict(
73
+ prompt,
74
+ guidance,
75
+ steps,
76
+ seed=1231231,
77
+ randomize_bt=False,
78
+ progress=gr.Progress(track_tqdm=True),
79
  ):
80
+ if randomize_bt:
81
+ seed = np.random.randint(0, 2**32 - 1)
82
  generator = torch.manual_seed(seed)
83
  prompt_embeds, pooled_prompt_embeds = compel_proc(prompt)
84
 
 
100
  )
101
  if nsfw_content_detected:
102
  raise gr.Error("NSFW content detected.")
103
+ return results.images[0], seed
104
 
105
 
106
  css = """
 
128
  placeholder="Insert your prompt here:", scale=5, container=False
129
  )
130
  generate_bt = gr.Button("Generate", scale=1)
131
+
132
  image = gr.Image(type="filepath")
133
  with gr.Accordion("Advanced options", open=False):
134
  guidance = gr.Slider(
135
  label="Guidance", minimum=0.0, maximum=5, value=0.3, step=0.001
136
  )
137
  steps = gr.Slider(label="Steps", value=4, minimum=2, maximum=10, step=1)
138
+ with gr.Row():
139
+ seed = gr.Slider(
140
+ randomize=True,
141
+ minimum=0,
142
+ maximum=12013012031030,
143
+ label="Seed",
144
+ step=1,
145
+ scale=5,
146
+ )
147
+ with gr.Group():
148
+ randomize_bt = gr.Checkbox(label="Randomize", value=False)
149
+ random_seed = gr.Textbox(show_label=False)
150
  with gr.Accordion("Run with diffusers"):
151
+ gr.Markdown(
152
+ """## Running LCM-LoRAs it with `diffusers`
153
  ```bash
154
  pip install diffusers==0.23.0
155
  ```
 
167
  )
168
  results.images[0]
169
  ```
170
+ """
171
+ )
172
+
173
+ inputs = [prompt, guidance, steps, seed, randomize_bt]
174
+ generate_bt.click(fn=predict, inputs=inputs, outputs=[image, random_seed])
175
 
176
  demo.queue()
177
  demo.launch()
requirements.txt CHANGED
@@ -11,3 +11,6 @@ accelerate==0.24.0
11
  compel==2.0.2
12
  controlnet-aux==0.0.7
13
  peft==0.6.0
 
 
 
 
11
  compel==2.0.2
12
  controlnet-aux==0.0.7
13
  peft==0.6.0
14
+ stable_fast @ https://github.com/chengzeyi/stable-fast/releases/download/v0.0.15.post1/stable_fast-0.0.15.post1+torch211cu121-cp310-cp310-manylinux2014_x86_64.whl
15
+ xformers
16
+ triton