Shivam Singh commited on
Commit
88bddf8
β€’
1 Parent(s): 53b7f16

update changes

Browse files
Files changed (2) hide show
  1. app.py +206 -102
  2. src/tryon_pipeline.py +4 -11
app.py CHANGED
@@ -10,20 +10,23 @@ from transformers import (
10
  CLIPTextModel,
11
  CLIPTextModelWithProjection,
12
  )
13
- from diffusers import DDPMScheduler, AutoencoderKL
14
  from typing import List
 
15
  import torch
16
  import os
17
  from transformers import AutoTokenizer
 
18
  import numpy as np
 
19
  from torchvision import transforms
20
  import apply_net
21
  from preprocess.humanparsing.run_parsing import Parsing
22
  from preprocess.openpose.run_openpose import OpenPose
23
- from detectron2.data.detection_utils import convert_PIL_to_numpy, _apply_exif_orientation
24
  from torchvision.transforms.functional import to_pil_image
25
 
26
- # Function Definitions and Initialization (Unchanged Sections)
27
  def pil_to_binary_mask(pil_image, threshold=0):
28
  np_image = np.array(pil_image)
29
  grayscale_image = Image.fromarray(np_image).convert("L")
@@ -31,9 +34,9 @@ def pil_to_binary_mask(pil_image, threshold=0):
31
  mask = np.zeros(binary_mask.shape, dtype=np.uint8)
32
  for i in range(binary_mask.shape[0]):
33
  for j in range(binary_mask.shape[1]):
34
- if binary_mask[i, j]:
35
- mask[i, j] = 1
36
- mask = (mask * 255).astype(np.uint8)
37
  output_mask = Image.fromarray(mask)
38
  return output_mask
39
 
@@ -41,20 +44,52 @@ def pil_to_binary_mask(pil_image, threshold=0):
41
  base_path = 'yisol/IDM-VTON'
42
  example_path = os.path.join(os.path.dirname(__file__), 'example')
43
 
44
- # Loading models
45
- unet = UNet2DConditionModel.from_pretrained(base_path, subfolder="unet", torch_dtype=torch.float16)
 
 
 
46
  unet.requires_grad_(False)
47
-
48
- tokenizer_one = AutoTokenizer.from_pretrained(base_path, subfolder="tokenizer", revision=None, use_fast=False)
49
- tokenizer_two = AutoTokenizer.from_pretrained(base_path, subfolder="tokenizer_2", revision=None, use_fast=False)
50
-
 
 
 
 
 
 
 
 
51
  noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler")
52
- text_encoder_one = CLIPTextModel.from_pretrained(base_path, subfolder="text_encoder", torch_dtype=torch.float16)
53
- text_encoder_two = CLIPTextModelWithProjection.from_pretrained(base_path, subfolder="text_encoder_2", torch_dtype=torch.float16)
54
- image_encoder = CLIPVisionModelWithProjection.from_pretrained(base_path, subfolder="image_encoder", torch_dtype=torch.float16)
55
- vae = AutoencoderKL.from_pretrained(base_path, subfolder="vae", torch_dtype=torch.float16)
56
 
57
- UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(base_path, subfolder="unet_encoder", torch_dtype=torch.float16)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  parsing_model = Parsing(0)
60
  openpose_model = OpenPose(0)
@@ -65,35 +100,38 @@ vae.requires_grad_(False)
65
  unet.requires_grad_(False)
66
  text_encoder_one.requires_grad_(False)
67
  text_encoder_two.requires_grad_(False)
68
-
69
- tensor_transfrom = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
 
 
 
 
70
 
71
  pipe = TryonPipeline.from_pretrained(
72
- base_path,
73
- unet=unet,
74
- vae=vae,
75
- feature_extractor=CLIPImageProcessor(),
76
- text_encoder=text_encoder_one,
77
- text_encoder_2=text_encoder_two,
78
- tokenizer=tokenizer_one,
79
- tokenizer_2=tokenizer_two,
80
- scheduler=noise_scheduler,
81
- image_encoder=image_encoder,
82
- torch_dtype=torch.float16,
83
  )
84
  pipe.unet_encoder = UNet_Encoder
85
 
86
- # Function for try-on functionality
87
  @spaces.GPU
88
- def start_tryon(dict, garm_img, garment_des, is_checked_crop, denoise_steps, seed):
89
  device = "cuda"
90
 
91
  openpose_model.preprocessor.body_estimation.model.to(device)
92
  pipe.to(device)
93
  pipe.unet_encoder.to(device)
94
 
95
- garm_img = garm_img.convert("RGB").resize((768,1024))
96
- human_img_orig = dict["background"].convert("RGB")
97
 
98
  if is_checked_crop:
99
  width, height = human_img_orig.size
@@ -109,97 +147,163 @@ def start_tryon(dict, garm_img, garment_des, is_checked_crop, denoise_steps, see
109
  else:
110
  human_img = human_img_orig.resize((768,1024))
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  with torch.no_grad():
113
  # Extract the images
114
  with torch.cuda.amp.autocast():
115
- prompt = "model is wearing " + garment_des
116
- negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
117
- (
118
- prompt_embeds,
119
- negative_prompt_embeds,
120
- pooled_prompt_embeds,
121
- negative_pooled_prompt_embeds,
122
- ) = pipe.encode_prompt(
123
- prompt,
124
- num_images_per_prompt=1,
125
- do_classifier_free_guidance=True,
126
- negative_prompt=negative_prompt,
127
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
- # No mask is used, set mask_image=None
130
- pose_img = tensor_transfrom(human_img).unsqueeze(0).to(device, torch.float16)
131
- garm_tensor = tensor_transfrom(garm_img).unsqueeze(0).to(device, torch.float16)
132
- generator = torch.Generator(device).manual_seed(seed) if seed is not None else None
133
- images = pipe(
134
- prompt_embeds=prompt_embeds.to(device, torch.float16),
135
- negative_prompt_embeds=negative_prompt_embeds.to(device, torch.float16),
136
- pooled_prompt_embeds=pooled_prompt_embeds.to(device, torch.float16),
137
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(device, torch.float16),
138
- num_inference_steps=denoise_steps,
139
- generator=generator,
140
- strength=1.0,
141
- pose_img=pose_img.to(device, torch.float16),
142
- cloth=garm_tensor.to(device, torch.float16),
143
- image=human_img,
144
- mask_image=None, # Bypassing the mask
145
- height=1024,
146
- width=768,
147
- ip_adapter_image=garm_img.resize((768, 1024)),
148
- guidance_scale=2.0,
149
- )[0]
 
 
150
 
151
  if is_checked_crop:
152
- out_img = images[0].resize(crop_size)
153
- human_img_orig.paste(out_img, (int(left), int(top)))
154
- return human_img_orig
155
  else:
156
- return images[0]
 
157
 
158
- # Loading example garments and human images
159
- garm_list = os.listdir(os.path.join(example_path, "cloth"))
160
- garm_list_path = [os.path.join(example_path, "cloth", garm) for garm in garm_list]
161
 
162
- human_list = os.listdir(os.path.join(example_path, "human"))
163
- human_list_path = [os.path.join(example_path, "human", human) for human in human_list]
164
 
165
  human_ex_list = []
166
  for ex_human in human_list_path:
167
- ex_dict = {"background": ex_human, "layers": None, "composite": None}
 
 
 
168
  human_ex_list.append(ex_dict)
169
 
170
- # Building the Gradio interface
 
 
171
  image_blocks = gr.Blocks().queue()
172
  with image_blocks as demo:
173
  with gr.Row():
174
- # Column 1: Upload person image
175
  with gr.Column():
176
- gr.Markdown("### Step 1: Upload a person image")
177
- imgs = gr.ImageEditor(sources="upload", type="pil", label="Human image", interactive=True)
178
- is_checked_crop = gr.Checkbox(label="Use auto-crop & resizing", value=False)
 
 
179
 
180
- gr.Examples(inputs=imgs, examples_per_page=10, examples=human_ex_list, label="Person image examples")
 
 
 
 
181
 
182
- # Column 2: Upload garment image
183
  with gr.Column():
184
- gr.Markdown("### Step 2: Upload a garment image")
185
- garm_img = gr.Image(label="Garment", sources="upload", type="pil")
186
- prompt = gr.Textbox(placeholder="Description of garment (e.g., Short Sleeve Round Neck T-shirts)", show_label=False)
187
- gr.Examples(inputs=garm_img, examples_per_page=8, examples=garm_list_path, label="Garment examples")
188
-
189
- # Column 3: Display Results
 
 
190
  with gr.Column():
191
- gr.Markdown("### Step 3: Try-On Results")
192
- image_out = gr.Image(label="Final Try-On Output", elem_id="output-img", show_share_button=False)
 
 
193
 
194
- # Advanced settings and run button
195
- with gr.Row():
196
- try_button = gr.Button(value="Run Try-On")
197
- with gr.Accordion("Advanced Settings", open=False):
198
- denoise_steps = gr.Number(label="Denoising Steps", minimum=20, maximum=40, value=30, step=1)
199
- seed = gr.Number(label="Seed", minimum=-1, maximum=2147483647, step=1, value=42)
200
 
201
- # Button click event to trigger the try-on function
202
- try_button.click(fn=start_tryon, inputs=[imgs, garm_img, prompt, is_checked_crop, denoise_steps, seed], outputs=[image_out])
 
 
 
 
 
 
 
 
 
 
 
203
 
204
- # Launch the Gradio interface
205
  image_blocks.launch()
 
 
10
  CLIPTextModel,
11
  CLIPTextModelWithProjection,
12
  )
13
+ from diffusers import DDPMScheduler,AutoencoderKL
14
  from typing import List
15
+
16
  import torch
17
  import os
18
  from transformers import AutoTokenizer
19
+
20
  import numpy as np
21
+ from utils_mask import get_mask_location
22
  from torchvision import transforms
23
  import apply_net
24
  from preprocess.humanparsing.run_parsing import Parsing
25
  from preprocess.openpose.run_openpose import OpenPose
26
+ from detectron2.data.detection_utils import convert_PIL_to_numpy,_apply_exif_orientation
27
  from torchvision.transforms.functional import to_pil_image
28
 
29
+
30
  def pil_to_binary_mask(pil_image, threshold=0):
31
  np_image = np.array(pil_image)
32
  grayscale_image = Image.fromarray(np_image).convert("L")
 
34
  mask = np.zeros(binary_mask.shape, dtype=np.uint8)
35
  for i in range(binary_mask.shape[0]):
36
  for j in range(binary_mask.shape[1]):
37
+ if binary_mask[i,j] == True :
38
+ mask[i,j] = 1
39
+ mask = (mask*255).astype(np.uint8)
40
  output_mask = Image.fromarray(mask)
41
  return output_mask
42
 
 
44
  base_path = 'yisol/IDM-VTON'
45
  example_path = os.path.join(os.path.dirname(__file__), 'example')
46
 
47
+ unet = UNet2DConditionModel.from_pretrained(
48
+ base_path,
49
+ subfolder="unet",
50
+ torch_dtype=torch.float16,
51
+ )
52
  unet.requires_grad_(False)
53
+ tokenizer_one = AutoTokenizer.from_pretrained(
54
+ base_path,
55
+ subfolder="tokenizer",
56
+ revision=None,
57
+ use_fast=False,
58
+ )
59
+ tokenizer_two = AutoTokenizer.from_pretrained(
60
+ base_path,
61
+ subfolder="tokenizer_2",
62
+ revision=None,
63
+ use_fast=False,
64
+ )
65
  noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler")
 
 
 
 
66
 
67
+ text_encoder_one = CLIPTextModel.from_pretrained(
68
+ base_path,
69
+ subfolder="text_encoder",
70
+ torch_dtype=torch.float16,
71
+ )
72
+ text_encoder_two = CLIPTextModelWithProjection.from_pretrained(
73
+ base_path,
74
+ subfolder="text_encoder_2",
75
+ torch_dtype=torch.float16,
76
+ )
77
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
78
+ base_path,
79
+ subfolder="image_encoder",
80
+ torch_dtype=torch.float16,
81
+ )
82
+ vae = AutoencoderKL.from_pretrained(base_path,
83
+ subfolder="vae",
84
+ torch_dtype=torch.float16,
85
+ )
86
+
87
+ # "stabilityai/stable-diffusion-xl-base-1.0",
88
+ UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(
89
+ base_path,
90
+ subfolder="unet_encoder",
91
+ torch_dtype=torch.float16,
92
+ )
93
 
94
  parsing_model = Parsing(0)
95
  openpose_model = OpenPose(0)
 
100
  unet.requires_grad_(False)
101
  text_encoder_one.requires_grad_(False)
102
  text_encoder_two.requires_grad_(False)
103
+ tensor_transfrom = transforms.Compose(
104
+ [
105
+ transforms.ToTensor(),
106
+ transforms.Normalize([0.5], [0.5]),
107
+ ]
108
+ )
109
 
110
  pipe = TryonPipeline.from_pretrained(
111
+ base_path,
112
+ unet=unet,
113
+ vae=vae,
114
+ feature_extractor= CLIPImageProcessor(),
115
+ text_encoder = text_encoder_one,
116
+ text_encoder_2 = text_encoder_two,
117
+ tokenizer = tokenizer_one,
118
+ tokenizer_2 = tokenizer_two,
119
+ scheduler = noise_scheduler,
120
+ image_encoder=image_encoder,
121
+ torch_dtype=torch.float16,
122
  )
123
  pipe.unet_encoder = UNet_Encoder
124
 
 
125
  @spaces.GPU
126
+ def start_tryon(dict,garm_img,garment_des,is_checked,is_checked_crop,denoise_steps,seed):
127
  device = "cuda"
128
 
129
  openpose_model.preprocessor.body_estimation.model.to(device)
130
  pipe.to(device)
131
  pipe.unet_encoder.to(device)
132
 
133
+ garm_img= garm_img.convert("RGB").resize((768,1024))
134
+ human_img_orig = dict["background"].convert("RGB")
135
 
136
  if is_checked_crop:
137
  width, height = human_img_orig.size
 
147
  else:
148
  human_img = human_img_orig.resize((768,1024))
149
 
150
+
151
+ if is_checked:
152
+ keypoints = openpose_model(human_img.resize((384,512)))
153
+ model_parse, _ = parsing_model(human_img.resize((384,512)))
154
+ mask, mask_gray = get_mask_location('hd', "upper_body", model_parse, keypoints)
155
+ mask = mask.resize((768,1024))
156
+ else:
157
+ mask = pil_to_binary_mask(dict['layers'][0].convert("RGB").resize((768, 1024)))
158
+ # mask = transforms.ToTensor()(mask)
159
+ # mask = mask.unsqueeze(0)
160
+ mask_gray = (1-transforms.ToTensor()(mask)) * tensor_transfrom(human_img)
161
+ mask_gray = to_pil_image((mask_gray+1.0)/2.0)
162
+
163
+
164
+ human_img_arg = _apply_exif_orientation(human_img.resize((384,512)))
165
+ human_img_arg = convert_PIL_to_numpy(human_img_arg, format="BGR")
166
+
167
+
168
+
169
+ args = apply_net.create_argument_parser().parse_args(('show', './configs/densepose_rcnn_R_50_FPN_s1x.yaml', './ckpt/densepose/model_final_162be9.pkl', 'dp_segm', '-v', '--opts', 'MODEL.DEVICE', 'cuda'))
170
+ # verbosity = getattr(args, "verbosity", None)
171
+ pose_img = args.func(args,human_img_arg)
172
+ pose_img = pose_img[:,:,::-1]
173
+ pose_img = Image.fromarray(pose_img).resize((768,1024))
174
+
175
  with torch.no_grad():
176
  # Extract the images
177
  with torch.cuda.amp.autocast():
178
+ with torch.no_grad():
179
+ prompt = "model is wearing " + garment_des
180
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
181
+ with torch.inference_mode():
182
+ (
183
+ prompt_embeds,
184
+ negative_prompt_embeds,
185
+ pooled_prompt_embeds,
186
+ negative_pooled_prompt_embeds,
187
+ ) = pipe.encode_prompt(
188
+ prompt,
189
+ num_images_per_prompt=1,
190
+ do_classifier_free_guidance=True,
191
+ negative_prompt=negative_prompt,
192
+ )
193
+
194
+ prompt = "a photo of " + garment_des
195
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
196
+ if not isinstance(prompt, List):
197
+ prompt = [prompt] * 1
198
+ if not isinstance(negative_prompt, List):
199
+ negative_prompt = [negative_prompt] * 1
200
+ with torch.inference_mode():
201
+ (
202
+ prompt_embeds_c,
203
+ _,
204
+ _,
205
+ _,
206
+ ) = pipe.encode_prompt(
207
+ prompt,
208
+ num_images_per_prompt=1,
209
+ do_classifier_free_guidance=False,
210
+ negative_prompt=negative_prompt,
211
+ )
212
 
213
+
214
+
215
+ pose_img = tensor_transfrom(pose_img).unsqueeze(0).to(device,torch.float16)
216
+ garm_tensor = tensor_transfrom(garm_img).unsqueeze(0).to(device,torch.float16)
217
+ generator = torch.Generator(device).manual_seed(seed) if seed is not None else None
218
+ images = pipe(
219
+ prompt_embeds=prompt_embeds.to(device,torch.float16),
220
+ negative_prompt_embeds=negative_prompt_embeds.to(device,torch.float16),
221
+ pooled_prompt_embeds=pooled_prompt_embeds.to(device,torch.float16),
222
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(device,torch.float16),
223
+ num_inference_steps=denoise_steps,
224
+ generator=generator,
225
+ strength = 1.0,
226
+ pose_img = pose_img.to(device,torch.float16),
227
+ text_embeds_cloth=prompt_embeds_c.to(device,torch.float16),
228
+ cloth = garm_tensor.to(device,torch.float16),
229
+ mask_image=mask,
230
+ image=human_img,
231
+ height=1024,
232
+ width=768,
233
+ ip_adapter_image = garm_img.resize((768,1024)),
234
+ guidance_scale=2.0,
235
+ )[0]
236
 
237
  if is_checked_crop:
238
+ out_img = images[0].resize(crop_size)
239
+ human_img_orig.paste(out_img, (int(left), int(top)))
240
+ return human_img_orig, mask_gray
241
  else:
242
+ return images[0], mask_gray
243
+ # return images[0], mask_gray
244
 
245
+ garm_list = os.listdir(os.path.join(example_path,"cloth"))
246
+ garm_list_path = [os.path.join(example_path,"cloth",garm) for garm in garm_list]
 
247
 
248
+ human_list = os.listdir(os.path.join(example_path,"human"))
249
+ human_list_path = [os.path.join(example_path,"human",human) for human in human_list]
250
 
251
  human_ex_list = []
252
  for ex_human in human_list_path:
253
+ ex_dict= {}
254
+ ex_dict['background'] = ex_human
255
+ ex_dict['layers'] = None
256
+ ex_dict['composite'] = None
257
  human_ex_list.append(ex_dict)
258
 
259
+ ##default human
260
+
261
+
262
  image_blocks = gr.Blocks().queue()
263
  with image_blocks as demo:
264
  with gr.Row():
 
265
  with gr.Column():
266
+ imgs = gr.ImageEditor(sources='upload', type="pil", label='Human. Mask with pen or use auto-masking', interactive=True)
267
+ with gr.Row():
268
+ is_checked = gr.Checkbox(label="Yes", info="Use auto-generated mask (Takes 5 seconds)",value=True)
269
+ with gr.Row():
270
+ is_checked_crop = gr.Checkbox(label="Yes", info="Use auto-crop & resizing",value=False)
271
 
272
+ example = gr.Examples(
273
+ inputs=imgs,
274
+ examples_per_page=10,
275
+ examples=human_ex_list
276
+ )
277
 
 
278
  with gr.Column():
279
+ garm_img = gr.Image(label="Garment", sources='upload', type="pil")
280
+ with gr.Row(elem_id="prompt-container"):
281
+ with gr.Row():
282
+ prompt = gr.Textbox(placeholder="Description of garment ex) Short Sleeve Round Neck T-shirts", show_label=False, elem_id="prompt")
283
+ example = gr.Examples(
284
+ inputs=garm_img,
285
+ examples_per_page=8,
286
+ examples=garm_list_path)
287
  with gr.Column():
288
+ # image_out = gr.Image(label="Output", elem_id="output-img", height=400)
289
+ image_out = gr.Image(label="Output", elem_id="output-img",show_share_button=False)
290
+
291
+
292
 
 
 
 
 
 
 
293
 
294
+ with gr.Column():
295
+ try_button = gr.Button(value="Try-on")
296
+ with gr.Accordion(label="Advanced Settings", open=False):
297
+ with gr.Row():
298
+ denoise_steps = gr.Number(label="Denoising Steps", minimum=20, maximum=40, value=30, step=1)
299
+ seed = gr.Number(label="Seed", minimum=-1, maximum=2147483647, step=1, value=42)
300
+
301
+
302
+
303
+ try_button.click(fn=start_tryon, inputs=[imgs, garm_img, prompt, is_checked,is_checked_crop, denoise_steps, seed], outputs=[image_out], api_name='tryon')
304
+
305
+
306
+
307
 
 
308
  image_blocks.launch()
309
+
src/tryon_pipeline.py CHANGED
@@ -1587,23 +1587,16 @@ class StableDiffusionXLInpaintPipeline(
1587
  )
1588
  init_image = init_image.to(dtype=torch.float32)
1589
 
1590
- if mask_image is not None:
1591
- mask = self.mask_processor.preprocess(
1592
- mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
1593
- )
1594
- else:
1595
- mask = None # No mask provided
1596
-
1597
  if masked_image_latents is not None:
1598
  masked_image = masked_image_latents
1599
  elif init_image.shape[1] == 4:
1600
  # if images are in latent space, we can't mask it
1601
  masked_image = None
1602
  else:
1603
- if mask is not None:
1604
- masked_image = init_image * (mask < 0.5)
1605
- else:
1606
- masked_image = None
1607
 
1608
  # 6. Prepare latent variables
1609
  num_channels_latents = self.vae.config.latent_channels
 
1587
  )
1588
  init_image = init_image.to(dtype=torch.float32)
1589
 
1590
+ mask = self.mask_processor.preprocess(
1591
+ mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
1592
+ )
 
 
 
 
1593
  if masked_image_latents is not None:
1594
  masked_image = masked_image_latents
1595
  elif init_image.shape[1] == 4:
1596
  # if images are in latent space, we can't mask it
1597
  masked_image = None
1598
  else:
1599
+ masked_image = init_image * (mask < 0.5)
 
 
 
1600
 
1601
  # 6. Prepare latent variables
1602
  num_channels_latents = self.vae.config.latent_channels