File size: 1,251 Bytes
5e1c565
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from diffusers import DiffusionPipeline


class LGMPipeline(DiffusionPipeline):
    def __init__(self, lgm):
        super().__init__()

        self.imagenet_default_mean = (0.485, 0.456, 0.406)
        self.imagenet_default_std = (0.229, 0.224, 0.225)

        lgm = lgm.half().cuda()
        self.register_modules(lgm=lgm)

    def save_ply(self, gaussians, path):
        self.lgm.gs.save_ply(gaussians, path)

    @torch.no_grad()
    def __call__(self, images):
        images = np.stack([images[1], images[2], images[3], images[0]], axis=0)
        images = torch.from_numpy(images).permute(0, 3, 1, 2).float().cuda()
        images = F.interpolate(
            images,
            size=(256, 256),
            mode="bilinear",
            align_corners=False,
        )
        images = TF.normalize(
            images, self.imagenet_default_mean, self.imagenet_default_std
        )

        rays_embeddings = self.lgm.prepare_default_rays("cuda", elevation=0)
        images = torch.cat([images, rays_embeddings], dim=1).unsqueeze(0)
        images = images.half().cuda()

        result = self.lgm(images)
        return result