Spaces:
Running
on
A10G
Running
on
A10G
Update src/face3d/extract_kp_videos.py
Browse files
src/face3d/extract_kp_videos.py
CHANGED
@@ -12,8 +12,8 @@ from itertools import cycle
|
|
12 |
from torch.multiprocessing import Pool, Process, set_start_method
|
13 |
|
14 |
class KeypointExtractor():
|
15 |
-
def __init__(self):
|
16 |
-
self.detector = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D)
|
17 |
|
18 |
def extract_keypoint(self, images, name=None, info=True):
|
19 |
if isinstance(images, list):
|
@@ -71,7 +71,7 @@ def read_video(filename):
|
|
71 |
def run(data):
|
72 |
filename, opt, device = data
|
73 |
os.environ['CUDA_VISIBLE_DEVICES'] = device
|
74 |
-
kp_extractor = KeypointExtractor()
|
75 |
images = read_video(filename)
|
76 |
name = filename.split('/')[-2:]
|
77 |
os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True)
|
|
|
12 |
from torch.multiprocessing import Pool, Process, set_start_method
|
13 |
|
14 |
class KeypointExtractor():
|
15 |
+
def __init__(self, device):
|
16 |
+
self.detector = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, device=device)
|
17 |
|
18 |
def extract_keypoint(self, images, name=None, info=True):
|
19 |
if isinstance(images, list):
|
|
|
71 |
def run(data):
|
72 |
filename, opt, device = data
|
73 |
os.environ['CUDA_VISIBLE_DEVICES'] = device
|
74 |
+
kp_extractor = KeypointExtractor(device)
|
75 |
images = read_video(filename)
|
76 |
name = filename.split('/')[-2:]
|
77 |
os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True)
|