|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
import torchvision.transforms.functional as TF |
|
from diffusers import DiffusionPipeline |
|
|
|
|
|
class LGMPipeline(DiffusionPipeline): |
|
def __init__(self, lgm): |
|
super().__init__() |
|
|
|
self.imagenet_default_mean = (0.485, 0.456, 0.406) |
|
self.imagenet_default_std = (0.229, 0.224, 0.225) |
|
|
|
lgm = lgm.half().cuda() |
|
self.register_modules(lgm=lgm) |
|
|
|
def save_ply(self, gaussians, path): |
|
self.lgm.gs.save_ply(gaussians, path) |
|
|
|
@torch.no_grad() |
|
def __call__(self, images): |
|
images = np.stack([images[1], images[2], images[3], images[0]], axis=0) |
|
images = torch.from_numpy(images).permute(0, 3, 1, 2).float().cuda() |
|
images = F.interpolate( |
|
images, |
|
size=(256, 256), |
|
mode="bilinear", |
|
align_corners=False, |
|
) |
|
images = TF.normalize( |
|
images, self.imagenet_default_mean, self.imagenet_default_std |
|
) |
|
|
|
rays_embeddings = self.lgm.prepare_default_rays("cuda", elevation=0) |
|
images = torch.cat([images, rays_embeddings], dim=1).unsqueeze(0) |
|
images = images.half().cuda() |
|
|
|
result = self.lgm(images) |
|
return result |
|
|