nyanko7 commited on
Commit
5e461b1
·
1 Parent(s): 569f71e

Update modules/model.py

Browse files
Files changed (1) hide show
  1. modules/model.py +238 -176
modules/model.py CHANGED
@@ -26,8 +26,9 @@ import modules.safe as _
26
  from safetensors.torch import load_file
27
 
28
  xformers_available = False
29
- try:
30
  import xformers
 
31
  xformers_available = True
32
  except ImportError:
33
  pass
@@ -37,6 +38,7 @@ exists = lambda val: val is not None
37
  default = lambda val, d: val if exists(val) else d
38
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39
 
 
40
  def get_attention_scores(attn, query, key, attention_mask=None):
41
 
42
  if attn.upcast_attention:
@@ -65,72 +67,89 @@ def get_attention_scores(attn, query, key, attention_mask=None):
65
 
66
  return attention_scores
67
 
68
-
69
  def load_lora_attn_procs(model_file, unet, scale=1.0):
70
-
71
- if Path(model_file).suffix == ".pt":
72
- state_dict = torch.load(model_file, map_location="cpu")
73
- else:
74
- state_dict = load_file(model_file, device="cpu")
75
-
76
- # 'lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn1_to_q.lora_down.weight'
77
- # 'down_blocks.0.attentions.0.transformer_blocks.0.attn1.processor.to_q_lora.down.weight'
78
- if any("lora_unet_down_blocks"in k for k in state_dict.keys()):
79
- # extract ldm format lora
80
- df_lora = {}
81
- attn_numlayer = re.compile(r'_attn(\d)_to_([qkv]|out).lora_')
82
- alpha_numlayer = re.compile(r'_attn(\d)_to_([qkv]|out).alpha')
83
- for k, v in state_dict.items():
84
- if "attn" not in k or "lora_te" in k:
85
- # currently not support: ff, clip-attn
86
- continue
87
- k = k.replace("lora_unet_down_blocks_", "down_blocks.")
88
- k = k.replace("lora_unet_up_blocks_", "up_blocks.")
89
- k = k.replace("lora_unet_mid_block_", "mid_block_")
90
- k = k.replace("_attentions_", ".attentions.")
91
- k = k.replace("_transformer_blocks_", ".transformer_blocks.")
92
- k = k.replace("to_out_0", "to_out")
93
- k = attn_numlayer.sub(r'.attn\1.processor.to_\2_lora.', k)
94
- k = alpha_numlayer.sub(r'.attn\1.processor.to_\2_lora.alpha', k)
95
- df_lora[k] = v
96
- state_dict = df_lora
97
-
98
- # fill attn processors
99
- attn_processors = {}
100
-
101
- is_lora = all("lora" in k for k in state_dict.keys())
102
-
103
- if is_lora:
104
- lora_grouped_dict = defaultdict(dict)
105
- for key, value in state_dict.items():
106
- if "alpha" in key:
107
- attn_processor_key, sub_key = ".".join(key.split(".")[:-2]), ".".join(key.split(".")[-2:])
108
- else:
109
- attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
110
- lora_grouped_dict[attn_processor_key][sub_key] = value
111
-
112
- for key, value_dict in lora_grouped_dict.items():
113
- rank = value_dict["to_k_lora.down.weight"].shape[0]
114
- cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
115
- hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
116
-
117
- attn_processors[key] = LoRACrossAttnProcessor(
118
- hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank, scale=scale
119
  )
120
- attn_processors[key].load_state_dict(value_dict, strict=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
- else:
123
- raise ValueError(f"{model_file} does not seem to be in the correct format expected by LoRA training.")
 
 
124
 
125
- # set correct dtype & device
126
- attn_processors = {k: v.to(device=unet.device, dtype=unet.dtype) for k, v in attn_processors.items()}
 
 
 
127
 
128
- # set layers
129
- unet.set_attn_processor(attn_processors)
130
 
131
 
132
- class CrossAttnProcessor(nn.Module):
133
- def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, qkvo_bias=None):
 
 
 
 
 
 
 
134
  batch_size, sequence_length, _ = hidden_states.shape
135
  attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
136
 
@@ -146,12 +165,12 @@ class CrossAttnProcessor(nn.Module):
146
  query = attn.to_q(hidden_states)
147
  key = attn.to_k(encoder_states)
148
  value = attn.to_v(encoder_states)
149
-
150
  if qkvo_bias is not None:
151
  query += qkvo_bias["q"](hidden_states)
152
  key += qkvo_bias["k"](encoder_states)
153
  value += qkvo_bias["v"](encoder_states)
154
-
155
  query = attn.head_to_batch_dim(query)
156
  key = attn.head_to_batch_dim(key)
157
  value = attn.head_to_batch_dim(value)
@@ -161,56 +180,74 @@ class CrossAttnProcessor(nn.Module):
161
  attention_scores = get_attention_scores(attn, query, key, attention_mask)
162
  w = img_state[sequence_length].to(query.device)
163
  cross_attention_weight = weight_func(w, sigma, attention_scores)
164
- attention_scores += torch.repeat_interleave(cross_attention_weight, repeats=attn.heads, dim=0)
165
-
 
 
166
  # calc probs
167
  attention_probs = attention_scores.softmax(dim=-1)
168
  attention_probs = attention_probs.to(query.dtype)
169
  hidden_states = torch.bmm(attention_probs, value)
170
-
171
  elif xformers_available:
172
  hidden_states = xformers.ops.memory_efficient_attention(
173
- query.contiguous(), key.contiguous(), value.contiguous(), attn_bias=attention_mask
 
 
 
174
  )
175
  hidden_states = hidden_states.to(query.dtype)
176
-
177
  else:
178
  q_bucket_size = 512
179
  k_bucket_size = 1024
180
-
181
  # use flash-attention
182
- hidden_states = FlashAttentionFunction(
183
- query.contiguous(), key.contiguous(), value.contiguous(),
184
- attention_mask, causal=False, q_bucket_size=q_bucket_size, k_bucket_size=k_bucket_size
 
 
 
 
 
185
  )
186
  hidden_states = hidden_states.to(query.dtype)
187
-
188
  hidden_states = attn.batch_to_head_dim(hidden_states)
189
 
190
  # linear proj
191
  hidden_states = attn.to_out[0](hidden_states)
192
-
193
  if qkvo_bias is not None:
194
  hidden_states += qkvo_bias["o"](hidden_states)
195
-
196
  # dropout
197
  hidden_states = attn.to_out[1](hidden_states)
198
 
199
  return hidden_states
200
-
201
 
202
  class LoRACrossAttnProcessor(CrossAttnProcessor):
203
  def __init__(self, hidden_size, cross_attention_dim=None, rank=4, scale=1.0):
204
  super().__init__()
205
 
206
  self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
207
- self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
208
- self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
 
 
 
 
209
  self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
210
  self.scale = scale
211
-
212
  def __call__(
213
- self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None,
 
 
 
 
214
  ):
215
  scale = self.scale
216
  qkvo_bias = {
@@ -219,33 +256,37 @@ class LoRACrossAttnProcessor(CrossAttnProcessor):
219
  "v": lambda inputs: scale * self.to_v_lora(inputs),
220
  "o": lambda inputs: scale * self.to_out_lora(inputs),
221
  }
222
- return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask, qkvo_bias)
 
 
223
 
224
 
225
  class LoRALinearLayer(nn.Module):
226
- def __init__(self, in_features, out_features, rank=4):
227
- super().__init__()
228
 
229
- if rank > min(in_features, out_features):
230
- raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
 
 
231
 
232
- self.down = nn.Linear(in_features, rank, bias=False)
233
- self.up = nn.Linear(rank, out_features, bias=False)
234
- self.scale = 1.0
235
- self.alpha = rank
236
 
237
- nn.init.normal_(self.down.weight, std=1 / rank)
238
- nn.init.zeros_(self.up.weight)
239
 
240
- def forward(self, hidden_states):
241
- orig_dtype = hidden_states.dtype
242
- dtype = self.down.weight.dtype
243
- rank = self.down.out_features
244
 
245
- down_hidden_states = self.down(hidden_states.to(dtype))
246
- up_hidden_states = self.up(down_hidden_states) * (self.alpha / rank)
247
 
248
- return up_hidden_states.to(orig_dtype)
249
 
250
 
251
  class ModelWrapper:
@@ -287,8 +328,13 @@ class StableDiffusionPipeline(DiffusionPipeline):
287
  scheduler=scheduler,
288
  )
289
  self.setup_unet(self.unet)
290
- self.prompt_parser = FrozenCLIPEmbedderWithCustomWords(self.tokenizer, self.text_encoder)
291
-
 
 
 
 
 
292
  def setup_unet(self, unet):
293
  unet = unet.to(self.device)
294
  model = ModelWrapper(unet, self.scheduler.alphas_cumprod)
@@ -301,14 +347,14 @@ class StableDiffusionPipeline(DiffusionPipeline):
301
  library = importlib.import_module("k_diffusion")
302
  sampling = getattr(library, "sampling")
303
  return getattr(sampling, scheduler_type)
304
-
305
  def encode_sketchs(self, state, scale_ratio=8, g_strength=1.0, text_ids=None):
306
  uncond, cond = text_ids[0], text_ids[1]
307
-
308
  img_state = []
309
  if state is None:
310
  return torch.FloatTensor(0)
311
-
312
  for k, v in state.items():
313
  if v["map"] is None:
314
  continue
@@ -319,14 +365,16 @@ class StableDiffusionPipeline(DiffusionPipeline):
319
  truncation=True,
320
  add_special_tokens=False,
321
  ).input_ids
322
-
323
  dotmap = v["map"] < 255
324
- arr = torch.from_numpy(dotmap.astype(float) * float(v["weight"]) * g_strength)
 
 
325
  img_state.append((v_input, arr))
326
-
327
  if len(img_state) == 0:
328
  return torch.FloatTensor(0)
329
-
330
  w_tensors = dict()
331
  cond = cond.tolist()
332
  uncond = uncond.tolist()
@@ -341,26 +389,31 @@ class StableDiffusionPipeline(DiffusionPipeline):
341
  for v_as_tokens, img_where_color in img_state:
342
  is_in = 0
343
 
344
- ret = F.interpolate(
345
- img_where_color.unsqueeze(0).unsqueeze(1),
346
- scale_factor=1 / scale_ratio,
347
- mode="bilinear",
348
- align_corners=True,
349
- ).squeeze().reshape(-1, 1).repeat(1, len(v_as_tokens))
350
-
 
 
 
 
 
351
  for idx, tok in enumerate(cond):
352
  if cond[idx : idx + len(v_as_tokens)] == v_as_tokens:
353
  is_in = 1
354
- ret_cond_tensor[0, :, idx : idx + len(v_as_tokens)] += (ret)
355
-
356
  for idx, tok in enumerate(uncond):
357
  if uncond[idx : idx + len(v_as_tokens)] == v_as_tokens:
358
- is_in = 1
359
- ret_uncond_tensor[0, :, idx : idx + len(v_as_tokens)] += (ret)
360
 
361
  if not is_in == 1:
362
  print(f"tokens {v_as_tokens} not found in text")
363
-
364
  w_tensors[w_r * h_r] = torch.cat([ret_uncond_tensor, ret_cond_tensor])
365
  scale_ratio *= 2
366
 
@@ -432,7 +485,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
432
  ):
433
  return torch.device(module._hf_hook.execution_device)
434
  return self.device
435
-
436
  def decode_latents(self, latents):
437
  latents = latents.to(self.device, dtype=self.vae.dtype)
438
  latents = 1 / 0.18215 * latents
@@ -533,7 +586,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
533
  pww_attn_weight=1.0,
534
  sampler_name="",
535
  sampler_opt={},
536
- scale_ratio=8.0
537
  ):
538
  sampler = self.get_scheduler(sampler_name)
539
  if image is not None:
@@ -556,8 +609,10 @@ class StableDiffusionPipeline(DiffusionPipeline):
556
  # 3. Encode input prompt
557
  text_ids, text_embeddings = self.prompt_parser([negative_prompt, prompt])
558
  text_embeddings = text_embeddings.to(self.unet.dtype)
559
-
560
- init_timestep = int(num_inference_steps / min(strength, 0.999)) if strength > 0 else 0
 
 
561
  sigmas = self.get_sigmas(init_timestep, sampler_opt).to(
562
  text_embeddings.device, dtype=text_embeddings.dtype
563
  )
@@ -581,17 +636,15 @@ class StableDiffusionPipeline(DiffusionPipeline):
581
  )
582
 
583
  img_state = self.encode_sketchs(
584
- pww_state,
585
  g_strength=pww_attn_weight,
586
  text_ids=text_ids,
587
  )
588
-
589
  def model_fn(x, sigma):
590
-
591
  latent_model_input = torch.cat([x] * 2)
592
- weight_func = (
593
- lambda w, sigma, qk: w * math.log(1 + sigma) * qk.max()
594
- )
595
  encoder_state = {
596
  "img_state": img_state,
597
  "states": text_embeddings,
@@ -744,19 +797,17 @@ class StableDiffusionPipeline(DiffusionPipeline):
744
  self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to(
745
  latents.device
746
  )
747
-
748
  img_state = self.encode_sketchs(
749
- pww_state,
750
  g_strength=pww_attn_weight,
751
  text_ids=text_ids,
752
  )
753
 
754
  def model_fn(x, sigma):
755
-
756
  latent_model_input = torch.cat([x] * 2)
757
- weight_func = (
758
- lambda w, sigma, qk: w * math.log(1 + sigma) * qk.max()
759
- )
760
  encoder_state = {
761
  "img_state": img_state,
762
  "states": text_embeddings,
@@ -802,7 +853,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
802
  sampler_name=sampler_name,
803
  sampler_opt=sampler_opt,
804
  pww_state=None,
805
- pww_attn_weight=pww_attn_weight/2,
806
  )
807
 
808
  # 8. Post-processing
@@ -816,76 +867,83 @@ class StableDiffusionPipeline(DiffusionPipeline):
816
 
817
 
818
  class FlashAttentionFunction(Function):
819
-
820
-
821
  @staticmethod
822
  @torch.no_grad()
823
  def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
824
- """ Algorithm 2 in the paper """
825
 
826
  device = q.device
827
  max_neg_value = -torch.finfo(q.dtype).max
828
  qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
829
 
830
  o = torch.zeros_like(q)
831
- all_row_sums = torch.zeros((*q.shape[:-1], 1), device = device)
832
- all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, device = device)
833
 
834
- scale = (q.shape[-1] ** -0.5)
835
 
836
  if not exists(mask):
837
  mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
838
  else:
839
- mask = rearrange(mask, 'b n -> b 1 1 n')
840
- mask = mask.split(q_bucket_size, dim = -1)
841
 
842
  row_splits = zip(
843
- q.split(q_bucket_size, dim = -2),
844
- o.split(q_bucket_size, dim = -2),
845
  mask,
846
- all_row_sums.split(q_bucket_size, dim = -2),
847
- all_row_maxes.split(q_bucket_size, dim = -2),
848
  )
849
 
850
  for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
851
  q_start_index = ind * q_bucket_size - qk_len_diff
852
 
853
  col_splits = zip(
854
- k.split(k_bucket_size, dim = -2),
855
- v.split(k_bucket_size, dim = -2),
856
  )
857
 
858
  for k_ind, (kc, vc) in enumerate(col_splits):
859
  k_start_index = k_ind * k_bucket_size
860
 
861
- attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
862
 
863
  if exists(row_mask):
864
  attn_weights.masked_fill_(~row_mask, max_neg_value)
865
 
866
  if causal and q_start_index < (k_start_index + k_bucket_size - 1):
867
- causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1)
 
 
868
  attn_weights.masked_fill_(causal_mask, max_neg_value)
869
 
870
- block_row_maxes = attn_weights.amax(dim = -1, keepdims = True)
871
  attn_weights -= block_row_maxes
872
  exp_weights = torch.exp(attn_weights)
873
 
874
  if exists(row_mask):
875
- exp_weights.masked_fill_(~row_mask, 0.)
876
 
877
- block_row_sums = exp_weights.sum(dim = -1, keepdims = True).clamp(min = EPSILON)
 
 
878
 
879
  new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
880
 
881
- exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc)
882
 
883
  exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
884
  exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
885
 
886
- new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
 
 
 
887
 
888
- oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values)
 
 
889
 
890
  row_maxes.copy_(new_row_maxes)
891
  row_sums.copy_(new_row_sums)
@@ -900,7 +958,7 @@ class FlashAttentionFunction(Function):
900
  @staticmethod
901
  @torch.no_grad()
902
  def backward(ctx, do):
903
- """ Algorithm 4 in the paper """
904
 
905
  causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
906
  q, k, v, o, lse = ctx.saved_tensors
@@ -915,49 +973,53 @@ class FlashAttentionFunction(Function):
915
  dv = torch.zeros_like(v)
916
 
917
  row_splits = zip(
918
- q.split(q_bucket_size, dim = -2),
919
- o.split(q_bucket_size, dim = -2),
920
- do.split(q_bucket_size, dim = -2),
921
  mask,
922
- lse.split(q_bucket_size, dim = -2),
923
- dq.split(q_bucket_size, dim = -2)
924
  )
925
 
926
  for ind, (qc, oc, doc, row_mask, lsec, dqc) in enumerate(row_splits):
927
  q_start_index = ind * q_bucket_size - qk_len_diff
928
 
929
  col_splits = zip(
930
- k.split(k_bucket_size, dim = -2),
931
- v.split(k_bucket_size, dim = -2),
932
- dk.split(k_bucket_size, dim = -2),
933
- dv.split(k_bucket_size, dim = -2),
934
  )
935
 
936
  for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
937
  k_start_index = k_ind * k_bucket_size
938
 
939
- attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
940
 
941
  if causal and q_start_index < (k_start_index + k_bucket_size - 1):
942
- causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1)
 
 
943
  attn_weights.masked_fill_(causal_mask, max_neg_value)
944
 
945
  p = torch.exp(attn_weights - lsec)
946
 
947
  if exists(row_mask):
948
- p.masked_fill_(~row_mask, 0.)
949
 
950
- dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc)
951
- dp = einsum('... i d, ... j d -> ... i j', doc, vc)
952
 
953
- D = (doc * oc).sum(dim = -1, keepdims = True)
954
  ds = p * scale * (dp - D)
955
 
956
- dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc)
957
- dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc)
958
 
959
  dqc.add_(dq_chunk)
960
  dkc.add_(dk_chunk)
961
  dvc.add_(dv_chunk)
962
 
963
- return dq, dk, dv, None, None, None, None
 
 
 
26
  from safetensors.torch import load_file
27
 
28
  xformers_available = False
29
+ try:
30
  import xformers
31
+
32
  xformers_available = True
33
  except ImportError:
34
  pass
 
38
  default = lambda val, d: val if exists(val) else d
39
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
40
 
41
+
42
  def get_attention_scores(attn, query, key, attention_mask=None):
43
 
44
  if attn.upcast_attention:
 
67
 
68
  return attention_scores
69
 
70
+
71
  def load_lora_attn_procs(model_file, unet, scale=1.0):
72
+
73
+ if Path(model_file).suffix == ".pt":
74
+ state_dict = torch.load(model_file, map_location="cpu")
75
+ else:
76
+ state_dict = load_file(model_file, device="cpu")
77
+
78
+ if any("lora_unet_down_blocks" in k for k in state_dict.keys()):
79
+ # convert ldm format lora
80
+ df_lora = {}
81
+ attn_numlayer = re.compile(r"_attn(\d)_to_([qkv]|out).lora_")
82
+ alpha_numlayer = re.compile(r"_attn(\d)_to_([qkv]|out).alpha")
83
+ for k, v in state_dict.items():
84
+ if "attn" not in k or "lora_te" in k:
85
+ # currently not support: ff, clip-attn
86
+ continue
87
+ k = k.replace("lora_unet_down_blocks_", "down_blocks.")
88
+ k = k.replace("lora_unet_up_blocks_", "up_blocks.")
89
+ k = k.replace("lora_unet_mid_block_", "mid_block_")
90
+ k = k.replace("_attentions_", ".attentions.")
91
+ k = k.replace("_transformer_blocks_", ".transformer_blocks.")
92
+ k = k.replace("to_out_0", "to_out")
93
+ k = attn_numlayer.sub(r".attn\1.processor.to_\2_lora.", k)
94
+ k = alpha_numlayer.sub(r".attn\1.processor.to_\2_lora.alpha", k)
95
+ df_lora[k] = v
96
+ state_dict = df_lora
97
+
98
+ # fill attn processors
99
+ attn_processors = {}
100
+
101
+ is_lora = all("lora" in k for k in state_dict.keys())
102
+
103
+ if is_lora:
104
+ lora_grouped_dict = defaultdict(dict)
105
+ for key, value in state_dict.items():
106
+ if "alpha" in key:
107
+ attn_processor_key, sub_key = ".".join(key.split(".")[:-2]), ".".join(
108
+ key.split(".")[-2:]
109
+ )
110
+ else:
111
+ attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(
112
+ key.split(".")[-3:]
 
 
 
 
 
 
 
 
113
  )
114
+ lora_grouped_dict[attn_processor_key][sub_key] = value
115
+
116
+ for key, value_dict in lora_grouped_dict.items():
117
+ rank = value_dict["to_k_lora.down.weight"].shape[0]
118
+ cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
119
+ hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
120
+
121
+ attn_processors[key] = LoRACrossAttnProcessor(
122
+ hidden_size=hidden_size,
123
+ cross_attention_dim=cross_attention_dim,
124
+ rank=rank,
125
+ scale=scale,
126
+ )
127
+ attn_processors[key].load_state_dict(value_dict, strict=False)
128
 
129
+ else:
130
+ raise ValueError(
131
+ f"{model_file} does not seem to be in the correct format expected by LoRA training."
132
+ )
133
 
134
+ # set correct dtype & device
135
+ attn_processors = {
136
+ k: v.to(device=unet.device, dtype=unet.dtype)
137
+ for k, v in attn_processors.items()
138
+ }
139
 
140
+ # set layers
141
+ unet.set_attn_processor(attn_processors)
142
 
143
 
144
+ class CrossAttnProcessor(nn.Module):
145
+ def __call__(
146
+ self,
147
+ attn,
148
+ hidden_states,
149
+ encoder_hidden_states=None,
150
+ attention_mask=None,
151
+ qkvo_bias=None,
152
+ ):
153
  batch_size, sequence_length, _ = hidden_states.shape
154
  attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
155
 
 
165
  query = attn.to_q(hidden_states)
166
  key = attn.to_k(encoder_states)
167
  value = attn.to_v(encoder_states)
168
+
169
  if qkvo_bias is not None:
170
  query += qkvo_bias["q"](hidden_states)
171
  key += qkvo_bias["k"](encoder_states)
172
  value += qkvo_bias["v"](encoder_states)
173
+
174
  query = attn.head_to_batch_dim(query)
175
  key = attn.head_to_batch_dim(key)
176
  value = attn.head_to_batch_dim(value)
 
180
  attention_scores = get_attention_scores(attn, query, key, attention_mask)
181
  w = img_state[sequence_length].to(query.device)
182
  cross_attention_weight = weight_func(w, sigma, attention_scores)
183
+ attention_scores += torch.repeat_interleave(
184
+ cross_attention_weight, repeats=attn.heads, dim=0
185
+ )
186
+
187
  # calc probs
188
  attention_probs = attention_scores.softmax(dim=-1)
189
  attention_probs = attention_probs.to(query.dtype)
190
  hidden_states = torch.bmm(attention_probs, value)
191
+
192
  elif xformers_available:
193
  hidden_states = xformers.ops.memory_efficient_attention(
194
+ query.contiguous(),
195
+ key.contiguous(),
196
+ value.contiguous(),
197
+ attn_bias=attention_mask,
198
  )
199
  hidden_states = hidden_states.to(query.dtype)
200
+
201
  else:
202
  q_bucket_size = 512
203
  k_bucket_size = 1024
204
+
205
  # use flash-attention
206
+ hidden_states = FlashAttn.apply(
207
+ query.contiguous(),
208
+ key.contiguous(),
209
+ value.contiguous(),
210
+ attention_mask,
211
+ causal=False,
212
+ q_bucket_size=q_bucket_size,
213
+ k_bucket_size=k_bucket_size,
214
  )
215
  hidden_states = hidden_states.to(query.dtype)
216
+
217
  hidden_states = attn.batch_to_head_dim(hidden_states)
218
 
219
  # linear proj
220
  hidden_states = attn.to_out[0](hidden_states)
221
+
222
  if qkvo_bias is not None:
223
  hidden_states += qkvo_bias["o"](hidden_states)
224
+
225
  # dropout
226
  hidden_states = attn.to_out[1](hidden_states)
227
 
228
  return hidden_states
229
+
230
 
231
  class LoRACrossAttnProcessor(CrossAttnProcessor):
232
  def __init__(self, hidden_size, cross_attention_dim=None, rank=4, scale=1.0):
233
  super().__init__()
234
 
235
  self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
236
+ self.to_k_lora = LoRALinearLayer(
237
+ cross_attention_dim or hidden_size, hidden_size, rank
238
+ )
239
+ self.to_v_lora = LoRALinearLayer(
240
+ cross_attention_dim or hidden_size, hidden_size, rank
241
+ )
242
  self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
243
  self.scale = scale
244
+
245
  def __call__(
246
+ self,
247
+ attn,
248
+ hidden_states,
249
+ encoder_hidden_states=None,
250
+ attention_mask=None,
251
  ):
252
  scale = self.scale
253
  qkvo_bias = {
 
256
  "v": lambda inputs: scale * self.to_v_lora(inputs),
257
  "o": lambda inputs: scale * self.to_out_lora(inputs),
258
  }
259
+ return super().__call__(
260
+ attn, hidden_states, encoder_hidden_states, attention_mask, qkvo_bias
261
+ )
262
 
263
 
264
  class LoRALinearLayer(nn.Module):
265
+ def __init__(self, in_features, out_features, rank=4):
266
+ super().__init__()
267
 
268
+ if rank > min(in_features, out_features):
269
+ raise ValueError(
270
+ f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}"
271
+ )
272
 
273
+ self.down = nn.Linear(in_features, rank, bias=False)
274
+ self.up = nn.Linear(rank, out_features, bias=False)
275
+ self.scale = 1.0
276
+ self.alpha = rank
277
 
278
+ nn.init.normal_(self.down.weight, std=1 / rank)
279
+ nn.init.zeros_(self.up.weight)
280
 
281
+ def forward(self, hidden_states):
282
+ orig_dtype = hidden_states.dtype
283
+ dtype = self.down.weight.dtype
284
+ rank = self.down.out_features
285
 
286
+ down_hidden_states = self.down(hidden_states.to(dtype))
287
+ up_hidden_states = self.up(down_hidden_states) * (self.alpha / rank)
288
 
289
+ return up_hidden_states.to(orig_dtype)
290
 
291
 
292
  class ModelWrapper:
 
328
  scheduler=scheduler,
329
  )
330
  self.setup_unet(self.unet)
331
+ self.prompt_parser = FrozenCLIPEmbedderWithCustomWords(
332
+ self.tokenizer, self.text_encoder
333
+ )
334
+
335
+ def set_clip_skip(self, n):
336
+ self.prompt_parser.CLIP_stop_at_last_layers = n
337
+
338
  def setup_unet(self, unet):
339
  unet = unet.to(self.device)
340
  model = ModelWrapper(unet, self.scheduler.alphas_cumprod)
 
347
  library = importlib.import_module("k_diffusion")
348
  sampling = getattr(library, "sampling")
349
  return getattr(sampling, scheduler_type)
350
+
351
  def encode_sketchs(self, state, scale_ratio=8, g_strength=1.0, text_ids=None):
352
  uncond, cond = text_ids[0], text_ids[1]
353
+
354
  img_state = []
355
  if state is None:
356
  return torch.FloatTensor(0)
357
+
358
  for k, v in state.items():
359
  if v["map"] is None:
360
  continue
 
365
  truncation=True,
366
  add_special_tokens=False,
367
  ).input_ids
368
+
369
  dotmap = v["map"] < 255
370
+ arr = torch.from_numpy(
371
+ dotmap.astype(float) * float(v["weight"]) * g_strength
372
+ )
373
  img_state.append((v_input, arr))
374
+
375
  if len(img_state) == 0:
376
  return torch.FloatTensor(0)
377
+
378
  w_tensors = dict()
379
  cond = cond.tolist()
380
  uncond = uncond.tolist()
 
389
  for v_as_tokens, img_where_color in img_state:
390
  is_in = 0
391
 
392
+ ret = (
393
+ F.interpolate(
394
+ img_where_color.unsqueeze(0).unsqueeze(1),
395
+ scale_factor=1 / scale_ratio,
396
+ mode="bilinear",
397
+ align_corners=True,
398
+ )
399
+ .squeeze()
400
+ .reshape(-1, 1)
401
+ .repeat(1, len(v_as_tokens))
402
+ )
403
+
404
  for idx, tok in enumerate(cond):
405
  if cond[idx : idx + len(v_as_tokens)] == v_as_tokens:
406
  is_in = 1
407
+ ret_cond_tensor[0, :, idx : idx + len(v_as_tokens)] += ret
408
+
409
  for idx, tok in enumerate(uncond):
410
  if uncond[idx : idx + len(v_as_tokens)] == v_as_tokens:
411
+ is_in = 1
412
+ ret_uncond_tensor[0, :, idx : idx + len(v_as_tokens)] += ret
413
 
414
  if not is_in == 1:
415
  print(f"tokens {v_as_tokens} not found in text")
416
+
417
  w_tensors[w_r * h_r] = torch.cat([ret_uncond_tensor, ret_cond_tensor])
418
  scale_ratio *= 2
419
 
 
485
  ):
486
  return torch.device(module._hf_hook.execution_device)
487
  return self.device
488
+
489
  def decode_latents(self, latents):
490
  latents = latents.to(self.device, dtype=self.vae.dtype)
491
  latents = 1 / 0.18215 * latents
 
586
  pww_attn_weight=1.0,
587
  sampler_name="",
588
  sampler_opt={},
589
+ scale_ratio=8.0,
590
  ):
591
  sampler = self.get_scheduler(sampler_name)
592
  if image is not None:
 
609
  # 3. Encode input prompt
610
  text_ids, text_embeddings = self.prompt_parser([negative_prompt, prompt])
611
  text_embeddings = text_embeddings.to(self.unet.dtype)
612
+
613
+ init_timestep = (
614
+ int(num_inference_steps / min(strength, 0.999)) if strength > 0 else 0
615
+ )
616
  sigmas = self.get_sigmas(init_timestep, sampler_opt).to(
617
  text_embeddings.device, dtype=text_embeddings.dtype
618
  )
 
636
  )
637
 
638
  img_state = self.encode_sketchs(
639
+ pww_state,
640
  g_strength=pww_attn_weight,
641
  text_ids=text_ids,
642
  )
643
+
644
  def model_fn(x, sigma):
645
+
646
  latent_model_input = torch.cat([x] * 2)
647
+ weight_func = lambda w, sigma, qk: w * math.log(1 + sigma) * qk.max()
 
 
648
  encoder_state = {
649
  "img_state": img_state,
650
  "states": text_embeddings,
 
797
  self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to(
798
  latents.device
799
  )
800
+
801
  img_state = self.encode_sketchs(
802
+ pww_state,
803
  g_strength=pww_attn_weight,
804
  text_ids=text_ids,
805
  )
806
 
807
  def model_fn(x, sigma):
808
+
809
  latent_model_input = torch.cat([x] * 2)
810
+ weight_func = lambda w, sigma, qk: w * math.log(1 + sigma) * qk.max()
 
 
811
  encoder_state = {
812
  "img_state": img_state,
813
  "states": text_embeddings,
 
853
  sampler_name=sampler_name,
854
  sampler_opt=sampler_opt,
855
  pww_state=None,
856
+ pww_attn_weight=pww_attn_weight / 2,
857
  )
858
 
859
  # 8. Post-processing
 
867
 
868
 
869
  class FlashAttentionFunction(Function):
 
 
870
  @staticmethod
871
  @torch.no_grad()
872
  def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
873
+ """Algorithm 2 in the paper"""
874
 
875
  device = q.device
876
  max_neg_value = -torch.finfo(q.dtype).max
877
  qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
878
 
879
  o = torch.zeros_like(q)
880
+ all_row_sums = torch.zeros((*q.shape[:-1], 1), device=device)
881
+ all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, device=device)
882
 
883
+ scale = q.shape[-1] ** -0.5
884
 
885
  if not exists(mask):
886
  mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
887
  else:
888
+ mask = rearrange(mask, "b n -> b 1 1 n")
889
+ mask = mask.split(q_bucket_size, dim=-1)
890
 
891
  row_splits = zip(
892
+ q.split(q_bucket_size, dim=-2),
893
+ o.split(q_bucket_size, dim=-2),
894
  mask,
895
+ all_row_sums.split(q_bucket_size, dim=-2),
896
+ all_row_maxes.split(q_bucket_size, dim=-2),
897
  )
898
 
899
  for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
900
  q_start_index = ind * q_bucket_size - qk_len_diff
901
 
902
  col_splits = zip(
903
+ k.split(k_bucket_size, dim=-2),
904
+ v.split(k_bucket_size, dim=-2),
905
  )
906
 
907
  for k_ind, (kc, vc) in enumerate(col_splits):
908
  k_start_index = k_ind * k_bucket_size
909
 
910
+ attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale
911
 
912
  if exists(row_mask):
913
  attn_weights.masked_fill_(~row_mask, max_neg_value)
914
 
915
  if causal and q_start_index < (k_start_index + k_bucket_size - 1):
916
+ causal_mask = torch.ones(
917
+ (qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device
918
+ ).triu(q_start_index - k_start_index + 1)
919
  attn_weights.masked_fill_(causal_mask, max_neg_value)
920
 
921
+ block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
922
  attn_weights -= block_row_maxes
923
  exp_weights = torch.exp(attn_weights)
924
 
925
  if exists(row_mask):
926
+ exp_weights.masked_fill_(~row_mask, 0.0)
927
 
928
+ block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(
929
+ min=EPSILON
930
+ )
931
 
932
  new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
933
 
934
+ exp_values = einsum("... i j, ... j d -> ... i d", exp_weights, vc)
935
 
936
  exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
937
  exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
938
 
939
+ new_row_sums = (
940
+ exp_row_max_diff * row_sums
941
+ + exp_block_row_max_diff * block_row_sums
942
+ )
943
 
944
+ oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_(
945
+ (exp_block_row_max_diff / new_row_sums) * exp_values
946
+ )
947
 
948
  row_maxes.copy_(new_row_maxes)
949
  row_sums.copy_(new_row_sums)
 
958
  @staticmethod
959
  @torch.no_grad()
960
  def backward(ctx, do):
961
+ """Algorithm 4 in the paper"""
962
 
963
  causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
964
  q, k, v, o, lse = ctx.saved_tensors
 
973
  dv = torch.zeros_like(v)
974
 
975
  row_splits = zip(
976
+ q.split(q_bucket_size, dim=-2),
977
+ o.split(q_bucket_size, dim=-2),
978
+ do.split(q_bucket_size, dim=-2),
979
  mask,
980
+ lse.split(q_bucket_size, dim=-2),
981
+ dq.split(q_bucket_size, dim=-2),
982
  )
983
 
984
  for ind, (qc, oc, doc, row_mask, lsec, dqc) in enumerate(row_splits):
985
  q_start_index = ind * q_bucket_size - qk_len_diff
986
 
987
  col_splits = zip(
988
+ k.split(k_bucket_size, dim=-2),
989
+ v.split(k_bucket_size, dim=-2),
990
+ dk.split(k_bucket_size, dim=-2),
991
+ dv.split(k_bucket_size, dim=-2),
992
  )
993
 
994
  for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
995
  k_start_index = k_ind * k_bucket_size
996
 
997
+ attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale
998
 
999
  if causal and q_start_index < (k_start_index + k_bucket_size - 1):
1000
+ causal_mask = torch.ones(
1001
+ (qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device
1002
+ ).triu(q_start_index - k_start_index + 1)
1003
  attn_weights.masked_fill_(causal_mask, max_neg_value)
1004
 
1005
  p = torch.exp(attn_weights - lsec)
1006
 
1007
  if exists(row_mask):
1008
+ p.masked_fill_(~row_mask, 0.0)
1009
 
1010
+ dv_chunk = einsum("... i j, ... i d -> ... j d", p, doc)
1011
+ dp = einsum("... i d, ... j d -> ... i j", doc, vc)
1012
 
1013
+ D = (doc * oc).sum(dim=-1, keepdims=True)
1014
  ds = p * scale * (dp - D)
1015
 
1016
+ dq_chunk = einsum("... i j, ... j d -> ... i d", ds, kc)
1017
+ dk_chunk = einsum("... i j, ... i d -> ... j d", ds, qc)
1018
 
1019
  dqc.add_(dq_chunk)
1020
  dkc.add_(dk_chunk)
1021
  dvc.add_(dv_chunk)
1022
 
1023
+ return dq, dk, dv, None, None, None, None
1024
+
1025
+ FlashAttn = FlashAttentionFunction()