LeoXing1996 commited on
Commit
9121982
1 Parent(s): 00a4e2b

update memory efficient attention

Browse files
animatediff/models/motion_module.py CHANGED
@@ -467,6 +467,14 @@ class CrossAttention(nn.Module):
467
  hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
468
  return hidden_states
469
 
 
 
 
 
 
 
 
 
470
 
471
  class VersatileAttention(CrossAttention):
472
  def __init__(
@@ -532,7 +540,12 @@ class VersatileAttention(CrossAttention):
532
  attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
533
 
534
  # attention, what we cannot get enough of
535
- if self._use_memory_efficient_attention_xformers:
 
 
 
 
 
536
  hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
537
  # Some versions of xformers return output in fp32, cast it back to the dtype of the input
538
  hidden_states = hidden_states.to(query.dtype)
 
467
  hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
468
  return hidden_states
469
 
470
+ def _memory_efficient_attention_pt20(self, query, key, value, attention_mask):
471
+ query = query.contiguous()
472
+ key = key.contiguous()
473
+ value = value.contiguous()
474
+ hidden_states = torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0, is_causal=False)
475
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
476
+ return hidden_states
477
+
478
 
479
  class VersatileAttention(CrossAttention):
480
  def __init__(
 
540
  attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
541
 
542
  # attention, what we cannot get enough of
543
+ if hasattr(F, 'scaled_dot_product_attention'):
544
+ # NOTE: pt20's scaled_dot_product_attention seems more memory efficient than
545
+ # xformers' memory_efficient_attention, set it as the first class citizen
546
+ hidden_states = self._memory_efficient_attention_pt20(query, key, value, attention_mask)
547
+ hidden_states = hidden_states.to(query.dtype)
548
+ elif self._use_memory_efficient_attention_xformers:
549
  hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
550
  # Some versions of xformers return output in fp32, cast it back to the dtype of the input
551
  hidden_states = hidden_states.to(query.dtype)
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  opencv-python
 
2
  torchvision==0.14.1
3
  diffusers==0.24.0
4
  transformers==4.25.1
 
1
  opencv-python
2
+ torch>=2.0.0
3
  torchvision==0.14.1
4
  diffusers==0.24.0
5
  transformers==4.25.1