cocktailpeanut commited on
Commit
ec0e0d5
β€’
1 Parent(s): d1b3d39
Files changed (1) hide show
  1. app.py +155 -72
app.py CHANGED
@@ -27,6 +27,7 @@ from torchvision.transforms.functional import to_pil_image
27
 
28
  import devicetorch
29
 
 
30
 
31
  def pil_to_binary_mask(pil_image, threshold=0):
32
  np_image = np.array(pil_image)
@@ -45,10 +46,12 @@ def pil_to_binary_mask(pil_image, threshold=0):
45
  base_path = 'yisol/IDM-VTON'
46
  example_path = os.path.join(os.path.dirname(__file__), 'example')
47
 
 
48
  unet = UNet2DConditionModel.from_pretrained(
49
  base_path,
50
  subfolder="unet",
51
- torch_dtype=torch.float16,
 
52
  )
53
  unet.requires_grad_(False)
54
  tokenizer_one = AutoTokenizer.from_pretrained(
@@ -68,28 +71,33 @@ noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler"
68
  text_encoder_one = CLIPTextModel.from_pretrained(
69
  base_path,
70
  subfolder="text_encoder",
71
- torch_dtype=torch.float16,
 
72
  )
73
  text_encoder_two = CLIPTextModelWithProjection.from_pretrained(
74
  base_path,
75
  subfolder="text_encoder_2",
76
- torch_dtype=torch.float16,
 
77
  )
78
  image_encoder = CLIPVisionModelWithProjection.from_pretrained(
79
  base_path,
80
  subfolder="image_encoder",
81
- torch_dtype=torch.float16,
 
82
  )
83
  vae = AutoencoderKL.from_pretrained(base_path,
84
  subfolder="vae",
85
- torch_dtype=torch.float16,
 
86
  )
87
 
88
  # "stabilityai/stable-diffusion-xl-base-1.0",
89
  UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(
90
  base_path,
91
  subfolder="unet_encoder",
92
- torch_dtype=torch.float16,
 
93
  )
94
 
95
  parsing_model = Parsing(0)
@@ -119,7 +127,8 @@ pipe = TryonPipeline.from_pretrained(
119
  tokenizer_2 = tokenizer_two,
120
  scheduler = noise_scheduler,
121
  image_encoder=image_encoder,
122
- torch_dtype=torch.float16,
 
123
  )
124
  pipe.unet_encoder = UNet_Encoder
125
 
@@ -127,14 +136,12 @@ pipe.unet_encoder = UNet_Encoder
127
  def start_tryon(dict,garm_img,garment_des,is_checked,is_checked_crop,denoise_steps,seed):
128
  #device = "cuda"
129
  device = devicetorch.get(torch)
130
-
131
  openpose_model.preprocessor.body_estimation.model.to(device)
132
  pipe.to(device)
133
  pipe.unet_encoder.to(device)
134
 
135
  garm_img= garm_img.convert("RGB").resize((768,1024))
136
- human_img_orig = dict["background"].convert("RGB")
137
-
138
  if is_checked_crop:
139
  width, height = human_img_orig.size
140
  target_width = int(min(width, height * (3 / 4)))
@@ -148,8 +155,6 @@ def start_tryon(dict,garm_img,garment_des,is_checked,is_checked_crop,denoise_ste
148
  human_img = cropped_img.resize((768,1024))
149
  else:
150
  human_img = human_img_orig.resize((768,1024))
151
-
152
-
153
  if is_checked:
154
  keypoints = openpose_model(human_img.resize((384,512)))
155
  model_parse, _ = parsing_model(human_img.resize((384,512)))
@@ -165,82 +170,161 @@ def start_tryon(dict,garm_img,garment_des,is_checked,is_checked_crop,denoise_ste
165
 
166
  human_img_arg = _apply_exif_orientation(human_img.resize((384,512)))
167
  human_img_arg = convert_PIL_to_numpy(human_img_arg, format="BGR")
168
-
169
-
170
-
171
  #args = apply_net.create_argument_parser().parse_args(('show', './configs/densepose_rcnn_R_50_FPN_s1x.yaml', './ckpt/densepose/model_final_162be9.pkl', 'dp_segm', '-v', '--opts', 'MODEL.DEVICE', 'cuda'))
172
- args = apply_net.create_argument_parser().parse_args(('show', './configs/densepose_rcnn_R_50_FPN_s1x.yaml', './ckpt/densepose/model_final_162be9.pkl', 'dp_segm', '-v', '--opts', 'MODEL.DEVICE', device))
 
 
 
 
173
  # verbosity = getattr(args, "verbosity", None)
174
- pose_img = args.func(args,human_img_arg)
175
- pose_img = pose_img[:,:,::-1]
176
  pose_img = Image.fromarray(pose_img).resize((768,1024))
177
-
 
178
  with torch.no_grad():
179
  # Extract the images
180
- with torch.autocast(device_type=device):
181
- #with torch.cuda.amp.autocast():
182
- with torch.no_grad():
183
- prompt = "model is wearing " + garment_des
184
- negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
185
- with torch.inference_mode():
186
- (
187
- prompt_embeds,
188
- negative_prompt_embeds,
189
- pooled_prompt_embeds,
190
- negative_pooled_prompt_embeds,
191
- ) = pipe.encode_prompt(
192
- prompt,
193
- num_images_per_prompt=1,
194
- do_classifier_free_guidance=True,
195
- negative_prompt=negative_prompt,
196
- )
197
-
198
- prompt = "a photo of " + garment_des
199
  negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
200
- if not isinstance(prompt, List):
201
- prompt = [prompt] * 1
202
- if not isinstance(negative_prompt, List):
203
- negative_prompt = [negative_prompt] * 1
204
  with torch.inference_mode():
205
  (
206
- prompt_embeds_c,
207
- _,
208
- _,
209
- _,
210
  ) = pipe.encode_prompt(
211
  prompt,
212
  num_images_per_prompt=1,
213
- do_classifier_free_guidance=False,
214
  negative_prompt=negative_prompt,
215
  )
216
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
 
219
- pose_img = tensor_transfrom(pose_img).unsqueeze(0).to(device,torch.float16)
220
- garm_tensor = tensor_transfrom(garm_img).unsqueeze(0).to(device,torch.float16)
221
- generator = torch.Generator(device).manual_seed(seed) if seed is not None else None
222
- images = pipe(
223
- prompt_embeds=prompt_embeds.to(device,torch.float16),
224
- negative_prompt_embeds=negative_prompt_embeds.to(device,torch.float16),
225
- pooled_prompt_embeds=pooled_prompt_embeds.to(device,torch.float16),
226
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(device,torch.float16),
227
- num_inference_steps=denoise_steps,
228
- generator=generator,
229
- strength = 1.0,
230
- pose_img = pose_img.to(device,torch.float16),
231
- text_embeds_cloth=prompt_embeds_c.to(device,torch.float16),
232
- cloth = garm_tensor.to(device,torch.float16),
233
- mask_image=mask,
234
- image=human_img,
235
- height=1024,
236
- width=768,
237
- ip_adapter_image = garm_img.resize((768,1024)),
238
- guidance_scale=2.0,
239
- )[0]
 
 
 
 
 
 
 
 
 
240
 
241
  if is_checked_crop:
242
- out_img = images[0].resize(crop_size)
243
- human_img_orig.paste(out_img, (int(left), int(top)))
244
  return human_img_orig, mask_gray
245
  else:
246
  return images[0], mask_gray
@@ -311,8 +395,7 @@ with image_blocks as demo:
311
 
312
  try_button.click(fn=start_tryon, inputs=[imgs, garm_img, prompt, is_checked,is_checked_crop, denoise_steps, seed], outputs=[image_out,masked_img], api_name='tryon')
313
 
314
-
315
 
316
 
317
- image_blocks.launch()
318
 
 
 
27
 
28
  import devicetorch
29
 
30
+ torch_dtype = devicetorch.dtype(torch)
31
 
32
  def pil_to_binary_mask(pil_image, threshold=0):
33
  np_image = np.array(pil_image)
 
46
  base_path = 'yisol/IDM-VTON'
47
  example_path = os.path.join(os.path.dirname(__file__), 'example')
48
 
49
+ dtype = devicetorch.dtype(torch)
50
  unet = UNet2DConditionModel.from_pretrained(
51
  base_path,
52
  subfolder="unet",
53
+ #torch_dtype=torch.float16,
54
+ torch_dtype=dtype,
55
  )
56
  unet.requires_grad_(False)
57
  tokenizer_one = AutoTokenizer.from_pretrained(
 
71
  text_encoder_one = CLIPTextModel.from_pretrained(
72
  base_path,
73
  subfolder="text_encoder",
74
+ #torch_dtype=torch.float16,
75
+ torch_dtype=dtype,
76
  )
77
  text_encoder_two = CLIPTextModelWithProjection.from_pretrained(
78
  base_path,
79
  subfolder="text_encoder_2",
80
+ #torch_dtype=torch.float16,
81
+ torch_dtype=dtype,
82
  )
83
  image_encoder = CLIPVisionModelWithProjection.from_pretrained(
84
  base_path,
85
  subfolder="image_encoder",
86
+ #torch_dtype=torch.float16,
87
+ torch_dtype=dtype,
88
  )
89
  vae = AutoencoderKL.from_pretrained(base_path,
90
  subfolder="vae",
91
+ #torch_dtype=torch.float16,
92
+ torch_dtype=dtype,
93
  )
94
 
95
  # "stabilityai/stable-diffusion-xl-base-1.0",
96
  UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(
97
  base_path,
98
  subfolder="unet_encoder",
99
+ #torch_dtype=torch.float16,
100
+ torch_dtype=dtype,
101
  )
102
 
103
  parsing_model = Parsing(0)
 
127
  tokenizer_2 = tokenizer_two,
128
  scheduler = noise_scheduler,
129
  image_encoder=image_encoder,
130
+ #torch_dtype=torch.float16,
131
+ torch_dtype=dtype,
132
  )
133
  pipe.unet_encoder = UNet_Encoder
134
 
 
136
  def start_tryon(dict,garm_img,garment_des,is_checked,is_checked_crop,denoise_steps,seed):
137
  #device = "cuda"
138
  device = devicetorch.get(torch)
 
139
  openpose_model.preprocessor.body_estimation.model.to(device)
140
  pipe.to(device)
141
  pipe.unet_encoder.to(device)
142
 
143
  garm_img= garm_img.convert("RGB").resize((768,1024))
144
+ human_img_orig = dict["background"].convert("RGB")
 
145
  if is_checked_crop:
146
  width, height = human_img_orig.size
147
  target_width = int(min(width, height * (3 / 4)))
 
155
  human_img = cropped_img.resize((768,1024))
156
  else:
157
  human_img = human_img_orig.resize((768,1024))
 
 
158
  if is_checked:
159
  keypoints = openpose_model(human_img.resize((384,512)))
160
  model_parse, _ = parsing_model(human_img.resize((384,512)))
 
170
 
171
  human_img_arg = _apply_exif_orientation(human_img.resize((384,512)))
172
  human_img_arg = convert_PIL_to_numpy(human_img_arg, format="BGR")
 
 
 
173
  #args = apply_net.create_argument_parser().parse_args(('show', './configs/densepose_rcnn_R_50_FPN_s1x.yaml', './ckpt/densepose/model_final_162be9.pkl', 'dp_segm', '-v', '--opts', 'MODEL.DEVICE', 'cuda'))
174
+
175
+ model_device = "cpu"
176
+ if device == "cuda":
177
+ model_device = "cuda"
178
+ args = apply_net.create_argument_parser().parse_args(('show', './configs/densepose_rcnn_R_50_FPN_s1x.yaml', './ckpt/densepose/model_final_162be9.pkl', 'dp_segm', '-v', '--opts', 'MODEL.DEVICE', model_device))
179
  # verbosity = getattr(args, "verbosity", None)
180
+ pose_img = args.func(args,human_img_arg)
181
+ pose_img = pose_img[:,:,::-1]
182
  pose_img = Image.fromarray(pose_img).resize((768,1024))
183
+ #pose_img = Image.fromarray(pose_img).resize((512, 768))
184
+
185
  with torch.no_grad():
186
  # Extract the images
187
+
188
+ if device == "cuda":
189
+ with torch.cuda.amp.autocast():
190
+ with torch.no_grad():
191
+ prompt = "model is wearing " + garment_des
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
 
 
 
 
193
  with torch.inference_mode():
194
  (
195
+ prompt_embeds,
196
+ negative_prompt_embeds,
197
+ pooled_prompt_embeds,
198
+ negative_pooled_prompt_embeds,
199
  ) = pipe.encode_prompt(
200
  prompt,
201
  num_images_per_prompt=1,
202
+ do_classifier_free_guidance=True,
203
  negative_prompt=negative_prompt,
204
  )
205
 
206
+ prompt = "a photo of " + garment_des
207
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
208
+ if not isinstance(prompt, List):
209
+ prompt = [prompt] * 1
210
+ if not isinstance(negative_prompt, List):
211
+ negative_prompt = [negative_prompt] * 1
212
+ with torch.inference_mode():
213
+ (
214
+ prompt_embeds_c,
215
+ _,
216
+ _,
217
+ _,
218
+ ) = pipe.encode_prompt(
219
+ prompt,
220
+ num_images_per_prompt=1,
221
+ do_classifier_free_guidance=False,
222
+ negative_prompt=negative_prompt,
223
+ )
224
+
225
+
226
+
227
+ #pose_img = tensor_transfrom(pose_img).unsqueeze(0).to(device,torch.float16)
228
+ pose_img = tensor_transfrom(pose_img).unsqueeze(0).to(device,dtype)
229
+ #garm_tensor = tensor_transfrom(garm_img).unsqueeze(0).to(device,torch.float16)
230
+ garm_tensor = tensor_transfrom(garm_img).unsqueeze(0).to(device,dtype)
231
+ generator = torch.Generator(device).manual_seed(seed) if seed is not None else None
232
+ images = pipe(
233
+ prompt_embeds=prompt_embeds.to(device,dtype),
234
+ #prompt_embeds=prompt_embeds.to(device,torch.float16),
235
+ negative_prompt_embeds=negative_prompt_embeds.to(device,dtype),
236
+ #negative_prompt_embeds=negative_prompt_embeds.to(device,torch.float16),
237
+ pooled_prompt_embeds=pooled_prompt_embeds.to(device,dtype),
238
+ #pooled_prompt_embeds=pooled_prompt_embeds.to(device,torch.float16),
239
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(device,dtype),
240
+ #negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(device,torch.float16),
241
+ num_inference_steps=denoise_steps,
242
+ generator=generator,
243
+ strength = 1.0,
244
+ #pose_img = pose_img.to(device,torch.float16),
245
+ pose_img = pose_img.to(device,dtype),
246
+ #text_embeds_cloth=prompt_embeds_c.to(device,torch.float16),
247
+ text_embeds_cloth=prompt_embeds_c.to(device,dtype),
248
+ #cloth = garm_tensor.to(device,torch.float16),
249
+ cloth = garm_tensor.to(device,dtype),
250
+ mask_image=mask,
251
+ image=human_img,
252
+ height=1024,
253
+ width=768,
254
+ ip_adapter_image = garm_img.resize((768,1024)),
255
+ guidance_scale=2.0,
256
+ )[0]
257
+ else:
258
+ prompt = "model is wearing " + garment_des
259
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
260
+ with torch.inference_mode():
261
+ (
262
+ prompt_embeds,
263
+ negative_prompt_embeds,
264
+ pooled_prompt_embeds,
265
+ negative_pooled_prompt_embeds,
266
+ ) = pipe.encode_prompt(
267
+ prompt,
268
+ num_images_per_prompt=1,
269
+ do_classifier_free_guidance=True,
270
+ negative_prompt=negative_prompt,
271
+ )
272
+
273
+ prompt = "a photo of " + garment_des
274
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
275
+ if not isinstance(prompt, List):
276
+ prompt = [prompt] * 1
277
+ if not isinstance(negative_prompt, List):
278
+ negative_prompt = [negative_prompt] * 1
279
+ with torch.inference_mode():
280
+ (
281
+ prompt_embeds_c,
282
+ _,
283
+ _,
284
+ _,
285
+ ) = pipe.encode_prompt(
286
+ prompt,
287
+ num_images_per_prompt=1,
288
+ do_classifier_free_guidance=False,
289
+ negative_prompt=negative_prompt,
290
+ )
291
+
292
 
293
 
294
+ #pose_img = tensor_transfrom(pose_img).unsqueeze(0).to(device,torch.float16)
295
+ pose_img = tensor_transfrom(pose_img).unsqueeze(0).to(device,dtype)
296
+ #garm_tensor = tensor_transfrom(garm_img).unsqueeze(0).to(device,torch.float16)
297
+ garm_tensor = tensor_transfrom(garm_img).unsqueeze(0).to(device,dtype)
298
+ generator = torch.Generator(device).manual_seed(seed) if seed is not None else None
299
+ images = pipe(
300
+ prompt_embeds=prompt_embeds.to(device,dtype),
301
+ #prompt_embeds=prompt_embeds.to(device,torch.float16),
302
+ negative_prompt_embeds=negative_prompt_embeds.to(device,dtype),
303
+ #negative_prompt_embeds=negative_prompt_embeds.to(device,torch.float16),
304
+ pooled_prompt_embeds=pooled_prompt_embeds.to(device,dtype),
305
+ #pooled_prompt_embeds=pooled_prompt_embeds.to(device,torch.float16),
306
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(device,dtype),
307
+ #negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(device,torch.float16),
308
+ num_inference_steps=denoise_steps,
309
+ generator=generator,
310
+ strength = 1.0,
311
+ #pose_img = pose_img.to(device,torch.float16),
312
+ pose_img = pose_img.to(device,dtype),
313
+ #text_embeds_cloth=prompt_embeds_c.to(device,torch.float16),
314
+ text_embeds_cloth=prompt_embeds_c.to(device,dtype),
315
+ #cloth = garm_tensor.to(device,torch.float16),
316
+ cloth = garm_tensor.to(device,dtype),
317
+ mask_image=mask,
318
+ image=human_img,
319
+ height=1024,
320
+ width=768,
321
+ ip_adapter_image = garm_img.resize((768,1024)),
322
+ guidance_scale=2.0,
323
+ )[0]
324
 
325
  if is_checked_crop:
326
+ out_img = images[0].resize(crop_size)
327
+ human_img_orig.paste(out_img, (int(left), int(top)))
328
  return human_img_orig, mask_gray
329
  else:
330
  return images[0], mask_gray
 
395
 
396
  try_button.click(fn=start_tryon, inputs=[imgs, garm_img, prompt, is_checked,is_checked_crop, denoise_steps, seed], outputs=[image_out,masked_img], api_name='tryon')
397
 
 
398
 
399
 
 
400
 
401
+ image_blocks.launch()