Chaerin5 commited on
Commit
735c5d1
·
1 Parent(s): 349b8de

enable zerogpu

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -210,13 +210,13 @@ if NEW_MODEL:
210
  learn_sigma=True,
211
  ).to(device)
212
  # ckpt_state_dict = torch.load(model_path)['model_state_dict']
213
- ckpt_state_dict = torch.load(model_path, map_location=torch.device(device))['ema_state_dict']
214
  missing_keys, extra_keys = model.load_state_dict(ckpt_state_dict, strict=False)
215
  model.eval()
216
  print(missing_keys, extra_keys)
217
  assert len(missing_keys) == 0
218
  vae_state_dict = torch.load(vae_path)['state_dict']
219
- autoencoder = vqvae.create_model(3, 3, opts.latent_dim).eval().requires_grad_(False).to(device)
220
  missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False)
221
  autoencoder.eval()
222
  assert len(missing_keys) == 0
@@ -243,7 +243,7 @@ else:
243
  autoencoder.eval()
244
  assert len(missing_keys) == 0 and len(extra_keys) == 0
245
  sam_path = hf_hub_download(repo_id="Chaerin5/FoundHand-weights", filename="sam_vit_h_4b8939.pth", token=token)
246
- sam_predictor = init_sam(ckpt_path=sam_path, device=device)
247
 
248
 
249
  print("Mediapipe hand detector and SAM ready...")
@@ -254,7 +254,7 @@ hands = mp_hands.Hands(
254
  min_detection_confidence=0.1,
255
  )
256
 
257
- # @spaces.GPU(duration=120)
258
  def get_ref_anno(ref):
259
  if ref is None:
260
  return (
@@ -301,6 +301,7 @@ def get_ref_anno(ref):
301
  elif keypts[21].sum() != 0:
302
  input_point = np.array(keypts[21:22])
303
  input_label = np.array([1])
 
304
  masks, _, _ = sam_predictor.predict(
305
  point_coords=input_point,
306
  point_labels=input_label,
 
210
  learn_sigma=True,
211
  ).to(device)
212
  # ckpt_state_dict = torch.load(model_path)['model_state_dict']
213
+ ckpt_state_dict = torch.load(model_path, map_location='cpu')['ema_state_dict']
214
  missing_keys, extra_keys = model.load_state_dict(ckpt_state_dict, strict=False)
215
  model.eval()
216
  print(missing_keys, extra_keys)
217
  assert len(missing_keys) == 0
218
  vae_state_dict = torch.load(vae_path)['state_dict']
219
+ autoencoder = vqvae.create_model(3, 3, opts.latent_dim).eval().requires_grad_(False) # .to(device)
220
  missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False)
221
  autoencoder.eval()
222
  assert len(missing_keys) == 0
 
243
  autoencoder.eval()
244
  assert len(missing_keys) == 0 and len(extra_keys) == 0
245
  sam_path = hf_hub_download(repo_id="Chaerin5/FoundHand-weights", filename="sam_vit_h_4b8939.pth", token=token)
246
+ sam_predictor = init_sam(ckpt_path=sam_path, device='cpu')
247
 
248
 
249
  print("Mediapipe hand detector and SAM ready...")
 
254
  min_detection_confidence=0.1,
255
  )
256
 
257
+ @spaces.GPU(duration=120)
258
  def get_ref_anno(ref):
259
  if ref is None:
260
  return (
 
301
  elif keypts[21].sum() != 0:
302
  input_point = np.array(keypts[21:22])
303
  input_label = np.array([1])
304
+ print("ready to run SAM")
305
  masks, _, _ = sam_predictor.predict(
306
  point_coords=input_point,
307
  point_labels=input_label,