NIRVANALAN commited on
Commit
615f2be
1 Parent(s): cb12c31
app.py CHANGED
@@ -175,8 +175,8 @@ def main(args):
175
  # denoise_model.load_state_dict(
176
  # dist_util.load_state_dict(args.ddpm_model_path, map_location="cpu"))
177
  denoise_model.to(dist_util.dev())
178
- denoise_model = denoise_model.to(th.bfloat16)
179
- auto_encoder = auto_encoder.to(th.bfloat16)
180
  # if args.use_fp16:
181
  # denoise_model.convert_to_fp16()
182
  denoise_model.eval()
 
175
  # denoise_model.load_state_dict(
176
  # dist_util.load_state_dict(args.ddpm_model_path, map_location="cpu"))
177
  denoise_model.to(dist_util.dev())
178
+ # denoise_model = denoise_model.to(th.bfloat16)
179
+ # auto_encoder = auto_encoder.to(th.bfloat16)
180
  # if args.use_fp16:
181
  # denoise_model.convert_to_fp16()
182
  denoise_model.eval()
configs/i23d_args.json CHANGED
@@ -3,7 +3,7 @@
3
  "triplane_scaling_divider": 0.96806,
4
  "diffusion_input_size": 32,
5
  "trainer_name": "flow_matching",
6
- "use_amp": true,
7
  "clip_denoised": false,
8
  "num_samples": 1,
9
  "num_instances": 10,
 
3
  "triplane_scaling_divider": 0.96806,
4
  "diffusion_input_size": 32,
5
  "trainer_name": "flow_matching",
6
+ "use_amp": false,
7
  "clip_denoised": false,
8
  "num_samples": 1,
9
  "num_instances": 10,
dit/dit_i23d.py CHANGED
@@ -16,7 +16,7 @@ try:
16
  from apex.normalization import FusedRMSNorm as RMSNorm
17
  except:
18
  from torch.nn import LayerNorm
19
- from diffusers.models.normalization import RMSNorm
20
 
21
  # from vit.vit_triplane import XYZPosEmbed
22
 
 
16
  from apex.normalization import FusedRMSNorm as RMSNorm
17
  except:
18
  from torch.nn import LayerNorm
19
+ from dit.norm import RMSNorm
20
 
21
  # from vit.vit_triplane import XYZPosEmbed
22
 
dit/dit_models_xformers.py CHANGED
@@ -29,7 +29,7 @@ try:
29
  from apex.normalization import FusedLayerNorm as LayerNorm
30
  except:
31
  from torch.nn import LayerNorm
32
- from diffusers.models.normalization import RMSNorm
33
 
34
  # from torch.nn import LayerNorm
35
  # from xformers import triton
 
29
  from apex.normalization import FusedLayerNorm as LayerNorm
30
  except:
31
  from torch.nn import LayerNorm
32
+ from dit.norm import RMSNorm
33
 
34
  # from torch.nn import LayerNorm
35
  # from xformers import triton
dit/norm.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numbers
2
+ from typing import Dict, Optional, Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from pdb import set_trace as st
9
+
10
+
11
+ class RMSNorm(nn.Module):
12
+ def __init__(self, dim, eps: float, elementwise_affine: bool = True):
13
+ super().__init__()
14
+
15
+ self.eps = eps
16
+
17
+ if isinstance(dim, numbers.Integral):
18
+ dim = (dim,)
19
+
20
+ self.dim = torch.Size(dim)
21
+
22
+ if elementwise_affine:
23
+ self.weight = nn.Parameter(torch.ones(dim))
24
+ else:
25
+ self.weight = None
26
+
27
+ def forward(self, hidden_states):
28
+ input_dtype = hidden_states.dtype
29
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
30
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
31
+
32
+ if self.weight is not None:
33
+ # convert into half-precision if necessary
34
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
35
+ hidden_states = hidden_states.to(self.weight.dtype)
36
+ hidden_states = hidden_states * self.weight
37
+ else:
38
+ hidden_states = hidden_states.to(input_dtype)
39
+
40
+ return hidden_states.to(input_dtype)
ldm/modules/attention.py CHANGED
@@ -18,7 +18,7 @@ try:
18
  from apex.normalization import FusedRMSNorm as RMSNorm
19
  except:
20
  # from dit.norm import RMSNorm
21
- from diffusers.models.normalization import RMSNorm
22
 
23
 
24
  def exists(val):
 
18
  from apex.normalization import FusedRMSNorm as RMSNorm
19
  except:
20
  # from dit.norm import RMSNorm
21
+ from dit.norm import RMSNorm
22
 
23
 
24
  def exists(val):
vit/vision_transformer.py CHANGED
@@ -38,7 +38,7 @@ try:
38
  from apex.normalization import FusedRMSNorm as RMSNorm
39
  except:
40
  # from dit.norm import RMSNorm
41
- from diffusers.models.normalization import RMSNorm
42
 
43
  # from apex.normalization import FusedLayerNorm as LayerNorm
44
 
 
38
  from apex.normalization import FusedRMSNorm as RMSNorm
39
  except:
40
  # from dit.norm import RMSNorm
41
+ from dit.norm import RMSNorm
42
 
43
  # from apex.normalization import FusedLayerNorm as LayerNorm
44