Spaces:
Runtime error
Runtime error
LeoXing1996
commited on
Commit
•
9121982
1
Parent(s):
00a4e2b
update memory efficient attention
Browse files- animatediff/models/motion_module.py +14 -1
- requirements.txt +1 -0
animatediff/models/motion_module.py
CHANGED
@@ -467,6 +467,14 @@ class CrossAttention(nn.Module):
|
|
467 |
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
468 |
return hidden_states
|
469 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
470 |
|
471 |
class VersatileAttention(CrossAttention):
|
472 |
def __init__(
|
@@ -532,7 +540,12 @@ class VersatileAttention(CrossAttention):
|
|
532 |
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
|
533 |
|
534 |
# attention, what we cannot get enough of
|
535 |
-
if
|
|
|
|
|
|
|
|
|
|
|
536 |
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
|
537 |
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
|
538 |
hidden_states = hidden_states.to(query.dtype)
|
|
|
467 |
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
468 |
return hidden_states
|
469 |
|
470 |
+
def _memory_efficient_attention_pt20(self, query, key, value, attention_mask):
|
471 |
+
query = query.contiguous()
|
472 |
+
key = key.contiguous()
|
473 |
+
value = value.contiguous()
|
474 |
+
hidden_states = torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0, is_causal=False)
|
475 |
+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
476 |
+
return hidden_states
|
477 |
+
|
478 |
|
479 |
class VersatileAttention(CrossAttention):
|
480 |
def __init__(
|
|
|
540 |
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
|
541 |
|
542 |
# attention, what we cannot get enough of
|
543 |
+
if hasattr(F, 'scaled_dot_product_attention'):
|
544 |
+
# NOTE: pt20's scaled_dot_product_attention seems more memory efficient than
|
545 |
+
# xformers' memory_efficient_attention, set it as the first class citizen
|
546 |
+
hidden_states = self._memory_efficient_attention_pt20(query, key, value, attention_mask)
|
547 |
+
hidden_states = hidden_states.to(query.dtype)
|
548 |
+
elif self._use_memory_efficient_attention_xformers:
|
549 |
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
|
550 |
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
|
551 |
hidden_states = hidden_states.to(query.dtype)
|
requirements.txt
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
opencv-python
|
|
|
2 |
torchvision==0.14.1
|
3 |
diffusers==0.24.0
|
4 |
transformers==4.25.1
|
|
|
1 |
opencv-python
|
2 |
+
torch>=2.0.0
|
3 |
torchvision==0.14.1
|
4 |
diffusers==0.24.0
|
5 |
transformers==4.25.1
|