fffiloni commited on
Commit
f89e4d2
1 Parent(s): b7d7c65

Update text2vid_torch2.py

Browse files
Files changed (1) hide show
  1. text2vid_torch2.py +1 -68
text2vid_torch2.py CHANGED
@@ -168,7 +168,7 @@ class AttnProcessor2_0:
168
 
169
  return hidden_states
170
 
171
- '''
172
  def get_qk(
173
  self, query, key):
174
  r"""
@@ -222,73 +222,6 @@ class AttnProcessor2_0:
222
 
223
 
224
  return query, key, dynamic_lambda, key1
225
- '''
226
-
227
- def get_qk(self, query, key):
228
- r"""
229
- Compute the attention scores.
230
-
231
- Args:
232
- query (`torch.Tensor`): The query tensor.
233
- key (`torch.Tensor`): The key tensor.
234
- attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
235
-
236
- Returns:
237
- `torch.Tensor`: The attention probabilities/scores.
238
- """
239
- try:
240
- q_old = query.clone()
241
- k_old = key.clone()
242
- dynamic_lambda = None
243
- key1 = None
244
-
245
- if self.use_last_attn_slice:
246
- if self.last_attn_slice is not None:
247
- query_list = self.last_attn_slice[0]
248
- key_list = self.last_attn_slice[1]
249
-
250
- # Ensure that shapes are compatible before performing assignments
251
- if query.shape[1] == self.num_frames and query.shape == key.shape:
252
- key1 = key.clone()
253
-
254
- # Safety check: ensure key1 can receive the value from key_list without causing size mismatch
255
- if key1.shape[0] >= key_list.shape[0]:
256
- key1[:, :1, :key_list.shape[2]] = key_list[:, :1]
257
- else:
258
- raise RuntimeError(f"Shape mismatch: key1 has {key1.shape[0]} batches, but key_list has {key_list.shape[0]} batches.")
259
-
260
- # Dynamic lambda scaling
261
- dynamic_lambda = torch.tensor([1 + self.LAMBDA * (i / 50) for i in range(self.num_frames)]).to(key.dtype).cuda()
262
-
263
- if q_old.shape == k_old.shape and q_old.shape[1] != self.num_frames:
264
- # Ensure batch size division is valid
265
- batch_dim = query_list.shape[0] // self.bs
266
- all_dim = query.shape[0] // self.bs
267
-
268
- for i in range(self.bs):
269
- # Safety check for slicing indices to avoid memory access errors
270
- query_slice = query[i * all_dim:(i * all_dim) + batch_dim, :query_list.shape[1], :query_list.shape[2]]
271
- target_slice = query_list[i * batch_dim:(i + 1) * batch_dim]
272
-
273
- # Validate dimensions match before assignment
274
- if query_slice.shape == target_slice.shape:
275
- query_slice[:] = target_slice
276
- else:
277
- raise RuntimeError(f"Shape mismatch during slicing: query slice shape {query_slice.shape}, target slice shape {target_slice.shape}")
278
-
279
- if self.save_last_attn_slice:
280
- self.last_attn_slice = [query, key]
281
- self.save_last_attn_slice = False
282
-
283
- except RuntimeError as e:
284
- # If a RuntimeError happens, catch it and clean CUDA memory
285
- print(f"RuntimeError occurred: {e}. Cleaning up CUDA memory...")
286
- torch.cuda.empty_cache() # Free up CUDA memory to avoid further issues
287
- raise # Re-raise the error to propagate it if needed
288
-
289
- return query, key, dynamic_lambda, key1
290
-
291
-
292
 
293
 
294
  def init_attention_func(unet):
 
168
 
169
  return hidden_states
170
 
171
+
172
  def get_qk(
173
  self, query, key):
174
  r"""
 
222
 
223
 
224
  return query, key, dynamic_lambda, key1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
 
226
 
227
  def init_attention_func(unet):