File size: 2,824 Bytes
e3e5f9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
from typing import List, Optional
from PIL import Image
import imageio
import time
import torch
from pytorch3d.io import load_objs_as_meshes, load_obj, save_obj
from pytorch3d.ops import interpolate_face_attributes
from pytorch3d.common.datatypes import Device
from pytorch3d.structures import Meshes
from pytorch3d.vis.texture_vis import texturesuv_image_matplotlib
from pytorch3d.renderer import (
    look_at_view_transform,
    FoVPerspectiveCameras,
    PointLights,
    DirectionalLights,
    AmbientLights,
    Materials,
    RasterizationSettings,
    MeshRenderer,
    MeshRasterizer,
    SoftPhongShader,
    TexturesUV,
    TexturesVertex,
    camera_position_from_spherical_angles,
    BlendParams,
)


def render(
    obj_filename, 
    elev=0, 
    azim=0, 
    resolution=512, 
    gif_dst_path='', 
    n_views=120, 
    fps=30, 
    device="cuda:0", 
    rgb=False
):
    '''
        obj_filename: path to obj file
        gif_dst_path: 
            if set a path, will render n_views frames, then save it to a gif file
            if not set, will render single frame, then return PIL.Image instance
        rgb: if set true, will convert result to rgb image/frame
    '''
    # load mesh
    mesh = load_objs_as_meshes([obj_filename], device=device)
    meshes = mesh.extend(n_views)
    
    if gif_dst_path != '':
        elev = torch.linspace(elev, elev, n_views+1)[:-1]
        azim = torch.linspace(0, 360, n_views+1)[:-1]

    # prepare R,T  then compute cameras
    R, T = look_at_view_transform(dist=1.5, elev=elev, azim=azim)
    cameras = FoVPerspectiveCameras(device=device, R=R, T=T, fov=49.1)

    # init pytorch3d renderer instance
    renderer = MeshRenderer(
        rasterizer=MeshRasterizer(
            cameras=cameras,
            raster_settings=RasterizationSettings(
                image_size=resolution,
                blur_radius=0.0,
                faces_per_pixel=1,
            ),
        ),
        shader=SoftPhongShader(
            device=device,
            cameras=cameras,
            lights=AmbientLights(device=device),
            blend_params=BlendParams(background_color=(1.0, 1.0, 1.0)),
        )
    )
    images = renderer(meshes)

    # single frame rendering
    if gif_dst_path == '': 
        frame = images[0, ..., :3] if rgb else images[0, ...]
        frame = Image.fromarray((frame.cpu().squeeze(0) * 255).numpy().astype("uint8"))
        return frame

    # orbit frames rendering
    with imageio.get_writer(uri=gif_dst_path, mode='I', duration=1. / fps * 1000, loop=0) as writer:
        for i in range(n_views):
            frame = images[i, ..., :3] if rgb else images[i, ...]
            frame = Image.fromarray((frame.cpu().squeeze(0) * 255).numpy().astype("uint8"))
            writer.append_data(frame)
        return gif_dst_path