adamelliotfields commited on
Commit
c62ffd9
·
verified ·
1 Parent(s): 9ae9087

Remove T-GATE

Browse files
Files changed (6) hide show
  1. app.py +7 -23
  2. cli.py +0 -2
  3. config.py +0 -2
  4. generate.py +3 -25
  5. requirements.txt +0 -1
  6. usage.md +0 -4
app.py CHANGED
@@ -180,6 +180,12 @@ with gr.Blocks(
180
  with gr.TabItem("🛠️ Advanced"):
181
  with gr.Group():
182
  with gr.Row():
 
 
 
 
 
 
183
  deepcache_interval = gr.Slider(
184
  value=cfg.DEEPCACHE_INTERVAL,
185
  label="DeepCache Interval",
@@ -187,21 +193,6 @@ with gr.Blocks(
187
  maximum=4,
188
  step=1,
189
  )
190
- tgate_step = gr.Slider(
191
- maximum=cfg.INFERENCE_STEPS,
192
- value=cfg.TGATE_STEP,
193
- label="T-GATE Step",
194
- minimum=0,
195
- step=1,
196
- )
197
-
198
- with gr.Row():
199
- file_format = gr.Dropdown(
200
- choices=["png", "jpeg", "webp"],
201
- label="File Format",
202
- filterable=False,
203
- value="png",
204
- )
205
  tome_ratio = gr.Slider(
206
  value=cfg.TOME_RATIO,
207
  label="ToMe Ratio",
@@ -227,7 +218,7 @@ with gr.Blocks(
227
  elem_classes=["checkbox"],
228
  label="Truncate prompts",
229
  value=False,
230
- scale=3,
231
  )
232
 
233
  with gr.TabItem("ℹ️ Usage"):
@@ -288,12 +279,6 @@ with gr.Blocks(
288
  outputs=[output_images],
289
  )
290
 
291
- inference_steps.change(
292
- lambda max, step: gr.Slider(maximum=max, value=min(max, step)),
293
- inputs=[inference_steps, tgate_step],
294
- outputs=[tgate_step],
295
- )
296
-
297
  gr.on(
298
  triggers=[generate_btn.click, prompt.submit],
299
  fn=handle_generate,
@@ -318,7 +303,6 @@ with gr.Blocks(
318
  truncate_prompts,
319
  increment_seed,
320
  deepcache_interval,
321
- tgate_step,
322
  tome_ratio,
323
  ],
324
  )
 
180
  with gr.TabItem("🛠️ Advanced"):
181
  with gr.Group():
182
  with gr.Row():
183
+ file_format = gr.Dropdown(
184
+ choices=["png", "jpeg", "webp"],
185
+ label="File Format",
186
+ filterable=False,
187
+ value="png",
188
+ )
189
  deepcache_interval = gr.Slider(
190
  value=cfg.DEEPCACHE_INTERVAL,
191
  label="DeepCache Interval",
 
193
  maximum=4,
194
  step=1,
195
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  tome_ratio = gr.Slider(
197
  value=cfg.TOME_RATIO,
198
  label="ToMe Ratio",
 
218
  elem_classes=["checkbox"],
219
  label="Truncate prompts",
220
  value=False,
221
+ scale=1,
222
  )
223
 
224
  with gr.TabItem("ℹ️ Usage"):
 
279
  outputs=[output_images],
280
  )
281
 
 
 
 
 
 
 
282
  gr.on(
283
  triggers=[generate_btn.click, prompt.submit],
284
  fn=handle_generate,
 
303
  truncate_prompts,
304
  increment_seed,
305
  deepcache_interval,
 
306
  tome_ratio,
307
  ],
308
  )
cli.py CHANGED
@@ -22,7 +22,6 @@ def main():
22
  parser.add_argument("-h", "--height", type=int, metavar="INT", default=cfg.HEIGHT)
23
  parser.add_argument("-m", "--model", type=str, metavar="STR", default=cfg.MODEL)
24
  parser.add_argument("-d", "--deepcache", type=int, metavar="INT", default=cfg.DEEPCACHE_INTERVAL)
25
- parser.add_argument("-t", "--tgate", type=int, metavar="INT", default=cfg.TGATE_STEP)
26
  parser.add_argument("--style", type=str, metavar="STR", default=cfg.STYLE)
27
  parser.add_argument("--scheduler", type=str, metavar="STR", default=cfg.SCHEDULER)
28
  parser.add_argument("--guidance", type=float, metavar="FLOAT", default=cfg.GUIDANCE_SCALE)
@@ -54,7 +53,6 @@ def main():
54
  args.truncate,
55
  args.no_increment,
56
  args.deepcache,
57
- args.tgate,
58
  args.tome,
59
  )
60
  save_images(images, args.filename)
 
22
  parser.add_argument("-h", "--height", type=int, metavar="INT", default=cfg.HEIGHT)
23
  parser.add_argument("-m", "--model", type=str, metavar="STR", default=cfg.MODEL)
24
  parser.add_argument("-d", "--deepcache", type=int, metavar="INT", default=cfg.DEEPCACHE_INTERVAL)
 
25
  parser.add_argument("--style", type=str, metavar="STR", default=cfg.STYLE)
26
  parser.add_argument("--scheduler", type=str, metavar="STR", default=cfg.SCHEDULER)
27
  parser.add_argument("--guidance", type=float, metavar="FLOAT", default=cfg.GUIDANCE_SCALE)
 
53
  args.truncate,
54
  args.no_increment,
55
  args.deepcache,
 
56
  args.tome,
57
  )
58
  save_images(images, args.filename)
config.py CHANGED
@@ -49,6 +49,4 @@ INFERENCE_STEPS = 30
49
 
50
  DEEPCACHE_INTERVAL = 2
51
 
52
- TGATE_STEP = 0
53
-
54
  TOME_RATIO = 0.0
 
49
 
50
  DEEPCACHE_INTERVAL = 2
51
 
 
 
52
  TOME_RATIO = 0.0
generate.py CHANGED
@@ -25,8 +25,6 @@ from diffusers import (
25
  StableDiffusionPipeline,
26
  )
27
  from diffusers.models import AutoencoderKL, AutoencoderTiny
28
- from tgate.SD import tgate as tgate_sd
29
- from tgate.SD_DeepCache import tgate as tgate_sd_deepcache
30
  from torch._dynamo import OptimizedModule
31
 
32
  # some models use the deprecated CLIPFeatureExtractor class (should use CLIPImageProcessor)
@@ -77,17 +75,6 @@ class Loader:
77
  self.pipe.deepcache.enable()
78
  return self.pipe.deepcache
79
 
80
- def _load_tgate(self):
81
- has_tgate = hasattr(self.pipe, "tgate")
82
- has_deepcache = hasattr(self.pipe, "deepcache")
83
-
84
- if not has_tgate:
85
- self.pipe.tgate = MethodType(
86
- tgate_sd_deepcache if has_deepcache else tgate_sd,
87
- self.pipe,
88
- )
89
- return self.pipe.tgate
90
-
91
  def _load_vae(self, model_name=None, taesd=False, dtype=None):
92
  vae_type = type(self.pipe.vae)
93
  is_kl = issubclass(vae_type, (AutoencoderKL, OptimizedModule))
@@ -172,7 +159,6 @@ class Loader:
172
 
173
  self._load_vae(model_lower, taesd, dtype)
174
  self._load_deepcache(interval=deepcache_interval)
175
- self._load_tgate()
176
  return self.pipe
177
  else:
178
  print(f"Unloading {model_name.lower()}...")
@@ -189,13 +175,12 @@ class Loader:
189
 
190
  print(f"Loading {model_lower} with {'Tiny' if taesd else 'KL'} VAE...")
191
  self.pipe = StableDiffusionPipeline.from_pretrained(**pipe_kwargs).to(self.gpu)
192
- self._load_vae(model_lower, taesd, dtype)
193
- self._load_deepcache(interval=deepcache_interval)
194
- self._load_tgate()
195
  self.pipe.load_textual_inversion(
196
  pretrained_model_name_or_path=list(EMBEDDINGS.keys()),
197
  tokens=list(EMBEDDINGS.values()),
198
  )
 
 
199
  return self.pipe
200
 
201
 
@@ -262,7 +247,6 @@ def generate(
262
  truncate_prompts=False,
263
  increment_seed=True,
264
  deepcache_interval=1,
265
- tgate_step=0,
266
  tome_ratio=0,
267
  log: Callable[[str], None] = None,
268
  Error=Exception,
@@ -328,17 +312,11 @@ def generate(
328
  raise Error("ParsingException: Invalid prompt")
329
 
330
  with token_merging(pipe, tome_ratio=tome_ratio):
331
- # cap the tgate step
332
- gate_step = min(
333
- tgate_step if tgate_step > 0 else inference_steps,
334
- inference_steps,
335
- )
336
- result = pipe.tgate(
337
  num_inference_steps=inference_steps,
338
  negative_prompt_embeds=neg_embeds,
339
  guidance_scale=guidance_scale,
340
  prompt_embeds=pos_embeds,
341
- gate_step=gate_step,
342
  generator=generator,
343
  height=height,
344
  width=width,
 
25
  StableDiffusionPipeline,
26
  )
27
  from diffusers.models import AutoencoderKL, AutoencoderTiny
 
 
28
  from torch._dynamo import OptimizedModule
29
 
30
  # some models use the deprecated CLIPFeatureExtractor class (should use CLIPImageProcessor)
 
75
  self.pipe.deepcache.enable()
76
  return self.pipe.deepcache
77
 
 
 
 
 
 
 
 
 
 
 
 
78
  def _load_vae(self, model_name=None, taesd=False, dtype=None):
79
  vae_type = type(self.pipe.vae)
80
  is_kl = issubclass(vae_type, (AutoencoderKL, OptimizedModule))
 
159
 
160
  self._load_vae(model_lower, taesd, dtype)
161
  self._load_deepcache(interval=deepcache_interval)
 
162
  return self.pipe
163
  else:
164
  print(f"Unloading {model_name.lower()}...")
 
175
 
176
  print(f"Loading {model_lower} with {'Tiny' if taesd else 'KL'} VAE...")
177
  self.pipe = StableDiffusionPipeline.from_pretrained(**pipe_kwargs).to(self.gpu)
 
 
 
178
  self.pipe.load_textual_inversion(
179
  pretrained_model_name_or_path=list(EMBEDDINGS.keys()),
180
  tokens=list(EMBEDDINGS.values()),
181
  )
182
+ self._load_vae(model_lower, taesd, dtype)
183
+ self._load_deepcache(interval=deepcache_interval)
184
  return self.pipe
185
 
186
 
 
247
  truncate_prompts=False,
248
  increment_seed=True,
249
  deepcache_interval=1,
 
250
  tome_ratio=0,
251
  log: Callable[[str], None] = None,
252
  Error=Exception,
 
312
  raise Error("ParsingException: Invalid prompt")
313
 
314
  with token_merging(pipe, tome_ratio=tome_ratio):
315
+ result = pipe(
 
 
 
 
 
316
  num_inference_steps=inference_steps,
317
  negative_prompt_embeds=neg_embeds,
318
  guidance_scale=guidance_scale,
319
  prompt_embeds=pos_embeds,
 
320
  generator=generator,
321
  height=height,
322
  width=width,
requirements.txt CHANGED
@@ -7,7 +7,6 @@ gradio==4.39.0
7
  ruff
8
  scipy # for LMS scheduler
9
  spaces
10
- tgate==0.1.2
11
  tomesd==0.1.3
12
  torch
13
  torchvision
 
7
  ruff
8
  scipy # for LMS scheduler
9
  spaces
 
10
  tomesd==0.1.3
11
  torch
12
  torchvision
usage.md CHANGED
@@ -65,10 +65,6 @@ Optionally, the [Karras](https://arxiv.org/abs/2206.00364) noise schedule can be
65
  * `3`: balanced
66
  * `4`: more speed
67
 
68
- #### T-GATE
69
-
70
- [Temporal gating](https://github.com/HaozheLiu-ST/T-GATE) (Zhang et al. 2024) caches self and cross attention computations up to `Step`. Afterwards, attention is no longer computed and the cache is used, resulting in a noticeable speedup.
71
-
72
  #### ToMe
73
 
74
  [Token merging](https://arxiv.org/abs/2303.17604) (Bolya & Hoffman 2023) reduces the number of tokens processed by the model. Set `Ratio` to the desired reduction factor. ToMe's impact is more noticeable on larger images.
 
65
  * `3`: balanced
66
  * `4`: more speed
67
 
 
 
 
 
68
  #### ToMe
69
 
70
  [Token merging](https://arxiv.org/abs/2303.17604) (Bolya & Hoffman 2023) reduces the number of tokens processed by the model. Set `Ratio` to the desired reduction factor. ToMe's impact is more noticeable on larger images.