Spaces:
Configuration error
Configuration error
import argparse | |
import os | |
from huggingface_hub import snapshot_download | |
from tqdm import tqdm | |
from model.cloth_masker import AutoMasker | |
def parse_args(): | |
parser = argparse.ArgumentParser(description="Simple example of Preprocess Agnostic Mask") | |
parser.add_argument( | |
"--data_root_path", | |
type=str, | |
required=True, | |
help="Path to the dataset to evaluate." | |
) | |
parser.add_argument( | |
"--repo_path", | |
type=str, | |
default="zhengchong/CatVTON", | |
help=( | |
"The Path or repo name of CatVTON. " | |
), | |
) | |
args = parser.parse_args() | |
env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) | |
if env_local_rank != -1 and env_local_rank != args.local_rank: | |
args.local_rank = env_local_rank | |
return args | |
def main(args): | |
args.repo_path = snapshot_download(repo_id=args.repo_path) | |
automasker = AutoMasker( | |
densepose_ckpt=os.path.join(args.repo_path, "DensePose"), | |
schp_ckpt=os.path.join(args.repo_path, "SCHP"), | |
device='cuda', | |
) | |
for sub_folder in ['upper_body', 'lower_body', 'dresses']: | |
assert os.path.exists(os.path.join(args.data_root_path, sub_folder)), f"Folder {sub_folder} does not exist." | |
pair_txt = os.path.join(args.data_root_path, sub_folder, 'test_pairs_paired.txt') | |
assert os.path.exists(pair_txt), f"File {pair_txt} does not exist." | |
cloth_type = {'upper_body': 'upper', 'lower_body': 'lower', 'dresses': 'overall'}[sub_folder] | |
with open(pair_txt, 'r') as f: | |
lines = f.readlines() | |
output_dir = os.path.join(args.data_root_path, sub_folder, 'agnostic_masks') | |
if not os.path.exists(output_dir): | |
os.makedirs(output_dir) | |
for line in tqdm(lines, desc=f"Processing {sub_folder}"): | |
person_img, _ = line.strip().split(" ") | |
if os.path.exists(os.path.join(output_dir, person_img.replace('.jpg', '.png'))): | |
continue | |
mask = automasker( | |
os.path.join(args.data_root_path, sub_folder, 'images', person_img), | |
cloth_type | |
)['mask'] | |
mask.save(os.path.join(output_dir, person_img.replace('.jpg', '.png'))) | |
if __name__ == "__main__": | |
args = parse_args() | |
main(args) | |