File size: 3,880 Bytes
c24da45
31ec2b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c24da45
 
31ec2b7
c24da45
 
 
 
 
 
 
 
 
 
31ec2b7
 
 
c24da45
 
31ec2b7
c24da45
31ec2b7
c24da45
 
 
 
31ec2b7
c24da45
 
 
 
 
 
 
31ec2b7
 
 
 
 
 
c24da45
 
 
31ec2b7
 
c24da45
 
 
 
 
 
 
 
31ec2b7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104

# import torch
# import torch.nn as nn
# import nvdiffrast.torch as dr
# from util.flexicubes_geometry import FlexiCubesGeometry

# class Renderer(nn.Module):
#     def __init__(self, tet_grid_size, camera_angle_num, scale, geo_type):
#         super().__init__()

#         self.tet_grid_size = tet_grid_size
#         self.camera_angle_num = camera_angle_num
#         self.scale = scale
#         self.geo_type = geo_type
#         self.glctx = dr.RasterizeCudaContext()

#         if self.geo_type == "flex":
#             self.flexicubes = FlexiCubesGeometry(grid_res = self.tet_grid_size)   

#     def forward(self, data, sdf, deform, verts, tets, training=False, weight = None):

#         results = {}

#         deform = torch.tanh(deform) / self.tet_grid_size * self.scale / 0.95
#         if self.geo_type == "flex":
#             deform = deform *0.5

#             v_deformed = verts + deform

#             verts_list = []
#             faces_list = []
#             reg_list = []
#             n_shape = verts.shape[0]
#             for i in range(n_shape): 
#                 verts_i, faces_i, reg_i = self.flexicubes.get_mesh(v_deformed[i], sdf[i].squeeze(dim=-1),
#                 with_uv=False, indices=tets, weight_n=weight[i], is_training=training)

#                 verts_list.append(verts_i)
#                 faces_list.append(faces_i)
#                 reg_list.append(reg_i)       
#             verts = verts_list
#             faces = faces_list

#             flexicubes_surface_reg = torch.cat(reg_list).mean()
#             flexicubes_weight_reg = (weight ** 2).mean()
#             results["flex_surf_loss"] = flexicubes_surface_reg
#             results["flex_weight_loss"] = flexicubes_weight_reg

#         return results, verts, faces

import torch
import torch.nn as nn
# import nvdiffrast.torch as dr  # Comentado porque no se usará en CPU
from util.flexicubes_geometry import FlexiCubesGeometry

class Renderer(nn.Module):
    def __init__(self, tet_grid_size, camera_angle_num, scale, geo_type):
        super().__init__()

        self.tet_grid_size = tet_grid_size
        self.camera_angle_num = camera_angle_num
        self.scale = scale
        self.geo_type = geo_type
        
        # Eliminar el contexto de GPU y usar una alternativa o desactivarlo
        # self.glctx = dr.RasterizeCudaContext()  # Comentado porque se usa GPU

        if self.geo_type == "flex":
            self.flexicubes = FlexiCubesGeometry(grid_res=self.tet_grid_size)

    def forward(self, data, sdf, deform, verts, tets, training=False, weight=None):
        results = {}

        deform = torch.tanh(deform) / self.tet_grid_size * self.scale / 0.95
        if self.geo_type == "flex":
            deform = deform * 0.5

            v_deformed = verts + deform

            verts_list = []
            faces_list = []
            reg_list = []
            n_shape = verts.shape[0]
            for i in range(n_shape):
                # Aquí deberás adaptar el uso de FlexiCubesGeometry para que funcione sin GPU.
                verts_i, faces_i, reg_i = self.flexicubes.get_mesh(
                    v_deformed[i], sdf[i].squeeze(dim=-1),
                    with_uv=False, indices=tets, weight_n=weight[i], is_training=training
                )

                verts_list.append(verts_i)
                faces_list.append(faces_i)
                reg_list.append(reg_i)
                
            verts = verts_list
            faces = faces_list

            flexicubes_surface_reg = torch.cat(reg_list).mean()
            flexicubes_weight_reg = (weight ** 2).mean()
            results["flex_surf_loss"] = flexicubes_surface_reg
            results["flex_weight_loss"] = flexicubes_weight_reg

        return results, verts, faces