Spaces:
Running
on
L40S
Running
on
L40S
import os | |
from typing import List, Optional | |
from PIL import Image | |
import imageio | |
import time | |
import torch | |
from pytorch3d.io import load_objs_as_meshes, load_obj, save_obj | |
from pytorch3d.ops import interpolate_face_attributes | |
from pytorch3d.common.datatypes import Device | |
from pytorch3d.structures import Meshes | |
from pytorch3d.vis.texture_vis import texturesuv_image_matplotlib | |
from pytorch3d.renderer import ( | |
look_at_view_transform, | |
FoVPerspectiveCameras, | |
PointLights, | |
DirectionalLights, | |
AmbientLights, | |
Materials, | |
RasterizationSettings, | |
MeshRenderer, | |
MeshRasterizer, | |
SoftPhongShader, | |
TexturesUV, | |
TexturesVertex, | |
camera_position_from_spherical_angles, | |
BlendParams, | |
) | |
def render_func( | |
obj_filename, | |
elev=0, | |
azim=None, | |
resolution=512, | |
gif_dst_path='', | |
n_views=120, | |
fps=30, | |
device="cuda:0", | |
rgb=False | |
): | |
''' | |
obj_filename: path to obj file | |
gif_dst_path: | |
if set a path, will render n_views frames, then save it to a gif file | |
if not set, will render single frame, then return PIL.Image instance | |
rgb: if set true, will convert result to rgb image/frame | |
''' | |
# load mesh | |
mesh = load_objs_as_meshes([obj_filename], device=device) | |
meshes = mesh.extend(n_views) | |
if azim is None: | |
elev = torch.linspace(elev, elev, n_views+1)[:-1] | |
azim = torch.linspace(0, 360, n_views+1)[:-1] | |
# prepare R,T then compute cameras | |
R, T = look_at_view_transform(dist=1.5, elev=elev, azim=azim) | |
cameras = FoVPerspectiveCameras(device=device, R=R, T=T, fov=49.1) | |
# init pytorch3d renderer instance | |
renderer = MeshRenderer( | |
rasterizer=MeshRasterizer( | |
cameras=cameras, | |
raster_settings=RasterizationSettings( | |
image_size=resolution, | |
blur_radius=0.0, | |
faces_per_pixel=1, | |
), | |
), | |
shader=SoftPhongShader( | |
device=device, | |
cameras=cameras, | |
lights=AmbientLights(device=device), | |
blend_params=BlendParams(background_color=(1.0, 1.0, 1.0)), | |
) | |
) | |
images = renderer(meshes) | |
if gif_dst_path != '': | |
with imageio.get_writer(uri=gif_dst_path, mode='I', duration=1. / fps * 1000, loop=0) as writer: | |
for i in range(n_views): | |
frame = images[i, ..., :3] if rgb else images[i, ...] | |
frame = Image.fromarray((frame.cpu().squeeze(0) * 255).numpy().astype("uint8")) | |
writer.append_data(frame) | |
frame = images[..., :3] if rgb else images | |
frames = [Image.fromarray((fra.cpu().squeeze(0) * 255).numpy().astype("uint8")) for fra in frame] | |
return frames | |