dylanebert HF staff commited on
Commit
1bad10f
1 Parent(s): 5e1c565

Correct final rotation

Browse files
Files changed (2) hide show
  1. lgm/lgm.py +25 -0
  2. pipeline.py +0 -16
lgm/lgm.py CHANGED
@@ -285,6 +285,31 @@ class LGM(ModelMixin, ConfigMixin):
285
  rotation = self.rot_act(x[..., 7:11])
286
  rgbs = self.rgb_act(x[..., 11:])
287
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
  gaussians = torch.cat([pos, opacity, scale, rotation, rgbs], dim=-1)
289
 
290
  return gaussians
 
285
  rotation = self.rot_act(x[..., 7:11])
286
  rgbs = self.rgb_act(x[..., 11:])
287
 
288
+ q = torch.tensor([0, 0, 1, 0], dtype=pos.dtype, device=pos.device)
289
+ R = torch.tensor(
290
+ [
291
+ [-1, 0, 0],
292
+ [0, -1, 0],
293
+ [0, 0, 1],
294
+ ],
295
+ dtype=pos.dtype,
296
+ device=pos.device,
297
+ )
298
+
299
+ pos = torch.matmul(pos, R.T)
300
+
301
+ def multiply_quat(q1, q2):
302
+ w1, x1, y1, z1 = q1.unbind(-1)
303
+ w2, x2, y2, z2 = q2.unbind(-1)
304
+ w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
305
+ x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
306
+ y = w1 * y2 + y1 * w2 + z1 * x2 - x1 * z2
307
+ z = w1 * z2 + z1 * w2 + x1 * y2 - y1 * x2
308
+ return torch.stack([w, x, y, z], dim=-1)
309
+
310
+ for i in range(B):
311
+ rotation[i, :] = multiply_quat(q, rotation[i, :])
312
+
313
  gaussians = torch.cat([pos, opacity, scale, rotation, rgbs], dim=-1)
314
 
315
  return gaussians
pipeline.py CHANGED
@@ -1,5 +1,4 @@
1
  import numpy as np
2
- import rembg
3
  import torch
4
  import torch.nn.functional as F
5
  import torchvision.transforms.functional as TF
@@ -10,8 +9,6 @@ class LGMPipeline(DiffusionPipeline):
10
  def __init__(self, lgm):
11
  super().__init__()
12
 
13
- self.bg_remover = rembg.new_session()
14
-
15
  self.imagenet_default_mean = (0.485, 0.456, 0.406)
16
  self.imagenet_default_std = (0.229, 0.224, 0.225)
17
 
@@ -23,19 +20,6 @@ class LGMPipeline(DiffusionPipeline):
23
 
24
  @torch.no_grad()
25
  def __call__(self, images):
26
- unstacked = []
27
- for i in range(4):
28
- image = rembg.remove(images[i], session=self.bg_remover)
29
- image = images.astype(np.float32) / 255.0
30
- image = image[..., :3] * image[..., -1:] + (1 - image[..., -1:])
31
- unstacked.append(image)
32
- images = np.concatenate(
33
- [
34
- np.concatenate([unstacked[1], unstacked[2]], axis=1),
35
- np.concatenate([unstacked[3], unstacked[0]], axis=1),
36
- ],
37
- axis=0,
38
- )
39
  images = np.stack([images[1], images[2], images[3], images[0]], axis=0)
40
  images = torch.from_numpy(images).permute(0, 3, 1, 2).float().cuda()
41
  images = F.interpolate(
 
1
  import numpy as np
 
2
  import torch
3
  import torch.nn.functional as F
4
  import torchvision.transforms.functional as TF
 
9
  def __init__(self, lgm):
10
  super().__init__()
11
 
 
 
12
  self.imagenet_default_mean = (0.485, 0.456, 0.406)
13
  self.imagenet_default_std = (0.229, 0.224, 0.225)
14
 
 
20
 
21
  @torch.no_grad()
22
  def __call__(self, images):
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  images = np.stack([images[1], images[2], images[3], images[0]], axis=0)
24
  images = torch.from_numpy(images).permute(0, 3, 1, 2).float().cuda()
25
  images = F.interpolate(