vfusion3d / processor.py
jadechoghari's picture
Update processor.py
1ce37cf verified
raw
history blame
3.36 kB
import subprocess
import importlib
import sys
import logging
from transformers import ProcessorMixin
import torch
import numpy as np
from rembg import remove, new_session
from functools import partial
from torchvision.utils import save_image
from PIL import Image
from kiui.op import recenter
import kiui
# commented out the package installation part as it's not necessary for this fix
# and might be causing issues if run repeatedly
class LRMImageProcessor(ProcessorMixin):
def __init__(self, source_size=512, *args, **kwargs):
print("add super")
super().__init__(*args, **kwargs)
print("Add source")
self.source_size = source_size
self.session = new_session("isnet-general-use")
self.rembg_remove = partial(remove, session=self.session)
def preprocess_image(self, image):
print("preprocess_image")
image = np.array(image)
image = self.rembg_remove(image)
mask = self.rembg_remove(image, only_mask=True)
image = recenter(image, mask, border_ratio=0.20)
image = torch.tensor(image).permute(2, 0, 1).unsqueeze(0) / 255.0
if image.shape[1] == 4:
image = image[:, :3, ...] * image[:, 3:, ...] + (1 - image[:, 3:, ...])
image = torch.nn.functional.interpolate(image, size=(self.source_size, self.source_size), mode='bicubic', align_corners=True)
image = torch.clamp(image, 0, 1)
return image
def get_normalized_camera_intrinsics(self, intrinsics: torch.Tensor):
print("get_normalized_camera_intrinsics")
fx, fy = intrinsics[:, 0, 0], intrinsics[:, 0, 1]
cx, cy = intrinsics[:, 1, 0], intrinsics[:, 1, 1]
width, height = intrinsics[:, 2, 0], intrinsics[:, 2, 1]
fx, fy = fx / width, fy / height
cx, cy = cx / width, cy / height
return fx, fy, cx, cy
def build_camera_principle(self, RT: torch.Tensor, intrinsics: torch.Tensor):
print("build_camera_principle")
fx, fy, cx, cy = self.get_normalized_camera_intrinsics(intrinsics)
return torch.cat([
RT.reshape(-1, 12),
fx.unsqueeze(-1),
fy.unsqueeze(-1),
cx.unsqueeze(-1),
cy.unsqueeze(-1),
], dim=-1)
def _default_intrinsics(self):
print("_default_intrinsics")
fx = fy = 384
cx = cy = 256
w = h = 512
intrinsics = torch.tensor([
[fx, fy],
[cx, cy],
[w, h],
], dtype=torch.float32)
return intrinsics
def _default_source_camera(self, batch_size: int = 1):
print("_default_source_camera")
dist_to_center = 1.5
canonical_camera_extrinsics = torch.tensor([[
[0, 0, 1, 1],
[1, 0, 0, 0],
[0, 1, 0, 0],
]], dtype=torch.float32)
canonical_camera_intrinsics = self._default_intrinsics().unsqueeze(0)
source_camera = self.build_camera_principle(canonical_camera_extrinsics, canonical_camera_intrinsics)
return source_camera.repeat(batch_size, 1)
def __call__(self, image, *args, **kwargs):
print("processing...")
processed_image = self.preprocess_image(image)
print("processed done")
source_camera = self._default_source_camera(batch_size=1)
return processed_image, source_camera