Virtual-Try-On / app.py
parokshsaxena
commenting out code for enhanced garment net generated from claude as it was failing the flow
91c7a78
raw
history blame
16.1 kB
import spaces
import logging
import math
import gradio as gr
from PIL import Image
from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline
from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref
from src.unet_hacked_tryon import UNet2DConditionModel
from src.enhanced_garment_net import EnhancedGarmentNetWithTimestep
from transformers import (
CLIPImageProcessor,
CLIPVisionModelWithProjection,
CLIPTextModel,
CLIPTextModelWithProjection,
)
from diffusers import DDPMScheduler,AutoencoderKL
from typing import List
import torch
import os
from transformers import AutoTokenizer
import numpy as np
from utils_mask import get_mask_location
from torchvision import transforms
import apply_net
from preprocess.humanparsing.run_parsing import Parsing
from preprocess.openpose.run_openpose import OpenPose
from detectron2.data.detection_utils import convert_PIL_to_numpy,_apply_exif_orientation
from torchvision.transforms.functional import to_pil_image
from src.background_processor import BackgroundProcessor
def pil_to_binary_mask(pil_image, threshold=0):
np_image = np.array(pil_image)
grayscale_image = Image.fromarray(np_image).convert("L")
binary_mask = np.array(grayscale_image) > threshold
mask = np.zeros(binary_mask.shape, dtype=np.uint8)
for i in range(binary_mask.shape[0]):
for j in range(binary_mask.shape[1]):
if binary_mask[i,j] == True :
mask[i,j] = 1
mask = (mask*255).astype(np.uint8)
output_mask = Image.fromarray(mask)
return output_mask
base_path = 'yisol/IDM-VTON'
example_path = os.path.join(os.path.dirname(__file__), 'example')
unet = UNet2DConditionModel.from_pretrained(
base_path,
subfolder="unet",
torch_dtype=torch.float16,
)
unet.requires_grad_(False)
# This is suggestion from Claude for enhanced garment net
#enhancedGarmentNet = EnhancedGarmentNetWithTimestep()
#enhancedGarmentNet.to(dtype=torch.float16)
tokenizer_one = AutoTokenizer.from_pretrained(
base_path,
subfolder="tokenizer",
revision=None,
use_fast=False,
)
tokenizer_two = AutoTokenizer.from_pretrained(
base_path,
subfolder="tokenizer_2",
revision=None,
use_fast=False,
)
noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler")
text_encoder_one = CLIPTextModel.from_pretrained(
base_path,
subfolder="text_encoder",
torch_dtype=torch.float16,
)
text_encoder_two = CLIPTextModelWithProjection.from_pretrained(
base_path,
subfolder="text_encoder_2",
torch_dtype=torch.float16,
)
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
base_path,
subfolder="image_encoder",
torch_dtype=torch.float16,
)
vae = AutoencoderKL.from_pretrained(base_path,
subfolder="vae",
torch_dtype=torch.float16,
)
# "stabilityai/stable-diffusion-xl-base-1.0",
UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(
base_path,
subfolder="unet_encoder",
torch_dtype=torch.float16,
)
parsing_model = Parsing(0)
openpose_model = OpenPose(0)
UNet_Encoder.requires_grad_(False)
image_encoder.requires_grad_(False)
vae.requires_grad_(False)
unet.requires_grad_(False)
text_encoder_one.requires_grad_(False)
text_encoder_two.requires_grad_(False)
tensor_transfrom = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
pipe = TryonPipeline.from_pretrained(
base_path,
unet=unet,
vae=vae,
feature_extractor= CLIPImageProcessor(),
text_encoder = text_encoder_one,
text_encoder_2 = text_encoder_two,
tokenizer = tokenizer_one,
tokenizer_2 = tokenizer_two,
scheduler = noise_scheduler,
image_encoder=image_encoder,
torch_dtype=torch.float16,
)
pipe.unet_encoder = UNet_Encoder
# pipe.garment_net = enhancedGarmentNet
# Standard size of shein images
#WIDTH = int(4160/5)
#HEIGHT = int(6240/5)
# Standard size on which model is trained
WIDTH = int(768)
HEIGHT = int(1024)
POSE_WIDTH = int(WIDTH/2) # int(WIDTH/2)
POSE_HEIGHT = int(HEIGHT/2) #int(HEIGHT/2)
ARM_WIDTH = "dc" # "hd" # hd -> full sleeve, dc for half sleeve
CATEGORY = "upper_body" # "lower_body"
def is_cropping_required(width, height):
# If aspect ratio is 1.33, which is same as standard 3x4 ( 768x1024 ), then no need to crop, else crop
aspect_ratio = round(height/width, 2)
if aspect_ratio == 1.33:
return False
return True
@spaces.GPU
def start_tryon(human_img_dict,garm_img,garment_des, background_img, is_checked,is_checked_crop,denoise_steps,seed):
logging.info("Starting try on")
#device = "cuda"
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
openpose_model.preprocessor.body_estimation.model.to(device)
pipe.to(device)
pipe.unet_encoder.to(device)
# pipe.garment_net.to(device)
human_img_orig = human_img_dict["background"].convert("RGB") # ImageEditor
#human_img_orig = human_img_dict.convert("RGB") # Image
"""
# Derive HEIGHT & WIDTH such that width is not more than 1000. This will cater to both Shein images (4160x6240) of 2:3 AR and model standard images ( 768x1024 ) of 3:4 AR
WIDTH, HEIGHT = human_img_orig.size
division_factor = math.ceil(WIDTH/1000)
WIDTH = int(WIDTH/division_factor)
HEIGHT = int(HEIGHT/division_factor)
POSE_WIDTH = int(WIDTH/2)
POSE_HEIGHT = int(HEIGHT/2)
"""
# is_checked_crop as True if original AR is not same as 2x3 as expected by model
w, h = human_img_orig.size
is_checked_crop = is_cropping_required(w, h)
garm_img= garm_img.convert("RGB").resize((WIDTH,HEIGHT))
if is_checked_crop:
# This will crop the image to make it Aspect Ratio of 3 x 4. And then at the end revert it back to original dimentions
width, height = human_img_orig.size
target_width = int(min(width, height * (3 / 4)))
target_height = int(min(height, width * (4 / 3)))
left = (width - target_width) / 2
right = (width + target_width) / 2
# for Landmark, model sizes are 594x879, so we need to reduce the height. In some case the garment on the model is
# also getting removed when reducing size from bottom. So we will only reduce height from top for now
top = (height - target_height) #top = (height - target_height) / 2
bottom = height #bottom = (height + target_height) / 2
cropped_img = human_img_orig.crop((left, top, right, bottom))
crop_size = cropped_img.size
human_img = cropped_img.resize((WIDTH, HEIGHT))
else:
human_img = human_img_orig.resize((WIDTH, HEIGHT))
# Commenting out naize harmonization for now. We will have to integrate with Deep Learning based Harmonization methods
# Do color transfer from background image for better image harmonization
#if background_img:
# human_img = BackgroundProcessor.intensity_transfer(human_img, background_img)
if is_checked:
# internally openpose_model is resizing human_img to resolution 384 if not passed as input
keypoints = openpose_model(human_img.resize((POSE_WIDTH, POSE_HEIGHT)))
model_parse, _ = parsing_model(human_img.resize((POSE_WIDTH, POSE_HEIGHT)))
# internally get mask location function is resizing model_parse to 384x512 if width & height not passed
mask, mask_gray = get_mask_location(ARM_WIDTH, CATEGORY, model_parse, keypoints)
mask = mask.resize((WIDTH, HEIGHT))
logging.info("Mask location on model identified")
else:
mask = pil_to_binary_mask(human_img_dict['layers'][0].convert("RGB").resize((WIDTH, HEIGHT)))
# mask = transforms.ToTensor()(mask)
# mask = mask.unsqueeze(0)
mask_gray = (1-transforms.ToTensor()(mask)) * tensor_transfrom(human_img)
mask_gray = to_pil_image((mask_gray+1.0)/2.0)
human_img_arg = _apply_exif_orientation(human_img.resize((POSE_WIDTH,POSE_HEIGHT)))
human_img_arg = convert_PIL_to_numpy(human_img_arg, format="BGR")
args = apply_net.create_argument_parser().parse_args(('show', './configs/densepose_rcnn_R_50_FPN_s1x.yaml', './ckpt/densepose/model_final_162be9.pkl', 'dp_segm', '-v', '--opts', 'MODEL.DEVICE', device))
# verbosity = getattr(args, "verbosity", None)
pose_img = args.func(args,human_img_arg)
pose_img = pose_img[:,:,::-1]
pose_img = Image.fromarray(pose_img).resize((WIDTH,HEIGHT))
with torch.no_grad():
# Extract the images
with torch.cuda.amp.autocast():
with torch.no_grad():
prompt = "model is wearing " + garment_des
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
with torch.inference_mode():
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = pipe.encode_prompt(
prompt,
num_images_per_prompt=1,
do_classifier_free_guidance=True,
negative_prompt=negative_prompt,
)
prompt = "a photo of " + garment_des
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
if not isinstance(prompt, List):
prompt = [prompt] * 1
if not isinstance(negative_prompt, List):
negative_prompt = [negative_prompt] * 1
with torch.inference_mode():
(
prompt_embeds_c,
_,
_,
_,
) = pipe.encode_prompt(
prompt,
num_images_per_prompt=1,
do_classifier_free_guidance=False,
negative_prompt=negative_prompt,
)
pose_img = tensor_transfrom(pose_img).unsqueeze(0).to(device,torch.float16)
garm_tensor = tensor_transfrom(garm_img).unsqueeze(0).to(device,torch.float16)
generator = torch.Generator(device).manual_seed(seed) if seed is not None else None
images = pipe(
prompt_embeds=prompt_embeds.to(device,torch.float16),
negative_prompt_embeds=negative_prompt_embeds.to(device,torch.float16),
pooled_prompt_embeds=pooled_prompt_embeds.to(device,torch.float16),
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(device,torch.float16),
num_inference_steps=denoise_steps,
generator=generator,
strength = 1.0,
pose_img = pose_img.to(device,torch.float16),
text_embeds_cloth=prompt_embeds_c.to(device,torch.float16),
cloth = garm_tensor.to(device,torch.float16),
mask_image=mask,
image=human_img,
height=HEIGHT,
width=WIDTH,
ip_adapter_image = garm_img.resize((WIDTH,HEIGHT)),
guidance_scale=2.0,
)[0]
if is_checked_crop:
out_img = images[0].resize(crop_size)
human_img_orig.paste(out_img, (int(left), int(top)))
final_image = human_img_orig
# return human_img_orig, mask_gray
else:
final_image = images[0]
# return images[0], mask_gray
# apply background to final image
if background_img:
logging.info("Adding background")
final_image = BackgroundProcessor.replace_background_with_removebg(final_image, background_img)
return final_image, mask_gray
# return images[0], mask_gray
garm_list = os.listdir(os.path.join(example_path,"cloth"))
garm_list_path = [os.path.join(example_path,"cloth",garm) for garm in garm_list]
human_list = os.listdir(os.path.join(example_path,"human"))
human_list_path = [os.path.join(example_path,"human",human) for human in human_list]
human_ex_list = []
#human_ex_list = human_list_path # Image
#""" if using ImageEditor instead of Image while taking input, use this - ImageEditor
for ex_human in human_list_path:
ex_dict= {}
ex_dict['background'] = ex_human
ex_dict['layers'] = None
ex_dict['composite'] = None
human_ex_list.append(ex_dict)
#"""
##default human
# api_open=True will allow this API to be hit using curl
image_blocks = gr.Blocks().queue(api_open=True)
with image_blocks as demo:
gr.Markdown("## Virtual Try-On πŸ‘•πŸ‘”πŸ‘š")
gr.Markdown("Upload an image of a person and an image of a garment ✨.")
with gr.Row():
with gr.Column():
# changing from ImageEditor to Image to allow easy passing of data through API
# instead of passing {"dictionary": <>} ( which is failing ), we can directly pass the image
imgs = gr.ImageEditor(sources='upload', type="pil", label='Human. Mask with pen or use auto-masking', interactive=True)
#imgs = gr.Image(sources='upload', type='pil',label='Human. Mask with pen or use auto-masking')
with gr.Row():
is_checked = gr.Checkbox(label="Yes", info="Use auto-generated mask (Takes 5 seconds)",value=True)
with gr.Row():
is_checked_crop = gr.Checkbox(label="Yes", info="Use auto-crop & resizing",value=False)
example = gr.Examples(
inputs=imgs,
examples_per_page=10,
examples=human_ex_list
)
with gr.Column():
garm_img = gr.Image(label="Garment", sources='upload', type="pil")
with gr.Row(elem_id="prompt-container"):
with gr.Row():
prompt = gr.Textbox(placeholder="Description of garment ex) Short Sleeve Round Neck T-shirts", show_label=False, elem_id="prompt")
example = gr.Examples(
inputs=garm_img,
examples_per_page=8,
examples=garm_list_path)
with gr.Column():
background_img = gr.Image(label="Background", sources='upload', type="pil")
with gr.Column():
with gr.Row():
image_out = gr.Image(label="Output", elem_id="output-img", show_share_button=False)
with gr.Row():
masked_img = gr.Image(label="Masked image output", elem_id="masked-img", show_share_button=False)
"""
with gr.Column():
# image_out = gr.Image(label="Output", elem_id="output-img", height=400)
masked_img = gr.Image(label="Masked image output", elem_id="masked-img", show_share_button=False)
with gr.Column():
# image_out = gr.Image(label="Output", elem_id="output-img", height=400)
image_out = gr.Image(label="Output", elem_id="output-img", show_share_button=False)
"""
with gr.Column():
try_button = gr.Button(value="Try-on")
with gr.Accordion(label="Advanced Settings", open=False):
with gr.Row():
denoise_steps = gr.Number(label="Denoising Steps", minimum=20, maximum=40, value=30, step=1)
seed = gr.Number(label="Seed", minimum=-1, maximum=2147483647, step=1, value=42)
try_button.click(fn=start_tryon, inputs=[imgs, garm_img, prompt, background_img, is_checked,is_checked_crop, denoise_steps, seed], outputs=[image_out,masked_img], api_name='tryon')
image_blocks.launch()