adamelliotfields commited on
Commit
48c31e7
1 Parent(s): 321e262

Add DeepCache and T-GATE

Browse files
Files changed (4) hide show
  1. app.css +1 -1
  2. app.py +109 -51
  3. generate.py +137 -68
  4. requirements.txt +2 -0
app.css CHANGED
@@ -28,7 +28,7 @@
28
  margin-left: 8px;
29
  }
30
 
31
- #gallery {
32
  --block-border-width: 0px;
33
  background-color: transparent;
34
  }
 
28
  margin-left: 8px;
29
  }
30
 
31
+ .gallery {
32
  --block-border-width: 0px;
33
  background-color: transparent;
34
  }
app.py CHANGED
@@ -4,9 +4,11 @@ import gradio as gr
4
 
5
  from generate import generate
6
 
 
 
7
  # base font stacks
8
- mono_fonts = ["monospace"]
9
- sans_fonts = [
10
  "sans-serif",
11
  "Apple Color Emoji",
12
  "Segoe UI Emoji",
@@ -42,96 +44,109 @@ def generate_btn_click(*args, **kwargs):
42
 
43
 
44
  with gr.Blocks(
45
- head=read_file("head.html"),
46
  css="./app.css",
47
  js="./app.js",
48
  theme=gr.themes.Default(
49
  # colors
 
50
  primary_hue=gr.themes.colors.orange,
51
  secondary_hue=gr.themes.colors.blue,
52
- neutral_hue=gr.themes.colors.gray,
53
  # sizing
54
  text_size=gr.themes.sizes.text_md,
55
- spacing_size=gr.themes.sizes.spacing_md,
56
  radius_size=gr.themes.sizes.radius_sm,
 
57
  # fonts
58
- font=[gr.themes.GoogleFont("Inter"), *sans_fonts],
59
- font_mono=[gr.themes.GoogleFont("Ubuntu Mono"), *mono_fonts],
60
  ).set(
61
- block_background_fill=gr.themes.colors.gray.c50,
62
- block_background_fill_dark=gr.themes.colors.gray.c900,
63
  block_shadow="0 0 #0000",
64
  block_shadow_dark="0 0 #0000",
 
 
65
  ),
66
  ) as demo:
67
- gr.HTML(read_file("intro.html"))
68
  output_images = gr.Gallery(
69
- label="Output",
70
- show_label=False,
71
- columns=1,
72
- interactive=False,
73
  show_share_button=False,
74
- elem_id="gallery",
 
 
 
 
75
  )
76
  prompt = gr.Textbox(
77
- label="Prompt",
78
  show_label=False,
79
- lines=2,
80
- placeholder="corgi, at the beach, cute",
81
  value=None,
 
82
  )
83
  generate_btn = gr.Button("Generate", variant="primary", elem_classes=[])
84
 
85
  with gr.Accordion(
 
 
86
  label="Menu",
87
  open=False,
88
- elem_id="menu",
89
- elem_classes=["accordion"],
90
  ):
91
  with gr.Tabs():
92
  with gr.TabItem("⚙️ Settings"):
93
  with gr.Group():
94
  negative_prompt = gr.Textbox(
95
  label="Negative Prompt",
 
 
96
  lines=1,
97
- placeholder="ugly",
98
- value="",
99
  )
100
 
101
  with gr.Row():
102
  num_images = gr.Dropdown(
103
- label="Images",
104
  choices=[1, 2, 3, 4],
105
- value=1,
106
  filterable=False,
 
 
 
107
  )
108
- aspect_ratio = gr.Dropdown(
109
- label="Aspect Ratio",
110
- choices=["1:1", "4:3", "3:4", "16:9", "9:16"],
111
- value="1:1",
112
- filterable=False,
 
 
 
 
 
 
 
 
 
 
113
  )
114
- seed = gr.Number(label="Seed", value=0)
115
 
116
  with gr.Row():
117
  guidance_scale = gr.Slider(
118
  label="Guidance Scale",
119
  minimum=1.0,
120
  maximum=15.0,
121
- step=0.1,
122
  value=7.5,
 
123
  )
124
  inference_steps = gr.Slider(
125
  label="Inference Steps",
126
  minimum=1,
127
  maximum=50,
128
- step=1,
129
  value=30,
 
130
  )
131
 
132
  with gr.Row():
133
  model = gr.Dropdown(
 
134
  label="Model",
 
135
  choices=[
136
  "fluently/Fluently-v4",
137
  "Linaqruf/anything-v3-1",
@@ -140,10 +155,12 @@ with gr.Blocks(
140
  "runwayml/stable-diffusion-v1-5",
141
  "SG161222/Realistic_Vision_V5.1_Novae",
142
  ],
143
- value="Lykon/dreamshaper-8",
144
  )
145
  scheduler = gr.Dropdown(
 
146
  label="Scheduler",
 
 
147
  choices=[
148
  "DEIS 2M",
149
  "DPM++ 2M",
@@ -153,22 +170,20 @@ with gr.Blocks(
153
  "LMS",
154
  "PNDM",
155
  ],
156
- value="DEIS 2M",
157
- elem_id="scheduler",
158
  )
 
159
 
160
  with gr.Row():
161
  use_karras = gr.Checkbox(
162
- label="Use Karras σ",
163
- value=True,
164
  elem_classes=["checkbox"],
165
- scale=2,
 
 
166
  )
167
  increment_seed = gr.Checkbox(
168
- label="Autoincrement seed",
169
- value=True,
170
  elem_classes=["checkbox"],
171
- elem_id="increment-seed",
 
172
  scale=2,
173
  )
174
  random_seed_btn = gr.Button(
@@ -179,21 +194,57 @@ with gr.Blocks(
179
  )
180
 
181
  with gr.TabItem("🛠️ Advanced"):
182
- gr.Markdown("_Coming soon..._", elem_classes=["markdown"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
  with gr.TabItem("ℹ️ Info"):
185
  gr.Markdown(read_file("info.md"), elem_classes=["markdown"])
186
 
187
- # change gallery columns when num_images changes
188
- num_images.change(
189
- lambda n: gr.Gallery(columns=n),
190
- inputs=[num_images],
191
- outputs=[output_images],
192
- )
193
-
194
  # update the random seed using JavaScript
195
  random_seed_btn.click(None, outputs=[seed], js="() => Math.floor(Math.random() * 2**32)")
196
 
 
197
  generate_btn.click(
198
  generate_btn_click,
199
  api_name="generate",
@@ -205,12 +256,19 @@ with gr.Blocks(
205
  seed,
206
  model,
207
  scheduler,
208
- aspect_ratio,
 
209
  guidance_scale,
210
  inference_steps,
211
- use_karras,
212
  num_images,
 
 
 
 
213
  increment_seed,
 
 
 
214
  ],
215
  )
216
 
 
4
 
5
  from generate import generate
6
 
7
+ DEFAULT_NEGATIVE_PROMPT = "<bad_prompt>, ugly, unattractive, deformed, disfigured, mutated, blurry, distorted, noisy, grainy, glitch, worst quality"
8
+
9
  # base font stacks
10
+ MONO_FONTS = ["monospace"]
11
+ SANS_FONTS = [
12
  "sans-serif",
13
  "Apple Color Emoji",
14
  "Segoe UI Emoji",
 
44
 
45
 
46
  with gr.Blocks(
47
+ head=read_file("./partials/head.html"),
48
  css="./app.css",
49
  js="./app.js",
50
  theme=gr.themes.Default(
51
  # colors
52
+ neutral_hue=gr.themes.colors.gray,
53
  primary_hue=gr.themes.colors.orange,
54
  secondary_hue=gr.themes.colors.blue,
 
55
  # sizing
56
  text_size=gr.themes.sizes.text_md,
 
57
  radius_size=gr.themes.sizes.radius_sm,
58
+ spacing_size=gr.themes.sizes.spacing_md,
59
  # fonts
60
+ font=[gr.themes.GoogleFont("Inter"), *SANS_FONTS],
61
+ font_mono=[gr.themes.GoogleFont("Ubuntu Mono"), *MONO_FONTS],
62
  ).set(
 
 
63
  block_shadow="0 0 #0000",
64
  block_shadow_dark="0 0 #0000",
65
+ block_background_fill=gr.themes.colors.gray.c50,
66
+ block_background_fill_dark=gr.themes.colors.gray.c900,
67
  ),
68
  ) as demo:
69
+ gr.HTML(read_file("./partials/intro.html"))
70
  output_images = gr.Gallery(
71
+ elem_classes=["gallery"],
 
 
 
72
  show_share_button=False,
73
+ interactive=False,
74
+ show_label=False,
75
+ label="Output",
76
+ format="png",
77
+ columns=2,
78
  )
79
  prompt = gr.Textbox(
80
+ placeholder="corgi, at the beach, cute, 8k",
81
  show_label=False,
82
+ label="Prompt",
 
83
  value=None,
84
+ lines=2,
85
  )
86
  generate_btn = gr.Button("Generate", variant="primary", elem_classes=[])
87
 
88
  with gr.Accordion(
89
+ elem_classes=["accordion"],
90
+ elem_id="menu",
91
  label="Menu",
92
  open=False,
 
 
93
  ):
94
  with gr.Tabs():
95
  with gr.TabItem("⚙️ Settings"):
96
  with gr.Group():
97
  negative_prompt = gr.Textbox(
98
  label="Negative Prompt",
99
+ value=DEFAULT_NEGATIVE_PROMPT,
100
+ placeholder="",
101
  lines=1,
 
 
102
  )
103
 
104
  with gr.Row():
105
  num_images = gr.Dropdown(
 
106
  choices=[1, 2, 3, 4],
 
107
  filterable=False,
108
+ label="Images",
109
+ value=1,
110
+ scale=1,
111
  )
112
+ width = gr.Slider(
113
+ label="Width",
114
+ minimum=256,
115
+ maximum=1024,
116
+ value=512,
117
+ step=32,
118
+ scale=2,
119
+ )
120
+ height = gr.Slider(
121
+ label="Height",
122
+ minimum=256,
123
+ maximum=1024,
124
+ value=512,
125
+ step=32,
126
+ scale=2,
127
  )
 
128
 
129
  with gr.Row():
130
  guidance_scale = gr.Slider(
131
  label="Guidance Scale",
132
  minimum=1.0,
133
  maximum=15.0,
 
134
  value=7.5,
135
+ step=0.1,
136
  )
137
  inference_steps = gr.Slider(
138
  label="Inference Steps",
139
  minimum=1,
140
  maximum=50,
 
141
  value=30,
142
+ step=1,
143
  )
144
 
145
  with gr.Row():
146
  model = gr.Dropdown(
147
+ value="Lykon/dreamshaper-8",
148
  label="Model",
149
+ scale=2,
150
  choices=[
151
  "fluently/Fluently-v4",
152
  "Linaqruf/anything-v3-1",
 
155
  "runwayml/stable-diffusion-v1-5",
156
  "SG161222/Realistic_Vision_V5.1_Novae",
157
  ],
 
158
  )
159
  scheduler = gr.Dropdown(
160
+ elem_id="scheduler",
161
  label="Scheduler",
162
+ value="DEIS 2M",
163
+ scale=2,
164
  choices=[
165
  "DEIS 2M",
166
  "DPM++ 2M",
 
170
  "LMS",
171
  "PNDM",
172
  ],
 
 
173
  )
174
+ seed = gr.Number(label="Seed", value=42)
175
 
176
  with gr.Row():
177
  use_karras = gr.Checkbox(
 
 
178
  elem_classes=["checkbox"],
179
+ label="Karras σ",
180
+ value=True,
181
+ scale=1,
182
  )
183
  increment_seed = gr.Checkbox(
 
 
184
  elem_classes=["checkbox"],
185
+ label="Autoincrement",
186
+ value=True,
187
  scale=2,
188
  )
189
  random_seed_btn = gr.Button(
 
194
  )
195
 
196
  with gr.TabItem("🛠️ Advanced"):
197
+ with gr.Group():
198
+ with gr.Row():
199
+ deep_cache_interval = gr.Slider(
200
+ label="DeepCache Interval",
201
+ minimum=1,
202
+ maximum=4,
203
+ value=0,
204
+ step=1,
205
+ )
206
+ deep_cache_branch = gr.Slider(
207
+ label="DeepCache Branch",
208
+ minimum=0,
209
+ maximum=3,
210
+ value=0,
211
+ step=1,
212
+ )
213
+ tgate_step = gr.Slider(
214
+ label="T-GATE Step",
215
+ minimum=0,
216
+ maximum=50,
217
+ value=0,
218
+ step=1,
219
+ )
220
+
221
+ with gr.Row():
222
+ use_taesd = gr.Checkbox(
223
+ elem_classes=["checkbox"],
224
+ label="Tiny VAE",
225
+ value=False,
226
+ scale=1,
227
+ )
228
+ use_clip_skip = gr.Checkbox(
229
+ elem_classes=["checkbox"],
230
+ label="Clip skip",
231
+ value=False,
232
+ scale=1,
233
+ )
234
+ truncate_prompts = gr.Checkbox(
235
+ elem_classes=["checkbox"],
236
+ label="Truncate prompts",
237
+ value=False,
238
+ scale=3,
239
+ )
240
 
241
  with gr.TabItem("ℹ️ Info"):
242
  gr.Markdown(read_file("info.md"), elem_classes=["markdown"])
243
 
 
 
 
 
 
 
 
244
  # update the random seed using JavaScript
245
  random_seed_btn.click(None, outputs=[seed], js="() => Math.floor(Math.random() * 2**32)")
246
 
247
+ # ensure correct argument order
248
  generate_btn.click(
249
  generate_btn_click,
250
  api_name="generate",
 
256
  seed,
257
  model,
258
  scheduler,
259
+ width,
260
+ height,
261
  guidance_scale,
262
  inference_steps,
 
263
  num_images,
264
+ use_karras,
265
+ use_taesd,
266
+ use_clip_skip,
267
+ truncate_prompts,
268
  increment_seed,
269
+ deep_cache_interval,
270
+ deep_cache_branch,
271
+ tgate_step,
272
  ],
273
  )
274
 
generate.py CHANGED
@@ -1,12 +1,15 @@
1
  import re
 
2
  from datetime import datetime
3
  from itertools import product
4
  from os import environ
 
5
  from warnings import filterwarnings
6
 
7
  import spaces
8
  import torch
9
- from compel import Compel
 
10
  from diffusers import (
11
  DEISMultistepScheduler,
12
  DPMSolverMultistepScheduler,
@@ -17,18 +20,23 @@ from diffusers import (
17
  PNDMScheduler,
18
  StableDiffusionPipeline,
19
  )
20
- from diffusers.models import AutoencoderTiny
 
 
 
21
 
22
  ZERO_GPU = (
23
  environ.get("SPACES_ZERO_GPU", "").lower() == "true"
24
  or environ.get("SPACES_ZERO_GPU", "") == "1"
25
  )
26
 
27
- TORCH_DTYPE = (
28
- torch.bfloat16
29
- if torch.cuda.is_available() and torch.cuda.is_bf16_supported()
30
- else torch.float16
31
- )
 
 
32
 
33
  # some models use the deprecated CLIPFeatureExtractor class
34
  # should use CLIPImageProcessor instead
@@ -46,7 +54,27 @@ class Loader:
46
  cls._instance.pipe = None
47
  return cls._instance
48
 
49
- def load(self, model, scheduler, karras):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  model_lower = model.lower()
51
 
52
  schedulers = {
@@ -60,24 +88,24 @@ class Loader:
60
  }
61
 
62
  scheduler_kwargs = {
63
- "beta_start": 0.00085,
64
- "beta_end": 0.012,
65
  "beta_schedule": "scaled_linear",
66
  "timestep_spacing": "leading",
67
- "steps_offset": 1,
68
  "use_karras_sigmas": karras,
 
 
 
69
  }
70
 
71
  if scheduler == "PNDM" or scheduler == "Euler a":
72
  del scheduler_kwargs["use_karras_sigmas"]
73
 
74
  pipe_kwargs = {
 
75
  "pretrained_model_name_or_path": model_lower,
76
  "requires_safety_checker": False,
77
- "safety_checker": None,
78
- "scheduler": schedulers[scheduler](**scheduler_kwargs),
79
- "torch_dtype": TORCH_DTYPE,
80
  "use_safetensors": True,
 
 
81
  }
82
 
83
  # already loaded
@@ -92,11 +120,19 @@ class Loader:
92
 
93
  if same_model:
94
  if not same_scheduler:
95
- print(f"Swapping scheduler to {scheduler}...")
96
- elif not same_karras:
97
  print(f"{'Enabling' if karras else 'Disabling'} Karras sigmas...")
98
- elif not (same_scheduler and same_karras):
99
  self.pipe.scheduler = schedulers[scheduler](**scheduler_kwargs)
 
 
 
 
 
 
 
 
100
  return self.pipe
101
  else:
102
  print(f"Unloading {model_name.lower()}...")
@@ -111,39 +147,51 @@ class Loader:
111
  ]:
112
  pipe_kwargs["variant"] = "fp16"
113
 
114
- # uses special VAE
115
- if model_lower not in ["linaqruf/anything-v3-1"]:
116
- pipe_kwargs["vae"] = AutoencoderTiny.from_pretrained(
117
- "madebyollin/taesd",
118
- torch_dtype=TORCH_DTYPE,
119
- use_safetensors=True,
120
- )
121
-
122
- print(f"Loading {model_lower}...")
123
  self.pipe = StableDiffusionPipeline.from_pretrained(**pipe_kwargs).to(self.gpu)
 
 
 
 
 
124
  return self.pipe
125
 
126
 
127
- # prepare prompts for Compel
128
- def join_prompt(prompt: str) -> str:
129
- lines = prompt.strip().splitlines()
130
- return '("' + '", "'.join(lines) + '").and()' if len(lines) > 1 else prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
 
133
  # parse prompts with arrays
134
  def parse_prompt(prompt: str) -> list[str]:
135
- joined_prompt = join_prompt(prompt)
136
- arrays = re.findall(r"\[\[(.*?)\]\]", joined_prompt)
137
 
138
  if not arrays:
139
- return [joined_prompt]
140
 
141
  tokens = [item.split(",") for item in arrays]
142
  combinations = list(product(*tokens))
143
  prompts = []
144
 
145
  for combo in combinations:
146
- current_prompt = joined_prompt
147
  for i, token in enumerate(combo):
148
  current_prompt = current_prompt.replace(f"[[{arrays[i]}]]", token.strip(), 1)
149
 
@@ -156,55 +204,65 @@ def generate(
156
  positive_prompt,
157
  negative_prompt="",
158
  seed=None,
159
- model="lykon/dreamshaper-8",
160
  scheduler="DEIS 2M",
161
- aspect_ratio="1:1",
 
162
  guidance_scale=7.5,
163
  inference_steps=30,
164
- karras=True,
165
  num_images=1,
 
 
 
 
166
  increment_seed=True,
 
 
 
167
  Error=Exception,
168
  ):
169
  if not torch.cuda.is_available():
170
  raise Error("CUDA not available")
171
 
172
- # image dimensions
173
- aspect_ratios = {
174
- "16:9": (640, 360),
175
- "4:3": (576, 432),
176
- "1:1": (512, 512),
177
- "3:4": (432, 576),
178
- "9:16": (360, 640),
179
- }
180
- width, height = aspect_ratios[aspect_ratio]
 
 
 
 
 
181
 
182
  with torch.inference_mode():
183
  loader = Loader()
184
- pipe = loader.load(model, scheduler, karras)
185
 
186
  # prompt embeds
187
  compel = Compel(
188
- tokenizer=pipe.tokenizer,
 
 
 
189
  text_encoder=pipe.text_encoder,
190
- truncate_long_prompts=False,
191
  device=pipe.device,
192
- dtype_for_device_getter=lambda _: TORCH_DTYPE,
193
  )
194
 
195
- neg_prompt = join_prompt(negative_prompt)
196
- neg_embeds = compel(neg_prompt)
197
-
198
- if seed is None:
199
- seed = int(datetime.now().timestamp())
200
-
201
- current_seed = seed
202
  images = []
 
 
203
 
204
  for i in range(num_images):
 
205
  generator = torch.Generator(device=pipe.device).manual_seed(current_seed)
206
 
207
- # run the prompt for this iteration
208
  all_positive_prompts = parse_prompt(positive_prompt)
209
  prompt_index = i % len(all_positive_prompts)
210
  pos_prompt = all_positive_prompts[prompt_index]
@@ -213,16 +271,27 @@ def generate(
213
  [pos_embeds, neg_embeds]
214
  )
215
 
216
- result = pipe(
217
- width=width,
218
- height=height,
219
- prompt_embeds=pos_embeds,
220
- negative_prompt_embeds=neg_embeds,
221
- num_inference_steps=inference_steps,
222
- guidance_scale=guidance_scale,
223
- generator=generator,
224
- )
225
- images.append((result.images[0], str(current_seed)))
 
 
 
 
 
 
 
 
 
 
 
226
 
227
  if increment_seed:
228
  current_seed += 1
 
1
  import re
2
+ from contextlib import contextmanager
3
  from datetime import datetime
4
  from itertools import product
5
  from os import environ
6
+ from types import MethodType
7
  from warnings import filterwarnings
8
 
9
  import spaces
10
  import torch
11
+ from compel import Compel, DiffusersTextualInversionManager, ReturnedEmbeddingsType
12
+ from DeepCache import DeepCacheSDHelper
13
  from diffusers import (
14
  DEISMultistepScheduler,
15
  DPMSolverMultistepScheduler,
 
20
  PNDMScheduler,
21
  StableDiffusionPipeline,
22
  )
23
+ from diffusers.models import AutoencoderKL, AutoencoderTiny
24
+ from tgate.SD import tgate as tgate_sd
25
+ from tgate.SD_DeepCache import tgate as tgate_sd_deepcache
26
+ from torch._dynamo import OptimizedModule
27
 
28
  ZERO_GPU = (
29
  environ.get("SPACES_ZERO_GPU", "").lower() == "true"
30
  or environ.get("SPACES_ZERO_GPU", "") == "1"
31
  )
32
 
33
+ EMBEDDINGS = {
34
+ "./embeddings/bad_prompt_version2.pt": "<bad_prompt>",
35
+ "./embeddings/BadDream.pt": "<bad_dream>",
36
+ "./embeddings/FastNegativeV2.pt": "<fast_negative>",
37
+ "./embeddings/negative_hand.pt": "<negative_hand>",
38
+ "./embeddings/UnrealisticDream.pt": "<unrealistic_dream>",
39
+ }
40
 
41
  # some models use the deprecated CLIPFeatureExtractor class
42
  # should use CLIPImageProcessor instead
 
54
  cls._instance.pipe = None
55
  return cls._instance
56
 
57
+ def _load_vae(self, model_name=None, taesd=False, dtype=None):
58
+ if taesd:
59
+ # can't compile tiny VAE
60
+ return AutoencoderTiny.from_pretrained(
61
+ pretrained_model_name_or_path="madebyollin/taesd",
62
+ use_safetensors=True,
63
+ torch_dtype=dtype,
64
+ ).to(self.gpu)
65
+
66
+ return torch.compile(
67
+ fullgraph=True,
68
+ mode="reduce-overhead",
69
+ model=AutoencoderKL.from_pretrained(
70
+ pretrained_model_name_or_path=model_name,
71
+ use_safetensors=True,
72
+ torch_dtype=dtype,
73
+ subfolder="vae",
74
+ ).to(self.gpu),
75
+ )
76
+
77
+ def load(self, model, scheduler, karras, taesd, dtype=None):
78
  model_lower = model.lower()
79
 
80
  schedulers = {
 
88
  }
89
 
90
  scheduler_kwargs = {
 
 
91
  "beta_schedule": "scaled_linear",
92
  "timestep_spacing": "leading",
 
93
  "use_karras_sigmas": karras,
94
+ "beta_start": 0.00085,
95
+ "beta_end": 0.012,
96
+ "steps_offset": 1,
97
  }
98
 
99
  if scheduler == "PNDM" or scheduler == "Euler a":
100
  del scheduler_kwargs["use_karras_sigmas"]
101
 
102
  pipe_kwargs = {
103
+ "scheduler": schedulers[scheduler](**scheduler_kwargs),
104
  "pretrained_model_name_or_path": model_lower,
105
  "requires_safety_checker": False,
 
 
 
106
  "use_safetensors": True,
107
+ "safety_checker": None,
108
+ "torch_dtype": dtype,
109
  }
110
 
111
  # already loaded
 
120
 
121
  if same_model:
122
  if not same_scheduler:
123
+ print(f"Switching to {scheduler}...")
124
+ if not same_karras:
125
  print(f"{'Enabling' if karras else 'Disabling'} Karras sigmas...")
126
+ if not same_scheduler or not same_karras:
127
  self.pipe.scheduler = schedulers[scheduler](**scheduler_kwargs)
128
+
129
+ # if compiled will be an OptimizedModule
130
+ vae_type = type(self.pipe.vae)
131
+ if (issubclass(vae_type, (AutoencoderKL, OptimizedModule)) and taesd) or (
132
+ issubclass(vae_type, AutoencoderTiny) and not taesd
133
+ ):
134
+ print(f"Switching to {'Tiny' if taesd else 'KL'} VAE...")
135
+ self.pipe.vae = self._load_vae(model_lower, taesd, dtype)
136
  return self.pipe
137
  else:
138
  print(f"Unloading {model_name.lower()}...")
 
147
  ]:
148
  pipe_kwargs["variant"] = "fp16"
149
 
150
+ print(f"Loading {model_lower} with {'Tiny' if taesd else 'KL'} VAE...")
 
 
 
 
 
 
 
 
151
  self.pipe = StableDiffusionPipeline.from_pretrained(**pipe_kwargs).to(self.gpu)
152
+ self.pipe.vae = self._load_vae(model_lower, taesd, dtype)
153
+ self.pipe.load_textual_inversion(
154
+ pretrained_model_name_or_path=list(EMBEDDINGS.keys()),
155
+ tokens=list(EMBEDDINGS.values()),
156
+ )
157
  return self.pipe
158
 
159
 
160
+ @contextmanager
161
+ def deep_cache(pipe, interval=1, branch=0, tgate_step=0):
162
+ if interval > 1:
163
+ helper = DeepCacheSDHelper(pipe=pipe)
164
+ helper.set_params(cache_interval=interval, cache_branch_id=branch)
165
+ helper.enable()
166
+
167
+ if tgate_step > 0:
168
+ pipe.deepcache = helper
169
+ pipe.tgate = MethodType(tgate_sd_deepcache, pipe)
170
+
171
+ try:
172
+ yield helper
173
+ finally:
174
+ helper.disable()
175
+ elif interval < 2 and tgate_step > 0:
176
+ pipe.tgate = MethodType(tgate_sd, pipe)
177
+ yield None
178
+ else:
179
+ yield None
180
 
181
 
182
  # parse prompts with arrays
183
  def parse_prompt(prompt: str) -> list[str]:
184
+ arrays = re.findall(r"\[\[(.*?)\]\]", prompt)
 
185
 
186
  if not arrays:
187
+ return [prompt]
188
 
189
  tokens = [item.split(",") for item in arrays]
190
  combinations = list(product(*tokens))
191
  prompts = []
192
 
193
  for combo in combinations:
194
+ current_prompt = prompt
195
  for i, token in enumerate(combo):
196
  current_prompt = current_prompt.replace(f"[[{arrays[i]}]]", token.strip(), 1)
197
 
 
204
  positive_prompt,
205
  negative_prompt="",
206
  seed=None,
207
+ model="Lykon/dreamshaper-8",
208
  scheduler="DEIS 2M",
209
+ width=512,
210
+ height=512,
211
  guidance_scale=7.5,
212
  inference_steps=30,
 
213
  num_images=1,
214
+ karras=True,
215
+ taesd=False,
216
+ clip_skip=False,
217
+ truncate_prompts=False,
218
  increment_seed=True,
219
+ deep_cache_interval=1,
220
+ deep_cache_branch=0,
221
+ tgate_step=0,
222
  Error=Exception,
223
  ):
224
  if not torch.cuda.is_available():
225
  raise Error("CUDA not available")
226
 
227
+ if seed is None:
228
+ seed = int(datetime.now().timestamp())
229
+
230
+ TORCH_DTYPE = (
231
+ torch.bfloat16
232
+ if torch.cuda.is_available() and torch.cuda.is_bf16_supported()
233
+ else torch.float16
234
+ )
235
+
236
+ EMBEDDINGS_TYPE = (
237
+ ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NORMALIZED
238
+ if clip_skip
239
+ else ReturnedEmbeddingsType.LAST_HIDDEN_STATES_NORMALIZED
240
+ )
241
 
242
  with torch.inference_mode():
243
  loader = Loader()
244
+ pipe = loader.load(model, scheduler, karras, taesd, dtype=TORCH_DTYPE)
245
 
246
  # prompt embeds
247
  compel = Compel(
248
+ textual_inversion_manager=DiffusersTextualInversionManager(pipe),
249
+ dtype_for_device_getter=lambda _: TORCH_DTYPE,
250
+ returned_embeddings_type=EMBEDDINGS_TYPE,
251
+ truncate_long_prompts=truncate_prompts,
252
  text_encoder=pipe.text_encoder,
253
+ tokenizer=pipe.tokenizer,
254
  device=pipe.device,
 
255
  )
256
 
 
 
 
 
 
 
 
257
  images = []
258
+ current_seed = seed
259
+ neg_embeds = compel(negative_prompt)
260
 
261
  for i in range(num_images):
262
+ # seeded generator for each iteration
263
  generator = torch.Generator(device=pipe.device).manual_seed(current_seed)
264
 
265
+ # get the prompt for this iteration
266
  all_positive_prompts = parse_prompt(positive_prompt)
267
  prompt_index = i % len(all_positive_prompts)
268
  pos_prompt = all_positive_prompts[prompt_index]
 
271
  [pos_embeds, neg_embeds]
272
  )
273
 
274
+ with deep_cache(
275
+ pipe,
276
+ interval=deep_cache_interval,
277
+ branch=deep_cache_branch,
278
+ tgate_step=tgate_step,
279
+ ):
280
+ pipe_kwargs = {
281
+ "num_inference_steps": inference_steps,
282
+ "negative_prompt_embeds": neg_embeds,
283
+ "guidance_scale": guidance_scale,
284
+ "prompt_embeds": pos_embeds,
285
+ "generator": generator,
286
+ "height": height,
287
+ "width": width,
288
+ }
289
+ result = (
290
+ pipe.tgate(**pipe_kwargs, gate_step=tgate_step)
291
+ if tgate_step > 0
292
+ else pipe(**pipe_kwargs)
293
+ )
294
+ images.append((result.images[0], str(current_seed)))
295
 
296
  if increment_seed:
297
  current_seed += 1
requirements.txt CHANGED
@@ -1,11 +1,13 @@
1
  accelerate
2
  compel
 
3
  diffusers
4
  hf-transfer
5
  gradio
6
  ruff
7
  scipy # for LMS scheduler
8
  spaces
 
9
  torch
10
  torchvision
11
  transformers
 
1
  accelerate
2
  compel
3
+ deepcache
4
  diffusers
5
  hf-transfer
6
  gradio
7
  ruff
8
  scipy # for LMS scheduler
9
  spaces
10
+ tgate
11
  torch
12
  torchvision
13
  transformers