Spaces:
Runtime error
Runtime error
import itertools | |
import json | |
import os | |
from collections import Counter | |
import numpy as np | |
import trimesh | |
from matplotlib.path import Path | |
from sklearn.neighbors import KNeighborsClassifier | |
def load_segmentation(path, shape): | |
""" | |
Get a segmentation mask for a given image | |
Arguments: | |
path: path to the segmentation json file | |
shape: shape of the output mask | |
Returns: | |
Returns a segmentation mask | |
""" | |
with open(path) as json_file: | |
dict = json.load(json_file) | |
segmentations = [] | |
for key, val in dict.items(): | |
if not key.startswith("item"): | |
continue | |
# Each item can have multiple polygons. Combine them to one | |
# segmentation_coord = list(itertools.chain.from_iterable(val['segmentation'])) | |
# segmentation_coord = np.round(np.array(segmentation_coord)).astype(int) | |
coordinates = [] | |
for segmentation_coord in val["segmentation"]: | |
# The format before is [x1,y1, x2, y2, ....] | |
x = segmentation_coord[::2] | |
y = segmentation_coord[1::2] | |
xy = np.vstack((x, y)).T | |
coordinates.append(xy) | |
segmentations.append({ | |
"type": val["category_name"], | |
"type_id": val["category_id"], | |
"coordinates": coordinates, | |
}) | |
return segmentations | |
def smpl_to_recon_labels(recon, smpl, k=1): | |
""" | |
Get the bodypart labels for the recon object by using the labels from the corresponding smpl object | |
Arguments: | |
recon: trimesh object (fully clothed model) | |
shape: trimesh object (smpl model) | |
k: number of nearest neighbours to use | |
Returns: | |
Returns a dictionary containing the bodypart and the corresponding indices | |
""" | |
smpl_vert_segmentation = json.load( | |
open(os.path.join(os.path.dirname(__file__), "smpl_vert_segmentation.json")) | |
) | |
n = smpl.vertices.shape[0] | |
y = np.array([None] * n) | |
for key, val in smpl_vert_segmentation.items(): | |
y[val] = key | |
classifier = KNeighborsClassifier(n_neighbors=1) | |
classifier.fit(smpl.vertices, y) | |
y_pred = classifier.predict(recon.vertices) | |
recon_labels = {} | |
for key in smpl_vert_segmentation.keys(): | |
recon_labels[key] = list(np.argwhere(y_pred == key).flatten().astype(int)) | |
return recon_labels | |
def extract_cloth(recon, segmentation, K, R, t, smpl=None): | |
""" | |
Extract a portion of a mesh using 2d segmentation coordinates | |
Arguments: | |
recon: fully clothed mesh | |
seg_coord: segmentation coordinates in 2D (NDC) | |
K: intrinsic matrix of the projection | |
R: rotation matrix of the projection | |
t: translation vector of the projection | |
Returns: | |
Returns a submesh using the segmentation coordinates | |
""" | |
seg_coord = segmentation["coord_normalized"] | |
mesh = trimesh.Trimesh(recon.vertices, recon.faces) | |
extrinsic = np.zeros((3, 4)) | |
extrinsic[:3, :3] = R | |
extrinsic[:, 3] = t | |
P = K[:3, :3] @ extrinsic | |
P_inv = np.linalg.pinv(P) | |
# Each segmentation can contain multiple polygons | |
# We need to check them separately | |
points_so_far = [] | |
faces = recon.faces | |
for polygon in seg_coord: | |
n = len(polygon) | |
coords_h = np.hstack((polygon, np.ones((n, 1)))) | |
# Apply the inverse projection on homogeneus 2D coordinates to get the corresponding 3d Coordinates | |
XYZ = P_inv @ coords_h[:, :, None] | |
XYZ = XYZ.reshape((XYZ.shape[0], XYZ.shape[1])) | |
XYZ = XYZ[:, :3] / XYZ[:, 3, None] | |
p = Path(XYZ[:, :2]) | |
grid = p.contains_points(recon.vertices[:, :2]) | |
indeces = np.argwhere(grid == True) | |
points_so_far += list(indeces.flatten()) | |
if smpl is not None: | |
num_verts = recon.vertices.shape[0] | |
recon_labels = smpl_to_recon_labels(recon, smpl) | |
body_parts_to_remove = [ | |
"rightHand", | |
"leftToeBase", | |
"leftFoot", | |
"rightFoot", | |
"head", | |
"leftHandIndex1", | |
"rightHandIndex1", | |
"rightToeBase", | |
"leftHand", | |
"rightHand", | |
] | |
type = segmentation["type_id"] | |
# Remove additional bodyparts that are most likely not part of the segmentation but might intersect (e.g. hand in front of torso) | |
# https://github.com/switchablenorms/DeepFashion2 | |
# Short sleeve clothes | |
if type == 1 or type == 3 or type == 10: | |
body_parts_to_remove += ["leftForeArm", "rightForeArm"] | |
# No sleeves at all or lower body clothes | |
elif (type == 5 or type == 6 or type == 12 or type == 13 or type == 8 or type == 9): | |
body_parts_to_remove += [ | |
"leftForeArm", | |
"rightForeArm", | |
"leftArm", | |
"rightArm", | |
] | |
# Shorts | |
elif type == 7: | |
body_parts_to_remove += [ | |
"leftLeg", | |
"rightLeg", | |
"leftForeArm", | |
"rightForeArm", | |
"leftArm", | |
"rightArm", | |
] | |
verts_to_remove = list( | |
itertools.chain.from_iterable([recon_labels[part] for part in body_parts_to_remove]) | |
) | |
label_mask = np.zeros(num_verts, dtype=bool) | |
label_mask[verts_to_remove] = True | |
seg_mask = np.zeros(num_verts, dtype=bool) | |
seg_mask[points_so_far] = True | |
# Remove points that belong to other bodyparts | |
# If a vertice in pointsSoFar is included in the bodyparts to remove, then these points should be removed | |
extra_verts_to_remove = np.array(list(seg_mask) and list(label_mask)) | |
combine_mask = np.zeros(num_verts, dtype=bool) | |
combine_mask[points_so_far] = True | |
combine_mask[extra_verts_to_remove] = False | |
all_indices = np.argwhere(combine_mask == True).flatten() | |
i_x = np.where(np.in1d(faces[:, 0], all_indices))[0] | |
i_y = np.where(np.in1d(faces[:, 1], all_indices))[0] | |
i_z = np.where(np.in1d(faces[:, 2], all_indices))[0] | |
faces_to_keep = np.array(list(set(i_x).union(i_y).union(i_z))) | |
mask = np.zeros(len(recon.faces), dtype=bool) | |
if len(faces_to_keep) > 0: | |
mask[faces_to_keep] = True | |
mesh.update_faces(mask) | |
mesh.remove_unreferenced_vertices() | |
# mesh.rezero() | |
return mesh | |
return None | |