LeoXing1996 commited on
Commit
053817b
1 Parent(s): 9121982

update memory efficient attention

Browse files
animatediff/models/motion_module.py CHANGED
@@ -458,6 +458,10 @@ class CrossAttention(nn.Module):
458
  hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
459
  return hidden_states
460
 
 
 
 
 
461
  def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
462
  # TODO attention_mask
463
  query = query.contiguous()
 
458
  hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
459
  return hidden_states
460
 
461
+ def set_use_memory_efficient_attention_xformers(self, *args, **kwargs):
462
+ print('Set Xformers for MotionModule\'s Attention.')
463
+ self._use_memory_efficient_attention_xformers = True
464
+
465
  def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
466
  # TODO attention_mask
467
  query = query.contiguous()