Spaces:
Runtime error
Runtime error
import pdb | |
from pathlib import Path | |
import sys | |
PROJECT_ROOT = Path(__file__).absolute().parents[0].absolute() | |
sys.path.insert(0, str(PROJECT_ROOT)) | |
import os | |
import torch | |
import numpy as np | |
import cv2 | |
import torchvision.transforms as transforms | |
from torch.utils.data import DataLoader | |
from datasets.simple_extractor_dataset import SimpleFolderDataset | |
from utils.transforms import transform_logits | |
from tqdm import tqdm | |
from PIL import Image | |
def colorize_parsing(parsing_result): | |
label_map = { | |
0: "background", 1: "hat", 2: "hair", 3: "sunglasses", 4: "upper_clothes", | |
5: "skirt", 6: "pants", 7: "dress", 8: "belt", 9: "left_shoe", | |
10: "right_shoe", 11: "head", 12: "left_leg", 13: "right_leg", | |
14: "left_arm", 15: "right_arm", 16: "bag", 17: "scarf" | |
} | |
# Define colors for each part (RGB) | |
color_map = { | |
0: (0, 0, 0), # Background | |
1: (128, 0, 0), # Hat | |
2: (255, 0, 0), # Hair | |
3: (0, 255, 0), # Sunglasses | |
4: (0, 0, 255), # Upper-clothes | |
5: (255, 255, 0), # Skirt | |
6: (255, 0, 255), # Pants | |
7: (0, 255, 255), # Dress | |
8: (128, 128, 0), # Belt | |
9: (0, 128, 128), # Left-shoe | |
10: (128, 0, 128), # Right-shoe | |
11: (128, 128, 128),# Head | |
12: (64, 0, 0), # Left-leg | |
13: (192, 0, 0), # Right-leg | |
14: (64, 128, 0), # Left-arm | |
15: (192, 128, 0), # Right-arm | |
16: (64, 0, 128), # Bag | |
17: (192, 0, 128), # Scarf | |
} | |
height, width = parsing_result.shape | |
colored_parsing = np.zeros((height, width, 3), dtype=np.uint8) | |
for label, color in color_map.items(): | |
colored_parsing[parsing_result == label] = color | |
return colored_parsing | |
def add_numbers_to_image(colored_parsing, parsing_result): | |
label_map = { | |
0: "background", 1: "hat", 2: "hair", 3: "sunglasses", 4: "upper_clothes", | |
5: "skirt", 6: "pants", 7: "dress", 8: "belt", 9: "left_shoe", | |
10: "right_shoe", 11: "head", 12: "left_leg", 13: "right_leg", | |
14: "left_arm", 15: "right_arm", 16: "bag", 17: "scarf" | |
} | |
height, width = parsing_result.shape | |
numbered_image = colored_parsing.copy() | |
for label in range(18): # 0 to 17 | |
mask = (parsing_result == label) | |
if np.any(mask): | |
y, x = np.where(mask) | |
center_y, center_x = int(np.mean(y)), int(np.mean(x)) | |
cv2.putText(numbered_image, str(label), (center_x, center_y), | |
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1, cv2.LINE_AA) | |
return numbered_image | |
def get_palette(num_cls): | |
""" Returns the color map for visualizing the segmentation mask. | |
Args: | |
num_cls: Number of classes | |
Returns: | |
The color map | |
""" | |
n = num_cls | |
palette = [0] * (n * 3) | |
for j in range(0, n): | |
lab = j | |
palette[j * 3 + 0] = 0 | |
palette[j * 3 + 1] = 0 | |
palette[j * 3 + 2] = 0 | |
i = 0 | |
while lab: | |
palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i)) | |
palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i)) | |
palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i)) | |
i += 1 | |
lab >>= 3 | |
return palette | |
def delete_irregular(logits_result): | |
parsing_result = np.argmax(logits_result, axis=2) | |
upper_cloth = np.where(parsing_result == 4, 255, 0) | |
contours, hierarchy = cv2.findContours(upper_cloth.astype(np.uint8), | |
cv2.RETR_CCOMP, cv2.CHAIN_APPROX_TC89_L1) | |
area = [] | |
for i in range(len(contours)): | |
a = cv2.contourArea(contours[i], True) | |
area.append(abs(a)) | |
if len(area) != 0: | |
top = area.index(max(area)) | |
M = cv2.moments(contours[top]) | |
cY = int(M["m01"] / M["m00"]) | |
dresses = np.where(parsing_result == 7, 255, 0) | |
contours_dress, hierarchy_dress = cv2.findContours(dresses.astype(np.uint8), | |
cv2.RETR_CCOMP, cv2.CHAIN_APPROX_TC89_L1) | |
area_dress = [] | |
for j in range(len(contours_dress)): | |
a_d = cv2.contourArea(contours_dress[j], True) | |
area_dress.append(abs(a_d)) | |
if len(area_dress) != 0: | |
top_dress = area_dress.index(max(area_dress)) | |
M_dress = cv2.moments(contours_dress[top_dress]) | |
cY_dress = int(M_dress["m01"] / M_dress["m00"]) | |
wear_type = "dresses" | |
if len(area) != 0: | |
if len(area_dress) != 0 and cY_dress > cY: | |
irregular_list = np.array([4, 5, 6]) | |
logits_result[:, :, irregular_list] = -1 | |
else: | |
irregular_list = np.array([5, 6, 7, 8, 9, 10, 12, 13]) | |
logits_result[:cY, :, irregular_list] = -1 | |
wear_type = "cloth_pant" | |
parsing_result = np.argmax(logits_result, axis=2) | |
# pad border | |
parsing_result = np.pad(parsing_result, pad_width=1, mode='constant', constant_values=0) | |
return parsing_result, wear_type | |
def hole_fill(img): | |
img_copy = img.copy() | |
mask = np.zeros((img.shape[0] + 2, img.shape[1] + 2), dtype=np.uint8) | |
cv2.floodFill(img, mask, (0, 0), 255) | |
img_inverse = cv2.bitwise_not(img) | |
dst = cv2.bitwise_or(img_copy, img_inverse) | |
return dst | |
def refine_mask(mask): | |
contours, hierarchy = cv2.findContours(mask.astype(np.uint8), | |
cv2.RETR_CCOMP, cv2.CHAIN_APPROX_TC89_L1) | |
area = [] | |
for j in range(len(contours)): | |
a_d = cv2.contourArea(contours[j], True) | |
area.append(abs(a_d)) | |
refine_mask = np.zeros_like(mask).astype(np.uint8) | |
if len(area) != 0: | |
i = area.index(max(area)) | |
cv2.drawContours(refine_mask, contours, i, color=255, thickness=-1) | |
# keep large area in skin case | |
for j in range(len(area)): | |
if j != i and area[i] > 2000: | |
cv2.drawContours(refine_mask, contours, j, color=255, thickness=-1) | |
return refine_mask | |
def refine_hole(parsing_result_filled, parsing_result, arm_mask): | |
filled_hole = cv2.bitwise_and(np.where(parsing_result_filled == 4, 255, 0), | |
np.where(parsing_result != 4, 255, 0)) - arm_mask * 255 | |
contours, hierarchy = cv2.findContours(filled_hole, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_TC89_L1) | |
refine_hole_mask = np.zeros_like(parsing_result).astype(np.uint8) | |
for i in range(len(contours)): | |
a = cv2.contourArea(contours[i], True) | |
# keep hole > 2000 pixels | |
if abs(a) > 2000: | |
cv2.drawContours(refine_hole_mask, contours, i, color=255, thickness=-1) | |
return refine_hole_mask + arm_mask | |
def onnx_inference(session, lip_session, input_dir): | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229]) | |
]) | |
dataset = SimpleFolderDataset(root=input_dir, input_size=[512, 512], transform=transform) | |
dataloader = DataLoader(dataset) | |
with torch.no_grad(): | |
for _, batch in enumerate(tqdm(dataloader)): | |
image, meta = batch | |
c = meta['center'].numpy()[0] | |
s = meta['scale'].numpy()[0] | |
w = meta['width'].numpy()[0] | |
h = meta['height'].numpy()[0] | |
output = session.run(None, {"input.1": image.numpy().astype(np.float32)}) | |
upsample = torch.nn.Upsample(size=[512, 512], mode='bilinear', align_corners=True) | |
upsample_output = upsample(torch.from_numpy(output[1][0]).unsqueeze(0)) | |
upsample_output = upsample_output.squeeze() | |
upsample_output = upsample_output.permute(1, 2, 0) # CHW -> HWC | |
logits_result = transform_logits(upsample_output.data.cpu().numpy(), c, s, w, h, input_size=[512, 512]) | |
parsing_result = np.argmax(logits_result, axis=2) | |
parsing_result = np.pad(parsing_result, pad_width=1, mode='constant', constant_values=0) | |
# try holefilling the clothes part | |
arm_mask = (parsing_result == 14).astype(np.float32) \ | |
+ (parsing_result == 15).astype(np.float32) | |
upper_cloth_mask = (parsing_result == 4).astype(np.float32) + arm_mask | |
img = np.where(upper_cloth_mask, 255, 0) | |
dst = hole_fill(img.astype(np.uint8)) | |
parsing_result_filled = dst / 255 * 4 | |
parsing_result_woarm = np.where(parsing_result_filled == 4, parsing_result_filled, parsing_result) | |
# add back arm and refined hole between arm and cloth | |
refine_hole_mask = refine_hole(parsing_result_filled.astype(np.uint8), parsing_result.astype(np.uint8), | |
arm_mask.astype(np.uint8)) | |
parsing_result = np.where(refine_hole_mask, parsing_result, parsing_result_woarm) | |
# remove padding | |
parsing_result = parsing_result[1:-1, 1:-1] | |
dataset_lip = SimpleFolderDataset(root=input_dir, input_size=[473, 473], transform=transform) | |
dataloader_lip = DataLoader(dataset_lip) | |
with torch.no_grad(): | |
for _, batch in enumerate(tqdm(dataloader_lip)): | |
image, meta = batch | |
c = meta['center'].numpy()[0] | |
s = meta['scale'].numpy()[0] | |
w = meta['width'].numpy()[0] | |
h = meta['height'].numpy()[0] | |
output_lip = lip_session.run(None, {"input.1": image.numpy().astype(np.float32)}) | |
upsample = torch.nn.Upsample(size=[473, 473], mode='bilinear', align_corners=True) | |
upsample_output_lip = upsample(torch.from_numpy(output_lip[1][0]).unsqueeze(0)) | |
upsample_output_lip = upsample_output_lip.squeeze() | |
upsample_output_lip = upsample_output_lip.permute(1, 2, 0) # CHW -> HWC | |
logits_result_lip = transform_logits(upsample_output_lip.data.cpu().numpy(), c, s, w, h, | |
input_size=[473, 473]) | |
parsing_result_lip = np.argmax(logits_result_lip, axis=2) | |
# add neck parsing result | |
neck_mask = np.logical_and(np.logical_not((parsing_result_lip == 13).astype(np.float32)), | |
(parsing_result == 11).astype(np.float32)) | |
parsing_result = np.where(neck_mask, 18, parsing_result) | |
palette = get_palette(19) | |
output_img = Image.fromarray(np.asarray(parsing_result, dtype=np.uint8)) | |
output_img.putpalette(palette) | |
face_mask = torch.from_numpy((parsing_result == 11).astype(np.float32)) | |
# Colorize the parsing result | |
colored_parsing = colorize_parsing(parsing_result) | |
# Add numbers to the colorized image | |
numbered_parsing = add_numbers_to_image(colored_parsing, parsing_result) | |
# Save the numbered parsing result | |
output_filename = "colored_parsing.png" | |
cv2.imwrite(output_filename, cv2.cvtColor(numbered_parsing, cv2.COLOR_RGB2BGR)) | |
return output_img, face_mask | |