import gradio as gr import spaces from PIL import Image from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref from src.unet_hacked_tryon import UNet2DConditionModel from transformers import ( CLIPImageProcessor, CLIPVisionModelWithProjection, CLIPTextModel, CLIPTextModelWithProjection, ) from diffusers import DDPMScheduler,AutoencoderKL from typing import List import torch import os from transformers import AutoTokenizer import numpy as np from utils_mask import get_mask_location from torchvision import transforms import apply_net from preprocess.humanparsing.run_parsing import Parsing from preprocess.openpose.run_openpose import OpenPose from detectron2.data.detection_utils import convert_PIL_to_numpy,_apply_exif_orientation from torchvision.transforms.functional import to_pil_image custom_css = """ .container { max-width: 1200px; margin: 0 auto; padding: 20px; } .header { display: flex; align-items: center; justify-content: space-between; padding: 20px 0; } .header { margin-bottom: 20px; /* Adds space below the header */ } .main-content { margin-top: 20px; /* Adds space above the main content row */ } .upload-box { border: 2px dashed #E5E7EB !important; border-radius: 8px !important; min-height: 300px !important; display: flex !important; flex-direction: column !important; align-items: center !important; justify-content: center !important; background: white !important; margin: 10px 0 !important; } .preview-container { background: #F9FAFB; border-radius: 8px; padding: 15px; text-align: center; } .browser-header { background: #F3F4F6; padding: 8px; border-radius: 8px 8px 0 0; margin-bottom: 10px; } .browser-dots { display: flex; gap: 6px; } #contact-button { background: #7E22CE !important; /* Purple background */ color: white !important; /* White text */ font-size: 12px; /* Adjust font size */ padding: 10px 16px; /* Larger padding for height */ border-radius: 15px; /* Rounded corners for consistent styling */ cursor: pointer; /* Pointer cursor on hover */ text-transform: uppercase; /* Uppercase text */ border: none; /* No border */ box-shadow: 0px 2px 4px rgba(0, 0, 0, 0.1); /* Subtle shadow */ width: 60px; /* Smaller width */ text-align: center; /* Centers the text */ display: inline-block; /* Prevents button from stretching */ margin: 0 auto; /* Centers the button horizontally */ } #contact-button:hover { background: #5A189A !important; /* Slightly darker purple on hover */ box-shadow: 0px 4px 6px rgba(0, 0, 0, 0.15); /* Enhanced shadow on hover */ } .url-bar { background: white; border: 1px solid #E5E7EB; border-radius: 4px; padding: 6px 12px; margin: 8px 0; display: flex; align-items: center; } .generate-btn { background: #7E22CE !important; color: white !important; padding: 12px 24px !important; border-radius: 8px !important; width: 100% !important; margin-top: 20px !important; } .upload-box { border: 2px dashed #E5E7EB !important; /* Dashed border */ border-radius: 8px !important; /* Rounded corners */ min-height: 300px !important; /* Fixed height */ display: flex !important; /* Flexbox for alignment */ flex-direction: column !important; /* Vertical layout */ align-items: center !important; /* Center content horizontally */ justify-content: center !important; /* Center content vertically */ background: #F9FAFB !important; /* Light gray background */ margin: 10px 0 !important; /* Spacing */ text-align: center; /* Center text */ } #cloud-icon { font-size: 48px; /* Larger icon size */ margin-bottom: 10px; /* Spacing below icon */ } .supported-formats { color: #6B7280; /* Gray text color */ font-size: 12px; /* Smaller font size */ text-align: center; /* Center text */ margin-top: 10px; /* Spacing above */ } .sample-btn { background: #E3D8FF !important; /* Light purple background */ color: #7E22CE !important; /* Purple text color */ padding: 6px 12px !important; /* Button padding */ border-radius: 4px !important; /* Rounded button */ font-size: 14px; /* Button font size */ text-transform: uppercase; /* Uppercase text */ margin-top: 10px !important; /* Spacing above button */ cursor: pointer; /* Pointer cursor */ } """ def pil_to_binary_mask(pil_image, threshold=0): np_image = np.array(pil_image) grayscale_image = Image.fromarray(np_image).convert("L") binary_mask = np.array(grayscale_image) > threshold mask = np.zeros(binary_mask.shape, dtype=np.uint8) for i in range(binary_mask.shape[0]): for j in range(binary_mask.shape[1]): if binary_mask[i,j] == True : mask[i,j] = 1 mask = (mask*255).astype(np.uint8) output_mask = Image.fromarray(mask) return output_mask base_path = 'Roopansh/Ailusion-VTON-DEMO-v1.1' example_path = os.path.join(os.path.dirname(__file__), 'example') unet = UNet2DConditionModel.from_pretrained( base_path, subfolder="unet", torch_dtype=torch.float16, ) unet.requires_grad_(False) tokenizer_one = AutoTokenizer.from_pretrained( base_path, subfolder="tokenizer", revision=None, use_fast=False, ) tokenizer_two = AutoTokenizer.from_pretrained( base_path, subfolder="tokenizer_2", revision=None, use_fast=False, ) noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler") text_encoder_one = CLIPTextModel.from_pretrained( base_path, subfolder="text_encoder", torch_dtype=torch.float16, ) text_encoder_two = CLIPTextModelWithProjection.from_pretrained( base_path, subfolder="text_encoder_2", torch_dtype=torch.float16, ) image_encoder = CLIPVisionModelWithProjection.from_pretrained( base_path, subfolder="image_encoder", torch_dtype=torch.float16, ) vae = AutoencoderKL.from_pretrained(base_path, subfolder="vae", torch_dtype=torch.float16, ) # "stabilityai/stable-diffusion-xl-base-1.0", UNet_Encoder = UNet2DConditionModel_ref.from_pretrained( base_path, subfolder="unet_encoder", torch_dtype=torch.float16, ) parsing_model = Parsing(0) openpose_model = OpenPose(0) UNet_Encoder.requires_grad_(False) image_encoder.requires_grad_(False) vae.requires_grad_(False) unet.requires_grad_(False) text_encoder_one.requires_grad_(False) text_encoder_two.requires_grad_(False) tensor_transfrom = transforms.Compose( [ transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] ) pipe = TryonPipeline.from_pretrained( base_path, unet=unet, vae=vae, feature_extractor= CLIPImageProcessor(), text_encoder = text_encoder_one, text_encoder_2 = text_encoder_two, tokenizer = tokenizer_one, tokenizer_2 = tokenizer_two, scheduler = noise_scheduler, image_encoder=image_encoder, torch_dtype=torch.float16, ) pipe.unet_encoder = UNet_Encoder @spaces.GPU(duration=120) def start_tryon(dict,garm_img,garment_des,is_checked,is_checked_crop,denoise_steps,seed): device = "cuda" openpose_model.preprocessor.body_estimation.model.to(device) pipe.to(device) pipe.unet_encoder.to(device) garm_img= garm_img.convert("RGB").resize((768,1024)) human_img_orig = dict["background"].convert("RGB") if is_checked_crop: width, height = human_img_orig.size target_width = int(min(width, height * (3 / 4))) target_height = int(min(height, width * (4 / 3))) left = (width - target_width) / 2 top = (height - target_height) / 2 right = (width + target_width) / 2 bottom = (height + target_height) / 2 cropped_img = human_img_orig.crop((left, top, right, bottom)) crop_size = cropped_img.size human_img = cropped_img.resize((768,1024)) else: human_img = human_img_orig.resize((768,1024)) if is_checked: keypoints = openpose_model(human_img.resize((384,512))) model_parse, _ = parsing_model(human_img.resize((384,512))) mask, mask_gray = get_mask_location('hd', "upper_body", model_parse, keypoints) mask = mask.resize((768,1024)) else: mask = pil_to_binary_mask(dict['layers'][0].convert("RGB").resize((768, 1024))) # mask = transforms.ToTensor()(mask) # mask = mask.unsqueeze(0) mask_gray = (1-transforms.ToTensor()(mask)) * tensor_transfrom(human_img) mask_gray = to_pil_image((mask_gray+1.0)/2.0) human_img_arg = _apply_exif_orientation(human_img.resize((384,512))) human_img_arg = convert_PIL_to_numpy(human_img_arg, format="BGR") args = apply_net.create_argument_parser().parse_args(('show', './configs/densepose_rcnn_R_50_FPN_s1x.yaml', './ckpt/densepose/model_final_162be9.pkl', 'dp_segm', '-v', '--opts', 'MODEL.DEVICE', 'cuda')) # verbosity = getattr(args, "verbosity", None) pose_img = args.func(args,human_img_arg) pose_img = pose_img[:,:,::-1] pose_img = Image.fromarray(pose_img).resize((768,1024)) with torch.no_grad(): # Extract the images with torch.cuda.amp.autocast(): with torch.no_grad(): prompt = "model is wearing " + garment_des negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" with torch.inference_mode(): ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = pipe.encode_prompt( prompt, num_images_per_prompt=1, do_classifier_free_guidance=True, negative_prompt=negative_prompt, ) prompt = "a photo of " + garment_des negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" if not isinstance(prompt, List): prompt = [prompt] * 1 if not isinstance(negative_prompt, List): negative_prompt = [negative_prompt] * 1 with torch.inference_mode(): ( prompt_embeds_c, _, _, _, ) = pipe.encode_prompt( prompt, num_images_per_prompt=1, do_classifier_free_guidance=False, negative_prompt=negative_prompt, ) pose_img = tensor_transfrom(pose_img).unsqueeze(0).to(device,torch.float16) garm_tensor = tensor_transfrom(garm_img).unsqueeze(0).to(device,torch.float16) generator = torch.Generator(device).manual_seed(seed) if seed is not None else None images = pipe( prompt_embeds=prompt_embeds.to(device,torch.float16), negative_prompt_embeds=negative_prompt_embeds.to(device,torch.float16), pooled_prompt_embeds=pooled_prompt_embeds.to(device,torch.float16), negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(device,torch.float16), num_inference_steps=denoise_steps, generator=generator, strength = 1.0, pose_img = pose_img.to(device,torch.float16), text_embeds_cloth=prompt_embeds_c.to(device,torch.float16), cloth = garm_tensor.to(device,torch.float16), mask_image=mask, image=human_img, height=1024, width=768, ip_adapter_image = garm_img.resize((768,1024)), guidance_scale=2.0, )[0] if is_checked_crop: out_img = images[0].resize(crop_size) human_img_orig.paste(out_img, (int(left), int(top))) # return human_img_orig, mask_gray return human_img_orig else: # return images[0], mask_gray return images[0] # return images[0], mask_gray garm_list = os.listdir(os.path.join(example_path,"cloth")) garm_list_path = [os.path.join(example_path,"cloth",garm) for garm in garm_list] human_list = os.listdir(os.path.join(example_path,"human")) human_list_path = [os.path.join(example_path,"human",human) for human in human_list] human_ex_list = [] for ex_human in human_list_path: ex_dict= {} ex_dict['background'] = ex_human ex_dict['layers'] = None ex_dict['composite'] = None human_ex_list.append(ex_dict) ##default human with gr.Blocks(css=custom_css) as demo: with gr.Column(): # Header with gr.Row(elem_classes="header"): gr.Markdown("

AILUSION

", elem_id="component-0") gr.Markdown("

Feature Trail

") gr.Button("Contact us", variant="primary", elem_id="contact-button") # Main content with gr.Row(elem_classes="main-content"): # Upload Human with gr.Column(): gr.Markdown("

Upload Human

", elem_id="upload-human-title") with gr.Column(elem_classes="upload-box"): imgs = gr.Image(type="pil", interactive=True) gr.Markdown( "

Supported formats: JPEG, PNG

", elem_classes="supported-formats" ) example = gr.Examples( inputs=human_input, examples_per_page=5, examples=human_list_path) # Upload Garment with gr.Column(): gr.Markdown("

Upload Garment

", elem_classes="upload-title") with gr.Column(elem_classes="upload-box"): garm_img = gr.Image(type="pil", interactive=True) gr.Markdown( "

Supported formats: JPEG, PNG

", elem_classes="supported-formats" ) example = gr.Examples( inputs=garment_input, examples_per_page=5, examples=garm_ex_list ) # Preview Section with gr.Column(elem_classes="preview-container"): preview_output = gr.Image(type="pil") gr.Markdown("\u20b9998") with gr.Row(): gr.Button("Add to cart", variant="secondary") gr.Button("\u2661", variant="secondary") # Generate Output button with gr.Row(): generate_btn = gr.Button("Generate Output", elem_classes="generate-btn", scale=2) # Set the function to process images generate_btn.click( fn=start_tryon, inputs=[human_input, garment_input], outputs=preview_output, api_name='tryon' ) demo.launch(share=True)