Spaces:
Runtime error
Runtime error
jiaweir
commited on
Commit
•
7c6a4e6
1
Parent(s):
b73b3dd
optimize
Browse files- lgm/core/models.py +6 -5
- main_4d_demo.py +1 -1
lgm/core/models.py
CHANGED
@@ -128,14 +128,16 @@ class LGM(nn.Module):
|
|
128 |
|
129 |
x_orig_res = x.clone()
|
130 |
|
131 |
-
|
132 |
-
|
|
|
|
|
133 |
|
134 |
x = x.permute(0, 1, 3, 4, 2).reshape(B, -1, 14)
|
135 |
|
136 |
pos = self.pos_act(x[..., 0:3]) # [B, N, 3]
|
137 |
opacity = self.opacity_act(x[..., 3:4])
|
138 |
-
scale = self.scale_act(x[..., 4:7]) *
|
139 |
rotation = self.rot_act(x[..., 7:11])
|
140 |
rgbs = self.rgb_act(x[..., 11:])
|
141 |
|
@@ -155,8 +157,7 @@ class LGM(nn.Module):
|
|
155 |
gaussians_orig_res = torch.cat([pos, opacity, scale, rotation, rgbs], dim=-1) # [B, N, 14]
|
156 |
|
157 |
|
158 |
-
|
159 |
-
return gaussians_orig_res, gaussians_orig_res
|
160 |
|
161 |
|
162 |
def forward(self, data, step_ratio=1):
|
|
|
128 |
|
129 |
x_orig_res = x.clone()
|
130 |
|
131 |
+
dowsample_rate = 2
|
132 |
+
|
133 |
+
x = F.interpolate(x, (self.opt.splat_size // dowsample_rate, self.opt.splat_size//dowsample_rate), mode='nearest')
|
134 |
+
x = x.reshape(B, 4, 14, self.opt.splat_size//dowsample_rate, self.opt.splat_size//dowsample_rate)
|
135 |
|
136 |
x = x.permute(0, 1, 3, 4, 2).reshape(B, -1, 14)
|
137 |
|
138 |
pos = self.pos_act(x[..., 0:3]) # [B, N, 3]
|
139 |
opacity = self.opacity_act(x[..., 3:4])
|
140 |
+
scale = self.scale_act(x[..., 4:7]) * dowsample_rate
|
141 |
rotation = self.rot_act(x[..., 7:11])
|
142 |
rgbs = self.rgb_act(x[..., 11:])
|
143 |
|
|
|
157 |
gaussians_orig_res = torch.cat([pos, opacity, scale, rotation, rgbs], dim=-1) # [B, N, 14]
|
158 |
|
159 |
|
160 |
+
return gaussians, gaussians_orig_res
|
|
|
161 |
|
162 |
|
163 |
def forward(self, data, step_ratio=1):
|
main_4d_demo.py
CHANGED
@@ -540,7 +540,7 @@ class GUI:
|
|
540 |
|
541 |
# render eval
|
542 |
image_list =[]
|
543 |
-
fps =
|
544 |
delta_time = 1 / 30
|
545 |
self.renderer.prepare_render_4x()
|
546 |
time = 0
|
|
|
540 |
|
541 |
# render eval
|
542 |
image_list =[]
|
543 |
+
fps = 28
|
544 |
delta_time = 1 / 30
|
545 |
self.renderer.prepare_render_4x()
|
546 |
time = 0
|