hysts HF staff commited on
Commit
86f1fe1
·
1 Parent(s): 384ccc2
Files changed (1) hide show
  1. app.py +30 -2
app.py CHANGED
@@ -6,12 +6,12 @@ import os
6
  import pathlib
7
  import sys
8
  import urllib.request
9
- from typing import Union
10
 
11
  import cv2
12
  import gradio as gr
13
  import numpy as np
14
  import torch
 
15
 
16
  sys.path.insert(0, "face_detection")
17
 
@@ -20,7 +20,35 @@ from ibug.face_detection import RetinaFacePredictor, S3FDPredictor
20
  DESCRIPTION = "# [ibug-group/face_detection](https://github.com/ibug-group/face_detection)"
21
 
22
 
23
- def load_model(model_name: str, threshold: float, device: torch.device) -> Union[RetinaFacePredictor, S3FDPredictor]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  if model_name == "s3fd":
25
  model = S3FDPredictor(threshold=threshold, device=device)
26
  else:
 
6
  import pathlib
7
  import sys
8
  import urllib.request
 
9
 
10
  import cv2
11
  import gradio as gr
12
  import numpy as np
13
  import torch
14
+ from huggingface_hub import hf_hub_download
15
 
16
  sys.path.insert(0, "face_detection")
17
 
 
20
  DESCRIPTION = "# [ibug-group/face_detection](https://github.com/ibug-group/face_detection)"
21
 
22
 
23
+ def is_lfs_pointer_file(path: pathlib.Path) -> bool:
24
+ try:
25
+ with open(path, "r") as f:
26
+ # Git LFS pointer files usually start with version line
27
+ version_line = f.readline()
28
+ if version_line.startswith("version https://git-lfs.github.com/spec/"):
29
+ # Check for the presence of oid and size lines
30
+ oid_line = f.readline()
31
+ size_line = f.readline()
32
+ if oid_line.startswith("oid sha256:") and size_line.startswith("size "):
33
+ return True
34
+ except Exception as e:
35
+ print(f"Error reading file {path}: {e}")
36
+ return False
37
+
38
+
39
+ lfs_model_path = pathlib.Path("face_detection/ibug/face_detection/retina_face/weights/Resnet50_Final.pth")
40
+ if is_lfs_pointer_file(lfs_model_path):
41
+ os.remove(lfs_model_path)
42
+ out_path = hf_hub_download(
43
+ "public-data/ibug-face-detection",
44
+ filename=lfs_model_path.name,
45
+ repo_type="model",
46
+ subfolder="retina_face",
47
+ )
48
+ os.symlink(out_path, lfs_model_path)
49
+
50
+
51
+ def load_model(model_name: str, threshold: float, device: torch.device) -> RetinaFacePredictor | S3FDPredictor:
52
  if model_name == "s3fd":
53
  model = S3FDPredictor(threshold=threshold, device=device)
54
  else: