Chaerin5 commited on
Commit
32fa016
·
1 Parent(s): 0ae1eb4

fix vae nan bug

Browse files
Files changed (1) hide show
  1. app.py +26 -22
app.py CHANGED
@@ -228,29 +228,32 @@ if NEW_MODEL:
228
  print(f"encoder after load_state_dict parameters max: {max([p.max() for p in autoencoder.encoder.parameters()])}")
229
  autoencoder = autoencoder.to(device)
230
  autoencoder.eval()
 
 
 
231
  assert len(missing_keys) == 0
232
- else:
233
- opts = HandDiffOpts()
234
- model_path = './finetune_epoch=5-step=130000.ckpt'
235
- sd_path = './sd-v1-4.ckpt'
236
- print('Load diffusion model...')
237
- diffusion = create_diffusion(str(opts.test_sampling_steps))
238
- model = vit.DiT_XL_2(
239
- input_size=opts.latent_size[0],
240
- latent_dim=opts.latent_dim,
241
- in_channels=opts.latent_dim+opts.n_keypoints+opts.n_mask,
242
- learn_sigma=True,
243
- ).to(device)
244
- ckpt_state_dict = torch.load(model_path)['state_dict']
245
- dit_state_dict = {remove_prefix(k, 'diffusion_backbone.'): v for k, v in ckpt_state_dict.items() if k.startswith('diffusion_backbone')}
246
- vae_state_dict = {remove_prefix(k, 'autoencoder.'): v for k, v in ckpt_state_dict.items() if k.startswith('autoencoder')}
247
- missing_keys, extra_keys = model.load_state_dict(dit_state_dict, strict=False)
248
- model.eval()
249
- assert len(missing_keys) == 0 and len(extra_keys) == 0
250
- autoencoder = vqvae.create_model(3, 3, opts.latent_dim).eval().requires_grad_(False).to(device)
251
- missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False)
252
- autoencoder.eval()
253
- assert len(missing_keys) == 0 and len(extra_keys) == 0
254
  sam_path = hf_hub_download(repo_id="Chaerin5/FoundHand-weights", filename="sam_vit_h_4b8939.pth", token=token)
255
  sam_predictor = init_sam(ckpt_path=sam_path, device='cpu')
256
 
@@ -492,6 +495,7 @@ def get_ref_anno(ref):
492
  print(f"opts.latent_scaling_factor: {opts.latent_scaling_factor}")
493
  print(f"autoencoder encoder before operating max: {min([p.min() for p in autoencoder.encoder.parameters()])}")
494
  print(f"autoencoder encoder before operating min: {max([p.max() for p in autoencoder.encoder.parameters()])}")
 
495
  latent = opts.latent_scaling_factor * autoencoder.encode(image).sample()
496
  print(f"latent.max(): {latent.max()}, latent.min(): {latent.min()}")
497
  if not REF_POSE_MASK:
 
228
  print(f"encoder after load_state_dict parameters max: {max([p.max() for p in autoencoder.encoder.parameters()])}")
229
  autoencoder = autoencoder.to(device)
230
  autoencoder.eval()
231
+ print(f"encoder after eval() min: {min([p.min() for p in autoencoder.encoder.parameters()])}")
232
+ print(f"encoder after eval() max: {max([p.max() for p in autoencoder.encoder.parameters()])}")
233
+ print(f"autoencoder encoder after eval() dtype: {next(autoencoder.encoder.parameters()).dtype}")
234
  assert len(missing_keys) == 0
235
+ # else:
236
+ # opts = HandDiffOpts()
237
+ # model_path = './finetune_epoch=5-step=130000.ckpt'
238
+ # sd_path = './sd-v1-4.ckpt'
239
+ # print('Load diffusion model...')
240
+ # diffusion = create_diffusion(str(opts.test_sampling_steps))
241
+ # model = vit.DiT_XL_2(
242
+ # input_size=opts.latent_size[0],
243
+ # latent_dim=opts.latent_dim,
244
+ # in_channels=opts.latent_dim+opts.n_keypoints+opts.n_mask,
245
+ # learn_sigma=True,
246
+ # ).to(device)
247
+ # ckpt_state_dict = torch.load(model_path)['state_dict']
248
+ # dit_state_dict = {remove_prefix(k, 'diffusion_backbone.'): v for k, v in ckpt_state_dict.items() if k.startswith('diffusion_backbone')}
249
+ # vae_state_dict = {remove_prefix(k, 'autoencoder.'): v for k, v in ckpt_state_dict.items() if k.startswith('autoencoder')}
250
+ # missing_keys, extra_keys = model.load_state_dict(dit_state_dict, strict=False)
251
+ # model.eval()
252
+ # assert len(missing_keys) == 0 and len(extra_keys) == 0
253
+ # autoencoder = vqvae.create_model(3, 3, opts.latent_dim).eval().requires_grad_(False).to(device)
254
+ # missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False)
255
+ # autoencoder.eval()
256
+ # assert len(missing_keys) == 0 and len(extra_keys) == 0
257
  sam_path = hf_hub_download(repo_id="Chaerin5/FoundHand-weights", filename="sam_vit_h_4b8939.pth", token=token)
258
  sam_predictor = init_sam(ckpt_path=sam_path, device='cpu')
259
 
 
495
  print(f"opts.latent_scaling_factor: {opts.latent_scaling_factor}")
496
  print(f"autoencoder encoder before operating max: {min([p.min() for p in autoencoder.encoder.parameters()])}")
497
  print(f"autoencoder encoder before operating min: {max([p.max() for p in autoencoder.encoder.parameters()])}")
498
+ print(f"autoencoder encoder before operating dtype: {next(autoencoder.encoder.parameters()).dtype}")
499
  latent = opts.latent_scaling_factor * autoencoder.encode(image).sample()
500
  print(f"latent.max(): {latent.max()}, latent.min(): {latent.min()}")
501
  if not REF_POSE_MASK: