Spaces:
Runtime error
Runtime error
enable zerogpu
Browse files
app.py
CHANGED
@@ -210,13 +210,13 @@ if NEW_MODEL:
|
|
210 |
learn_sigma=True,
|
211 |
).to(device)
|
212 |
# ckpt_state_dict = torch.load(model_path)['model_state_dict']
|
213 |
-
ckpt_state_dict = torch.load(model_path, map_location=
|
214 |
missing_keys, extra_keys = model.load_state_dict(ckpt_state_dict, strict=False)
|
215 |
model.eval()
|
216 |
print(missing_keys, extra_keys)
|
217 |
assert len(missing_keys) == 0
|
218 |
vae_state_dict = torch.load(vae_path)['state_dict']
|
219 |
-
autoencoder = vqvae.create_model(3, 3, opts.latent_dim).eval().requires_grad_(False).to(device)
|
220 |
missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False)
|
221 |
autoencoder.eval()
|
222 |
assert len(missing_keys) == 0
|
@@ -243,7 +243,7 @@ else:
|
|
243 |
autoencoder.eval()
|
244 |
assert len(missing_keys) == 0 and len(extra_keys) == 0
|
245 |
sam_path = hf_hub_download(repo_id="Chaerin5/FoundHand-weights", filename="sam_vit_h_4b8939.pth", token=token)
|
246 |
-
sam_predictor = init_sam(ckpt_path=sam_path, device=
|
247 |
|
248 |
|
249 |
print("Mediapipe hand detector and SAM ready...")
|
@@ -254,7 +254,7 @@ hands = mp_hands.Hands(
|
|
254 |
min_detection_confidence=0.1,
|
255 |
)
|
256 |
|
257 |
-
|
258 |
def get_ref_anno(ref):
|
259 |
if ref is None:
|
260 |
return (
|
@@ -301,6 +301,7 @@ def get_ref_anno(ref):
|
|
301 |
elif keypts[21].sum() != 0:
|
302 |
input_point = np.array(keypts[21:22])
|
303 |
input_label = np.array([1])
|
|
|
304 |
masks, _, _ = sam_predictor.predict(
|
305 |
point_coords=input_point,
|
306 |
point_labels=input_label,
|
|
|
210 |
learn_sigma=True,
|
211 |
).to(device)
|
212 |
# ckpt_state_dict = torch.load(model_path)['model_state_dict']
|
213 |
+
ckpt_state_dict = torch.load(model_path, map_location='cpu')['ema_state_dict']
|
214 |
missing_keys, extra_keys = model.load_state_dict(ckpt_state_dict, strict=False)
|
215 |
model.eval()
|
216 |
print(missing_keys, extra_keys)
|
217 |
assert len(missing_keys) == 0
|
218 |
vae_state_dict = torch.load(vae_path)['state_dict']
|
219 |
+
autoencoder = vqvae.create_model(3, 3, opts.latent_dim).eval().requires_grad_(False) # .to(device)
|
220 |
missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False)
|
221 |
autoencoder.eval()
|
222 |
assert len(missing_keys) == 0
|
|
|
243 |
autoencoder.eval()
|
244 |
assert len(missing_keys) == 0 and len(extra_keys) == 0
|
245 |
sam_path = hf_hub_download(repo_id="Chaerin5/FoundHand-weights", filename="sam_vit_h_4b8939.pth", token=token)
|
246 |
+
sam_predictor = init_sam(ckpt_path=sam_path, device='cpu')
|
247 |
|
248 |
|
249 |
print("Mediapipe hand detector and SAM ready...")
|
|
|
254 |
min_detection_confidence=0.1,
|
255 |
)
|
256 |
|
257 |
+
@spaces.GPU(duration=120)
|
258 |
def get_ref_anno(ref):
|
259 |
if ref is None:
|
260 |
return (
|
|
|
301 |
elif keypts[21].sum() != 0:
|
302 |
input_point = np.array(keypts[21:22])
|
303 |
input_label = np.array([1])
|
304 |
+
print("ready to run SAM")
|
305 |
masks, _, _ = sam_predictor.predict(
|
306 |
point_coords=input_point,
|
307 |
point_labels=input_label,
|