Spaces:
Runtime error
Runtime error
LeoXing1996
commited on
Commit
•
053817b
1
Parent(s):
9121982
update memory efficient attention
Browse files
animatediff/models/motion_module.py
CHANGED
@@ -458,6 +458,10 @@ class CrossAttention(nn.Module):
|
|
458 |
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
459 |
return hidden_states
|
460 |
|
|
|
|
|
|
|
|
|
461 |
def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
|
462 |
# TODO attention_mask
|
463 |
query = query.contiguous()
|
|
|
458 |
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
459 |
return hidden_states
|
460 |
|
461 |
+
def set_use_memory_efficient_attention_xformers(self, *args, **kwargs):
|
462 |
+
print('Set Xformers for MotionModule\'s Attention.')
|
463 |
+
self._use_memory_efficient_attention_xformers = True
|
464 |
+
|
465 |
def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
|
466 |
# TODO attention_mask
|
467 |
query = query.contiguous()
|