IDM-VTON
update IDM-VTON Demo
938e515
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import pickle
from functools import lru_cache
from typing import Dict, Optional, Tuple
import torch
from detectron2.utils.file_io import PathManager
from densepose.data.meshes.catalog import MeshCatalog, MeshInfo
def _maybe_copy_to_device(
attribute: Optional[torch.Tensor], device: torch.device
) -> Optional[torch.Tensor]:
if attribute is None:
return None
return attribute.to(device)
class Mesh:
def __init__(
self,
vertices: Optional[torch.Tensor] = None,
faces: Optional[torch.Tensor] = None,
geodists: Optional[torch.Tensor] = None,
symmetry: Optional[Dict[str, torch.Tensor]] = None,
texcoords: Optional[torch.Tensor] = None,
mesh_info: Optional[MeshInfo] = None,
device: Optional[torch.device] = None,
):
"""
Args:
vertices (tensor [N, 3] of float32): vertex coordinates in 3D
faces (tensor [M, 3] of long): triangular face represented as 3
vertex indices
geodists (tensor [N, N] of float32): geodesic distances from
vertex `i` to vertex `j` (optional, default: None)
symmetry (dict: str -> tensor): various mesh symmetry data:
- "vertex_transforms": vertex mapping under horizontal flip,
tensor of size [N] of type long; vertex `i` is mapped to
vertex `tensor[i]` (optional, default: None)
texcoords (tensor [N, 2] of float32): texture coordinates, i.e. global
and normalized mesh UVs (optional, default: None)
mesh_info (MeshInfo type): necessary to load the attributes on-the-go,
can be used instead of passing all the variables one by one
device (torch.device): device of the Mesh. If not provided, will use
the device of the vertices
"""
self._vertices = vertices
self._faces = faces
self._geodists = geodists
self._symmetry = symmetry
self._texcoords = texcoords
self.mesh_info = mesh_info
self.device = device
assert self._vertices is not None or self.mesh_info is not None
all_fields = [self._vertices, self._faces, self._geodists, self._texcoords]
if self.device is None:
for field in all_fields:
if field is not None:
self.device = field.device
break
if self.device is None and symmetry is not None:
for key in symmetry:
self.device = symmetry[key].device
break
self.device = torch.device("cpu") if self.device is None else self.device
assert all([var.device == self.device for var in all_fields if var is not None])
if symmetry:
assert all(symmetry[key].device == self.device for key in symmetry)
if texcoords and vertices:
assert len(vertices) == len(texcoords)
def to(self, device: torch.device):
device_symmetry = self._symmetry
if device_symmetry:
device_symmetry = {key: value.to(device) for key, value in device_symmetry.items()}
return Mesh(
_maybe_copy_to_device(self._vertices, device),
_maybe_copy_to_device(self._faces, device),
_maybe_copy_to_device(self._geodists, device),
device_symmetry,
_maybe_copy_to_device(self._texcoords, device),
self.mesh_info,
device,
)
@property
def vertices(self):
if self._vertices is None and self.mesh_info is not None:
self._vertices = load_mesh_data(self.mesh_info.data, "vertices", self.device)
return self._vertices
@property
def faces(self):
if self._faces is None and self.mesh_info is not None:
self._faces = load_mesh_data(self.mesh_info.data, "faces", self.device)
return self._faces
@property
def geodists(self):
if self._geodists is None and self.mesh_info is not None:
self._geodists = load_mesh_auxiliary_data(self.mesh_info.geodists, self.device)
return self._geodists
@property
def symmetry(self):
if self._symmetry is None and self.mesh_info is not None:
self._symmetry = load_mesh_symmetry(self.mesh_info.symmetry, self.device)
return self._symmetry
@property
def texcoords(self):
if self._texcoords is None and self.mesh_info is not None:
self._texcoords = load_mesh_auxiliary_data(self.mesh_info.texcoords, self.device)
return self._texcoords
def get_geodists(self):
if self.geodists is None:
self.geodists = self._compute_geodists()
return self.geodists
def _compute_geodists(self):
# TODO: compute using Laplace-Beltrami
geodists = None
return geodists
def load_mesh_data(
mesh_fpath: str, field: str, device: Optional[torch.device] = None
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
with PathManager.open(mesh_fpath, "rb") as hFile:
# pyre-fixme[7]: Expected `Tuple[Optional[Tensor], Optional[Tensor]]` but
# got `Tensor`.
return torch.as_tensor(pickle.load(hFile)[field], dtype=torch.float).to(device)
return None
def load_mesh_auxiliary_data(
fpath: str, device: Optional[torch.device] = None
) -> Optional[torch.Tensor]:
fpath_local = PathManager.get_local_path(fpath)
with PathManager.open(fpath_local, "rb") as hFile:
return torch.as_tensor(pickle.load(hFile), dtype=torch.float).to(device)
return None
@lru_cache()
def load_mesh_symmetry(
symmetry_fpath: str, device: Optional[torch.device] = None
) -> Optional[Dict[str, torch.Tensor]]:
with PathManager.open(symmetry_fpath, "rb") as hFile:
symmetry_loaded = pickle.load(hFile)
symmetry = {
"vertex_transforms": torch.as_tensor(
symmetry_loaded["vertex_transforms"], dtype=torch.long
).to(device),
}
return symmetry
return None
@lru_cache()
def create_mesh(mesh_name: str, device: Optional[torch.device] = None) -> Mesh:
return Mesh(mesh_info=MeshCatalog[mesh_name], device=device)