NIRVANALAN commited on
Commit
89337c5
β€’
1 Parent(s): a0afc03

update dep

Browse files
Files changed (4) hide show
  1. README.md +1 -2
  2. app.py +1 -1
  3. dit/dit_models_xformers.py +2 -2
  4. ldm/modules/attention.py +4 -4
README.md CHANGED
@@ -3,8 +3,7 @@ title: LN3DIff-I23D
3
  emoji: πŸŒ–
4
  colorFrom: green
5
  colorTo: green
6
- # sdk: gradio
7
- sdk: docker
8
  sdk_version: 4.25.0
9
  app_file: app.py
10
  pinned: false
 
3
  emoji: πŸŒ–
4
  colorFrom: green
5
  colorTo: green
6
+ sdk: gradio
 
7
  sdk_version: 4.25.0
8
  app_file: app.py
9
  pinned: false
app.py CHANGED
@@ -49,7 +49,7 @@ th.backends.cuda.matmul.allow_tf32 = True
49
  th.backends.cudnn.allow_tf32 = True
50
  th.backends.cudnn.enabled = True
51
 
52
- install_dependency()
53
 
54
  from guided_diffusion import dist_util, logger
55
  from guided_diffusion.script_util import (
 
49
  th.backends.cudnn.allow_tf32 = True
50
  th.backends.cudnn.enabled = True
51
 
52
+ # install_dependency()
53
 
54
  from guided_diffusion import dist_util, logger
55
  from guided_diffusion.script_util import (
dit/dit_models_xformers.py CHANGED
@@ -28,8 +28,8 @@ try:
28
  from apex.normalization import FusedRMSNorm as RMSNorm
29
  from apex.normalization import FusedLayerNorm as LayerNorm
30
  except:
31
- from torch.nn import LayerNorm as LayerNorm
32
- from .norm import RMSNorm
33
 
34
  # from torch.nn import LayerNorm
35
  # from xformers import triton
 
28
  from apex.normalization import FusedRMSNorm as RMSNorm
29
  from apex.normalization import FusedLayerNorm as LayerNorm
30
  except:
31
+ from torch.nn import LayerNorm
32
+ from torch.nn import RMSNorm # requires torch2.4
33
 
34
  # from torch.nn import LayerNorm
35
  # from xformers import triton
ldm/modules/attention.py CHANGED
@@ -14,10 +14,10 @@ from ldm.modules.diffusionmodules.util import checkpoint
14
  import os
15
  _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
16
  from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
17
- # try:
18
- from apex.normalization import FusedRMSNorm as RMSNorm
19
- # except:
20
- # from dit.norm import RMSNorm
21
 
22
 
23
  def exists(val):
 
14
  import os
15
  _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
16
  from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
17
+ try:
18
+ from apex.normalization import FusedRMSNorm as RMSNorm
19
+ except:
20
+ from torch.nn import RMSNorm # requires torch2.4
21
 
22
 
23
  def exists(val):