vinthony commited on
Commit
0a11942
1 Parent(s): 1cd9fe9

Update src/face3d/extract_kp_videos.py

Browse files
Files changed (1) hide show
  1. src/face3d/extract_kp_videos.py +3 -3
src/face3d/extract_kp_videos.py CHANGED
@@ -12,8 +12,8 @@ from itertools import cycle
12
  from torch.multiprocessing import Pool, Process, set_start_method
13
 
14
  class KeypointExtractor():
15
- def __init__(self):
16
- self.detector = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D)
17
 
18
  def extract_keypoint(self, images, name=None, info=True):
19
  if isinstance(images, list):
@@ -71,7 +71,7 @@ def read_video(filename):
71
  def run(data):
72
  filename, opt, device = data
73
  os.environ['CUDA_VISIBLE_DEVICES'] = device
74
- kp_extractor = KeypointExtractor()
75
  images = read_video(filename)
76
  name = filename.split('/')[-2:]
77
  os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True)
 
12
  from torch.multiprocessing import Pool, Process, set_start_method
13
 
14
  class KeypointExtractor():
15
+ def __init__(self, device):
16
+ self.detector = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, device=device)
17
 
18
  def extract_keypoint(self, images, name=None, info=True):
19
  if isinstance(images, list):
 
71
  def run(data):
72
  filename, opt, device = data
73
  os.environ['CUDA_VISIBLE_DEVICES'] = device
74
+ kp_extractor = KeypointExtractor(device)
75
  images = read_video(filename)
76
  name = filename.split('/')[-2:]
77
  os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True)