import argparse import os os.environ['CUDA_HOME'] = '/usr/local/cuda' os.environ['PATH'] = os.environ['PATH'] + ':/usr/local/cuda/bin' from datetime import datetime import gradio as gr import spaces import numpy as np import torch from diffusers.image_processor import VaeImageProcessor from huggingface_hub import snapshot_download from PIL import Image torch.jit.script = lambda f: f from model.cloth_masker import AutoMasker, vis_mask from model.pipeline import CatVTONPipeline from utils import init_weight_dtype, resize_and_crop, resize_and_padding def parse_args(): parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument( "--base_model_path", type=str, default="booksforcharlie/stable-diffusion-inpainting", help=( "The path to the base model to use for evaluation. This can be a local path or a model identifier from the Model Hub." ), ) parser.add_argument( "--resume_path", type=str, default="zhengchong/CatVTON", help=( "The Path to the checkpoint of trained tryon model." ), ) parser.add_argument( "--output_dir", type=str, default="resource/demo/output", help="The output directory where the model predictions will be written.", ) parser.add_argument( "--width", type=int, default=768, help=( "The resolution for input images, all the images in the train/validation dataset will be resized to this" " resolution" ), ) parser.add_argument( "--height", type=int, default=1024, help=( "The resolution for input images, all the images in the train/validation dataset will be resized to this" " resolution" ), ) parser.add_argument( "--repaint", action="store_true", help="Whether to repaint the result image with the original background." ) parser.add_argument( "--allow_tf32", action="store_true", default=True, help=( "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" ), ) parser.add_argument( "--mixed_precision", type=str, default="bf16", choices=["no", "fp16", "bf16"], help=( "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." ), ) args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: args.local_rank = env_local_rank return args def image_grid(imgs, rows, cols): assert len(imgs) == rows * cols w, h = imgs[0].size grid = Image.new("RGB", size=(cols * w, rows * h)) for i, img in enumerate(imgs): grid.paste(img, box=(i % cols * w, i // cols * h)) return grid args = parse_args() repo_path = snapshot_download(repo_id=args.resume_path) # Pipeline pipeline = CatVTONPipeline( base_ckpt=args.base_model_path, attn_ckpt=repo_path, attn_ckpt_version="mix", weight_dtype=init_weight_dtype(args.mixed_precision), use_tf32=args.allow_tf32, device='cuda' ) # AutoMasker mask_processor = VaeImageProcessor(vae_scale_factor=8, do_normalize=False, do_binarize=True, do_convert_grayscale=True) automasker = AutoMasker( densepose_ckpt=os.path.join(repo_path, "DensePose"), schp_ckpt=os.path.join(repo_path, "SCHP"), device='cuda', ) @spaces.GPU(duration=120) def submit_function( person_image, cloth_image, cloth_type, num_inference_steps, guidance_scale, seed, show_type ): # Check if layers exist and are not empty if "layers" in person_image and person_image["layers"]: person_image, mask = person_image["background"], person_image["layers"][0] mask = Image.open(mask).convert("L") if len(np.unique(np.array(mask))) == 1: # All mask values are the same (empty mask) mask = None else: mask = np.array(mask) mask[mask > 0] = 255 # Convert to binary mask (0 or 255) mask = Image.fromarray(mask) else: person_image = person_image["background"] mask = None # No mask is provided, it will be auto-generated tmp_folder = args.output_dir date_str = datetime.now().strftime("%Y%m%d%H%M%S") result_save_path = os.path.join(tmp_folder, date_str[:8], date_str[8:] + ".png") if not os.path.exists(os.path.join(tmp_folder, date_str[:8])): os.makedirs(os.path.join(tmp_folder, date_str[:8])) generator = None if seed != -1: generator = torch.Generator(device='cuda').manual_seed(seed) person_image = Image.open(person_image).convert("RGB") cloth_image = Image.open(cloth_image).convert("RGB") person_image = resize_and_crop(person_image, (args.width, args.height)) cloth_image = resize_and_padding(cloth_image, (args.width, args.height)) # Process mask if mask is not None: mask = resize_and_crop(mask, (args.width, args.height)) else: mask = automasker( person_image, cloth_type )['mask'] mask = mask_processor.blur(mask, blur_factor=9) # Inference result_image = pipeline( image=person_image, condition_image=cloth_image, mask=mask, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator )[0] # Post-process masked_person = vis_mask(person_image, mask) save_result_image = image_grid([person_image, masked_person, cloth_image, result_image], 1, 4) save_result_image.save(result_save_path) if show_type == "result only": return result_image else: width, height = person_image.size if show_type == "input & result": condition_width = width // 2 conditions = image_grid([person_image, cloth_image], 2, 1) else: condition_width = width // 3 conditions = image_grid([person_image, masked_person, cloth_image], 3, 1) conditions = conditions.resize((condition_width, height), Image.NEAREST) new_result_image = Image.new("RGB", (width + condition_width + 5, height)) new_result_image.paste(conditions, (0, 0)) new_result_image.paste(result_image, (condition_width + 5, 0)) return new_result_image def person_example_fn(image_path): return image_path HEADER = """

๐Ÿˆ CatVTON: Concatenation Is All You Need for Virtual Try-On with Diffusion Models

arxiv huggingface GitHub Demo Demo webpage License

ยท This demo and our weights are only for Non-commercial Use.
ยท You can try CatVTON in our HuggingFace Space or our online demo (run on 3090).
ยท Thanks to ZeroGPU for providing A100 for our HuggingFace Space.
ยท SafetyChecker is set to filter NSFW content, but it may block normal results too. Please adjust the `seed` for normal outcomes.
""" def app_gradio(): custom_css = """ @media (max-width: 768px) { .gr-column { width: 100% !important; padding: 0.5rem; } .gr-row { flex-direction: column !important; } .container { margin: 0.5rem !important; padding: 1rem !important; } button.primary-btn { padding: 0.8rem 1rem; font-size: 1rem; } } @media (max-width: 480px) { .gr-slider, .gr-radio-group, .gr-markdown, .gr-accordion { font-size: 0.9rem !important; padding: 0.5rem; } button.primary-btn { font-size: 0.8rem; padding: 0.6rem 0.8rem; } .gr-form { margin: 0.5rem; } } button.primary-btn { background: linear-gradient(135deg, #2541b2 0%, #1a237e 100%); transition: all 0.3s ease; border: none; box-shadow: 0 2px 4px rgba(0,0,0,0.1); color: white !important; } button.primary-btn:hover { transform: translateY(-2px); box-shadow: 0 4px 8px rgba(0,0,0,0.2); } .gr-button { background: linear-gradient(135deg, #2541b2 0%, #1a237e 100%); color: white !important; border: none; transition: all 0.3s ease; } .gr-button:hover { opacity: 0.9; transform: translateY(-2px); } body { background: linear-gradient(135deg, #f8f9fa 0%, #e8eaf6 100%); } .container { border-radius: 12px; box-shadow: 0 4px 6px rgba(0,0,0,0.1); } .gr-form { border-radius: 8px; background: white; box-shadow: 0 2px 4px rgba(0,0,0,0.05); } .gr-radio-group { background: white; padding: 12px; border-radius: 8px; } .gr-accordion { border-radius: 8px; overflow: hidden; } /* Force white text in buttons */ button.primary-btn span { color: white !important; } .gr-button span { color: white !important; } """ with gr.Blocks(title="Deradh Virtual Try-On", css=custom_css) as demo: gr.Markdown( """

Deradh Virtual Try-On Experience

Visit Deradh.com
Experience the future of fashion with our AI-powered virtual try-on technology
""" ) with gr.Row(): with gr.Column(scale=1, min_width="auto"): with gr.Row(): image_path = gr.Image( type="filepath", interactive=True, visible=False, ) person_image = gr.ImageEditor( interactive=True, label="Upload Your Photo", type="filepath" ) with gr.Row(): with gr.Column(scale=1, min_width="auto"): cloth_image = gr.Image( interactive=True, label="Select Garment", type="filepath" ) with gr.Column(scale=1, min_width="auto"): gr.Markdown( '''

For Best Performance:

  1. Stand in front of a plain, contrasting background.
  2. Ensure your entire body is visible in the frame.
  3. Upload the highest quality image possible.
  4. Avoid cluttered or low-light environments.
  5. Wear minimal accessories for accurate results.
''' ) cloth_type = gr.Radio( label="(Important) Garment Type", choices=["upper", "lower", "overall"], # value="upper", ) submit = gr.Button("Try On", elem_classes="primary-btn") gr.Markdown( '''
Important: Please wait after clicking Try On - Processing may take a moment
''' ) # gr.Markdown( # ''' #
#

Advanced Settings:

# #
# ''' # ) with gr.Accordion("Developer Options", open=False): num_inference_steps = gr.Slider( label="Quality Steps", minimum=10, maximum=100, step=5, value=50 ) guidance_scale = gr.Slider( label="Style Intensity", minimum=0.0, maximum=7.5, step=0.5, value=2.5 ) seed = gr.Slider( label="Variation Seed", minimum=-1, maximum=10000, step=1, value=42 ) show_type = gr.Radio( label="Display Options", choices=["result only", "input & result", "input & mask & result"], value="input & result", ) with gr.Column(scale=2, min_width="auto"): result_image = gr.Image( interactive=False, label="Virtual Try-On Result" ) with gr.Row(): root_path = "resource/demo/example" with gr.Column(): men_exm = gr.Examples( examples=[ os.path.join(root_path, "person", "men", _) for _ in os.listdir(os.path.join(root_path, "person", "men")) ], examples_per_page=4, inputs=image_path, label="Sample Photos - Men", ) women_exm = gr.Examples( examples=[ os.path.join(root_path, "person", "women", _) for _ in os.listdir(os.path.join(root_path, "person", "women")) ], examples_per_page=4, inputs=image_path, label="Sample Photos - Women", ) with gr.Column(): condition_upper_exm = gr.Examples( examples=[ os.path.join(root_path, "condition", "upper", _) for _ in os.listdir(os.path.join(root_path, "condition", "upper")) ], examples_per_page=4, inputs=cloth_image, label="Sample Upper Garments", ) condition_overall_exm = gr.Examples( examples=[ os.path.join(root_path, "condition", "overall", _) for _ in os.listdir(os.path.join(root_path, "condition", "overall")) ], examples_per_page=4, inputs=cloth_image, label="Sample Full Outfits", ) condition_person_exm = gr.Examples( examples=[ os.path.join(root_path, "condition", "person", _) for _ in os.listdir(os.path.join(root_path, "condition", "person")) ], examples_per_page=4, inputs=cloth_image, label="Style Reference Photos", ) image_path.change( person_example_fn, inputs=image_path, outputs=person_image ) submit.click( submit_function, [ person_image, cloth_image, cloth_type, num_inference_steps, guidance_scale, seed, show_type, ], result_image, ) demo.queue().launch(share=True, show_error=True) if __name__ == "__main__": app_gradio()