Commit
•
1bad10f
1
Parent(s):
5e1c565
Correct final rotation
Browse files- lgm/lgm.py +25 -0
- pipeline.py +0 -16
lgm/lgm.py
CHANGED
@@ -285,6 +285,31 @@ class LGM(ModelMixin, ConfigMixin):
|
|
285 |
rotation = self.rot_act(x[..., 7:11])
|
286 |
rgbs = self.rgb_act(x[..., 11:])
|
287 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
288 |
gaussians = torch.cat([pos, opacity, scale, rotation, rgbs], dim=-1)
|
289 |
|
290 |
return gaussians
|
|
|
285 |
rotation = self.rot_act(x[..., 7:11])
|
286 |
rgbs = self.rgb_act(x[..., 11:])
|
287 |
|
288 |
+
q = torch.tensor([0, 0, 1, 0], dtype=pos.dtype, device=pos.device)
|
289 |
+
R = torch.tensor(
|
290 |
+
[
|
291 |
+
[-1, 0, 0],
|
292 |
+
[0, -1, 0],
|
293 |
+
[0, 0, 1],
|
294 |
+
],
|
295 |
+
dtype=pos.dtype,
|
296 |
+
device=pos.device,
|
297 |
+
)
|
298 |
+
|
299 |
+
pos = torch.matmul(pos, R.T)
|
300 |
+
|
301 |
+
def multiply_quat(q1, q2):
|
302 |
+
w1, x1, y1, z1 = q1.unbind(-1)
|
303 |
+
w2, x2, y2, z2 = q2.unbind(-1)
|
304 |
+
w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
|
305 |
+
x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
|
306 |
+
y = w1 * y2 + y1 * w2 + z1 * x2 - x1 * z2
|
307 |
+
z = w1 * z2 + z1 * w2 + x1 * y2 - y1 * x2
|
308 |
+
return torch.stack([w, x, y, z], dim=-1)
|
309 |
+
|
310 |
+
for i in range(B):
|
311 |
+
rotation[i, :] = multiply_quat(q, rotation[i, :])
|
312 |
+
|
313 |
gaussians = torch.cat([pos, opacity, scale, rotation, rgbs], dim=-1)
|
314 |
|
315 |
return gaussians
|
pipeline.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
import numpy as np
|
2 |
-
import rembg
|
3 |
import torch
|
4 |
import torch.nn.functional as F
|
5 |
import torchvision.transforms.functional as TF
|
@@ -10,8 +9,6 @@ class LGMPipeline(DiffusionPipeline):
|
|
10 |
def __init__(self, lgm):
|
11 |
super().__init__()
|
12 |
|
13 |
-
self.bg_remover = rembg.new_session()
|
14 |
-
|
15 |
self.imagenet_default_mean = (0.485, 0.456, 0.406)
|
16 |
self.imagenet_default_std = (0.229, 0.224, 0.225)
|
17 |
|
@@ -23,19 +20,6 @@ class LGMPipeline(DiffusionPipeline):
|
|
23 |
|
24 |
@torch.no_grad()
|
25 |
def __call__(self, images):
|
26 |
-
unstacked = []
|
27 |
-
for i in range(4):
|
28 |
-
image = rembg.remove(images[i], session=self.bg_remover)
|
29 |
-
image = images.astype(np.float32) / 255.0
|
30 |
-
image = image[..., :3] * image[..., -1:] + (1 - image[..., -1:])
|
31 |
-
unstacked.append(image)
|
32 |
-
images = np.concatenate(
|
33 |
-
[
|
34 |
-
np.concatenate([unstacked[1], unstacked[2]], axis=1),
|
35 |
-
np.concatenate([unstacked[3], unstacked[0]], axis=1),
|
36 |
-
],
|
37 |
-
axis=0,
|
38 |
-
)
|
39 |
images = np.stack([images[1], images[2], images[3], images[0]], axis=0)
|
40 |
images = torch.from_numpy(images).permute(0, 3, 1, 2).float().cuda()
|
41 |
images = F.interpolate(
|
|
|
1 |
import numpy as np
|
|
|
2 |
import torch
|
3 |
import torch.nn.functional as F
|
4 |
import torchvision.transforms.functional as TF
|
|
|
9 |
def __init__(self, lgm):
|
10 |
super().__init__()
|
11 |
|
|
|
|
|
12 |
self.imagenet_default_mean = (0.485, 0.456, 0.406)
|
13 |
self.imagenet_default_std = (0.229, 0.224, 0.225)
|
14 |
|
|
|
20 |
|
21 |
@torch.no_grad()
|
22 |
def __call__(self, images):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
images = np.stack([images[1], images[2], images[3], images[0]], axis=0)
|
24 |
images = torch.from_numpy(images).permute(0, 3, 1, 2).float().cuda()
|
25 |
images = F.interpolate(
|