Spaces:
Paused
Paused
Upload 77 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- videoretalking/third_part/face3d/checkpoints/model_name/test_opt.txt +34 -0
- videoretalking/third_part/face3d/coeff_detector.py +118 -0
- videoretalking/third_part/face3d/data/__init__.py +116 -0
- videoretalking/third_part/face3d/data/base_dataset.py +125 -0
- videoretalking/third_part/face3d/data/flist_dataset.py +125 -0
- videoretalking/third_part/face3d/data/image_folder.py +66 -0
- videoretalking/third_part/face3d/data/template_dataset.py +75 -0
- videoretalking/third_part/face3d/data_preparation.py +45 -0
- videoretalking/third_part/face3d/extract_kp_videos.py +109 -0
- videoretalking/third_part/face3d/face_recon_videos.py +157 -0
- videoretalking/third_part/face3d/models/__init__.py +67 -0
- videoretalking/third_part/face3d/models/arcface_torch/README.md +164 -0
- videoretalking/third_part/face3d/models/arcface_torch/backbones/__init__.py +25 -0
- videoretalking/third_part/face3d/models/arcface_torch/backbones/iresnet.py +187 -0
- videoretalking/third_part/face3d/models/arcface_torch/backbones/iresnet2060.py +176 -0
- videoretalking/third_part/face3d/models/arcface_torch/backbones/mobilefacenet.py +130 -0
- videoretalking/third_part/face3d/models/arcface_torch/configs/3millions.py +23 -0
- videoretalking/third_part/face3d/models/arcface_torch/configs/3millions_pfc.py +23 -0
- videoretalking/third_part/face3d/models/arcface_torch/configs/__init__.py +0 -0
- videoretalking/third_part/face3d/models/arcface_torch/configs/base.py +56 -0
- videoretalking/third_part/face3d/models/arcface_torch/configs/glint360k_mbf.py +26 -0
- videoretalking/third_part/face3d/models/arcface_torch/configs/glint360k_r100.py +26 -0
- videoretalking/third_part/face3d/models/arcface_torch/configs/glint360k_r18.py +26 -0
- videoretalking/third_part/face3d/models/arcface_torch/configs/glint360k_r34.py +26 -0
- videoretalking/third_part/face3d/models/arcface_torch/configs/glint360k_r50.py +26 -0
- videoretalking/third_part/face3d/models/arcface_torch/configs/ms1mv3_mbf.py +26 -0
- videoretalking/third_part/face3d/models/arcface_torch/configs/ms1mv3_r18.py +26 -0
- videoretalking/third_part/face3d/models/arcface_torch/configs/ms1mv3_r2060.py +26 -0
- videoretalking/third_part/face3d/models/arcface_torch/configs/ms1mv3_r34.py +26 -0
- videoretalking/third_part/face3d/models/arcface_torch/configs/ms1mv3_r50.py +26 -0
- videoretalking/third_part/face3d/models/arcface_torch/configs/speed.py +23 -0
- videoretalking/third_part/face3d/models/arcface_torch/dataset.py +124 -0
- videoretalking/third_part/face3d/models/arcface_torch/docs/eval.md +31 -0
- videoretalking/third_part/face3d/models/arcface_torch/docs/install.md +51 -0
- videoretalking/third_part/face3d/models/arcface_torch/docs/modelzoo.md +0 -0
- videoretalking/third_part/face3d/models/arcface_torch/docs/speed_benchmark.md +93 -0
- videoretalking/third_part/face3d/models/arcface_torch/eval/__init__.py +0 -0
- videoretalking/third_part/face3d/models/arcface_torch/eval/verification.py +407 -0
- videoretalking/third_part/face3d/models/arcface_torch/eval_ijbc.py +483 -0
- videoretalking/third_part/face3d/models/arcface_torch/inference.py +35 -0
- videoretalking/third_part/face3d/models/arcface_torch/losses.py +42 -0
- videoretalking/third_part/face3d/models/arcface_torch/onnx_helper.py +250 -0
- videoretalking/third_part/face3d/models/arcface_torch/onnx_ijbc.py +267 -0
- videoretalking/third_part/face3d/models/arcface_torch/partial_fc.py +222 -0
- videoretalking/third_part/face3d/models/arcface_torch/requirement.txt +5 -0
- videoretalking/third_part/face3d/models/arcface_torch/run.sh +2 -0
- videoretalking/third_part/face3d/models/arcface_torch/torch2onnx.py +59 -0
- videoretalking/third_part/face3d/models/arcface_torch/train.py +141 -0
- videoretalking/third_part/face3d/models/arcface_torch/utils/__init__.py +0 -0
- videoretalking/third_part/face3d/models/arcface_torch/utils/plot.py +72 -0
videoretalking/third_part/face3d/checkpoints/model_name/test_opt.txt
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
----------------- Options ---------------
|
2 |
+
add_image: True
|
3 |
+
bfm_folder: BFM
|
4 |
+
bfm_model: BFM_model_front.mat
|
5 |
+
camera_d: 10.0
|
6 |
+
center: 112.0
|
7 |
+
checkpoints_dir: ./checkpoints
|
8 |
+
dataset_mode: None
|
9 |
+
ddp_port: 12355
|
10 |
+
display_per_batch: True
|
11 |
+
epoch: 20 [default: latest]
|
12 |
+
eval_batch_nums: inf
|
13 |
+
focal: 1015.0
|
14 |
+
gpu_ids: 0
|
15 |
+
inference_batch_size: 8
|
16 |
+
init_path: checkpoints/init_model/resnet50-0676ba61.pth
|
17 |
+
input_dir: demo_video [default: None]
|
18 |
+
isTrain: False [default: None]
|
19 |
+
keypoint_dir: demo_cctv [default: None]
|
20 |
+
model: facerecon
|
21 |
+
name: model_name [default: face_recon]
|
22 |
+
net_recon: resnet50
|
23 |
+
output_dir: demo_cctv [default: mp4]
|
24 |
+
phase: test
|
25 |
+
save_split_files: False
|
26 |
+
suffix:
|
27 |
+
use_ddp: False [default: True]
|
28 |
+
use_last_fc: False
|
29 |
+
verbose: False
|
30 |
+
vis_batch_nums: 1
|
31 |
+
world_size: 1
|
32 |
+
z_far: 15.0
|
33 |
+
z_near: 5.0
|
34 |
+
----------------- End -------------------
|
videoretalking/third_part/face3d/coeff_detector.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import glob
|
3 |
+
import numpy as np
|
4 |
+
from os import makedirs, name
|
5 |
+
from PIL import Image
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
|
11 |
+
from face3d.options.inference_options import InferenceOptions
|
12 |
+
from face3d.models import create_model
|
13 |
+
from face3d.util.preprocess import align_img
|
14 |
+
from face3d.util.load_mats import load_lm3d
|
15 |
+
from face3d.extract_kp_videos import KeypointExtractor
|
16 |
+
|
17 |
+
|
18 |
+
class CoeffDetector(nn.Module):
|
19 |
+
def __init__(self, opt):
|
20 |
+
super().__init__()
|
21 |
+
|
22 |
+
self.model = create_model(opt)
|
23 |
+
self.model.setup(opt)
|
24 |
+
self.model.device = 'cuda'
|
25 |
+
self.model.parallelize()
|
26 |
+
self.model.eval()
|
27 |
+
|
28 |
+
self.lm3d_std = load_lm3d(opt.bfm_folder)
|
29 |
+
|
30 |
+
def forward(self, img, lm):
|
31 |
+
|
32 |
+
img, trans_params = self.image_transform(img, lm)
|
33 |
+
|
34 |
+
data_input = {
|
35 |
+
'imgs': img[None],
|
36 |
+
}
|
37 |
+
self.model.set_input(data_input)
|
38 |
+
self.model.test()
|
39 |
+
pred_coeff = {key:self.model.pred_coeffs_dict[key].cpu().numpy() for key in self.model.pred_coeffs_dict}
|
40 |
+
pred_coeff = np.concatenate([
|
41 |
+
pred_coeff['id'],
|
42 |
+
pred_coeff['exp'],
|
43 |
+
pred_coeff['tex'],
|
44 |
+
pred_coeff['angle'],
|
45 |
+
pred_coeff['gamma'],
|
46 |
+
pred_coeff['trans'],
|
47 |
+
trans_params[None],
|
48 |
+
], 1)
|
49 |
+
|
50 |
+
return {'coeff_3dmm':pred_coeff,
|
51 |
+
'crop_img': Image.fromarray((img.cpu().permute(1, 2, 0).numpy()*255).astype(np.uint8))}
|
52 |
+
|
53 |
+
def image_transform(self, images, lm):
|
54 |
+
"""
|
55 |
+
param:
|
56 |
+
images: -- PIL image
|
57 |
+
lm: -- numpy array
|
58 |
+
"""
|
59 |
+
W,H = images.size
|
60 |
+
if np.mean(lm) == -1:
|
61 |
+
lm = (self.lm3d_std[:, :2]+1)/2.
|
62 |
+
lm = np.concatenate(
|
63 |
+
[lm[:, :1]*W, lm[:, 1:2]*H], 1
|
64 |
+
)
|
65 |
+
else:
|
66 |
+
lm[:, -1] = H - 1 - lm[:, -1]
|
67 |
+
|
68 |
+
trans_params, img, lm, _ = align_img(images, lm, self.lm3d_std)
|
69 |
+
img = torch.tensor(np.array(img)/255., dtype=torch.float32).permute(2, 0, 1)
|
70 |
+
trans_params = np.array([float(item) for item in np.hsplit(trans_params, 5)])
|
71 |
+
trans_params = torch.tensor(trans_params.astype(np.float32))
|
72 |
+
return img, trans_params
|
73 |
+
|
74 |
+
def get_data_path(root, keypoint_root):
|
75 |
+
filenames = list()
|
76 |
+
keypoint_filenames = list()
|
77 |
+
|
78 |
+
IMAGE_EXTENSIONS_LOWERCASE = {'jpg', 'png', 'jpeg', 'webp'}
|
79 |
+
IMAGE_EXTENSIONS = IMAGE_EXTENSIONS_LOWERCASE.union({f.upper() for f in IMAGE_EXTENSIONS_LOWERCASE})
|
80 |
+
extensions = IMAGE_EXTENSIONS
|
81 |
+
|
82 |
+
for ext in extensions:
|
83 |
+
filenames += glob.glob(f'{root}/*.{ext}', recursive=True)
|
84 |
+
filenames = sorted(filenames)
|
85 |
+
for filename in filenames:
|
86 |
+
name = os.path.splitext(os.path.basename(filename))[0]
|
87 |
+
keypoint_filenames.append(
|
88 |
+
os.path.join(keypoint_root, name + '.txt')
|
89 |
+
)
|
90 |
+
return filenames, keypoint_filenames
|
91 |
+
|
92 |
+
|
93 |
+
if __name__ == "__main__":
|
94 |
+
opt = InferenceOptions().parse()
|
95 |
+
coeff_detector = CoeffDetector(opt)
|
96 |
+
kp_extractor = KeypointExtractor()
|
97 |
+
image_names, keypoint_names = get_data_path(opt.input_dir, opt.keypoint_dir)
|
98 |
+
makedirs(opt.keypoint_dir, exist_ok=True)
|
99 |
+
makedirs(opt.output_dir, exist_ok=True)
|
100 |
+
|
101 |
+
for image_name, keypoint_name in tqdm(zip(image_names, keypoint_names)):
|
102 |
+
image = Image.open(image_name)
|
103 |
+
if not os.path.isfile(keypoint_name):
|
104 |
+
lm = kp_extractor.extract_keypoint(image, keypoint_name)
|
105 |
+
else:
|
106 |
+
lm = np.loadtxt(keypoint_name).astype(np.float32)
|
107 |
+
lm = lm.reshape([-1, 2])
|
108 |
+
predicted = coeff_detector(image, lm)
|
109 |
+
name = os.path.splitext(os.path.basename(image_name))[0]
|
110 |
+
np.savetxt(
|
111 |
+
"{}/{}_3dmm_coeff.txt".format(opt.output_dir, name),
|
112 |
+
predicted['coeff_3dmm'].reshape(-1))
|
113 |
+
|
114 |
+
|
115 |
+
|
116 |
+
|
117 |
+
|
118 |
+
|
videoretalking/third_part/face3d/data/__init__.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This package includes all the modules related to data loading and preprocessing
|
2 |
+
|
3 |
+
To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
|
4 |
+
You need to implement four functions:
|
5 |
+
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
|
6 |
+
-- <__len__>: return the size of dataset.
|
7 |
+
-- <__getitem__>: get a data point from data loader.
|
8 |
+
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
|
9 |
+
|
10 |
+
Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
|
11 |
+
See our template dataset class 'template_dataset.py' for more details.
|
12 |
+
"""
|
13 |
+
import numpy as np
|
14 |
+
import importlib
|
15 |
+
import torch.utils.data
|
16 |
+
from face3d.data.base_dataset import BaseDataset
|
17 |
+
|
18 |
+
|
19 |
+
def find_dataset_using_name(dataset_name):
|
20 |
+
"""Import the module "data/[dataset_name]_dataset.py".
|
21 |
+
|
22 |
+
In the file, the class called DatasetNameDataset() will
|
23 |
+
be instantiated. It has to be a subclass of BaseDataset,
|
24 |
+
and it is case-insensitive.
|
25 |
+
"""
|
26 |
+
dataset_filename = "data." + dataset_name + "_dataset"
|
27 |
+
datasetlib = importlib.import_module(dataset_filename)
|
28 |
+
|
29 |
+
dataset = None
|
30 |
+
target_dataset_name = dataset_name.replace('_', '') + 'dataset'
|
31 |
+
for name, cls in datasetlib.__dict__.items():
|
32 |
+
if name.lower() == target_dataset_name.lower() \
|
33 |
+
and issubclass(cls, BaseDataset):
|
34 |
+
dataset = cls
|
35 |
+
|
36 |
+
if dataset is None:
|
37 |
+
raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
|
38 |
+
|
39 |
+
return dataset
|
40 |
+
|
41 |
+
|
42 |
+
def get_option_setter(dataset_name):
|
43 |
+
"""Return the static method <modify_commandline_options> of the dataset class."""
|
44 |
+
dataset_class = find_dataset_using_name(dataset_name)
|
45 |
+
return dataset_class.modify_commandline_options
|
46 |
+
|
47 |
+
|
48 |
+
def create_dataset(opt, rank=0):
|
49 |
+
"""Create a dataset given the option.
|
50 |
+
|
51 |
+
This function wraps the class CustomDatasetDataLoader.
|
52 |
+
This is the main interface between this package and 'train.py'/'test.py'
|
53 |
+
|
54 |
+
Example:
|
55 |
+
>>> from data import create_dataset
|
56 |
+
>>> dataset = create_dataset(opt)
|
57 |
+
"""
|
58 |
+
data_loader = CustomDatasetDataLoader(opt, rank=rank)
|
59 |
+
dataset = data_loader.load_data()
|
60 |
+
return dataset
|
61 |
+
|
62 |
+
class CustomDatasetDataLoader():
|
63 |
+
"""Wrapper class of Dataset class that performs multi-threaded data loading"""
|
64 |
+
|
65 |
+
def __init__(self, opt, rank=0):
|
66 |
+
"""Initialize this class
|
67 |
+
|
68 |
+
Step 1: create a dataset instance given the name [dataset_mode]
|
69 |
+
Step 2: create a multi-threaded data loader.
|
70 |
+
"""
|
71 |
+
self.opt = opt
|
72 |
+
dataset_class = find_dataset_using_name(opt.dataset_mode)
|
73 |
+
self.dataset = dataset_class(opt)
|
74 |
+
self.sampler = None
|
75 |
+
print("rank %d %s dataset [%s] was created" % (rank, self.dataset.name, type(self.dataset).__name__))
|
76 |
+
if opt.use_ddp and opt.isTrain:
|
77 |
+
world_size = opt.world_size
|
78 |
+
self.sampler = torch.utils.data.distributed.DistributedSampler(
|
79 |
+
self.dataset,
|
80 |
+
num_replicas=world_size,
|
81 |
+
rank=rank,
|
82 |
+
shuffle=not opt.serial_batches
|
83 |
+
)
|
84 |
+
self.dataloader = torch.utils.data.DataLoader(
|
85 |
+
self.dataset,
|
86 |
+
sampler=self.sampler,
|
87 |
+
num_workers=int(opt.num_threads / world_size),
|
88 |
+
batch_size=int(opt.batch_size / world_size),
|
89 |
+
drop_last=True)
|
90 |
+
else:
|
91 |
+
self.dataloader = torch.utils.data.DataLoader(
|
92 |
+
self.dataset,
|
93 |
+
batch_size=opt.batch_size,
|
94 |
+
shuffle=(not opt.serial_batches) and opt.isTrain,
|
95 |
+
num_workers=int(opt.num_threads),
|
96 |
+
drop_last=True
|
97 |
+
)
|
98 |
+
|
99 |
+
def set_epoch(self, epoch):
|
100 |
+
self.dataset.current_epoch = epoch
|
101 |
+
if self.sampler is not None:
|
102 |
+
self.sampler.set_epoch(epoch)
|
103 |
+
|
104 |
+
def load_data(self):
|
105 |
+
return self
|
106 |
+
|
107 |
+
def __len__(self):
|
108 |
+
"""Return the number of data in the dataset"""
|
109 |
+
return min(len(self.dataset), self.opt.max_dataset_size)
|
110 |
+
|
111 |
+
def __iter__(self):
|
112 |
+
"""Return a batch of data"""
|
113 |
+
for i, data in enumerate(self.dataloader):
|
114 |
+
if i * self.opt.batch_size >= self.opt.max_dataset_size:
|
115 |
+
break
|
116 |
+
yield data
|
videoretalking/third_part/face3d/data/base_dataset.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
|
2 |
+
|
3 |
+
It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
|
4 |
+
"""
|
5 |
+
import random
|
6 |
+
import numpy as np
|
7 |
+
import torch.utils.data as data
|
8 |
+
from PIL import Image
|
9 |
+
import torchvision.transforms as transforms
|
10 |
+
from abc import ABC, abstractmethod
|
11 |
+
|
12 |
+
|
13 |
+
class BaseDataset(data.Dataset, ABC):
|
14 |
+
"""This class is an abstract base class (ABC) for datasets.
|
15 |
+
|
16 |
+
To create a subclass, you need to implement the following four functions:
|
17 |
+
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
|
18 |
+
-- <__len__>: return the size of dataset.
|
19 |
+
-- <__getitem__>: get a data point.
|
20 |
+
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, opt):
|
24 |
+
"""Initialize the class; save the options in the class
|
25 |
+
|
26 |
+
Parameters:
|
27 |
+
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
|
28 |
+
"""
|
29 |
+
self.opt = opt
|
30 |
+
# self.root = opt.dataroot
|
31 |
+
self.current_epoch = 0
|
32 |
+
|
33 |
+
@staticmethod
|
34 |
+
def modify_commandline_options(parser, is_train):
|
35 |
+
"""Add new dataset-specific options, and rewrite default values for existing options.
|
36 |
+
|
37 |
+
Parameters:
|
38 |
+
parser -- original option parser
|
39 |
+
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
the modified parser.
|
43 |
+
"""
|
44 |
+
return parser
|
45 |
+
|
46 |
+
@abstractmethod
|
47 |
+
def __len__(self):
|
48 |
+
"""Return the total number of images in the dataset."""
|
49 |
+
return 0
|
50 |
+
|
51 |
+
@abstractmethod
|
52 |
+
def __getitem__(self, index):
|
53 |
+
"""Return a data point and its metadata information.
|
54 |
+
|
55 |
+
Parameters:
|
56 |
+
index - - a random integer for data indexing
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
a dictionary of data with their names. It ususally contains the data itself and its metadata information.
|
60 |
+
"""
|
61 |
+
pass
|
62 |
+
|
63 |
+
|
64 |
+
def get_transform(grayscale=False):
|
65 |
+
transform_list = []
|
66 |
+
if grayscale:
|
67 |
+
transform_list.append(transforms.Grayscale(1))
|
68 |
+
transform_list += [transforms.ToTensor()]
|
69 |
+
return transforms.Compose(transform_list)
|
70 |
+
|
71 |
+
def get_affine_mat(opt, size):
|
72 |
+
shift_x, shift_y, scale, rot_angle, flip = 0., 0., 1., 0., False
|
73 |
+
w, h = size
|
74 |
+
|
75 |
+
if 'shift' in opt.preprocess:
|
76 |
+
shift_pixs = int(opt.shift_pixs)
|
77 |
+
shift_x = random.randint(-shift_pixs, shift_pixs)
|
78 |
+
shift_y = random.randint(-shift_pixs, shift_pixs)
|
79 |
+
if 'scale' in opt.preprocess:
|
80 |
+
scale = 1 + opt.scale_delta * (2 * random.random() - 1)
|
81 |
+
if 'rot' in opt.preprocess:
|
82 |
+
rot_angle = opt.rot_angle * (2 * random.random() - 1)
|
83 |
+
rot_rad = -rot_angle * np.pi/180
|
84 |
+
if 'flip' in opt.preprocess:
|
85 |
+
flip = random.random() > 0.5
|
86 |
+
|
87 |
+
shift_to_origin = np.array([1, 0, -w//2, 0, 1, -h//2, 0, 0, 1]).reshape([3, 3])
|
88 |
+
flip_mat = np.array([-1 if flip else 1, 0, 0, 0, 1, 0, 0, 0, 1]).reshape([3, 3])
|
89 |
+
shift_mat = np.array([1, 0, shift_x, 0, 1, shift_y, 0, 0, 1]).reshape([3, 3])
|
90 |
+
rot_mat = np.array([np.cos(rot_rad), np.sin(rot_rad), 0, -np.sin(rot_rad), np.cos(rot_rad), 0, 0, 0, 1]).reshape([3, 3])
|
91 |
+
scale_mat = np.array([scale, 0, 0, 0, scale, 0, 0, 0, 1]).reshape([3, 3])
|
92 |
+
shift_to_center = np.array([1, 0, w//2, 0, 1, h//2, 0, 0, 1]).reshape([3, 3])
|
93 |
+
|
94 |
+
affine = shift_to_center @ scale_mat @ rot_mat @ shift_mat @ flip_mat @ shift_to_origin
|
95 |
+
affine_inv = np.linalg.inv(affine)
|
96 |
+
return affine, affine_inv, flip
|
97 |
+
|
98 |
+
def apply_img_affine(img, affine_inv, method=Image.BICUBIC):
|
99 |
+
return img.transform(img.size, Image.AFFINE, data=affine_inv.flatten()[:6], resample=Image.BICUBIC)
|
100 |
+
|
101 |
+
def apply_lm_affine(landmark, affine, flip, size):
|
102 |
+
_, h = size
|
103 |
+
lm = landmark.copy()
|
104 |
+
lm[:, 1] = h - 1 - lm[:, 1]
|
105 |
+
lm = np.concatenate((lm, np.ones([lm.shape[0], 1])), -1)
|
106 |
+
lm = lm @ np.transpose(affine)
|
107 |
+
lm[:, :2] = lm[:, :2] / lm[:, 2:]
|
108 |
+
lm = lm[:, :2]
|
109 |
+
lm[:, 1] = h - 1 - lm[:, 1]
|
110 |
+
if flip:
|
111 |
+
lm_ = lm.copy()
|
112 |
+
lm_[:17] = lm[16::-1]
|
113 |
+
lm_[17:22] = lm[26:21:-1]
|
114 |
+
lm_[22:27] = lm[21:16:-1]
|
115 |
+
lm_[31:36] = lm[35:30:-1]
|
116 |
+
lm_[36:40] = lm[45:41:-1]
|
117 |
+
lm_[40:42] = lm[47:45:-1]
|
118 |
+
lm_[42:46] = lm[39:35:-1]
|
119 |
+
lm_[46:48] = lm[41:39:-1]
|
120 |
+
lm_[48:55] = lm[54:47:-1]
|
121 |
+
lm_[55:60] = lm[59:54:-1]
|
122 |
+
lm_[60:65] = lm[64:59:-1]
|
123 |
+
lm_[65:68] = lm[67:64:-1]
|
124 |
+
lm = lm_
|
125 |
+
return lm
|
videoretalking/third_part/face3d/data/flist_dataset.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This script defines the custom dataset for Deep3DFaceRecon_pytorch
|
2 |
+
"""
|
3 |
+
|
4 |
+
import os.path
|
5 |
+
from data.base_dataset import BaseDataset, get_transform, get_affine_mat, apply_img_affine, apply_lm_affine
|
6 |
+
from data.image_folder import make_dataset
|
7 |
+
from PIL import Image
|
8 |
+
import random
|
9 |
+
import util.util as util
|
10 |
+
import numpy as np
|
11 |
+
import json
|
12 |
+
import torch
|
13 |
+
from scipy.io import loadmat, savemat
|
14 |
+
import pickle
|
15 |
+
from util.preprocess import align_img, estimate_norm
|
16 |
+
from util.load_mats import load_lm3d
|
17 |
+
|
18 |
+
|
19 |
+
def default_flist_reader(flist):
|
20 |
+
"""
|
21 |
+
flist format: impath label\nimpath label\n ...(same to caffe's filelist)
|
22 |
+
"""
|
23 |
+
imlist = []
|
24 |
+
with open(flist, 'r') as rf:
|
25 |
+
for line in rf.readlines():
|
26 |
+
impath = line.strip()
|
27 |
+
imlist.append(impath)
|
28 |
+
|
29 |
+
return imlist
|
30 |
+
|
31 |
+
def jason_flist_reader(flist):
|
32 |
+
with open(flist, 'r') as fp:
|
33 |
+
info = json.load(fp)
|
34 |
+
return info
|
35 |
+
|
36 |
+
def parse_label(label):
|
37 |
+
return torch.tensor(np.array(label).astype(np.float32))
|
38 |
+
|
39 |
+
|
40 |
+
class FlistDataset(BaseDataset):
|
41 |
+
"""
|
42 |
+
It requires one directories to host training images '/path/to/data/train'
|
43 |
+
You can train the model with the dataset flag '--dataroot /path/to/data'.
|
44 |
+
"""
|
45 |
+
|
46 |
+
def __init__(self, opt):
|
47 |
+
"""Initialize this dataset class.
|
48 |
+
|
49 |
+
Parameters:
|
50 |
+
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
|
51 |
+
"""
|
52 |
+
BaseDataset.__init__(self, opt)
|
53 |
+
|
54 |
+
self.lm3d_std = load_lm3d(opt.bfm_folder)
|
55 |
+
|
56 |
+
msk_names = default_flist_reader(opt.flist)
|
57 |
+
self.msk_paths = [os.path.join(opt.data_root, i) for i in msk_names]
|
58 |
+
|
59 |
+
self.size = len(self.msk_paths)
|
60 |
+
self.opt = opt
|
61 |
+
|
62 |
+
self.name = 'train' if opt.isTrain else 'val'
|
63 |
+
if '_' in opt.flist:
|
64 |
+
self.name += '_' + opt.flist.split(os.sep)[-1].split('_')[0]
|
65 |
+
|
66 |
+
|
67 |
+
def __getitem__(self, index):
|
68 |
+
"""Return a data point and its metadata information.
|
69 |
+
|
70 |
+
Parameters:
|
71 |
+
index (int) -- a random integer for data indexing
|
72 |
+
|
73 |
+
Returns a dictionary that contains A, B, A_paths and B_paths
|
74 |
+
img (tensor) -- an image in the input domain
|
75 |
+
msk (tensor) -- its corresponding attention mask
|
76 |
+
lm (tensor) -- its corresponding 3d landmarks
|
77 |
+
im_paths (str) -- image paths
|
78 |
+
aug_flag (bool) -- a flag used to tell whether its raw or augmented
|
79 |
+
"""
|
80 |
+
msk_path = self.msk_paths[index % self.size] # make sure index is within then range
|
81 |
+
img_path = msk_path.replace('mask/', '')
|
82 |
+
lm_path = '.'.join(msk_path.replace('mask', 'landmarks').split('.')[:-1]) + '.txt'
|
83 |
+
|
84 |
+
raw_img = Image.open(img_path).convert('RGB')
|
85 |
+
raw_msk = Image.open(msk_path).convert('RGB')
|
86 |
+
raw_lm = np.loadtxt(lm_path).astype(np.float32)
|
87 |
+
|
88 |
+
_, img, lm, msk = align_img(raw_img, raw_lm, self.lm3d_std, raw_msk)
|
89 |
+
|
90 |
+
aug_flag = self.opt.use_aug and self.opt.isTrain
|
91 |
+
if aug_flag:
|
92 |
+
img, lm, msk = self._augmentation(img, lm, self.opt, msk)
|
93 |
+
|
94 |
+
_, H = img.size
|
95 |
+
M = estimate_norm(lm, H)
|
96 |
+
transform = get_transform()
|
97 |
+
img_tensor = transform(img)
|
98 |
+
msk_tensor = transform(msk)[:1, ...]
|
99 |
+
lm_tensor = parse_label(lm)
|
100 |
+
M_tensor = parse_label(M)
|
101 |
+
|
102 |
+
|
103 |
+
return {'imgs': img_tensor,
|
104 |
+
'lms': lm_tensor,
|
105 |
+
'msks': msk_tensor,
|
106 |
+
'M': M_tensor,
|
107 |
+
'im_paths': img_path,
|
108 |
+
'aug_flag': aug_flag,
|
109 |
+
'dataset': self.name}
|
110 |
+
|
111 |
+
def _augmentation(self, img, lm, opt, msk=None):
|
112 |
+
affine, affine_inv, flip = get_affine_mat(opt, img.size)
|
113 |
+
img = apply_img_affine(img, affine_inv)
|
114 |
+
lm = apply_lm_affine(lm, affine, flip, img.size)
|
115 |
+
if msk is not None:
|
116 |
+
msk = apply_img_affine(msk, affine_inv, method=Image.BILINEAR)
|
117 |
+
return img, lm, msk
|
118 |
+
|
119 |
+
|
120 |
+
|
121 |
+
|
122 |
+
def __len__(self):
|
123 |
+
"""Return the total number of images in the dataset.
|
124 |
+
"""
|
125 |
+
return self.size
|
videoretalking/third_part/face3d/data/image_folder.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""A modified image folder class
|
2 |
+
|
3 |
+
We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
|
4 |
+
so that this class can load images from both current directory and its subdirectories.
|
5 |
+
"""
|
6 |
+
import numpy as np
|
7 |
+
import torch.utils.data as data
|
8 |
+
|
9 |
+
from PIL import Image
|
10 |
+
import os
|
11 |
+
import os.path
|
12 |
+
|
13 |
+
IMG_EXTENSIONS = [
|
14 |
+
'.jpg', '.JPG', '.jpeg', '.JPEG',
|
15 |
+
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
|
16 |
+
'.tif', '.TIF', '.tiff', '.TIFF',
|
17 |
+
]
|
18 |
+
|
19 |
+
|
20 |
+
def is_image_file(filename):
|
21 |
+
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
22 |
+
|
23 |
+
|
24 |
+
def make_dataset(dir, max_dataset_size=float("inf")):
|
25 |
+
images = []
|
26 |
+
assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir
|
27 |
+
|
28 |
+
for root, _, fnames in sorted(os.walk(dir, followlinks=True)):
|
29 |
+
for fname in fnames:
|
30 |
+
if is_image_file(fname):
|
31 |
+
path = os.path.join(root, fname)
|
32 |
+
images.append(path)
|
33 |
+
return images[:min(max_dataset_size, len(images))]
|
34 |
+
|
35 |
+
|
36 |
+
def default_loader(path):
|
37 |
+
return Image.open(path).convert('RGB')
|
38 |
+
|
39 |
+
|
40 |
+
class ImageFolder(data.Dataset):
|
41 |
+
|
42 |
+
def __init__(self, root, transform=None, return_paths=False,
|
43 |
+
loader=default_loader):
|
44 |
+
imgs = make_dataset(root)
|
45 |
+
if len(imgs) == 0:
|
46 |
+
raise(RuntimeError("Found 0 images in: " + root + "\n"
|
47 |
+
"Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
|
48 |
+
|
49 |
+
self.root = root
|
50 |
+
self.imgs = imgs
|
51 |
+
self.transform = transform
|
52 |
+
self.return_paths = return_paths
|
53 |
+
self.loader = loader
|
54 |
+
|
55 |
+
def __getitem__(self, index):
|
56 |
+
path = self.imgs[index]
|
57 |
+
img = self.loader(path)
|
58 |
+
if self.transform is not None:
|
59 |
+
img = self.transform(img)
|
60 |
+
if self.return_paths:
|
61 |
+
return img, path
|
62 |
+
else:
|
63 |
+
return img
|
64 |
+
|
65 |
+
def __len__(self):
|
66 |
+
return len(self.imgs)
|
videoretalking/third_part/face3d/data/template_dataset.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Dataset class template
|
2 |
+
|
3 |
+
This module provides a template for users to implement custom datasets.
|
4 |
+
You can specify '--dataset_mode template' to use this dataset.
|
5 |
+
The class name should be consistent with both the filename and its dataset_mode option.
|
6 |
+
The filename should be <dataset_mode>_dataset.py
|
7 |
+
The class name should be <Dataset_mode>Dataset.py
|
8 |
+
You need to implement the following functions:
|
9 |
+
-- <modify_commandline_options>: Add dataset-specific options and rewrite default values for existing options.
|
10 |
+
-- <__init__>: Initialize this dataset class.
|
11 |
+
-- <__getitem__>: Return a data point and its metadata information.
|
12 |
+
-- <__len__>: Return the number of images.
|
13 |
+
"""
|
14 |
+
from data.base_dataset import BaseDataset, get_transform
|
15 |
+
# from data.image_folder import make_dataset
|
16 |
+
# from PIL import Image
|
17 |
+
|
18 |
+
|
19 |
+
class TemplateDataset(BaseDataset):
|
20 |
+
"""A template dataset class for you to implement custom datasets."""
|
21 |
+
@staticmethod
|
22 |
+
def modify_commandline_options(parser, is_train):
|
23 |
+
"""Add new dataset-specific options, and rewrite default values for existing options.
|
24 |
+
|
25 |
+
Parameters:
|
26 |
+
parser -- original option parser
|
27 |
+
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
the modified parser.
|
31 |
+
"""
|
32 |
+
parser.add_argument('--new_dataset_option', type=float, default=1.0, help='new dataset option')
|
33 |
+
parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0) # specify dataset-specific default values
|
34 |
+
return parser
|
35 |
+
|
36 |
+
def __init__(self, opt):
|
37 |
+
"""Initialize this dataset class.
|
38 |
+
|
39 |
+
Parameters:
|
40 |
+
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
|
41 |
+
|
42 |
+
A few things can be done here.
|
43 |
+
- save the options (have been done in BaseDataset)
|
44 |
+
- get image paths and meta information of the dataset.
|
45 |
+
- define the image transformation.
|
46 |
+
"""
|
47 |
+
# save the option and dataset root
|
48 |
+
BaseDataset.__init__(self, opt)
|
49 |
+
# get the image paths of your dataset;
|
50 |
+
self.image_paths = [] # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root
|
51 |
+
# define the default transform function. You can use <base_dataset.get_transform>; You can also define your custom transform function
|
52 |
+
self.transform = get_transform(opt)
|
53 |
+
|
54 |
+
def __getitem__(self, index):
|
55 |
+
"""Return a data point and its metadata information.
|
56 |
+
|
57 |
+
Parameters:
|
58 |
+
index -- a random integer for data indexing
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
a dictionary of data with their names. It usually contains the data itself and its metadata information.
|
62 |
+
|
63 |
+
Step 1: get a random image path: e.g., path = self.image_paths[index]
|
64 |
+
Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB').
|
65 |
+
Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image)
|
66 |
+
Step 4: return a data point as a dictionary.
|
67 |
+
"""
|
68 |
+
path = 'temp' # needs to be a string
|
69 |
+
data_A = None # needs to be a tensor
|
70 |
+
data_B = None # needs to be a tensor
|
71 |
+
return {'data_A': data_A, 'data_B': data_B, 'path': path}
|
72 |
+
|
73 |
+
def __len__(self):
|
74 |
+
"""Return the total number of images."""
|
75 |
+
return len(self.image_paths)
|
videoretalking/third_part/face3d/data_preparation.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This script is the data preparation script for Deep3DFaceRecon_pytorch
|
2 |
+
"""
|
3 |
+
|
4 |
+
import os
|
5 |
+
import numpy as np
|
6 |
+
import argparse
|
7 |
+
from util.detect_lm68 import detect_68p,load_lm_graph
|
8 |
+
from util.skin_mask import get_skin_mask
|
9 |
+
from util.generate_list import check_list, write_list
|
10 |
+
import warnings
|
11 |
+
warnings.filterwarnings("ignore")
|
12 |
+
|
13 |
+
parser = argparse.ArgumentParser()
|
14 |
+
parser.add_argument('--data_root', type=str, default='datasets', help='root directory for training data')
|
15 |
+
parser.add_argument('--img_folder', nargs="+", required=True, help='folders of training images')
|
16 |
+
parser.add_argument('--mode', type=str, default='train', help='train or val')
|
17 |
+
opt = parser.parse_args()
|
18 |
+
|
19 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
20 |
+
|
21 |
+
def data_prepare(folder_list,mode):
|
22 |
+
|
23 |
+
lm_sess,input_op,output_op = load_lm_graph('./checkpoints/lm_model/68lm_detector.pb') # load a tensorflow version 68-landmark detector
|
24 |
+
|
25 |
+
for img_folder in folder_list:
|
26 |
+
detect_68p(img_folder,lm_sess,input_op,output_op) # detect landmarks for images
|
27 |
+
get_skin_mask(img_folder) # generate skin attention mask for images
|
28 |
+
|
29 |
+
# create files that record path to all training data
|
30 |
+
msks_list = []
|
31 |
+
for img_folder in folder_list:
|
32 |
+
path = os.path.join(img_folder, 'mask')
|
33 |
+
msks_list += ['/'.join([img_folder, 'mask', i]) for i in sorted(os.listdir(path)) if 'jpg' in i or
|
34 |
+
'png' in i or 'jpeg' in i or 'PNG' in i]
|
35 |
+
|
36 |
+
imgs_list = [i.replace('mask/', '') for i in msks_list]
|
37 |
+
lms_list = [i.replace('mask', 'landmarks') for i in msks_list]
|
38 |
+
lms_list = ['.'.join(i.split('.')[:-1]) + '.txt' for i in lms_list]
|
39 |
+
|
40 |
+
lms_list_final, imgs_list_final, msks_list_final = check_list(lms_list, imgs_list, msks_list) # check if the path is valid
|
41 |
+
write_list(lms_list_final, imgs_list_final, msks_list_final, mode=mode) # save files
|
42 |
+
|
43 |
+
if __name__ == '__main__':
|
44 |
+
print('Datasets:',opt.img_folder)
|
45 |
+
data_prepare([os.path.join(opt.data_root,folder) for folder in opt.img_folder],opt.mode)
|
videoretalking/third_part/face3d/extract_kp_videos.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import time
|
4 |
+
import glob
|
5 |
+
import argparse
|
6 |
+
import face_alignment
|
7 |
+
import numpy as np
|
8 |
+
from PIL import Image
|
9 |
+
import torch
|
10 |
+
from tqdm import tqdm
|
11 |
+
from itertools import cycle
|
12 |
+
|
13 |
+
from torch.multiprocessing import Pool, Process, set_start_method
|
14 |
+
|
15 |
+
class KeypointExtractor():
|
16 |
+
def __init__(self):
|
17 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
18 |
+
self.detector = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, device=device)
|
19 |
+
|
20 |
+
def extract_keypoint(self, images, name=None, info=True):
|
21 |
+
if isinstance(images, list):
|
22 |
+
keypoints = []
|
23 |
+
if info:
|
24 |
+
i_range = tqdm(images,desc='landmark Det:')
|
25 |
+
else:
|
26 |
+
i_range = images
|
27 |
+
|
28 |
+
for image in i_range:
|
29 |
+
current_kp = self.extract_keypoint(image)
|
30 |
+
if np.mean(current_kp) == -1 and keypoints:
|
31 |
+
keypoints.append(keypoints[-1])
|
32 |
+
else:
|
33 |
+
keypoints.append(current_kp[None])
|
34 |
+
|
35 |
+
keypoints = np.concatenate(keypoints, 0)
|
36 |
+
np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
|
37 |
+
return keypoints
|
38 |
+
else:
|
39 |
+
while True:
|
40 |
+
try:
|
41 |
+
keypoints = self.detector.get_landmarks_from_image(np.array(images))[0]
|
42 |
+
break
|
43 |
+
except RuntimeError as e:
|
44 |
+
if str(e).startswith('CUDA'):
|
45 |
+
print("Warning: out of memory, sleep for 1s")
|
46 |
+
time.sleep(1)
|
47 |
+
else:
|
48 |
+
print(e)
|
49 |
+
break
|
50 |
+
except TypeError:
|
51 |
+
print('No face detected in this image')
|
52 |
+
shape = [68, 2]
|
53 |
+
keypoints = -1. * np.ones(shape)
|
54 |
+
break
|
55 |
+
if name is not None:
|
56 |
+
np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
|
57 |
+
return keypoints
|
58 |
+
|
59 |
+
def read_video(filename):
|
60 |
+
frames = []
|
61 |
+
cap = cv2.VideoCapture(filename)
|
62 |
+
while cap.isOpened():
|
63 |
+
ret, frame = cap.read()
|
64 |
+
if ret:
|
65 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
66 |
+
frame = Image.fromarray(frame)
|
67 |
+
frames.append(frame)
|
68 |
+
else:
|
69 |
+
break
|
70 |
+
cap.release()
|
71 |
+
return frames
|
72 |
+
|
73 |
+
def run(data):
|
74 |
+
filename, opt, device = data
|
75 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = device
|
76 |
+
kp_extractor = KeypointExtractor()
|
77 |
+
images = read_video(filename)
|
78 |
+
name = filename.split('/')[-2:]
|
79 |
+
os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True)
|
80 |
+
kp_extractor.extract_keypoint(
|
81 |
+
images,
|
82 |
+
name=os.path.join(opt.output_dir, name[-2], name[-1])
|
83 |
+
)
|
84 |
+
|
85 |
+
if __name__ == '__main__':
|
86 |
+
set_start_method('spawn')
|
87 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
88 |
+
parser.add_argument('--input_dir', type=str, help='the folder of the input files')
|
89 |
+
parser.add_argument('--output_dir', type=str, help='the folder of the output files')
|
90 |
+
parser.add_argument('--device_ids', type=str, default='0,1')
|
91 |
+
parser.add_argument('--workers', type=int, default=4)
|
92 |
+
|
93 |
+
opt = parser.parse_args()
|
94 |
+
filenames = list()
|
95 |
+
VIDEO_EXTENSIONS_LOWERCASE = {'mp4'}
|
96 |
+
VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE})
|
97 |
+
extensions = VIDEO_EXTENSIONS
|
98 |
+
|
99 |
+
for ext in extensions:
|
100 |
+
os.listdir(f'{opt.input_dir}')
|
101 |
+
print(f'{opt.input_dir}/*.{ext}')
|
102 |
+
filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}'))
|
103 |
+
print('Total number of videos:', len(filenames))
|
104 |
+
pool = Pool(opt.workers)
|
105 |
+
args_list = cycle([opt])
|
106 |
+
device_ids = opt.device_ids.split(",")
|
107 |
+
device_ids = cycle(device_ids)
|
108 |
+
for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))):
|
109 |
+
None
|
videoretalking/third_part/face3d/face_recon_videos.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import glob
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
from tqdm import tqdm
|
7 |
+
from scipy.io import savemat
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from models import create_model
|
12 |
+
from options.inference_options import InferenceOptions
|
13 |
+
from util.preprocess import align_img
|
14 |
+
from util.load_mats import load_lm3d
|
15 |
+
from util.util import mkdirs, tensor2im, save_image
|
16 |
+
|
17 |
+
|
18 |
+
def get_data_path(root, keypoint_root):
|
19 |
+
filenames = list()
|
20 |
+
keypoint_filenames = list()
|
21 |
+
|
22 |
+
VIDEO_EXTENSIONS_LOWERCASE = {'mp4'}
|
23 |
+
VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE})
|
24 |
+
extensions = VIDEO_EXTENSIONS
|
25 |
+
|
26 |
+
for ext in extensions:
|
27 |
+
filenames += glob.glob(f'{root}/**/*.{ext}', recursive=True)
|
28 |
+
filenames = sorted(filenames)
|
29 |
+
keypoint_filenames = sorted(glob.glob(f'{keypoint_root}/**/*.txt', recursive=True))
|
30 |
+
assert len(filenames) == len(keypoint_filenames)
|
31 |
+
|
32 |
+
return filenames, keypoint_filenames
|
33 |
+
|
34 |
+
class VideoPathDataset(torch.utils.data.Dataset):
|
35 |
+
def __init__(self, filenames, txt_filenames, bfm_folder):
|
36 |
+
self.filenames = filenames
|
37 |
+
self.txt_filenames = txt_filenames
|
38 |
+
self.lm3d_std = load_lm3d(bfm_folder)
|
39 |
+
|
40 |
+
def __len__(self):
|
41 |
+
return len(self.filenames)
|
42 |
+
|
43 |
+
def __getitem__(self, index):
|
44 |
+
filename = self.filenames[index]
|
45 |
+
txt_filename = self.txt_filenames[index]
|
46 |
+
frames = self.read_video(filename)
|
47 |
+
lm = np.loadtxt(txt_filename).astype(np.float32)
|
48 |
+
lm = lm.reshape([len(frames), -1, 2])
|
49 |
+
out_images, out_trans_params = list(), list()
|
50 |
+
for i in range(len(frames)):
|
51 |
+
out_img, _, out_trans_param \
|
52 |
+
= self.image_transform(frames[i], lm[i])
|
53 |
+
out_images.append(out_img[None])
|
54 |
+
out_trans_params.append(out_trans_param[None])
|
55 |
+
return {
|
56 |
+
'imgs': torch.cat(out_images, 0),
|
57 |
+
'trans_param':torch.cat(out_trans_params, 0),
|
58 |
+
'filename': filename
|
59 |
+
}
|
60 |
+
|
61 |
+
def read_video(self, filename):
|
62 |
+
frames = list()
|
63 |
+
cap = cv2.VideoCapture(filename)
|
64 |
+
while cap.isOpened():
|
65 |
+
ret, frame = cap.read()
|
66 |
+
if ret:
|
67 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
68 |
+
frame = Image.fromarray(frame)
|
69 |
+
frames.append(frame)
|
70 |
+
else:
|
71 |
+
break
|
72 |
+
cap.release()
|
73 |
+
return frames
|
74 |
+
|
75 |
+
def image_transform(self, images, lm):
|
76 |
+
W,H = images.size
|
77 |
+
if np.mean(lm) == -1:
|
78 |
+
lm = (self.lm3d_std[:, :2]+1)/2.
|
79 |
+
lm = np.concatenate(
|
80 |
+
[lm[:, :1]*W, lm[:, 1:2]*H], 1
|
81 |
+
)
|
82 |
+
else:
|
83 |
+
lm[:, -1] = H - 1 - lm[:, -1]
|
84 |
+
|
85 |
+
trans_params, img, lm, _ = align_img(images, lm, self.lm3d_std)
|
86 |
+
img = torch.tensor(np.array(img)/255., dtype=torch.float32).permute(2, 0, 1)
|
87 |
+
lm = torch.tensor(lm)
|
88 |
+
trans_params = np.array([float(item) for item in np.hsplit(trans_params, 5)])
|
89 |
+
trans_params = torch.tensor(trans_params.astype(np.float32))
|
90 |
+
return img, lm, trans_params
|
91 |
+
|
92 |
+
def main(opt, model):
|
93 |
+
# import torch.multiprocessing
|
94 |
+
# torch.multiprocessing.set_sharing_strategy('file_system')
|
95 |
+
filenames, keypoint_filenames = get_data_path(opt.input_dir, opt.keypoint_dir)
|
96 |
+
dataset = VideoPathDataset(filenames, keypoint_filenames, opt.bfm_folder)
|
97 |
+
dataloader = torch.utils.data.DataLoader(
|
98 |
+
dataset,
|
99 |
+
batch_size=1, # can noly set to one here!
|
100 |
+
shuffle=False,
|
101 |
+
drop_last=False,
|
102 |
+
num_workers=0,
|
103 |
+
)
|
104 |
+
batch_size = opt.inference_batch_size
|
105 |
+
for data in tqdm(dataloader):
|
106 |
+
num_batch = data['imgs'][0].shape[0] // batch_size + 1
|
107 |
+
pred_coeffs = list()
|
108 |
+
for index in range(num_batch):
|
109 |
+
data_input = {
|
110 |
+
'imgs': data['imgs'][0,index*batch_size:(index+1)*batch_size],
|
111 |
+
}
|
112 |
+
model.set_input(data_input)
|
113 |
+
model.test()
|
114 |
+
pred_coeff = {key:model.pred_coeffs_dict[key].cpu().numpy() for key in model.pred_coeffs_dict}
|
115 |
+
pred_coeff = np.concatenate([
|
116 |
+
pred_coeff['id'],
|
117 |
+
pred_coeff['exp'],
|
118 |
+
pred_coeff['tex'],
|
119 |
+
pred_coeff['angle'],
|
120 |
+
pred_coeff['gamma'],
|
121 |
+
pred_coeff['trans']], 1)
|
122 |
+
pred_coeffs.append(pred_coeff)
|
123 |
+
visuals = model.get_current_visuals() # get image results
|
124 |
+
if False: # debug
|
125 |
+
for name in visuals:
|
126 |
+
images = visuals[name]
|
127 |
+
for i in range(images.shape[0]):
|
128 |
+
image_numpy = tensor2im(images[i])
|
129 |
+
save_image(
|
130 |
+
image_numpy,
|
131 |
+
os.path.join(
|
132 |
+
opt.output_dir,
|
133 |
+
os.path.basename(data['filename'][0])+str(i).zfill(5)+'.jpg')
|
134 |
+
)
|
135 |
+
exit()
|
136 |
+
|
137 |
+
pred_coeffs = np.concatenate(pred_coeffs, 0)
|
138 |
+
pred_trans_params = data['trans_param'][0].cpu().numpy()
|
139 |
+
name = data['filename'][0].split('/')[-2:]
|
140 |
+
name[-1] = os.path.splitext(name[-1])[0] + '.mat'
|
141 |
+
os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True)
|
142 |
+
savemat(
|
143 |
+
os.path.join(opt.output_dir, name[-2], name[-1]),
|
144 |
+
{'coeff':pred_coeffs, 'transform_params':pred_trans_params}
|
145 |
+
)
|
146 |
+
|
147 |
+
if __name__ == '__main__':
|
148 |
+
opt = InferenceOptions().parse() # get test options
|
149 |
+
model = create_model(opt)
|
150 |
+
model.setup(opt)
|
151 |
+
model.device = 'cuda:0'
|
152 |
+
model.parallelize()
|
153 |
+
model.eval()
|
154 |
+
|
155 |
+
main(opt, model)
|
156 |
+
|
157 |
+
|
videoretalking/third_part/face3d/models/__init__.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This package contains modules related to objective functions, optimizations, and network architectures.
|
2 |
+
|
3 |
+
To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
|
4 |
+
You need to implement the following five functions:
|
5 |
+
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
|
6 |
+
-- <set_input>: unpack data from dataset and apply preprocessing.
|
7 |
+
-- <forward>: produce intermediate results.
|
8 |
+
-- <optimize_parameters>: calculate loss, gradients, and update network weights.
|
9 |
+
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
|
10 |
+
|
11 |
+
In the function <__init__>, you need to define four lists:
|
12 |
+
-- self.loss_names (str list): specify the training losses that you want to plot and save.
|
13 |
+
-- self.model_names (str list): define networks used in our training.
|
14 |
+
-- self.visual_names (str list): specify the images that you want to display and save.
|
15 |
+
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
|
16 |
+
|
17 |
+
Now you can use the model class by specifying flag '--model dummy'.
|
18 |
+
See our template model class 'template_model.py' for more details.
|
19 |
+
"""
|
20 |
+
|
21 |
+
import importlib
|
22 |
+
from face3d.models.base_model import BaseModel
|
23 |
+
|
24 |
+
|
25 |
+
def find_model_using_name(model_name):
|
26 |
+
"""Import the module "models/[model_name]_model.py".
|
27 |
+
|
28 |
+
In the file, the class called DatasetNameModel() will
|
29 |
+
be instantiated. It has to be a subclass of BaseModel,
|
30 |
+
and it is case-insensitive.
|
31 |
+
"""
|
32 |
+
model_filename = "face3d.models." + model_name + "_model"
|
33 |
+
modellib = importlib.import_module(model_filename)
|
34 |
+
model = None
|
35 |
+
target_model_name = model_name.replace('_', '') + 'model'
|
36 |
+
for name, cls in modellib.__dict__.items():
|
37 |
+
if name.lower() == target_model_name.lower() \
|
38 |
+
and issubclass(cls, BaseModel):
|
39 |
+
model = cls
|
40 |
+
|
41 |
+
if model is None:
|
42 |
+
print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
|
43 |
+
exit(0)
|
44 |
+
|
45 |
+
return model
|
46 |
+
|
47 |
+
|
48 |
+
def get_option_setter(model_name):
|
49 |
+
"""Return the static method <modify_commandline_options> of the model class."""
|
50 |
+
model_class = find_model_using_name(model_name)
|
51 |
+
return model_class.modify_commandline_options
|
52 |
+
|
53 |
+
|
54 |
+
def create_model(opt):
|
55 |
+
"""Create a model given the option.
|
56 |
+
|
57 |
+
This function warps the class CustomDatasetDataLoader.
|
58 |
+
This is the main interface between this package and 'train.py'/'test.py'
|
59 |
+
|
60 |
+
Example:
|
61 |
+
>>> from models import create_model
|
62 |
+
>>> model = create_model(opt)
|
63 |
+
"""
|
64 |
+
model = find_model_using_name(opt.model)
|
65 |
+
instance = model(opt)
|
66 |
+
print("model [%s] was created" % type(instance).__name__)
|
67 |
+
return instance
|
videoretalking/third_part/face3d/models/arcface_torch/README.md
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Distributed Arcface Training in Pytorch
|
2 |
+
|
3 |
+
This is a deep learning library that makes face recognition efficient, and effective, which can train tens of millions
|
4 |
+
identity on a single server.
|
5 |
+
|
6 |
+
## Requirements
|
7 |
+
|
8 |
+
- Install [pytorch](http://pytorch.org) (torch>=1.6.0), our doc for [install.md](docs/install.md).
|
9 |
+
- `pip install -r requirements.txt`.
|
10 |
+
- Download the dataset
|
11 |
+
from [https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_](https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_)
|
12 |
+
.
|
13 |
+
|
14 |
+
## How to Training
|
15 |
+
|
16 |
+
To train a model, run `train.py` with the path to the configs:
|
17 |
+
|
18 |
+
### 1. Single node, 8 GPUs:
|
19 |
+
|
20 |
+
```shell
|
21 |
+
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r50
|
22 |
+
```
|
23 |
+
|
24 |
+
### 2. Multiple nodes, each node 8 GPUs:
|
25 |
+
|
26 |
+
Node 0:
|
27 |
+
|
28 |
+
```shell
|
29 |
+
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50
|
30 |
+
```
|
31 |
+
|
32 |
+
Node 1:
|
33 |
+
|
34 |
+
```shell
|
35 |
+
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50
|
36 |
+
```
|
37 |
+
|
38 |
+
### 3.Training resnet2060 with 8 GPUs:
|
39 |
+
|
40 |
+
```shell
|
41 |
+
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r2060.py
|
42 |
+
```
|
43 |
+
|
44 |
+
## Model Zoo
|
45 |
+
|
46 |
+
- The models are available for non-commercial research purposes only.
|
47 |
+
- All models can be found in here.
|
48 |
+
- [Baidu Yun Pan](https://pan.baidu.com/s/1CL-l4zWqsI1oDuEEYVhj-g): e8pw
|
49 |
+
- [onedrive](https://1drv.ms/u/s!AswpsDO2toNKq0lWY69vN58GR6mw?e=p9Ov5d)
|
50 |
+
|
51 |
+
### Performance on [**ICCV2021-MFR**](http://iccv21-mfr.com/)
|
52 |
+
|
53 |
+
ICCV2021-MFR testset consists of non-celebrities so we can ensure that it has very few overlap with public available face
|
54 |
+
recognition training set, such as MS1M and CASIA as they mostly collected from online celebrities.
|
55 |
+
As the result, we can evaluate the FAIR performance for different algorithms.
|
56 |
+
|
57 |
+
For **ICCV2021-MFR-ALL** set, TAR is measured on all-to-all 1:1 protocal, with FAR less than 0.000001(e-6). The
|
58 |
+
globalised multi-racial testset contains 242,143 identities and 1,624,305 images.
|
59 |
+
|
60 |
+
For **ICCV2021-MFR-MASK** set, TAR is measured on mask-to-nonmask 1:1 protocal, with FAR less than 0.0001(e-4).
|
61 |
+
Mask testset contains 6,964 identities, 6,964 masked images and 13,928 non-masked images.
|
62 |
+
There are totally 13,928 positive pairs and 96,983,824 negative pairs.
|
63 |
+
|
64 |
+
| Datasets | backbone | Training throughout | Size / MB | **ICCV2021-MFR-MASK** | **ICCV2021-MFR-ALL** |
|
65 |
+
| :---: | :--- | :--- | :--- |:--- |:--- |
|
66 |
+
| MS1MV3 | r18 | - | 91 | **47.85** | **68.33** |
|
67 |
+
| Glint360k | r18 | 8536 | 91 | **53.32** | **72.07** |
|
68 |
+
| MS1MV3 | r34 | - | 130 | **58.72** | **77.36** |
|
69 |
+
| Glint360k | r34 | 6344 | 130 | **65.10** | **83.02** |
|
70 |
+
| MS1MV3 | r50 | 5500 | 166 | **63.85** | **80.53** |
|
71 |
+
| Glint360k | r50 | 5136 | 166 | **70.23** | **87.08** |
|
72 |
+
| MS1MV3 | r100 | - | 248 | **69.09** | **84.31** |
|
73 |
+
| Glint360k | r100 | 3332 | 248 | **75.57** | **90.66** |
|
74 |
+
| MS1MV3 | mobilefacenet | 12185 | 7.8 | **41.52** | **65.26** |
|
75 |
+
| Glint360k | mobilefacenet | 11197 | 7.8 | **44.52** | **66.48** |
|
76 |
+
|
77 |
+
### Performance on IJB-C and Verification Datasets
|
78 |
+
|
79 |
+
| Datasets | backbone | IJBC(1e-05) | IJBC(1e-04) | agedb30 | cfp_fp | lfw | log |
|
80 |
+
| :---: | :--- | :--- | :--- | :--- |:--- |:--- |:--- |
|
81 |
+
| MS1MV3 | r18 | 92.07 | 94.66 | 97.77 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r18_fp16/training.log)|
|
82 |
+
| MS1MV3 | r34 | 94.10 | 95.90 | 98.10 | 98.67 | 99.80 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r34_fp16/training.log)|
|
83 |
+
| MS1MV3 | r50 | 94.79 | 96.46 | 98.35 | 98.96 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r50_fp16/training.log)|
|
84 |
+
| MS1MV3 | r100 | 95.31 | 96.81 | 98.48 | 99.06 | 99.85 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r100_fp16/training.log)|
|
85 |
+
| MS1MV3 | **r2060**| 95.34 | 97.11 | 98.67 | 99.24 | 99.87 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r2060_fp16/training.log)|
|
86 |
+
| Glint360k |r18-0.1 | 93.16 | 95.33 | 97.72 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r18_fp16_0.1/training.log)|
|
87 |
+
| Glint360k |r34-0.1 | 95.16 | 96.56 | 98.33 | 98.78 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r34_fp16_0.1/training.log)|
|
88 |
+
| Glint360k |r50-0.1 | 95.61 | 96.97 | 98.38 | 99.20 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r50_fp16_0.1/training.log)|
|
89 |
+
| Glint360k |r100-0.1 | 95.88 | 97.32 | 98.48 | 99.29 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r100_fp16_0.1/training.log)|
|
90 |
+
|
91 |
+
[comment]: <> (More details see [model.md](docs/modelzoo.md) in docs.)
|
92 |
+
|
93 |
+
|
94 |
+
## [Speed Benchmark](docs/speed_benchmark.md)
|
95 |
+
|
96 |
+
**Arcface Torch** can train large-scale face recognition training set efficiently and quickly. When the number of
|
97 |
+
classes in training sets is greater than 300K and the training is sufficient, partial fc sampling strategy will get same
|
98 |
+
accuracy with several times faster training performance and smaller GPU memory.
|
99 |
+
Partial FC is a sparse variant of the model parallel architecture for large sacle face recognition. Partial FC use a
|
100 |
+
sparse softmax, where each batch dynamicly sample a subset of class centers for training. In each iteration, only a
|
101 |
+
sparse part of the parameters will be updated, which can reduce a lot of GPU memory and calculations. With Partial FC,
|
102 |
+
we can scale trainset of 29 millions identities, the largest to date. Partial FC also supports multi-machine distributed
|
103 |
+
training and mixed precision training.
|
104 |
+
|
105 |
+
![Image text](https://github.com/anxiangsir/insightface_arcface_log/blob/master/partial_fc_v2.png)
|
106 |
+
|
107 |
+
More details see
|
108 |
+
[speed_benchmark.md](docs/speed_benchmark.md) in docs.
|
109 |
+
|
110 |
+
### 1. Training speed of different parallel methods (samples / second), Tesla V100 32GB * 8. (Larger is better)
|
111 |
+
|
112 |
+
`-` means training failed because of gpu memory limitations.
|
113 |
+
|
114 |
+
| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
|
115 |
+
| :--- | :--- | :--- | :--- |
|
116 |
+
|125000 | 4681 | 4824 | 5004 |
|
117 |
+
|1400000 | **1672** | 3043 | 4738 |
|
118 |
+
|5500000 | **-** | **1389** | 3975 |
|
119 |
+
|8000000 | **-** | **-** | 3565 |
|
120 |
+
|16000000 | **-** | **-** | 2679 |
|
121 |
+
|29000000 | **-** | **-** | **1855** |
|
122 |
+
|
123 |
+
### 2. GPU memory cost of different parallel methods (MB per GPU), Tesla V100 32GB * 8. (Smaller is better)
|
124 |
+
|
125 |
+
| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
|
126 |
+
| :--- | :--- | :--- | :--- |
|
127 |
+
|125000 | 7358 | 5306 | 4868 |
|
128 |
+
|1400000 | 32252 | 11178 | 6056 |
|
129 |
+
|5500000 | **-** | 32188 | 9854 |
|
130 |
+
|8000000 | **-** | **-** | 12310 |
|
131 |
+
|16000000 | **-** | **-** | 19950 |
|
132 |
+
|29000000 | **-** | **-** | 32324 |
|
133 |
+
|
134 |
+
## Evaluation ICCV2021-MFR and IJB-C
|
135 |
+
|
136 |
+
More details see [eval.md](docs/eval.md) in docs.
|
137 |
+
|
138 |
+
## Test
|
139 |
+
|
140 |
+
We tested many versions of PyTorch. Please create an issue if you are having trouble.
|
141 |
+
|
142 |
+
- [x] torch 1.6.0
|
143 |
+
- [x] torch 1.7.1
|
144 |
+
- [x] torch 1.8.0
|
145 |
+
- [x] torch 1.9.0
|
146 |
+
|
147 |
+
## Citation
|
148 |
+
|
149 |
+
```
|
150 |
+
@inproceedings{deng2019arcface,
|
151 |
+
title={Arcface: Additive angular margin loss for deep face recognition},
|
152 |
+
author={Deng, Jiankang and Guo, Jia and Xue, Niannan and Zafeiriou, Stefanos},
|
153 |
+
booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
|
154 |
+
pages={4690--4699},
|
155 |
+
year={2019}
|
156 |
+
}
|
157 |
+
@inproceedings{an2020partical_fc,
|
158 |
+
title={Partial FC: Training 10 Million Identities on a Single Machine},
|
159 |
+
author={An, Xiang and Zhu, Xuhan and Xiao, Yang and Wu, Lan and Zhang, Ming and Gao, Yuan and Qin, Bin and
|
160 |
+
Zhang, Debing and Fu Ying},
|
161 |
+
booktitle={Arxiv 2010.05222},
|
162 |
+
year={2020}
|
163 |
+
}
|
164 |
+
```
|
videoretalking/third_part/face3d/models/arcface_torch/backbones/__init__.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .iresnet import iresnet18, iresnet34, iresnet50, iresnet100, iresnet200
|
2 |
+
from .mobilefacenet import get_mbf
|
3 |
+
|
4 |
+
|
5 |
+
def get_model(name, **kwargs):
|
6 |
+
# resnet
|
7 |
+
if name == "r18":
|
8 |
+
return iresnet18(False, **kwargs)
|
9 |
+
elif name == "r34":
|
10 |
+
return iresnet34(False, **kwargs)
|
11 |
+
elif name == "r50":
|
12 |
+
return iresnet50(False, **kwargs)
|
13 |
+
elif name == "r100":
|
14 |
+
return iresnet100(False, **kwargs)
|
15 |
+
elif name == "r200":
|
16 |
+
return iresnet200(False, **kwargs)
|
17 |
+
elif name == "r2060":
|
18 |
+
from .iresnet2060 import iresnet2060
|
19 |
+
return iresnet2060(False, **kwargs)
|
20 |
+
elif name == "mbf":
|
21 |
+
fp16 = kwargs.get("fp16", False)
|
22 |
+
num_features = kwargs.get("num_features", 512)
|
23 |
+
return get_mbf(fp16=fp16, num_features=num_features)
|
24 |
+
else:
|
25 |
+
raise ValueError()
|
videoretalking/third_part/face3d/models/arcface_torch/backbones/iresnet.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
__all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200']
|
5 |
+
|
6 |
+
|
7 |
+
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
8 |
+
"""3x3 convolution with padding"""
|
9 |
+
return nn.Conv2d(in_planes,
|
10 |
+
out_planes,
|
11 |
+
kernel_size=3,
|
12 |
+
stride=stride,
|
13 |
+
padding=dilation,
|
14 |
+
groups=groups,
|
15 |
+
bias=False,
|
16 |
+
dilation=dilation)
|
17 |
+
|
18 |
+
|
19 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
20 |
+
"""1x1 convolution"""
|
21 |
+
return nn.Conv2d(in_planes,
|
22 |
+
out_planes,
|
23 |
+
kernel_size=1,
|
24 |
+
stride=stride,
|
25 |
+
bias=False)
|
26 |
+
|
27 |
+
|
28 |
+
class IBasicBlock(nn.Module):
|
29 |
+
expansion = 1
|
30 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None,
|
31 |
+
groups=1, base_width=64, dilation=1):
|
32 |
+
super(IBasicBlock, self).__init__()
|
33 |
+
if groups != 1 or base_width != 64:
|
34 |
+
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
35 |
+
if dilation > 1:
|
36 |
+
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
37 |
+
self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,)
|
38 |
+
self.conv1 = conv3x3(inplanes, planes)
|
39 |
+
self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,)
|
40 |
+
self.prelu = nn.PReLU(planes)
|
41 |
+
self.conv2 = conv3x3(planes, planes, stride)
|
42 |
+
self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,)
|
43 |
+
self.downsample = downsample
|
44 |
+
self.stride = stride
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
identity = x
|
48 |
+
out = self.bn1(x)
|
49 |
+
out = self.conv1(out)
|
50 |
+
out = self.bn2(out)
|
51 |
+
out = self.prelu(out)
|
52 |
+
out = self.conv2(out)
|
53 |
+
out = self.bn3(out)
|
54 |
+
if self.downsample is not None:
|
55 |
+
identity = self.downsample(x)
|
56 |
+
out += identity
|
57 |
+
return out
|
58 |
+
|
59 |
+
|
60 |
+
class IResNet(nn.Module):
|
61 |
+
fc_scale = 7 * 7
|
62 |
+
def __init__(self,
|
63 |
+
block, layers, dropout=0, num_features=512, zero_init_residual=False,
|
64 |
+
groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
|
65 |
+
super(IResNet, self).__init__()
|
66 |
+
self.fp16 = fp16
|
67 |
+
self.inplanes = 64
|
68 |
+
self.dilation = 1
|
69 |
+
if replace_stride_with_dilation is None:
|
70 |
+
replace_stride_with_dilation = [False, False, False]
|
71 |
+
if len(replace_stride_with_dilation) != 3:
|
72 |
+
raise ValueError("replace_stride_with_dilation should be None "
|
73 |
+
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
74 |
+
self.groups = groups
|
75 |
+
self.base_width = width_per_group
|
76 |
+
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
|
77 |
+
self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
|
78 |
+
self.prelu = nn.PReLU(self.inplanes)
|
79 |
+
self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
|
80 |
+
self.layer2 = self._make_layer(block,
|
81 |
+
128,
|
82 |
+
layers[1],
|
83 |
+
stride=2,
|
84 |
+
dilate=replace_stride_with_dilation[0])
|
85 |
+
self.layer3 = self._make_layer(block,
|
86 |
+
256,
|
87 |
+
layers[2],
|
88 |
+
stride=2,
|
89 |
+
dilate=replace_stride_with_dilation[1])
|
90 |
+
self.layer4 = self._make_layer(block,
|
91 |
+
512,
|
92 |
+
layers[3],
|
93 |
+
stride=2,
|
94 |
+
dilate=replace_stride_with_dilation[2])
|
95 |
+
self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,)
|
96 |
+
self.dropout = nn.Dropout(p=dropout, inplace=True)
|
97 |
+
self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
|
98 |
+
self.features = nn.BatchNorm1d(num_features, eps=1e-05)
|
99 |
+
nn.init.constant_(self.features.weight, 1.0)
|
100 |
+
self.features.weight.requires_grad = False
|
101 |
+
|
102 |
+
for m in self.modules():
|
103 |
+
if isinstance(m, nn.Conv2d):
|
104 |
+
nn.init.normal_(m.weight, 0, 0.1)
|
105 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
106 |
+
nn.init.constant_(m.weight, 1)
|
107 |
+
nn.init.constant_(m.bias, 0)
|
108 |
+
|
109 |
+
if zero_init_residual:
|
110 |
+
for m in self.modules():
|
111 |
+
if isinstance(m, IBasicBlock):
|
112 |
+
nn.init.constant_(m.bn2.weight, 0)
|
113 |
+
|
114 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
115 |
+
downsample = None
|
116 |
+
previous_dilation = self.dilation
|
117 |
+
if dilate:
|
118 |
+
self.dilation *= stride
|
119 |
+
stride = 1
|
120 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
121 |
+
downsample = nn.Sequential(
|
122 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
123 |
+
nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
|
124 |
+
)
|
125 |
+
layers = []
|
126 |
+
layers.append(
|
127 |
+
block(self.inplanes, planes, stride, downsample, self.groups,
|
128 |
+
self.base_width, previous_dilation))
|
129 |
+
self.inplanes = planes * block.expansion
|
130 |
+
for _ in range(1, blocks):
|
131 |
+
layers.append(
|
132 |
+
block(self.inplanes,
|
133 |
+
planes,
|
134 |
+
groups=self.groups,
|
135 |
+
base_width=self.base_width,
|
136 |
+
dilation=self.dilation))
|
137 |
+
|
138 |
+
return nn.Sequential(*layers)
|
139 |
+
|
140 |
+
def forward(self, x):
|
141 |
+
with torch.cuda.amp.autocast(self.fp16):
|
142 |
+
x = self.conv1(x)
|
143 |
+
x = self.bn1(x)
|
144 |
+
x = self.prelu(x)
|
145 |
+
x = self.layer1(x)
|
146 |
+
x = self.layer2(x)
|
147 |
+
x = self.layer3(x)
|
148 |
+
x = self.layer4(x)
|
149 |
+
x = self.bn2(x)
|
150 |
+
x = torch.flatten(x, 1)
|
151 |
+
x = self.dropout(x)
|
152 |
+
x = self.fc(x.float() if self.fp16 else x)
|
153 |
+
x = self.features(x)
|
154 |
+
return x
|
155 |
+
|
156 |
+
|
157 |
+
def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
|
158 |
+
model = IResNet(block, layers, **kwargs)
|
159 |
+
if pretrained:
|
160 |
+
raise ValueError()
|
161 |
+
return model
|
162 |
+
|
163 |
+
|
164 |
+
def iresnet18(pretrained=False, progress=True, **kwargs):
|
165 |
+
return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained,
|
166 |
+
progress, **kwargs)
|
167 |
+
|
168 |
+
|
169 |
+
def iresnet34(pretrained=False, progress=True, **kwargs):
|
170 |
+
return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained,
|
171 |
+
progress, **kwargs)
|
172 |
+
|
173 |
+
|
174 |
+
def iresnet50(pretrained=False, progress=True, **kwargs):
|
175 |
+
return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained,
|
176 |
+
progress, **kwargs)
|
177 |
+
|
178 |
+
|
179 |
+
def iresnet100(pretrained=False, progress=True, **kwargs):
|
180 |
+
return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained,
|
181 |
+
progress, **kwargs)
|
182 |
+
|
183 |
+
|
184 |
+
def iresnet200(pretrained=False, progress=True, **kwargs):
|
185 |
+
return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained,
|
186 |
+
progress, **kwargs)
|
187 |
+
|
videoretalking/third_part/face3d/models/arcface_torch/backbones/iresnet2060.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
assert torch.__version__ >= "1.8.1"
|
5 |
+
from torch.utils.checkpoint import checkpoint_sequential
|
6 |
+
|
7 |
+
__all__ = ['iresnet2060']
|
8 |
+
|
9 |
+
|
10 |
+
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
11 |
+
"""3x3 convolution with padding"""
|
12 |
+
return nn.Conv2d(in_planes,
|
13 |
+
out_planes,
|
14 |
+
kernel_size=3,
|
15 |
+
stride=stride,
|
16 |
+
padding=dilation,
|
17 |
+
groups=groups,
|
18 |
+
bias=False,
|
19 |
+
dilation=dilation)
|
20 |
+
|
21 |
+
|
22 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
23 |
+
"""1x1 convolution"""
|
24 |
+
return nn.Conv2d(in_planes,
|
25 |
+
out_planes,
|
26 |
+
kernel_size=1,
|
27 |
+
stride=stride,
|
28 |
+
bias=False)
|
29 |
+
|
30 |
+
|
31 |
+
class IBasicBlock(nn.Module):
|
32 |
+
expansion = 1
|
33 |
+
|
34 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None,
|
35 |
+
groups=1, base_width=64, dilation=1):
|
36 |
+
super(IBasicBlock, self).__init__()
|
37 |
+
if groups != 1 or base_width != 64:
|
38 |
+
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
39 |
+
if dilation > 1:
|
40 |
+
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
41 |
+
self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05, )
|
42 |
+
self.conv1 = conv3x3(inplanes, planes)
|
43 |
+
self.bn2 = nn.BatchNorm2d(planes, eps=1e-05, )
|
44 |
+
self.prelu = nn.PReLU(planes)
|
45 |
+
self.conv2 = conv3x3(planes, planes, stride)
|
46 |
+
self.bn3 = nn.BatchNorm2d(planes, eps=1e-05, )
|
47 |
+
self.downsample = downsample
|
48 |
+
self.stride = stride
|
49 |
+
|
50 |
+
def forward(self, x):
|
51 |
+
identity = x
|
52 |
+
out = self.bn1(x)
|
53 |
+
out = self.conv1(out)
|
54 |
+
out = self.bn2(out)
|
55 |
+
out = self.prelu(out)
|
56 |
+
out = self.conv2(out)
|
57 |
+
out = self.bn3(out)
|
58 |
+
if self.downsample is not None:
|
59 |
+
identity = self.downsample(x)
|
60 |
+
out += identity
|
61 |
+
return out
|
62 |
+
|
63 |
+
|
64 |
+
class IResNet(nn.Module):
|
65 |
+
fc_scale = 7 * 7
|
66 |
+
|
67 |
+
def __init__(self,
|
68 |
+
block, layers, dropout=0, num_features=512, zero_init_residual=False,
|
69 |
+
groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
|
70 |
+
super(IResNet, self).__init__()
|
71 |
+
self.fp16 = fp16
|
72 |
+
self.inplanes = 64
|
73 |
+
self.dilation = 1
|
74 |
+
if replace_stride_with_dilation is None:
|
75 |
+
replace_stride_with_dilation = [False, False, False]
|
76 |
+
if len(replace_stride_with_dilation) != 3:
|
77 |
+
raise ValueError("replace_stride_with_dilation should be None "
|
78 |
+
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
79 |
+
self.groups = groups
|
80 |
+
self.base_width = width_per_group
|
81 |
+
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
|
82 |
+
self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
|
83 |
+
self.prelu = nn.PReLU(self.inplanes)
|
84 |
+
self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
|
85 |
+
self.layer2 = self._make_layer(block,
|
86 |
+
128,
|
87 |
+
layers[1],
|
88 |
+
stride=2,
|
89 |
+
dilate=replace_stride_with_dilation[0])
|
90 |
+
self.layer3 = self._make_layer(block,
|
91 |
+
256,
|
92 |
+
layers[2],
|
93 |
+
stride=2,
|
94 |
+
dilate=replace_stride_with_dilation[1])
|
95 |
+
self.layer4 = self._make_layer(block,
|
96 |
+
512,
|
97 |
+
layers[3],
|
98 |
+
stride=2,
|
99 |
+
dilate=replace_stride_with_dilation[2])
|
100 |
+
self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05, )
|
101 |
+
self.dropout = nn.Dropout(p=dropout, inplace=True)
|
102 |
+
self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
|
103 |
+
self.features = nn.BatchNorm1d(num_features, eps=1e-05)
|
104 |
+
nn.init.constant_(self.features.weight, 1.0)
|
105 |
+
self.features.weight.requires_grad = False
|
106 |
+
|
107 |
+
for m in self.modules():
|
108 |
+
if isinstance(m, nn.Conv2d):
|
109 |
+
nn.init.normal_(m.weight, 0, 0.1)
|
110 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
111 |
+
nn.init.constant_(m.weight, 1)
|
112 |
+
nn.init.constant_(m.bias, 0)
|
113 |
+
|
114 |
+
if zero_init_residual:
|
115 |
+
for m in self.modules():
|
116 |
+
if isinstance(m, IBasicBlock):
|
117 |
+
nn.init.constant_(m.bn2.weight, 0)
|
118 |
+
|
119 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
120 |
+
downsample = None
|
121 |
+
previous_dilation = self.dilation
|
122 |
+
if dilate:
|
123 |
+
self.dilation *= stride
|
124 |
+
stride = 1
|
125 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
126 |
+
downsample = nn.Sequential(
|
127 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
128 |
+
nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
|
129 |
+
)
|
130 |
+
layers = []
|
131 |
+
layers.append(
|
132 |
+
block(self.inplanes, planes, stride, downsample, self.groups,
|
133 |
+
self.base_width, previous_dilation))
|
134 |
+
self.inplanes = planes * block.expansion
|
135 |
+
for _ in range(1, blocks):
|
136 |
+
layers.append(
|
137 |
+
block(self.inplanes,
|
138 |
+
planes,
|
139 |
+
groups=self.groups,
|
140 |
+
base_width=self.base_width,
|
141 |
+
dilation=self.dilation))
|
142 |
+
|
143 |
+
return nn.Sequential(*layers)
|
144 |
+
|
145 |
+
def checkpoint(self, func, num_seg, x):
|
146 |
+
if self.training:
|
147 |
+
return checkpoint_sequential(func, num_seg, x)
|
148 |
+
else:
|
149 |
+
return func(x)
|
150 |
+
|
151 |
+
def forward(self, x):
|
152 |
+
with torch.cuda.amp.autocast(self.fp16):
|
153 |
+
x = self.conv1(x)
|
154 |
+
x = self.bn1(x)
|
155 |
+
x = self.prelu(x)
|
156 |
+
x = self.layer1(x)
|
157 |
+
x = self.checkpoint(self.layer2, 20, x)
|
158 |
+
x = self.checkpoint(self.layer3, 100, x)
|
159 |
+
x = self.layer4(x)
|
160 |
+
x = self.bn2(x)
|
161 |
+
x = torch.flatten(x, 1)
|
162 |
+
x = self.dropout(x)
|
163 |
+
x = self.fc(x.float() if self.fp16 else x)
|
164 |
+
x = self.features(x)
|
165 |
+
return x
|
166 |
+
|
167 |
+
|
168 |
+
def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
|
169 |
+
model = IResNet(block, layers, **kwargs)
|
170 |
+
if pretrained:
|
171 |
+
raise ValueError()
|
172 |
+
return model
|
173 |
+
|
174 |
+
|
175 |
+
def iresnet2060(pretrained=False, progress=True, **kwargs):
|
176 |
+
return _iresnet('iresnet2060', IBasicBlock, [3, 128, 1024 - 128, 3], pretrained, progress, **kwargs)
|
videoretalking/third_part/face3d/models/arcface_torch/backbones/mobilefacenet.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Adapted from https://github.com/cavalleria/cavaface.pytorch/blob/master/backbone/mobilefacenet.py
|
3 |
+
Original author cavalleria
|
4 |
+
'''
|
5 |
+
|
6 |
+
import torch.nn as nn
|
7 |
+
from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Sequential, Module
|
8 |
+
import torch
|
9 |
+
|
10 |
+
|
11 |
+
class Flatten(Module):
|
12 |
+
def forward(self, x):
|
13 |
+
return x.view(x.size(0), -1)
|
14 |
+
|
15 |
+
|
16 |
+
class ConvBlock(Module):
|
17 |
+
def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
|
18 |
+
super(ConvBlock, self).__init__()
|
19 |
+
self.layers = nn.Sequential(
|
20 |
+
Conv2d(in_c, out_c, kernel, groups=groups, stride=stride, padding=padding, bias=False),
|
21 |
+
BatchNorm2d(num_features=out_c),
|
22 |
+
PReLU(num_parameters=out_c)
|
23 |
+
)
|
24 |
+
|
25 |
+
def forward(self, x):
|
26 |
+
return self.layers(x)
|
27 |
+
|
28 |
+
|
29 |
+
class LinearBlock(Module):
|
30 |
+
def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
|
31 |
+
super(LinearBlock, self).__init__()
|
32 |
+
self.layers = nn.Sequential(
|
33 |
+
Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False),
|
34 |
+
BatchNorm2d(num_features=out_c)
|
35 |
+
)
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
return self.layers(x)
|
39 |
+
|
40 |
+
|
41 |
+
class DepthWise(Module):
|
42 |
+
def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1):
|
43 |
+
super(DepthWise, self).__init__()
|
44 |
+
self.residual = residual
|
45 |
+
self.layers = nn.Sequential(
|
46 |
+
ConvBlock(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)),
|
47 |
+
ConvBlock(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride),
|
48 |
+
LinearBlock(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
|
49 |
+
)
|
50 |
+
|
51 |
+
def forward(self, x):
|
52 |
+
short_cut = None
|
53 |
+
if self.residual:
|
54 |
+
short_cut = x
|
55 |
+
x = self.layers(x)
|
56 |
+
if self.residual:
|
57 |
+
output = short_cut + x
|
58 |
+
else:
|
59 |
+
output = x
|
60 |
+
return output
|
61 |
+
|
62 |
+
|
63 |
+
class Residual(Module):
|
64 |
+
def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)):
|
65 |
+
super(Residual, self).__init__()
|
66 |
+
modules = []
|
67 |
+
for _ in range(num_block):
|
68 |
+
modules.append(DepthWise(c, c, True, kernel, stride, padding, groups))
|
69 |
+
self.layers = Sequential(*modules)
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
return self.layers(x)
|
73 |
+
|
74 |
+
|
75 |
+
class GDC(Module):
|
76 |
+
def __init__(self, embedding_size):
|
77 |
+
super(GDC, self).__init__()
|
78 |
+
self.layers = nn.Sequential(
|
79 |
+
LinearBlock(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0)),
|
80 |
+
Flatten(),
|
81 |
+
Linear(512, embedding_size, bias=False),
|
82 |
+
BatchNorm1d(embedding_size))
|
83 |
+
|
84 |
+
def forward(self, x):
|
85 |
+
return self.layers(x)
|
86 |
+
|
87 |
+
|
88 |
+
class MobileFaceNet(Module):
|
89 |
+
def __init__(self, fp16=False, num_features=512):
|
90 |
+
super(MobileFaceNet, self).__init__()
|
91 |
+
scale = 2
|
92 |
+
self.fp16 = fp16
|
93 |
+
self.layers = nn.Sequential(
|
94 |
+
ConvBlock(3, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1)),
|
95 |
+
ConvBlock(64 * scale, 64 * scale, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64),
|
96 |
+
DepthWise(64 * scale, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128),
|
97 |
+
Residual(64 * scale, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
|
98 |
+
DepthWise(64 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256),
|
99 |
+
Residual(128 * scale, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
|
100 |
+
DepthWise(128 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512),
|
101 |
+
Residual(128 * scale, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
|
102 |
+
)
|
103 |
+
self.conv_sep = ConvBlock(128 * scale, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0))
|
104 |
+
self.features = GDC(num_features)
|
105 |
+
self._initialize_weights()
|
106 |
+
|
107 |
+
def _initialize_weights(self):
|
108 |
+
for m in self.modules():
|
109 |
+
if isinstance(m, nn.Conv2d):
|
110 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
111 |
+
if m.bias is not None:
|
112 |
+
m.bias.data.zero_()
|
113 |
+
elif isinstance(m, nn.BatchNorm2d):
|
114 |
+
m.weight.data.fill_(1)
|
115 |
+
m.bias.data.zero_()
|
116 |
+
elif isinstance(m, nn.Linear):
|
117 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
118 |
+
if m.bias is not None:
|
119 |
+
m.bias.data.zero_()
|
120 |
+
|
121 |
+
def forward(self, x):
|
122 |
+
with torch.cuda.amp.autocast(self.fp16):
|
123 |
+
x = self.layers(x)
|
124 |
+
x = self.conv_sep(x.float() if self.fp16 else x)
|
125 |
+
x = self.features(x)
|
126 |
+
return x
|
127 |
+
|
128 |
+
|
129 |
+
def get_mbf(fp16, num_features):
|
130 |
+
return MobileFaceNet(fp16, num_features)
|
videoretalking/third_part/face3d/models/arcface_torch/configs/3millions.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict as edict
|
2 |
+
|
3 |
+
# configs for test speed
|
4 |
+
|
5 |
+
config = edict()
|
6 |
+
config.loss = "arcface"
|
7 |
+
config.network = "r50"
|
8 |
+
config.resume = False
|
9 |
+
config.output = None
|
10 |
+
config.embedding_size = 512
|
11 |
+
config.sample_rate = 1.0
|
12 |
+
config.fp16 = True
|
13 |
+
config.momentum = 0.9
|
14 |
+
config.weight_decay = 5e-4
|
15 |
+
config.batch_size = 128
|
16 |
+
config.lr = 0.1 # batch size is 512
|
17 |
+
|
18 |
+
config.rec = "synthetic"
|
19 |
+
config.num_classes = 300 * 10000
|
20 |
+
config.num_epoch = 30
|
21 |
+
config.warmup_epoch = -1
|
22 |
+
config.decay_epoch = [10, 16, 22]
|
23 |
+
config.val_targets = []
|
videoretalking/third_part/face3d/models/arcface_torch/configs/3millions_pfc.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict as edict
|
2 |
+
|
3 |
+
# configs for test speed
|
4 |
+
|
5 |
+
config = edict()
|
6 |
+
config.loss = "arcface"
|
7 |
+
config.network = "r50"
|
8 |
+
config.resume = False
|
9 |
+
config.output = None
|
10 |
+
config.embedding_size = 512
|
11 |
+
config.sample_rate = 0.1
|
12 |
+
config.fp16 = True
|
13 |
+
config.momentum = 0.9
|
14 |
+
config.weight_decay = 5e-4
|
15 |
+
config.batch_size = 128
|
16 |
+
config.lr = 0.1 # batch size is 512
|
17 |
+
|
18 |
+
config.rec = "synthetic"
|
19 |
+
config.num_classes = 300 * 10000
|
20 |
+
config.num_epoch = 30
|
21 |
+
config.warmup_epoch = -1
|
22 |
+
config.decay_epoch = [10, 16, 22]
|
23 |
+
config.val_targets = []
|
videoretalking/third_part/face3d/models/arcface_torch/configs/__init__.py
ADDED
File without changes
|
videoretalking/third_part/face3d/models/arcface_torch/configs/base.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict as edict
|
2 |
+
|
3 |
+
# make training faster
|
4 |
+
# our RAM is 256G
|
5 |
+
# mount -t tmpfs -o size=140G tmpfs /train_tmp
|
6 |
+
|
7 |
+
config = edict()
|
8 |
+
config.loss = "arcface"
|
9 |
+
config.network = "r50"
|
10 |
+
config.resume = False
|
11 |
+
config.output = "ms1mv3_arcface_r50"
|
12 |
+
|
13 |
+
config.dataset = "ms1m-retinaface-t1"
|
14 |
+
config.embedding_size = 512
|
15 |
+
config.sample_rate = 1
|
16 |
+
config.fp16 = False
|
17 |
+
config.momentum = 0.9
|
18 |
+
config.weight_decay = 5e-4
|
19 |
+
config.batch_size = 128
|
20 |
+
config.lr = 0.1 # batch size is 512
|
21 |
+
|
22 |
+
if config.dataset == "emore":
|
23 |
+
config.rec = "/train_tmp/faces_emore"
|
24 |
+
config.num_classes = 85742
|
25 |
+
config.num_image = 5822653
|
26 |
+
config.num_epoch = 16
|
27 |
+
config.warmup_epoch = -1
|
28 |
+
config.decay_epoch = [8, 14, ]
|
29 |
+
config.val_targets = ["lfw", ]
|
30 |
+
|
31 |
+
elif config.dataset == "ms1m-retinaface-t1":
|
32 |
+
config.rec = "/train_tmp/ms1m-retinaface-t1"
|
33 |
+
config.num_classes = 93431
|
34 |
+
config.num_image = 5179510
|
35 |
+
config.num_epoch = 25
|
36 |
+
config.warmup_epoch = -1
|
37 |
+
config.decay_epoch = [11, 17, 22]
|
38 |
+
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|
39 |
+
|
40 |
+
elif config.dataset == "glint360k":
|
41 |
+
config.rec = "/train_tmp/glint360k"
|
42 |
+
config.num_classes = 360232
|
43 |
+
config.num_image = 17091657
|
44 |
+
config.num_epoch = 20
|
45 |
+
config.warmup_epoch = -1
|
46 |
+
config.decay_epoch = [8, 12, 15, 18]
|
47 |
+
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|
48 |
+
|
49 |
+
elif config.dataset == "webface":
|
50 |
+
config.rec = "/train_tmp/faces_webface_112x112"
|
51 |
+
config.num_classes = 10572
|
52 |
+
config.num_image = "forget"
|
53 |
+
config.num_epoch = 34
|
54 |
+
config.warmup_epoch = -1
|
55 |
+
config.decay_epoch = [20, 28, 32]
|
56 |
+
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|
videoretalking/third_part/face3d/models/arcface_torch/configs/glint360k_mbf.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict as edict
|
2 |
+
|
3 |
+
# make training faster
|
4 |
+
# our RAM is 256G
|
5 |
+
# mount -t tmpfs -o size=140G tmpfs /train_tmp
|
6 |
+
|
7 |
+
config = edict()
|
8 |
+
config.loss = "cosface"
|
9 |
+
config.network = "mbf"
|
10 |
+
config.resume = False
|
11 |
+
config.output = None
|
12 |
+
config.embedding_size = 512
|
13 |
+
config.sample_rate = 0.1
|
14 |
+
config.fp16 = True
|
15 |
+
config.momentum = 0.9
|
16 |
+
config.weight_decay = 2e-4
|
17 |
+
config.batch_size = 128
|
18 |
+
config.lr = 0.1 # batch size is 512
|
19 |
+
|
20 |
+
config.rec = "/train_tmp/glint360k"
|
21 |
+
config.num_classes = 360232
|
22 |
+
config.num_image = 17091657
|
23 |
+
config.num_epoch = 20
|
24 |
+
config.warmup_epoch = -1
|
25 |
+
config.decay_epoch = [8, 12, 15, 18]
|
26 |
+
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|
videoretalking/third_part/face3d/models/arcface_torch/configs/glint360k_r100.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict as edict
|
2 |
+
|
3 |
+
# make training faster
|
4 |
+
# our RAM is 256G
|
5 |
+
# mount -t tmpfs -o size=140G tmpfs /train_tmp
|
6 |
+
|
7 |
+
config = edict()
|
8 |
+
config.loss = "cosface"
|
9 |
+
config.network = "r100"
|
10 |
+
config.resume = False
|
11 |
+
config.output = None
|
12 |
+
config.embedding_size = 512
|
13 |
+
config.sample_rate = 1.0
|
14 |
+
config.fp16 = True
|
15 |
+
config.momentum = 0.9
|
16 |
+
config.weight_decay = 5e-4
|
17 |
+
config.batch_size = 128
|
18 |
+
config.lr = 0.1 # batch size is 512
|
19 |
+
|
20 |
+
config.rec = "/train_tmp/glint360k"
|
21 |
+
config.num_classes = 360232
|
22 |
+
config.num_image = 17091657
|
23 |
+
config.num_epoch = 20
|
24 |
+
config.warmup_epoch = -1
|
25 |
+
config.decay_epoch = [8, 12, 15, 18]
|
26 |
+
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|
videoretalking/third_part/face3d/models/arcface_torch/configs/glint360k_r18.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict as edict
|
2 |
+
|
3 |
+
# make training faster
|
4 |
+
# our RAM is 256G
|
5 |
+
# mount -t tmpfs -o size=140G tmpfs /train_tmp
|
6 |
+
|
7 |
+
config = edict()
|
8 |
+
config.loss = "cosface"
|
9 |
+
config.network = "r18"
|
10 |
+
config.resume = False
|
11 |
+
config.output = None
|
12 |
+
config.embedding_size = 512
|
13 |
+
config.sample_rate = 1.0
|
14 |
+
config.fp16 = True
|
15 |
+
config.momentum = 0.9
|
16 |
+
config.weight_decay = 5e-4
|
17 |
+
config.batch_size = 128
|
18 |
+
config.lr = 0.1 # batch size is 512
|
19 |
+
|
20 |
+
config.rec = "/train_tmp/glint360k"
|
21 |
+
config.num_classes = 360232
|
22 |
+
config.num_image = 17091657
|
23 |
+
config.num_epoch = 20
|
24 |
+
config.warmup_epoch = -1
|
25 |
+
config.decay_epoch = [8, 12, 15, 18]
|
26 |
+
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|
videoretalking/third_part/face3d/models/arcface_torch/configs/glint360k_r34.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict as edict
|
2 |
+
|
3 |
+
# make training faster
|
4 |
+
# our RAM is 256G
|
5 |
+
# mount -t tmpfs -o size=140G tmpfs /train_tmp
|
6 |
+
|
7 |
+
config = edict()
|
8 |
+
config.loss = "cosface"
|
9 |
+
config.network = "r34"
|
10 |
+
config.resume = False
|
11 |
+
config.output = None
|
12 |
+
config.embedding_size = 512
|
13 |
+
config.sample_rate = 1.0
|
14 |
+
config.fp16 = True
|
15 |
+
config.momentum = 0.9
|
16 |
+
config.weight_decay = 5e-4
|
17 |
+
config.batch_size = 128
|
18 |
+
config.lr = 0.1 # batch size is 512
|
19 |
+
|
20 |
+
config.rec = "/train_tmp/glint360k"
|
21 |
+
config.num_classes = 360232
|
22 |
+
config.num_image = 17091657
|
23 |
+
config.num_epoch = 20
|
24 |
+
config.warmup_epoch = -1
|
25 |
+
config.decay_epoch = [8, 12, 15, 18]
|
26 |
+
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|
videoretalking/third_part/face3d/models/arcface_torch/configs/glint360k_r50.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict as edict
|
2 |
+
|
3 |
+
# make training faster
|
4 |
+
# our RAM is 256G
|
5 |
+
# mount -t tmpfs -o size=140G tmpfs /train_tmp
|
6 |
+
|
7 |
+
config = edict()
|
8 |
+
config.loss = "cosface"
|
9 |
+
config.network = "r50"
|
10 |
+
config.resume = False
|
11 |
+
config.output = None
|
12 |
+
config.embedding_size = 512
|
13 |
+
config.sample_rate = 1.0
|
14 |
+
config.fp16 = True
|
15 |
+
config.momentum = 0.9
|
16 |
+
config.weight_decay = 5e-4
|
17 |
+
config.batch_size = 128
|
18 |
+
config.lr = 0.1 # batch size is 512
|
19 |
+
|
20 |
+
config.rec = "/train_tmp/glint360k"
|
21 |
+
config.num_classes = 360232
|
22 |
+
config.num_image = 17091657
|
23 |
+
config.num_epoch = 20
|
24 |
+
config.warmup_epoch = -1
|
25 |
+
config.decay_epoch = [8, 12, 15, 18]
|
26 |
+
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|
videoretalking/third_part/face3d/models/arcface_torch/configs/ms1mv3_mbf.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict as edict
|
2 |
+
|
3 |
+
# make training faster
|
4 |
+
# our RAM is 256G
|
5 |
+
# mount -t tmpfs -o size=140G tmpfs /train_tmp
|
6 |
+
|
7 |
+
config = edict()
|
8 |
+
config.loss = "arcface"
|
9 |
+
config.network = "mbf"
|
10 |
+
config.resume = False
|
11 |
+
config.output = None
|
12 |
+
config.embedding_size = 512
|
13 |
+
config.sample_rate = 1.0
|
14 |
+
config.fp16 = True
|
15 |
+
config.momentum = 0.9
|
16 |
+
config.weight_decay = 2e-4
|
17 |
+
config.batch_size = 128
|
18 |
+
config.lr = 0.1 # batch size is 512
|
19 |
+
|
20 |
+
config.rec = "/train_tmp/ms1m-retinaface-t1"
|
21 |
+
config.num_classes = 93431
|
22 |
+
config.num_image = 5179510
|
23 |
+
config.num_epoch = 30
|
24 |
+
config.warmup_epoch = -1
|
25 |
+
config.decay_epoch = [10, 20, 25]
|
26 |
+
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|
videoretalking/third_part/face3d/models/arcface_torch/configs/ms1mv3_r18.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict as edict
|
2 |
+
|
3 |
+
# make training faster
|
4 |
+
# our RAM is 256G
|
5 |
+
# mount -t tmpfs -o size=140G tmpfs /train_tmp
|
6 |
+
|
7 |
+
config = edict()
|
8 |
+
config.loss = "arcface"
|
9 |
+
config.network = "r18"
|
10 |
+
config.resume = False
|
11 |
+
config.output = None
|
12 |
+
config.embedding_size = 512
|
13 |
+
config.sample_rate = 1.0
|
14 |
+
config.fp16 = True
|
15 |
+
config.momentum = 0.9
|
16 |
+
config.weight_decay = 5e-4
|
17 |
+
config.batch_size = 128
|
18 |
+
config.lr = 0.1 # batch size is 512
|
19 |
+
|
20 |
+
config.rec = "/train_tmp/ms1m-retinaface-t1"
|
21 |
+
config.num_classes = 93431
|
22 |
+
config.num_image = 5179510
|
23 |
+
config.num_epoch = 25
|
24 |
+
config.warmup_epoch = -1
|
25 |
+
config.decay_epoch = [10, 16, 22]
|
26 |
+
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|
videoretalking/third_part/face3d/models/arcface_torch/configs/ms1mv3_r2060.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict as edict
|
2 |
+
|
3 |
+
# make training faster
|
4 |
+
# our RAM is 256G
|
5 |
+
# mount -t tmpfs -o size=140G tmpfs /train_tmp
|
6 |
+
|
7 |
+
config = edict()
|
8 |
+
config.loss = "arcface"
|
9 |
+
config.network = "r2060"
|
10 |
+
config.resume = False
|
11 |
+
config.output = None
|
12 |
+
config.embedding_size = 512
|
13 |
+
config.sample_rate = 1.0
|
14 |
+
config.fp16 = True
|
15 |
+
config.momentum = 0.9
|
16 |
+
config.weight_decay = 5e-4
|
17 |
+
config.batch_size = 64
|
18 |
+
config.lr = 0.1 # batch size is 512
|
19 |
+
|
20 |
+
config.rec = "/train_tmp/ms1m-retinaface-t1"
|
21 |
+
config.num_classes = 93431
|
22 |
+
config.num_image = 5179510
|
23 |
+
config.num_epoch = 25
|
24 |
+
config.warmup_epoch = -1
|
25 |
+
config.decay_epoch = [10, 16, 22]
|
26 |
+
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|
videoretalking/third_part/face3d/models/arcface_torch/configs/ms1mv3_r34.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict as edict
|
2 |
+
|
3 |
+
# make training faster
|
4 |
+
# our RAM is 256G
|
5 |
+
# mount -t tmpfs -o size=140G tmpfs /train_tmp
|
6 |
+
|
7 |
+
config = edict()
|
8 |
+
config.loss = "arcface"
|
9 |
+
config.network = "r34"
|
10 |
+
config.resume = False
|
11 |
+
config.output = None
|
12 |
+
config.embedding_size = 512
|
13 |
+
config.sample_rate = 1.0
|
14 |
+
config.fp16 = True
|
15 |
+
config.momentum = 0.9
|
16 |
+
config.weight_decay = 5e-4
|
17 |
+
config.batch_size = 128
|
18 |
+
config.lr = 0.1 # batch size is 512
|
19 |
+
|
20 |
+
config.rec = "/train_tmp/ms1m-retinaface-t1"
|
21 |
+
config.num_classes = 93431
|
22 |
+
config.num_image = 5179510
|
23 |
+
config.num_epoch = 25
|
24 |
+
config.warmup_epoch = -1
|
25 |
+
config.decay_epoch = [10, 16, 22]
|
26 |
+
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|
videoretalking/third_part/face3d/models/arcface_torch/configs/ms1mv3_r50.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict as edict
|
2 |
+
|
3 |
+
# make training faster
|
4 |
+
# our RAM is 256G
|
5 |
+
# mount -t tmpfs -o size=140G tmpfs /train_tmp
|
6 |
+
|
7 |
+
config = edict()
|
8 |
+
config.loss = "arcface"
|
9 |
+
config.network = "r50"
|
10 |
+
config.resume = False
|
11 |
+
config.output = None
|
12 |
+
config.embedding_size = 512
|
13 |
+
config.sample_rate = 1.0
|
14 |
+
config.fp16 = True
|
15 |
+
config.momentum = 0.9
|
16 |
+
config.weight_decay = 5e-4
|
17 |
+
config.batch_size = 128
|
18 |
+
config.lr = 0.1 # batch size is 512
|
19 |
+
|
20 |
+
config.rec = "/train_tmp/ms1m-retinaface-t1"
|
21 |
+
config.num_classes = 93431
|
22 |
+
config.num_image = 5179510
|
23 |
+
config.num_epoch = 25
|
24 |
+
config.warmup_epoch = -1
|
25 |
+
config.decay_epoch = [10, 16, 22]
|
26 |
+
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|
videoretalking/third_part/face3d/models/arcface_torch/configs/speed.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict as edict
|
2 |
+
|
3 |
+
# configs for test speed
|
4 |
+
|
5 |
+
config = edict()
|
6 |
+
config.loss = "arcface"
|
7 |
+
config.network = "r50"
|
8 |
+
config.resume = False
|
9 |
+
config.output = None
|
10 |
+
config.embedding_size = 512
|
11 |
+
config.sample_rate = 1.0
|
12 |
+
config.fp16 = True
|
13 |
+
config.momentum = 0.9
|
14 |
+
config.weight_decay = 5e-4
|
15 |
+
config.batch_size = 128
|
16 |
+
config.lr = 0.1 # batch size is 512
|
17 |
+
|
18 |
+
config.rec = "synthetic"
|
19 |
+
config.num_classes = 100 * 10000
|
20 |
+
config.num_epoch = 30
|
21 |
+
config.warmup_epoch = -1
|
22 |
+
config.decay_epoch = [10, 16, 22]
|
23 |
+
config.val_targets = []
|
videoretalking/third_part/face3d/models/arcface_torch/dataset.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numbers
|
2 |
+
import os
|
3 |
+
import queue as Queue
|
4 |
+
import threading
|
5 |
+
|
6 |
+
import mxnet as mx
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from torch.utils.data import DataLoader, Dataset
|
10 |
+
from torchvision import transforms
|
11 |
+
|
12 |
+
|
13 |
+
class BackgroundGenerator(threading.Thread):
|
14 |
+
def __init__(self, generator, local_rank, max_prefetch=6):
|
15 |
+
super(BackgroundGenerator, self).__init__()
|
16 |
+
self.queue = Queue.Queue(max_prefetch)
|
17 |
+
self.generator = generator
|
18 |
+
self.local_rank = local_rank
|
19 |
+
self.daemon = True
|
20 |
+
self.start()
|
21 |
+
|
22 |
+
def run(self):
|
23 |
+
torch.cuda.set_device(self.local_rank)
|
24 |
+
for item in self.generator:
|
25 |
+
self.queue.put(item)
|
26 |
+
self.queue.put(None)
|
27 |
+
|
28 |
+
def next(self):
|
29 |
+
next_item = self.queue.get()
|
30 |
+
if next_item is None:
|
31 |
+
raise StopIteration
|
32 |
+
return next_item
|
33 |
+
|
34 |
+
def __next__(self):
|
35 |
+
return self.next()
|
36 |
+
|
37 |
+
def __iter__(self):
|
38 |
+
return self
|
39 |
+
|
40 |
+
|
41 |
+
class DataLoaderX(DataLoader):
|
42 |
+
|
43 |
+
def __init__(self, local_rank, **kwargs):
|
44 |
+
super(DataLoaderX, self).__init__(**kwargs)
|
45 |
+
self.stream = torch.cuda.Stream(local_rank)
|
46 |
+
self.local_rank = local_rank
|
47 |
+
|
48 |
+
def __iter__(self):
|
49 |
+
self.iter = super(DataLoaderX, self).__iter__()
|
50 |
+
self.iter = BackgroundGenerator(self.iter, self.local_rank)
|
51 |
+
self.preload()
|
52 |
+
return self
|
53 |
+
|
54 |
+
def preload(self):
|
55 |
+
self.batch = next(self.iter, None)
|
56 |
+
if self.batch is None:
|
57 |
+
return None
|
58 |
+
with torch.cuda.stream(self.stream):
|
59 |
+
for k in range(len(self.batch)):
|
60 |
+
self.batch[k] = self.batch[k].to(device=self.local_rank, non_blocking=True)
|
61 |
+
|
62 |
+
def __next__(self):
|
63 |
+
torch.cuda.current_stream().wait_stream(self.stream)
|
64 |
+
batch = self.batch
|
65 |
+
if batch is None:
|
66 |
+
raise StopIteration
|
67 |
+
self.preload()
|
68 |
+
return batch
|
69 |
+
|
70 |
+
|
71 |
+
class MXFaceDataset(Dataset):
|
72 |
+
def __init__(self, root_dir, local_rank):
|
73 |
+
super(MXFaceDataset, self).__init__()
|
74 |
+
self.transform = transforms.Compose(
|
75 |
+
[transforms.ToPILImage(),
|
76 |
+
transforms.RandomHorizontalFlip(),
|
77 |
+
transforms.ToTensor(),
|
78 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
79 |
+
])
|
80 |
+
self.root_dir = root_dir
|
81 |
+
self.local_rank = local_rank
|
82 |
+
path_imgrec = os.path.join(root_dir, 'train.rec')
|
83 |
+
path_imgidx = os.path.join(root_dir, 'train.idx')
|
84 |
+
self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r')
|
85 |
+
s = self.imgrec.read_idx(0)
|
86 |
+
header, _ = mx.recordio.unpack(s)
|
87 |
+
if header.flag > 0:
|
88 |
+
self.header0 = (int(header.label[0]), int(header.label[1]))
|
89 |
+
self.imgidx = np.array(range(1, int(header.label[0])))
|
90 |
+
else:
|
91 |
+
self.imgidx = np.array(list(self.imgrec.keys))
|
92 |
+
|
93 |
+
def __getitem__(self, index):
|
94 |
+
idx = self.imgidx[index]
|
95 |
+
s = self.imgrec.read_idx(idx)
|
96 |
+
header, img = mx.recordio.unpack(s)
|
97 |
+
label = header.label
|
98 |
+
if not isinstance(label, numbers.Number):
|
99 |
+
label = label[0]
|
100 |
+
label = torch.tensor(label, dtype=torch.long)
|
101 |
+
sample = mx.image.imdecode(img).asnumpy()
|
102 |
+
if self.transform is not None:
|
103 |
+
sample = self.transform(sample)
|
104 |
+
return sample, label
|
105 |
+
|
106 |
+
def __len__(self):
|
107 |
+
return len(self.imgidx)
|
108 |
+
|
109 |
+
|
110 |
+
class SyntheticDataset(Dataset):
|
111 |
+
def __init__(self, local_rank):
|
112 |
+
super(SyntheticDataset, self).__init__()
|
113 |
+
img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32)
|
114 |
+
img = np.transpose(img, (2, 0, 1))
|
115 |
+
img = torch.from_numpy(img).squeeze(0).float()
|
116 |
+
img = ((img / 255) - 0.5) / 0.5
|
117 |
+
self.img = img
|
118 |
+
self.label = 1
|
119 |
+
|
120 |
+
def __getitem__(self, index):
|
121 |
+
return self.img, self.label
|
122 |
+
|
123 |
+
def __len__(self):
|
124 |
+
return 1000000
|
videoretalking/third_part/face3d/models/arcface_torch/docs/eval.md
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Eval on ICCV2021-MFR
|
2 |
+
|
3 |
+
coming soon.
|
4 |
+
|
5 |
+
|
6 |
+
## Eval IJBC
|
7 |
+
You can eval ijbc with pytorch or onnx.
|
8 |
+
|
9 |
+
|
10 |
+
1. Eval IJBC With Onnx
|
11 |
+
```shell
|
12 |
+
CUDA_VISIBLE_DEVICES=0 python onnx_ijbc.py --model-root ms1mv3_arcface_r50 --image-path IJB_release/IJBC --result-dir ms1mv3_arcface_r50
|
13 |
+
```
|
14 |
+
|
15 |
+
2. Eval IJBC With Pytorch
|
16 |
+
```shell
|
17 |
+
CUDA_VISIBLE_DEVICES=0,1 python eval_ijbc.py \
|
18 |
+
--model-prefix ms1mv3_arcface_r50/backbone.pth \
|
19 |
+
--image-path IJB_release/IJBC \
|
20 |
+
--result-dir ms1mv3_arcface_r50 \
|
21 |
+
--batch-size 128 \
|
22 |
+
--job ms1mv3_arcface_r50 \
|
23 |
+
--target IJBC \
|
24 |
+
--network iresnet50
|
25 |
+
```
|
26 |
+
|
27 |
+
## Inference
|
28 |
+
|
29 |
+
```shell
|
30 |
+
python inference.py --weight ms1mv3_arcface_r50/backbone.pth --network r50
|
31 |
+
```
|
videoretalking/third_part/face3d/models/arcface_torch/docs/install.md
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## v1.8.0
|
2 |
+
### Linux and Windows
|
3 |
+
```shell
|
4 |
+
# CUDA 11.0
|
5 |
+
pip --default-timeout=100 install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html
|
6 |
+
|
7 |
+
# CUDA 10.2
|
8 |
+
pip --default-timeout=100 install torch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0
|
9 |
+
|
10 |
+
# CPU only
|
11 |
+
pip --default-timeout=100 install torch==1.8.0+cpu torchvision==0.9.0+cpu torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html
|
12 |
+
|
13 |
+
```
|
14 |
+
|
15 |
+
|
16 |
+
## v1.7.1
|
17 |
+
### Linux and Windows
|
18 |
+
```shell
|
19 |
+
# CUDA 11.0
|
20 |
+
pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
|
21 |
+
|
22 |
+
# CUDA 10.2
|
23 |
+
pip install torch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2
|
24 |
+
|
25 |
+
# CUDA 10.1
|
26 |
+
pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
|
27 |
+
|
28 |
+
# CUDA 9.2
|
29 |
+
pip install torch==1.7.1+cu92 torchvision==0.8.2+cu92 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
|
30 |
+
|
31 |
+
# CPU only
|
32 |
+
pip install torch==1.7.1+cpu torchvision==0.8.2+cpu torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
|
33 |
+
```
|
34 |
+
|
35 |
+
|
36 |
+
## v1.6.0
|
37 |
+
|
38 |
+
### Linux and Windows
|
39 |
+
```shell
|
40 |
+
# CUDA 10.2
|
41 |
+
pip install torch==1.6.0 torchvision==0.7.0
|
42 |
+
|
43 |
+
# CUDA 10.1
|
44 |
+
pip install torch==1.6.0+cu101 torchvision==0.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
|
45 |
+
|
46 |
+
# CUDA 9.2
|
47 |
+
pip install torch==1.6.0+cu92 torchvision==0.7.0+cu92 -f https://download.pytorch.org/whl/torch_stable.html
|
48 |
+
|
49 |
+
# CPU only
|
50 |
+
pip install torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
51 |
+
```
|
videoretalking/third_part/face3d/models/arcface_torch/docs/modelzoo.md
ADDED
File without changes
|
videoretalking/third_part/face3d/models/arcface_torch/docs/speed_benchmark.md
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Test Training Speed
|
2 |
+
|
3 |
+
- Test Commands
|
4 |
+
|
5 |
+
You need to use the following two commands to test the Partial FC training performance.
|
6 |
+
The number of identites is **3 millions** (synthetic data), turn mixed precision training on, backbone is resnet50,
|
7 |
+
batch size is 1024.
|
8 |
+
```shell
|
9 |
+
# Model Parallel
|
10 |
+
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/3millions
|
11 |
+
# Partial FC 0.1
|
12 |
+
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/3millions_pfc
|
13 |
+
```
|
14 |
+
|
15 |
+
- GPU Memory
|
16 |
+
|
17 |
+
```
|
18 |
+
# (Model Parallel) gpustat -i
|
19 |
+
[0] Tesla V100-SXM2-32GB | 64'C, 94 % | 30338 / 32510 MB
|
20 |
+
[1] Tesla V100-SXM2-32GB | 60'C, 99 % | 28876 / 32510 MB
|
21 |
+
[2] Tesla V100-SXM2-32GB | 60'C, 99 % | 28872 / 32510 MB
|
22 |
+
[3] Tesla V100-SXM2-32GB | 69'C, 99 % | 28872 / 32510 MB
|
23 |
+
[4] Tesla V100-SXM2-32GB | 66'C, 99 % | 28888 / 32510 MB
|
24 |
+
[5] Tesla V100-SXM2-32GB | 60'C, 99 % | 28932 / 32510 MB
|
25 |
+
[6] Tesla V100-SXM2-32GB | 68'C, 100 % | 28916 / 32510 MB
|
26 |
+
[7] Tesla V100-SXM2-32GB | 65'C, 99 % | 28860 / 32510 MB
|
27 |
+
|
28 |
+
# (Partial FC 0.1) gpustat -i
|
29 |
+
[0] Tesla V100-SXM2-32GB | 60'C, 95 % | 10488 / 32510 MB │·······················
|
30 |
+
[1] Tesla V100-SXM2-32GB | 60'C, 97 % | 10344 / 32510 MB │·······················
|
31 |
+
[2] Tesla V100-SXM2-32GB | 61'C, 95 % | 10340 / 32510 MB │·······················
|
32 |
+
[3] Tesla V100-SXM2-32GB | 66'C, 95 % | 10340 / 32510 MB │·······················
|
33 |
+
[4] Tesla V100-SXM2-32GB | 65'C, 94 % | 10356 / 32510 MB │·······················
|
34 |
+
[5] Tesla V100-SXM2-32GB | 61'C, 95 % | 10400 / 32510 MB │·······················
|
35 |
+
[6] Tesla V100-SXM2-32GB | 68'C, 96 % | 10384 / 32510 MB │·······················
|
36 |
+
[7] Tesla V100-SXM2-32GB | 64'C, 95 % | 10328 / 32510 MB │·······················
|
37 |
+
```
|
38 |
+
|
39 |
+
- Training Speed
|
40 |
+
|
41 |
+
```python
|
42 |
+
# (Model Parallel) trainging.log
|
43 |
+
Training: Speed 2271.33 samples/sec Loss 1.1624 LearningRate 0.2000 Epoch: 0 Global Step: 100
|
44 |
+
Training: Speed 2269.94 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 150
|
45 |
+
Training: Speed 2272.67 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 200
|
46 |
+
Training: Speed 2266.55 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 250
|
47 |
+
Training: Speed 2272.54 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 300
|
48 |
+
|
49 |
+
# (Partial FC 0.1) trainging.log
|
50 |
+
Training: Speed 5299.56 samples/sec Loss 1.0965 LearningRate 0.2000 Epoch: 0 Global Step: 100
|
51 |
+
Training: Speed 5296.37 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 150
|
52 |
+
Training: Speed 5304.37 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 200
|
53 |
+
Training: Speed 5274.43 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 250
|
54 |
+
Training: Speed 5300.10 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 300
|
55 |
+
```
|
56 |
+
|
57 |
+
In this test case, Partial FC 0.1 only use1 1/3 of the GPU memory of the model parallel,
|
58 |
+
and the training speed is 2.5 times faster than the model parallel.
|
59 |
+
|
60 |
+
|
61 |
+
## Speed Benchmark
|
62 |
+
|
63 |
+
1. Training speed of different parallel methods (samples/second), Tesla V100 32GB * 8. (Larger is better)
|
64 |
+
|
65 |
+
| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
|
66 |
+
| :--- | :--- | :--- | :--- |
|
67 |
+
|125000 | 4681 | 4824 | 5004 |
|
68 |
+
|250000 | 4047 | 4521 | 4976 |
|
69 |
+
|500000 | 3087 | 4013 | 4900 |
|
70 |
+
|1000000 | 2090 | 3449 | 4803 |
|
71 |
+
|1400000 | 1672 | 3043 | 4738 |
|
72 |
+
|2000000 | - | 2593 | 4626 |
|
73 |
+
|4000000 | - | 1748 | 4208 |
|
74 |
+
|5500000 | - | 1389 | 3975 |
|
75 |
+
|8000000 | - | - | 3565 |
|
76 |
+
|16000000 | - | - | 2679 |
|
77 |
+
|29000000 | - | - | 1855 |
|
78 |
+
|
79 |
+
2. GPU memory cost of different parallel methods (GB per GPU), Tesla V100 32GB * 8. (Smaller is better)
|
80 |
+
|
81 |
+
| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
|
82 |
+
| :--- | :--- | :--- | :--- |
|
83 |
+
|125000 | 7358 | 5306 | 4868 |
|
84 |
+
|250000 | 9940 | 5826 | 5004 |
|
85 |
+
|500000 | 14220 | 7114 | 5202 |
|
86 |
+
|1000000 | 23708 | 9966 | 5620 |
|
87 |
+
|1400000 | 32252 | 11178 | 6056 |
|
88 |
+
|2000000 | - | 13978 | 6472 |
|
89 |
+
|4000000 | - | 23238 | 8284 |
|
90 |
+
|5500000 | - | 32188 | 9854 |
|
91 |
+
|8000000 | - | - | 12310 |
|
92 |
+
|16000000 | - | - | 19950 |
|
93 |
+
|29000000 | - | - | 32324 |
|
videoretalking/third_part/face3d/models/arcface_torch/eval/__init__.py
ADDED
File without changes
|
videoretalking/third_part/face3d/models/arcface_torch/eval/verification.py
ADDED
@@ -0,0 +1,407 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Helper for evaluation on the Labeled Faces in the Wild dataset
|
2 |
+
"""
|
3 |
+
|
4 |
+
# MIT License
|
5 |
+
#
|
6 |
+
# Copyright (c) 2016 David Sandberg
|
7 |
+
#
|
8 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
9 |
+
# of this software and associated documentation files (the "Software"), to deal
|
10 |
+
# in the Software without restriction, including without limitation the rights
|
11 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
12 |
+
# copies of the Software, and to permit persons to whom the Software is
|
13 |
+
# furnished to do so, subject to the following conditions:
|
14 |
+
#
|
15 |
+
# The above copyright notice and this permission notice shall be included in all
|
16 |
+
# copies or substantial portions of the Software.
|
17 |
+
#
|
18 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
19 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
20 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
21 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
22 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
23 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
24 |
+
# SOFTWARE.
|
25 |
+
|
26 |
+
|
27 |
+
import datetime
|
28 |
+
import os
|
29 |
+
import pickle
|
30 |
+
|
31 |
+
import mxnet as mx
|
32 |
+
import numpy as np
|
33 |
+
import sklearn
|
34 |
+
import torch
|
35 |
+
from mxnet import ndarray as nd
|
36 |
+
from scipy import interpolate
|
37 |
+
from sklearn.decomposition import PCA
|
38 |
+
from sklearn.model_selection import KFold
|
39 |
+
|
40 |
+
|
41 |
+
class LFold:
|
42 |
+
def __init__(self, n_splits=2, shuffle=False):
|
43 |
+
self.n_splits = n_splits
|
44 |
+
if self.n_splits > 1:
|
45 |
+
self.k_fold = KFold(n_splits=n_splits, shuffle=shuffle)
|
46 |
+
|
47 |
+
def split(self, indices):
|
48 |
+
if self.n_splits > 1:
|
49 |
+
return self.k_fold.split(indices)
|
50 |
+
else:
|
51 |
+
return [(indices, indices)]
|
52 |
+
|
53 |
+
|
54 |
+
def calculate_roc(thresholds,
|
55 |
+
embeddings1,
|
56 |
+
embeddings2,
|
57 |
+
actual_issame,
|
58 |
+
nrof_folds=10,
|
59 |
+
pca=0):
|
60 |
+
assert (embeddings1.shape[0] == embeddings2.shape[0])
|
61 |
+
assert (embeddings1.shape[1] == embeddings2.shape[1])
|
62 |
+
nrof_pairs = min(len(actual_issame), embeddings1.shape[0])
|
63 |
+
nrof_thresholds = len(thresholds)
|
64 |
+
k_fold = LFold(n_splits=nrof_folds, shuffle=False)
|
65 |
+
|
66 |
+
tprs = np.zeros((nrof_folds, nrof_thresholds))
|
67 |
+
fprs = np.zeros((nrof_folds, nrof_thresholds))
|
68 |
+
accuracy = np.zeros((nrof_folds))
|
69 |
+
indices = np.arange(nrof_pairs)
|
70 |
+
|
71 |
+
if pca == 0:
|
72 |
+
diff = np.subtract(embeddings1, embeddings2)
|
73 |
+
dist = np.sum(np.square(diff), 1)
|
74 |
+
|
75 |
+
for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)):
|
76 |
+
if pca > 0:
|
77 |
+
print('doing pca on', fold_idx)
|
78 |
+
embed1_train = embeddings1[train_set]
|
79 |
+
embed2_train = embeddings2[train_set]
|
80 |
+
_embed_train = np.concatenate((embed1_train, embed2_train), axis=0)
|
81 |
+
pca_model = PCA(n_components=pca)
|
82 |
+
pca_model.fit(_embed_train)
|
83 |
+
embed1 = pca_model.transform(embeddings1)
|
84 |
+
embed2 = pca_model.transform(embeddings2)
|
85 |
+
embed1 = sklearn.preprocessing.normalize(embed1)
|
86 |
+
embed2 = sklearn.preprocessing.normalize(embed2)
|
87 |
+
diff = np.subtract(embed1, embed2)
|
88 |
+
dist = np.sum(np.square(diff), 1)
|
89 |
+
|
90 |
+
# Find the best threshold for the fold
|
91 |
+
acc_train = np.zeros((nrof_thresholds))
|
92 |
+
for threshold_idx, threshold in enumerate(thresholds):
|
93 |
+
_, _, acc_train[threshold_idx] = calculate_accuracy(
|
94 |
+
threshold, dist[train_set], actual_issame[train_set])
|
95 |
+
best_threshold_index = np.argmax(acc_train)
|
96 |
+
for threshold_idx, threshold in enumerate(thresholds):
|
97 |
+
tprs[fold_idx, threshold_idx], fprs[fold_idx, threshold_idx], _ = calculate_accuracy(
|
98 |
+
threshold, dist[test_set],
|
99 |
+
actual_issame[test_set])
|
100 |
+
_, _, accuracy[fold_idx] = calculate_accuracy(
|
101 |
+
thresholds[best_threshold_index], dist[test_set],
|
102 |
+
actual_issame[test_set])
|
103 |
+
|
104 |
+
tpr = np.mean(tprs, 0)
|
105 |
+
fpr = np.mean(fprs, 0)
|
106 |
+
return tpr, fpr, accuracy
|
107 |
+
|
108 |
+
|
109 |
+
def calculate_accuracy(threshold, dist, actual_issame):
|
110 |
+
predict_issame = np.less(dist, threshold)
|
111 |
+
tp = np.sum(np.logical_and(predict_issame, actual_issame))
|
112 |
+
fp = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame)))
|
113 |
+
tn = np.sum(
|
114 |
+
np.logical_and(np.logical_not(predict_issame),
|
115 |
+
np.logical_not(actual_issame)))
|
116 |
+
fn = np.sum(np.logical_and(np.logical_not(predict_issame), actual_issame))
|
117 |
+
|
118 |
+
tpr = 0 if (tp + fn == 0) else float(tp) / float(tp + fn)
|
119 |
+
fpr = 0 if (fp + tn == 0) else float(fp) / float(fp + tn)
|
120 |
+
acc = float(tp + tn) / dist.size
|
121 |
+
return tpr, fpr, acc
|
122 |
+
|
123 |
+
|
124 |
+
def calculate_val(thresholds,
|
125 |
+
embeddings1,
|
126 |
+
embeddings2,
|
127 |
+
actual_issame,
|
128 |
+
far_target,
|
129 |
+
nrof_folds=10):
|
130 |
+
assert (embeddings1.shape[0] == embeddings2.shape[0])
|
131 |
+
assert (embeddings1.shape[1] == embeddings2.shape[1])
|
132 |
+
nrof_pairs = min(len(actual_issame), embeddings1.shape[0])
|
133 |
+
nrof_thresholds = len(thresholds)
|
134 |
+
k_fold = LFold(n_splits=nrof_folds, shuffle=False)
|
135 |
+
|
136 |
+
val = np.zeros(nrof_folds)
|
137 |
+
far = np.zeros(nrof_folds)
|
138 |
+
|
139 |
+
diff = np.subtract(embeddings1, embeddings2)
|
140 |
+
dist = np.sum(np.square(diff), 1)
|
141 |
+
indices = np.arange(nrof_pairs)
|
142 |
+
|
143 |
+
for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)):
|
144 |
+
|
145 |
+
# Find the threshold that gives FAR = far_target
|
146 |
+
far_train = np.zeros(nrof_thresholds)
|
147 |
+
for threshold_idx, threshold in enumerate(thresholds):
|
148 |
+
_, far_train[threshold_idx] = calculate_val_far(
|
149 |
+
threshold, dist[train_set], actual_issame[train_set])
|
150 |
+
if np.max(far_train) >= far_target:
|
151 |
+
f = interpolate.interp1d(far_train, thresholds, kind='slinear')
|
152 |
+
threshold = f(far_target)
|
153 |
+
else:
|
154 |
+
threshold = 0.0
|
155 |
+
|
156 |
+
val[fold_idx], far[fold_idx] = calculate_val_far(
|
157 |
+
threshold, dist[test_set], actual_issame[test_set])
|
158 |
+
|
159 |
+
val_mean = np.mean(val)
|
160 |
+
far_mean = np.mean(far)
|
161 |
+
val_std = np.std(val)
|
162 |
+
return val_mean, val_std, far_mean
|
163 |
+
|
164 |
+
|
165 |
+
def calculate_val_far(threshold, dist, actual_issame):
|
166 |
+
predict_issame = np.less(dist, threshold)
|
167 |
+
true_accept = np.sum(np.logical_and(predict_issame, actual_issame))
|
168 |
+
false_accept = np.sum(
|
169 |
+
np.logical_and(predict_issame, np.logical_not(actual_issame)))
|
170 |
+
n_same = np.sum(actual_issame)
|
171 |
+
n_diff = np.sum(np.logical_not(actual_issame))
|
172 |
+
# print(true_accept, false_accept)
|
173 |
+
# print(n_same, n_diff)
|
174 |
+
val = float(true_accept) / float(n_same)
|
175 |
+
far = float(false_accept) / float(n_diff)
|
176 |
+
return val, far
|
177 |
+
|
178 |
+
|
179 |
+
def evaluate(embeddings, actual_issame, nrof_folds=10, pca=0):
|
180 |
+
# Calculate evaluation metrics
|
181 |
+
thresholds = np.arange(0, 4, 0.01)
|
182 |
+
embeddings1 = embeddings[0::2]
|
183 |
+
embeddings2 = embeddings[1::2]
|
184 |
+
tpr, fpr, accuracy = calculate_roc(thresholds,
|
185 |
+
embeddings1,
|
186 |
+
embeddings2,
|
187 |
+
np.asarray(actual_issame),
|
188 |
+
nrof_folds=nrof_folds,
|
189 |
+
pca=pca)
|
190 |
+
thresholds = np.arange(0, 4, 0.001)
|
191 |
+
val, val_std, far = calculate_val(thresholds,
|
192 |
+
embeddings1,
|
193 |
+
embeddings2,
|
194 |
+
np.asarray(actual_issame),
|
195 |
+
1e-3,
|
196 |
+
nrof_folds=nrof_folds)
|
197 |
+
return tpr, fpr, accuracy, val, val_std, far
|
198 |
+
|
199 |
+
@torch.no_grad()
|
200 |
+
def load_bin(path, image_size):
|
201 |
+
try:
|
202 |
+
with open(path, 'rb') as f:
|
203 |
+
bins, issame_list = pickle.load(f) # py2
|
204 |
+
except UnicodeDecodeError as e:
|
205 |
+
with open(path, 'rb') as f:
|
206 |
+
bins, issame_list = pickle.load(f, encoding='bytes') # py3
|
207 |
+
data_list = []
|
208 |
+
for flip in [0, 1]:
|
209 |
+
data = torch.empty((len(issame_list) * 2, 3, image_size[0], image_size[1]))
|
210 |
+
data_list.append(data)
|
211 |
+
for idx in range(len(issame_list) * 2):
|
212 |
+
_bin = bins[idx]
|
213 |
+
img = mx.image.imdecode(_bin)
|
214 |
+
if img.shape[1] != image_size[0]:
|
215 |
+
img = mx.image.resize_short(img, image_size[0])
|
216 |
+
img = nd.transpose(img, axes=(2, 0, 1))
|
217 |
+
for flip in [0, 1]:
|
218 |
+
if flip == 1:
|
219 |
+
img = mx.ndarray.flip(data=img, axis=2)
|
220 |
+
data_list[flip][idx][:] = torch.from_numpy(img.asnumpy())
|
221 |
+
if idx % 1000 == 0:
|
222 |
+
print('loading bin', idx)
|
223 |
+
print(data_list[0].shape)
|
224 |
+
return data_list, issame_list
|
225 |
+
|
226 |
+
@torch.no_grad()
|
227 |
+
def test(data_set, backbone, batch_size, nfolds=10):
|
228 |
+
print('testing verification..')
|
229 |
+
data_list = data_set[0]
|
230 |
+
issame_list = data_set[1]
|
231 |
+
embeddings_list = []
|
232 |
+
time_consumed = 0.0
|
233 |
+
for i in range(len(data_list)):
|
234 |
+
data = data_list[i]
|
235 |
+
embeddings = None
|
236 |
+
ba = 0
|
237 |
+
while ba < data.shape[0]:
|
238 |
+
bb = min(ba + batch_size, data.shape[0])
|
239 |
+
count = bb - ba
|
240 |
+
_data = data[bb - batch_size: bb]
|
241 |
+
time0 = datetime.datetime.now()
|
242 |
+
img = ((_data / 255) - 0.5) / 0.5
|
243 |
+
net_out: torch.Tensor = backbone(img)
|
244 |
+
_embeddings = net_out.detach().cpu().numpy()
|
245 |
+
time_now = datetime.datetime.now()
|
246 |
+
diff = time_now - time0
|
247 |
+
time_consumed += diff.total_seconds()
|
248 |
+
if embeddings is None:
|
249 |
+
embeddings = np.zeros((data.shape[0], _embeddings.shape[1]))
|
250 |
+
embeddings[ba:bb, :] = _embeddings[(batch_size - count):, :]
|
251 |
+
ba = bb
|
252 |
+
embeddings_list.append(embeddings)
|
253 |
+
|
254 |
+
_xnorm = 0.0
|
255 |
+
_xnorm_cnt = 0
|
256 |
+
for embed in embeddings_list:
|
257 |
+
for i in range(embed.shape[0]):
|
258 |
+
_em = embed[i]
|
259 |
+
_norm = np.linalg.norm(_em)
|
260 |
+
_xnorm += _norm
|
261 |
+
_xnorm_cnt += 1
|
262 |
+
_xnorm /= _xnorm_cnt
|
263 |
+
|
264 |
+
acc1 = 0.0
|
265 |
+
std1 = 0.0
|
266 |
+
embeddings = embeddings_list[0] + embeddings_list[1]
|
267 |
+
embeddings = sklearn.preprocessing.normalize(embeddings)
|
268 |
+
print(embeddings.shape)
|
269 |
+
print('infer time', time_consumed)
|
270 |
+
_, _, accuracy, val, val_std, far = evaluate(embeddings, issame_list, nrof_folds=nfolds)
|
271 |
+
acc2, std2 = np.mean(accuracy), np.std(accuracy)
|
272 |
+
return acc1, std1, acc2, std2, _xnorm, embeddings_list
|
273 |
+
|
274 |
+
|
275 |
+
def dumpR(data_set,
|
276 |
+
backbone,
|
277 |
+
batch_size,
|
278 |
+
name='',
|
279 |
+
data_extra=None,
|
280 |
+
label_shape=None):
|
281 |
+
print('dump verification embedding..')
|
282 |
+
data_list = data_set[0]
|
283 |
+
issame_list = data_set[1]
|
284 |
+
embeddings_list = []
|
285 |
+
time_consumed = 0.0
|
286 |
+
for i in range(len(data_list)):
|
287 |
+
data = data_list[i]
|
288 |
+
embeddings = None
|
289 |
+
ba = 0
|
290 |
+
while ba < data.shape[0]:
|
291 |
+
bb = min(ba + batch_size, data.shape[0])
|
292 |
+
count = bb - ba
|
293 |
+
|
294 |
+
_data = nd.slice_axis(data, axis=0, begin=bb - batch_size, end=bb)
|
295 |
+
time0 = datetime.datetime.now()
|
296 |
+
if data_extra is None:
|
297 |
+
db = mx.io.DataBatch(data=(_data,), label=(_label,))
|
298 |
+
else:
|
299 |
+
db = mx.io.DataBatch(data=(_data, _data_extra),
|
300 |
+
label=(_label,))
|
301 |
+
model.forward(db, is_train=False)
|
302 |
+
net_out = model.get_outputs()
|
303 |
+
_embeddings = net_out[0].asnumpy()
|
304 |
+
time_now = datetime.datetime.now()
|
305 |
+
diff = time_now - time0
|
306 |
+
time_consumed += diff.total_seconds()
|
307 |
+
if embeddings is None:
|
308 |
+
embeddings = np.zeros((data.shape[0], _embeddings.shape[1]))
|
309 |
+
embeddings[ba:bb, :] = _embeddings[(batch_size - count):, :]
|
310 |
+
ba = bb
|
311 |
+
embeddings_list.append(embeddings)
|
312 |
+
embeddings = embeddings_list[0] + embeddings_list[1]
|
313 |
+
embeddings = sklearn.preprocessing.normalize(embeddings)
|
314 |
+
actual_issame = np.asarray(issame_list)
|
315 |
+
outname = os.path.join('temp.bin')
|
316 |
+
with open(outname, 'wb') as f:
|
317 |
+
pickle.dump((embeddings, issame_list),
|
318 |
+
f,
|
319 |
+
protocol=pickle.HIGHEST_PROTOCOL)
|
320 |
+
|
321 |
+
|
322 |
+
# if __name__ == '__main__':
|
323 |
+
#
|
324 |
+
# parser = argparse.ArgumentParser(description='do verification')
|
325 |
+
# # general
|
326 |
+
# parser.add_argument('--data-dir', default='', help='')
|
327 |
+
# parser.add_argument('--model',
|
328 |
+
# default='../model/softmax,50',
|
329 |
+
# help='path to load model.')
|
330 |
+
# parser.add_argument('--target',
|
331 |
+
# default='lfw,cfp_ff,cfp_fp,agedb_30',
|
332 |
+
# help='test targets.')
|
333 |
+
# parser.add_argument('--gpu', default=0, type=int, help='gpu id')
|
334 |
+
# parser.add_argument('--batch-size', default=32, type=int, help='')
|
335 |
+
# parser.add_argument('--max', default='', type=str, help='')
|
336 |
+
# parser.add_argument('--mode', default=0, type=int, help='')
|
337 |
+
# parser.add_argument('--nfolds', default=10, type=int, help='')
|
338 |
+
# args = parser.parse_args()
|
339 |
+
# image_size = [112, 112]
|
340 |
+
# print('image_size', image_size)
|
341 |
+
# ctx = mx.gpu(args.gpu)
|
342 |
+
# nets = []
|
343 |
+
# vec = args.model.split(',')
|
344 |
+
# prefix = args.model.split(',')[0]
|
345 |
+
# epochs = []
|
346 |
+
# if len(vec) == 1:
|
347 |
+
# pdir = os.path.dirname(prefix)
|
348 |
+
# for fname in os.listdir(pdir):
|
349 |
+
# if not fname.endswith('.params'):
|
350 |
+
# continue
|
351 |
+
# _file = os.path.join(pdir, fname)
|
352 |
+
# if _file.startswith(prefix):
|
353 |
+
# epoch = int(fname.split('.')[0].split('-')[1])
|
354 |
+
# epochs.append(epoch)
|
355 |
+
# epochs = sorted(epochs, reverse=True)
|
356 |
+
# if len(args.max) > 0:
|
357 |
+
# _max = [int(x) for x in args.max.split(',')]
|
358 |
+
# assert len(_max) == 2
|
359 |
+
# if len(epochs) > _max[1]:
|
360 |
+
# epochs = epochs[_max[0]:_max[1]]
|
361 |
+
#
|
362 |
+
# else:
|
363 |
+
# epochs = [int(x) for x in vec[1].split('|')]
|
364 |
+
# print('model number', len(epochs))
|
365 |
+
# time0 = datetime.datetime.now()
|
366 |
+
# for epoch in epochs:
|
367 |
+
# print('loading', prefix, epoch)
|
368 |
+
# sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
|
369 |
+
# # arg_params, aux_params = ch_dev(arg_params, aux_params, ctx)
|
370 |
+
# all_layers = sym.get_internals()
|
371 |
+
# sym = all_layers['fc1_output']
|
372 |
+
# model = mx.mod.Module(symbol=sym, context=ctx, label_names=None)
|
373 |
+
# # model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,))])
|
374 |
+
# model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0],
|
375 |
+
# image_size[1]))])
|
376 |
+
# model.set_params(arg_params, aux_params)
|
377 |
+
# nets.append(model)
|
378 |
+
# time_now = datetime.datetime.now()
|
379 |
+
# diff = time_now - time0
|
380 |
+
# print('model loading time', diff.total_seconds())
|
381 |
+
#
|
382 |
+
# ver_list = []
|
383 |
+
# ver_name_list = []
|
384 |
+
# for name in args.target.split(','):
|
385 |
+
# path = os.path.join(args.data_dir, name + ".bin")
|
386 |
+
# if os.path.exists(path):
|
387 |
+
# print('loading.. ', name)
|
388 |
+
# data_set = load_bin(path, image_size)
|
389 |
+
# ver_list.append(data_set)
|
390 |
+
# ver_name_list.append(name)
|
391 |
+
#
|
392 |
+
# if args.mode == 0:
|
393 |
+
# for i in range(len(ver_list)):
|
394 |
+
# results = []
|
395 |
+
# for model in nets:
|
396 |
+
# acc1, std1, acc2, std2, xnorm, embeddings_list = test(
|
397 |
+
# ver_list[i], model, args.batch_size, args.nfolds)
|
398 |
+
# print('[%s]XNorm: %f' % (ver_name_list[i], xnorm))
|
399 |
+
# print('[%s]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], acc1, std1))
|
400 |
+
# print('[%s]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], acc2, std2))
|
401 |
+
# results.append(acc2)
|
402 |
+
# print('Max of [%s] is %1.5f' % (ver_name_list[i], np.max(results)))
|
403 |
+
# elif args.mode == 1:
|
404 |
+
# raise ValueError
|
405 |
+
# else:
|
406 |
+
# model = nets[0]
|
407 |
+
# dumpR(ver_list[0], model, args.batch_size, args.target)
|
videoretalking/third_part/face3d/models/arcface_torch/eval_ijbc.py
ADDED
@@ -0,0 +1,483 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
|
3 |
+
import os
|
4 |
+
import pickle
|
5 |
+
|
6 |
+
import matplotlib
|
7 |
+
import pandas as pd
|
8 |
+
|
9 |
+
matplotlib.use('Agg')
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
import timeit
|
12 |
+
import sklearn
|
13 |
+
import argparse
|
14 |
+
import cv2
|
15 |
+
import numpy as np
|
16 |
+
import torch
|
17 |
+
from skimage import transform as trans
|
18 |
+
from backbones import get_model
|
19 |
+
from sklearn.metrics import roc_curve, auc
|
20 |
+
|
21 |
+
from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap
|
22 |
+
from prettytable import PrettyTable
|
23 |
+
from pathlib import Path
|
24 |
+
|
25 |
+
import sys
|
26 |
+
import warnings
|
27 |
+
|
28 |
+
sys.path.insert(0, "../")
|
29 |
+
warnings.filterwarnings("ignore")
|
30 |
+
|
31 |
+
parser = argparse.ArgumentParser(description='do ijb test')
|
32 |
+
# general
|
33 |
+
parser.add_argument('--model-prefix', default='', help='path to load model.')
|
34 |
+
parser.add_argument('--image-path', default='', type=str, help='')
|
35 |
+
parser.add_argument('--result-dir', default='.', type=str, help='')
|
36 |
+
parser.add_argument('--batch-size', default=128, type=int, help='')
|
37 |
+
parser.add_argument('--network', default='iresnet50', type=str, help='')
|
38 |
+
parser.add_argument('--job', default='insightface', type=str, help='job name')
|
39 |
+
parser.add_argument('--target', default='IJBC', type=str, help='target, set to IJBC or IJBB')
|
40 |
+
args = parser.parse_args()
|
41 |
+
|
42 |
+
target = args.target
|
43 |
+
model_path = args.model_prefix
|
44 |
+
image_path = args.image_path
|
45 |
+
result_dir = args.result_dir
|
46 |
+
gpu_id = None
|
47 |
+
use_norm_score = True # if Ture, TestMode(N1)
|
48 |
+
use_detector_score = True # if Ture, TestMode(D1)
|
49 |
+
use_flip_test = True # if Ture, TestMode(F1)
|
50 |
+
job = args.job
|
51 |
+
batch_size = args.batch_size
|
52 |
+
|
53 |
+
|
54 |
+
class Embedding(object):
|
55 |
+
def __init__(self, prefix, data_shape, batch_size=1):
|
56 |
+
image_size = (112, 112)
|
57 |
+
self.image_size = image_size
|
58 |
+
weight = torch.load(prefix)
|
59 |
+
resnet = get_model(args.network, dropout=0, fp16=False).cuda()
|
60 |
+
resnet.load_state_dict(weight)
|
61 |
+
model = torch.nn.DataParallel(resnet)
|
62 |
+
self.model = model
|
63 |
+
self.model.eval()
|
64 |
+
src = np.array([
|
65 |
+
[30.2946, 51.6963],
|
66 |
+
[65.5318, 51.5014],
|
67 |
+
[48.0252, 71.7366],
|
68 |
+
[33.5493, 92.3655],
|
69 |
+
[62.7299, 92.2041]], dtype=np.float32)
|
70 |
+
src[:, 0] += 8.0
|
71 |
+
self.src = src
|
72 |
+
self.batch_size = batch_size
|
73 |
+
self.data_shape = data_shape
|
74 |
+
|
75 |
+
def get(self, rimg, landmark):
|
76 |
+
|
77 |
+
assert landmark.shape[0] == 68 or landmark.shape[0] == 5
|
78 |
+
assert landmark.shape[1] == 2
|
79 |
+
if landmark.shape[0] == 68:
|
80 |
+
landmark5 = np.zeros((5, 2), dtype=np.float32)
|
81 |
+
landmark5[0] = (landmark[36] + landmark[39]) / 2
|
82 |
+
landmark5[1] = (landmark[42] + landmark[45]) / 2
|
83 |
+
landmark5[2] = landmark[30]
|
84 |
+
landmark5[3] = landmark[48]
|
85 |
+
landmark5[4] = landmark[54]
|
86 |
+
else:
|
87 |
+
landmark5 = landmark
|
88 |
+
tform = trans.SimilarityTransform()
|
89 |
+
tform.estimate(landmark5, self.src)
|
90 |
+
M = tform.params[0:2, :]
|
91 |
+
img = cv2.warpAffine(rimg,
|
92 |
+
M, (self.image_size[1], self.image_size[0]),
|
93 |
+
borderValue=0.0)
|
94 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
95 |
+
img_flip = np.fliplr(img)
|
96 |
+
img = np.transpose(img, (2, 0, 1)) # 3*112*112, RGB
|
97 |
+
img_flip = np.transpose(img_flip, (2, 0, 1))
|
98 |
+
input_blob = np.zeros((2, 3, self.image_size[1], self.image_size[0]), dtype=np.uint8)
|
99 |
+
input_blob[0] = img
|
100 |
+
input_blob[1] = img_flip
|
101 |
+
return input_blob
|
102 |
+
|
103 |
+
@torch.no_grad()
|
104 |
+
def forward_db(self, batch_data):
|
105 |
+
imgs = torch.Tensor(batch_data).cuda()
|
106 |
+
imgs.div_(255).sub_(0.5).div_(0.5)
|
107 |
+
feat = self.model(imgs)
|
108 |
+
feat = feat.reshape([self.batch_size, 2 * feat.shape[1]])
|
109 |
+
return feat.cpu().numpy()
|
110 |
+
|
111 |
+
|
112 |
+
# 将一个list尽量均分成n份,限制len(list)==n,份数大于原list内元素个数则分配空list[]
|
113 |
+
def divideIntoNstrand(listTemp, n):
|
114 |
+
twoList = [[] for i in range(n)]
|
115 |
+
for i, e in enumerate(listTemp):
|
116 |
+
twoList[i % n].append(e)
|
117 |
+
return twoList
|
118 |
+
|
119 |
+
|
120 |
+
def read_template_media_list(path):
|
121 |
+
# ijb_meta = np.loadtxt(path, dtype=str)
|
122 |
+
ijb_meta = pd.read_csv(path, sep=' ', header=None).values
|
123 |
+
templates = ijb_meta[:, 1].astype(np.int)
|
124 |
+
medias = ijb_meta[:, 2].astype(np.int)
|
125 |
+
return templates, medias
|
126 |
+
|
127 |
+
|
128 |
+
# In[ ]:
|
129 |
+
|
130 |
+
|
131 |
+
def read_template_pair_list(path):
|
132 |
+
# pairs = np.loadtxt(path, dtype=str)
|
133 |
+
pairs = pd.read_csv(path, sep=' ', header=None).values
|
134 |
+
# print(pairs.shape)
|
135 |
+
# print(pairs[:, 0].astype(np.int))
|
136 |
+
t1 = pairs[:, 0].astype(np.int)
|
137 |
+
t2 = pairs[:, 1].astype(np.int)
|
138 |
+
label = pairs[:, 2].astype(np.int)
|
139 |
+
return t1, t2, label
|
140 |
+
|
141 |
+
|
142 |
+
# In[ ]:
|
143 |
+
|
144 |
+
|
145 |
+
def read_image_feature(path):
|
146 |
+
with open(path, 'rb') as fid:
|
147 |
+
img_feats = pickle.load(fid)
|
148 |
+
return img_feats
|
149 |
+
|
150 |
+
|
151 |
+
# In[ ]:
|
152 |
+
|
153 |
+
|
154 |
+
def get_image_feature(img_path, files_list, model_path, epoch, gpu_id):
|
155 |
+
batch_size = args.batch_size
|
156 |
+
data_shape = (3, 112, 112)
|
157 |
+
|
158 |
+
files = files_list
|
159 |
+
print('files:', len(files))
|
160 |
+
rare_size = len(files) % batch_size
|
161 |
+
faceness_scores = []
|
162 |
+
batch = 0
|
163 |
+
img_feats = np.empty((len(files), 1024), dtype=np.float32)
|
164 |
+
|
165 |
+
batch_data = np.empty((2 * batch_size, 3, 112, 112))
|
166 |
+
embedding = Embedding(model_path, data_shape, batch_size)
|
167 |
+
for img_index, each_line in enumerate(files[:len(files) - rare_size]):
|
168 |
+
name_lmk_score = each_line.strip().split(' ')
|
169 |
+
img_name = os.path.join(img_path, name_lmk_score[0])
|
170 |
+
img = cv2.imread(img_name)
|
171 |
+
lmk = np.array([float(x) for x in name_lmk_score[1:-1]],
|
172 |
+
dtype=np.float32)
|
173 |
+
lmk = lmk.reshape((5, 2))
|
174 |
+
input_blob = embedding.get(img, lmk)
|
175 |
+
|
176 |
+
batch_data[2 * (img_index - batch * batch_size)][:] = input_blob[0]
|
177 |
+
batch_data[2 * (img_index - batch * batch_size) + 1][:] = input_blob[1]
|
178 |
+
if (img_index + 1) % batch_size == 0:
|
179 |
+
print('batch', batch)
|
180 |
+
img_feats[batch * batch_size:batch * batch_size +
|
181 |
+
batch_size][:] = embedding.forward_db(batch_data)
|
182 |
+
batch += 1
|
183 |
+
faceness_scores.append(name_lmk_score[-1])
|
184 |
+
|
185 |
+
batch_data = np.empty((2 * rare_size, 3, 112, 112))
|
186 |
+
embedding = Embedding(model_path, data_shape, rare_size)
|
187 |
+
for img_index, each_line in enumerate(files[len(files) - rare_size:]):
|
188 |
+
name_lmk_score = each_line.strip().split(' ')
|
189 |
+
img_name = os.path.join(img_path, name_lmk_score[0])
|
190 |
+
img = cv2.imread(img_name)
|
191 |
+
lmk = np.array([float(x) for x in name_lmk_score[1:-1]],
|
192 |
+
dtype=np.float32)
|
193 |
+
lmk = lmk.reshape((5, 2))
|
194 |
+
input_blob = embedding.get(img, lmk)
|
195 |
+
batch_data[2 * img_index][:] = input_blob[0]
|
196 |
+
batch_data[2 * img_index + 1][:] = input_blob[1]
|
197 |
+
if (img_index + 1) % rare_size == 0:
|
198 |
+
print('batch', batch)
|
199 |
+
img_feats[len(files) -
|
200 |
+
rare_size:][:] = embedding.forward_db(batch_data)
|
201 |
+
batch += 1
|
202 |
+
faceness_scores.append(name_lmk_score[-1])
|
203 |
+
faceness_scores = np.array(faceness_scores).astype(np.float32)
|
204 |
+
# img_feats = np.ones( (len(files), 1024), dtype=np.float32) * 0.01
|
205 |
+
# faceness_scores = np.ones( (len(files), ), dtype=np.float32 )
|
206 |
+
return img_feats, faceness_scores
|
207 |
+
|
208 |
+
|
209 |
+
# In[ ]:
|
210 |
+
|
211 |
+
|
212 |
+
def image2template_feature(img_feats=None, templates=None, medias=None):
|
213 |
+
# ==========================================================
|
214 |
+
# 1. face image feature l2 normalization. img_feats:[number_image x feats_dim]
|
215 |
+
# 2. compute media feature.
|
216 |
+
# 3. compute template feature.
|
217 |
+
# ==========================================================
|
218 |
+
unique_templates = np.unique(templates)
|
219 |
+
template_feats = np.zeros((len(unique_templates), img_feats.shape[1]))
|
220 |
+
|
221 |
+
for count_template, uqt in enumerate(unique_templates):
|
222 |
+
|
223 |
+
(ind_t,) = np.where(templates == uqt)
|
224 |
+
face_norm_feats = img_feats[ind_t]
|
225 |
+
face_medias = medias[ind_t]
|
226 |
+
unique_medias, unique_media_counts = np.unique(face_medias,
|
227 |
+
return_counts=True)
|
228 |
+
media_norm_feats = []
|
229 |
+
for u, ct in zip(unique_medias, unique_media_counts):
|
230 |
+
(ind_m,) = np.where(face_medias == u)
|
231 |
+
if ct == 1:
|
232 |
+
media_norm_feats += [face_norm_feats[ind_m]]
|
233 |
+
else: # image features from the same video will be aggregated into one feature
|
234 |
+
media_norm_feats += [
|
235 |
+
np.mean(face_norm_feats[ind_m], axis=0, keepdims=True)
|
236 |
+
]
|
237 |
+
media_norm_feats = np.array(media_norm_feats)
|
238 |
+
# media_norm_feats = media_norm_feats / np.sqrt(np.sum(media_norm_feats ** 2, -1, keepdims=True))
|
239 |
+
template_feats[count_template] = np.sum(media_norm_feats, axis=0)
|
240 |
+
if count_template % 2000 == 0:
|
241 |
+
print('Finish Calculating {} template features.'.format(
|
242 |
+
count_template))
|
243 |
+
# template_norm_feats = template_feats / np.sqrt(np.sum(template_feats ** 2, -1, keepdims=True))
|
244 |
+
template_norm_feats = sklearn.preprocessing.normalize(template_feats)
|
245 |
+
# print(template_norm_feats.shape)
|
246 |
+
return template_norm_feats, unique_templates
|
247 |
+
|
248 |
+
|
249 |
+
# In[ ]:
|
250 |
+
|
251 |
+
|
252 |
+
def verification(template_norm_feats=None,
|
253 |
+
unique_templates=None,
|
254 |
+
p1=None,
|
255 |
+
p2=None):
|
256 |
+
# ==========================================================
|
257 |
+
# Compute set-to-set Similarity Score.
|
258 |
+
# ==========================================================
|
259 |
+
template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int)
|
260 |
+
for count_template, uqt in enumerate(unique_templates):
|
261 |
+
template2id[uqt] = count_template
|
262 |
+
|
263 |
+
score = np.zeros((len(p1),)) # save cosine distance between pairs
|
264 |
+
|
265 |
+
total_pairs = np.array(range(len(p1)))
|
266 |
+
batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation
|
267 |
+
sublists = [
|
268 |
+
total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize)
|
269 |
+
]
|
270 |
+
total_sublists = len(sublists)
|
271 |
+
for c, s in enumerate(sublists):
|
272 |
+
feat1 = template_norm_feats[template2id[p1[s]]]
|
273 |
+
feat2 = template_norm_feats[template2id[p2[s]]]
|
274 |
+
similarity_score = np.sum(feat1 * feat2, -1)
|
275 |
+
score[s] = similarity_score.flatten()
|
276 |
+
if c % 10 == 0:
|
277 |
+
print('Finish {}/{} pairs.'.format(c, total_sublists))
|
278 |
+
return score
|
279 |
+
|
280 |
+
|
281 |
+
# In[ ]:
|
282 |
+
def verification2(template_norm_feats=None,
|
283 |
+
unique_templates=None,
|
284 |
+
p1=None,
|
285 |
+
p2=None):
|
286 |
+
template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int)
|
287 |
+
for count_template, uqt in enumerate(unique_templates):
|
288 |
+
template2id[uqt] = count_template
|
289 |
+
score = np.zeros((len(p1),)) # save cosine distance between pairs
|
290 |
+
total_pairs = np.array(range(len(p1)))
|
291 |
+
batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation
|
292 |
+
sublists = [
|
293 |
+
total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize)
|
294 |
+
]
|
295 |
+
total_sublists = len(sublists)
|
296 |
+
for c, s in enumerate(sublists):
|
297 |
+
feat1 = template_norm_feats[template2id[p1[s]]]
|
298 |
+
feat2 = template_norm_feats[template2id[p2[s]]]
|
299 |
+
similarity_score = np.sum(feat1 * feat2, -1)
|
300 |
+
score[s] = similarity_score.flatten()
|
301 |
+
if c % 10 == 0:
|
302 |
+
print('Finish {}/{} pairs.'.format(c, total_sublists))
|
303 |
+
return score
|
304 |
+
|
305 |
+
|
306 |
+
def read_score(path):
|
307 |
+
with open(path, 'rb') as fid:
|
308 |
+
img_feats = pickle.load(fid)
|
309 |
+
return img_feats
|
310 |
+
|
311 |
+
|
312 |
+
# # Step1: Load Meta Data
|
313 |
+
|
314 |
+
# In[ ]:
|
315 |
+
|
316 |
+
assert target == 'IJBC' or target == 'IJBB'
|
317 |
+
|
318 |
+
# =============================================================
|
319 |
+
# load image and template relationships for template feature embedding
|
320 |
+
# tid --> template id, mid --> media id
|
321 |
+
# format:
|
322 |
+
# image_name tid mid
|
323 |
+
# =============================================================
|
324 |
+
start = timeit.default_timer()
|
325 |
+
templates, medias = read_template_media_list(
|
326 |
+
os.path.join('%s/meta' % image_path,
|
327 |
+
'%s_face_tid_mid.txt' % target.lower()))
|
328 |
+
stop = timeit.default_timer()
|
329 |
+
print('Time: %.2f s. ' % (stop - start))
|
330 |
+
|
331 |
+
# In[ ]:
|
332 |
+
|
333 |
+
# =============================================================
|
334 |
+
# load template pairs for template-to-template verification
|
335 |
+
# tid : template id, label : 1/0
|
336 |
+
# format:
|
337 |
+
# tid_1 tid_2 label
|
338 |
+
# =============================================================
|
339 |
+
start = timeit.default_timer()
|
340 |
+
p1, p2, label = read_template_pair_list(
|
341 |
+
os.path.join('%s/meta' % image_path,
|
342 |
+
'%s_template_pair_label.txt' % target.lower()))
|
343 |
+
stop = timeit.default_timer()
|
344 |
+
print('Time: %.2f s. ' % (stop - start))
|
345 |
+
|
346 |
+
# # Step 2: Get Image Features
|
347 |
+
|
348 |
+
# In[ ]:
|
349 |
+
|
350 |
+
# =============================================================
|
351 |
+
# load image features
|
352 |
+
# format:
|
353 |
+
# img_feats: [image_num x feats_dim] (227630, 512)
|
354 |
+
# =============================================================
|
355 |
+
start = timeit.default_timer()
|
356 |
+
img_path = '%s/loose_crop' % image_path
|
357 |
+
img_list_path = '%s/meta/%s_name_5pts_score.txt' % (image_path, target.lower())
|
358 |
+
img_list = open(img_list_path)
|
359 |
+
files = img_list.readlines()
|
360 |
+
# files_list = divideIntoNstrand(files, rank_size)
|
361 |
+
files_list = files
|
362 |
+
|
363 |
+
# img_feats
|
364 |
+
# for i in range(rank_size):
|
365 |
+
img_feats, faceness_scores = get_image_feature(img_path, files_list,
|
366 |
+
model_path, 0, gpu_id)
|
367 |
+
stop = timeit.default_timer()
|
368 |
+
print('Time: %.2f s. ' % (stop - start))
|
369 |
+
print('Feature Shape: ({} , {}) .'.format(img_feats.shape[0],
|
370 |
+
img_feats.shape[1]))
|
371 |
+
|
372 |
+
# # Step3: Get Template Features
|
373 |
+
|
374 |
+
# In[ ]:
|
375 |
+
|
376 |
+
# =============================================================
|
377 |
+
# compute template features from image features.
|
378 |
+
# =============================================================
|
379 |
+
start = timeit.default_timer()
|
380 |
+
# ==========================================================
|
381 |
+
# Norm feature before aggregation into template feature?
|
382 |
+
# Feature norm from embedding network and faceness score are able to decrease weights for noise samples (not face).
|
383 |
+
# ==========================================================
|
384 |
+
# 1. FaceScore (Feature Norm)
|
385 |
+
# 2. FaceScore (Detector)
|
386 |
+
|
387 |
+
if use_flip_test:
|
388 |
+
# concat --- F1
|
389 |
+
# img_input_feats = img_feats
|
390 |
+
# add --- F2
|
391 |
+
img_input_feats = img_feats[:, 0:img_feats.shape[1] //
|
392 |
+
2] + img_feats[:, img_feats.shape[1] // 2:]
|
393 |
+
else:
|
394 |
+
img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2]
|
395 |
+
|
396 |
+
if use_norm_score:
|
397 |
+
img_input_feats = img_input_feats
|
398 |
+
else:
|
399 |
+
# normalise features to remove norm information
|
400 |
+
img_input_feats = img_input_feats / np.sqrt(
|
401 |
+
np.sum(img_input_feats ** 2, -1, keepdims=True))
|
402 |
+
|
403 |
+
if use_detector_score:
|
404 |
+
print(img_input_feats.shape, faceness_scores.shape)
|
405 |
+
img_input_feats = img_input_feats * faceness_scores[:, np.newaxis]
|
406 |
+
else:
|
407 |
+
img_input_feats = img_input_feats
|
408 |
+
|
409 |
+
template_norm_feats, unique_templates = image2template_feature(
|
410 |
+
img_input_feats, templates, medias)
|
411 |
+
stop = timeit.default_timer()
|
412 |
+
print('Time: %.2f s. ' % (stop - start))
|
413 |
+
|
414 |
+
# # Step 4: Get Template Similarity Scores
|
415 |
+
|
416 |
+
# In[ ]:
|
417 |
+
|
418 |
+
# =============================================================
|
419 |
+
# compute verification scores between template pairs.
|
420 |
+
# =============================================================
|
421 |
+
start = timeit.default_timer()
|
422 |
+
score = verification(template_norm_feats, unique_templates, p1, p2)
|
423 |
+
stop = timeit.default_timer()
|
424 |
+
print('Time: %.2f s. ' % (stop - start))
|
425 |
+
|
426 |
+
# In[ ]:
|
427 |
+
save_path = os.path.join(result_dir, args.job)
|
428 |
+
# save_path = result_dir + '/%s_result' % target
|
429 |
+
|
430 |
+
if not os.path.exists(save_path):
|
431 |
+
os.makedirs(save_path)
|
432 |
+
|
433 |
+
score_save_file = os.path.join(save_path, "%s.npy" % target.lower())
|
434 |
+
np.save(score_save_file, score)
|
435 |
+
|
436 |
+
# # Step 5: Get ROC Curves and TPR@FPR Table
|
437 |
+
|
438 |
+
# In[ ]:
|
439 |
+
|
440 |
+
files = [score_save_file]
|
441 |
+
methods = []
|
442 |
+
scores = []
|
443 |
+
for file in files:
|
444 |
+
methods.append(Path(file).stem)
|
445 |
+
scores.append(np.load(file))
|
446 |
+
|
447 |
+
methods = np.array(methods)
|
448 |
+
scores = dict(zip(methods, scores))
|
449 |
+
colours = dict(
|
450 |
+
zip(methods, sample_colours_from_colourmap(methods.shape[0], 'Set2')))
|
451 |
+
x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1]
|
452 |
+
tpr_fpr_table = PrettyTable(['Methods'] + [str(x) for x in x_labels])
|
453 |
+
fig = plt.figure()
|
454 |
+
for method in methods:
|
455 |
+
fpr, tpr, _ = roc_curve(label, scores[method])
|
456 |
+
roc_auc = auc(fpr, tpr)
|
457 |
+
fpr = np.flipud(fpr)
|
458 |
+
tpr = np.flipud(tpr) # select largest tpr at same fpr
|
459 |
+
plt.plot(fpr,
|
460 |
+
tpr,
|
461 |
+
color=colours[method],
|
462 |
+
lw=1,
|
463 |
+
label=('[%s (AUC = %0.4f %%)]' %
|
464 |
+
(method.split('-')[-1], roc_auc * 100)))
|
465 |
+
tpr_fpr_row = []
|
466 |
+
tpr_fpr_row.append("%s-%s" % (method, target))
|
467 |
+
for fpr_iter in np.arange(len(x_labels)):
|
468 |
+
_, min_index = min(
|
469 |
+
list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr)))))
|
470 |
+
tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100))
|
471 |
+
tpr_fpr_table.add_row(tpr_fpr_row)
|
472 |
+
plt.xlim([10 ** -6, 0.1])
|
473 |
+
plt.ylim([0.3, 1.0])
|
474 |
+
plt.grid(linestyle='--', linewidth=1)
|
475 |
+
plt.xticks(x_labels)
|
476 |
+
plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True))
|
477 |
+
plt.xscale('log')
|
478 |
+
plt.xlabel('False Positive Rate')
|
479 |
+
plt.ylabel('True Positive Rate')
|
480 |
+
plt.title('ROC on IJB')
|
481 |
+
plt.legend(loc="lower right")
|
482 |
+
fig.savefig(os.path.join(save_path, '%s.pdf' % target.lower()))
|
483 |
+
print(tpr_fpr_table)
|
videoretalking/third_part/face3d/models/arcface_torch/inference.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from backbones import get_model
|
8 |
+
|
9 |
+
|
10 |
+
@torch.no_grad()
|
11 |
+
def inference(weight, name, img):
|
12 |
+
if img is None:
|
13 |
+
img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.uint8)
|
14 |
+
else:
|
15 |
+
img = cv2.imread(img)
|
16 |
+
img = cv2.resize(img, (112, 112))
|
17 |
+
|
18 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
19 |
+
img = np.transpose(img, (2, 0, 1))
|
20 |
+
img = torch.from_numpy(img).unsqueeze(0).float()
|
21 |
+
img.div_(255).sub_(0.5).div_(0.5)
|
22 |
+
net = get_model(name, fp16=False)
|
23 |
+
net.load_state_dict(torch.load(weight))
|
24 |
+
net.eval()
|
25 |
+
feat = net(img).numpy()
|
26 |
+
print(feat)
|
27 |
+
|
28 |
+
|
29 |
+
if __name__ == "__main__":
|
30 |
+
parser = argparse.ArgumentParser(description='PyTorch ArcFace Training')
|
31 |
+
parser.add_argument('--network', type=str, default='r50', help='backbone network')
|
32 |
+
parser.add_argument('--weight', type=str, default='')
|
33 |
+
parser.add_argument('--img', type=str, default=None)
|
34 |
+
args = parser.parse_args()
|
35 |
+
inference(args.weight, args.network, args.img)
|
videoretalking/third_part/face3d/models/arcface_torch/losses.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
|
5 |
+
def get_loss(name):
|
6 |
+
if name == "cosface":
|
7 |
+
return CosFace()
|
8 |
+
elif name == "arcface":
|
9 |
+
return ArcFace()
|
10 |
+
else:
|
11 |
+
raise ValueError()
|
12 |
+
|
13 |
+
|
14 |
+
class CosFace(nn.Module):
|
15 |
+
def __init__(self, s=64.0, m=0.40):
|
16 |
+
super(CosFace, self).__init__()
|
17 |
+
self.s = s
|
18 |
+
self.m = m
|
19 |
+
|
20 |
+
def forward(self, cosine, label):
|
21 |
+
index = torch.where(label != -1)[0]
|
22 |
+
m_hot = torch.zeros(index.size()[0], cosine.size()[1], device=cosine.device)
|
23 |
+
m_hot.scatter_(1, label[index, None], self.m)
|
24 |
+
cosine[index] -= m_hot
|
25 |
+
ret = cosine * self.s
|
26 |
+
return ret
|
27 |
+
|
28 |
+
|
29 |
+
class ArcFace(nn.Module):
|
30 |
+
def __init__(self, s=64.0, m=0.5):
|
31 |
+
super(ArcFace, self).__init__()
|
32 |
+
self.s = s
|
33 |
+
self.m = m
|
34 |
+
|
35 |
+
def forward(self, cosine: torch.Tensor, label):
|
36 |
+
index = torch.where(label != -1)[0]
|
37 |
+
m_hot = torch.zeros(index.size()[0], cosine.size()[1], device=cosine.device)
|
38 |
+
m_hot.scatter_(1, label[index, None], self.m)
|
39 |
+
cosine.acos_()
|
40 |
+
cosine[index] += m_hot
|
41 |
+
cosine.cos_().mul_(self.s)
|
42 |
+
return cosine
|
videoretalking/third_part/face3d/models/arcface_torch/onnx_helper.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import division
|
2 |
+
import datetime
|
3 |
+
import os
|
4 |
+
import os.path as osp
|
5 |
+
import glob
|
6 |
+
import numpy as np
|
7 |
+
import cv2
|
8 |
+
import sys
|
9 |
+
import onnxruntime
|
10 |
+
import onnx
|
11 |
+
import argparse
|
12 |
+
from onnx import numpy_helper
|
13 |
+
from insightface.data import get_image
|
14 |
+
|
15 |
+
class ArcFaceORT:
|
16 |
+
def __init__(self, model_path, cpu=False):
|
17 |
+
self.model_path = model_path
|
18 |
+
# providers = None will use available provider, for onnxruntime-gpu it will be "CUDAExecutionProvider"
|
19 |
+
self.providers = ['CPUExecutionProvider'] if cpu else None
|
20 |
+
|
21 |
+
#input_size is (w,h), return error message, return None if success
|
22 |
+
def check(self, track='cfat', test_img = None):
|
23 |
+
#default is cfat
|
24 |
+
max_model_size_mb=1024
|
25 |
+
max_feat_dim=512
|
26 |
+
max_time_cost=15
|
27 |
+
if track.startswith('ms1m'):
|
28 |
+
max_model_size_mb=1024
|
29 |
+
max_feat_dim=512
|
30 |
+
max_time_cost=10
|
31 |
+
elif track.startswith('glint'):
|
32 |
+
max_model_size_mb=1024
|
33 |
+
max_feat_dim=1024
|
34 |
+
max_time_cost=20
|
35 |
+
elif track.startswith('cfat'):
|
36 |
+
max_model_size_mb = 1024
|
37 |
+
max_feat_dim = 512
|
38 |
+
max_time_cost = 15
|
39 |
+
elif track.startswith('unconstrained'):
|
40 |
+
max_model_size_mb=1024
|
41 |
+
max_feat_dim=1024
|
42 |
+
max_time_cost=30
|
43 |
+
else:
|
44 |
+
return "track not found"
|
45 |
+
|
46 |
+
if not os.path.exists(self.model_path):
|
47 |
+
return "model_path not exists"
|
48 |
+
if not os.path.isdir(self.model_path):
|
49 |
+
return "model_path should be directory"
|
50 |
+
onnx_files = []
|
51 |
+
for _file in os.listdir(self.model_path):
|
52 |
+
if _file.endswith('.onnx'):
|
53 |
+
onnx_files.append(osp.join(self.model_path, _file))
|
54 |
+
if len(onnx_files)==0:
|
55 |
+
return "do not have onnx files"
|
56 |
+
self.model_file = sorted(onnx_files)[-1]
|
57 |
+
print('use onnx-model:', self.model_file)
|
58 |
+
try:
|
59 |
+
session = onnxruntime.InferenceSession(self.model_file, providers=self.providers)
|
60 |
+
except:
|
61 |
+
return "load onnx failed"
|
62 |
+
input_cfg = session.get_inputs()[0]
|
63 |
+
input_shape = input_cfg.shape
|
64 |
+
print('input-shape:', input_shape)
|
65 |
+
if len(input_shape)!=4:
|
66 |
+
return "length of input_shape should be 4"
|
67 |
+
if not isinstance(input_shape[0], str):
|
68 |
+
#return "input_shape[0] should be str to support batch-inference"
|
69 |
+
print('reset input-shape[0] to None')
|
70 |
+
model = onnx.load(self.model_file)
|
71 |
+
model.graph.input[0].type.tensor_type.shape.dim[0].dim_param = 'None'
|
72 |
+
new_model_file = osp.join(self.model_path, 'zzzzrefined.onnx')
|
73 |
+
onnx.save(model, new_model_file)
|
74 |
+
self.model_file = new_model_file
|
75 |
+
print('use new onnx-model:', self.model_file)
|
76 |
+
try:
|
77 |
+
session = onnxruntime.InferenceSession(self.model_file, providers=self.providers)
|
78 |
+
except:
|
79 |
+
return "load onnx failed"
|
80 |
+
input_cfg = session.get_inputs()[0]
|
81 |
+
input_shape = input_cfg.shape
|
82 |
+
print('new-input-shape:', input_shape)
|
83 |
+
|
84 |
+
self.image_size = tuple(input_shape[2:4][::-1])
|
85 |
+
#print('image_size:', self.image_size)
|
86 |
+
input_name = input_cfg.name
|
87 |
+
outputs = session.get_outputs()
|
88 |
+
output_names = []
|
89 |
+
for o in outputs:
|
90 |
+
output_names.append(o.name)
|
91 |
+
#print(o.name, o.shape)
|
92 |
+
if len(output_names)!=1:
|
93 |
+
return "number of output nodes should be 1"
|
94 |
+
self.session = session
|
95 |
+
self.input_name = input_name
|
96 |
+
self.output_names = output_names
|
97 |
+
#print(self.output_names)
|
98 |
+
model = onnx.load(self.model_file)
|
99 |
+
graph = model.graph
|
100 |
+
if len(graph.node)<8:
|
101 |
+
return "too small onnx graph"
|
102 |
+
|
103 |
+
input_size = (112,112)
|
104 |
+
self.crop = None
|
105 |
+
if track=='cfat':
|
106 |
+
crop_file = osp.join(self.model_path, 'crop.txt')
|
107 |
+
if osp.exists(crop_file):
|
108 |
+
lines = open(crop_file,'r').readlines()
|
109 |
+
if len(lines)!=6:
|
110 |
+
return "crop.txt should contain 6 lines"
|
111 |
+
lines = [int(x) for x in lines]
|
112 |
+
self.crop = lines[:4]
|
113 |
+
input_size = tuple(lines[4:6])
|
114 |
+
if input_size!=self.image_size:
|
115 |
+
return "input-size is inconsistant with onnx model input, %s vs %s"%(input_size, self.image_size)
|
116 |
+
|
117 |
+
self.model_size_mb = os.path.getsize(self.model_file) / float(1024*1024)
|
118 |
+
if self.model_size_mb > max_model_size_mb:
|
119 |
+
return "max model size exceed, given %.3f-MB"%self.model_size_mb
|
120 |
+
|
121 |
+
input_mean = None
|
122 |
+
input_std = None
|
123 |
+
if track=='cfat':
|
124 |
+
pn_file = osp.join(self.model_path, 'pixel_norm.txt')
|
125 |
+
if osp.exists(pn_file):
|
126 |
+
lines = open(pn_file,'r').readlines()
|
127 |
+
if len(lines)!=2:
|
128 |
+
return "pixel_norm.txt should contain 2 lines"
|
129 |
+
input_mean = float(lines[0])
|
130 |
+
input_std = float(lines[1])
|
131 |
+
if input_mean is not None or input_std is not None:
|
132 |
+
if input_mean is None or input_std is None:
|
133 |
+
return "please set input_mean and input_std simultaneously"
|
134 |
+
else:
|
135 |
+
find_sub = False
|
136 |
+
find_mul = False
|
137 |
+
for nid, node in enumerate(graph.node[:8]):
|
138 |
+
print(nid, node.name)
|
139 |
+
if node.name.startswith('Sub') or node.name.startswith('_minus'):
|
140 |
+
find_sub = True
|
141 |
+
if node.name.startswith('Mul') or node.name.startswith('_mul') or node.name.startswith('Div'):
|
142 |
+
find_mul = True
|
143 |
+
if find_sub and find_mul:
|
144 |
+
print("find sub and mul")
|
145 |
+
#mxnet arcface model
|
146 |
+
input_mean = 0.0
|
147 |
+
input_std = 1.0
|
148 |
+
else:
|
149 |
+
input_mean = 127.5
|
150 |
+
input_std = 127.5
|
151 |
+
self.input_mean = input_mean
|
152 |
+
self.input_std = input_std
|
153 |
+
for initn in graph.initializer:
|
154 |
+
weight_array = numpy_helper.to_array(initn)
|
155 |
+
dt = weight_array.dtype
|
156 |
+
if dt.itemsize<4:
|
157 |
+
return 'invalid weight type - (%s:%s)' % (initn.name, dt.name)
|
158 |
+
if test_img is None:
|
159 |
+
test_img = get_image('Tom_Hanks_54745')
|
160 |
+
test_img = cv2.resize(test_img, self.image_size)
|
161 |
+
else:
|
162 |
+
test_img = cv2.resize(test_img, self.image_size)
|
163 |
+
feat, cost = self.benchmark(test_img)
|
164 |
+
batch_result = self.check_batch(test_img)
|
165 |
+
batch_result_sum = float(np.sum(batch_result))
|
166 |
+
if batch_result_sum in [float('inf'), -float('inf')] or batch_result_sum != batch_result_sum:
|
167 |
+
print(batch_result)
|
168 |
+
print(batch_result_sum)
|
169 |
+
return "batch result output contains NaN!"
|
170 |
+
|
171 |
+
if len(feat.shape) < 2:
|
172 |
+
return "the shape of the feature must be two, but get {}".format(str(feat.shape))
|
173 |
+
|
174 |
+
if feat.shape[1] > max_feat_dim:
|
175 |
+
return "max feat dim exceed, given %d"%feat.shape[1]
|
176 |
+
self.feat_dim = feat.shape[1]
|
177 |
+
cost_ms = cost*1000
|
178 |
+
if cost_ms>max_time_cost:
|
179 |
+
return "max time cost exceed, given %.4f"%cost_ms
|
180 |
+
self.cost_ms = cost_ms
|
181 |
+
print('check stat:, model-size-mb: %.4f, feat-dim: %d, time-cost-ms: %.4f, input-mean: %.3f, input-std: %.3f'%(self.model_size_mb, self.feat_dim, self.cost_ms, self.input_mean, self.input_std))
|
182 |
+
return None
|
183 |
+
|
184 |
+
def check_batch(self, img):
|
185 |
+
if not isinstance(img, list):
|
186 |
+
imgs = [img, ] * 32
|
187 |
+
if self.crop is not None:
|
188 |
+
nimgs = []
|
189 |
+
for img in imgs:
|
190 |
+
nimg = img[self.crop[1]:self.crop[3], self.crop[0]:self.crop[2], :]
|
191 |
+
if nimg.shape[0] != self.image_size[1] or nimg.shape[1] != self.image_size[0]:
|
192 |
+
nimg = cv2.resize(nimg, self.image_size)
|
193 |
+
nimgs.append(nimg)
|
194 |
+
imgs = nimgs
|
195 |
+
blob = cv2.dnn.blobFromImages(
|
196 |
+
images=imgs, scalefactor=1.0 / self.input_std, size=self.image_size,
|
197 |
+
mean=(self.input_mean, self.input_mean, self.input_mean), swapRB=True)
|
198 |
+
net_out = self.session.run(self.output_names, {self.input_name: blob})[0]
|
199 |
+
return net_out
|
200 |
+
|
201 |
+
|
202 |
+
def meta_info(self):
|
203 |
+
return {'model-size-mb':self.model_size_mb, 'feature-dim':self.feat_dim, 'infer': self.cost_ms}
|
204 |
+
|
205 |
+
|
206 |
+
def forward(self, imgs):
|
207 |
+
if not isinstance(imgs, list):
|
208 |
+
imgs = [imgs]
|
209 |
+
input_size = self.image_size
|
210 |
+
if self.crop is not None:
|
211 |
+
nimgs = []
|
212 |
+
for img in imgs:
|
213 |
+
nimg = img[self.crop[1]:self.crop[3],self.crop[0]:self.crop[2],:]
|
214 |
+
if nimg.shape[0]!=input_size[1] or nimg.shape[1]!=input_size[0]:
|
215 |
+
nimg = cv2.resize(nimg, input_size)
|
216 |
+
nimgs.append(nimg)
|
217 |
+
imgs = nimgs
|
218 |
+
blob = cv2.dnn.blobFromImages(imgs, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
|
219 |
+
net_out = self.session.run(self.output_names, {self.input_name : blob})[0]
|
220 |
+
return net_out
|
221 |
+
|
222 |
+
def benchmark(self, img):
|
223 |
+
input_size = self.image_size
|
224 |
+
if self.crop is not None:
|
225 |
+
nimg = img[self.crop[1]:self.crop[3],self.crop[0]:self.crop[2],:]
|
226 |
+
if nimg.shape[0]!=input_size[1] or nimg.shape[1]!=input_size[0]:
|
227 |
+
nimg = cv2.resize(nimg, input_size)
|
228 |
+
img = nimg
|
229 |
+
blob = cv2.dnn.blobFromImage(img, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
|
230 |
+
costs = []
|
231 |
+
for _ in range(50):
|
232 |
+
ta = datetime.datetime.now()
|
233 |
+
net_out = self.session.run(self.output_names, {self.input_name : blob})[0]
|
234 |
+
tb = datetime.datetime.now()
|
235 |
+
cost = (tb-ta).total_seconds()
|
236 |
+
costs.append(cost)
|
237 |
+
costs = sorted(costs)
|
238 |
+
cost = costs[5]
|
239 |
+
return net_out, cost
|
240 |
+
|
241 |
+
|
242 |
+
if __name__ == '__main__':
|
243 |
+
parser = argparse.ArgumentParser(description='')
|
244 |
+
# general
|
245 |
+
parser.add_argument('workdir', help='submitted work dir', type=str)
|
246 |
+
parser.add_argument('--track', help='track name, for different challenge', type=str, default='cfat')
|
247 |
+
args = parser.parse_args()
|
248 |
+
handler = ArcFaceORT(args.workdir)
|
249 |
+
err = handler.check(args.track)
|
250 |
+
print('err:', err)
|
videoretalking/third_part/face3d/models/arcface_torch/onnx_ijbc.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import pickle
|
4 |
+
import timeit
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
import mxnet as mx
|
8 |
+
import numpy as np
|
9 |
+
import pandas as pd
|
10 |
+
import prettytable
|
11 |
+
import skimage.transform
|
12 |
+
from sklearn.metrics import roc_curve
|
13 |
+
from sklearn.preprocessing import normalize
|
14 |
+
|
15 |
+
from onnx_helper import ArcFaceORT
|
16 |
+
|
17 |
+
SRC = np.array(
|
18 |
+
[
|
19 |
+
[30.2946, 51.6963],
|
20 |
+
[65.5318, 51.5014],
|
21 |
+
[48.0252, 71.7366],
|
22 |
+
[33.5493, 92.3655],
|
23 |
+
[62.7299, 92.2041]]
|
24 |
+
, dtype=np.float32)
|
25 |
+
SRC[:, 0] += 8.0
|
26 |
+
|
27 |
+
|
28 |
+
class AlignedDataSet(mx.gluon.data.Dataset):
|
29 |
+
def __init__(self, root, lines, align=True):
|
30 |
+
self.lines = lines
|
31 |
+
self.root = root
|
32 |
+
self.align = align
|
33 |
+
|
34 |
+
def __len__(self):
|
35 |
+
return len(self.lines)
|
36 |
+
|
37 |
+
def __getitem__(self, idx):
|
38 |
+
each_line = self.lines[idx]
|
39 |
+
name_lmk_score = each_line.strip().split(' ')
|
40 |
+
name = os.path.join(self.root, name_lmk_score[0])
|
41 |
+
img = cv2.cvtColor(cv2.imread(name), cv2.COLOR_BGR2RGB)
|
42 |
+
landmark5 = np.array([float(x) for x in name_lmk_score[1:-1]], dtype=np.float32).reshape((5, 2))
|
43 |
+
st = skimage.transform.SimilarityTransform()
|
44 |
+
st.estimate(landmark5, SRC)
|
45 |
+
img = cv2.warpAffine(img, st.params[0:2, :], (112, 112), borderValue=0.0)
|
46 |
+
img_1 = np.expand_dims(img, 0)
|
47 |
+
img_2 = np.expand_dims(np.fliplr(img), 0)
|
48 |
+
output = np.concatenate((img_1, img_2), axis=0).astype(np.float32)
|
49 |
+
output = np.transpose(output, (0, 3, 1, 2))
|
50 |
+
output = mx.nd.array(output)
|
51 |
+
return output
|
52 |
+
|
53 |
+
|
54 |
+
def extract(model_root, dataset):
|
55 |
+
model = ArcFaceORT(model_path=model_root)
|
56 |
+
model.check()
|
57 |
+
feat_mat = np.zeros(shape=(len(dataset), 2 * model.feat_dim))
|
58 |
+
|
59 |
+
def batchify_fn(data):
|
60 |
+
return mx.nd.concat(*data, dim=0)
|
61 |
+
|
62 |
+
data_loader = mx.gluon.data.DataLoader(
|
63 |
+
dataset, 128, last_batch='keep', num_workers=4,
|
64 |
+
thread_pool=True, prefetch=16, batchify_fn=batchify_fn)
|
65 |
+
num_iter = 0
|
66 |
+
for batch in data_loader:
|
67 |
+
batch = batch.asnumpy()
|
68 |
+
batch = (batch - model.input_mean) / model.input_std
|
69 |
+
feat = model.session.run(model.output_names, {model.input_name: batch})[0]
|
70 |
+
feat = np.reshape(feat, (-1, model.feat_dim * 2))
|
71 |
+
feat_mat[128 * num_iter: 128 * num_iter + feat.shape[0], :] = feat
|
72 |
+
num_iter += 1
|
73 |
+
if num_iter % 50 == 0:
|
74 |
+
print(num_iter)
|
75 |
+
return feat_mat
|
76 |
+
|
77 |
+
|
78 |
+
def read_template_media_list(path):
|
79 |
+
ijb_meta = pd.read_csv(path, sep=' ', header=None).values
|
80 |
+
templates = ijb_meta[:, 1].astype(np.int)
|
81 |
+
medias = ijb_meta[:, 2].astype(np.int)
|
82 |
+
return templates, medias
|
83 |
+
|
84 |
+
|
85 |
+
def read_template_pair_list(path):
|
86 |
+
pairs = pd.read_csv(path, sep=' ', header=None).values
|
87 |
+
t1 = pairs[:, 0].astype(np.int)
|
88 |
+
t2 = pairs[:, 1].astype(np.int)
|
89 |
+
label = pairs[:, 2].astype(np.int)
|
90 |
+
return t1, t2, label
|
91 |
+
|
92 |
+
|
93 |
+
def read_image_feature(path):
|
94 |
+
with open(path, 'rb') as fid:
|
95 |
+
img_feats = pickle.load(fid)
|
96 |
+
return img_feats
|
97 |
+
|
98 |
+
|
99 |
+
def image2template_feature(img_feats=None,
|
100 |
+
templates=None,
|
101 |
+
medias=None):
|
102 |
+
unique_templates = np.unique(templates)
|
103 |
+
template_feats = np.zeros((len(unique_templates), img_feats.shape[1]))
|
104 |
+
for count_template, uqt in enumerate(unique_templates):
|
105 |
+
(ind_t,) = np.where(templates == uqt)
|
106 |
+
face_norm_feats = img_feats[ind_t]
|
107 |
+
face_medias = medias[ind_t]
|
108 |
+
unique_medias, unique_media_counts = np.unique(face_medias, return_counts=True)
|
109 |
+
media_norm_feats = []
|
110 |
+
for u, ct in zip(unique_medias, unique_media_counts):
|
111 |
+
(ind_m,) = np.where(face_medias == u)
|
112 |
+
if ct == 1:
|
113 |
+
media_norm_feats += [face_norm_feats[ind_m]]
|
114 |
+
else: # image features from the same video will be aggregated into one feature
|
115 |
+
media_norm_feats += [np.mean(face_norm_feats[ind_m], axis=0, keepdims=True), ]
|
116 |
+
media_norm_feats = np.array(media_norm_feats)
|
117 |
+
template_feats[count_template] = np.sum(media_norm_feats, axis=0)
|
118 |
+
if count_template % 2000 == 0:
|
119 |
+
print('Finish Calculating {} template features.'.format(
|
120 |
+
count_template))
|
121 |
+
template_norm_feats = normalize(template_feats)
|
122 |
+
return template_norm_feats, unique_templates
|
123 |
+
|
124 |
+
|
125 |
+
def verification(template_norm_feats=None,
|
126 |
+
unique_templates=None,
|
127 |
+
p1=None,
|
128 |
+
p2=None):
|
129 |
+
template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int)
|
130 |
+
for count_template, uqt in enumerate(unique_templates):
|
131 |
+
template2id[uqt] = count_template
|
132 |
+
score = np.zeros((len(p1),))
|
133 |
+
total_pairs = np.array(range(len(p1)))
|
134 |
+
batchsize = 100000
|
135 |
+
sublists = [total_pairs[i: i + batchsize] for i in range(0, len(p1), batchsize)]
|
136 |
+
total_sublists = len(sublists)
|
137 |
+
for c, s in enumerate(sublists):
|
138 |
+
feat1 = template_norm_feats[template2id[p1[s]]]
|
139 |
+
feat2 = template_norm_feats[template2id[p2[s]]]
|
140 |
+
similarity_score = np.sum(feat1 * feat2, -1)
|
141 |
+
score[s] = similarity_score.flatten()
|
142 |
+
if c % 10 == 0:
|
143 |
+
print('Finish {}/{} pairs.'.format(c, total_sublists))
|
144 |
+
return score
|
145 |
+
|
146 |
+
|
147 |
+
def verification2(template_norm_feats=None,
|
148 |
+
unique_templates=None,
|
149 |
+
p1=None,
|
150 |
+
p2=None):
|
151 |
+
template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int)
|
152 |
+
for count_template, uqt in enumerate(unique_templates):
|
153 |
+
template2id[uqt] = count_template
|
154 |
+
score = np.zeros((len(p1),)) # save cosine distance between pairs
|
155 |
+
total_pairs = np.array(range(len(p1)))
|
156 |
+
batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation
|
157 |
+
sublists = [total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize)]
|
158 |
+
total_sublists = len(sublists)
|
159 |
+
for c, s in enumerate(sublists):
|
160 |
+
feat1 = template_norm_feats[template2id[p1[s]]]
|
161 |
+
feat2 = template_norm_feats[template2id[p2[s]]]
|
162 |
+
similarity_score = np.sum(feat1 * feat2, -1)
|
163 |
+
score[s] = similarity_score.flatten()
|
164 |
+
if c % 10 == 0:
|
165 |
+
print('Finish {}/{} pairs.'.format(c, total_sublists))
|
166 |
+
return score
|
167 |
+
|
168 |
+
|
169 |
+
def main(args):
|
170 |
+
use_norm_score = True # if Ture, TestMode(N1)
|
171 |
+
use_detector_score = True # if Ture, TestMode(D1)
|
172 |
+
use_flip_test = True # if Ture, TestMode(F1)
|
173 |
+
assert args.target == 'IJBC' or args.target == 'IJBB'
|
174 |
+
|
175 |
+
start = timeit.default_timer()
|
176 |
+
templates, medias = read_template_media_list(
|
177 |
+
os.path.join('%s/meta' % args.image_path, '%s_face_tid_mid.txt' % args.target.lower()))
|
178 |
+
stop = timeit.default_timer()
|
179 |
+
print('Time: %.2f s. ' % (stop - start))
|
180 |
+
|
181 |
+
start = timeit.default_timer()
|
182 |
+
p1, p2, label = read_template_pair_list(
|
183 |
+
os.path.join('%s/meta' % args.image_path,
|
184 |
+
'%s_template_pair_label.txt' % args.target.lower()))
|
185 |
+
stop = timeit.default_timer()
|
186 |
+
print('Time: %.2f s. ' % (stop - start))
|
187 |
+
|
188 |
+
start = timeit.default_timer()
|
189 |
+
img_path = '%s/loose_crop' % args.image_path
|
190 |
+
img_list_path = '%s/meta/%s_name_5pts_score.txt' % (args.image_path, args.target.lower())
|
191 |
+
img_list = open(img_list_path)
|
192 |
+
files = img_list.readlines()
|
193 |
+
dataset = AlignedDataSet(root=img_path, lines=files, align=True)
|
194 |
+
img_feats = extract(args.model_root, dataset)
|
195 |
+
|
196 |
+
faceness_scores = []
|
197 |
+
for each_line in files:
|
198 |
+
name_lmk_score = each_line.split()
|
199 |
+
faceness_scores.append(name_lmk_score[-1])
|
200 |
+
faceness_scores = np.array(faceness_scores).astype(np.float32)
|
201 |
+
stop = timeit.default_timer()
|
202 |
+
print('Time: %.2f s. ' % (stop - start))
|
203 |
+
print('Feature Shape: ({} , {}) .'.format(img_feats.shape[0], img_feats.shape[1]))
|
204 |
+
start = timeit.default_timer()
|
205 |
+
|
206 |
+
if use_flip_test:
|
207 |
+
img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2] + img_feats[:, img_feats.shape[1] // 2:]
|
208 |
+
else:
|
209 |
+
img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2]
|
210 |
+
|
211 |
+
if use_norm_score:
|
212 |
+
img_input_feats = img_input_feats
|
213 |
+
else:
|
214 |
+
img_input_feats = img_input_feats / np.sqrt(np.sum(img_input_feats ** 2, -1, keepdims=True))
|
215 |
+
|
216 |
+
if use_detector_score:
|
217 |
+
print(img_input_feats.shape, faceness_scores.shape)
|
218 |
+
img_input_feats = img_input_feats * faceness_scores[:, np.newaxis]
|
219 |
+
else:
|
220 |
+
img_input_feats = img_input_feats
|
221 |
+
|
222 |
+
template_norm_feats, unique_templates = image2template_feature(
|
223 |
+
img_input_feats, templates, medias)
|
224 |
+
stop = timeit.default_timer()
|
225 |
+
print('Time: %.2f s. ' % (stop - start))
|
226 |
+
|
227 |
+
start = timeit.default_timer()
|
228 |
+
score = verification(template_norm_feats, unique_templates, p1, p2)
|
229 |
+
stop = timeit.default_timer()
|
230 |
+
print('Time: %.2f s. ' % (stop - start))
|
231 |
+
save_path = os.path.join(args.result_dir, "{}_result".format(args.target))
|
232 |
+
if not os.path.exists(save_path):
|
233 |
+
os.makedirs(save_path)
|
234 |
+
score_save_file = os.path.join(save_path, "{}.npy".format(args.model_root))
|
235 |
+
np.save(score_save_file, score)
|
236 |
+
files = [score_save_file]
|
237 |
+
methods = []
|
238 |
+
scores = []
|
239 |
+
for file in files:
|
240 |
+
methods.append(os.path.basename(file))
|
241 |
+
scores.append(np.load(file))
|
242 |
+
methods = np.array(methods)
|
243 |
+
scores = dict(zip(methods, scores))
|
244 |
+
x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1]
|
245 |
+
tpr_fpr_table = prettytable.PrettyTable(['Methods'] + [str(x) for x in x_labels])
|
246 |
+
for method in methods:
|
247 |
+
fpr, tpr, _ = roc_curve(label, scores[method])
|
248 |
+
fpr = np.flipud(fpr)
|
249 |
+
tpr = np.flipud(tpr)
|
250 |
+
tpr_fpr_row = []
|
251 |
+
tpr_fpr_row.append("%s-%s" % (method, args.target))
|
252 |
+
for fpr_iter in np.arange(len(x_labels)):
|
253 |
+
_, min_index = min(
|
254 |
+
list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr)))))
|
255 |
+
tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100))
|
256 |
+
tpr_fpr_table.add_row(tpr_fpr_row)
|
257 |
+
print(tpr_fpr_table)
|
258 |
+
|
259 |
+
|
260 |
+
if __name__ == '__main__':
|
261 |
+
parser = argparse.ArgumentParser(description='do ijb test')
|
262 |
+
# general
|
263 |
+
parser.add_argument('--model-root', default='', help='path to load model.')
|
264 |
+
parser.add_argument('--image-path', default='', type=str, help='')
|
265 |
+
parser.add_argument('--result-dir', default='.', type=str, help='')
|
266 |
+
parser.add_argument('--target', default='IJBC', type=str, help='target, set to IJBC or IJBB')
|
267 |
+
main(parser.parse_args())
|
videoretalking/third_part/face3d/models/arcface_torch/partial_fc.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.distributed as dist
|
6 |
+
from torch.nn import Module
|
7 |
+
from torch.nn.functional import normalize, linear
|
8 |
+
from torch.nn.parameter import Parameter
|
9 |
+
|
10 |
+
|
11 |
+
class PartialFC(Module):
|
12 |
+
"""
|
13 |
+
Author: {Xiang An, Yang Xiao, XuHan Zhu} in DeepGlint,
|
14 |
+
Partial FC: Training 10 Million Identities on a Single Machine
|
15 |
+
See the original paper:
|
16 |
+
https://arxiv.org/abs/2010.05222
|
17 |
+
"""
|
18 |
+
|
19 |
+
@torch.no_grad()
|
20 |
+
def __init__(self, rank, local_rank, world_size, batch_size, resume,
|
21 |
+
margin_softmax, num_classes, sample_rate=1.0, embedding_size=512, prefix="./"):
|
22 |
+
"""
|
23 |
+
rank: int
|
24 |
+
Unique process(GPU) ID from 0 to world_size - 1.
|
25 |
+
local_rank: int
|
26 |
+
Unique process(GPU) ID within the server from 0 to 7.
|
27 |
+
world_size: int
|
28 |
+
Number of GPU.
|
29 |
+
batch_size: int
|
30 |
+
Batch size on current rank(GPU).
|
31 |
+
resume: bool
|
32 |
+
Select whether to restore the weight of softmax.
|
33 |
+
margin_softmax: callable
|
34 |
+
A function of margin softmax, eg: cosface, arcface.
|
35 |
+
num_classes: int
|
36 |
+
The number of class center storage in current rank(CPU/GPU), usually is total_classes // world_size,
|
37 |
+
required.
|
38 |
+
sample_rate: float
|
39 |
+
The partial fc sampling rate, when the number of classes increases to more than 2 millions, Sampling
|
40 |
+
can greatly speed up training, and reduce a lot of GPU memory, default is 1.0.
|
41 |
+
embedding_size: int
|
42 |
+
The feature dimension, default is 512.
|
43 |
+
prefix: str
|
44 |
+
Path for save checkpoint, default is './'.
|
45 |
+
"""
|
46 |
+
super(PartialFC, self).__init__()
|
47 |
+
#
|
48 |
+
self.num_classes: int = num_classes
|
49 |
+
self.rank: int = rank
|
50 |
+
self.local_rank: int = local_rank
|
51 |
+
self.device: torch.device = torch.device("cuda:{}".format(self.local_rank))
|
52 |
+
self.world_size: int = world_size
|
53 |
+
self.batch_size: int = batch_size
|
54 |
+
self.margin_softmax: callable = margin_softmax
|
55 |
+
self.sample_rate: float = sample_rate
|
56 |
+
self.embedding_size: int = embedding_size
|
57 |
+
self.prefix: str = prefix
|
58 |
+
self.num_local: int = num_classes // world_size + int(rank < num_classes % world_size)
|
59 |
+
self.class_start: int = num_classes // world_size * rank + min(rank, num_classes % world_size)
|
60 |
+
self.num_sample: int = int(self.sample_rate * self.num_local)
|
61 |
+
|
62 |
+
self.weight_name = os.path.join(self.prefix, "rank_{}_softmax_weight.pt".format(self.rank))
|
63 |
+
self.weight_mom_name = os.path.join(self.prefix, "rank_{}_softmax_weight_mom.pt".format(self.rank))
|
64 |
+
|
65 |
+
if resume:
|
66 |
+
try:
|
67 |
+
self.weight: torch.Tensor = torch.load(self.weight_name)
|
68 |
+
self.weight_mom: torch.Tensor = torch.load(self.weight_mom_name)
|
69 |
+
if self.weight.shape[0] != self.num_local or self.weight_mom.shape[0] != self.num_local:
|
70 |
+
raise IndexError
|
71 |
+
logging.info("softmax weight resume successfully!")
|
72 |
+
logging.info("softmax weight mom resume successfully!")
|
73 |
+
except (FileNotFoundError, KeyError, IndexError):
|
74 |
+
self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device)
|
75 |
+
self.weight_mom: torch.Tensor = torch.zeros_like(self.weight)
|
76 |
+
logging.info("softmax weight init!")
|
77 |
+
logging.info("softmax weight mom init!")
|
78 |
+
else:
|
79 |
+
self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device)
|
80 |
+
self.weight_mom: torch.Tensor = torch.zeros_like(self.weight)
|
81 |
+
logging.info("softmax weight init successfully!")
|
82 |
+
logging.info("softmax weight mom init successfully!")
|
83 |
+
self.stream: torch.cuda.Stream = torch.cuda.Stream(local_rank)
|
84 |
+
|
85 |
+
self.index = None
|
86 |
+
if int(self.sample_rate) == 1:
|
87 |
+
self.update = lambda: 0
|
88 |
+
self.sub_weight = Parameter(self.weight)
|
89 |
+
self.sub_weight_mom = self.weight_mom
|
90 |
+
else:
|
91 |
+
self.sub_weight = Parameter(torch.empty((0, 0)).cuda(local_rank))
|
92 |
+
|
93 |
+
def save_params(self):
|
94 |
+
""" Save softmax weight for each rank on prefix
|
95 |
+
"""
|
96 |
+
torch.save(self.weight.data, self.weight_name)
|
97 |
+
torch.save(self.weight_mom, self.weight_mom_name)
|
98 |
+
|
99 |
+
@torch.no_grad()
|
100 |
+
def sample(self, total_label):
|
101 |
+
"""
|
102 |
+
Sample all positive class centers in each rank, and random select neg class centers to filling a fixed
|
103 |
+
`num_sample`.
|
104 |
+
|
105 |
+
total_label: tensor
|
106 |
+
Label after all gather, which cross all GPUs.
|
107 |
+
"""
|
108 |
+
index_positive = (self.class_start <= total_label) & (total_label < self.class_start + self.num_local)
|
109 |
+
total_label[~index_positive] = -1
|
110 |
+
total_label[index_positive] -= self.class_start
|
111 |
+
if int(self.sample_rate) != 1:
|
112 |
+
positive = torch.unique(total_label[index_positive], sorted=True)
|
113 |
+
if self.num_sample - positive.size(0) >= 0:
|
114 |
+
perm = torch.rand(size=[self.num_local], device=self.device)
|
115 |
+
perm[positive] = 2.0
|
116 |
+
index = torch.topk(perm, k=self.num_sample)[1]
|
117 |
+
index = index.sort()[0]
|
118 |
+
else:
|
119 |
+
index = positive
|
120 |
+
self.index = index
|
121 |
+
total_label[index_positive] = torch.searchsorted(index, total_label[index_positive])
|
122 |
+
self.sub_weight = Parameter(self.weight[index])
|
123 |
+
self.sub_weight_mom = self.weight_mom[index]
|
124 |
+
|
125 |
+
def forward(self, total_features, norm_weight):
|
126 |
+
""" Partial fc forward, `logits = X * sample(W)`
|
127 |
+
"""
|
128 |
+
torch.cuda.current_stream().wait_stream(self.stream)
|
129 |
+
logits = linear(total_features, norm_weight)
|
130 |
+
return logits
|
131 |
+
|
132 |
+
@torch.no_grad()
|
133 |
+
def update(self):
|
134 |
+
""" Set updated weight and weight_mom to memory bank.
|
135 |
+
"""
|
136 |
+
self.weight_mom[self.index] = self.sub_weight_mom
|
137 |
+
self.weight[self.index] = self.sub_weight
|
138 |
+
|
139 |
+
def prepare(self, label, optimizer):
|
140 |
+
"""
|
141 |
+
get sampled class centers for cal softmax.
|
142 |
+
|
143 |
+
label: tensor
|
144 |
+
Label tensor on each rank.
|
145 |
+
optimizer: opt
|
146 |
+
Optimizer for partial fc, which need to get weight mom.
|
147 |
+
"""
|
148 |
+
with torch.cuda.stream(self.stream):
|
149 |
+
total_label = torch.zeros(
|
150 |
+
size=[self.batch_size * self.world_size], device=self.device, dtype=torch.long)
|
151 |
+
dist.all_gather(list(total_label.chunk(self.world_size, dim=0)), label)
|
152 |
+
self.sample(total_label)
|
153 |
+
optimizer.state.pop(optimizer.param_groups[-1]['params'][0], None)
|
154 |
+
optimizer.param_groups[-1]['params'][0] = self.sub_weight
|
155 |
+
optimizer.state[self.sub_weight]['momentum_buffer'] = self.sub_weight_mom
|
156 |
+
norm_weight = normalize(self.sub_weight)
|
157 |
+
return total_label, norm_weight
|
158 |
+
|
159 |
+
def forward_backward(self, label, features, optimizer):
|
160 |
+
"""
|
161 |
+
Partial fc forward and backward with model parallel
|
162 |
+
|
163 |
+
label: tensor
|
164 |
+
Label tensor on each rank(GPU)
|
165 |
+
features: tensor
|
166 |
+
Features tensor on each rank(GPU)
|
167 |
+
optimizer: optimizer
|
168 |
+
Optimizer for partial fc
|
169 |
+
|
170 |
+
Returns:
|
171 |
+
--------
|
172 |
+
x_grad: tensor
|
173 |
+
The gradient of features.
|
174 |
+
loss_v: tensor
|
175 |
+
Loss value for cross entropy.
|
176 |
+
"""
|
177 |
+
total_label, norm_weight = self.prepare(label, optimizer)
|
178 |
+
total_features = torch.zeros(
|
179 |
+
size=[self.batch_size * self.world_size, self.embedding_size], device=self.device)
|
180 |
+
dist.all_gather(list(total_features.chunk(self.world_size, dim=0)), features.data)
|
181 |
+
total_features.requires_grad = True
|
182 |
+
|
183 |
+
logits = self.forward(total_features, norm_weight)
|
184 |
+
logits = self.margin_softmax(logits, total_label)
|
185 |
+
|
186 |
+
with torch.no_grad():
|
187 |
+
max_fc = torch.max(logits, dim=1, keepdim=True)[0]
|
188 |
+
dist.all_reduce(max_fc, dist.ReduceOp.MAX)
|
189 |
+
|
190 |
+
# calculate exp(logits) and all-reduce
|
191 |
+
logits_exp = torch.exp(logits - max_fc)
|
192 |
+
logits_sum_exp = logits_exp.sum(dim=1, keepdims=True)
|
193 |
+
dist.all_reduce(logits_sum_exp, dist.ReduceOp.SUM)
|
194 |
+
|
195 |
+
# calculate prob
|
196 |
+
logits_exp.div_(logits_sum_exp)
|
197 |
+
|
198 |
+
# get one-hot
|
199 |
+
grad = logits_exp
|
200 |
+
index = torch.where(total_label != -1)[0]
|
201 |
+
one_hot = torch.zeros(size=[index.size()[0], grad.size()[1]], device=grad.device)
|
202 |
+
one_hot.scatter_(1, total_label[index, None], 1)
|
203 |
+
|
204 |
+
# calculate loss
|
205 |
+
loss = torch.zeros(grad.size()[0], 1, device=grad.device)
|
206 |
+
loss[index] = grad[index].gather(1, total_label[index, None])
|
207 |
+
dist.all_reduce(loss, dist.ReduceOp.SUM)
|
208 |
+
loss_v = loss.clamp_min_(1e-30).log_().mean() * (-1)
|
209 |
+
|
210 |
+
# calculate grad
|
211 |
+
grad[index] -= one_hot
|
212 |
+
grad.div_(self.batch_size * self.world_size)
|
213 |
+
|
214 |
+
logits.backward(grad)
|
215 |
+
if total_features.grad is not None:
|
216 |
+
total_features.grad.detach_()
|
217 |
+
x_grad: torch.Tensor = torch.zeros_like(features, requires_grad=True)
|
218 |
+
# feature gradient all-reduce
|
219 |
+
dist.reduce_scatter(x_grad, list(total_features.grad.chunk(self.world_size, dim=0)))
|
220 |
+
x_grad = x_grad * self.world_size
|
221 |
+
# backward backbone
|
222 |
+
return x_grad, loss_v
|
videoretalking/third_part/face3d/models/arcface_torch/requirement.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
tensorboard
|
2 |
+
easydict
|
3 |
+
mxnet
|
4 |
+
onnx
|
5 |
+
sklearn
|
videoretalking/third_part/face3d/models/arcface_torch/run.sh
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r50
|
2 |
+
ps -ef | grep "train" | grep -v grep | awk '{print "kill -9 "$2}' | sh
|
videoretalking/third_part/face3d/models/arcface_torch/torch2onnx.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import onnx
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
def convert_onnx(net, path_module, output, opset=11, simplify=False):
|
7 |
+
assert isinstance(net, torch.nn.Module)
|
8 |
+
img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32)
|
9 |
+
img = img.astype(np.float)
|
10 |
+
img = (img / 255. - 0.5) / 0.5 # torch style norm
|
11 |
+
img = img.transpose((2, 0, 1))
|
12 |
+
img = torch.from_numpy(img).unsqueeze(0).float()
|
13 |
+
|
14 |
+
weight = torch.load(path_module)
|
15 |
+
net.load_state_dict(weight)
|
16 |
+
net.eval()
|
17 |
+
torch.onnx.export(net, img, output, keep_initializers_as_inputs=False, verbose=False, opset_version=opset)
|
18 |
+
model = onnx.load(output)
|
19 |
+
graph = model.graph
|
20 |
+
graph.input[0].type.tensor_type.shape.dim[0].dim_param = 'None'
|
21 |
+
if simplify:
|
22 |
+
from onnxsim import simplify
|
23 |
+
model, check = simplify(model)
|
24 |
+
assert check, "Simplified ONNX model could not be validated"
|
25 |
+
onnx.save(model, output)
|
26 |
+
|
27 |
+
|
28 |
+
if __name__ == '__main__':
|
29 |
+
import os
|
30 |
+
import argparse
|
31 |
+
from backbones import get_model
|
32 |
+
|
33 |
+
parser = argparse.ArgumentParser(description='ArcFace PyTorch to onnx')
|
34 |
+
parser.add_argument('input', type=str, help='input backbone.pth file or path')
|
35 |
+
parser.add_argument('--output', type=str, default=None, help='output onnx path')
|
36 |
+
parser.add_argument('--network', type=str, default=None, help='backbone network')
|
37 |
+
parser.add_argument('--simplify', type=bool, default=False, help='onnx simplify')
|
38 |
+
args = parser.parse_args()
|
39 |
+
input_file = args.input
|
40 |
+
if os.path.isdir(input_file):
|
41 |
+
input_file = os.path.join(input_file, "backbone.pth")
|
42 |
+
assert os.path.exists(input_file)
|
43 |
+
model_name = os.path.basename(os.path.dirname(input_file)).lower()
|
44 |
+
params = model_name.split("_")
|
45 |
+
if len(params) >= 3 and params[1] in ('arcface', 'cosface'):
|
46 |
+
if args.network is None:
|
47 |
+
args.network = params[2]
|
48 |
+
assert args.network is not None
|
49 |
+
print(args)
|
50 |
+
backbone_onnx = get_model(args.network, dropout=0)
|
51 |
+
|
52 |
+
output_path = args.output
|
53 |
+
if output_path is None:
|
54 |
+
output_path = os.path.join(os.path.dirname(__file__), 'onnx')
|
55 |
+
if not os.path.exists(output_path):
|
56 |
+
os.makedirs(output_path)
|
57 |
+
assert os.path.isdir(output_path)
|
58 |
+
output_file = os.path.join(output_path, "%s.onnx" % model_name)
|
59 |
+
convert_onnx(backbone_onnx, input_file, output_file, simplify=args.simplify)
|
videoretalking/third_part/face3d/models/arcface_torch/train.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.distributed as dist
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import torch.utils.data.distributed
|
9 |
+
from torch.nn.utils import clip_grad_norm_
|
10 |
+
|
11 |
+
import losses
|
12 |
+
from backbones import get_model
|
13 |
+
from dataset import MXFaceDataset, SyntheticDataset, DataLoaderX
|
14 |
+
from partial_fc import PartialFC
|
15 |
+
from utils.utils_amp import MaxClipGradScaler
|
16 |
+
from utils.utils_callbacks import CallBackVerification, CallBackLogging, CallBackModelCheckpoint
|
17 |
+
from utils.utils_config import get_config
|
18 |
+
from utils.utils_logging import AverageMeter, init_logging
|
19 |
+
|
20 |
+
|
21 |
+
def main(args):
|
22 |
+
cfg = get_config(args.config)
|
23 |
+
try:
|
24 |
+
world_size = int(os.environ['WORLD_SIZE'])
|
25 |
+
rank = int(os.environ['RANK'])
|
26 |
+
dist.init_process_group('nccl')
|
27 |
+
except KeyError:
|
28 |
+
world_size = 1
|
29 |
+
rank = 0
|
30 |
+
dist.init_process_group(backend='nccl', init_method="tcp://127.0.0.1:12584", rank=rank, world_size=world_size)
|
31 |
+
|
32 |
+
local_rank = args.local_rank
|
33 |
+
torch.cuda.set_device(local_rank)
|
34 |
+
os.makedirs(cfg.output, exist_ok=True)
|
35 |
+
init_logging(rank, cfg.output)
|
36 |
+
|
37 |
+
if cfg.rec == "synthetic":
|
38 |
+
train_set = SyntheticDataset(local_rank=local_rank)
|
39 |
+
else:
|
40 |
+
train_set = MXFaceDataset(root_dir=cfg.rec, local_rank=local_rank)
|
41 |
+
|
42 |
+
train_sampler = torch.utils.data.distributed.DistributedSampler(train_set, shuffle=True)
|
43 |
+
train_loader = DataLoaderX(
|
44 |
+
local_rank=local_rank, dataset=train_set, batch_size=cfg.batch_size,
|
45 |
+
sampler=train_sampler, num_workers=2, pin_memory=True, drop_last=True)
|
46 |
+
backbone = get_model(cfg.network, dropout=0.0, fp16=cfg.fp16, num_features=cfg.embedding_size).to(local_rank)
|
47 |
+
|
48 |
+
if cfg.resume:
|
49 |
+
try:
|
50 |
+
backbone_pth = os.path.join(cfg.output, "backbone.pth")
|
51 |
+
backbone.load_state_dict(torch.load(backbone_pth, map_location=torch.device(local_rank)))
|
52 |
+
if rank == 0:
|
53 |
+
logging.info("backbone resume successfully!")
|
54 |
+
except (FileNotFoundError, KeyError, IndexError, RuntimeError):
|
55 |
+
if rank == 0:
|
56 |
+
logging.info("resume fail, backbone init successfully!")
|
57 |
+
|
58 |
+
backbone = torch.nn.parallel.DistributedDataParallel(
|
59 |
+
module=backbone, broadcast_buffers=False, device_ids=[local_rank])
|
60 |
+
backbone.train()
|
61 |
+
margin_softmax = losses.get_loss(cfg.loss)
|
62 |
+
module_partial_fc = PartialFC(
|
63 |
+
rank=rank, local_rank=local_rank, world_size=world_size, resume=cfg.resume,
|
64 |
+
batch_size=cfg.batch_size, margin_softmax=margin_softmax, num_classes=cfg.num_classes,
|
65 |
+
sample_rate=cfg.sample_rate, embedding_size=cfg.embedding_size, prefix=cfg.output)
|
66 |
+
|
67 |
+
opt_backbone = torch.optim.SGD(
|
68 |
+
params=[{'params': backbone.parameters()}],
|
69 |
+
lr=cfg.lr / 512 * cfg.batch_size * world_size,
|
70 |
+
momentum=0.9, weight_decay=cfg.weight_decay)
|
71 |
+
opt_pfc = torch.optim.SGD(
|
72 |
+
params=[{'params': module_partial_fc.parameters()}],
|
73 |
+
lr=cfg.lr / 512 * cfg.batch_size * world_size,
|
74 |
+
momentum=0.9, weight_decay=cfg.weight_decay)
|
75 |
+
|
76 |
+
num_image = len(train_set)
|
77 |
+
total_batch_size = cfg.batch_size * world_size
|
78 |
+
cfg.warmup_step = num_image // total_batch_size * cfg.warmup_epoch
|
79 |
+
cfg.total_step = num_image // total_batch_size * cfg.num_epoch
|
80 |
+
|
81 |
+
def lr_step_func(current_step):
|
82 |
+
cfg.decay_step = [x * num_image // total_batch_size for x in cfg.decay_epoch]
|
83 |
+
if current_step < cfg.warmup_step:
|
84 |
+
return current_step / cfg.warmup_step
|
85 |
+
else:
|
86 |
+
return 0.1 ** len([m for m in cfg.decay_step if m <= current_step])
|
87 |
+
|
88 |
+
scheduler_backbone = torch.optim.lr_scheduler.LambdaLR(
|
89 |
+
optimizer=opt_backbone, lr_lambda=lr_step_func)
|
90 |
+
scheduler_pfc = torch.optim.lr_scheduler.LambdaLR(
|
91 |
+
optimizer=opt_pfc, lr_lambda=lr_step_func)
|
92 |
+
|
93 |
+
for key, value in cfg.items():
|
94 |
+
num_space = 25 - len(key)
|
95 |
+
logging.info(": " + key + " " * num_space + str(value))
|
96 |
+
|
97 |
+
val_target = cfg.val_targets
|
98 |
+
callback_verification = CallBackVerification(2000, rank, val_target, cfg.rec)
|
99 |
+
callback_logging = CallBackLogging(50, rank, cfg.total_step, cfg.batch_size, world_size, None)
|
100 |
+
callback_checkpoint = CallBackModelCheckpoint(rank, cfg.output)
|
101 |
+
|
102 |
+
loss = AverageMeter()
|
103 |
+
start_epoch = 0
|
104 |
+
global_step = 0
|
105 |
+
grad_amp = MaxClipGradScaler(cfg.batch_size, 128 * cfg.batch_size, growth_interval=100) if cfg.fp16 else None
|
106 |
+
for epoch in range(start_epoch, cfg.num_epoch):
|
107 |
+
train_sampler.set_epoch(epoch)
|
108 |
+
for step, (img, label) in enumerate(train_loader):
|
109 |
+
global_step += 1
|
110 |
+
features = F.normalize(backbone(img))
|
111 |
+
x_grad, loss_v = module_partial_fc.forward_backward(label, features, opt_pfc)
|
112 |
+
if cfg.fp16:
|
113 |
+
features.backward(grad_amp.scale(x_grad))
|
114 |
+
grad_amp.unscale_(opt_backbone)
|
115 |
+
clip_grad_norm_(backbone.parameters(), max_norm=5, norm_type=2)
|
116 |
+
grad_amp.step(opt_backbone)
|
117 |
+
grad_amp.update()
|
118 |
+
else:
|
119 |
+
features.backward(x_grad)
|
120 |
+
clip_grad_norm_(backbone.parameters(), max_norm=5, norm_type=2)
|
121 |
+
opt_backbone.step()
|
122 |
+
|
123 |
+
opt_pfc.step()
|
124 |
+
module_partial_fc.update()
|
125 |
+
opt_backbone.zero_grad()
|
126 |
+
opt_pfc.zero_grad()
|
127 |
+
loss.update(loss_v, 1)
|
128 |
+
callback_logging(global_step, loss, epoch, cfg.fp16, scheduler_backbone.get_last_lr()[0], grad_amp)
|
129 |
+
callback_verification(global_step, backbone)
|
130 |
+
scheduler_backbone.step()
|
131 |
+
scheduler_pfc.step()
|
132 |
+
callback_checkpoint(global_step, backbone, module_partial_fc)
|
133 |
+
dist.destroy_process_group()
|
134 |
+
|
135 |
+
|
136 |
+
if __name__ == "__main__":
|
137 |
+
torch.backends.cudnn.benchmark = True
|
138 |
+
parser = argparse.ArgumentParser(description='PyTorch ArcFace Training')
|
139 |
+
parser.add_argument('config', type=str, help='py config file')
|
140 |
+
parser.add_argument('--local_rank', type=int, default=0, help='local_rank')
|
141 |
+
main(parser.parse_args())
|
videoretalking/third_part/face3d/models/arcface_torch/utils/__init__.py
ADDED
File without changes
|
videoretalking/third_part/face3d/models/arcface_torch/utils/plot.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
|
3 |
+
import os
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import numpy as np
|
8 |
+
import pandas as pd
|
9 |
+
from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap
|
10 |
+
from prettytable import PrettyTable
|
11 |
+
from sklearn.metrics import roc_curve, auc
|
12 |
+
|
13 |
+
image_path = "/data/anxiang/IJB_release/IJBC"
|
14 |
+
files = [
|
15 |
+
"./ms1mv3_arcface_r100/ms1mv3_arcface_r100/ijbc.npy"
|
16 |
+
]
|
17 |
+
|
18 |
+
|
19 |
+
def read_template_pair_list(path):
|
20 |
+
pairs = pd.read_csv(path, sep=' ', header=None).values
|
21 |
+
t1 = pairs[:, 0].astype(np.int)
|
22 |
+
t2 = pairs[:, 1].astype(np.int)
|
23 |
+
label = pairs[:, 2].astype(np.int)
|
24 |
+
return t1, t2, label
|
25 |
+
|
26 |
+
|
27 |
+
p1, p2, label = read_template_pair_list(
|
28 |
+
os.path.join('%s/meta' % image_path,
|
29 |
+
'%s_template_pair_label.txt' % 'ijbc'))
|
30 |
+
|
31 |
+
methods = []
|
32 |
+
scores = []
|
33 |
+
for file in files:
|
34 |
+
methods.append(file.split('/')[-2])
|
35 |
+
scores.append(np.load(file))
|
36 |
+
|
37 |
+
methods = np.array(methods)
|
38 |
+
scores = dict(zip(methods, scores))
|
39 |
+
colours = dict(
|
40 |
+
zip(methods, sample_colours_from_colourmap(methods.shape[0], 'Set2')))
|
41 |
+
x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1]
|
42 |
+
tpr_fpr_table = PrettyTable(['Methods'] + [str(x) for x in x_labels])
|
43 |
+
fig = plt.figure()
|
44 |
+
for method in methods:
|
45 |
+
fpr, tpr, _ = roc_curve(label, scores[method])
|
46 |
+
roc_auc = auc(fpr, tpr)
|
47 |
+
fpr = np.flipud(fpr)
|
48 |
+
tpr = np.flipud(tpr) # select largest tpr at same fpr
|
49 |
+
plt.plot(fpr,
|
50 |
+
tpr,
|
51 |
+
color=colours[method],
|
52 |
+
lw=1,
|
53 |
+
label=('[%s (AUC = %0.4f %%)]' %
|
54 |
+
(method.split('-')[-1], roc_auc * 100)))
|
55 |
+
tpr_fpr_row = []
|
56 |
+
tpr_fpr_row.append("%s-%s" % (method, "IJBC"))
|
57 |
+
for fpr_iter in np.arange(len(x_labels)):
|
58 |
+
_, min_index = min(
|
59 |
+
list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr)))))
|
60 |
+
tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100))
|
61 |
+
tpr_fpr_table.add_row(tpr_fpr_row)
|
62 |
+
plt.xlim([10 ** -6, 0.1])
|
63 |
+
plt.ylim([0.3, 1.0])
|
64 |
+
plt.grid(linestyle='--', linewidth=1)
|
65 |
+
plt.xticks(x_labels)
|
66 |
+
plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True))
|
67 |
+
plt.xscale('log')
|
68 |
+
plt.xlabel('False Positive Rate')
|
69 |
+
plt.ylabel('True Positive Rate')
|
70 |
+
plt.title('ROC on IJB')
|
71 |
+
plt.legend(loc="lower right")
|
72 |
+
print(tpr_fpr_table)
|