adamelliotfields commited on
Commit
5c4e8c1
1 Parent(s): 2e278ad
Files changed (5) hide show
  1. README.md +5 -2
  2. app.py +41 -12
  3. cli.py +59 -0
  4. generate.py +6 -5
  5. usage.md +5 -5
README.md CHANGED
@@ -56,8 +56,11 @@ python -m venv .venv
56
  source .venv/bin/activate
57
  pip install -r requirements.txt torch==2.4.0 torchvision==0.19.0
58
 
59
- # http://localhost:7860
60
- python app.py
 
 
 
61
  ```
62
 
63
  ## TODO
 
56
  source .venv/bin/activate
57
  pip install -r requirements.txt torch==2.4.0 torchvision==0.19.0
58
 
59
+ # gradio
60
+ python app.py --port 7860
61
+
62
+ # cli
63
+ python cli.py 'an astronaut riding a horse on mars'
64
  ```
65
 
66
  ## TODO
app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import gradio as gr
2
 
3
  from generate import generate
@@ -40,7 +42,7 @@ def generate_btn_click(*args):
40
  prompt = None
41
  if prompt is None or prompt.strip() == "":
42
  raise gr.Error("You must enter a prompt")
43
- return generate(*args)
44
 
45
 
46
  with gr.Blocks(
@@ -87,10 +89,10 @@ with gr.Blocks(
87
 
88
  with gr.Row():
89
  num_images = gr.Dropdown(
90
- choices=list(range(1, 9)),
91
  filterable=False,
92
  label="Images",
93
- value=1,
94
  scale=1,
95
  )
96
  width = gr.Slider(
@@ -129,6 +131,7 @@ with gr.Blocks(
129
  with gr.Row():
130
  model = gr.Dropdown(
131
  value="Lykon/dreamshaper-8",
 
132
  min_width=200,
133
  label="Model",
134
  scale=2,
@@ -144,6 +147,7 @@ with gr.Blocks(
144
  scheduler = gr.Dropdown(
145
  elem_id="scheduler",
146
  label="Scheduler",
 
147
  value="DEIS 2M",
148
  min_width=200,
149
  scale=2,
@@ -186,14 +190,22 @@ with gr.Blocks(
186
  tgate_step = gr.Slider(
187
  label="T-GATE Step",
188
  minimum=0,
189
- maximum=50,
190
- value=20,
191
  step=1,
192
  )
 
 
 
 
 
 
 
 
193
  tome_ratio = gr.Slider(
194
  label="ToMe Ratio",
195
  minimum=0.0,
196
- maximum=1.0,
197
  value=0.0,
198
  step=0.01,
199
  )
@@ -263,7 +275,18 @@ with gr.Blocks(
263
  # update the random seed using JavaScript
264
  random_btn.click(None, outputs=[seed], js=SEED_JS)
265
 
266
- # ensure correct argument order
 
 
 
 
 
 
 
 
 
 
 
267
  generate_btn.click(
268
  generate_btn_click,
269
  api_name="api",
@@ -291,8 +314,14 @@ with gr.Blocks(
291
  ],
292
  )
293
 
294
- # https://www.gradio.app/docs/gradio/interface#interface-queue
295
- demo.queue().launch(
296
- server_name="0.0.0.0",
297
- server_port=7860,
298
- )
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
  import gradio as gr
4
 
5
  from generate import generate
 
42
  prompt = None
43
  if prompt is None or prompt.strip() == "":
44
  raise gr.Error("You must enter a prompt")
45
+ return generate(*args, log=gr.Info, Error=gr.Error)
46
 
47
 
48
  with gr.Blocks(
 
89
 
90
  with gr.Row():
91
  num_images = gr.Dropdown(
92
+ choices=list(range(1, 5)),
93
  filterable=False,
94
  label="Images",
95
+ value=4,
96
  scale=1,
97
  )
98
  width = gr.Slider(
 
131
  with gr.Row():
132
  model = gr.Dropdown(
133
  value="Lykon/dreamshaper-8",
134
+ filterable=False,
135
  min_width=200,
136
  label="Model",
137
  scale=2,
 
147
  scheduler = gr.Dropdown(
148
  elem_id="scheduler",
149
  label="Scheduler",
150
+ filterable=False,
151
  value="DEIS 2M",
152
  min_width=200,
153
  scale=2,
 
190
  tgate_step = gr.Slider(
191
  label="T-GATE Step",
192
  minimum=0,
193
+ maximum=30,
194
+ value=0,
195
  step=1,
196
  )
197
+
198
+ with gr.Row():
199
+ file_format = gr.Dropdown(
200
+ choices=["png", "jpeg", "webp"],
201
+ label="File Format",
202
+ filterable=False,
203
+ value="png",
204
+ )
205
  tome_ratio = gr.Slider(
206
  label="ToMe Ratio",
207
  minimum=0.0,
208
+ maximum=0.5,
209
  value=0.0,
210
  step=0.01,
211
  )
 
275
  # update the random seed using JavaScript
276
  random_btn.click(None, outputs=[seed], js=SEED_JS)
277
 
278
+ file_format.change(
279
+ lambda f: gr.Gallery(format=f),
280
+ inputs=[file_format],
281
+ outputs=[output_images],
282
+ )
283
+
284
+ inference_steps.change(
285
+ lambda max, step: gr.Slider(maximum=max, value=min(max, step)),
286
+ inputs=[inference_steps, tgate_step],
287
+ outputs=[tgate_step],
288
+ )
289
+
290
  generate_btn.click(
291
  generate_btn_click,
292
  api_name="api",
 
314
  ],
315
  )
316
 
317
+ if __name__ == "__main__":
318
+ parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False)
319
+ parser.add_argument("-s", "--server", type=str, metavar="STR", default="0.0.0.0")
320
+ parser.add_argument("-p", "--port", type=int, metavar="INT", default=7860)
321
+ args = parser.parse_args()
322
+
323
+ # https://www.gradio.app/docs/gradio/interface#interface-queue
324
+ demo.queue().launch(
325
+ server_name=args.server,
326
+ server_port=args.port,
327
+ )
cli.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ from generate import generate
4
+
5
+
6
+ def save_images(images, filename="image.png"):
7
+ for i, (img, _) in enumerate(images):
8
+ name, ext = filename.rsplit(".", 1)
9
+ img.save(f"{name}.{ext}" if len(images) == 1 else f"{name}_{i}.{ext}")
10
+
11
+
12
+ def main():
13
+ parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False)
14
+ parser.add_argument("prompt", type=str, metavar="PROMPT")
15
+ parser.add_argument("-n", "--negative", type=str, metavar="STR", default="<fast_negative>")
16
+ parser.add_argument("-s", "--seed", type=int, metavar="INT")
17
+ parser.add_argument("-i", "--images", type=int, metavar="INT", default=1)
18
+ parser.add_argument("-f", "--filename", type=str, metavar="STR", default="image.png")
19
+ parser.add_argument("-w", "--width", type=int, metavar="INT", default=448)
20
+ parser.add_argument("-h", "--height", type=int, metavar="INT", default=576)
21
+ parser.add_argument("-m", "--model", type=str, metavar="STR", default="Lykon/dreamshaper-8")
22
+ parser.add_argument("-d", "--deepcache", type=int, metavar="INT", default=2)
23
+ parser.add_argument("-t", "--tgate", type=int, metavar="INT", default=20)
24
+ parser.add_argument("--scheduler", type=str, metavar="STR", default="DEIS 2M")
25
+ parser.add_argument("--guidance", type=float, metavar="FLOAT", default=7)
26
+ parser.add_argument("--steps", type=int, metavar="INT", default=30)
27
+ parser.add_argument("--tome", type=float, metavar="FLOAT", default=0.0)
28
+ parser.add_argument("--taesd", action="store_true")
29
+ parser.add_argument("--clip-skip", action="store_true")
30
+ parser.add_argument("--truncate", action="store_true")
31
+ parser.add_argument("--no-karras", action="store_false")
32
+ parser.add_argument("--no-increment", action="store_false")
33
+
34
+ args = parser.parse_args()
35
+ images = generate(
36
+ args.prompt,
37
+ args.negative,
38
+ args.seed,
39
+ args.model,
40
+ args.scheduler,
41
+ args.width,
42
+ args.height,
43
+ args.guidance,
44
+ args.steps,
45
+ args.images,
46
+ args.no_karras,
47
+ args.taesd,
48
+ args.clip_skip,
49
+ args.truncate,
50
+ args.no_increment,
51
+ args.deepcache,
52
+ args.tgate,
53
+ args.tome,
54
+ )
55
+ save_images(images, args.filename)
56
+
57
+
58
+ if __name__ == "__main__":
59
+ main()
generate.py CHANGED
@@ -5,9 +5,9 @@ from datetime import datetime
5
  from itertools import product
6
  from os import environ
7
  from types import MethodType
 
8
  from warnings import filterwarnings
9
 
10
- import gradio as gr
11
  import spaces
12
  import tomesd
13
  import torch
@@ -80,7 +80,6 @@ class Loader:
80
  tgate_sd_deepcache if has_deepcache else tgate_sd,
81
  self.pipe,
82
  )
83
-
84
  return self.pipe.tgate
85
 
86
  def _load_vae(self, model_name=None, taesd=False, dtype=None):
@@ -244,10 +243,11 @@ def generate(
244
  deepcache_interval=1,
245
  tgate_step=0,
246
  tome_ratio=0,
247
- progress=gr.Progress(track_tqdm=True),
 
248
  ):
249
  if not torch.cuda.is_available():
250
- raise gr.Error("CUDA not available")
251
 
252
  if seed is None:
253
  seed = int(datetime.now().timestamp())
@@ -324,5 +324,6 @@ def generate(
324
 
325
  end = time.perf_counter()
326
  diff = end - start
327
- gr.Info(f"Generated {len(images)} image{'s' if len(images) > 1 else ''} in {diff:.2f}s")
 
328
  return images
 
5
  from itertools import product
6
  from os import environ
7
  from types import MethodType
8
+ from typing import Callable
9
  from warnings import filterwarnings
10
 
 
11
  import spaces
12
  import tomesd
13
  import torch
 
80
  tgate_sd_deepcache if has_deepcache else tgate_sd,
81
  self.pipe,
82
  )
 
83
  return self.pipe.tgate
84
 
85
  def _load_vae(self, model_name=None, taesd=False, dtype=None):
 
243
  deepcache_interval=1,
244
  tgate_step=0,
245
  tome_ratio=0,
246
+ log: Callable[[str], None] = None,
247
+ Error=Exception,
248
  ):
249
  if not torch.cuda.is_available():
250
+ raise Error("CUDA not available")
251
 
252
  if seed is None:
253
  seed = int(datetime.now().timestamp())
 
324
 
325
  end = time.perf_counter()
326
  diff = end - start
327
+ if log:
328
+ log(f"Generated {len(images)} image{'s' if len(images) > 1 else ''} in {diff:.2f}s")
329
  return images
usage.md CHANGED
@@ -41,7 +41,7 @@ When using arrays, you should disable `Autoincrement` so the same seed is used f
41
 
42
  #### Schedulers
43
 
44
- All are based on [k_diffusion](https://github.com/crowsonkb/k-diffusion) except [DEIS](https://github.com/qsh-zh/deis) and [DPM++](https://github.com/LuChengTHU/dpm-solver). Optionally, the [Karras](https://arxiv.org/abs/2206.00364) noise schedule can be used:
45
 
46
  * [DEIS 2M](https://huggingface.co/docs/diffusers/en/api/schedulers/deis) (default)
47
  * [DPM++ 2M](https://huggingface.co/docs/diffusers/en/api/schedulers/multistep_dpm_solver)
@@ -63,11 +63,11 @@ All are based on [k_diffusion](https://github.com/crowsonkb/k-diffusion) except
63
 
64
  #### T-GATE
65
 
66
- [T-GATE](https://github.com/HaozheLiu-ST/T-GATE) (Zhang et al. 2024) caches self and cross attention computations up to `Step`. Afterwards, attention is no longer computed and the cache is used, resulting in a noticeable speedup. Defaults to `20`.
67
 
68
- #### ToME
69
 
70
- [ToMe](https://arxiv.org/abs/2303.17604) (Bolya & Hoffman 2023) reduces the number of tokens processed by the model. Set `Ratio` to the desired reduction factor. ToMe's impact is more noticeable on larger images.
71
 
72
  #### Tiny VAE
73
 
@@ -79,4 +79,4 @@ When enabled, the last CLIP layer is skipped. This _can_ improve image quality w
79
 
80
  #### Prompt Truncation
81
 
82
- When enabled, prompts will be truncated to CLIP's limit of 77 tokens. By default this is disabled, so Compel will chunk prompts into segments rather than cutting them off.
 
41
 
42
  #### Schedulers
43
 
44
+ Optionally, the [Karras](https://arxiv.org/abs/2206.00364) noise schedule can be used:
45
 
46
  * [DEIS 2M](https://huggingface.co/docs/diffusers/en/api/schedulers/deis) (default)
47
  * [DPM++ 2M](https://huggingface.co/docs/diffusers/en/api/schedulers/multistep_dpm_solver)
 
63
 
64
  #### T-GATE
65
 
66
+ [Temporal gating](https://github.com/HaozheLiu-ST/T-GATE) (Zhang et al. 2024) caches self and cross attention computations up to `Step`. Afterwards, attention is no longer computed and the cache is used, resulting in a noticeable speedup.
67
 
68
+ #### ToMe
69
 
70
+ [Token merging](https://arxiv.org/abs/2303.17604) (Bolya & Hoffman 2023) reduces the number of tokens processed by the model. Set `Ratio` to the desired reduction factor. ToMe's impact is more noticeable on larger images.
71
 
72
  #### Tiny VAE
73
 
 
79
 
80
  #### Prompt Truncation
81
 
82
+ When enabled, prompts will be truncated to CLIP's limit of 77 tokens. By default this is _disabled_, so Compel will chunk prompts into segments rather than cutting them off.