Spaces:
Running
on
A10G
Running
on
A10G
# Copyright (c) 2023-2024, Zexin He | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# https://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import torch | |
import os | |
import argparse | |
import mcubes | |
import trimesh | |
import numpy as np | |
from PIL import Image | |
from omegaconf import OmegaConf | |
from tqdm.auto import tqdm | |
from accelerate.logging import get_logger | |
from .base_inferrer import Inferrer | |
from openlrm.datasets.cam_utils import build_camera_principle, build_camera_standard, surrounding_views_linspace, create_intrinsics | |
from openlrm.utils.logging import configure_logger | |
from openlrm.runners import REGISTRY_RUNNERS | |
from openlrm.utils.video import images_to_video | |
from openlrm.utils.hf_hub import wrap_model_hub | |
logger = get_logger(__name__) | |
def parse_configs(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--config', type=str) | |
parser.add_argument('--infer', type=str) | |
args, unknown = parser.parse_known_args() | |
cfg = OmegaConf.create() | |
cli_cfg = OmegaConf.from_cli(unknown) | |
# parse from ENV | |
if os.environ.get('APP_INFER') is not None: | |
args.infer = os.environ.get('APP_INFER') | |
if os.environ.get('APP_MODEL_NAME') is not None: | |
cli_cfg.model_name = os.environ.get('APP_MODEL_NAME') | |
if args.config is not None: | |
cfg_train = OmegaConf.load(args.config) | |
cfg.source_size = cfg_train.dataset.source_image_res | |
cfg.render_size = cfg_train.dataset.render_image.high | |
_relative_path = os.path.join(cfg_train.experiment.parent, cfg_train.experiment.child, os.path.basename(cli_cfg.model_name).split('_')[-1]) | |
cfg.video_dump = os.path.join("exps", 'videos', _relative_path) | |
cfg.mesh_dump = os.path.join("exps", 'meshes', _relative_path) | |
if args.infer is not None: | |
cfg_infer = OmegaConf.load(args.infer) | |
cfg.merge_with(cfg_infer) | |
cfg.setdefault('video_dump', os.path.join("dumps", cli_cfg.model_name, 'videos')) | |
cfg.setdefault('mesh_dump', os.path.join("dumps", cli_cfg.model_name, 'meshes')) | |
cfg.merge_with(cli_cfg) | |
""" | |
[required] | |
model_name: str | |
image_input: str | |
export_video: bool | |
export_mesh: bool | |
[special] | |
source_size: int | |
render_size: int | |
video_dump: str | |
mesh_dump: str | |
[default] | |
render_views: int | |
render_fps: int | |
mesh_size: int | |
mesh_thres: float | |
frame_size: int | |
logger: str | |
""" | |
cfg.setdefault('logger', 'INFO') | |
# assert not (args.config is not None and args.infer is not None), "Only one of config and infer should be provided" | |
assert cfg.model_name is not None, "model_name is required" | |
if not os.environ.get('APP_ENABLED', None): | |
assert cfg.image_input is not None, "image_input is required" | |
assert cfg.export_video or cfg.export_mesh, \ | |
"At least one of export_video or export_mesh should be True" | |
cfg.app_enabled = False | |
else: | |
cfg.app_enabled = True | |
return cfg | |
class LRMInferrer(Inferrer): | |
EXP_TYPE: str = 'lrm' | |
def __init__(self): | |
super().__init__() | |
self.cfg = parse_configs() | |
configure_logger( | |
stream_level=self.cfg.logger, | |
log_level=self.cfg.logger, | |
) | |
self.model = self._build_model(self.cfg).to(self.device) | |
def _build_model(self, cfg): | |
from openlrm.models import model_dict | |
hf_model_cls = wrap_model_hub(model_dict[self.EXP_TYPE]) | |
model = hf_model_cls.from_pretrained(cfg.model_name) | |
return model | |
def _default_source_camera(self, dist_to_center: float = 2.0, batch_size: int = 1, device: torch.device = torch.device('cpu')): | |
# return: (N, D_cam_raw) | |
canonical_camera_extrinsics = torch.tensor([[ | |
[1, 0, 0, 0], | |
[0, 0, -1, -dist_to_center], | |
[0, 1, 0, 0], | |
]], dtype=torch.float32, device=device) | |
canonical_camera_intrinsics = create_intrinsics( | |
f=0.75, | |
c=0.5, | |
device=device, | |
).unsqueeze(0) | |
source_camera = build_camera_principle(canonical_camera_extrinsics, canonical_camera_intrinsics) | |
return source_camera.repeat(batch_size, 1) | |
def _default_render_cameras(self, n_views: int, batch_size: int = 1, device: torch.device = torch.device('cpu')): | |
# return: (N, M, D_cam_render) | |
render_camera_extrinsics = surrounding_views_linspace(n_views=n_views, device=device) | |
render_camera_intrinsics = create_intrinsics( | |
f=0.75, | |
c=0.5, | |
device=device, | |
).unsqueeze(0).repeat(render_camera_extrinsics.shape[0], 1, 1) | |
render_cameras = build_camera_standard(render_camera_extrinsics, render_camera_intrinsics) | |
return render_cameras.unsqueeze(0).repeat(batch_size, 1, 1) | |
def infer_planes(self, image: torch.Tensor, source_cam_dist: float): | |
N = image.shape[0] | |
source_camera = self._default_source_camera(dist_to_center=source_cam_dist, batch_size=N, device=self.device) | |
planes = self.model.forward_planes(image, source_camera) | |
assert N == planes.shape[0] | |
return planes | |
def infer_video(self, planes: torch.Tensor, frame_size: int, render_size: int, render_views: int, render_fps: int, dump_video_path: str): | |
N = planes.shape[0] | |
render_cameras = self._default_render_cameras(n_views=render_views, batch_size=N, device=self.device) | |
render_anchors = torch.zeros(N, render_cameras.shape[1], 2, device=self.device) | |
render_resolutions = torch.ones(N, render_cameras.shape[1], 1, device=self.device) * render_size | |
render_bg_colors = torch.ones(N, render_cameras.shape[1], 1, device=self.device, dtype=torch.float32) * 1. | |
frames = [] | |
for i in range(0, render_cameras.shape[1], frame_size): | |
frames.append( | |
self.model.synthesizer( | |
planes=planes, | |
cameras=render_cameras[:, i:i+frame_size], | |
anchors=render_anchors[:, i:i+frame_size], | |
resolutions=render_resolutions[:, i:i+frame_size], | |
bg_colors=render_bg_colors[:, i:i+frame_size], | |
region_size=render_size, | |
) | |
) | |
# merge frames | |
frames = { | |
k: torch.cat([r[k] for r in frames], dim=1) | |
for k in frames[0].keys() | |
} | |
# dump | |
os.makedirs(os.path.dirname(dump_video_path), exist_ok=True) | |
for k, v in frames.items(): | |
if k == 'images_rgb': | |
images_to_video( | |
images=v[0], | |
output_path=dump_video_path, | |
fps=render_fps, | |
gradio_codec=self.cfg.app_enabled, | |
) | |
def infer_mesh(self, planes: torch.Tensor, mesh_size: int, mesh_thres: float, dump_mesh_path: str): | |
grid_out = self.model.synthesizer.forward_grid( | |
planes=planes, | |
grid_size=mesh_size, | |
) | |
vtx, faces = mcubes.marching_cubes(grid_out['sigma'].squeeze(0).squeeze(-1).cpu().numpy(), mesh_thres) | |
vtx = vtx / (mesh_size - 1) * 2 - 1 | |
vtx_tensor = torch.tensor(vtx, dtype=torch.float32, device=self.device).unsqueeze(0) | |
vtx_colors = self.model.synthesizer.forward_points(planes, vtx_tensor)['rgb'].squeeze(0).cpu().numpy() # (0, 1) | |
vtx_colors = (vtx_colors * 255).astype(np.uint8) | |
mesh = trimesh.Trimesh(vertices=vtx, faces=faces, vertex_colors=vtx_colors) | |
# dump | |
os.makedirs(os.path.dirname(dump_mesh_path), exist_ok=True) | |
mesh.export(dump_mesh_path) | |
def infer_single(self, image_path: str, source_cam_dist: float, export_video: bool, export_mesh: bool, dump_video_path: str, dump_mesh_path: str): | |
source_size = self.cfg.source_size | |
render_size = self.cfg.render_size | |
render_views = self.cfg.render_views | |
render_fps = self.cfg.render_fps | |
mesh_size = self.cfg.mesh_size | |
mesh_thres = self.cfg.mesh_thres | |
frame_size = self.cfg.frame_size | |
source_cam_dist = self.cfg.source_cam_dist if source_cam_dist is None else source_cam_dist | |
# prepare image: [1, C_img, H_img, W_img], 0-1 scale | |
image = torch.from_numpy(np.array(Image.open(image_path))).to(self.device) | |
image = image.permute(2, 0, 1).unsqueeze(0) / 255.0 | |
if image.shape[1] == 4: # RGBA | |
image = image[:, :3, ...] * image[:, 3:, ...] + (1 - image[:, 3:, ...]) | |
image = torch.nn.functional.interpolate(image, size=(source_size, source_size), mode='bicubic', align_corners=True) | |
image = torch.clamp(image, 0, 1) | |
with torch.no_grad(): | |
planes = self.infer_planes(image, source_cam_dist=source_cam_dist) | |
results = {} | |
if export_video: | |
frames = self.infer_video(planes, frame_size=frame_size, render_size=render_size, render_views=render_views, render_fps=render_fps, dump_video_path=dump_video_path) | |
results.update({ | |
'frames': frames, | |
}) | |
if export_mesh: | |
mesh = self.infer_mesh(planes, mesh_size=mesh_size, mesh_thres=mesh_thres, dump_mesh_path=dump_mesh_path) | |
results.update({ | |
'mesh': mesh, | |
}) | |
def infer(self): | |
image_paths = [] | |
if os.path.isfile(self.cfg.image_input): | |
omit_prefix = os.path.dirname(self.cfg.image_input) | |
image_paths.append(self.cfg.image_input) | |
else: | |
omit_prefix = self.cfg.image_input | |
for root, dirs, files in os.walk(self.cfg.image_input): | |
for file in files: | |
if file.endswith('.png'): | |
image_paths.append(os.path.join(root, file)) | |
image_paths.sort() | |
# alloc to each DDP worker | |
image_paths = image_paths[self.accelerator.process_index::self.accelerator.num_processes] | |
for image_path in tqdm(image_paths, disable=not self.accelerator.is_local_main_process): | |
# prepare dump paths | |
image_name = os.path.basename(image_path) | |
uid = image_name.split('.')[0] | |
subdir_path = os.path.dirname(image_path).replace(omit_prefix, '') | |
subdir_path = subdir_path[1:] if subdir_path.startswith('/') else subdir_path | |
dump_video_path = os.path.join( | |
self.cfg.video_dump, | |
subdir_path, | |
f'{uid}.mov', | |
) | |
dump_mesh_path = os.path.join( | |
self.cfg.mesh_dump, | |
subdir_path, | |
f'{uid}.ply', | |
) | |
self.infer_single( | |
image_path, | |
source_cam_dist=None, | |
export_video=self.cfg.export_video, | |
export_mesh=self.cfg.export_mesh, | |
dump_video_path=dump_video_path, | |
dump_mesh_path=dump_mesh_path, | |
) | |