Spanicin commited on
Commit
5c012bf
·
verified ·
1 Parent(s): 1569822

Upload 77 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. videoretalking/third_part/face3d/checkpoints/model_name/test_opt.txt +34 -0
  2. videoretalking/third_part/face3d/coeff_detector.py +118 -0
  3. videoretalking/third_part/face3d/data/__init__.py +116 -0
  4. videoretalking/third_part/face3d/data/base_dataset.py +125 -0
  5. videoretalking/third_part/face3d/data/flist_dataset.py +125 -0
  6. videoretalking/third_part/face3d/data/image_folder.py +66 -0
  7. videoretalking/third_part/face3d/data/template_dataset.py +75 -0
  8. videoretalking/third_part/face3d/data_preparation.py +45 -0
  9. videoretalking/third_part/face3d/extract_kp_videos.py +109 -0
  10. videoretalking/third_part/face3d/face_recon_videos.py +157 -0
  11. videoretalking/third_part/face3d/models/__init__.py +67 -0
  12. videoretalking/third_part/face3d/models/arcface_torch/README.md +164 -0
  13. videoretalking/third_part/face3d/models/arcface_torch/backbones/__init__.py +25 -0
  14. videoretalking/third_part/face3d/models/arcface_torch/backbones/iresnet.py +187 -0
  15. videoretalking/third_part/face3d/models/arcface_torch/backbones/iresnet2060.py +176 -0
  16. videoretalking/third_part/face3d/models/arcface_torch/backbones/mobilefacenet.py +130 -0
  17. videoretalking/third_part/face3d/models/arcface_torch/configs/3millions.py +23 -0
  18. videoretalking/third_part/face3d/models/arcface_torch/configs/3millions_pfc.py +23 -0
  19. videoretalking/third_part/face3d/models/arcface_torch/configs/__init__.py +0 -0
  20. videoretalking/third_part/face3d/models/arcface_torch/configs/base.py +56 -0
  21. videoretalking/third_part/face3d/models/arcface_torch/configs/glint360k_mbf.py +26 -0
  22. videoretalking/third_part/face3d/models/arcface_torch/configs/glint360k_r100.py +26 -0
  23. videoretalking/third_part/face3d/models/arcface_torch/configs/glint360k_r18.py +26 -0
  24. videoretalking/third_part/face3d/models/arcface_torch/configs/glint360k_r34.py +26 -0
  25. videoretalking/third_part/face3d/models/arcface_torch/configs/glint360k_r50.py +26 -0
  26. videoretalking/third_part/face3d/models/arcface_torch/configs/ms1mv3_mbf.py +26 -0
  27. videoretalking/third_part/face3d/models/arcface_torch/configs/ms1mv3_r18.py +26 -0
  28. videoretalking/third_part/face3d/models/arcface_torch/configs/ms1mv3_r2060.py +26 -0
  29. videoretalking/third_part/face3d/models/arcface_torch/configs/ms1mv3_r34.py +26 -0
  30. videoretalking/third_part/face3d/models/arcface_torch/configs/ms1mv3_r50.py +26 -0
  31. videoretalking/third_part/face3d/models/arcface_torch/configs/speed.py +23 -0
  32. videoretalking/third_part/face3d/models/arcface_torch/dataset.py +124 -0
  33. videoretalking/third_part/face3d/models/arcface_torch/docs/eval.md +31 -0
  34. videoretalking/third_part/face3d/models/arcface_torch/docs/install.md +51 -0
  35. videoretalking/third_part/face3d/models/arcface_torch/docs/modelzoo.md +0 -0
  36. videoretalking/third_part/face3d/models/arcface_torch/docs/speed_benchmark.md +93 -0
  37. videoretalking/third_part/face3d/models/arcface_torch/eval/__init__.py +0 -0
  38. videoretalking/third_part/face3d/models/arcface_torch/eval/verification.py +407 -0
  39. videoretalking/third_part/face3d/models/arcface_torch/eval_ijbc.py +483 -0
  40. videoretalking/third_part/face3d/models/arcface_torch/inference.py +35 -0
  41. videoretalking/third_part/face3d/models/arcface_torch/losses.py +42 -0
  42. videoretalking/third_part/face3d/models/arcface_torch/onnx_helper.py +250 -0
  43. videoretalking/third_part/face3d/models/arcface_torch/onnx_ijbc.py +267 -0
  44. videoretalking/third_part/face3d/models/arcface_torch/partial_fc.py +222 -0
  45. videoretalking/third_part/face3d/models/arcface_torch/requirement.txt +5 -0
  46. videoretalking/third_part/face3d/models/arcface_torch/run.sh +2 -0
  47. videoretalking/third_part/face3d/models/arcface_torch/torch2onnx.py +59 -0
  48. videoretalking/third_part/face3d/models/arcface_torch/train.py +141 -0
  49. videoretalking/third_part/face3d/models/arcface_torch/utils/__init__.py +0 -0
  50. videoretalking/third_part/face3d/models/arcface_torch/utils/plot.py +72 -0
videoretalking/third_part/face3d/checkpoints/model_name/test_opt.txt ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ----------------- Options ---------------
2
+ add_image: True
3
+ bfm_folder: BFM
4
+ bfm_model: BFM_model_front.mat
5
+ camera_d: 10.0
6
+ center: 112.0
7
+ checkpoints_dir: ./checkpoints
8
+ dataset_mode: None
9
+ ddp_port: 12355
10
+ display_per_batch: True
11
+ epoch: 20 [default: latest]
12
+ eval_batch_nums: inf
13
+ focal: 1015.0
14
+ gpu_ids: 0
15
+ inference_batch_size: 8
16
+ init_path: checkpoints/init_model/resnet50-0676ba61.pth
17
+ input_dir: demo_video [default: None]
18
+ isTrain: False [default: None]
19
+ keypoint_dir: demo_cctv [default: None]
20
+ model: facerecon
21
+ name: model_name [default: face_recon]
22
+ net_recon: resnet50
23
+ output_dir: demo_cctv [default: mp4]
24
+ phase: test
25
+ save_split_files: False
26
+ suffix:
27
+ use_ddp: False [default: True]
28
+ use_last_fc: False
29
+ verbose: False
30
+ vis_batch_nums: 1
31
+ world_size: 1
32
+ z_far: 15.0
33
+ z_near: 5.0
34
+ ----------------- End -------------------
videoretalking/third_part/face3d/coeff_detector.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import numpy as np
4
+ from os import makedirs, name
5
+ from PIL import Image
6
+ from tqdm import tqdm
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ from face3d.options.inference_options import InferenceOptions
12
+ from face3d.models import create_model
13
+ from face3d.util.preprocess import align_img
14
+ from face3d.util.load_mats import load_lm3d
15
+ from face3d.extract_kp_videos import KeypointExtractor
16
+
17
+
18
+ class CoeffDetector(nn.Module):
19
+ def __init__(self, opt):
20
+ super().__init__()
21
+
22
+ self.model = create_model(opt)
23
+ self.model.setup(opt)
24
+ self.model.device = 'cuda'
25
+ self.model.parallelize()
26
+ self.model.eval()
27
+
28
+ self.lm3d_std = load_lm3d(opt.bfm_folder)
29
+
30
+ def forward(self, img, lm):
31
+
32
+ img, trans_params = self.image_transform(img, lm)
33
+
34
+ data_input = {
35
+ 'imgs': img[None],
36
+ }
37
+ self.model.set_input(data_input)
38
+ self.model.test()
39
+ pred_coeff = {key:self.model.pred_coeffs_dict[key].cpu().numpy() for key in self.model.pred_coeffs_dict}
40
+ pred_coeff = np.concatenate([
41
+ pred_coeff['id'],
42
+ pred_coeff['exp'],
43
+ pred_coeff['tex'],
44
+ pred_coeff['angle'],
45
+ pred_coeff['gamma'],
46
+ pred_coeff['trans'],
47
+ trans_params[None],
48
+ ], 1)
49
+
50
+ return {'coeff_3dmm':pred_coeff,
51
+ 'crop_img': Image.fromarray((img.cpu().permute(1, 2, 0).numpy()*255).astype(np.uint8))}
52
+
53
+ def image_transform(self, images, lm):
54
+ """
55
+ param:
56
+ images: -- PIL image
57
+ lm: -- numpy array
58
+ """
59
+ W,H = images.size
60
+ if np.mean(lm) == -1:
61
+ lm = (self.lm3d_std[:, :2]+1)/2.
62
+ lm = np.concatenate(
63
+ [lm[:, :1]*W, lm[:, 1:2]*H], 1
64
+ )
65
+ else:
66
+ lm[:, -1] = H - 1 - lm[:, -1]
67
+
68
+ trans_params, img, lm, _ = align_img(images, lm, self.lm3d_std)
69
+ img = torch.tensor(np.array(img)/255., dtype=torch.float32).permute(2, 0, 1)
70
+ trans_params = np.array([float(item) for item in np.hsplit(trans_params, 5)])
71
+ trans_params = torch.tensor(trans_params.astype(np.float32))
72
+ return img, trans_params
73
+
74
+ def get_data_path(root, keypoint_root):
75
+ filenames = list()
76
+ keypoint_filenames = list()
77
+
78
+ IMAGE_EXTENSIONS_LOWERCASE = {'jpg', 'png', 'jpeg', 'webp'}
79
+ IMAGE_EXTENSIONS = IMAGE_EXTENSIONS_LOWERCASE.union({f.upper() for f in IMAGE_EXTENSIONS_LOWERCASE})
80
+ extensions = IMAGE_EXTENSIONS
81
+
82
+ for ext in extensions:
83
+ filenames += glob.glob(f'{root}/*.{ext}', recursive=True)
84
+ filenames = sorted(filenames)
85
+ for filename in filenames:
86
+ name = os.path.splitext(os.path.basename(filename))[0]
87
+ keypoint_filenames.append(
88
+ os.path.join(keypoint_root, name + '.txt')
89
+ )
90
+ return filenames, keypoint_filenames
91
+
92
+
93
+ if __name__ == "__main__":
94
+ opt = InferenceOptions().parse()
95
+ coeff_detector = CoeffDetector(opt)
96
+ kp_extractor = KeypointExtractor()
97
+ image_names, keypoint_names = get_data_path(opt.input_dir, opt.keypoint_dir)
98
+ makedirs(opt.keypoint_dir, exist_ok=True)
99
+ makedirs(opt.output_dir, exist_ok=True)
100
+
101
+ for image_name, keypoint_name in tqdm(zip(image_names, keypoint_names)):
102
+ image = Image.open(image_name)
103
+ if not os.path.isfile(keypoint_name):
104
+ lm = kp_extractor.extract_keypoint(image, keypoint_name)
105
+ else:
106
+ lm = np.loadtxt(keypoint_name).astype(np.float32)
107
+ lm = lm.reshape([-1, 2])
108
+ predicted = coeff_detector(image, lm)
109
+ name = os.path.splitext(os.path.basename(image_name))[0]
110
+ np.savetxt(
111
+ "{}/{}_3dmm_coeff.txt".format(opt.output_dir, name),
112
+ predicted['coeff_3dmm'].reshape(-1))
113
+
114
+
115
+
116
+
117
+
118
+
videoretalking/third_part/face3d/data/__init__.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This package includes all the modules related to data loading and preprocessing
2
+
3
+ To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
4
+ You need to implement four functions:
5
+ -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
6
+ -- <__len__>: return the size of dataset.
7
+ -- <__getitem__>: get a data point from data loader.
8
+ -- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
9
+
10
+ Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
11
+ See our template dataset class 'template_dataset.py' for more details.
12
+ """
13
+ import numpy as np
14
+ import importlib
15
+ import torch.utils.data
16
+ from face3d.data.base_dataset import BaseDataset
17
+
18
+
19
+ def find_dataset_using_name(dataset_name):
20
+ """Import the module "data/[dataset_name]_dataset.py".
21
+
22
+ In the file, the class called DatasetNameDataset() will
23
+ be instantiated. It has to be a subclass of BaseDataset,
24
+ and it is case-insensitive.
25
+ """
26
+ dataset_filename = "data." + dataset_name + "_dataset"
27
+ datasetlib = importlib.import_module(dataset_filename)
28
+
29
+ dataset = None
30
+ target_dataset_name = dataset_name.replace('_', '') + 'dataset'
31
+ for name, cls in datasetlib.__dict__.items():
32
+ if name.lower() == target_dataset_name.lower() \
33
+ and issubclass(cls, BaseDataset):
34
+ dataset = cls
35
+
36
+ if dataset is None:
37
+ raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
38
+
39
+ return dataset
40
+
41
+
42
+ def get_option_setter(dataset_name):
43
+ """Return the static method <modify_commandline_options> of the dataset class."""
44
+ dataset_class = find_dataset_using_name(dataset_name)
45
+ return dataset_class.modify_commandline_options
46
+
47
+
48
+ def create_dataset(opt, rank=0):
49
+ """Create a dataset given the option.
50
+
51
+ This function wraps the class CustomDatasetDataLoader.
52
+ This is the main interface between this package and 'train.py'/'test.py'
53
+
54
+ Example:
55
+ >>> from data import create_dataset
56
+ >>> dataset = create_dataset(opt)
57
+ """
58
+ data_loader = CustomDatasetDataLoader(opt, rank=rank)
59
+ dataset = data_loader.load_data()
60
+ return dataset
61
+
62
+ class CustomDatasetDataLoader():
63
+ """Wrapper class of Dataset class that performs multi-threaded data loading"""
64
+
65
+ def __init__(self, opt, rank=0):
66
+ """Initialize this class
67
+
68
+ Step 1: create a dataset instance given the name [dataset_mode]
69
+ Step 2: create a multi-threaded data loader.
70
+ """
71
+ self.opt = opt
72
+ dataset_class = find_dataset_using_name(opt.dataset_mode)
73
+ self.dataset = dataset_class(opt)
74
+ self.sampler = None
75
+ print("rank %d %s dataset [%s] was created" % (rank, self.dataset.name, type(self.dataset).__name__))
76
+ if opt.use_ddp and opt.isTrain:
77
+ world_size = opt.world_size
78
+ self.sampler = torch.utils.data.distributed.DistributedSampler(
79
+ self.dataset,
80
+ num_replicas=world_size,
81
+ rank=rank,
82
+ shuffle=not opt.serial_batches
83
+ )
84
+ self.dataloader = torch.utils.data.DataLoader(
85
+ self.dataset,
86
+ sampler=self.sampler,
87
+ num_workers=int(opt.num_threads / world_size),
88
+ batch_size=int(opt.batch_size / world_size),
89
+ drop_last=True)
90
+ else:
91
+ self.dataloader = torch.utils.data.DataLoader(
92
+ self.dataset,
93
+ batch_size=opt.batch_size,
94
+ shuffle=(not opt.serial_batches) and opt.isTrain,
95
+ num_workers=int(opt.num_threads),
96
+ drop_last=True
97
+ )
98
+
99
+ def set_epoch(self, epoch):
100
+ self.dataset.current_epoch = epoch
101
+ if self.sampler is not None:
102
+ self.sampler.set_epoch(epoch)
103
+
104
+ def load_data(self):
105
+ return self
106
+
107
+ def __len__(self):
108
+ """Return the number of data in the dataset"""
109
+ return min(len(self.dataset), self.opt.max_dataset_size)
110
+
111
+ def __iter__(self):
112
+ """Return a batch of data"""
113
+ for i, data in enumerate(self.dataloader):
114
+ if i * self.opt.batch_size >= self.opt.max_dataset_size:
115
+ break
116
+ yield data
videoretalking/third_part/face3d/data/base_dataset.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
2
+
3
+ It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
4
+ """
5
+ import random
6
+ import numpy as np
7
+ import torch.utils.data as data
8
+ from PIL import Image
9
+ import torchvision.transforms as transforms
10
+ from abc import ABC, abstractmethod
11
+
12
+
13
+ class BaseDataset(data.Dataset, ABC):
14
+ """This class is an abstract base class (ABC) for datasets.
15
+
16
+ To create a subclass, you need to implement the following four functions:
17
+ -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
18
+ -- <__len__>: return the size of dataset.
19
+ -- <__getitem__>: get a data point.
20
+ -- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
21
+ """
22
+
23
+ def __init__(self, opt):
24
+ """Initialize the class; save the options in the class
25
+
26
+ Parameters:
27
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
28
+ """
29
+ self.opt = opt
30
+ # self.root = opt.dataroot
31
+ self.current_epoch = 0
32
+
33
+ @staticmethod
34
+ def modify_commandline_options(parser, is_train):
35
+ """Add new dataset-specific options, and rewrite default values for existing options.
36
+
37
+ Parameters:
38
+ parser -- original option parser
39
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
40
+
41
+ Returns:
42
+ the modified parser.
43
+ """
44
+ return parser
45
+
46
+ @abstractmethod
47
+ def __len__(self):
48
+ """Return the total number of images in the dataset."""
49
+ return 0
50
+
51
+ @abstractmethod
52
+ def __getitem__(self, index):
53
+ """Return a data point and its metadata information.
54
+
55
+ Parameters:
56
+ index - - a random integer for data indexing
57
+
58
+ Returns:
59
+ a dictionary of data with their names. It ususally contains the data itself and its metadata information.
60
+ """
61
+ pass
62
+
63
+
64
+ def get_transform(grayscale=False):
65
+ transform_list = []
66
+ if grayscale:
67
+ transform_list.append(transforms.Grayscale(1))
68
+ transform_list += [transforms.ToTensor()]
69
+ return transforms.Compose(transform_list)
70
+
71
+ def get_affine_mat(opt, size):
72
+ shift_x, shift_y, scale, rot_angle, flip = 0., 0., 1., 0., False
73
+ w, h = size
74
+
75
+ if 'shift' in opt.preprocess:
76
+ shift_pixs = int(opt.shift_pixs)
77
+ shift_x = random.randint(-shift_pixs, shift_pixs)
78
+ shift_y = random.randint(-shift_pixs, shift_pixs)
79
+ if 'scale' in opt.preprocess:
80
+ scale = 1 + opt.scale_delta * (2 * random.random() - 1)
81
+ if 'rot' in opt.preprocess:
82
+ rot_angle = opt.rot_angle * (2 * random.random() - 1)
83
+ rot_rad = -rot_angle * np.pi/180
84
+ if 'flip' in opt.preprocess:
85
+ flip = random.random() > 0.5
86
+
87
+ shift_to_origin = np.array([1, 0, -w//2, 0, 1, -h//2, 0, 0, 1]).reshape([3, 3])
88
+ flip_mat = np.array([-1 if flip else 1, 0, 0, 0, 1, 0, 0, 0, 1]).reshape([3, 3])
89
+ shift_mat = np.array([1, 0, shift_x, 0, 1, shift_y, 0, 0, 1]).reshape([3, 3])
90
+ rot_mat = np.array([np.cos(rot_rad), np.sin(rot_rad), 0, -np.sin(rot_rad), np.cos(rot_rad), 0, 0, 0, 1]).reshape([3, 3])
91
+ scale_mat = np.array([scale, 0, 0, 0, scale, 0, 0, 0, 1]).reshape([3, 3])
92
+ shift_to_center = np.array([1, 0, w//2, 0, 1, h//2, 0, 0, 1]).reshape([3, 3])
93
+
94
+ affine = shift_to_center @ scale_mat @ rot_mat @ shift_mat @ flip_mat @ shift_to_origin
95
+ affine_inv = np.linalg.inv(affine)
96
+ return affine, affine_inv, flip
97
+
98
+ def apply_img_affine(img, affine_inv, method=Image.BICUBIC):
99
+ return img.transform(img.size, Image.AFFINE, data=affine_inv.flatten()[:6], resample=Image.BICUBIC)
100
+
101
+ def apply_lm_affine(landmark, affine, flip, size):
102
+ _, h = size
103
+ lm = landmark.copy()
104
+ lm[:, 1] = h - 1 - lm[:, 1]
105
+ lm = np.concatenate((lm, np.ones([lm.shape[0], 1])), -1)
106
+ lm = lm @ np.transpose(affine)
107
+ lm[:, :2] = lm[:, :2] / lm[:, 2:]
108
+ lm = lm[:, :2]
109
+ lm[:, 1] = h - 1 - lm[:, 1]
110
+ if flip:
111
+ lm_ = lm.copy()
112
+ lm_[:17] = lm[16::-1]
113
+ lm_[17:22] = lm[26:21:-1]
114
+ lm_[22:27] = lm[21:16:-1]
115
+ lm_[31:36] = lm[35:30:-1]
116
+ lm_[36:40] = lm[45:41:-1]
117
+ lm_[40:42] = lm[47:45:-1]
118
+ lm_[42:46] = lm[39:35:-1]
119
+ lm_[46:48] = lm[41:39:-1]
120
+ lm_[48:55] = lm[54:47:-1]
121
+ lm_[55:60] = lm[59:54:-1]
122
+ lm_[60:65] = lm[64:59:-1]
123
+ lm_[65:68] = lm[67:64:-1]
124
+ lm = lm_
125
+ return lm
videoretalking/third_part/face3d/data/flist_dataset.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This script defines the custom dataset for Deep3DFaceRecon_pytorch
2
+ """
3
+
4
+ import os.path
5
+ from data.base_dataset import BaseDataset, get_transform, get_affine_mat, apply_img_affine, apply_lm_affine
6
+ from data.image_folder import make_dataset
7
+ from PIL import Image
8
+ import random
9
+ import util.util as util
10
+ import numpy as np
11
+ import json
12
+ import torch
13
+ from scipy.io import loadmat, savemat
14
+ import pickle
15
+ from util.preprocess import align_img, estimate_norm
16
+ from util.load_mats import load_lm3d
17
+
18
+
19
+ def default_flist_reader(flist):
20
+ """
21
+ flist format: impath label\nimpath label\n ...(same to caffe's filelist)
22
+ """
23
+ imlist = []
24
+ with open(flist, 'r') as rf:
25
+ for line in rf.readlines():
26
+ impath = line.strip()
27
+ imlist.append(impath)
28
+
29
+ return imlist
30
+
31
+ def jason_flist_reader(flist):
32
+ with open(flist, 'r') as fp:
33
+ info = json.load(fp)
34
+ return info
35
+
36
+ def parse_label(label):
37
+ return torch.tensor(np.array(label).astype(np.float32))
38
+
39
+
40
+ class FlistDataset(BaseDataset):
41
+ """
42
+ It requires one directories to host training images '/path/to/data/train'
43
+ You can train the model with the dataset flag '--dataroot /path/to/data'.
44
+ """
45
+
46
+ def __init__(self, opt):
47
+ """Initialize this dataset class.
48
+
49
+ Parameters:
50
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
51
+ """
52
+ BaseDataset.__init__(self, opt)
53
+
54
+ self.lm3d_std = load_lm3d(opt.bfm_folder)
55
+
56
+ msk_names = default_flist_reader(opt.flist)
57
+ self.msk_paths = [os.path.join(opt.data_root, i) for i in msk_names]
58
+
59
+ self.size = len(self.msk_paths)
60
+ self.opt = opt
61
+
62
+ self.name = 'train' if opt.isTrain else 'val'
63
+ if '_' in opt.flist:
64
+ self.name += '_' + opt.flist.split(os.sep)[-1].split('_')[0]
65
+
66
+
67
+ def __getitem__(self, index):
68
+ """Return a data point and its metadata information.
69
+
70
+ Parameters:
71
+ index (int) -- a random integer for data indexing
72
+
73
+ Returns a dictionary that contains A, B, A_paths and B_paths
74
+ img (tensor) -- an image in the input domain
75
+ msk (tensor) -- its corresponding attention mask
76
+ lm (tensor) -- its corresponding 3d landmarks
77
+ im_paths (str) -- image paths
78
+ aug_flag (bool) -- a flag used to tell whether its raw or augmented
79
+ """
80
+ msk_path = self.msk_paths[index % self.size] # make sure index is within then range
81
+ img_path = msk_path.replace('mask/', '')
82
+ lm_path = '.'.join(msk_path.replace('mask', 'landmarks').split('.')[:-1]) + '.txt'
83
+
84
+ raw_img = Image.open(img_path).convert('RGB')
85
+ raw_msk = Image.open(msk_path).convert('RGB')
86
+ raw_lm = np.loadtxt(lm_path).astype(np.float32)
87
+
88
+ _, img, lm, msk = align_img(raw_img, raw_lm, self.lm3d_std, raw_msk)
89
+
90
+ aug_flag = self.opt.use_aug and self.opt.isTrain
91
+ if aug_flag:
92
+ img, lm, msk = self._augmentation(img, lm, self.opt, msk)
93
+
94
+ _, H = img.size
95
+ M = estimate_norm(lm, H)
96
+ transform = get_transform()
97
+ img_tensor = transform(img)
98
+ msk_tensor = transform(msk)[:1, ...]
99
+ lm_tensor = parse_label(lm)
100
+ M_tensor = parse_label(M)
101
+
102
+
103
+ return {'imgs': img_tensor,
104
+ 'lms': lm_tensor,
105
+ 'msks': msk_tensor,
106
+ 'M': M_tensor,
107
+ 'im_paths': img_path,
108
+ 'aug_flag': aug_flag,
109
+ 'dataset': self.name}
110
+
111
+ def _augmentation(self, img, lm, opt, msk=None):
112
+ affine, affine_inv, flip = get_affine_mat(opt, img.size)
113
+ img = apply_img_affine(img, affine_inv)
114
+ lm = apply_lm_affine(lm, affine, flip, img.size)
115
+ if msk is not None:
116
+ msk = apply_img_affine(msk, affine_inv, method=Image.BILINEAR)
117
+ return img, lm, msk
118
+
119
+
120
+
121
+
122
+ def __len__(self):
123
+ """Return the total number of images in the dataset.
124
+ """
125
+ return self.size
videoretalking/third_part/face3d/data/image_folder.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A modified image folder class
2
+
3
+ We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
4
+ so that this class can load images from both current directory and its subdirectories.
5
+ """
6
+ import numpy as np
7
+ import torch.utils.data as data
8
+
9
+ from PIL import Image
10
+ import os
11
+ import os.path
12
+
13
+ IMG_EXTENSIONS = [
14
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
15
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
16
+ '.tif', '.TIF', '.tiff', '.TIFF',
17
+ ]
18
+
19
+
20
+ def is_image_file(filename):
21
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
22
+
23
+
24
+ def make_dataset(dir, max_dataset_size=float("inf")):
25
+ images = []
26
+ assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir
27
+
28
+ for root, _, fnames in sorted(os.walk(dir, followlinks=True)):
29
+ for fname in fnames:
30
+ if is_image_file(fname):
31
+ path = os.path.join(root, fname)
32
+ images.append(path)
33
+ return images[:min(max_dataset_size, len(images))]
34
+
35
+
36
+ def default_loader(path):
37
+ return Image.open(path).convert('RGB')
38
+
39
+
40
+ class ImageFolder(data.Dataset):
41
+
42
+ def __init__(self, root, transform=None, return_paths=False,
43
+ loader=default_loader):
44
+ imgs = make_dataset(root)
45
+ if len(imgs) == 0:
46
+ raise(RuntimeError("Found 0 images in: " + root + "\n"
47
+ "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
48
+
49
+ self.root = root
50
+ self.imgs = imgs
51
+ self.transform = transform
52
+ self.return_paths = return_paths
53
+ self.loader = loader
54
+
55
+ def __getitem__(self, index):
56
+ path = self.imgs[index]
57
+ img = self.loader(path)
58
+ if self.transform is not None:
59
+ img = self.transform(img)
60
+ if self.return_paths:
61
+ return img, path
62
+ else:
63
+ return img
64
+
65
+ def __len__(self):
66
+ return len(self.imgs)
videoretalking/third_part/face3d/data/template_dataset.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Dataset class template
2
+
3
+ This module provides a template for users to implement custom datasets.
4
+ You can specify '--dataset_mode template' to use this dataset.
5
+ The class name should be consistent with both the filename and its dataset_mode option.
6
+ The filename should be <dataset_mode>_dataset.py
7
+ The class name should be <Dataset_mode>Dataset.py
8
+ You need to implement the following functions:
9
+ -- <modify_commandline_options>: Add dataset-specific options and rewrite default values for existing options.
10
+ -- <__init__>: Initialize this dataset class.
11
+ -- <__getitem__>: Return a data point and its metadata information.
12
+ -- <__len__>: Return the number of images.
13
+ """
14
+ from data.base_dataset import BaseDataset, get_transform
15
+ # from data.image_folder import make_dataset
16
+ # from PIL import Image
17
+
18
+
19
+ class TemplateDataset(BaseDataset):
20
+ """A template dataset class for you to implement custom datasets."""
21
+ @staticmethod
22
+ def modify_commandline_options(parser, is_train):
23
+ """Add new dataset-specific options, and rewrite default values for existing options.
24
+
25
+ Parameters:
26
+ parser -- original option parser
27
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
28
+
29
+ Returns:
30
+ the modified parser.
31
+ """
32
+ parser.add_argument('--new_dataset_option', type=float, default=1.0, help='new dataset option')
33
+ parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0) # specify dataset-specific default values
34
+ return parser
35
+
36
+ def __init__(self, opt):
37
+ """Initialize this dataset class.
38
+
39
+ Parameters:
40
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
41
+
42
+ A few things can be done here.
43
+ - save the options (have been done in BaseDataset)
44
+ - get image paths and meta information of the dataset.
45
+ - define the image transformation.
46
+ """
47
+ # save the option and dataset root
48
+ BaseDataset.__init__(self, opt)
49
+ # get the image paths of your dataset;
50
+ self.image_paths = [] # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root
51
+ # define the default transform function. You can use <base_dataset.get_transform>; You can also define your custom transform function
52
+ self.transform = get_transform(opt)
53
+
54
+ def __getitem__(self, index):
55
+ """Return a data point and its metadata information.
56
+
57
+ Parameters:
58
+ index -- a random integer for data indexing
59
+
60
+ Returns:
61
+ a dictionary of data with their names. It usually contains the data itself and its metadata information.
62
+
63
+ Step 1: get a random image path: e.g., path = self.image_paths[index]
64
+ Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB').
65
+ Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image)
66
+ Step 4: return a data point as a dictionary.
67
+ """
68
+ path = 'temp' # needs to be a string
69
+ data_A = None # needs to be a tensor
70
+ data_B = None # needs to be a tensor
71
+ return {'data_A': data_A, 'data_B': data_B, 'path': path}
72
+
73
+ def __len__(self):
74
+ """Return the total number of images."""
75
+ return len(self.image_paths)
videoretalking/third_part/face3d/data_preparation.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This script is the data preparation script for Deep3DFaceRecon_pytorch
2
+ """
3
+
4
+ import os
5
+ import numpy as np
6
+ import argparse
7
+ from util.detect_lm68 import detect_68p,load_lm_graph
8
+ from util.skin_mask import get_skin_mask
9
+ from util.generate_list import check_list, write_list
10
+ import warnings
11
+ warnings.filterwarnings("ignore")
12
+
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument('--data_root', type=str, default='datasets', help='root directory for training data')
15
+ parser.add_argument('--img_folder', nargs="+", required=True, help='folders of training images')
16
+ parser.add_argument('--mode', type=str, default='train', help='train or val')
17
+ opt = parser.parse_args()
18
+
19
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
20
+
21
+ def data_prepare(folder_list,mode):
22
+
23
+ lm_sess,input_op,output_op = load_lm_graph('./checkpoints/lm_model/68lm_detector.pb') # load a tensorflow version 68-landmark detector
24
+
25
+ for img_folder in folder_list:
26
+ detect_68p(img_folder,lm_sess,input_op,output_op) # detect landmarks for images
27
+ get_skin_mask(img_folder) # generate skin attention mask for images
28
+
29
+ # create files that record path to all training data
30
+ msks_list = []
31
+ for img_folder in folder_list:
32
+ path = os.path.join(img_folder, 'mask')
33
+ msks_list += ['/'.join([img_folder, 'mask', i]) for i in sorted(os.listdir(path)) if 'jpg' in i or
34
+ 'png' in i or 'jpeg' in i or 'PNG' in i]
35
+
36
+ imgs_list = [i.replace('mask/', '') for i in msks_list]
37
+ lms_list = [i.replace('mask', 'landmarks') for i in msks_list]
38
+ lms_list = ['.'.join(i.split('.')[:-1]) + '.txt' for i in lms_list]
39
+
40
+ lms_list_final, imgs_list_final, msks_list_final = check_list(lms_list, imgs_list, msks_list) # check if the path is valid
41
+ write_list(lms_list_final, imgs_list_final, msks_list_final, mode=mode) # save files
42
+
43
+ if __name__ == '__main__':
44
+ print('Datasets:',opt.img_folder)
45
+ data_prepare([os.path.join(opt.data_root,folder) for folder in opt.img_folder],opt.mode)
videoretalking/third_part/face3d/extract_kp_videos.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import time
4
+ import glob
5
+ import argparse
6
+ import face_alignment
7
+ import numpy as np
8
+ from PIL import Image
9
+ import torch
10
+ from tqdm import tqdm
11
+ from itertools import cycle
12
+
13
+ from torch.multiprocessing import Pool, Process, set_start_method
14
+
15
+ class KeypointExtractor():
16
+ def __init__(self):
17
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
18
+ self.detector = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, device=device)
19
+
20
+ def extract_keypoint(self, images, name=None, info=True):
21
+ if isinstance(images, list):
22
+ keypoints = []
23
+ if info:
24
+ i_range = tqdm(images,desc='landmark Det:')
25
+ else:
26
+ i_range = images
27
+
28
+ for image in i_range:
29
+ current_kp = self.extract_keypoint(image)
30
+ if np.mean(current_kp) == -1 and keypoints:
31
+ keypoints.append(keypoints[-1])
32
+ else:
33
+ keypoints.append(current_kp[None])
34
+
35
+ keypoints = np.concatenate(keypoints, 0)
36
+ np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
37
+ return keypoints
38
+ else:
39
+ while True:
40
+ try:
41
+ keypoints = self.detector.get_landmarks_from_image(np.array(images))[0]
42
+ break
43
+ except RuntimeError as e:
44
+ if str(e).startswith('CUDA'):
45
+ print("Warning: out of memory, sleep for 1s")
46
+ time.sleep(1)
47
+ else:
48
+ print(e)
49
+ break
50
+ except TypeError:
51
+ print('No face detected in this image')
52
+ shape = [68, 2]
53
+ keypoints = -1. * np.ones(shape)
54
+ break
55
+ if name is not None:
56
+ np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
57
+ return keypoints
58
+
59
+ def read_video(filename):
60
+ frames = []
61
+ cap = cv2.VideoCapture(filename)
62
+ while cap.isOpened():
63
+ ret, frame = cap.read()
64
+ if ret:
65
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
66
+ frame = Image.fromarray(frame)
67
+ frames.append(frame)
68
+ else:
69
+ break
70
+ cap.release()
71
+ return frames
72
+
73
+ def run(data):
74
+ filename, opt, device = data
75
+ os.environ['CUDA_VISIBLE_DEVICES'] = device
76
+ kp_extractor = KeypointExtractor()
77
+ images = read_video(filename)
78
+ name = filename.split('/')[-2:]
79
+ os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True)
80
+ kp_extractor.extract_keypoint(
81
+ images,
82
+ name=os.path.join(opt.output_dir, name[-2], name[-1])
83
+ )
84
+
85
+ if __name__ == '__main__':
86
+ set_start_method('spawn')
87
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
88
+ parser.add_argument('--input_dir', type=str, help='the folder of the input files')
89
+ parser.add_argument('--output_dir', type=str, help='the folder of the output files')
90
+ parser.add_argument('--device_ids', type=str, default='0,1')
91
+ parser.add_argument('--workers', type=int, default=4)
92
+
93
+ opt = parser.parse_args()
94
+ filenames = list()
95
+ VIDEO_EXTENSIONS_LOWERCASE = {'mp4'}
96
+ VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE})
97
+ extensions = VIDEO_EXTENSIONS
98
+
99
+ for ext in extensions:
100
+ os.listdir(f'{opt.input_dir}')
101
+ print(f'{opt.input_dir}/*.{ext}')
102
+ filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}'))
103
+ print('Total number of videos:', len(filenames))
104
+ pool = Pool(opt.workers)
105
+ args_list = cycle([opt])
106
+ device_ids = opt.device_ids.split(",")
107
+ device_ids = cycle(device_ids)
108
+ for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))):
109
+ None
videoretalking/third_part/face3d/face_recon_videos.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import glob
4
+ import numpy as np
5
+ from PIL import Image
6
+ from tqdm import tqdm
7
+ from scipy.io import savemat
8
+
9
+ import torch
10
+
11
+ from models import create_model
12
+ from options.inference_options import InferenceOptions
13
+ from util.preprocess import align_img
14
+ from util.load_mats import load_lm3d
15
+ from util.util import mkdirs, tensor2im, save_image
16
+
17
+
18
+ def get_data_path(root, keypoint_root):
19
+ filenames = list()
20
+ keypoint_filenames = list()
21
+
22
+ VIDEO_EXTENSIONS_LOWERCASE = {'mp4'}
23
+ VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE})
24
+ extensions = VIDEO_EXTENSIONS
25
+
26
+ for ext in extensions:
27
+ filenames += glob.glob(f'{root}/**/*.{ext}', recursive=True)
28
+ filenames = sorted(filenames)
29
+ keypoint_filenames = sorted(glob.glob(f'{keypoint_root}/**/*.txt', recursive=True))
30
+ assert len(filenames) == len(keypoint_filenames)
31
+
32
+ return filenames, keypoint_filenames
33
+
34
+ class VideoPathDataset(torch.utils.data.Dataset):
35
+ def __init__(self, filenames, txt_filenames, bfm_folder):
36
+ self.filenames = filenames
37
+ self.txt_filenames = txt_filenames
38
+ self.lm3d_std = load_lm3d(bfm_folder)
39
+
40
+ def __len__(self):
41
+ return len(self.filenames)
42
+
43
+ def __getitem__(self, index):
44
+ filename = self.filenames[index]
45
+ txt_filename = self.txt_filenames[index]
46
+ frames = self.read_video(filename)
47
+ lm = np.loadtxt(txt_filename).astype(np.float32)
48
+ lm = lm.reshape([len(frames), -1, 2])
49
+ out_images, out_trans_params = list(), list()
50
+ for i in range(len(frames)):
51
+ out_img, _, out_trans_param \
52
+ = self.image_transform(frames[i], lm[i])
53
+ out_images.append(out_img[None])
54
+ out_trans_params.append(out_trans_param[None])
55
+ return {
56
+ 'imgs': torch.cat(out_images, 0),
57
+ 'trans_param':torch.cat(out_trans_params, 0),
58
+ 'filename': filename
59
+ }
60
+
61
+ def read_video(self, filename):
62
+ frames = list()
63
+ cap = cv2.VideoCapture(filename)
64
+ while cap.isOpened():
65
+ ret, frame = cap.read()
66
+ if ret:
67
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
68
+ frame = Image.fromarray(frame)
69
+ frames.append(frame)
70
+ else:
71
+ break
72
+ cap.release()
73
+ return frames
74
+
75
+ def image_transform(self, images, lm):
76
+ W,H = images.size
77
+ if np.mean(lm) == -1:
78
+ lm = (self.lm3d_std[:, :2]+1)/2.
79
+ lm = np.concatenate(
80
+ [lm[:, :1]*W, lm[:, 1:2]*H], 1
81
+ )
82
+ else:
83
+ lm[:, -1] = H - 1 - lm[:, -1]
84
+
85
+ trans_params, img, lm, _ = align_img(images, lm, self.lm3d_std)
86
+ img = torch.tensor(np.array(img)/255., dtype=torch.float32).permute(2, 0, 1)
87
+ lm = torch.tensor(lm)
88
+ trans_params = np.array([float(item) for item in np.hsplit(trans_params, 5)])
89
+ trans_params = torch.tensor(trans_params.astype(np.float32))
90
+ return img, lm, trans_params
91
+
92
+ def main(opt, model):
93
+ # import torch.multiprocessing
94
+ # torch.multiprocessing.set_sharing_strategy('file_system')
95
+ filenames, keypoint_filenames = get_data_path(opt.input_dir, opt.keypoint_dir)
96
+ dataset = VideoPathDataset(filenames, keypoint_filenames, opt.bfm_folder)
97
+ dataloader = torch.utils.data.DataLoader(
98
+ dataset,
99
+ batch_size=1, # can noly set to one here!
100
+ shuffle=False,
101
+ drop_last=False,
102
+ num_workers=0,
103
+ )
104
+ batch_size = opt.inference_batch_size
105
+ for data in tqdm(dataloader):
106
+ num_batch = data['imgs'][0].shape[0] // batch_size + 1
107
+ pred_coeffs = list()
108
+ for index in range(num_batch):
109
+ data_input = {
110
+ 'imgs': data['imgs'][0,index*batch_size:(index+1)*batch_size],
111
+ }
112
+ model.set_input(data_input)
113
+ model.test()
114
+ pred_coeff = {key:model.pred_coeffs_dict[key].cpu().numpy() for key in model.pred_coeffs_dict}
115
+ pred_coeff = np.concatenate([
116
+ pred_coeff['id'],
117
+ pred_coeff['exp'],
118
+ pred_coeff['tex'],
119
+ pred_coeff['angle'],
120
+ pred_coeff['gamma'],
121
+ pred_coeff['trans']], 1)
122
+ pred_coeffs.append(pred_coeff)
123
+ visuals = model.get_current_visuals() # get image results
124
+ if False: # debug
125
+ for name in visuals:
126
+ images = visuals[name]
127
+ for i in range(images.shape[0]):
128
+ image_numpy = tensor2im(images[i])
129
+ save_image(
130
+ image_numpy,
131
+ os.path.join(
132
+ opt.output_dir,
133
+ os.path.basename(data['filename'][0])+str(i).zfill(5)+'.jpg')
134
+ )
135
+ exit()
136
+
137
+ pred_coeffs = np.concatenate(pred_coeffs, 0)
138
+ pred_trans_params = data['trans_param'][0].cpu().numpy()
139
+ name = data['filename'][0].split('/')[-2:]
140
+ name[-1] = os.path.splitext(name[-1])[0] + '.mat'
141
+ os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True)
142
+ savemat(
143
+ os.path.join(opt.output_dir, name[-2], name[-1]),
144
+ {'coeff':pred_coeffs, 'transform_params':pred_trans_params}
145
+ )
146
+
147
+ if __name__ == '__main__':
148
+ opt = InferenceOptions().parse() # get test options
149
+ model = create_model(opt)
150
+ model.setup(opt)
151
+ model.device = 'cuda:0'
152
+ model.parallelize()
153
+ model.eval()
154
+
155
+ main(opt, model)
156
+
157
+
videoretalking/third_part/face3d/models/__init__.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This package contains modules related to objective functions, optimizations, and network architectures.
2
+
3
+ To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
4
+ You need to implement the following five functions:
5
+ -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
6
+ -- <set_input>: unpack data from dataset and apply preprocessing.
7
+ -- <forward>: produce intermediate results.
8
+ -- <optimize_parameters>: calculate loss, gradients, and update network weights.
9
+ -- <modify_commandline_options>: (optionally) add model-specific options and set default options.
10
+
11
+ In the function <__init__>, you need to define four lists:
12
+ -- self.loss_names (str list): specify the training losses that you want to plot and save.
13
+ -- self.model_names (str list): define networks used in our training.
14
+ -- self.visual_names (str list): specify the images that you want to display and save.
15
+ -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
16
+
17
+ Now you can use the model class by specifying flag '--model dummy'.
18
+ See our template model class 'template_model.py' for more details.
19
+ """
20
+
21
+ import importlib
22
+ from face3d.models.base_model import BaseModel
23
+
24
+
25
+ def find_model_using_name(model_name):
26
+ """Import the module "models/[model_name]_model.py".
27
+
28
+ In the file, the class called DatasetNameModel() will
29
+ be instantiated. It has to be a subclass of BaseModel,
30
+ and it is case-insensitive.
31
+ """
32
+ model_filename = "face3d.models." + model_name + "_model"
33
+ modellib = importlib.import_module(model_filename)
34
+ model = None
35
+ target_model_name = model_name.replace('_', '') + 'model'
36
+ for name, cls in modellib.__dict__.items():
37
+ if name.lower() == target_model_name.lower() \
38
+ and issubclass(cls, BaseModel):
39
+ model = cls
40
+
41
+ if model is None:
42
+ print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
43
+ exit(0)
44
+
45
+ return model
46
+
47
+
48
+ def get_option_setter(model_name):
49
+ """Return the static method <modify_commandline_options> of the model class."""
50
+ model_class = find_model_using_name(model_name)
51
+ return model_class.modify_commandline_options
52
+
53
+
54
+ def create_model(opt):
55
+ """Create a model given the option.
56
+
57
+ This function warps the class CustomDatasetDataLoader.
58
+ This is the main interface between this package and 'train.py'/'test.py'
59
+
60
+ Example:
61
+ >>> from models import create_model
62
+ >>> model = create_model(opt)
63
+ """
64
+ model = find_model_using_name(opt.model)
65
+ instance = model(opt)
66
+ print("model [%s] was created" % type(instance).__name__)
67
+ return instance
videoretalking/third_part/face3d/models/arcface_torch/README.md ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Distributed Arcface Training in Pytorch
2
+
3
+ This is a deep learning library that makes face recognition efficient, and effective, which can train tens of millions
4
+ identity on a single server.
5
+
6
+ ## Requirements
7
+
8
+ - Install [pytorch](http://pytorch.org) (torch>=1.6.0), our doc for [install.md](docs/install.md).
9
+ - `pip install -r requirements.txt`.
10
+ - Download the dataset
11
+ from [https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_](https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_)
12
+ .
13
+
14
+ ## How to Training
15
+
16
+ To train a model, run `train.py` with the path to the configs:
17
+
18
+ ### 1. Single node, 8 GPUs:
19
+
20
+ ```shell
21
+ python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r50
22
+ ```
23
+
24
+ ### 2. Multiple nodes, each node 8 GPUs:
25
+
26
+ Node 0:
27
+
28
+ ```shell
29
+ python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50
30
+ ```
31
+
32
+ Node 1:
33
+
34
+ ```shell
35
+ python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50
36
+ ```
37
+
38
+ ### 3.Training resnet2060 with 8 GPUs:
39
+
40
+ ```shell
41
+ python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r2060.py
42
+ ```
43
+
44
+ ## Model Zoo
45
+
46
+ - The models are available for non-commercial research purposes only.
47
+ - All models can be found in here.
48
+ - [Baidu Yun Pan](https://pan.baidu.com/s/1CL-l4zWqsI1oDuEEYVhj-g): e8pw
49
+ - [onedrive](https://1drv.ms/u/s!AswpsDO2toNKq0lWY69vN58GR6mw?e=p9Ov5d)
50
+
51
+ ### Performance on [**ICCV2021-MFR**](http://iccv21-mfr.com/)
52
+
53
+ ICCV2021-MFR testset consists of non-celebrities so we can ensure that it has very few overlap with public available face
54
+ recognition training set, such as MS1M and CASIA as they mostly collected from online celebrities.
55
+ As the result, we can evaluate the FAIR performance for different algorithms.
56
+
57
+ For **ICCV2021-MFR-ALL** set, TAR is measured on all-to-all 1:1 protocal, with FAR less than 0.000001(e-6). The
58
+ globalised multi-racial testset contains 242,143 identities and 1,624,305 images.
59
+
60
+ For **ICCV2021-MFR-MASK** set, TAR is measured on mask-to-nonmask 1:1 protocal, with FAR less than 0.0001(e-4).
61
+ Mask testset contains 6,964 identities, 6,964 masked images and 13,928 non-masked images.
62
+ There are totally 13,928 positive pairs and 96,983,824 negative pairs.
63
+
64
+ | Datasets | backbone | Training throughout | Size / MB | **ICCV2021-MFR-MASK** | **ICCV2021-MFR-ALL** |
65
+ | :---: | :--- | :--- | :--- |:--- |:--- |
66
+ | MS1MV3 | r18 | - | 91 | **47.85** | **68.33** |
67
+ | Glint360k | r18 | 8536 | 91 | **53.32** | **72.07** |
68
+ | MS1MV3 | r34 | - | 130 | **58.72** | **77.36** |
69
+ | Glint360k | r34 | 6344 | 130 | **65.10** | **83.02** |
70
+ | MS1MV3 | r50 | 5500 | 166 | **63.85** | **80.53** |
71
+ | Glint360k | r50 | 5136 | 166 | **70.23** | **87.08** |
72
+ | MS1MV3 | r100 | - | 248 | **69.09** | **84.31** |
73
+ | Glint360k | r100 | 3332 | 248 | **75.57** | **90.66** |
74
+ | MS1MV3 | mobilefacenet | 12185 | 7.8 | **41.52** | **65.26** |
75
+ | Glint360k | mobilefacenet | 11197 | 7.8 | **44.52** | **66.48** |
76
+
77
+ ### Performance on IJB-C and Verification Datasets
78
+
79
+ | Datasets | backbone | IJBC(1e-05) | IJBC(1e-04) | agedb30 | cfp_fp | lfw | log |
80
+ | :---: | :--- | :--- | :--- | :--- |:--- |:--- |:--- |
81
+ | MS1MV3 | r18 | 92.07 | 94.66 | 97.77 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r18_fp16/training.log)|
82
+ | MS1MV3 | r34 | 94.10 | 95.90 | 98.10 | 98.67 | 99.80 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r34_fp16/training.log)|
83
+ | MS1MV3 | r50 | 94.79 | 96.46 | 98.35 | 98.96 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r50_fp16/training.log)|
84
+ | MS1MV3 | r100 | 95.31 | 96.81 | 98.48 | 99.06 | 99.85 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r100_fp16/training.log)|
85
+ | MS1MV3 | **r2060**| 95.34 | 97.11 | 98.67 | 99.24 | 99.87 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r2060_fp16/training.log)|
86
+ | Glint360k |r18-0.1 | 93.16 | 95.33 | 97.72 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r18_fp16_0.1/training.log)|
87
+ | Glint360k |r34-0.1 | 95.16 | 96.56 | 98.33 | 98.78 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r34_fp16_0.1/training.log)|
88
+ | Glint360k |r50-0.1 | 95.61 | 96.97 | 98.38 | 99.20 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r50_fp16_0.1/training.log)|
89
+ | Glint360k |r100-0.1 | 95.88 | 97.32 | 98.48 | 99.29 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r100_fp16_0.1/training.log)|
90
+
91
+ [comment]: <> (More details see [model.md]&#40;docs/modelzoo.md&#41; in docs.)
92
+
93
+
94
+ ## [Speed Benchmark](docs/speed_benchmark.md)
95
+
96
+ **Arcface Torch** can train large-scale face recognition training set efficiently and quickly. When the number of
97
+ classes in training sets is greater than 300K and the training is sufficient, partial fc sampling strategy will get same
98
+ accuracy with several times faster training performance and smaller GPU memory.
99
+ Partial FC is a sparse variant of the model parallel architecture for large sacle face recognition. Partial FC use a
100
+ sparse softmax, where each batch dynamicly sample a subset of class centers for training. In each iteration, only a
101
+ sparse part of the parameters will be updated, which can reduce a lot of GPU memory and calculations. With Partial FC,
102
+ we can scale trainset of 29 millions identities, the largest to date. Partial FC also supports multi-machine distributed
103
+ training and mixed precision training.
104
+
105
+ ![Image text](https://github.com/anxiangsir/insightface_arcface_log/blob/master/partial_fc_v2.png)
106
+
107
+ More details see
108
+ [speed_benchmark.md](docs/speed_benchmark.md) in docs.
109
+
110
+ ### 1. Training speed of different parallel methods (samples / second), Tesla V100 32GB * 8. (Larger is better)
111
+
112
+ `-` means training failed because of gpu memory limitations.
113
+
114
+ | Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
115
+ | :--- | :--- | :--- | :--- |
116
+ |125000 | 4681 | 4824 | 5004 |
117
+ |1400000 | **1672** | 3043 | 4738 |
118
+ |5500000 | **-** | **1389** | 3975 |
119
+ |8000000 | **-** | **-** | 3565 |
120
+ |16000000 | **-** | **-** | 2679 |
121
+ |29000000 | **-** | **-** | **1855** |
122
+
123
+ ### 2. GPU memory cost of different parallel methods (MB per GPU), Tesla V100 32GB * 8. (Smaller is better)
124
+
125
+ | Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
126
+ | :--- | :--- | :--- | :--- |
127
+ |125000 | 7358 | 5306 | 4868 |
128
+ |1400000 | 32252 | 11178 | 6056 |
129
+ |5500000 | **-** | 32188 | 9854 |
130
+ |8000000 | **-** | **-** | 12310 |
131
+ |16000000 | **-** | **-** | 19950 |
132
+ |29000000 | **-** | **-** | 32324 |
133
+
134
+ ## Evaluation ICCV2021-MFR and IJB-C
135
+
136
+ More details see [eval.md](docs/eval.md) in docs.
137
+
138
+ ## Test
139
+
140
+ We tested many versions of PyTorch. Please create an issue if you are having trouble.
141
+
142
+ - [x] torch 1.6.0
143
+ - [x] torch 1.7.1
144
+ - [x] torch 1.8.0
145
+ - [x] torch 1.9.0
146
+
147
+ ## Citation
148
+
149
+ ```
150
+ @inproceedings{deng2019arcface,
151
+ title={Arcface: Additive angular margin loss for deep face recognition},
152
+ author={Deng, Jiankang and Guo, Jia and Xue, Niannan and Zafeiriou, Stefanos},
153
+ booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
154
+ pages={4690--4699},
155
+ year={2019}
156
+ }
157
+ @inproceedings{an2020partical_fc,
158
+ title={Partial FC: Training 10 Million Identities on a Single Machine},
159
+ author={An, Xiang and Zhu, Xuhan and Xiao, Yang and Wu, Lan and Zhang, Ming and Gao, Yuan and Qin, Bin and
160
+ Zhang, Debing and Fu Ying},
161
+ booktitle={Arxiv 2010.05222},
162
+ year={2020}
163
+ }
164
+ ```
videoretalking/third_part/face3d/models/arcface_torch/backbones/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .iresnet import iresnet18, iresnet34, iresnet50, iresnet100, iresnet200
2
+ from .mobilefacenet import get_mbf
3
+
4
+
5
+ def get_model(name, **kwargs):
6
+ # resnet
7
+ if name == "r18":
8
+ return iresnet18(False, **kwargs)
9
+ elif name == "r34":
10
+ return iresnet34(False, **kwargs)
11
+ elif name == "r50":
12
+ return iresnet50(False, **kwargs)
13
+ elif name == "r100":
14
+ return iresnet100(False, **kwargs)
15
+ elif name == "r200":
16
+ return iresnet200(False, **kwargs)
17
+ elif name == "r2060":
18
+ from .iresnet2060 import iresnet2060
19
+ return iresnet2060(False, **kwargs)
20
+ elif name == "mbf":
21
+ fp16 = kwargs.get("fp16", False)
22
+ num_features = kwargs.get("num_features", 512)
23
+ return get_mbf(fp16=fp16, num_features=num_features)
24
+ else:
25
+ raise ValueError()
videoretalking/third_part/face3d/models/arcface_torch/backbones/iresnet.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ __all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200']
5
+
6
+
7
+ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
8
+ """3x3 convolution with padding"""
9
+ return nn.Conv2d(in_planes,
10
+ out_planes,
11
+ kernel_size=3,
12
+ stride=stride,
13
+ padding=dilation,
14
+ groups=groups,
15
+ bias=False,
16
+ dilation=dilation)
17
+
18
+
19
+ def conv1x1(in_planes, out_planes, stride=1):
20
+ """1x1 convolution"""
21
+ return nn.Conv2d(in_planes,
22
+ out_planes,
23
+ kernel_size=1,
24
+ stride=stride,
25
+ bias=False)
26
+
27
+
28
+ class IBasicBlock(nn.Module):
29
+ expansion = 1
30
+ def __init__(self, inplanes, planes, stride=1, downsample=None,
31
+ groups=1, base_width=64, dilation=1):
32
+ super(IBasicBlock, self).__init__()
33
+ if groups != 1 or base_width != 64:
34
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
35
+ if dilation > 1:
36
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
37
+ self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,)
38
+ self.conv1 = conv3x3(inplanes, planes)
39
+ self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,)
40
+ self.prelu = nn.PReLU(planes)
41
+ self.conv2 = conv3x3(planes, planes, stride)
42
+ self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,)
43
+ self.downsample = downsample
44
+ self.stride = stride
45
+
46
+ def forward(self, x):
47
+ identity = x
48
+ out = self.bn1(x)
49
+ out = self.conv1(out)
50
+ out = self.bn2(out)
51
+ out = self.prelu(out)
52
+ out = self.conv2(out)
53
+ out = self.bn3(out)
54
+ if self.downsample is not None:
55
+ identity = self.downsample(x)
56
+ out += identity
57
+ return out
58
+
59
+
60
+ class IResNet(nn.Module):
61
+ fc_scale = 7 * 7
62
+ def __init__(self,
63
+ block, layers, dropout=0, num_features=512, zero_init_residual=False,
64
+ groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
65
+ super(IResNet, self).__init__()
66
+ self.fp16 = fp16
67
+ self.inplanes = 64
68
+ self.dilation = 1
69
+ if replace_stride_with_dilation is None:
70
+ replace_stride_with_dilation = [False, False, False]
71
+ if len(replace_stride_with_dilation) != 3:
72
+ raise ValueError("replace_stride_with_dilation should be None "
73
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
74
+ self.groups = groups
75
+ self.base_width = width_per_group
76
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
77
+ self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
78
+ self.prelu = nn.PReLU(self.inplanes)
79
+ self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
80
+ self.layer2 = self._make_layer(block,
81
+ 128,
82
+ layers[1],
83
+ stride=2,
84
+ dilate=replace_stride_with_dilation[0])
85
+ self.layer3 = self._make_layer(block,
86
+ 256,
87
+ layers[2],
88
+ stride=2,
89
+ dilate=replace_stride_with_dilation[1])
90
+ self.layer4 = self._make_layer(block,
91
+ 512,
92
+ layers[3],
93
+ stride=2,
94
+ dilate=replace_stride_with_dilation[2])
95
+ self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,)
96
+ self.dropout = nn.Dropout(p=dropout, inplace=True)
97
+ self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
98
+ self.features = nn.BatchNorm1d(num_features, eps=1e-05)
99
+ nn.init.constant_(self.features.weight, 1.0)
100
+ self.features.weight.requires_grad = False
101
+
102
+ for m in self.modules():
103
+ if isinstance(m, nn.Conv2d):
104
+ nn.init.normal_(m.weight, 0, 0.1)
105
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
106
+ nn.init.constant_(m.weight, 1)
107
+ nn.init.constant_(m.bias, 0)
108
+
109
+ if zero_init_residual:
110
+ for m in self.modules():
111
+ if isinstance(m, IBasicBlock):
112
+ nn.init.constant_(m.bn2.weight, 0)
113
+
114
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
115
+ downsample = None
116
+ previous_dilation = self.dilation
117
+ if dilate:
118
+ self.dilation *= stride
119
+ stride = 1
120
+ if stride != 1 or self.inplanes != planes * block.expansion:
121
+ downsample = nn.Sequential(
122
+ conv1x1(self.inplanes, planes * block.expansion, stride),
123
+ nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
124
+ )
125
+ layers = []
126
+ layers.append(
127
+ block(self.inplanes, planes, stride, downsample, self.groups,
128
+ self.base_width, previous_dilation))
129
+ self.inplanes = planes * block.expansion
130
+ for _ in range(1, blocks):
131
+ layers.append(
132
+ block(self.inplanes,
133
+ planes,
134
+ groups=self.groups,
135
+ base_width=self.base_width,
136
+ dilation=self.dilation))
137
+
138
+ return nn.Sequential(*layers)
139
+
140
+ def forward(self, x):
141
+ with torch.cuda.amp.autocast(self.fp16):
142
+ x = self.conv1(x)
143
+ x = self.bn1(x)
144
+ x = self.prelu(x)
145
+ x = self.layer1(x)
146
+ x = self.layer2(x)
147
+ x = self.layer3(x)
148
+ x = self.layer4(x)
149
+ x = self.bn2(x)
150
+ x = torch.flatten(x, 1)
151
+ x = self.dropout(x)
152
+ x = self.fc(x.float() if self.fp16 else x)
153
+ x = self.features(x)
154
+ return x
155
+
156
+
157
+ def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
158
+ model = IResNet(block, layers, **kwargs)
159
+ if pretrained:
160
+ raise ValueError()
161
+ return model
162
+
163
+
164
+ def iresnet18(pretrained=False, progress=True, **kwargs):
165
+ return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained,
166
+ progress, **kwargs)
167
+
168
+
169
+ def iresnet34(pretrained=False, progress=True, **kwargs):
170
+ return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained,
171
+ progress, **kwargs)
172
+
173
+
174
+ def iresnet50(pretrained=False, progress=True, **kwargs):
175
+ return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained,
176
+ progress, **kwargs)
177
+
178
+
179
+ def iresnet100(pretrained=False, progress=True, **kwargs):
180
+ return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained,
181
+ progress, **kwargs)
182
+
183
+
184
+ def iresnet200(pretrained=False, progress=True, **kwargs):
185
+ return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained,
186
+ progress, **kwargs)
187
+
videoretalking/third_part/face3d/models/arcface_torch/backbones/iresnet2060.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ assert torch.__version__ >= "1.8.1"
5
+ from torch.utils.checkpoint import checkpoint_sequential
6
+
7
+ __all__ = ['iresnet2060']
8
+
9
+
10
+ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
11
+ """3x3 convolution with padding"""
12
+ return nn.Conv2d(in_planes,
13
+ out_planes,
14
+ kernel_size=3,
15
+ stride=stride,
16
+ padding=dilation,
17
+ groups=groups,
18
+ bias=False,
19
+ dilation=dilation)
20
+
21
+
22
+ def conv1x1(in_planes, out_planes, stride=1):
23
+ """1x1 convolution"""
24
+ return nn.Conv2d(in_planes,
25
+ out_planes,
26
+ kernel_size=1,
27
+ stride=stride,
28
+ bias=False)
29
+
30
+
31
+ class IBasicBlock(nn.Module):
32
+ expansion = 1
33
+
34
+ def __init__(self, inplanes, planes, stride=1, downsample=None,
35
+ groups=1, base_width=64, dilation=1):
36
+ super(IBasicBlock, self).__init__()
37
+ if groups != 1 or base_width != 64:
38
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
39
+ if dilation > 1:
40
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
41
+ self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05, )
42
+ self.conv1 = conv3x3(inplanes, planes)
43
+ self.bn2 = nn.BatchNorm2d(planes, eps=1e-05, )
44
+ self.prelu = nn.PReLU(planes)
45
+ self.conv2 = conv3x3(planes, planes, stride)
46
+ self.bn3 = nn.BatchNorm2d(planes, eps=1e-05, )
47
+ self.downsample = downsample
48
+ self.stride = stride
49
+
50
+ def forward(self, x):
51
+ identity = x
52
+ out = self.bn1(x)
53
+ out = self.conv1(out)
54
+ out = self.bn2(out)
55
+ out = self.prelu(out)
56
+ out = self.conv2(out)
57
+ out = self.bn3(out)
58
+ if self.downsample is not None:
59
+ identity = self.downsample(x)
60
+ out += identity
61
+ return out
62
+
63
+
64
+ class IResNet(nn.Module):
65
+ fc_scale = 7 * 7
66
+
67
+ def __init__(self,
68
+ block, layers, dropout=0, num_features=512, zero_init_residual=False,
69
+ groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
70
+ super(IResNet, self).__init__()
71
+ self.fp16 = fp16
72
+ self.inplanes = 64
73
+ self.dilation = 1
74
+ if replace_stride_with_dilation is None:
75
+ replace_stride_with_dilation = [False, False, False]
76
+ if len(replace_stride_with_dilation) != 3:
77
+ raise ValueError("replace_stride_with_dilation should be None "
78
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
79
+ self.groups = groups
80
+ self.base_width = width_per_group
81
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
82
+ self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
83
+ self.prelu = nn.PReLU(self.inplanes)
84
+ self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
85
+ self.layer2 = self._make_layer(block,
86
+ 128,
87
+ layers[1],
88
+ stride=2,
89
+ dilate=replace_stride_with_dilation[0])
90
+ self.layer3 = self._make_layer(block,
91
+ 256,
92
+ layers[2],
93
+ stride=2,
94
+ dilate=replace_stride_with_dilation[1])
95
+ self.layer4 = self._make_layer(block,
96
+ 512,
97
+ layers[3],
98
+ stride=2,
99
+ dilate=replace_stride_with_dilation[2])
100
+ self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05, )
101
+ self.dropout = nn.Dropout(p=dropout, inplace=True)
102
+ self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
103
+ self.features = nn.BatchNorm1d(num_features, eps=1e-05)
104
+ nn.init.constant_(self.features.weight, 1.0)
105
+ self.features.weight.requires_grad = False
106
+
107
+ for m in self.modules():
108
+ if isinstance(m, nn.Conv2d):
109
+ nn.init.normal_(m.weight, 0, 0.1)
110
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
111
+ nn.init.constant_(m.weight, 1)
112
+ nn.init.constant_(m.bias, 0)
113
+
114
+ if zero_init_residual:
115
+ for m in self.modules():
116
+ if isinstance(m, IBasicBlock):
117
+ nn.init.constant_(m.bn2.weight, 0)
118
+
119
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
120
+ downsample = None
121
+ previous_dilation = self.dilation
122
+ if dilate:
123
+ self.dilation *= stride
124
+ stride = 1
125
+ if stride != 1 or self.inplanes != planes * block.expansion:
126
+ downsample = nn.Sequential(
127
+ conv1x1(self.inplanes, planes * block.expansion, stride),
128
+ nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
129
+ )
130
+ layers = []
131
+ layers.append(
132
+ block(self.inplanes, planes, stride, downsample, self.groups,
133
+ self.base_width, previous_dilation))
134
+ self.inplanes = planes * block.expansion
135
+ for _ in range(1, blocks):
136
+ layers.append(
137
+ block(self.inplanes,
138
+ planes,
139
+ groups=self.groups,
140
+ base_width=self.base_width,
141
+ dilation=self.dilation))
142
+
143
+ return nn.Sequential(*layers)
144
+
145
+ def checkpoint(self, func, num_seg, x):
146
+ if self.training:
147
+ return checkpoint_sequential(func, num_seg, x)
148
+ else:
149
+ return func(x)
150
+
151
+ def forward(self, x):
152
+ with torch.cuda.amp.autocast(self.fp16):
153
+ x = self.conv1(x)
154
+ x = self.bn1(x)
155
+ x = self.prelu(x)
156
+ x = self.layer1(x)
157
+ x = self.checkpoint(self.layer2, 20, x)
158
+ x = self.checkpoint(self.layer3, 100, x)
159
+ x = self.layer4(x)
160
+ x = self.bn2(x)
161
+ x = torch.flatten(x, 1)
162
+ x = self.dropout(x)
163
+ x = self.fc(x.float() if self.fp16 else x)
164
+ x = self.features(x)
165
+ return x
166
+
167
+
168
+ def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
169
+ model = IResNet(block, layers, **kwargs)
170
+ if pretrained:
171
+ raise ValueError()
172
+ return model
173
+
174
+
175
+ def iresnet2060(pretrained=False, progress=True, **kwargs):
176
+ return _iresnet('iresnet2060', IBasicBlock, [3, 128, 1024 - 128, 3], pretrained, progress, **kwargs)
videoretalking/third_part/face3d/models/arcface_torch/backbones/mobilefacenet.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Adapted from https://github.com/cavalleria/cavaface.pytorch/blob/master/backbone/mobilefacenet.py
3
+ Original author cavalleria
4
+ '''
5
+
6
+ import torch.nn as nn
7
+ from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Sequential, Module
8
+ import torch
9
+
10
+
11
+ class Flatten(Module):
12
+ def forward(self, x):
13
+ return x.view(x.size(0), -1)
14
+
15
+
16
+ class ConvBlock(Module):
17
+ def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
18
+ super(ConvBlock, self).__init__()
19
+ self.layers = nn.Sequential(
20
+ Conv2d(in_c, out_c, kernel, groups=groups, stride=stride, padding=padding, bias=False),
21
+ BatchNorm2d(num_features=out_c),
22
+ PReLU(num_parameters=out_c)
23
+ )
24
+
25
+ def forward(self, x):
26
+ return self.layers(x)
27
+
28
+
29
+ class LinearBlock(Module):
30
+ def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
31
+ super(LinearBlock, self).__init__()
32
+ self.layers = nn.Sequential(
33
+ Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False),
34
+ BatchNorm2d(num_features=out_c)
35
+ )
36
+
37
+ def forward(self, x):
38
+ return self.layers(x)
39
+
40
+
41
+ class DepthWise(Module):
42
+ def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1):
43
+ super(DepthWise, self).__init__()
44
+ self.residual = residual
45
+ self.layers = nn.Sequential(
46
+ ConvBlock(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)),
47
+ ConvBlock(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride),
48
+ LinearBlock(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
49
+ )
50
+
51
+ def forward(self, x):
52
+ short_cut = None
53
+ if self.residual:
54
+ short_cut = x
55
+ x = self.layers(x)
56
+ if self.residual:
57
+ output = short_cut + x
58
+ else:
59
+ output = x
60
+ return output
61
+
62
+
63
+ class Residual(Module):
64
+ def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)):
65
+ super(Residual, self).__init__()
66
+ modules = []
67
+ for _ in range(num_block):
68
+ modules.append(DepthWise(c, c, True, kernel, stride, padding, groups))
69
+ self.layers = Sequential(*modules)
70
+
71
+ def forward(self, x):
72
+ return self.layers(x)
73
+
74
+
75
+ class GDC(Module):
76
+ def __init__(self, embedding_size):
77
+ super(GDC, self).__init__()
78
+ self.layers = nn.Sequential(
79
+ LinearBlock(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0)),
80
+ Flatten(),
81
+ Linear(512, embedding_size, bias=False),
82
+ BatchNorm1d(embedding_size))
83
+
84
+ def forward(self, x):
85
+ return self.layers(x)
86
+
87
+
88
+ class MobileFaceNet(Module):
89
+ def __init__(self, fp16=False, num_features=512):
90
+ super(MobileFaceNet, self).__init__()
91
+ scale = 2
92
+ self.fp16 = fp16
93
+ self.layers = nn.Sequential(
94
+ ConvBlock(3, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1)),
95
+ ConvBlock(64 * scale, 64 * scale, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64),
96
+ DepthWise(64 * scale, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128),
97
+ Residual(64 * scale, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
98
+ DepthWise(64 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256),
99
+ Residual(128 * scale, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
100
+ DepthWise(128 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512),
101
+ Residual(128 * scale, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
102
+ )
103
+ self.conv_sep = ConvBlock(128 * scale, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0))
104
+ self.features = GDC(num_features)
105
+ self._initialize_weights()
106
+
107
+ def _initialize_weights(self):
108
+ for m in self.modules():
109
+ if isinstance(m, nn.Conv2d):
110
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
111
+ if m.bias is not None:
112
+ m.bias.data.zero_()
113
+ elif isinstance(m, nn.BatchNorm2d):
114
+ m.weight.data.fill_(1)
115
+ m.bias.data.zero_()
116
+ elif isinstance(m, nn.Linear):
117
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
118
+ if m.bias is not None:
119
+ m.bias.data.zero_()
120
+
121
+ def forward(self, x):
122
+ with torch.cuda.amp.autocast(self.fp16):
123
+ x = self.layers(x)
124
+ x = self.conv_sep(x.float() if self.fp16 else x)
125
+ x = self.features(x)
126
+ return x
127
+
128
+
129
+ def get_mbf(fp16, num_features):
130
+ return MobileFaceNet(fp16, num_features)
videoretalking/third_part/face3d/models/arcface_torch/configs/3millions.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict as edict
2
+
3
+ # configs for test speed
4
+
5
+ config = edict()
6
+ config.loss = "arcface"
7
+ config.network = "r50"
8
+ config.resume = False
9
+ config.output = None
10
+ config.embedding_size = 512
11
+ config.sample_rate = 1.0
12
+ config.fp16 = True
13
+ config.momentum = 0.9
14
+ config.weight_decay = 5e-4
15
+ config.batch_size = 128
16
+ config.lr = 0.1 # batch size is 512
17
+
18
+ config.rec = "synthetic"
19
+ config.num_classes = 300 * 10000
20
+ config.num_epoch = 30
21
+ config.warmup_epoch = -1
22
+ config.decay_epoch = [10, 16, 22]
23
+ config.val_targets = []
videoretalking/third_part/face3d/models/arcface_torch/configs/3millions_pfc.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict as edict
2
+
3
+ # configs for test speed
4
+
5
+ config = edict()
6
+ config.loss = "arcface"
7
+ config.network = "r50"
8
+ config.resume = False
9
+ config.output = None
10
+ config.embedding_size = 512
11
+ config.sample_rate = 0.1
12
+ config.fp16 = True
13
+ config.momentum = 0.9
14
+ config.weight_decay = 5e-4
15
+ config.batch_size = 128
16
+ config.lr = 0.1 # batch size is 512
17
+
18
+ config.rec = "synthetic"
19
+ config.num_classes = 300 * 10000
20
+ config.num_epoch = 30
21
+ config.warmup_epoch = -1
22
+ config.decay_epoch = [10, 16, 22]
23
+ config.val_targets = []
videoretalking/third_part/face3d/models/arcface_torch/configs/__init__.py ADDED
File without changes
videoretalking/third_part/face3d/models/arcface_torch/configs/base.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict as edict
2
+
3
+ # make training faster
4
+ # our RAM is 256G
5
+ # mount -t tmpfs -o size=140G tmpfs /train_tmp
6
+
7
+ config = edict()
8
+ config.loss = "arcface"
9
+ config.network = "r50"
10
+ config.resume = False
11
+ config.output = "ms1mv3_arcface_r50"
12
+
13
+ config.dataset = "ms1m-retinaface-t1"
14
+ config.embedding_size = 512
15
+ config.sample_rate = 1
16
+ config.fp16 = False
17
+ config.momentum = 0.9
18
+ config.weight_decay = 5e-4
19
+ config.batch_size = 128
20
+ config.lr = 0.1 # batch size is 512
21
+
22
+ if config.dataset == "emore":
23
+ config.rec = "/train_tmp/faces_emore"
24
+ config.num_classes = 85742
25
+ config.num_image = 5822653
26
+ config.num_epoch = 16
27
+ config.warmup_epoch = -1
28
+ config.decay_epoch = [8, 14, ]
29
+ config.val_targets = ["lfw", ]
30
+
31
+ elif config.dataset == "ms1m-retinaface-t1":
32
+ config.rec = "/train_tmp/ms1m-retinaface-t1"
33
+ config.num_classes = 93431
34
+ config.num_image = 5179510
35
+ config.num_epoch = 25
36
+ config.warmup_epoch = -1
37
+ config.decay_epoch = [11, 17, 22]
38
+ config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
39
+
40
+ elif config.dataset == "glint360k":
41
+ config.rec = "/train_tmp/glint360k"
42
+ config.num_classes = 360232
43
+ config.num_image = 17091657
44
+ config.num_epoch = 20
45
+ config.warmup_epoch = -1
46
+ config.decay_epoch = [8, 12, 15, 18]
47
+ config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
48
+
49
+ elif config.dataset == "webface":
50
+ config.rec = "/train_tmp/faces_webface_112x112"
51
+ config.num_classes = 10572
52
+ config.num_image = "forget"
53
+ config.num_epoch = 34
54
+ config.warmup_epoch = -1
55
+ config.decay_epoch = [20, 28, 32]
56
+ config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
videoretalking/third_part/face3d/models/arcface_torch/configs/glint360k_mbf.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict as edict
2
+
3
+ # make training faster
4
+ # our RAM is 256G
5
+ # mount -t tmpfs -o size=140G tmpfs /train_tmp
6
+
7
+ config = edict()
8
+ config.loss = "cosface"
9
+ config.network = "mbf"
10
+ config.resume = False
11
+ config.output = None
12
+ config.embedding_size = 512
13
+ config.sample_rate = 0.1
14
+ config.fp16 = True
15
+ config.momentum = 0.9
16
+ config.weight_decay = 2e-4
17
+ config.batch_size = 128
18
+ config.lr = 0.1 # batch size is 512
19
+
20
+ config.rec = "/train_tmp/glint360k"
21
+ config.num_classes = 360232
22
+ config.num_image = 17091657
23
+ config.num_epoch = 20
24
+ config.warmup_epoch = -1
25
+ config.decay_epoch = [8, 12, 15, 18]
26
+ config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
videoretalking/third_part/face3d/models/arcface_torch/configs/glint360k_r100.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict as edict
2
+
3
+ # make training faster
4
+ # our RAM is 256G
5
+ # mount -t tmpfs -o size=140G tmpfs /train_tmp
6
+
7
+ config = edict()
8
+ config.loss = "cosface"
9
+ config.network = "r100"
10
+ config.resume = False
11
+ config.output = None
12
+ config.embedding_size = 512
13
+ config.sample_rate = 1.0
14
+ config.fp16 = True
15
+ config.momentum = 0.9
16
+ config.weight_decay = 5e-4
17
+ config.batch_size = 128
18
+ config.lr = 0.1 # batch size is 512
19
+
20
+ config.rec = "/train_tmp/glint360k"
21
+ config.num_classes = 360232
22
+ config.num_image = 17091657
23
+ config.num_epoch = 20
24
+ config.warmup_epoch = -1
25
+ config.decay_epoch = [8, 12, 15, 18]
26
+ config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
videoretalking/third_part/face3d/models/arcface_torch/configs/glint360k_r18.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict as edict
2
+
3
+ # make training faster
4
+ # our RAM is 256G
5
+ # mount -t tmpfs -o size=140G tmpfs /train_tmp
6
+
7
+ config = edict()
8
+ config.loss = "cosface"
9
+ config.network = "r18"
10
+ config.resume = False
11
+ config.output = None
12
+ config.embedding_size = 512
13
+ config.sample_rate = 1.0
14
+ config.fp16 = True
15
+ config.momentum = 0.9
16
+ config.weight_decay = 5e-4
17
+ config.batch_size = 128
18
+ config.lr = 0.1 # batch size is 512
19
+
20
+ config.rec = "/train_tmp/glint360k"
21
+ config.num_classes = 360232
22
+ config.num_image = 17091657
23
+ config.num_epoch = 20
24
+ config.warmup_epoch = -1
25
+ config.decay_epoch = [8, 12, 15, 18]
26
+ config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
videoretalking/third_part/face3d/models/arcface_torch/configs/glint360k_r34.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict as edict
2
+
3
+ # make training faster
4
+ # our RAM is 256G
5
+ # mount -t tmpfs -o size=140G tmpfs /train_tmp
6
+
7
+ config = edict()
8
+ config.loss = "cosface"
9
+ config.network = "r34"
10
+ config.resume = False
11
+ config.output = None
12
+ config.embedding_size = 512
13
+ config.sample_rate = 1.0
14
+ config.fp16 = True
15
+ config.momentum = 0.9
16
+ config.weight_decay = 5e-4
17
+ config.batch_size = 128
18
+ config.lr = 0.1 # batch size is 512
19
+
20
+ config.rec = "/train_tmp/glint360k"
21
+ config.num_classes = 360232
22
+ config.num_image = 17091657
23
+ config.num_epoch = 20
24
+ config.warmup_epoch = -1
25
+ config.decay_epoch = [8, 12, 15, 18]
26
+ config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
videoretalking/third_part/face3d/models/arcface_torch/configs/glint360k_r50.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict as edict
2
+
3
+ # make training faster
4
+ # our RAM is 256G
5
+ # mount -t tmpfs -o size=140G tmpfs /train_tmp
6
+
7
+ config = edict()
8
+ config.loss = "cosface"
9
+ config.network = "r50"
10
+ config.resume = False
11
+ config.output = None
12
+ config.embedding_size = 512
13
+ config.sample_rate = 1.0
14
+ config.fp16 = True
15
+ config.momentum = 0.9
16
+ config.weight_decay = 5e-4
17
+ config.batch_size = 128
18
+ config.lr = 0.1 # batch size is 512
19
+
20
+ config.rec = "/train_tmp/glint360k"
21
+ config.num_classes = 360232
22
+ config.num_image = 17091657
23
+ config.num_epoch = 20
24
+ config.warmup_epoch = -1
25
+ config.decay_epoch = [8, 12, 15, 18]
26
+ config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
videoretalking/third_part/face3d/models/arcface_torch/configs/ms1mv3_mbf.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict as edict
2
+
3
+ # make training faster
4
+ # our RAM is 256G
5
+ # mount -t tmpfs -o size=140G tmpfs /train_tmp
6
+
7
+ config = edict()
8
+ config.loss = "arcface"
9
+ config.network = "mbf"
10
+ config.resume = False
11
+ config.output = None
12
+ config.embedding_size = 512
13
+ config.sample_rate = 1.0
14
+ config.fp16 = True
15
+ config.momentum = 0.9
16
+ config.weight_decay = 2e-4
17
+ config.batch_size = 128
18
+ config.lr = 0.1 # batch size is 512
19
+
20
+ config.rec = "/train_tmp/ms1m-retinaface-t1"
21
+ config.num_classes = 93431
22
+ config.num_image = 5179510
23
+ config.num_epoch = 30
24
+ config.warmup_epoch = -1
25
+ config.decay_epoch = [10, 20, 25]
26
+ config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
videoretalking/third_part/face3d/models/arcface_torch/configs/ms1mv3_r18.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict as edict
2
+
3
+ # make training faster
4
+ # our RAM is 256G
5
+ # mount -t tmpfs -o size=140G tmpfs /train_tmp
6
+
7
+ config = edict()
8
+ config.loss = "arcface"
9
+ config.network = "r18"
10
+ config.resume = False
11
+ config.output = None
12
+ config.embedding_size = 512
13
+ config.sample_rate = 1.0
14
+ config.fp16 = True
15
+ config.momentum = 0.9
16
+ config.weight_decay = 5e-4
17
+ config.batch_size = 128
18
+ config.lr = 0.1 # batch size is 512
19
+
20
+ config.rec = "/train_tmp/ms1m-retinaface-t1"
21
+ config.num_classes = 93431
22
+ config.num_image = 5179510
23
+ config.num_epoch = 25
24
+ config.warmup_epoch = -1
25
+ config.decay_epoch = [10, 16, 22]
26
+ config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
videoretalking/third_part/face3d/models/arcface_torch/configs/ms1mv3_r2060.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict as edict
2
+
3
+ # make training faster
4
+ # our RAM is 256G
5
+ # mount -t tmpfs -o size=140G tmpfs /train_tmp
6
+
7
+ config = edict()
8
+ config.loss = "arcface"
9
+ config.network = "r2060"
10
+ config.resume = False
11
+ config.output = None
12
+ config.embedding_size = 512
13
+ config.sample_rate = 1.0
14
+ config.fp16 = True
15
+ config.momentum = 0.9
16
+ config.weight_decay = 5e-4
17
+ config.batch_size = 64
18
+ config.lr = 0.1 # batch size is 512
19
+
20
+ config.rec = "/train_tmp/ms1m-retinaface-t1"
21
+ config.num_classes = 93431
22
+ config.num_image = 5179510
23
+ config.num_epoch = 25
24
+ config.warmup_epoch = -1
25
+ config.decay_epoch = [10, 16, 22]
26
+ config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
videoretalking/third_part/face3d/models/arcface_torch/configs/ms1mv3_r34.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict as edict
2
+
3
+ # make training faster
4
+ # our RAM is 256G
5
+ # mount -t tmpfs -o size=140G tmpfs /train_tmp
6
+
7
+ config = edict()
8
+ config.loss = "arcface"
9
+ config.network = "r34"
10
+ config.resume = False
11
+ config.output = None
12
+ config.embedding_size = 512
13
+ config.sample_rate = 1.0
14
+ config.fp16 = True
15
+ config.momentum = 0.9
16
+ config.weight_decay = 5e-4
17
+ config.batch_size = 128
18
+ config.lr = 0.1 # batch size is 512
19
+
20
+ config.rec = "/train_tmp/ms1m-retinaface-t1"
21
+ config.num_classes = 93431
22
+ config.num_image = 5179510
23
+ config.num_epoch = 25
24
+ config.warmup_epoch = -1
25
+ config.decay_epoch = [10, 16, 22]
26
+ config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
videoretalking/third_part/face3d/models/arcface_torch/configs/ms1mv3_r50.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict as edict
2
+
3
+ # make training faster
4
+ # our RAM is 256G
5
+ # mount -t tmpfs -o size=140G tmpfs /train_tmp
6
+
7
+ config = edict()
8
+ config.loss = "arcface"
9
+ config.network = "r50"
10
+ config.resume = False
11
+ config.output = None
12
+ config.embedding_size = 512
13
+ config.sample_rate = 1.0
14
+ config.fp16 = True
15
+ config.momentum = 0.9
16
+ config.weight_decay = 5e-4
17
+ config.batch_size = 128
18
+ config.lr = 0.1 # batch size is 512
19
+
20
+ config.rec = "/train_tmp/ms1m-retinaface-t1"
21
+ config.num_classes = 93431
22
+ config.num_image = 5179510
23
+ config.num_epoch = 25
24
+ config.warmup_epoch = -1
25
+ config.decay_epoch = [10, 16, 22]
26
+ config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
videoretalking/third_part/face3d/models/arcface_torch/configs/speed.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict as edict
2
+
3
+ # configs for test speed
4
+
5
+ config = edict()
6
+ config.loss = "arcface"
7
+ config.network = "r50"
8
+ config.resume = False
9
+ config.output = None
10
+ config.embedding_size = 512
11
+ config.sample_rate = 1.0
12
+ config.fp16 = True
13
+ config.momentum = 0.9
14
+ config.weight_decay = 5e-4
15
+ config.batch_size = 128
16
+ config.lr = 0.1 # batch size is 512
17
+
18
+ config.rec = "synthetic"
19
+ config.num_classes = 100 * 10000
20
+ config.num_epoch = 30
21
+ config.warmup_epoch = -1
22
+ config.decay_epoch = [10, 16, 22]
23
+ config.val_targets = []
videoretalking/third_part/face3d/models/arcface_torch/dataset.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numbers
2
+ import os
3
+ import queue as Queue
4
+ import threading
5
+
6
+ import mxnet as mx
7
+ import numpy as np
8
+ import torch
9
+ from torch.utils.data import DataLoader, Dataset
10
+ from torchvision import transforms
11
+
12
+
13
+ class BackgroundGenerator(threading.Thread):
14
+ def __init__(self, generator, local_rank, max_prefetch=6):
15
+ super(BackgroundGenerator, self).__init__()
16
+ self.queue = Queue.Queue(max_prefetch)
17
+ self.generator = generator
18
+ self.local_rank = local_rank
19
+ self.daemon = True
20
+ self.start()
21
+
22
+ def run(self):
23
+ torch.cuda.set_device(self.local_rank)
24
+ for item in self.generator:
25
+ self.queue.put(item)
26
+ self.queue.put(None)
27
+
28
+ def next(self):
29
+ next_item = self.queue.get()
30
+ if next_item is None:
31
+ raise StopIteration
32
+ return next_item
33
+
34
+ def __next__(self):
35
+ return self.next()
36
+
37
+ def __iter__(self):
38
+ return self
39
+
40
+
41
+ class DataLoaderX(DataLoader):
42
+
43
+ def __init__(self, local_rank, **kwargs):
44
+ super(DataLoaderX, self).__init__(**kwargs)
45
+ self.stream = torch.cuda.Stream(local_rank)
46
+ self.local_rank = local_rank
47
+
48
+ def __iter__(self):
49
+ self.iter = super(DataLoaderX, self).__iter__()
50
+ self.iter = BackgroundGenerator(self.iter, self.local_rank)
51
+ self.preload()
52
+ return self
53
+
54
+ def preload(self):
55
+ self.batch = next(self.iter, None)
56
+ if self.batch is None:
57
+ return None
58
+ with torch.cuda.stream(self.stream):
59
+ for k in range(len(self.batch)):
60
+ self.batch[k] = self.batch[k].to(device=self.local_rank, non_blocking=True)
61
+
62
+ def __next__(self):
63
+ torch.cuda.current_stream().wait_stream(self.stream)
64
+ batch = self.batch
65
+ if batch is None:
66
+ raise StopIteration
67
+ self.preload()
68
+ return batch
69
+
70
+
71
+ class MXFaceDataset(Dataset):
72
+ def __init__(self, root_dir, local_rank):
73
+ super(MXFaceDataset, self).__init__()
74
+ self.transform = transforms.Compose(
75
+ [transforms.ToPILImage(),
76
+ transforms.RandomHorizontalFlip(),
77
+ transforms.ToTensor(),
78
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
79
+ ])
80
+ self.root_dir = root_dir
81
+ self.local_rank = local_rank
82
+ path_imgrec = os.path.join(root_dir, 'train.rec')
83
+ path_imgidx = os.path.join(root_dir, 'train.idx')
84
+ self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r')
85
+ s = self.imgrec.read_idx(0)
86
+ header, _ = mx.recordio.unpack(s)
87
+ if header.flag > 0:
88
+ self.header0 = (int(header.label[0]), int(header.label[1]))
89
+ self.imgidx = np.array(range(1, int(header.label[0])))
90
+ else:
91
+ self.imgidx = np.array(list(self.imgrec.keys))
92
+
93
+ def __getitem__(self, index):
94
+ idx = self.imgidx[index]
95
+ s = self.imgrec.read_idx(idx)
96
+ header, img = mx.recordio.unpack(s)
97
+ label = header.label
98
+ if not isinstance(label, numbers.Number):
99
+ label = label[0]
100
+ label = torch.tensor(label, dtype=torch.long)
101
+ sample = mx.image.imdecode(img).asnumpy()
102
+ if self.transform is not None:
103
+ sample = self.transform(sample)
104
+ return sample, label
105
+
106
+ def __len__(self):
107
+ return len(self.imgidx)
108
+
109
+
110
+ class SyntheticDataset(Dataset):
111
+ def __init__(self, local_rank):
112
+ super(SyntheticDataset, self).__init__()
113
+ img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32)
114
+ img = np.transpose(img, (2, 0, 1))
115
+ img = torch.from_numpy(img).squeeze(0).float()
116
+ img = ((img / 255) - 0.5) / 0.5
117
+ self.img = img
118
+ self.label = 1
119
+
120
+ def __getitem__(self, index):
121
+ return self.img, self.label
122
+
123
+ def __len__(self):
124
+ return 1000000
videoretalking/third_part/face3d/models/arcface_torch/docs/eval.md ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Eval on ICCV2021-MFR
2
+
3
+ coming soon.
4
+
5
+
6
+ ## Eval IJBC
7
+ You can eval ijbc with pytorch or onnx.
8
+
9
+
10
+ 1. Eval IJBC With Onnx
11
+ ```shell
12
+ CUDA_VISIBLE_DEVICES=0 python onnx_ijbc.py --model-root ms1mv3_arcface_r50 --image-path IJB_release/IJBC --result-dir ms1mv3_arcface_r50
13
+ ```
14
+
15
+ 2. Eval IJBC With Pytorch
16
+ ```shell
17
+ CUDA_VISIBLE_DEVICES=0,1 python eval_ijbc.py \
18
+ --model-prefix ms1mv3_arcface_r50/backbone.pth \
19
+ --image-path IJB_release/IJBC \
20
+ --result-dir ms1mv3_arcface_r50 \
21
+ --batch-size 128 \
22
+ --job ms1mv3_arcface_r50 \
23
+ --target IJBC \
24
+ --network iresnet50
25
+ ```
26
+
27
+ ## Inference
28
+
29
+ ```shell
30
+ python inference.py --weight ms1mv3_arcface_r50/backbone.pth --network r50
31
+ ```
videoretalking/third_part/face3d/models/arcface_torch/docs/install.md ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## v1.8.0
2
+ ### Linux and Windows
3
+ ```shell
4
+ # CUDA 11.0
5
+ pip --default-timeout=100 install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html
6
+
7
+ # CUDA 10.2
8
+ pip --default-timeout=100 install torch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0
9
+
10
+ # CPU only
11
+ pip --default-timeout=100 install torch==1.8.0+cpu torchvision==0.9.0+cpu torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html
12
+
13
+ ```
14
+
15
+
16
+ ## v1.7.1
17
+ ### Linux and Windows
18
+ ```shell
19
+ # CUDA 11.0
20
+ pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
21
+
22
+ # CUDA 10.2
23
+ pip install torch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2
24
+
25
+ # CUDA 10.1
26
+ pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
27
+
28
+ # CUDA 9.2
29
+ pip install torch==1.7.1+cu92 torchvision==0.8.2+cu92 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
30
+
31
+ # CPU only
32
+ pip install torch==1.7.1+cpu torchvision==0.8.2+cpu torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
33
+ ```
34
+
35
+
36
+ ## v1.6.0
37
+
38
+ ### Linux and Windows
39
+ ```shell
40
+ # CUDA 10.2
41
+ pip install torch==1.6.0 torchvision==0.7.0
42
+
43
+ # CUDA 10.1
44
+ pip install torch==1.6.0+cu101 torchvision==0.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
45
+
46
+ # CUDA 9.2
47
+ pip install torch==1.6.0+cu92 torchvision==0.7.0+cu92 -f https://download.pytorch.org/whl/torch_stable.html
48
+
49
+ # CPU only
50
+ pip install torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
51
+ ```
videoretalking/third_part/face3d/models/arcface_torch/docs/modelzoo.md ADDED
File without changes
videoretalking/third_part/face3d/models/arcface_torch/docs/speed_benchmark.md ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Test Training Speed
2
+
3
+ - Test Commands
4
+
5
+ You need to use the following two commands to test the Partial FC training performance.
6
+ The number of identites is **3 millions** (synthetic data), turn mixed precision training on, backbone is resnet50,
7
+ batch size is 1024.
8
+ ```shell
9
+ # Model Parallel
10
+ python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/3millions
11
+ # Partial FC 0.1
12
+ python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/3millions_pfc
13
+ ```
14
+
15
+ - GPU Memory
16
+
17
+ ```
18
+ # (Model Parallel) gpustat -i
19
+ [0] Tesla V100-SXM2-32GB | 64'C, 94 % | 30338 / 32510 MB
20
+ [1] Tesla V100-SXM2-32GB | 60'C, 99 % | 28876 / 32510 MB
21
+ [2] Tesla V100-SXM2-32GB | 60'C, 99 % | 28872 / 32510 MB
22
+ [3] Tesla V100-SXM2-32GB | 69'C, 99 % | 28872 / 32510 MB
23
+ [4] Tesla V100-SXM2-32GB | 66'C, 99 % | 28888 / 32510 MB
24
+ [5] Tesla V100-SXM2-32GB | 60'C, 99 % | 28932 / 32510 MB
25
+ [6] Tesla V100-SXM2-32GB | 68'C, 100 % | 28916 / 32510 MB
26
+ [7] Tesla V100-SXM2-32GB | 65'C, 99 % | 28860 / 32510 MB
27
+
28
+ # (Partial FC 0.1) gpustat -i
29
+ [0] Tesla V100-SXM2-32GB | 60'C, 95 % | 10488 / 32510 MB │·······················
30
+ [1] Tesla V100-SXM2-32GB | 60'C, 97 % | 10344 / 32510 MB │·······················
31
+ [2] Tesla V100-SXM2-32GB | 61'C, 95 % | 10340 / 32510 MB │·······················
32
+ [3] Tesla V100-SXM2-32GB | 66'C, 95 % | 10340 / 32510 MB │·······················
33
+ [4] Tesla V100-SXM2-32GB | 65'C, 94 % | 10356 / 32510 MB │·······················
34
+ [5] Tesla V100-SXM2-32GB | 61'C, 95 % | 10400 / 32510 MB │·······················
35
+ [6] Tesla V100-SXM2-32GB | 68'C, 96 % | 10384 / 32510 MB │·······················
36
+ [7] Tesla V100-SXM2-32GB | 64'C, 95 % | 10328 / 32510 MB │·······················
37
+ ```
38
+
39
+ - Training Speed
40
+
41
+ ```python
42
+ # (Model Parallel) trainging.log
43
+ Training: Speed 2271.33 samples/sec Loss 1.1624 LearningRate 0.2000 Epoch: 0 Global Step: 100
44
+ Training: Speed 2269.94 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 150
45
+ Training: Speed 2272.67 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 200
46
+ Training: Speed 2266.55 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 250
47
+ Training: Speed 2272.54 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 300
48
+
49
+ # (Partial FC 0.1) trainging.log
50
+ Training: Speed 5299.56 samples/sec Loss 1.0965 LearningRate 0.2000 Epoch: 0 Global Step: 100
51
+ Training: Speed 5296.37 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 150
52
+ Training: Speed 5304.37 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 200
53
+ Training: Speed 5274.43 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 250
54
+ Training: Speed 5300.10 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 300
55
+ ```
56
+
57
+ In this test case, Partial FC 0.1 only use1 1/3 of the GPU memory of the model parallel,
58
+ and the training speed is 2.5 times faster than the model parallel.
59
+
60
+
61
+ ## Speed Benchmark
62
+
63
+ 1. Training speed of different parallel methods (samples/second), Tesla V100 32GB * 8. (Larger is better)
64
+
65
+ | Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
66
+ | :--- | :--- | :--- | :--- |
67
+ |125000 | 4681 | 4824 | 5004 |
68
+ |250000 | 4047 | 4521 | 4976 |
69
+ |500000 | 3087 | 4013 | 4900 |
70
+ |1000000 | 2090 | 3449 | 4803 |
71
+ |1400000 | 1672 | 3043 | 4738 |
72
+ |2000000 | - | 2593 | 4626 |
73
+ |4000000 | - | 1748 | 4208 |
74
+ |5500000 | - | 1389 | 3975 |
75
+ |8000000 | - | - | 3565 |
76
+ |16000000 | - | - | 2679 |
77
+ |29000000 | - | - | 1855 |
78
+
79
+ 2. GPU memory cost of different parallel methods (GB per GPU), Tesla V100 32GB * 8. (Smaller is better)
80
+
81
+ | Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
82
+ | :--- | :--- | :--- | :--- |
83
+ |125000 | 7358 | 5306 | 4868 |
84
+ |250000 | 9940 | 5826 | 5004 |
85
+ |500000 | 14220 | 7114 | 5202 |
86
+ |1000000 | 23708 | 9966 | 5620 |
87
+ |1400000 | 32252 | 11178 | 6056 |
88
+ |2000000 | - | 13978 | 6472 |
89
+ |4000000 | - | 23238 | 8284 |
90
+ |5500000 | - | 32188 | 9854 |
91
+ |8000000 | - | - | 12310 |
92
+ |16000000 | - | - | 19950 |
93
+ |29000000 | - | - | 32324 |
videoretalking/third_part/face3d/models/arcface_torch/eval/__init__.py ADDED
File without changes
videoretalking/third_part/face3d/models/arcface_torch/eval/verification.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Helper for evaluation on the Labeled Faces in the Wild dataset
2
+ """
3
+
4
+ # MIT License
5
+ #
6
+ # Copyright (c) 2016 David Sandberg
7
+ #
8
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ # of this software and associated documentation files (the "Software"), to deal
10
+ # in the Software without restriction, including without limitation the rights
11
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ # copies of the Software, and to permit persons to whom the Software is
13
+ # furnished to do so, subject to the following conditions:
14
+ #
15
+ # The above copyright notice and this permission notice shall be included in all
16
+ # copies or substantial portions of the Software.
17
+ #
18
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
+ # SOFTWARE.
25
+
26
+
27
+ import datetime
28
+ import os
29
+ import pickle
30
+
31
+ import mxnet as mx
32
+ import numpy as np
33
+ import sklearn
34
+ import torch
35
+ from mxnet import ndarray as nd
36
+ from scipy import interpolate
37
+ from sklearn.decomposition import PCA
38
+ from sklearn.model_selection import KFold
39
+
40
+
41
+ class LFold:
42
+ def __init__(self, n_splits=2, shuffle=False):
43
+ self.n_splits = n_splits
44
+ if self.n_splits > 1:
45
+ self.k_fold = KFold(n_splits=n_splits, shuffle=shuffle)
46
+
47
+ def split(self, indices):
48
+ if self.n_splits > 1:
49
+ return self.k_fold.split(indices)
50
+ else:
51
+ return [(indices, indices)]
52
+
53
+
54
+ def calculate_roc(thresholds,
55
+ embeddings1,
56
+ embeddings2,
57
+ actual_issame,
58
+ nrof_folds=10,
59
+ pca=0):
60
+ assert (embeddings1.shape[0] == embeddings2.shape[0])
61
+ assert (embeddings1.shape[1] == embeddings2.shape[1])
62
+ nrof_pairs = min(len(actual_issame), embeddings1.shape[0])
63
+ nrof_thresholds = len(thresholds)
64
+ k_fold = LFold(n_splits=nrof_folds, shuffle=False)
65
+
66
+ tprs = np.zeros((nrof_folds, nrof_thresholds))
67
+ fprs = np.zeros((nrof_folds, nrof_thresholds))
68
+ accuracy = np.zeros((nrof_folds))
69
+ indices = np.arange(nrof_pairs)
70
+
71
+ if pca == 0:
72
+ diff = np.subtract(embeddings1, embeddings2)
73
+ dist = np.sum(np.square(diff), 1)
74
+
75
+ for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)):
76
+ if pca > 0:
77
+ print('doing pca on', fold_idx)
78
+ embed1_train = embeddings1[train_set]
79
+ embed2_train = embeddings2[train_set]
80
+ _embed_train = np.concatenate((embed1_train, embed2_train), axis=0)
81
+ pca_model = PCA(n_components=pca)
82
+ pca_model.fit(_embed_train)
83
+ embed1 = pca_model.transform(embeddings1)
84
+ embed2 = pca_model.transform(embeddings2)
85
+ embed1 = sklearn.preprocessing.normalize(embed1)
86
+ embed2 = sklearn.preprocessing.normalize(embed2)
87
+ diff = np.subtract(embed1, embed2)
88
+ dist = np.sum(np.square(diff), 1)
89
+
90
+ # Find the best threshold for the fold
91
+ acc_train = np.zeros((nrof_thresholds))
92
+ for threshold_idx, threshold in enumerate(thresholds):
93
+ _, _, acc_train[threshold_idx] = calculate_accuracy(
94
+ threshold, dist[train_set], actual_issame[train_set])
95
+ best_threshold_index = np.argmax(acc_train)
96
+ for threshold_idx, threshold in enumerate(thresholds):
97
+ tprs[fold_idx, threshold_idx], fprs[fold_idx, threshold_idx], _ = calculate_accuracy(
98
+ threshold, dist[test_set],
99
+ actual_issame[test_set])
100
+ _, _, accuracy[fold_idx] = calculate_accuracy(
101
+ thresholds[best_threshold_index], dist[test_set],
102
+ actual_issame[test_set])
103
+
104
+ tpr = np.mean(tprs, 0)
105
+ fpr = np.mean(fprs, 0)
106
+ return tpr, fpr, accuracy
107
+
108
+
109
+ def calculate_accuracy(threshold, dist, actual_issame):
110
+ predict_issame = np.less(dist, threshold)
111
+ tp = np.sum(np.logical_and(predict_issame, actual_issame))
112
+ fp = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame)))
113
+ tn = np.sum(
114
+ np.logical_and(np.logical_not(predict_issame),
115
+ np.logical_not(actual_issame)))
116
+ fn = np.sum(np.logical_and(np.logical_not(predict_issame), actual_issame))
117
+
118
+ tpr = 0 if (tp + fn == 0) else float(tp) / float(tp + fn)
119
+ fpr = 0 if (fp + tn == 0) else float(fp) / float(fp + tn)
120
+ acc = float(tp + tn) / dist.size
121
+ return tpr, fpr, acc
122
+
123
+
124
+ def calculate_val(thresholds,
125
+ embeddings1,
126
+ embeddings2,
127
+ actual_issame,
128
+ far_target,
129
+ nrof_folds=10):
130
+ assert (embeddings1.shape[0] == embeddings2.shape[0])
131
+ assert (embeddings1.shape[1] == embeddings2.shape[1])
132
+ nrof_pairs = min(len(actual_issame), embeddings1.shape[0])
133
+ nrof_thresholds = len(thresholds)
134
+ k_fold = LFold(n_splits=nrof_folds, shuffle=False)
135
+
136
+ val = np.zeros(nrof_folds)
137
+ far = np.zeros(nrof_folds)
138
+
139
+ diff = np.subtract(embeddings1, embeddings2)
140
+ dist = np.sum(np.square(diff), 1)
141
+ indices = np.arange(nrof_pairs)
142
+
143
+ for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)):
144
+
145
+ # Find the threshold that gives FAR = far_target
146
+ far_train = np.zeros(nrof_thresholds)
147
+ for threshold_idx, threshold in enumerate(thresholds):
148
+ _, far_train[threshold_idx] = calculate_val_far(
149
+ threshold, dist[train_set], actual_issame[train_set])
150
+ if np.max(far_train) >= far_target:
151
+ f = interpolate.interp1d(far_train, thresholds, kind='slinear')
152
+ threshold = f(far_target)
153
+ else:
154
+ threshold = 0.0
155
+
156
+ val[fold_idx], far[fold_idx] = calculate_val_far(
157
+ threshold, dist[test_set], actual_issame[test_set])
158
+
159
+ val_mean = np.mean(val)
160
+ far_mean = np.mean(far)
161
+ val_std = np.std(val)
162
+ return val_mean, val_std, far_mean
163
+
164
+
165
+ def calculate_val_far(threshold, dist, actual_issame):
166
+ predict_issame = np.less(dist, threshold)
167
+ true_accept = np.sum(np.logical_and(predict_issame, actual_issame))
168
+ false_accept = np.sum(
169
+ np.logical_and(predict_issame, np.logical_not(actual_issame)))
170
+ n_same = np.sum(actual_issame)
171
+ n_diff = np.sum(np.logical_not(actual_issame))
172
+ # print(true_accept, false_accept)
173
+ # print(n_same, n_diff)
174
+ val = float(true_accept) / float(n_same)
175
+ far = float(false_accept) / float(n_diff)
176
+ return val, far
177
+
178
+
179
+ def evaluate(embeddings, actual_issame, nrof_folds=10, pca=0):
180
+ # Calculate evaluation metrics
181
+ thresholds = np.arange(0, 4, 0.01)
182
+ embeddings1 = embeddings[0::2]
183
+ embeddings2 = embeddings[1::2]
184
+ tpr, fpr, accuracy = calculate_roc(thresholds,
185
+ embeddings1,
186
+ embeddings2,
187
+ np.asarray(actual_issame),
188
+ nrof_folds=nrof_folds,
189
+ pca=pca)
190
+ thresholds = np.arange(0, 4, 0.001)
191
+ val, val_std, far = calculate_val(thresholds,
192
+ embeddings1,
193
+ embeddings2,
194
+ np.asarray(actual_issame),
195
+ 1e-3,
196
+ nrof_folds=nrof_folds)
197
+ return tpr, fpr, accuracy, val, val_std, far
198
+
199
+ @torch.no_grad()
200
+ def load_bin(path, image_size):
201
+ try:
202
+ with open(path, 'rb') as f:
203
+ bins, issame_list = pickle.load(f) # py2
204
+ except UnicodeDecodeError as e:
205
+ with open(path, 'rb') as f:
206
+ bins, issame_list = pickle.load(f, encoding='bytes') # py3
207
+ data_list = []
208
+ for flip in [0, 1]:
209
+ data = torch.empty((len(issame_list) * 2, 3, image_size[0], image_size[1]))
210
+ data_list.append(data)
211
+ for idx in range(len(issame_list) * 2):
212
+ _bin = bins[idx]
213
+ img = mx.image.imdecode(_bin)
214
+ if img.shape[1] != image_size[0]:
215
+ img = mx.image.resize_short(img, image_size[0])
216
+ img = nd.transpose(img, axes=(2, 0, 1))
217
+ for flip in [0, 1]:
218
+ if flip == 1:
219
+ img = mx.ndarray.flip(data=img, axis=2)
220
+ data_list[flip][idx][:] = torch.from_numpy(img.asnumpy())
221
+ if idx % 1000 == 0:
222
+ print('loading bin', idx)
223
+ print(data_list[0].shape)
224
+ return data_list, issame_list
225
+
226
+ @torch.no_grad()
227
+ def test(data_set, backbone, batch_size, nfolds=10):
228
+ print('testing verification..')
229
+ data_list = data_set[0]
230
+ issame_list = data_set[1]
231
+ embeddings_list = []
232
+ time_consumed = 0.0
233
+ for i in range(len(data_list)):
234
+ data = data_list[i]
235
+ embeddings = None
236
+ ba = 0
237
+ while ba < data.shape[0]:
238
+ bb = min(ba + batch_size, data.shape[0])
239
+ count = bb - ba
240
+ _data = data[bb - batch_size: bb]
241
+ time0 = datetime.datetime.now()
242
+ img = ((_data / 255) - 0.5) / 0.5
243
+ net_out: torch.Tensor = backbone(img)
244
+ _embeddings = net_out.detach().cpu().numpy()
245
+ time_now = datetime.datetime.now()
246
+ diff = time_now - time0
247
+ time_consumed += diff.total_seconds()
248
+ if embeddings is None:
249
+ embeddings = np.zeros((data.shape[0], _embeddings.shape[1]))
250
+ embeddings[ba:bb, :] = _embeddings[(batch_size - count):, :]
251
+ ba = bb
252
+ embeddings_list.append(embeddings)
253
+
254
+ _xnorm = 0.0
255
+ _xnorm_cnt = 0
256
+ for embed in embeddings_list:
257
+ for i in range(embed.shape[0]):
258
+ _em = embed[i]
259
+ _norm = np.linalg.norm(_em)
260
+ _xnorm += _norm
261
+ _xnorm_cnt += 1
262
+ _xnorm /= _xnorm_cnt
263
+
264
+ acc1 = 0.0
265
+ std1 = 0.0
266
+ embeddings = embeddings_list[0] + embeddings_list[1]
267
+ embeddings = sklearn.preprocessing.normalize(embeddings)
268
+ print(embeddings.shape)
269
+ print('infer time', time_consumed)
270
+ _, _, accuracy, val, val_std, far = evaluate(embeddings, issame_list, nrof_folds=nfolds)
271
+ acc2, std2 = np.mean(accuracy), np.std(accuracy)
272
+ return acc1, std1, acc2, std2, _xnorm, embeddings_list
273
+
274
+
275
+ def dumpR(data_set,
276
+ backbone,
277
+ batch_size,
278
+ name='',
279
+ data_extra=None,
280
+ label_shape=None):
281
+ print('dump verification embedding..')
282
+ data_list = data_set[0]
283
+ issame_list = data_set[1]
284
+ embeddings_list = []
285
+ time_consumed = 0.0
286
+ for i in range(len(data_list)):
287
+ data = data_list[i]
288
+ embeddings = None
289
+ ba = 0
290
+ while ba < data.shape[0]:
291
+ bb = min(ba + batch_size, data.shape[0])
292
+ count = bb - ba
293
+
294
+ _data = nd.slice_axis(data, axis=0, begin=bb - batch_size, end=bb)
295
+ time0 = datetime.datetime.now()
296
+ if data_extra is None:
297
+ db = mx.io.DataBatch(data=(_data,), label=(_label,))
298
+ else:
299
+ db = mx.io.DataBatch(data=(_data, _data_extra),
300
+ label=(_label,))
301
+ model.forward(db, is_train=False)
302
+ net_out = model.get_outputs()
303
+ _embeddings = net_out[0].asnumpy()
304
+ time_now = datetime.datetime.now()
305
+ diff = time_now - time0
306
+ time_consumed += diff.total_seconds()
307
+ if embeddings is None:
308
+ embeddings = np.zeros((data.shape[0], _embeddings.shape[1]))
309
+ embeddings[ba:bb, :] = _embeddings[(batch_size - count):, :]
310
+ ba = bb
311
+ embeddings_list.append(embeddings)
312
+ embeddings = embeddings_list[0] + embeddings_list[1]
313
+ embeddings = sklearn.preprocessing.normalize(embeddings)
314
+ actual_issame = np.asarray(issame_list)
315
+ outname = os.path.join('temp.bin')
316
+ with open(outname, 'wb') as f:
317
+ pickle.dump((embeddings, issame_list),
318
+ f,
319
+ protocol=pickle.HIGHEST_PROTOCOL)
320
+
321
+
322
+ # if __name__ == '__main__':
323
+ #
324
+ # parser = argparse.ArgumentParser(description='do verification')
325
+ # # general
326
+ # parser.add_argument('--data-dir', default='', help='')
327
+ # parser.add_argument('--model',
328
+ # default='../model/softmax,50',
329
+ # help='path to load model.')
330
+ # parser.add_argument('--target',
331
+ # default='lfw,cfp_ff,cfp_fp,agedb_30',
332
+ # help='test targets.')
333
+ # parser.add_argument('--gpu', default=0, type=int, help='gpu id')
334
+ # parser.add_argument('--batch-size', default=32, type=int, help='')
335
+ # parser.add_argument('--max', default='', type=str, help='')
336
+ # parser.add_argument('--mode', default=0, type=int, help='')
337
+ # parser.add_argument('--nfolds', default=10, type=int, help='')
338
+ # args = parser.parse_args()
339
+ # image_size = [112, 112]
340
+ # print('image_size', image_size)
341
+ # ctx = mx.gpu(args.gpu)
342
+ # nets = []
343
+ # vec = args.model.split(',')
344
+ # prefix = args.model.split(',')[0]
345
+ # epochs = []
346
+ # if len(vec) == 1:
347
+ # pdir = os.path.dirname(prefix)
348
+ # for fname in os.listdir(pdir):
349
+ # if not fname.endswith('.params'):
350
+ # continue
351
+ # _file = os.path.join(pdir, fname)
352
+ # if _file.startswith(prefix):
353
+ # epoch = int(fname.split('.')[0].split('-')[1])
354
+ # epochs.append(epoch)
355
+ # epochs = sorted(epochs, reverse=True)
356
+ # if len(args.max) > 0:
357
+ # _max = [int(x) for x in args.max.split(',')]
358
+ # assert len(_max) == 2
359
+ # if len(epochs) > _max[1]:
360
+ # epochs = epochs[_max[0]:_max[1]]
361
+ #
362
+ # else:
363
+ # epochs = [int(x) for x in vec[1].split('|')]
364
+ # print('model number', len(epochs))
365
+ # time0 = datetime.datetime.now()
366
+ # for epoch in epochs:
367
+ # print('loading', prefix, epoch)
368
+ # sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
369
+ # # arg_params, aux_params = ch_dev(arg_params, aux_params, ctx)
370
+ # all_layers = sym.get_internals()
371
+ # sym = all_layers['fc1_output']
372
+ # model = mx.mod.Module(symbol=sym, context=ctx, label_names=None)
373
+ # # model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,))])
374
+ # model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0],
375
+ # image_size[1]))])
376
+ # model.set_params(arg_params, aux_params)
377
+ # nets.append(model)
378
+ # time_now = datetime.datetime.now()
379
+ # diff = time_now - time0
380
+ # print('model loading time', diff.total_seconds())
381
+ #
382
+ # ver_list = []
383
+ # ver_name_list = []
384
+ # for name in args.target.split(','):
385
+ # path = os.path.join(args.data_dir, name + ".bin")
386
+ # if os.path.exists(path):
387
+ # print('loading.. ', name)
388
+ # data_set = load_bin(path, image_size)
389
+ # ver_list.append(data_set)
390
+ # ver_name_list.append(name)
391
+ #
392
+ # if args.mode == 0:
393
+ # for i in range(len(ver_list)):
394
+ # results = []
395
+ # for model in nets:
396
+ # acc1, std1, acc2, std2, xnorm, embeddings_list = test(
397
+ # ver_list[i], model, args.batch_size, args.nfolds)
398
+ # print('[%s]XNorm: %f' % (ver_name_list[i], xnorm))
399
+ # print('[%s]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], acc1, std1))
400
+ # print('[%s]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], acc2, std2))
401
+ # results.append(acc2)
402
+ # print('Max of [%s] is %1.5f' % (ver_name_list[i], np.max(results)))
403
+ # elif args.mode == 1:
404
+ # raise ValueError
405
+ # else:
406
+ # model = nets[0]
407
+ # dumpR(ver_list[0], model, args.batch_size, args.target)
videoretalking/third_part/face3d/models/arcface_torch/eval_ijbc.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+
3
+ import os
4
+ import pickle
5
+
6
+ import matplotlib
7
+ import pandas as pd
8
+
9
+ matplotlib.use('Agg')
10
+ import matplotlib.pyplot as plt
11
+ import timeit
12
+ import sklearn
13
+ import argparse
14
+ import cv2
15
+ import numpy as np
16
+ import torch
17
+ from skimage import transform as trans
18
+ from backbones import get_model
19
+ from sklearn.metrics import roc_curve, auc
20
+
21
+ from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap
22
+ from prettytable import PrettyTable
23
+ from pathlib import Path
24
+
25
+ import sys
26
+ import warnings
27
+
28
+ sys.path.insert(0, "../")
29
+ warnings.filterwarnings("ignore")
30
+
31
+ parser = argparse.ArgumentParser(description='do ijb test')
32
+ # general
33
+ parser.add_argument('--model-prefix', default='', help='path to load model.')
34
+ parser.add_argument('--image-path', default='', type=str, help='')
35
+ parser.add_argument('--result-dir', default='.', type=str, help='')
36
+ parser.add_argument('--batch-size', default=128, type=int, help='')
37
+ parser.add_argument('--network', default='iresnet50', type=str, help='')
38
+ parser.add_argument('--job', default='insightface', type=str, help='job name')
39
+ parser.add_argument('--target', default='IJBC', type=str, help='target, set to IJBC or IJBB')
40
+ args = parser.parse_args()
41
+
42
+ target = args.target
43
+ model_path = args.model_prefix
44
+ image_path = args.image_path
45
+ result_dir = args.result_dir
46
+ gpu_id = None
47
+ use_norm_score = True # if Ture, TestMode(N1)
48
+ use_detector_score = True # if Ture, TestMode(D1)
49
+ use_flip_test = True # if Ture, TestMode(F1)
50
+ job = args.job
51
+ batch_size = args.batch_size
52
+
53
+
54
+ class Embedding(object):
55
+ def __init__(self, prefix, data_shape, batch_size=1):
56
+ image_size = (112, 112)
57
+ self.image_size = image_size
58
+ weight = torch.load(prefix)
59
+ resnet = get_model(args.network, dropout=0, fp16=False).cuda()
60
+ resnet.load_state_dict(weight)
61
+ model = torch.nn.DataParallel(resnet)
62
+ self.model = model
63
+ self.model.eval()
64
+ src = np.array([
65
+ [30.2946, 51.6963],
66
+ [65.5318, 51.5014],
67
+ [48.0252, 71.7366],
68
+ [33.5493, 92.3655],
69
+ [62.7299, 92.2041]], dtype=np.float32)
70
+ src[:, 0] += 8.0
71
+ self.src = src
72
+ self.batch_size = batch_size
73
+ self.data_shape = data_shape
74
+
75
+ def get(self, rimg, landmark):
76
+
77
+ assert landmark.shape[0] == 68 or landmark.shape[0] == 5
78
+ assert landmark.shape[1] == 2
79
+ if landmark.shape[0] == 68:
80
+ landmark5 = np.zeros((5, 2), dtype=np.float32)
81
+ landmark5[0] = (landmark[36] + landmark[39]) / 2
82
+ landmark5[1] = (landmark[42] + landmark[45]) / 2
83
+ landmark5[2] = landmark[30]
84
+ landmark5[3] = landmark[48]
85
+ landmark5[4] = landmark[54]
86
+ else:
87
+ landmark5 = landmark
88
+ tform = trans.SimilarityTransform()
89
+ tform.estimate(landmark5, self.src)
90
+ M = tform.params[0:2, :]
91
+ img = cv2.warpAffine(rimg,
92
+ M, (self.image_size[1], self.image_size[0]),
93
+ borderValue=0.0)
94
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
95
+ img_flip = np.fliplr(img)
96
+ img = np.transpose(img, (2, 0, 1)) # 3*112*112, RGB
97
+ img_flip = np.transpose(img_flip, (2, 0, 1))
98
+ input_blob = np.zeros((2, 3, self.image_size[1], self.image_size[0]), dtype=np.uint8)
99
+ input_blob[0] = img
100
+ input_blob[1] = img_flip
101
+ return input_blob
102
+
103
+ @torch.no_grad()
104
+ def forward_db(self, batch_data):
105
+ imgs = torch.Tensor(batch_data).cuda()
106
+ imgs.div_(255).sub_(0.5).div_(0.5)
107
+ feat = self.model(imgs)
108
+ feat = feat.reshape([self.batch_size, 2 * feat.shape[1]])
109
+ return feat.cpu().numpy()
110
+
111
+
112
+ # 将一个list尽量均分成n份,限制len(list)==n,份数大于原list内元素个数则分配空list[]
113
+ def divideIntoNstrand(listTemp, n):
114
+ twoList = [[] for i in range(n)]
115
+ for i, e in enumerate(listTemp):
116
+ twoList[i % n].append(e)
117
+ return twoList
118
+
119
+
120
+ def read_template_media_list(path):
121
+ # ijb_meta = np.loadtxt(path, dtype=str)
122
+ ijb_meta = pd.read_csv(path, sep=' ', header=None).values
123
+ templates = ijb_meta[:, 1].astype(np.int)
124
+ medias = ijb_meta[:, 2].astype(np.int)
125
+ return templates, medias
126
+
127
+
128
+ # In[ ]:
129
+
130
+
131
+ def read_template_pair_list(path):
132
+ # pairs = np.loadtxt(path, dtype=str)
133
+ pairs = pd.read_csv(path, sep=' ', header=None).values
134
+ # print(pairs.shape)
135
+ # print(pairs[:, 0].astype(np.int))
136
+ t1 = pairs[:, 0].astype(np.int)
137
+ t2 = pairs[:, 1].astype(np.int)
138
+ label = pairs[:, 2].astype(np.int)
139
+ return t1, t2, label
140
+
141
+
142
+ # In[ ]:
143
+
144
+
145
+ def read_image_feature(path):
146
+ with open(path, 'rb') as fid:
147
+ img_feats = pickle.load(fid)
148
+ return img_feats
149
+
150
+
151
+ # In[ ]:
152
+
153
+
154
+ def get_image_feature(img_path, files_list, model_path, epoch, gpu_id):
155
+ batch_size = args.batch_size
156
+ data_shape = (3, 112, 112)
157
+
158
+ files = files_list
159
+ print('files:', len(files))
160
+ rare_size = len(files) % batch_size
161
+ faceness_scores = []
162
+ batch = 0
163
+ img_feats = np.empty((len(files), 1024), dtype=np.float32)
164
+
165
+ batch_data = np.empty((2 * batch_size, 3, 112, 112))
166
+ embedding = Embedding(model_path, data_shape, batch_size)
167
+ for img_index, each_line in enumerate(files[:len(files) - rare_size]):
168
+ name_lmk_score = each_line.strip().split(' ')
169
+ img_name = os.path.join(img_path, name_lmk_score[0])
170
+ img = cv2.imread(img_name)
171
+ lmk = np.array([float(x) for x in name_lmk_score[1:-1]],
172
+ dtype=np.float32)
173
+ lmk = lmk.reshape((5, 2))
174
+ input_blob = embedding.get(img, lmk)
175
+
176
+ batch_data[2 * (img_index - batch * batch_size)][:] = input_blob[0]
177
+ batch_data[2 * (img_index - batch * batch_size) + 1][:] = input_blob[1]
178
+ if (img_index + 1) % batch_size == 0:
179
+ print('batch', batch)
180
+ img_feats[batch * batch_size:batch * batch_size +
181
+ batch_size][:] = embedding.forward_db(batch_data)
182
+ batch += 1
183
+ faceness_scores.append(name_lmk_score[-1])
184
+
185
+ batch_data = np.empty((2 * rare_size, 3, 112, 112))
186
+ embedding = Embedding(model_path, data_shape, rare_size)
187
+ for img_index, each_line in enumerate(files[len(files) - rare_size:]):
188
+ name_lmk_score = each_line.strip().split(' ')
189
+ img_name = os.path.join(img_path, name_lmk_score[0])
190
+ img = cv2.imread(img_name)
191
+ lmk = np.array([float(x) for x in name_lmk_score[1:-1]],
192
+ dtype=np.float32)
193
+ lmk = lmk.reshape((5, 2))
194
+ input_blob = embedding.get(img, lmk)
195
+ batch_data[2 * img_index][:] = input_blob[0]
196
+ batch_data[2 * img_index + 1][:] = input_blob[1]
197
+ if (img_index + 1) % rare_size == 0:
198
+ print('batch', batch)
199
+ img_feats[len(files) -
200
+ rare_size:][:] = embedding.forward_db(batch_data)
201
+ batch += 1
202
+ faceness_scores.append(name_lmk_score[-1])
203
+ faceness_scores = np.array(faceness_scores).astype(np.float32)
204
+ # img_feats = np.ones( (len(files), 1024), dtype=np.float32) * 0.01
205
+ # faceness_scores = np.ones( (len(files), ), dtype=np.float32 )
206
+ return img_feats, faceness_scores
207
+
208
+
209
+ # In[ ]:
210
+
211
+
212
+ def image2template_feature(img_feats=None, templates=None, medias=None):
213
+ # ==========================================================
214
+ # 1. face image feature l2 normalization. img_feats:[number_image x feats_dim]
215
+ # 2. compute media feature.
216
+ # 3. compute template feature.
217
+ # ==========================================================
218
+ unique_templates = np.unique(templates)
219
+ template_feats = np.zeros((len(unique_templates), img_feats.shape[1]))
220
+
221
+ for count_template, uqt in enumerate(unique_templates):
222
+
223
+ (ind_t,) = np.where(templates == uqt)
224
+ face_norm_feats = img_feats[ind_t]
225
+ face_medias = medias[ind_t]
226
+ unique_medias, unique_media_counts = np.unique(face_medias,
227
+ return_counts=True)
228
+ media_norm_feats = []
229
+ for u, ct in zip(unique_medias, unique_media_counts):
230
+ (ind_m,) = np.where(face_medias == u)
231
+ if ct == 1:
232
+ media_norm_feats += [face_norm_feats[ind_m]]
233
+ else: # image features from the same video will be aggregated into one feature
234
+ media_norm_feats += [
235
+ np.mean(face_norm_feats[ind_m], axis=0, keepdims=True)
236
+ ]
237
+ media_norm_feats = np.array(media_norm_feats)
238
+ # media_norm_feats = media_norm_feats / np.sqrt(np.sum(media_norm_feats ** 2, -1, keepdims=True))
239
+ template_feats[count_template] = np.sum(media_norm_feats, axis=0)
240
+ if count_template % 2000 == 0:
241
+ print('Finish Calculating {} template features.'.format(
242
+ count_template))
243
+ # template_norm_feats = template_feats / np.sqrt(np.sum(template_feats ** 2, -1, keepdims=True))
244
+ template_norm_feats = sklearn.preprocessing.normalize(template_feats)
245
+ # print(template_norm_feats.shape)
246
+ return template_norm_feats, unique_templates
247
+
248
+
249
+ # In[ ]:
250
+
251
+
252
+ def verification(template_norm_feats=None,
253
+ unique_templates=None,
254
+ p1=None,
255
+ p2=None):
256
+ # ==========================================================
257
+ # Compute set-to-set Similarity Score.
258
+ # ==========================================================
259
+ template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int)
260
+ for count_template, uqt in enumerate(unique_templates):
261
+ template2id[uqt] = count_template
262
+
263
+ score = np.zeros((len(p1),)) # save cosine distance between pairs
264
+
265
+ total_pairs = np.array(range(len(p1)))
266
+ batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation
267
+ sublists = [
268
+ total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize)
269
+ ]
270
+ total_sublists = len(sublists)
271
+ for c, s in enumerate(sublists):
272
+ feat1 = template_norm_feats[template2id[p1[s]]]
273
+ feat2 = template_norm_feats[template2id[p2[s]]]
274
+ similarity_score = np.sum(feat1 * feat2, -1)
275
+ score[s] = similarity_score.flatten()
276
+ if c % 10 == 0:
277
+ print('Finish {}/{} pairs.'.format(c, total_sublists))
278
+ return score
279
+
280
+
281
+ # In[ ]:
282
+ def verification2(template_norm_feats=None,
283
+ unique_templates=None,
284
+ p1=None,
285
+ p2=None):
286
+ template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int)
287
+ for count_template, uqt in enumerate(unique_templates):
288
+ template2id[uqt] = count_template
289
+ score = np.zeros((len(p1),)) # save cosine distance between pairs
290
+ total_pairs = np.array(range(len(p1)))
291
+ batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation
292
+ sublists = [
293
+ total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize)
294
+ ]
295
+ total_sublists = len(sublists)
296
+ for c, s in enumerate(sublists):
297
+ feat1 = template_norm_feats[template2id[p1[s]]]
298
+ feat2 = template_norm_feats[template2id[p2[s]]]
299
+ similarity_score = np.sum(feat1 * feat2, -1)
300
+ score[s] = similarity_score.flatten()
301
+ if c % 10 == 0:
302
+ print('Finish {}/{} pairs.'.format(c, total_sublists))
303
+ return score
304
+
305
+
306
+ def read_score(path):
307
+ with open(path, 'rb') as fid:
308
+ img_feats = pickle.load(fid)
309
+ return img_feats
310
+
311
+
312
+ # # Step1: Load Meta Data
313
+
314
+ # In[ ]:
315
+
316
+ assert target == 'IJBC' or target == 'IJBB'
317
+
318
+ # =============================================================
319
+ # load image and template relationships for template feature embedding
320
+ # tid --> template id, mid --> media id
321
+ # format:
322
+ # image_name tid mid
323
+ # =============================================================
324
+ start = timeit.default_timer()
325
+ templates, medias = read_template_media_list(
326
+ os.path.join('%s/meta' % image_path,
327
+ '%s_face_tid_mid.txt' % target.lower()))
328
+ stop = timeit.default_timer()
329
+ print('Time: %.2f s. ' % (stop - start))
330
+
331
+ # In[ ]:
332
+
333
+ # =============================================================
334
+ # load template pairs for template-to-template verification
335
+ # tid : template id, label : 1/0
336
+ # format:
337
+ # tid_1 tid_2 label
338
+ # =============================================================
339
+ start = timeit.default_timer()
340
+ p1, p2, label = read_template_pair_list(
341
+ os.path.join('%s/meta' % image_path,
342
+ '%s_template_pair_label.txt' % target.lower()))
343
+ stop = timeit.default_timer()
344
+ print('Time: %.2f s. ' % (stop - start))
345
+
346
+ # # Step 2: Get Image Features
347
+
348
+ # In[ ]:
349
+
350
+ # =============================================================
351
+ # load image features
352
+ # format:
353
+ # img_feats: [image_num x feats_dim] (227630, 512)
354
+ # =============================================================
355
+ start = timeit.default_timer()
356
+ img_path = '%s/loose_crop' % image_path
357
+ img_list_path = '%s/meta/%s_name_5pts_score.txt' % (image_path, target.lower())
358
+ img_list = open(img_list_path)
359
+ files = img_list.readlines()
360
+ # files_list = divideIntoNstrand(files, rank_size)
361
+ files_list = files
362
+
363
+ # img_feats
364
+ # for i in range(rank_size):
365
+ img_feats, faceness_scores = get_image_feature(img_path, files_list,
366
+ model_path, 0, gpu_id)
367
+ stop = timeit.default_timer()
368
+ print('Time: %.2f s. ' % (stop - start))
369
+ print('Feature Shape: ({} , {}) .'.format(img_feats.shape[0],
370
+ img_feats.shape[1]))
371
+
372
+ # # Step3: Get Template Features
373
+
374
+ # In[ ]:
375
+
376
+ # =============================================================
377
+ # compute template features from image features.
378
+ # =============================================================
379
+ start = timeit.default_timer()
380
+ # ==========================================================
381
+ # Norm feature before aggregation into template feature?
382
+ # Feature norm from embedding network and faceness score are able to decrease weights for noise samples (not face).
383
+ # ==========================================================
384
+ # 1. FaceScore (Feature Norm)
385
+ # 2. FaceScore (Detector)
386
+
387
+ if use_flip_test:
388
+ # concat --- F1
389
+ # img_input_feats = img_feats
390
+ # add --- F2
391
+ img_input_feats = img_feats[:, 0:img_feats.shape[1] //
392
+ 2] + img_feats[:, img_feats.shape[1] // 2:]
393
+ else:
394
+ img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2]
395
+
396
+ if use_norm_score:
397
+ img_input_feats = img_input_feats
398
+ else:
399
+ # normalise features to remove norm information
400
+ img_input_feats = img_input_feats / np.sqrt(
401
+ np.sum(img_input_feats ** 2, -1, keepdims=True))
402
+
403
+ if use_detector_score:
404
+ print(img_input_feats.shape, faceness_scores.shape)
405
+ img_input_feats = img_input_feats * faceness_scores[:, np.newaxis]
406
+ else:
407
+ img_input_feats = img_input_feats
408
+
409
+ template_norm_feats, unique_templates = image2template_feature(
410
+ img_input_feats, templates, medias)
411
+ stop = timeit.default_timer()
412
+ print('Time: %.2f s. ' % (stop - start))
413
+
414
+ # # Step 4: Get Template Similarity Scores
415
+
416
+ # In[ ]:
417
+
418
+ # =============================================================
419
+ # compute verification scores between template pairs.
420
+ # =============================================================
421
+ start = timeit.default_timer()
422
+ score = verification(template_norm_feats, unique_templates, p1, p2)
423
+ stop = timeit.default_timer()
424
+ print('Time: %.2f s. ' % (stop - start))
425
+
426
+ # In[ ]:
427
+ save_path = os.path.join(result_dir, args.job)
428
+ # save_path = result_dir + '/%s_result' % target
429
+
430
+ if not os.path.exists(save_path):
431
+ os.makedirs(save_path)
432
+
433
+ score_save_file = os.path.join(save_path, "%s.npy" % target.lower())
434
+ np.save(score_save_file, score)
435
+
436
+ # # Step 5: Get ROC Curves and TPR@FPR Table
437
+
438
+ # In[ ]:
439
+
440
+ files = [score_save_file]
441
+ methods = []
442
+ scores = []
443
+ for file in files:
444
+ methods.append(Path(file).stem)
445
+ scores.append(np.load(file))
446
+
447
+ methods = np.array(methods)
448
+ scores = dict(zip(methods, scores))
449
+ colours = dict(
450
+ zip(methods, sample_colours_from_colourmap(methods.shape[0], 'Set2')))
451
+ x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1]
452
+ tpr_fpr_table = PrettyTable(['Methods'] + [str(x) for x in x_labels])
453
+ fig = plt.figure()
454
+ for method in methods:
455
+ fpr, tpr, _ = roc_curve(label, scores[method])
456
+ roc_auc = auc(fpr, tpr)
457
+ fpr = np.flipud(fpr)
458
+ tpr = np.flipud(tpr) # select largest tpr at same fpr
459
+ plt.plot(fpr,
460
+ tpr,
461
+ color=colours[method],
462
+ lw=1,
463
+ label=('[%s (AUC = %0.4f %%)]' %
464
+ (method.split('-')[-1], roc_auc * 100)))
465
+ tpr_fpr_row = []
466
+ tpr_fpr_row.append("%s-%s" % (method, target))
467
+ for fpr_iter in np.arange(len(x_labels)):
468
+ _, min_index = min(
469
+ list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr)))))
470
+ tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100))
471
+ tpr_fpr_table.add_row(tpr_fpr_row)
472
+ plt.xlim([10 ** -6, 0.1])
473
+ plt.ylim([0.3, 1.0])
474
+ plt.grid(linestyle='--', linewidth=1)
475
+ plt.xticks(x_labels)
476
+ plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True))
477
+ plt.xscale('log')
478
+ plt.xlabel('False Positive Rate')
479
+ plt.ylabel('True Positive Rate')
480
+ plt.title('ROC on IJB')
481
+ plt.legend(loc="lower right")
482
+ fig.savefig(os.path.join(save_path, '%s.pdf' % target.lower()))
483
+ print(tpr_fpr_table)
videoretalking/third_part/face3d/models/arcface_torch/inference.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+
7
+ from backbones import get_model
8
+
9
+
10
+ @torch.no_grad()
11
+ def inference(weight, name, img):
12
+ if img is None:
13
+ img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.uint8)
14
+ else:
15
+ img = cv2.imread(img)
16
+ img = cv2.resize(img, (112, 112))
17
+
18
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
19
+ img = np.transpose(img, (2, 0, 1))
20
+ img = torch.from_numpy(img).unsqueeze(0).float()
21
+ img.div_(255).sub_(0.5).div_(0.5)
22
+ net = get_model(name, fp16=False)
23
+ net.load_state_dict(torch.load(weight))
24
+ net.eval()
25
+ feat = net(img).numpy()
26
+ print(feat)
27
+
28
+
29
+ if __name__ == "__main__":
30
+ parser = argparse.ArgumentParser(description='PyTorch ArcFace Training')
31
+ parser.add_argument('--network', type=str, default='r50', help='backbone network')
32
+ parser.add_argument('--weight', type=str, default='')
33
+ parser.add_argument('--img', type=str, default=None)
34
+ args = parser.parse_args()
35
+ inference(args.weight, args.network, args.img)
videoretalking/third_part/face3d/models/arcface_torch/losses.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ def get_loss(name):
6
+ if name == "cosface":
7
+ return CosFace()
8
+ elif name == "arcface":
9
+ return ArcFace()
10
+ else:
11
+ raise ValueError()
12
+
13
+
14
+ class CosFace(nn.Module):
15
+ def __init__(self, s=64.0, m=0.40):
16
+ super(CosFace, self).__init__()
17
+ self.s = s
18
+ self.m = m
19
+
20
+ def forward(self, cosine, label):
21
+ index = torch.where(label != -1)[0]
22
+ m_hot = torch.zeros(index.size()[0], cosine.size()[1], device=cosine.device)
23
+ m_hot.scatter_(1, label[index, None], self.m)
24
+ cosine[index] -= m_hot
25
+ ret = cosine * self.s
26
+ return ret
27
+
28
+
29
+ class ArcFace(nn.Module):
30
+ def __init__(self, s=64.0, m=0.5):
31
+ super(ArcFace, self).__init__()
32
+ self.s = s
33
+ self.m = m
34
+
35
+ def forward(self, cosine: torch.Tensor, label):
36
+ index = torch.where(label != -1)[0]
37
+ m_hot = torch.zeros(index.size()[0], cosine.size()[1], device=cosine.device)
38
+ m_hot.scatter_(1, label[index, None], self.m)
39
+ cosine.acos_()
40
+ cosine[index] += m_hot
41
+ cosine.cos_().mul_(self.s)
42
+ return cosine
videoretalking/third_part/face3d/models/arcface_torch/onnx_helper.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import division
2
+ import datetime
3
+ import os
4
+ import os.path as osp
5
+ import glob
6
+ import numpy as np
7
+ import cv2
8
+ import sys
9
+ import onnxruntime
10
+ import onnx
11
+ import argparse
12
+ from onnx import numpy_helper
13
+ from insightface.data import get_image
14
+
15
+ class ArcFaceORT:
16
+ def __init__(self, model_path, cpu=False):
17
+ self.model_path = model_path
18
+ # providers = None will use available provider, for onnxruntime-gpu it will be "CUDAExecutionProvider"
19
+ self.providers = ['CPUExecutionProvider'] if cpu else None
20
+
21
+ #input_size is (w,h), return error message, return None if success
22
+ def check(self, track='cfat', test_img = None):
23
+ #default is cfat
24
+ max_model_size_mb=1024
25
+ max_feat_dim=512
26
+ max_time_cost=15
27
+ if track.startswith('ms1m'):
28
+ max_model_size_mb=1024
29
+ max_feat_dim=512
30
+ max_time_cost=10
31
+ elif track.startswith('glint'):
32
+ max_model_size_mb=1024
33
+ max_feat_dim=1024
34
+ max_time_cost=20
35
+ elif track.startswith('cfat'):
36
+ max_model_size_mb = 1024
37
+ max_feat_dim = 512
38
+ max_time_cost = 15
39
+ elif track.startswith('unconstrained'):
40
+ max_model_size_mb=1024
41
+ max_feat_dim=1024
42
+ max_time_cost=30
43
+ else:
44
+ return "track not found"
45
+
46
+ if not os.path.exists(self.model_path):
47
+ return "model_path not exists"
48
+ if not os.path.isdir(self.model_path):
49
+ return "model_path should be directory"
50
+ onnx_files = []
51
+ for _file in os.listdir(self.model_path):
52
+ if _file.endswith('.onnx'):
53
+ onnx_files.append(osp.join(self.model_path, _file))
54
+ if len(onnx_files)==0:
55
+ return "do not have onnx files"
56
+ self.model_file = sorted(onnx_files)[-1]
57
+ print('use onnx-model:', self.model_file)
58
+ try:
59
+ session = onnxruntime.InferenceSession(self.model_file, providers=self.providers)
60
+ except:
61
+ return "load onnx failed"
62
+ input_cfg = session.get_inputs()[0]
63
+ input_shape = input_cfg.shape
64
+ print('input-shape:', input_shape)
65
+ if len(input_shape)!=4:
66
+ return "length of input_shape should be 4"
67
+ if not isinstance(input_shape[0], str):
68
+ #return "input_shape[0] should be str to support batch-inference"
69
+ print('reset input-shape[0] to None')
70
+ model = onnx.load(self.model_file)
71
+ model.graph.input[0].type.tensor_type.shape.dim[0].dim_param = 'None'
72
+ new_model_file = osp.join(self.model_path, 'zzzzrefined.onnx')
73
+ onnx.save(model, new_model_file)
74
+ self.model_file = new_model_file
75
+ print('use new onnx-model:', self.model_file)
76
+ try:
77
+ session = onnxruntime.InferenceSession(self.model_file, providers=self.providers)
78
+ except:
79
+ return "load onnx failed"
80
+ input_cfg = session.get_inputs()[0]
81
+ input_shape = input_cfg.shape
82
+ print('new-input-shape:', input_shape)
83
+
84
+ self.image_size = tuple(input_shape[2:4][::-1])
85
+ #print('image_size:', self.image_size)
86
+ input_name = input_cfg.name
87
+ outputs = session.get_outputs()
88
+ output_names = []
89
+ for o in outputs:
90
+ output_names.append(o.name)
91
+ #print(o.name, o.shape)
92
+ if len(output_names)!=1:
93
+ return "number of output nodes should be 1"
94
+ self.session = session
95
+ self.input_name = input_name
96
+ self.output_names = output_names
97
+ #print(self.output_names)
98
+ model = onnx.load(self.model_file)
99
+ graph = model.graph
100
+ if len(graph.node)<8:
101
+ return "too small onnx graph"
102
+
103
+ input_size = (112,112)
104
+ self.crop = None
105
+ if track=='cfat':
106
+ crop_file = osp.join(self.model_path, 'crop.txt')
107
+ if osp.exists(crop_file):
108
+ lines = open(crop_file,'r').readlines()
109
+ if len(lines)!=6:
110
+ return "crop.txt should contain 6 lines"
111
+ lines = [int(x) for x in lines]
112
+ self.crop = lines[:4]
113
+ input_size = tuple(lines[4:6])
114
+ if input_size!=self.image_size:
115
+ return "input-size is inconsistant with onnx model input, %s vs %s"%(input_size, self.image_size)
116
+
117
+ self.model_size_mb = os.path.getsize(self.model_file) / float(1024*1024)
118
+ if self.model_size_mb > max_model_size_mb:
119
+ return "max model size exceed, given %.3f-MB"%self.model_size_mb
120
+
121
+ input_mean = None
122
+ input_std = None
123
+ if track=='cfat':
124
+ pn_file = osp.join(self.model_path, 'pixel_norm.txt')
125
+ if osp.exists(pn_file):
126
+ lines = open(pn_file,'r').readlines()
127
+ if len(lines)!=2:
128
+ return "pixel_norm.txt should contain 2 lines"
129
+ input_mean = float(lines[0])
130
+ input_std = float(lines[1])
131
+ if input_mean is not None or input_std is not None:
132
+ if input_mean is None or input_std is None:
133
+ return "please set input_mean and input_std simultaneously"
134
+ else:
135
+ find_sub = False
136
+ find_mul = False
137
+ for nid, node in enumerate(graph.node[:8]):
138
+ print(nid, node.name)
139
+ if node.name.startswith('Sub') or node.name.startswith('_minus'):
140
+ find_sub = True
141
+ if node.name.startswith('Mul') or node.name.startswith('_mul') or node.name.startswith('Div'):
142
+ find_mul = True
143
+ if find_sub and find_mul:
144
+ print("find sub and mul")
145
+ #mxnet arcface model
146
+ input_mean = 0.0
147
+ input_std = 1.0
148
+ else:
149
+ input_mean = 127.5
150
+ input_std = 127.5
151
+ self.input_mean = input_mean
152
+ self.input_std = input_std
153
+ for initn in graph.initializer:
154
+ weight_array = numpy_helper.to_array(initn)
155
+ dt = weight_array.dtype
156
+ if dt.itemsize<4:
157
+ return 'invalid weight type - (%s:%s)' % (initn.name, dt.name)
158
+ if test_img is None:
159
+ test_img = get_image('Tom_Hanks_54745')
160
+ test_img = cv2.resize(test_img, self.image_size)
161
+ else:
162
+ test_img = cv2.resize(test_img, self.image_size)
163
+ feat, cost = self.benchmark(test_img)
164
+ batch_result = self.check_batch(test_img)
165
+ batch_result_sum = float(np.sum(batch_result))
166
+ if batch_result_sum in [float('inf'), -float('inf')] or batch_result_sum != batch_result_sum:
167
+ print(batch_result)
168
+ print(batch_result_sum)
169
+ return "batch result output contains NaN!"
170
+
171
+ if len(feat.shape) < 2:
172
+ return "the shape of the feature must be two, but get {}".format(str(feat.shape))
173
+
174
+ if feat.shape[1] > max_feat_dim:
175
+ return "max feat dim exceed, given %d"%feat.shape[1]
176
+ self.feat_dim = feat.shape[1]
177
+ cost_ms = cost*1000
178
+ if cost_ms>max_time_cost:
179
+ return "max time cost exceed, given %.4f"%cost_ms
180
+ self.cost_ms = cost_ms
181
+ print('check stat:, model-size-mb: %.4f, feat-dim: %d, time-cost-ms: %.4f, input-mean: %.3f, input-std: %.3f'%(self.model_size_mb, self.feat_dim, self.cost_ms, self.input_mean, self.input_std))
182
+ return None
183
+
184
+ def check_batch(self, img):
185
+ if not isinstance(img, list):
186
+ imgs = [img, ] * 32
187
+ if self.crop is not None:
188
+ nimgs = []
189
+ for img in imgs:
190
+ nimg = img[self.crop[1]:self.crop[3], self.crop[0]:self.crop[2], :]
191
+ if nimg.shape[0] != self.image_size[1] or nimg.shape[1] != self.image_size[0]:
192
+ nimg = cv2.resize(nimg, self.image_size)
193
+ nimgs.append(nimg)
194
+ imgs = nimgs
195
+ blob = cv2.dnn.blobFromImages(
196
+ images=imgs, scalefactor=1.0 / self.input_std, size=self.image_size,
197
+ mean=(self.input_mean, self.input_mean, self.input_mean), swapRB=True)
198
+ net_out = self.session.run(self.output_names, {self.input_name: blob})[0]
199
+ return net_out
200
+
201
+
202
+ def meta_info(self):
203
+ return {'model-size-mb':self.model_size_mb, 'feature-dim':self.feat_dim, 'infer': self.cost_ms}
204
+
205
+
206
+ def forward(self, imgs):
207
+ if not isinstance(imgs, list):
208
+ imgs = [imgs]
209
+ input_size = self.image_size
210
+ if self.crop is not None:
211
+ nimgs = []
212
+ for img in imgs:
213
+ nimg = img[self.crop[1]:self.crop[3],self.crop[0]:self.crop[2],:]
214
+ if nimg.shape[0]!=input_size[1] or nimg.shape[1]!=input_size[0]:
215
+ nimg = cv2.resize(nimg, input_size)
216
+ nimgs.append(nimg)
217
+ imgs = nimgs
218
+ blob = cv2.dnn.blobFromImages(imgs, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
219
+ net_out = self.session.run(self.output_names, {self.input_name : blob})[0]
220
+ return net_out
221
+
222
+ def benchmark(self, img):
223
+ input_size = self.image_size
224
+ if self.crop is not None:
225
+ nimg = img[self.crop[1]:self.crop[3],self.crop[0]:self.crop[2],:]
226
+ if nimg.shape[0]!=input_size[1] or nimg.shape[1]!=input_size[0]:
227
+ nimg = cv2.resize(nimg, input_size)
228
+ img = nimg
229
+ blob = cv2.dnn.blobFromImage(img, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
230
+ costs = []
231
+ for _ in range(50):
232
+ ta = datetime.datetime.now()
233
+ net_out = self.session.run(self.output_names, {self.input_name : blob})[0]
234
+ tb = datetime.datetime.now()
235
+ cost = (tb-ta).total_seconds()
236
+ costs.append(cost)
237
+ costs = sorted(costs)
238
+ cost = costs[5]
239
+ return net_out, cost
240
+
241
+
242
+ if __name__ == '__main__':
243
+ parser = argparse.ArgumentParser(description='')
244
+ # general
245
+ parser.add_argument('workdir', help='submitted work dir', type=str)
246
+ parser.add_argument('--track', help='track name, for different challenge', type=str, default='cfat')
247
+ args = parser.parse_args()
248
+ handler = ArcFaceORT(args.workdir)
249
+ err = handler.check(args.track)
250
+ print('err:', err)
videoretalking/third_part/face3d/models/arcface_torch/onnx_ijbc.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import pickle
4
+ import timeit
5
+
6
+ import cv2
7
+ import mxnet as mx
8
+ import numpy as np
9
+ import pandas as pd
10
+ import prettytable
11
+ import skimage.transform
12
+ from sklearn.metrics import roc_curve
13
+ from sklearn.preprocessing import normalize
14
+
15
+ from onnx_helper import ArcFaceORT
16
+
17
+ SRC = np.array(
18
+ [
19
+ [30.2946, 51.6963],
20
+ [65.5318, 51.5014],
21
+ [48.0252, 71.7366],
22
+ [33.5493, 92.3655],
23
+ [62.7299, 92.2041]]
24
+ , dtype=np.float32)
25
+ SRC[:, 0] += 8.0
26
+
27
+
28
+ class AlignedDataSet(mx.gluon.data.Dataset):
29
+ def __init__(self, root, lines, align=True):
30
+ self.lines = lines
31
+ self.root = root
32
+ self.align = align
33
+
34
+ def __len__(self):
35
+ return len(self.lines)
36
+
37
+ def __getitem__(self, idx):
38
+ each_line = self.lines[idx]
39
+ name_lmk_score = each_line.strip().split(' ')
40
+ name = os.path.join(self.root, name_lmk_score[0])
41
+ img = cv2.cvtColor(cv2.imread(name), cv2.COLOR_BGR2RGB)
42
+ landmark5 = np.array([float(x) for x in name_lmk_score[1:-1]], dtype=np.float32).reshape((5, 2))
43
+ st = skimage.transform.SimilarityTransform()
44
+ st.estimate(landmark5, SRC)
45
+ img = cv2.warpAffine(img, st.params[0:2, :], (112, 112), borderValue=0.0)
46
+ img_1 = np.expand_dims(img, 0)
47
+ img_2 = np.expand_dims(np.fliplr(img), 0)
48
+ output = np.concatenate((img_1, img_2), axis=0).astype(np.float32)
49
+ output = np.transpose(output, (0, 3, 1, 2))
50
+ output = mx.nd.array(output)
51
+ return output
52
+
53
+
54
+ def extract(model_root, dataset):
55
+ model = ArcFaceORT(model_path=model_root)
56
+ model.check()
57
+ feat_mat = np.zeros(shape=(len(dataset), 2 * model.feat_dim))
58
+
59
+ def batchify_fn(data):
60
+ return mx.nd.concat(*data, dim=0)
61
+
62
+ data_loader = mx.gluon.data.DataLoader(
63
+ dataset, 128, last_batch='keep', num_workers=4,
64
+ thread_pool=True, prefetch=16, batchify_fn=batchify_fn)
65
+ num_iter = 0
66
+ for batch in data_loader:
67
+ batch = batch.asnumpy()
68
+ batch = (batch - model.input_mean) / model.input_std
69
+ feat = model.session.run(model.output_names, {model.input_name: batch})[0]
70
+ feat = np.reshape(feat, (-1, model.feat_dim * 2))
71
+ feat_mat[128 * num_iter: 128 * num_iter + feat.shape[0], :] = feat
72
+ num_iter += 1
73
+ if num_iter % 50 == 0:
74
+ print(num_iter)
75
+ return feat_mat
76
+
77
+
78
+ def read_template_media_list(path):
79
+ ijb_meta = pd.read_csv(path, sep=' ', header=None).values
80
+ templates = ijb_meta[:, 1].astype(np.int)
81
+ medias = ijb_meta[:, 2].astype(np.int)
82
+ return templates, medias
83
+
84
+
85
+ def read_template_pair_list(path):
86
+ pairs = pd.read_csv(path, sep=' ', header=None).values
87
+ t1 = pairs[:, 0].astype(np.int)
88
+ t2 = pairs[:, 1].astype(np.int)
89
+ label = pairs[:, 2].astype(np.int)
90
+ return t1, t2, label
91
+
92
+
93
+ def read_image_feature(path):
94
+ with open(path, 'rb') as fid:
95
+ img_feats = pickle.load(fid)
96
+ return img_feats
97
+
98
+
99
+ def image2template_feature(img_feats=None,
100
+ templates=None,
101
+ medias=None):
102
+ unique_templates = np.unique(templates)
103
+ template_feats = np.zeros((len(unique_templates), img_feats.shape[1]))
104
+ for count_template, uqt in enumerate(unique_templates):
105
+ (ind_t,) = np.where(templates == uqt)
106
+ face_norm_feats = img_feats[ind_t]
107
+ face_medias = medias[ind_t]
108
+ unique_medias, unique_media_counts = np.unique(face_medias, return_counts=True)
109
+ media_norm_feats = []
110
+ for u, ct in zip(unique_medias, unique_media_counts):
111
+ (ind_m,) = np.where(face_medias == u)
112
+ if ct == 1:
113
+ media_norm_feats += [face_norm_feats[ind_m]]
114
+ else: # image features from the same video will be aggregated into one feature
115
+ media_norm_feats += [np.mean(face_norm_feats[ind_m], axis=0, keepdims=True), ]
116
+ media_norm_feats = np.array(media_norm_feats)
117
+ template_feats[count_template] = np.sum(media_norm_feats, axis=0)
118
+ if count_template % 2000 == 0:
119
+ print('Finish Calculating {} template features.'.format(
120
+ count_template))
121
+ template_norm_feats = normalize(template_feats)
122
+ return template_norm_feats, unique_templates
123
+
124
+
125
+ def verification(template_norm_feats=None,
126
+ unique_templates=None,
127
+ p1=None,
128
+ p2=None):
129
+ template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int)
130
+ for count_template, uqt in enumerate(unique_templates):
131
+ template2id[uqt] = count_template
132
+ score = np.zeros((len(p1),))
133
+ total_pairs = np.array(range(len(p1)))
134
+ batchsize = 100000
135
+ sublists = [total_pairs[i: i + batchsize] for i in range(0, len(p1), batchsize)]
136
+ total_sublists = len(sublists)
137
+ for c, s in enumerate(sublists):
138
+ feat1 = template_norm_feats[template2id[p1[s]]]
139
+ feat2 = template_norm_feats[template2id[p2[s]]]
140
+ similarity_score = np.sum(feat1 * feat2, -1)
141
+ score[s] = similarity_score.flatten()
142
+ if c % 10 == 0:
143
+ print('Finish {}/{} pairs.'.format(c, total_sublists))
144
+ return score
145
+
146
+
147
+ def verification2(template_norm_feats=None,
148
+ unique_templates=None,
149
+ p1=None,
150
+ p2=None):
151
+ template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int)
152
+ for count_template, uqt in enumerate(unique_templates):
153
+ template2id[uqt] = count_template
154
+ score = np.zeros((len(p1),)) # save cosine distance between pairs
155
+ total_pairs = np.array(range(len(p1)))
156
+ batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation
157
+ sublists = [total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize)]
158
+ total_sublists = len(sublists)
159
+ for c, s in enumerate(sublists):
160
+ feat1 = template_norm_feats[template2id[p1[s]]]
161
+ feat2 = template_norm_feats[template2id[p2[s]]]
162
+ similarity_score = np.sum(feat1 * feat2, -1)
163
+ score[s] = similarity_score.flatten()
164
+ if c % 10 == 0:
165
+ print('Finish {}/{} pairs.'.format(c, total_sublists))
166
+ return score
167
+
168
+
169
+ def main(args):
170
+ use_norm_score = True # if Ture, TestMode(N1)
171
+ use_detector_score = True # if Ture, TestMode(D1)
172
+ use_flip_test = True # if Ture, TestMode(F1)
173
+ assert args.target == 'IJBC' or args.target == 'IJBB'
174
+
175
+ start = timeit.default_timer()
176
+ templates, medias = read_template_media_list(
177
+ os.path.join('%s/meta' % args.image_path, '%s_face_tid_mid.txt' % args.target.lower()))
178
+ stop = timeit.default_timer()
179
+ print('Time: %.2f s. ' % (stop - start))
180
+
181
+ start = timeit.default_timer()
182
+ p1, p2, label = read_template_pair_list(
183
+ os.path.join('%s/meta' % args.image_path,
184
+ '%s_template_pair_label.txt' % args.target.lower()))
185
+ stop = timeit.default_timer()
186
+ print('Time: %.2f s. ' % (stop - start))
187
+
188
+ start = timeit.default_timer()
189
+ img_path = '%s/loose_crop' % args.image_path
190
+ img_list_path = '%s/meta/%s_name_5pts_score.txt' % (args.image_path, args.target.lower())
191
+ img_list = open(img_list_path)
192
+ files = img_list.readlines()
193
+ dataset = AlignedDataSet(root=img_path, lines=files, align=True)
194
+ img_feats = extract(args.model_root, dataset)
195
+
196
+ faceness_scores = []
197
+ for each_line in files:
198
+ name_lmk_score = each_line.split()
199
+ faceness_scores.append(name_lmk_score[-1])
200
+ faceness_scores = np.array(faceness_scores).astype(np.float32)
201
+ stop = timeit.default_timer()
202
+ print('Time: %.2f s. ' % (stop - start))
203
+ print('Feature Shape: ({} , {}) .'.format(img_feats.shape[0], img_feats.shape[1]))
204
+ start = timeit.default_timer()
205
+
206
+ if use_flip_test:
207
+ img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2] + img_feats[:, img_feats.shape[1] // 2:]
208
+ else:
209
+ img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2]
210
+
211
+ if use_norm_score:
212
+ img_input_feats = img_input_feats
213
+ else:
214
+ img_input_feats = img_input_feats / np.sqrt(np.sum(img_input_feats ** 2, -1, keepdims=True))
215
+
216
+ if use_detector_score:
217
+ print(img_input_feats.shape, faceness_scores.shape)
218
+ img_input_feats = img_input_feats * faceness_scores[:, np.newaxis]
219
+ else:
220
+ img_input_feats = img_input_feats
221
+
222
+ template_norm_feats, unique_templates = image2template_feature(
223
+ img_input_feats, templates, medias)
224
+ stop = timeit.default_timer()
225
+ print('Time: %.2f s. ' % (stop - start))
226
+
227
+ start = timeit.default_timer()
228
+ score = verification(template_norm_feats, unique_templates, p1, p2)
229
+ stop = timeit.default_timer()
230
+ print('Time: %.2f s. ' % (stop - start))
231
+ save_path = os.path.join(args.result_dir, "{}_result".format(args.target))
232
+ if not os.path.exists(save_path):
233
+ os.makedirs(save_path)
234
+ score_save_file = os.path.join(save_path, "{}.npy".format(args.model_root))
235
+ np.save(score_save_file, score)
236
+ files = [score_save_file]
237
+ methods = []
238
+ scores = []
239
+ for file in files:
240
+ methods.append(os.path.basename(file))
241
+ scores.append(np.load(file))
242
+ methods = np.array(methods)
243
+ scores = dict(zip(methods, scores))
244
+ x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1]
245
+ tpr_fpr_table = prettytable.PrettyTable(['Methods'] + [str(x) for x in x_labels])
246
+ for method in methods:
247
+ fpr, tpr, _ = roc_curve(label, scores[method])
248
+ fpr = np.flipud(fpr)
249
+ tpr = np.flipud(tpr)
250
+ tpr_fpr_row = []
251
+ tpr_fpr_row.append("%s-%s" % (method, args.target))
252
+ for fpr_iter in np.arange(len(x_labels)):
253
+ _, min_index = min(
254
+ list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr)))))
255
+ tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100))
256
+ tpr_fpr_table.add_row(tpr_fpr_row)
257
+ print(tpr_fpr_table)
258
+
259
+
260
+ if __name__ == '__main__':
261
+ parser = argparse.ArgumentParser(description='do ijb test')
262
+ # general
263
+ parser.add_argument('--model-root', default='', help='path to load model.')
264
+ parser.add_argument('--image-path', default='', type=str, help='')
265
+ parser.add_argument('--result-dir', default='.', type=str, help='')
266
+ parser.add_argument('--target', default='IJBC', type=str, help='target, set to IJBC or IJBB')
267
+ main(parser.parse_args())
videoretalking/third_part/face3d/models/arcface_torch/partial_fc.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+
4
+ import torch
5
+ import torch.distributed as dist
6
+ from torch.nn import Module
7
+ from torch.nn.functional import normalize, linear
8
+ from torch.nn.parameter import Parameter
9
+
10
+
11
+ class PartialFC(Module):
12
+ """
13
+ Author: {Xiang An, Yang Xiao, XuHan Zhu} in DeepGlint,
14
+ Partial FC: Training 10 Million Identities on a Single Machine
15
+ See the original paper:
16
+ https://arxiv.org/abs/2010.05222
17
+ """
18
+
19
+ @torch.no_grad()
20
+ def __init__(self, rank, local_rank, world_size, batch_size, resume,
21
+ margin_softmax, num_classes, sample_rate=1.0, embedding_size=512, prefix="./"):
22
+ """
23
+ rank: int
24
+ Unique process(GPU) ID from 0 to world_size - 1.
25
+ local_rank: int
26
+ Unique process(GPU) ID within the server from 0 to 7.
27
+ world_size: int
28
+ Number of GPU.
29
+ batch_size: int
30
+ Batch size on current rank(GPU).
31
+ resume: bool
32
+ Select whether to restore the weight of softmax.
33
+ margin_softmax: callable
34
+ A function of margin softmax, eg: cosface, arcface.
35
+ num_classes: int
36
+ The number of class center storage in current rank(CPU/GPU), usually is total_classes // world_size,
37
+ required.
38
+ sample_rate: float
39
+ The partial fc sampling rate, when the number of classes increases to more than 2 millions, Sampling
40
+ can greatly speed up training, and reduce a lot of GPU memory, default is 1.0.
41
+ embedding_size: int
42
+ The feature dimension, default is 512.
43
+ prefix: str
44
+ Path for save checkpoint, default is './'.
45
+ """
46
+ super(PartialFC, self).__init__()
47
+ #
48
+ self.num_classes: int = num_classes
49
+ self.rank: int = rank
50
+ self.local_rank: int = local_rank
51
+ self.device: torch.device = torch.device("cuda:{}".format(self.local_rank))
52
+ self.world_size: int = world_size
53
+ self.batch_size: int = batch_size
54
+ self.margin_softmax: callable = margin_softmax
55
+ self.sample_rate: float = sample_rate
56
+ self.embedding_size: int = embedding_size
57
+ self.prefix: str = prefix
58
+ self.num_local: int = num_classes // world_size + int(rank < num_classes % world_size)
59
+ self.class_start: int = num_classes // world_size * rank + min(rank, num_classes % world_size)
60
+ self.num_sample: int = int(self.sample_rate * self.num_local)
61
+
62
+ self.weight_name = os.path.join(self.prefix, "rank_{}_softmax_weight.pt".format(self.rank))
63
+ self.weight_mom_name = os.path.join(self.prefix, "rank_{}_softmax_weight_mom.pt".format(self.rank))
64
+
65
+ if resume:
66
+ try:
67
+ self.weight: torch.Tensor = torch.load(self.weight_name)
68
+ self.weight_mom: torch.Tensor = torch.load(self.weight_mom_name)
69
+ if self.weight.shape[0] != self.num_local or self.weight_mom.shape[0] != self.num_local:
70
+ raise IndexError
71
+ logging.info("softmax weight resume successfully!")
72
+ logging.info("softmax weight mom resume successfully!")
73
+ except (FileNotFoundError, KeyError, IndexError):
74
+ self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device)
75
+ self.weight_mom: torch.Tensor = torch.zeros_like(self.weight)
76
+ logging.info("softmax weight init!")
77
+ logging.info("softmax weight mom init!")
78
+ else:
79
+ self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device)
80
+ self.weight_mom: torch.Tensor = torch.zeros_like(self.weight)
81
+ logging.info("softmax weight init successfully!")
82
+ logging.info("softmax weight mom init successfully!")
83
+ self.stream: torch.cuda.Stream = torch.cuda.Stream(local_rank)
84
+
85
+ self.index = None
86
+ if int(self.sample_rate) == 1:
87
+ self.update = lambda: 0
88
+ self.sub_weight = Parameter(self.weight)
89
+ self.sub_weight_mom = self.weight_mom
90
+ else:
91
+ self.sub_weight = Parameter(torch.empty((0, 0)).cuda(local_rank))
92
+
93
+ def save_params(self):
94
+ """ Save softmax weight for each rank on prefix
95
+ """
96
+ torch.save(self.weight.data, self.weight_name)
97
+ torch.save(self.weight_mom, self.weight_mom_name)
98
+
99
+ @torch.no_grad()
100
+ def sample(self, total_label):
101
+ """
102
+ Sample all positive class centers in each rank, and random select neg class centers to filling a fixed
103
+ `num_sample`.
104
+
105
+ total_label: tensor
106
+ Label after all gather, which cross all GPUs.
107
+ """
108
+ index_positive = (self.class_start <= total_label) & (total_label < self.class_start + self.num_local)
109
+ total_label[~index_positive] = -1
110
+ total_label[index_positive] -= self.class_start
111
+ if int(self.sample_rate) != 1:
112
+ positive = torch.unique(total_label[index_positive], sorted=True)
113
+ if self.num_sample - positive.size(0) >= 0:
114
+ perm = torch.rand(size=[self.num_local], device=self.device)
115
+ perm[positive] = 2.0
116
+ index = torch.topk(perm, k=self.num_sample)[1]
117
+ index = index.sort()[0]
118
+ else:
119
+ index = positive
120
+ self.index = index
121
+ total_label[index_positive] = torch.searchsorted(index, total_label[index_positive])
122
+ self.sub_weight = Parameter(self.weight[index])
123
+ self.sub_weight_mom = self.weight_mom[index]
124
+
125
+ def forward(self, total_features, norm_weight):
126
+ """ Partial fc forward, `logits = X * sample(W)`
127
+ """
128
+ torch.cuda.current_stream().wait_stream(self.stream)
129
+ logits = linear(total_features, norm_weight)
130
+ return logits
131
+
132
+ @torch.no_grad()
133
+ def update(self):
134
+ """ Set updated weight and weight_mom to memory bank.
135
+ """
136
+ self.weight_mom[self.index] = self.sub_weight_mom
137
+ self.weight[self.index] = self.sub_weight
138
+
139
+ def prepare(self, label, optimizer):
140
+ """
141
+ get sampled class centers for cal softmax.
142
+
143
+ label: tensor
144
+ Label tensor on each rank.
145
+ optimizer: opt
146
+ Optimizer for partial fc, which need to get weight mom.
147
+ """
148
+ with torch.cuda.stream(self.stream):
149
+ total_label = torch.zeros(
150
+ size=[self.batch_size * self.world_size], device=self.device, dtype=torch.long)
151
+ dist.all_gather(list(total_label.chunk(self.world_size, dim=0)), label)
152
+ self.sample(total_label)
153
+ optimizer.state.pop(optimizer.param_groups[-1]['params'][0], None)
154
+ optimizer.param_groups[-1]['params'][0] = self.sub_weight
155
+ optimizer.state[self.sub_weight]['momentum_buffer'] = self.sub_weight_mom
156
+ norm_weight = normalize(self.sub_weight)
157
+ return total_label, norm_weight
158
+
159
+ def forward_backward(self, label, features, optimizer):
160
+ """
161
+ Partial fc forward and backward with model parallel
162
+
163
+ label: tensor
164
+ Label tensor on each rank(GPU)
165
+ features: tensor
166
+ Features tensor on each rank(GPU)
167
+ optimizer: optimizer
168
+ Optimizer for partial fc
169
+
170
+ Returns:
171
+ --------
172
+ x_grad: tensor
173
+ The gradient of features.
174
+ loss_v: tensor
175
+ Loss value for cross entropy.
176
+ """
177
+ total_label, norm_weight = self.prepare(label, optimizer)
178
+ total_features = torch.zeros(
179
+ size=[self.batch_size * self.world_size, self.embedding_size], device=self.device)
180
+ dist.all_gather(list(total_features.chunk(self.world_size, dim=0)), features.data)
181
+ total_features.requires_grad = True
182
+
183
+ logits = self.forward(total_features, norm_weight)
184
+ logits = self.margin_softmax(logits, total_label)
185
+
186
+ with torch.no_grad():
187
+ max_fc = torch.max(logits, dim=1, keepdim=True)[0]
188
+ dist.all_reduce(max_fc, dist.ReduceOp.MAX)
189
+
190
+ # calculate exp(logits) and all-reduce
191
+ logits_exp = torch.exp(logits - max_fc)
192
+ logits_sum_exp = logits_exp.sum(dim=1, keepdims=True)
193
+ dist.all_reduce(logits_sum_exp, dist.ReduceOp.SUM)
194
+
195
+ # calculate prob
196
+ logits_exp.div_(logits_sum_exp)
197
+
198
+ # get one-hot
199
+ grad = logits_exp
200
+ index = torch.where(total_label != -1)[0]
201
+ one_hot = torch.zeros(size=[index.size()[0], grad.size()[1]], device=grad.device)
202
+ one_hot.scatter_(1, total_label[index, None], 1)
203
+
204
+ # calculate loss
205
+ loss = torch.zeros(grad.size()[0], 1, device=grad.device)
206
+ loss[index] = grad[index].gather(1, total_label[index, None])
207
+ dist.all_reduce(loss, dist.ReduceOp.SUM)
208
+ loss_v = loss.clamp_min_(1e-30).log_().mean() * (-1)
209
+
210
+ # calculate grad
211
+ grad[index] -= one_hot
212
+ grad.div_(self.batch_size * self.world_size)
213
+
214
+ logits.backward(grad)
215
+ if total_features.grad is not None:
216
+ total_features.grad.detach_()
217
+ x_grad: torch.Tensor = torch.zeros_like(features, requires_grad=True)
218
+ # feature gradient all-reduce
219
+ dist.reduce_scatter(x_grad, list(total_features.grad.chunk(self.world_size, dim=0)))
220
+ x_grad = x_grad * self.world_size
221
+ # backward backbone
222
+ return x_grad, loss_v
videoretalking/third_part/face3d/models/arcface_torch/requirement.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ tensorboard
2
+ easydict
3
+ mxnet
4
+ onnx
5
+ sklearn
videoretalking/third_part/face3d/models/arcface_torch/run.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r50
2
+ ps -ef | grep "train" | grep -v grep | awk '{print "kill -9 "$2}' | sh
videoretalking/third_part/face3d/models/arcface_torch/torch2onnx.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import onnx
3
+ import torch
4
+
5
+
6
+ def convert_onnx(net, path_module, output, opset=11, simplify=False):
7
+ assert isinstance(net, torch.nn.Module)
8
+ img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32)
9
+ img = img.astype(np.float)
10
+ img = (img / 255. - 0.5) / 0.5 # torch style norm
11
+ img = img.transpose((2, 0, 1))
12
+ img = torch.from_numpy(img).unsqueeze(0).float()
13
+
14
+ weight = torch.load(path_module)
15
+ net.load_state_dict(weight)
16
+ net.eval()
17
+ torch.onnx.export(net, img, output, keep_initializers_as_inputs=False, verbose=False, opset_version=opset)
18
+ model = onnx.load(output)
19
+ graph = model.graph
20
+ graph.input[0].type.tensor_type.shape.dim[0].dim_param = 'None'
21
+ if simplify:
22
+ from onnxsim import simplify
23
+ model, check = simplify(model)
24
+ assert check, "Simplified ONNX model could not be validated"
25
+ onnx.save(model, output)
26
+
27
+
28
+ if __name__ == '__main__':
29
+ import os
30
+ import argparse
31
+ from backbones import get_model
32
+
33
+ parser = argparse.ArgumentParser(description='ArcFace PyTorch to onnx')
34
+ parser.add_argument('input', type=str, help='input backbone.pth file or path')
35
+ parser.add_argument('--output', type=str, default=None, help='output onnx path')
36
+ parser.add_argument('--network', type=str, default=None, help='backbone network')
37
+ parser.add_argument('--simplify', type=bool, default=False, help='onnx simplify')
38
+ args = parser.parse_args()
39
+ input_file = args.input
40
+ if os.path.isdir(input_file):
41
+ input_file = os.path.join(input_file, "backbone.pth")
42
+ assert os.path.exists(input_file)
43
+ model_name = os.path.basename(os.path.dirname(input_file)).lower()
44
+ params = model_name.split("_")
45
+ if len(params) >= 3 and params[1] in ('arcface', 'cosface'):
46
+ if args.network is None:
47
+ args.network = params[2]
48
+ assert args.network is not None
49
+ print(args)
50
+ backbone_onnx = get_model(args.network, dropout=0)
51
+
52
+ output_path = args.output
53
+ if output_path is None:
54
+ output_path = os.path.join(os.path.dirname(__file__), 'onnx')
55
+ if not os.path.exists(output_path):
56
+ os.makedirs(output_path)
57
+ assert os.path.isdir(output_path)
58
+ output_file = os.path.join(output_path, "%s.onnx" % model_name)
59
+ convert_onnx(backbone_onnx, input_file, output_file, simplify=args.simplify)
videoretalking/third_part/face3d/models/arcface_torch/train.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import os
4
+
5
+ import torch
6
+ import torch.distributed as dist
7
+ import torch.nn.functional as F
8
+ import torch.utils.data.distributed
9
+ from torch.nn.utils import clip_grad_norm_
10
+
11
+ import losses
12
+ from backbones import get_model
13
+ from dataset import MXFaceDataset, SyntheticDataset, DataLoaderX
14
+ from partial_fc import PartialFC
15
+ from utils.utils_amp import MaxClipGradScaler
16
+ from utils.utils_callbacks import CallBackVerification, CallBackLogging, CallBackModelCheckpoint
17
+ from utils.utils_config import get_config
18
+ from utils.utils_logging import AverageMeter, init_logging
19
+
20
+
21
+ def main(args):
22
+ cfg = get_config(args.config)
23
+ try:
24
+ world_size = int(os.environ['WORLD_SIZE'])
25
+ rank = int(os.environ['RANK'])
26
+ dist.init_process_group('nccl')
27
+ except KeyError:
28
+ world_size = 1
29
+ rank = 0
30
+ dist.init_process_group(backend='nccl', init_method="tcp://127.0.0.1:12584", rank=rank, world_size=world_size)
31
+
32
+ local_rank = args.local_rank
33
+ torch.cuda.set_device(local_rank)
34
+ os.makedirs(cfg.output, exist_ok=True)
35
+ init_logging(rank, cfg.output)
36
+
37
+ if cfg.rec == "synthetic":
38
+ train_set = SyntheticDataset(local_rank=local_rank)
39
+ else:
40
+ train_set = MXFaceDataset(root_dir=cfg.rec, local_rank=local_rank)
41
+
42
+ train_sampler = torch.utils.data.distributed.DistributedSampler(train_set, shuffle=True)
43
+ train_loader = DataLoaderX(
44
+ local_rank=local_rank, dataset=train_set, batch_size=cfg.batch_size,
45
+ sampler=train_sampler, num_workers=2, pin_memory=True, drop_last=True)
46
+ backbone = get_model(cfg.network, dropout=0.0, fp16=cfg.fp16, num_features=cfg.embedding_size).to(local_rank)
47
+
48
+ if cfg.resume:
49
+ try:
50
+ backbone_pth = os.path.join(cfg.output, "backbone.pth")
51
+ backbone.load_state_dict(torch.load(backbone_pth, map_location=torch.device(local_rank)))
52
+ if rank == 0:
53
+ logging.info("backbone resume successfully!")
54
+ except (FileNotFoundError, KeyError, IndexError, RuntimeError):
55
+ if rank == 0:
56
+ logging.info("resume fail, backbone init successfully!")
57
+
58
+ backbone = torch.nn.parallel.DistributedDataParallel(
59
+ module=backbone, broadcast_buffers=False, device_ids=[local_rank])
60
+ backbone.train()
61
+ margin_softmax = losses.get_loss(cfg.loss)
62
+ module_partial_fc = PartialFC(
63
+ rank=rank, local_rank=local_rank, world_size=world_size, resume=cfg.resume,
64
+ batch_size=cfg.batch_size, margin_softmax=margin_softmax, num_classes=cfg.num_classes,
65
+ sample_rate=cfg.sample_rate, embedding_size=cfg.embedding_size, prefix=cfg.output)
66
+
67
+ opt_backbone = torch.optim.SGD(
68
+ params=[{'params': backbone.parameters()}],
69
+ lr=cfg.lr / 512 * cfg.batch_size * world_size,
70
+ momentum=0.9, weight_decay=cfg.weight_decay)
71
+ opt_pfc = torch.optim.SGD(
72
+ params=[{'params': module_partial_fc.parameters()}],
73
+ lr=cfg.lr / 512 * cfg.batch_size * world_size,
74
+ momentum=0.9, weight_decay=cfg.weight_decay)
75
+
76
+ num_image = len(train_set)
77
+ total_batch_size = cfg.batch_size * world_size
78
+ cfg.warmup_step = num_image // total_batch_size * cfg.warmup_epoch
79
+ cfg.total_step = num_image // total_batch_size * cfg.num_epoch
80
+
81
+ def lr_step_func(current_step):
82
+ cfg.decay_step = [x * num_image // total_batch_size for x in cfg.decay_epoch]
83
+ if current_step < cfg.warmup_step:
84
+ return current_step / cfg.warmup_step
85
+ else:
86
+ return 0.1 ** len([m for m in cfg.decay_step if m <= current_step])
87
+
88
+ scheduler_backbone = torch.optim.lr_scheduler.LambdaLR(
89
+ optimizer=opt_backbone, lr_lambda=lr_step_func)
90
+ scheduler_pfc = torch.optim.lr_scheduler.LambdaLR(
91
+ optimizer=opt_pfc, lr_lambda=lr_step_func)
92
+
93
+ for key, value in cfg.items():
94
+ num_space = 25 - len(key)
95
+ logging.info(": " + key + " " * num_space + str(value))
96
+
97
+ val_target = cfg.val_targets
98
+ callback_verification = CallBackVerification(2000, rank, val_target, cfg.rec)
99
+ callback_logging = CallBackLogging(50, rank, cfg.total_step, cfg.batch_size, world_size, None)
100
+ callback_checkpoint = CallBackModelCheckpoint(rank, cfg.output)
101
+
102
+ loss = AverageMeter()
103
+ start_epoch = 0
104
+ global_step = 0
105
+ grad_amp = MaxClipGradScaler(cfg.batch_size, 128 * cfg.batch_size, growth_interval=100) if cfg.fp16 else None
106
+ for epoch in range(start_epoch, cfg.num_epoch):
107
+ train_sampler.set_epoch(epoch)
108
+ for step, (img, label) in enumerate(train_loader):
109
+ global_step += 1
110
+ features = F.normalize(backbone(img))
111
+ x_grad, loss_v = module_partial_fc.forward_backward(label, features, opt_pfc)
112
+ if cfg.fp16:
113
+ features.backward(grad_amp.scale(x_grad))
114
+ grad_amp.unscale_(opt_backbone)
115
+ clip_grad_norm_(backbone.parameters(), max_norm=5, norm_type=2)
116
+ grad_amp.step(opt_backbone)
117
+ grad_amp.update()
118
+ else:
119
+ features.backward(x_grad)
120
+ clip_grad_norm_(backbone.parameters(), max_norm=5, norm_type=2)
121
+ opt_backbone.step()
122
+
123
+ opt_pfc.step()
124
+ module_partial_fc.update()
125
+ opt_backbone.zero_grad()
126
+ opt_pfc.zero_grad()
127
+ loss.update(loss_v, 1)
128
+ callback_logging(global_step, loss, epoch, cfg.fp16, scheduler_backbone.get_last_lr()[0], grad_amp)
129
+ callback_verification(global_step, backbone)
130
+ scheduler_backbone.step()
131
+ scheduler_pfc.step()
132
+ callback_checkpoint(global_step, backbone, module_partial_fc)
133
+ dist.destroy_process_group()
134
+
135
+
136
+ if __name__ == "__main__":
137
+ torch.backends.cudnn.benchmark = True
138
+ parser = argparse.ArgumentParser(description='PyTorch ArcFace Training')
139
+ parser.add_argument('config', type=str, help='py config file')
140
+ parser.add_argument('--local_rank', type=int, default=0, help='local_rank')
141
+ main(parser.parse_args())
videoretalking/third_part/face3d/models/arcface_torch/utils/__init__.py ADDED
File without changes
videoretalking/third_part/face3d/models/arcface_torch/utils/plot.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+
3
+ import os
4
+ from pathlib import Path
5
+
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ import pandas as pd
9
+ from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap
10
+ from prettytable import PrettyTable
11
+ from sklearn.metrics import roc_curve, auc
12
+
13
+ image_path = "/data/anxiang/IJB_release/IJBC"
14
+ files = [
15
+ "./ms1mv3_arcface_r100/ms1mv3_arcface_r100/ijbc.npy"
16
+ ]
17
+
18
+
19
+ def read_template_pair_list(path):
20
+ pairs = pd.read_csv(path, sep=' ', header=None).values
21
+ t1 = pairs[:, 0].astype(np.int)
22
+ t2 = pairs[:, 1].astype(np.int)
23
+ label = pairs[:, 2].astype(np.int)
24
+ return t1, t2, label
25
+
26
+
27
+ p1, p2, label = read_template_pair_list(
28
+ os.path.join('%s/meta' % image_path,
29
+ '%s_template_pair_label.txt' % 'ijbc'))
30
+
31
+ methods = []
32
+ scores = []
33
+ for file in files:
34
+ methods.append(file.split('/')[-2])
35
+ scores.append(np.load(file))
36
+
37
+ methods = np.array(methods)
38
+ scores = dict(zip(methods, scores))
39
+ colours = dict(
40
+ zip(methods, sample_colours_from_colourmap(methods.shape[0], 'Set2')))
41
+ x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1]
42
+ tpr_fpr_table = PrettyTable(['Methods'] + [str(x) for x in x_labels])
43
+ fig = plt.figure()
44
+ for method in methods:
45
+ fpr, tpr, _ = roc_curve(label, scores[method])
46
+ roc_auc = auc(fpr, tpr)
47
+ fpr = np.flipud(fpr)
48
+ tpr = np.flipud(tpr) # select largest tpr at same fpr
49
+ plt.plot(fpr,
50
+ tpr,
51
+ color=colours[method],
52
+ lw=1,
53
+ label=('[%s (AUC = %0.4f %%)]' %
54
+ (method.split('-')[-1], roc_auc * 100)))
55
+ tpr_fpr_row = []
56
+ tpr_fpr_row.append("%s-%s" % (method, "IJBC"))
57
+ for fpr_iter in np.arange(len(x_labels)):
58
+ _, min_index = min(
59
+ list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr)))))
60
+ tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100))
61
+ tpr_fpr_table.add_row(tpr_fpr_row)
62
+ plt.xlim([10 ** -6, 0.1])
63
+ plt.ylim([0.3, 1.0])
64
+ plt.grid(linestyle='--', linewidth=1)
65
+ plt.xticks(x_labels)
66
+ plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True))
67
+ plt.xscale('log')
68
+ plt.xlabel('False Positive Rate')
69
+ plt.ylabel('True Positive Rate')
70
+ plt.title('ROC on IJB')
71
+ plt.legend(loc="lower right")
72
+ print(tpr_fpr_table)