daddyjin commited on
Commit
b04d4f9
1 Parent(s): c16827d

add pirenderer based FONT and edit requirements.txt.

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Demo_TFR_Pirenderer/.idea/.gitignore +8 -0
  2. Demo_TFR_Pirenderer/.idea/deployment.xml +15 -0
  3. Demo_TFR_Pirenderer/.idea/inspectionProfiles/profiles_settings.xml +6 -0
  4. Demo_TFR_Pirenderer/.idea/modules.xml +8 -0
  5. Demo_TFR_Pirenderer/examples/driven_audio/RD_Radio31_000.wav +0 -0
  6. Demo_TFR_Pirenderer/examples/driven_audio/RD_Radio34_002.wav +0 -0
  7. Demo_TFR_Pirenderer/examples/driven_audio/RD_Radio36_000.wav +0 -0
  8. Demo_TFR_Pirenderer/examples/driven_audio/RD_Radio40_000.wav +0 -0
  9. Demo_TFR_Pirenderer/examples/source_image/ABOUT_00514.jpg +3 -0
  10. Demo_TFR_Pirenderer/examples/source_image/ABOUT_00994.jpg +3 -0
  11. Demo_TFR_Pirenderer/examples/source_image/ABOUT_test_00001.jpg +3 -0
  12. Demo_TFR_Pirenderer/examples/source_image/ABOUT_train_00001.jpg +3 -0
  13. Demo_TFR_Pirenderer/gradio_demo.py +142 -0
  14. Demo_TFR_Pirenderer/src/audio2exp_models/audio2exp.py +41 -0
  15. Demo_TFR_Pirenderer/src/audio2exp_models/networks.py +74 -0
  16. Demo_TFR_Pirenderer/src/audio2pose_models/audio2pose.py +94 -0
  17. Demo_TFR_Pirenderer/src/audio2pose_models/audio_encoder.py +64 -0
  18. Demo_TFR_Pirenderer/src/audio2pose_models/cvae.py +149 -0
  19. Demo_TFR_Pirenderer/src/audio2pose_models/discriminator.py +76 -0
  20. Demo_TFR_Pirenderer/src/audio2pose_models/networks.py +140 -0
  21. Demo_TFR_Pirenderer/src/audio2pose_models/res_unet.py +65 -0
  22. Demo_TFR_Pirenderer/src/config/auido2exp.yaml +58 -0
  23. Demo_TFR_Pirenderer/src/config/auido2pose.yaml +49 -0
  24. Demo_TFR_Pirenderer/src/config/face.yaml +83 -0
  25. Demo_TFR_Pirenderer/src/face3d/data/__init__.py +116 -0
  26. Demo_TFR_Pirenderer/src/face3d/data/base_dataset.py +125 -0
  27. Demo_TFR_Pirenderer/src/face3d/data/flist_dataset.py +125 -0
  28. Demo_TFR_Pirenderer/src/face3d/data/image_folder.py +66 -0
  29. Demo_TFR_Pirenderer/src/face3d/data/template_dataset.py +75 -0
  30. Demo_TFR_Pirenderer/src/face3d/extract_kp_videos.py +108 -0
  31. Demo_TFR_Pirenderer/src/face3d/extract_kp_videos_safe.py +138 -0
  32. Demo_TFR_Pirenderer/src/face3d/models/__init__.py +67 -0
  33. Demo_TFR_Pirenderer/src/face3d/models/arcface_torch/README.md +164 -0
  34. Demo_TFR_Pirenderer/src/face3d/models/arcface_torch/backbones/__init__.py +25 -0
  35. Demo_TFR_Pirenderer/src/face3d/models/arcface_torch/backbones/iresnet.py +187 -0
  36. Demo_TFR_Pirenderer/src/face3d/models/arcface_torch/backbones/iresnet2060.py +176 -0
  37. Demo_TFR_Pirenderer/src/face3d/models/arcface_torch/backbones/mobilefacenet.py +130 -0
  38. Demo_TFR_Pirenderer/src/face3d/models/arcface_torch/configs/3millions.py +23 -0
  39. Demo_TFR_Pirenderer/src/face3d/models/arcface_torch/configs/3millions_pfc.py +23 -0
  40. Demo_TFR_Pirenderer/src/face3d/models/arcface_torch/configs/__init__.py +0 -0
  41. Demo_TFR_Pirenderer/src/face3d/models/arcface_torch/configs/base.py +56 -0
  42. Demo_TFR_Pirenderer/src/face3d/models/arcface_torch/configs/glint360k_mbf.py +26 -0
  43. Demo_TFR_Pirenderer/src/face3d/models/arcface_torch/configs/glint360k_r100.py +26 -0
  44. Demo_TFR_Pirenderer/src/face3d/models/arcface_torch/configs/glint360k_r18.py +26 -0
  45. Demo_TFR_Pirenderer/src/face3d/models/arcface_torch/configs/glint360k_r34.py +26 -0
  46. Demo_TFR_Pirenderer/src/face3d/models/arcface_torch/configs/glint360k_r50.py +26 -0
  47. Demo_TFR_Pirenderer/src/face3d/models/arcface_torch/configs/ms1mv3_mbf.py +26 -0
  48. Demo_TFR_Pirenderer/src/face3d/models/arcface_torch/configs/ms1mv3_r18.py +26 -0
  49. Demo_TFR_Pirenderer/src/face3d/models/arcface_torch/configs/ms1mv3_r2060.py +26 -0
  50. Demo_TFR_Pirenderer/src/face3d/models/arcface_torch/configs/ms1mv3_r34.py +26 -0
Demo_TFR_Pirenderer/.idea/.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
4
+ # Editor-based HTTP Client requests
5
+ /httpRequests/
6
+ # Datasource local storage ignored files
7
+ /dataSources/
8
+ /dataSources.local.xml
Demo_TFR_Pirenderer/.idea/deployment.xml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="PublishConfigData" autoUpload="Always" serverName="10.26.128.77" remoteFilesAllowedToDisappearOnAutoupload="false">
4
+ <serverData>
5
+ <paths name="10.26.128.77">
6
+ <serverdata>
7
+ <mappings>
8
+ <mapping deploy="/data/liujin/Demo_TFR_Pirenderer" local="$PROJECT_DIR$" web="/" />
9
+ </mappings>
10
+ </serverdata>
11
+ </paths>
12
+ </serverData>
13
+ <option name="myAutoUpload" value="ALWAYS" />
14
+ </component>
15
+ </project>
Demo_TFR_Pirenderer/.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
Demo_TFR_Pirenderer/.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/Demo_TFR_Pirenderer.iml" filepath="$PROJECT_DIR$/.idea/Demo_TFR_Pirenderer.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
Demo_TFR_Pirenderer/examples/driven_audio/RD_Radio31_000.wav ADDED
Binary file (512 kB). View file
 
Demo_TFR_Pirenderer/examples/driven_audio/RD_Radio34_002.wav ADDED
Binary file (512 kB). View file
 
Demo_TFR_Pirenderer/examples/driven_audio/RD_Radio36_000.wav ADDED
Binary file (512 kB). View file
 
Demo_TFR_Pirenderer/examples/driven_audio/RD_Radio40_000.wav ADDED
Binary file (512 kB). View file
 
Demo_TFR_Pirenderer/examples/source_image/ABOUT_00514.jpg ADDED

Git LFS Details

  • SHA256: f7de79fd4ef5a83ec819b6e3482fafeec481ad077cbdc442ab27a244916156d1
  • Pointer size: 129 Bytes
  • Size of remote file: 9.54 kB
Demo_TFR_Pirenderer/examples/source_image/ABOUT_00994.jpg ADDED

Git LFS Details

  • SHA256: cc9f6bd9b1e474562bf499fd429acc8dc9ee6a2b80b0b3e2ad15006e006065e5
  • Pointer size: 129 Bytes
  • Size of remote file: 8.3 kB
Demo_TFR_Pirenderer/examples/source_image/ABOUT_test_00001.jpg ADDED

Git LFS Details

  • SHA256: d4fe157194a870eb083efb9c717ced2e7bfd6258b3a88139254ebc2f9ca20e12
  • Pointer size: 130 Bytes
  • Size of remote file: 10.5 kB
Demo_TFR_Pirenderer/examples/source_image/ABOUT_train_00001.jpg ADDED

Git LFS Details

  • SHA256: 43c913936f7afff514dc03dd61f23cd6595e3ad34110f4d213e8345ea850f6bd
  • Pointer size: 129 Bytes
  • Size of remote file: 8.72 kB
Demo_TFR_Pirenderer/gradio_demo.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, uuid
2
+ import os, sys, shutil
3
+ from src.utils.preprocess import CropAndExtract
4
+ from src.test_audio2coeff import Audio2Coeff
5
+ from src.generate_batch import get_data
6
+ from src.generate_facerender_batch import get_facerender_data
7
+ from src.pirenderer.animate import AnimateFromCoeff
8
+
9
+ from pydub import AudioSegment
10
+ from scipy.io import savemat, loadmat
11
+
12
+ def mp3_to_wav(mp3_filename,wav_filename,frame_rate):
13
+ mp3_file = AudioSegment.from_file(file=mp3_filename)
14
+ mp3_file.set_frame_rate(frame_rate).export(wav_filename,format="wav")
15
+
16
+
17
+ class OPT():
18
+ def __init__(self, checkpoint_path='checkpoints', config_path='src/config', lazy_load=False):
19
+
20
+ if torch.cuda.is_available() :
21
+ device = "cuda"
22
+ else:
23
+ device = "cpu"
24
+
25
+ self.device = device
26
+
27
+ os.environ['TORCH_HOME']= checkpoint_path
28
+
29
+ self.checkpoint_path = checkpoint_path
30
+ self.config_path = config_path
31
+
32
+ self.path_of_lm_croper = os.path.join( checkpoint_path, 'shape_predictor_68_face_landmarks.dat')
33
+ self.path_of_net_recon_model = os.path.join( checkpoint_path, 'epoch_20.pth')
34
+ self.dir_of_BFM_fitting = os.path.join( checkpoint_path, 'BFM_Fitting')
35
+ self.wav2lip_checkpoint = os.path.join( checkpoint_path, 'wav2lip.pth')
36
+
37
+ self.audio2pose_checkpoint = os.path.join( checkpoint_path, 'auido2pose.pth')
38
+ self.audio2pose_yaml_path = os.path.join( config_path, 'auido2pose.yaml')
39
+
40
+ self.audio2exp_checkpoint = os.path.join( checkpoint_path, 'auido2exp.pth')
41
+ self.audio2exp_yaml_path = os.path.join( config_path, 'auido2exp.yaml')
42
+
43
+ self.pirenderer_checkpoint = os.path.join(checkpoint_path, 'epoch_00190_iteration_000400000_checkpoint.pt')
44
+ self.pirenderer_yaml_path = os.path.join(config_path, 'face.yaml')
45
+
46
+ self.lazy_load = lazy_load
47
+
48
+ if not self.lazy_load:
49
+ #init model
50
+
51
+ # print(self.audio2pose_checkpoint)
52
+ self.audio_to_coeff = Audio2Coeff(self.audio2pose_checkpoint, self.audio2pose_yaml_path,
53
+ self.audio2exp_checkpoint, self.audio2exp_yaml_path, self.wav2lip_checkpoint, self.device)
54
+
55
+ # print(self.path_of_lm_croper)
56
+ self.preprocess_model = CropAndExtract(self.path_of_lm_croper, self.path_of_net_recon_model, self.dir_of_BFM_fitting, self.device)
57
+
58
+ def test(self, source_image, driven_audio, preprocess='full', still_mode=False, result_dir='./results/'):
59
+
60
+ ### crop: only model,
61
+
62
+ if self.lazy_load:
63
+ #init model
64
+
65
+ # print(self.audio2pose_checkpoint)
66
+ self.audio_to_coeff = Audio2Coeff(self.audio2pose_checkpoint, self.audio2pose_yaml_path,
67
+ self.audio2exp_checkpoint, self.audio2exp_yaml_path, self.wav2lip_checkpoint, self.device)
68
+
69
+ # print(self.path_of_lm_croper)
70
+ self.preprocess_model = CropAndExtract(self.path_of_lm_croper, self.path_of_net_recon_model, self.dir_of_BFM_fitting, self.device)
71
+
72
+ self.pirender = AnimateFromCoeff(self.pirenderer_checkpoint, self.pirenderer_yaml_path, self.device)
73
+
74
+ time_tag = str(uuid.uuid4())
75
+ save_dir = os.path.join(result_dir, time_tag)
76
+ os.makedirs(save_dir, exist_ok=True)
77
+
78
+ input_dir = os.path.join(save_dir, 'input')
79
+ os.makedirs(input_dir, exist_ok=True)
80
+
81
+ # print(source_image)
82
+ pic_path = os.path.join(input_dir, os.path.basename(source_image))
83
+ shutil.copy(source_image, input_dir)
84
+
85
+ if os.path.isfile(driven_audio):
86
+ audio_path = os.path.join(input_dir, os.path.basename(driven_audio))
87
+
88
+ #### mp3 to wav
89
+ if '.mp3' in audio_path:
90
+ mp3_to_wav(driven_audio, audio_path.replace('.mp3', '.wav'), 16000)
91
+ audio_path = audio_path.replace('.mp3', '.wav')
92
+ else:
93
+ shutil.copy(driven_audio, input_dir)
94
+ else:
95
+ raise AttributeError("error audio")
96
+
97
+
98
+ os.makedirs(save_dir, exist_ok=True)
99
+ pose_style = 0
100
+ #crop image and extract 3dmm from image
101
+ first_frame_dir = os.path.join(save_dir, 'first_frame_dir')
102
+ os.makedirs(first_frame_dir, exist_ok=True)
103
+ first_coeff_path, crop_pic_path, crop_info = self.preprocess_model.generate(pic_path, first_frame_dir, preprocess)
104
+
105
+ if first_coeff_path is None:
106
+ raise AttributeError("No face is detected")
107
+
108
+ #audio2ceoff
109
+ batch = get_data(first_coeff_path, audio_path, self.device, ref_eyeblink_coeff_path=None, still=still_mode) # longer audio?
110
+ coeff_path = self.audio_to_coeff.generate(batch, save_dir, pose_style)
111
+ # coeff_data = loadmat(coeff_path)
112
+ # print(coeff_data["coeff_3dmm"].shape) # B,70
113
+ # print(type(coeff_data["coeff_3dmm"])) # nd.array
114
+
115
+ # coeff2video
116
+ batch_size = 1
117
+ data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path, batch_size,
118
+ still_mode=still_mode, preprocess=preprocess)
119
+ # print(data["source_image"].shape)
120
+ # print(data["source_semantics"].shape)
121
+ # print(data["target_semantics_list"].shape)
122
+
123
+ return_path = self.pirender.generate(data, save_dir)
124
+
125
+ #coeff2video
126
+
127
+
128
+
129
+ if self.lazy_load:
130
+ del self.preprocess_model
131
+ del self.audio_to_coeff
132
+
133
+
134
+ if torch.cuda.is_available():
135
+ torch.cuda.empty_cache()
136
+ torch.cuda.synchronize()
137
+
138
+ import gc; gc.collect()
139
+
140
+ return return_path
141
+
142
+
Demo_TFR_Pirenderer/src/audio2exp_models/audio2exp.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import tqdm
2
+ import torch
3
+ from torch import nn
4
+
5
+
6
+ class Audio2Exp(nn.Module):
7
+ def __init__(self, netG, cfg, device, prepare_training_loss=False):
8
+ super(Audio2Exp, self).__init__()
9
+ self.cfg = cfg
10
+ self.device = device
11
+ self.netG = netG.to(device)
12
+
13
+ def test(self, batch):
14
+
15
+ mel_input = batch['indiv_mels'] # bs T 1 80 16
16
+ bs = mel_input.shape[0]
17
+ T = mel_input.shape[1]
18
+
19
+ exp_coeff_pred = []
20
+
21
+ for i in tqdm(range(0, T, 10),'audio2exp:'): # every 10 frames
22
+
23
+ current_mel_input = mel_input[:,i:i+10]
24
+
25
+ #ref = batch['ref'][:, :, :64].repeat((1,current_mel_input.shape[1],1)) #bs T 64
26
+ ref = batch['ref'][:, :, :64][:, i:i+10]
27
+ ratio = batch['ratio_gt'][:, i:i+10] #bs T
28
+
29
+ audiox = current_mel_input.view(-1, 1, 80, 16) # bs*T 1 80 16
30
+
31
+ curr_exp_coeff_pred = self.netG(audiox, ref, ratio) # bs T 64
32
+
33
+ exp_coeff_pred += [curr_exp_coeff_pred]
34
+
35
+ # BS x T x 64
36
+ results_dict = {
37
+ 'exp_coeff_pred': torch.cat(exp_coeff_pred, axis=1)
38
+ }
39
+ return results_dict
40
+
41
+
Demo_TFR_Pirenderer/src/audio2exp_models/networks.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+
5
+ class Conv2d(nn.Module):
6
+ def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, use_act = True, *args, **kwargs):
7
+ super().__init__(*args, **kwargs)
8
+ self.conv_block = nn.Sequential(
9
+ nn.Conv2d(cin, cout, kernel_size, stride, padding),
10
+ nn.BatchNorm2d(cout)
11
+ )
12
+ self.act = nn.ReLU()
13
+ self.residual = residual
14
+ self.use_act = use_act
15
+
16
+ def forward(self, x):
17
+ out = self.conv_block(x)
18
+ if self.residual:
19
+ out += x
20
+
21
+ if self.use_act:
22
+ return self.act(out)
23
+ else:
24
+ return out
25
+
26
+ class SimpleWrapperV2(nn.Module):
27
+ def __init__(self) -> None:
28
+ super().__init__()
29
+ self.audio_encoder = nn.Sequential(
30
+ Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
31
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
32
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
33
+
34
+ Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
35
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
36
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
37
+
38
+ Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
39
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
40
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
41
+
42
+ Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
43
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
44
+
45
+ Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
46
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0),
47
+ )
48
+
49
+ #### load the pre-trained audio_encoder
50
+ #self.audio_encoder = self.audio_encoder.to(device)
51
+ '''
52
+ wav2lip_state_dict = torch.load('/apdcephfs_cq2/share_1290939/wenxuazhang/checkpoints/wav2lip.pth')['state_dict']
53
+ state_dict = self.audio_encoder.state_dict()
54
+
55
+ for k,v in wav2lip_state_dict.items():
56
+ if 'audio_encoder' in k:
57
+ print('init:', k)
58
+ state_dict[k.replace('module.audio_encoder.', '')] = v
59
+ self.audio_encoder.load_state_dict(state_dict)
60
+ '''
61
+
62
+ self.mapping1 = nn.Linear(512+64+1, 64)
63
+ #self.mapping2 = nn.Linear(30, 64)
64
+ #nn.init.constant_(self.mapping1.weight, 0.)
65
+ nn.init.constant_(self.mapping1.bias, 0.)
66
+
67
+ def forward(self, x, ref, ratio):
68
+ x = self.audio_encoder(x).view(x.size(0), -1)
69
+ ref_reshape = ref.reshape(x.size(0), -1)
70
+ ratio = ratio.reshape(x.size(0), -1)
71
+
72
+ y = self.mapping1(torch.cat([x, ref_reshape, ratio], dim=1))
73
+ out = y.reshape(ref.shape[0], ref.shape[1], -1) #+ ref # resudial
74
+ return out
Demo_TFR_Pirenderer/src/audio2pose_models/audio2pose.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from src.audio2pose_models.cvae import CVAE
4
+ from src.audio2pose_models.discriminator import PoseSequenceDiscriminator
5
+ from src.audio2pose_models.audio_encoder import AudioEncoder
6
+
7
+ class Audio2Pose(nn.Module):
8
+ def __init__(self, cfg, wav2lip_checkpoint, device='cuda'):
9
+ super().__init__()
10
+ self.cfg = cfg
11
+ self.seq_len = cfg.MODEL.CVAE.SEQ_LEN
12
+ self.latent_dim = cfg.MODEL.CVAE.LATENT_SIZE
13
+ self.device = device
14
+
15
+ self.audio_encoder = AudioEncoder(wav2lip_checkpoint, device)
16
+ self.audio_encoder.eval()
17
+ for param in self.audio_encoder.parameters():
18
+ param.requires_grad = False
19
+
20
+ self.netG = CVAE(cfg)
21
+ self.netD_motion = PoseSequenceDiscriminator(cfg)
22
+
23
+
24
+ def forward(self, x):
25
+
26
+ batch = {}
27
+ coeff_gt = x['gt'].cuda().squeeze(0) #bs frame_len+1 73
28
+ batch['pose_motion_gt'] = coeff_gt[:, 1:, -9:-3] - coeff_gt[:, :1, -9:-3] #bs frame_len 6
29
+ batch['ref'] = coeff_gt[:, 0, -9:-3] #bs 6
30
+ batch['class'] = x['class'].squeeze(0).cuda() # bs
31
+ indiv_mels= x['indiv_mels'].cuda().squeeze(0) # bs seq_len+1 80 16
32
+
33
+ # forward
34
+ audio_emb_list = []
35
+ audio_emb = self.audio_encoder(indiv_mels[:, 1:, :, :].unsqueeze(2)) #bs seq_len 512
36
+ batch['audio_emb'] = audio_emb
37
+ batch = self.netG(batch)
38
+
39
+ pose_motion_pred = batch['pose_motion_pred'] # bs frame_len 6
40
+ pose_gt = coeff_gt[:, 1:, -9:-3].clone() # bs frame_len 6
41
+ pose_pred = coeff_gt[:, :1, -9:-3] + pose_motion_pred # bs frame_len 6
42
+
43
+ batch['pose_pred'] = pose_pred
44
+ batch['pose_gt'] = pose_gt
45
+
46
+ return batch
47
+
48
+ def test(self, x):
49
+
50
+ batch = {}
51
+ ref = x['ref'] #bs 1 70
52
+ batch['ref'] = x['ref'][:,0,-6:]
53
+ batch['class'] = x['class']
54
+ bs = ref.shape[0]
55
+
56
+ indiv_mels= x['indiv_mels'] # bs T 1 80 16
57
+ indiv_mels_use = indiv_mels[:, 1:] # we regard the ref as the first frame
58
+ num_frames = x['num_frames']
59
+ num_frames = int(num_frames) - 1
60
+
61
+ #
62
+ div = num_frames//self.seq_len
63
+ re = num_frames%self.seq_len
64
+ audio_emb_list = []
65
+ pose_motion_pred_list = [torch.zeros(batch['ref'].unsqueeze(1).shape, dtype=batch['ref'].dtype,
66
+ device=batch['ref'].device)]
67
+
68
+ for i in range(div):
69
+ z = torch.randn(bs, self.latent_dim).to(ref.device)
70
+ batch['z'] = z
71
+ audio_emb = self.audio_encoder(indiv_mels_use[:, i*self.seq_len:(i+1)*self.seq_len,:,:,:]) #bs seq_len 512
72
+ batch['audio_emb'] = audio_emb
73
+ batch = self.netG.test(batch)
74
+ pose_motion_pred_list.append(batch['pose_motion_pred']) #list of bs seq_len 6
75
+
76
+ if re != 0:
77
+ z = torch.randn(bs, self.latent_dim).to(ref.device)
78
+ batch['z'] = z
79
+ audio_emb = self.audio_encoder(indiv_mels_use[:, -1*self.seq_len:,:,:,:]) #bs seq_len 512
80
+ if audio_emb.shape[1] != self.seq_len:
81
+ pad_dim = self.seq_len-audio_emb.shape[1]
82
+ pad_audio_emb = audio_emb[:, :1].repeat(1, pad_dim, 1)
83
+ audio_emb = torch.cat([pad_audio_emb, audio_emb], 1)
84
+ batch['audio_emb'] = audio_emb
85
+ batch = self.netG.test(batch)
86
+ pose_motion_pred_list.append(batch['pose_motion_pred'][:,-1*re:,:])
87
+
88
+ pose_motion_pred = torch.cat(pose_motion_pred_list, dim = 1)
89
+ batch['pose_motion_pred'] = pose_motion_pred
90
+
91
+ pose_pred = ref[:, :1, -6:] + pose_motion_pred # bs T 6
92
+
93
+ batch['pose_pred'] = pose_pred
94
+ return batch
Demo_TFR_Pirenderer/src/audio2pose_models/audio_encoder.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+ class Conv2d(nn.Module):
6
+ def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
7
+ super().__init__(*args, **kwargs)
8
+ self.conv_block = nn.Sequential(
9
+ nn.Conv2d(cin, cout, kernel_size, stride, padding),
10
+ nn.BatchNorm2d(cout)
11
+ )
12
+ self.act = nn.ReLU()
13
+ self.residual = residual
14
+
15
+ def forward(self, x):
16
+ out = self.conv_block(x)
17
+ if self.residual:
18
+ out += x
19
+ return self.act(out)
20
+
21
+ class AudioEncoder(nn.Module):
22
+ def __init__(self, wav2lip_checkpoint, device):
23
+ super(AudioEncoder, self).__init__()
24
+
25
+ self.audio_encoder = nn.Sequential(
26
+ Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
27
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
28
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
29
+
30
+ Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
31
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
32
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
33
+
34
+ Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
35
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
36
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
37
+
38
+ Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
39
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
40
+
41
+ Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
42
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
43
+
44
+ #### load the pre-trained audio_encoder
45
+ wav2lip_state_dict = torch.load(wav2lip_checkpoint, map_location=torch.device(device))['state_dict']
46
+ state_dict = self.audio_encoder.state_dict()
47
+
48
+ for k,v in wav2lip_state_dict.items():
49
+ if 'audio_encoder' in k:
50
+ state_dict[k.replace('module.audio_encoder.', '')] = v
51
+ self.audio_encoder.load_state_dict(state_dict)
52
+
53
+
54
+ def forward(self, audio_sequences):
55
+ # audio_sequences = (B, T, 1, 80, 16)
56
+ B = audio_sequences.size(0)
57
+
58
+ audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
59
+
60
+ audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1
61
+ dim = audio_embedding.shape[1]
62
+ audio_embedding = audio_embedding.reshape((B, -1, dim, 1, 1))
63
+
64
+ return audio_embedding.squeeze(-1).squeeze(-1) #B seq_len+1 512
Demo_TFR_Pirenderer/src/audio2pose_models/cvae.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+ from src.audio2pose_models.res_unet import ResUnet
5
+
6
+ def class2onehot(idx, class_num):
7
+
8
+ assert torch.max(idx).item() < class_num
9
+ onehot = torch.zeros(idx.size(0), class_num).to(idx.device)
10
+ onehot.scatter_(1, idx, 1)
11
+ return onehot
12
+
13
+ class CVAE(nn.Module):
14
+ def __init__(self, cfg):
15
+ super().__init__()
16
+ encoder_layer_sizes = cfg.MODEL.CVAE.ENCODER_LAYER_SIZES
17
+ decoder_layer_sizes = cfg.MODEL.CVAE.DECODER_LAYER_SIZES
18
+ latent_size = cfg.MODEL.CVAE.LATENT_SIZE
19
+ num_classes = cfg.DATASET.NUM_CLASSES
20
+ audio_emb_in_size = cfg.MODEL.CVAE.AUDIO_EMB_IN_SIZE
21
+ audio_emb_out_size = cfg.MODEL.CVAE.AUDIO_EMB_OUT_SIZE
22
+ seq_len = cfg.MODEL.CVAE.SEQ_LEN
23
+
24
+ self.latent_size = latent_size
25
+
26
+ self.encoder = ENCODER(encoder_layer_sizes, latent_size, num_classes,
27
+ audio_emb_in_size, audio_emb_out_size, seq_len)
28
+ self.decoder = DECODER(decoder_layer_sizes, latent_size, num_classes,
29
+ audio_emb_in_size, audio_emb_out_size, seq_len)
30
+ def reparameterize(self, mu, logvar):
31
+ std = torch.exp(0.5 * logvar)
32
+ eps = torch.randn_like(std)
33
+ return mu + eps * std
34
+
35
+ def forward(self, batch):
36
+ batch = self.encoder(batch)
37
+ mu = batch['mu']
38
+ logvar = batch['logvar']
39
+ z = self.reparameterize(mu, logvar)
40
+ batch['z'] = z
41
+ return self.decoder(batch)
42
+
43
+ def test(self, batch):
44
+ '''
45
+ class_id = batch['class']
46
+ z = torch.randn([class_id.size(0), self.latent_size]).to(class_id.device)
47
+ batch['z'] = z
48
+ '''
49
+ return self.decoder(batch)
50
+
51
+ class ENCODER(nn.Module):
52
+ def __init__(self, layer_sizes, latent_size, num_classes,
53
+ audio_emb_in_size, audio_emb_out_size, seq_len):
54
+ super().__init__()
55
+
56
+ self.resunet = ResUnet()
57
+ self.num_classes = num_classes
58
+ self.seq_len = seq_len
59
+
60
+ self.MLP = nn.Sequential()
61
+ layer_sizes[0] += latent_size + seq_len*audio_emb_out_size + 6
62
+ for i, (in_size, out_size) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
63
+ self.MLP.add_module(
64
+ name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
65
+ self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
66
+
67
+ self.linear_means = nn.Linear(layer_sizes[-1], latent_size)
68
+ self.linear_logvar = nn.Linear(layer_sizes[-1], latent_size)
69
+ self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)
70
+
71
+ self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))
72
+
73
+ def forward(self, batch):
74
+ class_id = batch['class']
75
+ pose_motion_gt = batch['pose_motion_gt'] #bs seq_len 6
76
+ ref = batch['ref'] #bs 6
77
+ bs = pose_motion_gt.shape[0]
78
+ audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size
79
+
80
+ #pose encode
81
+ pose_emb = self.resunet(pose_motion_gt.unsqueeze(1)) #bs 1 seq_len 6
82
+ pose_emb = pose_emb.reshape(bs, -1) #bs seq_len*6
83
+
84
+ #audio mapping
85
+ print(audio_in.shape)
86
+ audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size
87
+ audio_out = audio_out.reshape(bs, -1)
88
+
89
+ class_bias = self.classbias[class_id] #bs latent_size
90
+ x_in = torch.cat([ref, pose_emb, audio_out, class_bias], dim=-1) #bs seq_len*(audio_emb_out_size+6)+latent_size
91
+ x_out = self.MLP(x_in)
92
+
93
+ mu = self.linear_means(x_out)
94
+ logvar = self.linear_means(x_out) #bs latent_size
95
+
96
+ batch.update({'mu':mu, 'logvar':logvar})
97
+ return batch
98
+
99
+ class DECODER(nn.Module):
100
+ def __init__(self, layer_sizes, latent_size, num_classes,
101
+ audio_emb_in_size, audio_emb_out_size, seq_len):
102
+ super().__init__()
103
+
104
+ self.resunet = ResUnet()
105
+ self.num_classes = num_classes
106
+ self.seq_len = seq_len
107
+
108
+ self.MLP = nn.Sequential()
109
+ input_size = latent_size + seq_len*audio_emb_out_size + 6
110
+ for i, (in_size, out_size) in enumerate(zip([input_size]+layer_sizes[:-1], layer_sizes)):
111
+ self.MLP.add_module(
112
+ name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
113
+ if i+1 < len(layer_sizes):
114
+ self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
115
+ else:
116
+ self.MLP.add_module(name="sigmoid", module=nn.Sigmoid())
117
+
118
+ self.pose_linear = nn.Linear(6, 6)
119
+ self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)
120
+
121
+ self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))
122
+
123
+ def forward(self, batch):
124
+
125
+ z = batch['z'] #bs latent_size
126
+ bs = z.shape[0]
127
+ class_id = batch['class']
128
+ ref = batch['ref'] #bs 6
129
+ audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size
130
+ #print('audio_in: ', audio_in[:, :, :10])
131
+
132
+ audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size
133
+ #print('audio_out: ', audio_out[:, :, :10])
134
+ audio_out = audio_out.reshape([bs, -1]) # bs seq_len*audio_emb_out_size
135
+ class_bias = self.classbias[class_id] #bs latent_size
136
+
137
+ z = z + class_bias
138
+ x_in = torch.cat([ref, z, audio_out], dim=-1)
139
+ x_out = self.MLP(x_in) # bs layer_sizes[-1]
140
+ x_out = x_out.reshape((bs, self.seq_len, -1))
141
+
142
+ #print('x_out: ', x_out)
143
+
144
+ pose_emb = self.resunet(x_out.unsqueeze(1)) #bs 1 seq_len 6
145
+
146
+ pose_motion_pred = self.pose_linear(pose_emb.squeeze(1)) #bs seq_len 6
147
+
148
+ batch.update({'pose_motion_pred':pose_motion_pred})
149
+ return batch
Demo_TFR_Pirenderer/src/audio2pose_models/discriminator.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+
5
+ class ConvNormRelu(nn.Module):
6
+ def __init__(self, conv_type='1d', in_channels=3, out_channels=64, downsample=False,
7
+ kernel_size=None, stride=None, padding=None, norm='BN', leaky=False):
8
+ super().__init__()
9
+ if kernel_size is None:
10
+ if downsample:
11
+ kernel_size, stride, padding = 4, 2, 1
12
+ else:
13
+ kernel_size, stride, padding = 3, 1, 1
14
+
15
+ if conv_type == '2d':
16
+ self.conv = nn.Conv2d(
17
+ in_channels,
18
+ out_channels,
19
+ kernel_size,
20
+ stride,
21
+ padding,
22
+ bias=False,
23
+ )
24
+ if norm == 'BN':
25
+ self.norm = nn.BatchNorm2d(out_channels)
26
+ elif norm == 'IN':
27
+ self.norm = nn.InstanceNorm2d(out_channels)
28
+ else:
29
+ raise NotImplementedError
30
+ elif conv_type == '1d':
31
+ self.conv = nn.Conv1d(
32
+ in_channels,
33
+ out_channels,
34
+ kernel_size,
35
+ stride,
36
+ padding,
37
+ bias=False,
38
+ )
39
+ if norm == 'BN':
40
+ self.norm = nn.BatchNorm1d(out_channels)
41
+ elif norm == 'IN':
42
+ self.norm = nn.InstanceNorm1d(out_channels)
43
+ else:
44
+ raise NotImplementedError
45
+ nn.init.kaiming_normal_(self.conv.weight)
46
+
47
+ self.act = nn.LeakyReLU(negative_slope=0.2, inplace=False) if leaky else nn.ReLU(inplace=True)
48
+
49
+ def forward(self, x):
50
+ x = self.conv(x)
51
+ if isinstance(self.norm, nn.InstanceNorm1d):
52
+ x = self.norm(x.permute((0, 2, 1))).permute((0, 2, 1)) # normalize on [C]
53
+ else:
54
+ x = self.norm(x)
55
+ x = self.act(x)
56
+ return x
57
+
58
+
59
+ class PoseSequenceDiscriminator(nn.Module):
60
+ def __init__(self, cfg):
61
+ super().__init__()
62
+ self.cfg = cfg
63
+ leaky = self.cfg.MODEL.DISCRIMINATOR.LEAKY_RELU
64
+
65
+ self.seq = nn.Sequential(
66
+ ConvNormRelu('1d', cfg.MODEL.DISCRIMINATOR.INPUT_CHANNELS, 256, downsample=True, leaky=leaky), # B, 256, 64
67
+ ConvNormRelu('1d', 256, 512, downsample=True, leaky=leaky), # B, 512, 32
68
+ ConvNormRelu('1d', 512, 1024, kernel_size=3, stride=1, padding=1, leaky=leaky), # B, 1024, 16
69
+ nn.Conv1d(1024, 1, kernel_size=3, stride=1, padding=1, bias=True) # B, 1, 16
70
+ )
71
+
72
+ def forward(self, x):
73
+ x = x.reshape(x.size(0), x.size(1), -1).transpose(1, 2)
74
+ x = self.seq(x)
75
+ x = x.squeeze(1)
76
+ return x
Demo_TFR_Pirenderer/src/audio2pose_models/networks.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+
4
+
5
+ class ResidualConv(nn.Module):
6
+ def __init__(self, input_dim, output_dim, stride, padding):
7
+ super(ResidualConv, self).__init__()
8
+
9
+ self.conv_block = nn.Sequential(
10
+ nn.BatchNorm2d(input_dim),
11
+ nn.ReLU(),
12
+ nn.Conv2d(
13
+ input_dim, output_dim, kernel_size=3, stride=stride, padding=padding
14
+ ),
15
+ nn.BatchNorm2d(output_dim),
16
+ nn.ReLU(),
17
+ nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
18
+ )
19
+ self.conv_skip = nn.Sequential(
20
+ nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1),
21
+ nn.BatchNorm2d(output_dim),
22
+ )
23
+
24
+ def forward(self, x):
25
+
26
+ return self.conv_block(x) + self.conv_skip(x)
27
+
28
+
29
+ class Upsample(nn.Module):
30
+ def __init__(self, input_dim, output_dim, kernel, stride):
31
+ super(Upsample, self).__init__()
32
+
33
+ self.upsample = nn.ConvTranspose2d(
34
+ input_dim, output_dim, kernel_size=kernel, stride=stride
35
+ )
36
+
37
+ def forward(self, x):
38
+ return self.upsample(x)
39
+
40
+
41
+ class Squeeze_Excite_Block(nn.Module):
42
+ def __init__(self, channel, reduction=16):
43
+ super(Squeeze_Excite_Block, self).__init__()
44
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
45
+ self.fc = nn.Sequential(
46
+ nn.Linear(channel, channel // reduction, bias=False),
47
+ nn.ReLU(inplace=True),
48
+ nn.Linear(channel // reduction, channel, bias=False),
49
+ nn.Sigmoid(),
50
+ )
51
+
52
+ def forward(self, x):
53
+ b, c, _, _ = x.size()
54
+ y = self.avg_pool(x).view(b, c)
55
+ y = self.fc(y).view(b, c, 1, 1)
56
+ return x * y.expand_as(x)
57
+
58
+
59
+ class ASPP(nn.Module):
60
+ def __init__(self, in_dims, out_dims, rate=[6, 12, 18]):
61
+ super(ASPP, self).__init__()
62
+
63
+ self.aspp_block1 = nn.Sequential(
64
+ nn.Conv2d(
65
+ in_dims, out_dims, 3, stride=1, padding=rate[0], dilation=rate[0]
66
+ ),
67
+ nn.ReLU(inplace=True),
68
+ nn.BatchNorm2d(out_dims),
69
+ )
70
+ self.aspp_block2 = nn.Sequential(
71
+ nn.Conv2d(
72
+ in_dims, out_dims, 3, stride=1, padding=rate[1], dilation=rate[1]
73
+ ),
74
+ nn.ReLU(inplace=True),
75
+ nn.BatchNorm2d(out_dims),
76
+ )
77
+ self.aspp_block3 = nn.Sequential(
78
+ nn.Conv2d(
79
+ in_dims, out_dims, 3, stride=1, padding=rate[2], dilation=rate[2]
80
+ ),
81
+ nn.ReLU(inplace=True),
82
+ nn.BatchNorm2d(out_dims),
83
+ )
84
+
85
+ self.output = nn.Conv2d(len(rate) * out_dims, out_dims, 1)
86
+ self._init_weights()
87
+
88
+ def forward(self, x):
89
+ x1 = self.aspp_block1(x)
90
+ x2 = self.aspp_block2(x)
91
+ x3 = self.aspp_block3(x)
92
+ out = torch.cat([x1, x2, x3], dim=1)
93
+ return self.output(out)
94
+
95
+ def _init_weights(self):
96
+ for m in self.modules():
97
+ if isinstance(m, nn.Conv2d):
98
+ nn.init.kaiming_normal_(m.weight)
99
+ elif isinstance(m, nn.BatchNorm2d):
100
+ m.weight.data.fill_(1)
101
+ m.bias.data.zero_()
102
+
103
+
104
+ class Upsample_(nn.Module):
105
+ def __init__(self, scale=2):
106
+ super(Upsample_, self).__init__()
107
+
108
+ self.upsample = nn.Upsample(mode="bilinear", scale_factor=scale)
109
+
110
+ def forward(self, x):
111
+ return self.upsample(x)
112
+
113
+
114
+ class AttentionBlock(nn.Module):
115
+ def __init__(self, input_encoder, input_decoder, output_dim):
116
+ super(AttentionBlock, self).__init__()
117
+
118
+ self.conv_encoder = nn.Sequential(
119
+ nn.BatchNorm2d(input_encoder),
120
+ nn.ReLU(),
121
+ nn.Conv2d(input_encoder, output_dim, 3, padding=1),
122
+ nn.MaxPool2d(2, 2),
123
+ )
124
+
125
+ self.conv_decoder = nn.Sequential(
126
+ nn.BatchNorm2d(input_decoder),
127
+ nn.ReLU(),
128
+ nn.Conv2d(input_decoder, output_dim, 3, padding=1),
129
+ )
130
+
131
+ self.conv_attn = nn.Sequential(
132
+ nn.BatchNorm2d(output_dim),
133
+ nn.ReLU(),
134
+ nn.Conv2d(output_dim, 1, 1),
135
+ )
136
+
137
+ def forward(self, x1, x2):
138
+ out = self.conv_encoder(x1) + self.conv_decoder(x2)
139
+ out = self.conv_attn(out)
140
+ return out * x2
Demo_TFR_Pirenderer/src/audio2pose_models/res_unet.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from src.audio2pose_models.networks import ResidualConv, Upsample
4
+
5
+
6
+ class ResUnet(nn.Module):
7
+ def __init__(self, channel=1, filters=[32, 64, 128, 256]):
8
+ super(ResUnet, self).__init__()
9
+
10
+ self.input_layer = nn.Sequential(
11
+ nn.Conv2d(channel, filters[0], kernel_size=3, padding=1),
12
+ nn.BatchNorm2d(filters[0]),
13
+ nn.ReLU(),
14
+ nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1),
15
+ )
16
+ self.input_skip = nn.Sequential(
17
+ nn.Conv2d(channel, filters[0], kernel_size=3, padding=1)
18
+ )
19
+
20
+ self.residual_conv_1 = ResidualConv(filters[0], filters[1], stride=(2,1), padding=1)
21
+ self.residual_conv_2 = ResidualConv(filters[1], filters[2], stride=(2,1), padding=1)
22
+
23
+ self.bridge = ResidualConv(filters[2], filters[3], stride=(2,1), padding=1)
24
+
25
+ self.upsample_1 = Upsample(filters[3], filters[3], kernel=(2,1), stride=(2,1))
26
+ self.up_residual_conv1 = ResidualConv(filters[3] + filters[2], filters[2], stride=1, padding=1)
27
+
28
+ self.upsample_2 = Upsample(filters[2], filters[2], kernel=(2,1), stride=(2,1))
29
+ self.up_residual_conv2 = ResidualConv(filters[2] + filters[1], filters[1], stride=1, padding=1)
30
+
31
+ self.upsample_3 = Upsample(filters[1], filters[1], kernel=(2,1), stride=(2,1))
32
+ self.up_residual_conv3 = ResidualConv(filters[1] + filters[0], filters[0], stride=1, padding=1)
33
+
34
+ self.output_layer = nn.Sequential(
35
+ nn.Conv2d(filters[0], 1, 1, 1),
36
+ nn.Sigmoid(),
37
+ )
38
+
39
+ def forward(self, x):
40
+ # Encode
41
+ x1 = self.input_layer(x) + self.input_skip(x)
42
+ x2 = self.residual_conv_1(x1)
43
+ x3 = self.residual_conv_2(x2)
44
+ # Bridge
45
+ x4 = self.bridge(x3)
46
+
47
+ # Decode
48
+ x4 = self.upsample_1(x4)
49
+ x5 = torch.cat([x4, x3], dim=1)
50
+
51
+ x6 = self.up_residual_conv1(x5)
52
+
53
+ x6 = self.upsample_2(x6)
54
+ x7 = torch.cat([x6, x2], dim=1)
55
+
56
+ x8 = self.up_residual_conv2(x7)
57
+
58
+ x8 = self.upsample_3(x8)
59
+ x9 = torch.cat([x8, x1], dim=1)
60
+
61
+ x10 = self.up_residual_conv3(x9)
62
+
63
+ output = self.output_layer(x10)
64
+
65
+ return output
Demo_TFR_Pirenderer/src/config/auido2exp.yaml ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DATASET:
2
+ TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/train.txt
3
+ EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/val.txt
4
+ TRAIN_BATCH_SIZE: 32
5
+ EVAL_BATCH_SIZE: 32
6
+ EXP: True
7
+ EXP_DIM: 64
8
+ FRAME_LEN: 32
9
+ COEFF_LEN: 73
10
+ NUM_CLASSES: 46
11
+ AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav
12
+ COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav2lip_3dmm
13
+ LMDB_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb
14
+ DEBUG: True
15
+ NUM_REPEATS: 2
16
+ T: 40
17
+
18
+
19
+ MODEL:
20
+ FRAMEWORK: V2
21
+ AUDIOENCODER:
22
+ LEAKY_RELU: True
23
+ NORM: 'IN'
24
+ DISCRIMINATOR:
25
+ LEAKY_RELU: False
26
+ INPUT_CHANNELS: 6
27
+ CVAE:
28
+ AUDIO_EMB_IN_SIZE: 512
29
+ AUDIO_EMB_OUT_SIZE: 128
30
+ SEQ_LEN: 32
31
+ LATENT_SIZE: 256
32
+ ENCODER_LAYER_SIZES: [192, 1024]
33
+ DECODER_LAYER_SIZES: [1024, 192]
34
+
35
+
36
+ TRAIN:
37
+ MAX_EPOCH: 300
38
+ GENERATOR:
39
+ LR: 2.0e-5
40
+ DISCRIMINATOR:
41
+ LR: 1.0e-5
42
+ LOSS:
43
+ W_FEAT: 0
44
+ W_COEFF_EXP: 2
45
+ W_LM: 1.0e-2
46
+ W_LM_MOUTH: 0
47
+ W_REG: 0
48
+ W_SYNC: 0
49
+ W_COLOR: 0
50
+ W_EXPRESSION: 0
51
+ W_LIPREADING: 0.01
52
+ W_LIPREADING_VV: 0
53
+ W_EYE_BLINK: 4
54
+
55
+ TAG:
56
+ NAME: small_dataset
57
+
58
+
Demo_TFR_Pirenderer/src/config/auido2pose.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DATASET:
2
+ TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/train_33.txt
3
+ EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/val.txt
4
+ TRAIN_BATCH_SIZE: 64
5
+ EVAL_BATCH_SIZE: 1
6
+ EXP: True
7
+ EXP_DIM: 64
8
+ FRAME_LEN: 32
9
+ COEFF_LEN: 73
10
+ NUM_CLASSES: 46
11
+ AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav
12
+ COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb
13
+ DEBUG: True
14
+
15
+
16
+ MODEL:
17
+ AUDIOENCODER:
18
+ LEAKY_RELU: True
19
+ NORM: 'IN'
20
+ DISCRIMINATOR:
21
+ LEAKY_RELU: False
22
+ INPUT_CHANNELS: 6
23
+ CVAE:
24
+ AUDIO_EMB_IN_SIZE: 512
25
+ AUDIO_EMB_OUT_SIZE: 6
26
+ SEQ_LEN: 32
27
+ LATENT_SIZE: 64
28
+ ENCODER_LAYER_SIZES: [192, 128]
29
+ DECODER_LAYER_SIZES: [128, 192]
30
+
31
+
32
+ TRAIN:
33
+ MAX_EPOCH: 150
34
+ GENERATOR:
35
+ LR: 1.0e-4
36
+ DISCRIMINATOR:
37
+ LR: 1.0e-4
38
+ LOSS:
39
+ LAMBDA_REG: 1
40
+ LAMBDA_LANDMARKS: 0
41
+ LAMBDA_VERTICES: 0
42
+ LAMBDA_GAN_MOTION: 0.7
43
+ LAMBDA_GAN_COEFF: 0
44
+ LAMBDA_KL: 1
45
+
46
+ TAG:
47
+ NAME: cvae_UNET_useAudio_usewav2lipAudioEncoder
48
+
49
+
Demo_TFR_Pirenderer/src/config/face.yaml ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # How often do you want to log the training stats.
2
+ # network_list:
3
+ # gen: gen_optimizer
4
+ # dis: dis_optimizer
5
+
6
+ distributed: False
7
+ image_to_tensorboard: True
8
+ snapshot_save_iter: 40000
9
+ snapshot_save_epoch: 20
10
+ snapshot_save_start_iter: 20000
11
+ snapshot_save_start_epoch: 10
12
+ image_save_iter: 1000
13
+ max_epoch: 200
14
+ logging_iter: 100
15
+ results_dir: ./eval_results
16
+
17
+ gen_optimizer:
18
+ type: adam
19
+ lr: 0.0001
20
+ adam_beta1: 0.5
21
+ adam_beta2: 0.999
22
+ lr_policy:
23
+ iteration_mode: True
24
+ type: step
25
+ step_size: 300000
26
+ gamma: 0.2
27
+
28
+ trainer:
29
+ type: trainers.face_trainer::FaceTrainer
30
+ pretrain_warp_iteration: 200000
31
+ loss_weight:
32
+ weight_perceptual_warp: 2.5
33
+ weight_perceptual_final: 4
34
+ vgg_param_warp:
35
+ network: vgg19
36
+ layers: ['relu_1_1', 'relu_2_1', 'relu_3_1', 'relu_4_1', 'relu_5_1']
37
+ use_style_loss: False
38
+ num_scales: 4
39
+ vgg_param_final:
40
+ network: vgg19
41
+ layers: ['relu_1_1', 'relu_2_1', 'relu_3_1', 'relu_4_1', 'relu_5_1']
42
+ use_style_loss: True
43
+ num_scales: 4
44
+ style_to_perceptual: 250
45
+ init:
46
+ type: 'normal'
47
+ gain: 0.02
48
+ gen:
49
+ type: generators.face_model::FaceGenerator
50
+ param:
51
+ mapping_net:
52
+ coeff_nc: 73
53
+ descriptor_nc: 256
54
+ layer: 3
55
+ warpping_net:
56
+ encoder_layer: 5
57
+ decoder_layer: 3
58
+ base_nc: 32
59
+ editing_net:
60
+ layer: 3
61
+ num_res_blocks: 2
62
+ base_nc: 64
63
+ common:
64
+ image_nc: 3
65
+ descriptor_nc: 256
66
+ max_nc: 256
67
+ use_spect: False
68
+
69
+
70
+ # Data options.
71
+ data:
72
+ type: data.vox_dataset_liujin::VoxDataset
73
+ path: ./dataset/vox_lmdb
74
+ resolution: 256
75
+ semantic_radius: 13
76
+ train:
77
+ batch_size: 8
78
+ distributed: True
79
+ val:
80
+ batch_size: 8
81
+ distributed: True
82
+
83
+
Demo_TFR_Pirenderer/src/face3d/data/__init__.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This package includes all the modules related to data loading and preprocessing
2
+
3
+ To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
4
+ You need to implement four functions:
5
+ -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
6
+ -- <__len__>: return the size of dataset.
7
+ -- <__getitem__>: get a data point from data loader.
8
+ -- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
9
+
10
+ Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
11
+ See our template dataset class 'template_dataset.py' for more details.
12
+ """
13
+ import numpy as np
14
+ import importlib
15
+ import torch.utils.data
16
+ from face3d.data.base_dataset import BaseDataset
17
+
18
+
19
+ def find_dataset_using_name(dataset_name):
20
+ """Import the module "data/[dataset_name]_dataset.py".
21
+
22
+ In the file, the class called DatasetNameDataset() will
23
+ be instantiated. It has to be a subclass of BaseDataset,
24
+ and it is case-insensitive.
25
+ """
26
+ dataset_filename = "data." + dataset_name + "_dataset"
27
+ datasetlib = importlib.import_module(dataset_filename)
28
+
29
+ dataset = None
30
+ target_dataset_name = dataset_name.replace('_', '') + 'dataset'
31
+ for name, cls in datasetlib.__dict__.items():
32
+ if name.lower() == target_dataset_name.lower() \
33
+ and issubclass(cls, BaseDataset):
34
+ dataset = cls
35
+
36
+ if dataset is None:
37
+ raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
38
+
39
+ return dataset
40
+
41
+
42
+ def get_option_setter(dataset_name):
43
+ """Return the static method <modify_commandline_options> of the dataset class."""
44
+ dataset_class = find_dataset_using_name(dataset_name)
45
+ return dataset_class.modify_commandline_options
46
+
47
+
48
+ def create_dataset(opt, rank=0):
49
+ """Create a dataset given the option.
50
+
51
+ This function wraps the class CustomDatasetDataLoader.
52
+ This is the main interface between this package and 'train.py'/'test.py'
53
+
54
+ Example:
55
+ >>> from data import create_dataset
56
+ >>> dataset = create_dataset(opt)
57
+ """
58
+ data_loader = CustomDatasetDataLoader(opt, rank=rank)
59
+ dataset = data_loader.load_data()
60
+ return dataset
61
+
62
+ class CustomDatasetDataLoader():
63
+ """Wrapper class of Dataset class that performs multi-threaded data loading"""
64
+
65
+ def __init__(self, opt, rank=0):
66
+ """Initialize this class
67
+
68
+ Step 1: create a dataset instance given the name [dataset_mode]
69
+ Step 2: create a multi-threaded data loader.
70
+ """
71
+ self.opt = opt
72
+ dataset_class = find_dataset_using_name(opt.dataset_mode)
73
+ self.dataset = dataset_class(opt)
74
+ self.sampler = None
75
+ print("rank %d %s dataset [%s] was created" % (rank, self.dataset.name, type(self.dataset).__name__))
76
+ if opt.use_ddp and opt.isTrain:
77
+ world_size = opt.world_size
78
+ self.sampler = torch.utils.data.distributed.DistributedSampler(
79
+ self.dataset,
80
+ num_replicas=world_size,
81
+ rank=rank,
82
+ shuffle=not opt.serial_batches
83
+ )
84
+ self.dataloader = torch.utils.data.DataLoader(
85
+ self.dataset,
86
+ sampler=self.sampler,
87
+ num_workers=int(opt.num_threads / world_size),
88
+ batch_size=int(opt.batch_size / world_size),
89
+ drop_last=True)
90
+ else:
91
+ self.dataloader = torch.utils.data.DataLoader(
92
+ self.dataset,
93
+ batch_size=opt.batch_size,
94
+ shuffle=(not opt.serial_batches) and opt.isTrain,
95
+ num_workers=int(opt.num_threads),
96
+ drop_last=True
97
+ )
98
+
99
+ def set_epoch(self, epoch):
100
+ self.dataset.current_epoch = epoch
101
+ if self.sampler is not None:
102
+ self.sampler.set_epoch(epoch)
103
+
104
+ def load_data(self):
105
+ return self
106
+
107
+ def __len__(self):
108
+ """Return the number of data in the dataset"""
109
+ return min(len(self.dataset), self.opt.max_dataset_size)
110
+
111
+ def __iter__(self):
112
+ """Return a batch of data"""
113
+ for i, data in enumerate(self.dataloader):
114
+ if i * self.opt.batch_size >= self.opt.max_dataset_size:
115
+ break
116
+ yield data
Demo_TFR_Pirenderer/src/face3d/data/base_dataset.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
2
+
3
+ It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
4
+ """
5
+ import random
6
+ import numpy as np
7
+ import torch.utils.data as data
8
+ from PIL import Image
9
+ import torchvision.transforms as transforms
10
+ from abc import ABC, abstractmethod
11
+
12
+
13
+ class BaseDataset(data.Dataset, ABC):
14
+ """This class is an abstract base class (ABC) for datasets.
15
+
16
+ To create a subclass, you need to implement the following four functions:
17
+ -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
18
+ -- <__len__>: return the size of dataset.
19
+ -- <__getitem__>: get a data point.
20
+ -- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
21
+ """
22
+
23
+ def __init__(self, opt):
24
+ """Initialize the class; save the options in the class
25
+
26
+ Parameters:
27
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
28
+ """
29
+ self.opt = opt
30
+ # self.root = opt.dataroot
31
+ self.current_epoch = 0
32
+
33
+ @staticmethod
34
+ def modify_commandline_options(parser, is_train):
35
+ """Add new dataset-specific options, and rewrite default values for existing options.
36
+
37
+ Parameters:
38
+ parser -- original option parser
39
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
40
+
41
+ Returns:
42
+ the modified parser.
43
+ """
44
+ return parser
45
+
46
+ @abstractmethod
47
+ def __len__(self):
48
+ """Return the total number of images in the dataset."""
49
+ return 0
50
+
51
+ @abstractmethod
52
+ def __getitem__(self, index):
53
+ """Return a data point and its metadata information.
54
+
55
+ Parameters:
56
+ index - - a random integer for data indexing
57
+
58
+ Returns:
59
+ a dictionary of data with their names. It ususally contains the data itself and its metadata information.
60
+ """
61
+ pass
62
+
63
+
64
+ def get_transform(grayscale=False):
65
+ transform_list = []
66
+ if grayscale:
67
+ transform_list.append(transforms.Grayscale(1))
68
+ transform_list += [transforms.ToTensor()]
69
+ return transforms.Compose(transform_list)
70
+
71
+ def get_affine_mat(opt, size):
72
+ shift_x, shift_y, scale, rot_angle, flip = 0., 0., 1., 0., False
73
+ w, h = size
74
+
75
+ if 'shift' in opt.preprocess:
76
+ shift_pixs = int(opt.shift_pixs)
77
+ shift_x = random.randint(-shift_pixs, shift_pixs)
78
+ shift_y = random.randint(-shift_pixs, shift_pixs)
79
+ if 'scale' in opt.preprocess:
80
+ scale = 1 + opt.scale_delta * (2 * random.random() - 1)
81
+ if 'rot' in opt.preprocess:
82
+ rot_angle = opt.rot_angle * (2 * random.random() - 1)
83
+ rot_rad = -rot_angle * np.pi/180
84
+ if 'flip' in opt.preprocess:
85
+ flip = random.random() > 0.5
86
+
87
+ shift_to_origin = np.array([1, 0, -w//2, 0, 1, -h//2, 0, 0, 1]).reshape([3, 3])
88
+ flip_mat = np.array([-1 if flip else 1, 0, 0, 0, 1, 0, 0, 0, 1]).reshape([3, 3])
89
+ shift_mat = np.array([1, 0, shift_x, 0, 1, shift_y, 0, 0, 1]).reshape([3, 3])
90
+ rot_mat = np.array([np.cos(rot_rad), np.sin(rot_rad), 0, -np.sin(rot_rad), np.cos(rot_rad), 0, 0, 0, 1]).reshape([3, 3])
91
+ scale_mat = np.array([scale, 0, 0, 0, scale, 0, 0, 0, 1]).reshape([3, 3])
92
+ shift_to_center = np.array([1, 0, w//2, 0, 1, h//2, 0, 0, 1]).reshape([3, 3])
93
+
94
+ affine = shift_to_center @ scale_mat @ rot_mat @ shift_mat @ flip_mat @ shift_to_origin
95
+ affine_inv = np.linalg.inv(affine)
96
+ return affine, affine_inv, flip
97
+
98
+ def apply_img_affine(img, affine_inv, method=Image.BICUBIC):
99
+ return img.transform(img.size, Image.AFFINE, data=affine_inv.flatten()[:6], resample=Image.BICUBIC)
100
+
101
+ def apply_lm_affine(landmark, affine, flip, size):
102
+ _, h = size
103
+ lm = landmark.copy()
104
+ lm[:, 1] = h - 1 - lm[:, 1]
105
+ lm = np.concatenate((lm, np.ones([lm.shape[0], 1])), -1)
106
+ lm = lm @ np.transpose(affine)
107
+ lm[:, :2] = lm[:, :2] / lm[:, 2:]
108
+ lm = lm[:, :2]
109
+ lm[:, 1] = h - 1 - lm[:, 1]
110
+ if flip:
111
+ lm_ = lm.copy()
112
+ lm_[:17] = lm[16::-1]
113
+ lm_[17:22] = lm[26:21:-1]
114
+ lm_[22:27] = lm[21:16:-1]
115
+ lm_[31:36] = lm[35:30:-1]
116
+ lm_[36:40] = lm[45:41:-1]
117
+ lm_[40:42] = lm[47:45:-1]
118
+ lm_[42:46] = lm[39:35:-1]
119
+ lm_[46:48] = lm[41:39:-1]
120
+ lm_[48:55] = lm[54:47:-1]
121
+ lm_[55:60] = lm[59:54:-1]
122
+ lm_[60:65] = lm[64:59:-1]
123
+ lm_[65:68] = lm[67:64:-1]
124
+ lm = lm_
125
+ return lm
Demo_TFR_Pirenderer/src/face3d/data/flist_dataset.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This script defines the custom dataset for Deep3DFaceRecon_pytorch
2
+ """
3
+
4
+ import os.path
5
+ from data.base_dataset import BaseDataset, get_transform, get_affine_mat, apply_img_affine, apply_lm_affine
6
+ from data.image_folder import make_dataset
7
+ from PIL import Image
8
+ import random
9
+ import util.util as util
10
+ import numpy as np
11
+ import json
12
+ import torch
13
+ from scipy.io import loadmat, savemat
14
+ import pickle
15
+ from util.preprocess import align_img, estimate_norm
16
+ from util.load_mats import load_lm3d
17
+
18
+
19
+ def default_flist_reader(flist):
20
+ """
21
+ flist format: impath label\nimpath label\n ...(same to caffe's filelist)
22
+ """
23
+ imlist = []
24
+ with open(flist, 'r') as rf:
25
+ for line in rf.readlines():
26
+ impath = line.strip()
27
+ imlist.append(impath)
28
+
29
+ return imlist
30
+
31
+ def jason_flist_reader(flist):
32
+ with open(flist, 'r') as fp:
33
+ info = json.load(fp)
34
+ return info
35
+
36
+ def parse_label(label):
37
+ return torch.tensor(np.array(label).astype(np.float32))
38
+
39
+
40
+ class FlistDataset(BaseDataset):
41
+ """
42
+ It requires one directories to host training images '/path/to/data/train'
43
+ You can train the model with the dataset flag '--dataroot /path/to/data'.
44
+ """
45
+
46
+ def __init__(self, opt):
47
+ """Initialize this dataset class.
48
+
49
+ Parameters:
50
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
51
+ """
52
+ BaseDataset.__init__(self, opt)
53
+
54
+ self.lm3d_std = load_lm3d(opt.bfm_folder)
55
+
56
+ msk_names = default_flist_reader(opt.flist)
57
+ self.msk_paths = [os.path.join(opt.data_root, i) for i in msk_names]
58
+
59
+ self.size = len(self.msk_paths)
60
+ self.opt = opt
61
+
62
+ self.name = 'train' if opt.isTrain else 'val'
63
+ if '_' in opt.flist:
64
+ self.name += '_' + opt.flist.split(os.sep)[-1].split('_')[0]
65
+
66
+
67
+ def __getitem__(self, index):
68
+ """Return a data point and its metadata information.
69
+
70
+ Parameters:
71
+ index (int) -- a random integer for data indexing
72
+
73
+ Returns a dictionary that contains A, B, A_paths and B_paths
74
+ img (tensor) -- an image in the input domain
75
+ msk (tensor) -- its corresponding attention mask
76
+ lm (tensor) -- its corresponding 3d landmarks
77
+ im_paths (str) -- image paths
78
+ aug_flag (bool) -- a flag used to tell whether its raw or augmented
79
+ """
80
+ msk_path = self.msk_paths[index % self.size] # make sure index is within then range
81
+ img_path = msk_path.replace('mask/', '')
82
+ lm_path = '.'.join(msk_path.replace('mask', 'landmarks').split('.')[:-1]) + '.txt'
83
+
84
+ raw_img = Image.open(img_path).convert('RGB')
85
+ raw_msk = Image.open(msk_path).convert('RGB')
86
+ raw_lm = np.loadtxt(lm_path).astype(np.float32)
87
+
88
+ _, img, lm, msk = align_img(raw_img, raw_lm, self.lm3d_std, raw_msk)
89
+
90
+ aug_flag = self.opt.use_aug and self.opt.isTrain
91
+ if aug_flag:
92
+ img, lm, msk = self._augmentation(img, lm, self.opt, msk)
93
+
94
+ _, H = img.size
95
+ M = estimate_norm(lm, H)
96
+ transform = get_transform()
97
+ img_tensor = transform(img)
98
+ msk_tensor = transform(msk)[:1, ...]
99
+ lm_tensor = parse_label(lm)
100
+ M_tensor = parse_label(M)
101
+
102
+
103
+ return {'imgs': img_tensor,
104
+ 'lms': lm_tensor,
105
+ 'msks': msk_tensor,
106
+ 'M': M_tensor,
107
+ 'im_paths': img_path,
108
+ 'aug_flag': aug_flag,
109
+ 'dataset': self.name}
110
+
111
+ def _augmentation(self, img, lm, opt, msk=None):
112
+ affine, affine_inv, flip = get_affine_mat(opt, img.size)
113
+ img = apply_img_affine(img, affine_inv)
114
+ lm = apply_lm_affine(lm, affine, flip, img.size)
115
+ if msk is not None:
116
+ msk = apply_img_affine(msk, affine_inv, method=Image.BILINEAR)
117
+ return img, lm, msk
118
+
119
+
120
+
121
+
122
+ def __len__(self):
123
+ """Return the total number of images in the dataset.
124
+ """
125
+ return self.size
Demo_TFR_Pirenderer/src/face3d/data/image_folder.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A modified image folder class
2
+
3
+ We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
4
+ so that this class can load images from both current directory and its subdirectories.
5
+ """
6
+ import numpy as np
7
+ import torch.utils.data as data
8
+
9
+ from PIL import Image
10
+ import os
11
+ import os.path
12
+
13
+ IMG_EXTENSIONS = [
14
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
15
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
16
+ '.tif', '.TIF', '.tiff', '.TIFF',
17
+ ]
18
+
19
+
20
+ def is_image_file(filename):
21
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
22
+
23
+
24
+ def make_dataset(dir, max_dataset_size=float("inf")):
25
+ images = []
26
+ assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir
27
+
28
+ for root, _, fnames in sorted(os.walk(dir, followlinks=True)):
29
+ for fname in fnames:
30
+ if is_image_file(fname):
31
+ path = os.path.join(root, fname)
32
+ images.append(path)
33
+ return images[:min(max_dataset_size, len(images))]
34
+
35
+
36
+ def default_loader(path):
37
+ return Image.open(path).convert('RGB')
38
+
39
+
40
+ class ImageFolder(data.Dataset):
41
+
42
+ def __init__(self, root, transform=None, return_paths=False,
43
+ loader=default_loader):
44
+ imgs = make_dataset(root)
45
+ if len(imgs) == 0:
46
+ raise(RuntimeError("Found 0 images in: " + root + "\n"
47
+ "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
48
+
49
+ self.root = root
50
+ self.imgs = imgs
51
+ self.transform = transform
52
+ self.return_paths = return_paths
53
+ self.loader = loader
54
+
55
+ def __getitem__(self, index):
56
+ path = self.imgs[index]
57
+ img = self.loader(path)
58
+ if self.transform is not None:
59
+ img = self.transform(img)
60
+ if self.return_paths:
61
+ return img, path
62
+ else:
63
+ return img
64
+
65
+ def __len__(self):
66
+ return len(self.imgs)
Demo_TFR_Pirenderer/src/face3d/data/template_dataset.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Dataset class template
2
+
3
+ This module provides a template for users to implement custom datasets.
4
+ You can specify '--dataset_mode template' to use this dataset.
5
+ The class name should be consistent with both the filename and its dataset_mode option.
6
+ The filename should be <dataset_mode>_dataset.py
7
+ The class name should be <Dataset_mode>Dataset.py
8
+ You need to implement the following functions:
9
+ -- <modify_commandline_options>: Add dataset-specific options and rewrite default values for existing options.
10
+ -- <__init__>: Initialize this dataset class.
11
+ -- <__getitem__>: Return a data point and its metadata information.
12
+ -- <__len__>: Return the number of images.
13
+ """
14
+ from data.base_dataset import BaseDataset, get_transform
15
+ # from data.image_folder import make_dataset
16
+ # from PIL import Image
17
+
18
+
19
+ class TemplateDataset(BaseDataset):
20
+ """A template dataset class for you to implement custom datasets."""
21
+ @staticmethod
22
+ def modify_commandline_options(parser, is_train):
23
+ """Add new dataset-specific options, and rewrite default values for existing options.
24
+
25
+ Parameters:
26
+ parser -- original option parser
27
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
28
+
29
+ Returns:
30
+ the modified parser.
31
+ """
32
+ parser.add_argument('--new_dataset_option', type=float, default=1.0, help='new dataset option')
33
+ parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0) # specify dataset-specific default values
34
+ return parser
35
+
36
+ def __init__(self, opt):
37
+ """Initialize this dataset class.
38
+
39
+ Parameters:
40
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
41
+
42
+ A few things can be done here.
43
+ - save the options (have been done in BaseDataset)
44
+ - get image paths and meta information of the dataset.
45
+ - define the image transformation.
46
+ """
47
+ # save the option and dataset root
48
+ BaseDataset.__init__(self, opt)
49
+ # get the image paths of your dataset;
50
+ self.image_paths = [] # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root
51
+ # define the default transform function. You can use <base_dataset.get_transform>; You can also define your custom transform function
52
+ self.transform = get_transform(opt)
53
+
54
+ def __getitem__(self, index):
55
+ """Return a data point and its metadata information.
56
+
57
+ Parameters:
58
+ index -- a random integer for data indexing
59
+
60
+ Returns:
61
+ a dictionary of data with their names. It usually contains the data itself and its metadata information.
62
+
63
+ Step 1: get a random image path: e.g., path = self.image_paths[index]
64
+ Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB').
65
+ Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image)
66
+ Step 4: return a data point as a dictionary.
67
+ """
68
+ path = 'temp' # needs to be a string
69
+ data_A = None # needs to be a tensor
70
+ data_B = None # needs to be a tensor
71
+ return {'data_A': data_A, 'data_B': data_B, 'path': path}
72
+
73
+ def __len__(self):
74
+ """Return the total number of images."""
75
+ return len(self.image_paths)
Demo_TFR_Pirenderer/src/face3d/extract_kp_videos.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import time
4
+ import glob
5
+ import argparse
6
+ import face_alignment
7
+ import numpy as np
8
+ from PIL import Image
9
+ from tqdm import tqdm
10
+ from itertools import cycle
11
+
12
+ from torch.multiprocessing import Pool, Process, set_start_method
13
+
14
+ class KeypointExtractor():
15
+ def __init__(self, device):
16
+ self.detector = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D,
17
+ device=device)
18
+
19
+ def extract_keypoint(self, images, name=None, info=True):
20
+ if isinstance(images, list):
21
+ keypoints = []
22
+ if info:
23
+ i_range = tqdm(images,desc='landmark Det:')
24
+ else:
25
+ i_range = images
26
+
27
+ for image in i_range:
28
+ current_kp = self.extract_keypoint(image)
29
+ if np.mean(current_kp) == -1 and keypoints:
30
+ keypoints.append(keypoints[-1])
31
+ else:
32
+ keypoints.append(current_kp[None])
33
+
34
+ keypoints = np.concatenate(keypoints, 0)
35
+ np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
36
+ return keypoints
37
+ else:
38
+ while True:
39
+ try:
40
+ keypoints = self.detector.get_landmarks_from_image(np.array(images))[0]
41
+ break
42
+ except RuntimeError as e:
43
+ if str(e).startswith('CUDA'):
44
+ print("Warning: out of memory, sleep for 1s")
45
+ time.sleep(1)
46
+ else:
47
+ print(e)
48
+ break
49
+ except TypeError:
50
+ print('No face detected in this image')
51
+ shape = [68, 2]
52
+ keypoints = -1. * np.ones(shape)
53
+ break
54
+ if name is not None:
55
+ np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
56
+ return keypoints
57
+
58
+ def read_video(filename):
59
+ frames = []
60
+ cap = cv2.VideoCapture(filename)
61
+ while cap.isOpened():
62
+ ret, frame = cap.read()
63
+ if ret:
64
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
65
+ frame = Image.fromarray(frame)
66
+ frames.append(frame)
67
+ else:
68
+ break
69
+ cap.release()
70
+ return frames
71
+
72
+ def run(data):
73
+ filename, opt, device = data
74
+ os.environ['CUDA_VISIBLE_DEVICES'] = device
75
+ kp_extractor = KeypointExtractor()
76
+ images = read_video(filename)
77
+ name = filename.split('/')[-2:]
78
+ os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True)
79
+ kp_extractor.extract_keypoint(
80
+ images,
81
+ name=os.path.join(opt.output_dir, name[-2], name[-1])
82
+ )
83
+
84
+ if __name__ == '__main__':
85
+ set_start_method('spawn')
86
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
87
+ parser.add_argument('--input_dir', type=str, help='the folder of the input files')
88
+ parser.add_argument('--output_dir', type=str, help='the folder of the output files')
89
+ parser.add_argument('--device_ids', type=str, default='0,1')
90
+ parser.add_argument('--workers', type=int, default=4)
91
+
92
+ opt = parser.parse_args()
93
+ filenames = list()
94
+ VIDEO_EXTENSIONS_LOWERCASE = {'mp4'}
95
+ VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE})
96
+ extensions = VIDEO_EXTENSIONS
97
+
98
+ for ext in extensions:
99
+ os.listdir(f'{opt.input_dir}')
100
+ print(f'{opt.input_dir}/*.{ext}')
101
+ filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}'))
102
+ print('Total number of videos:', len(filenames))
103
+ pool = Pool(opt.workers)
104
+ args_list = cycle([opt])
105
+ device_ids = opt.device_ids.split(",")
106
+ device_ids = cycle(device_ids)
107
+ for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))):
108
+ None
Demo_TFR_Pirenderer/src/face3d/extract_kp_videos_safe.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import time
4
+ import glob
5
+ import argparse
6
+ import numpy as np
7
+ from PIL import Image
8
+ import torch
9
+ from tqdm import tqdm
10
+ from itertools import cycle
11
+ from facexlib.alignment import init_alignment_model, landmark_98_to_68
12
+ from facexlib.detection import init_detection_model
13
+ from torch.multiprocessing import Pool, Process, set_start_method
14
+
15
+
16
+ class KeypointExtractor():
17
+ def __init__(self, device='cuda'):
18
+
19
+ ### gfpgan/weights
20
+ try:
21
+ import webui # in webui
22
+ root_path = 'extensions/SadTalker/gfpgan/weights'
23
+
24
+ except:
25
+ root_path = 'gfpgan/weights'
26
+
27
+ self.detector = init_alignment_model('awing_fan',device=device, model_rootpath=root_path)
28
+ self.det_net = init_detection_model('retinaface_resnet50', half=False,device=device, model_rootpath=root_path)
29
+
30
+ def extract_keypoint(self, images, name=None, info=True):
31
+ if isinstance(images, list):
32
+ keypoints = []
33
+ if info:
34
+ i_range = tqdm(images,desc='landmark Det:')
35
+ else:
36
+ i_range = images
37
+
38
+ for image in i_range:
39
+ current_kp = self.extract_keypoint(image)
40
+ # current_kp = self.detector.get_landmarks(np.array(image))
41
+ if np.mean(current_kp) == -1 and keypoints:
42
+ keypoints.append(keypoints[-1])
43
+ else:
44
+ keypoints.append(current_kp[None])
45
+
46
+ keypoints = np.concatenate(keypoints, 0)
47
+ np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
48
+ return keypoints
49
+ else:
50
+ while True:
51
+ try:
52
+ with torch.no_grad():
53
+ # face detection -> face alignment.
54
+ img = np.array(images)
55
+ bboxes = self.det_net.detect_faces(images, 0.97)
56
+
57
+ bboxes = bboxes[0]
58
+
59
+ # bboxes[0] -= 100
60
+ # bboxes[1] -= 100
61
+ # bboxes[2] += 100
62
+ # bboxes[3] += 100
63
+ img = img[int(bboxes[1]):int(bboxes[3]), int(bboxes[0]):int(bboxes[2]), :]
64
+
65
+ keypoints = landmark_98_to_68(self.detector.get_landmarks(img)) # [0]
66
+
67
+ #### keypoints to the original location
68
+ keypoints[:,0] += int(bboxes[0])
69
+ keypoints[:,1] += int(bboxes[1])
70
+
71
+ break
72
+ except RuntimeError as e:
73
+ if str(e).startswith('CUDA'):
74
+ print("Warning: out of memory, sleep for 1s")
75
+ time.sleep(1)
76
+ else:
77
+ print(e)
78
+ break
79
+ except TypeError:
80
+ print('No face detected in this image')
81
+ shape = [68, 2]
82
+ keypoints = -1. * np.ones(shape)
83
+ break
84
+ if name is not None:
85
+ np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
86
+ return keypoints
87
+
88
+ def read_video(filename):
89
+ frames = []
90
+ cap = cv2.VideoCapture(filename)
91
+ while cap.isOpened():
92
+ ret, frame = cap.read()
93
+ if ret:
94
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
95
+ frame = Image.fromarray(frame)
96
+ frames.append(frame)
97
+ else:
98
+ break
99
+ cap.release()
100
+ return frames
101
+
102
+ def run(data):
103
+ filename, opt, device = data
104
+ os.environ['CUDA_VISIBLE_DEVICES'] = device
105
+ kp_extractor = KeypointExtractor()
106
+ images = read_video(filename)
107
+ name = filename.split('/')[-2:]
108
+ os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True)
109
+ kp_extractor.extract_keypoint(
110
+ images,
111
+ name=os.path.join(opt.output_dir, name[-2], name[-1])
112
+ )
113
+
114
+ if __name__ == '__main__':
115
+ set_start_method('spawn')
116
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
117
+ parser.add_argument('--input_dir', type=str, help='the folder of the input files')
118
+ parser.add_argument('--output_dir', type=str, help='the folder of the output files')
119
+ parser.add_argument('--device_ids', type=str, default='0,1')
120
+ parser.add_argument('--workers', type=int, default=4)
121
+
122
+ opt = parser.parse_args()
123
+ filenames = list()
124
+ VIDEO_EXTENSIONS_LOWERCASE = {'mp4'}
125
+ VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE})
126
+ extensions = VIDEO_EXTENSIONS
127
+
128
+ for ext in extensions:
129
+ os.listdir(f'{opt.input_dir}')
130
+ print(f'{opt.input_dir}/*.{ext}')
131
+ filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}'))
132
+ print('Total number of videos:', len(filenames))
133
+ pool = Pool(opt.workers)
134
+ args_list = cycle([opt])
135
+ device_ids = opt.device_ids.split(",")
136
+ device_ids = cycle(device_ids)
137
+ for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))):
138
+ None
Demo_TFR_Pirenderer/src/face3d/models/__init__.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This package contains modules related to objective functions, optimizations, and network architectures.
2
+
3
+ To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
4
+ You need to implement the following five functions:
5
+ -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
6
+ -- <set_input>: unpack data from dataset and apply preprocessing.
7
+ -- <forward>: produce intermediate results.
8
+ -- <optimize_parameters>: calculate loss, gradients, and update network weights.
9
+ -- <modify_commandline_options>: (optionally) add model-specific options and set default options.
10
+
11
+ In the function <__init__>, you need to define four lists:
12
+ -- self.loss_names (str list): specify the training losses that you want to plot and save.
13
+ -- self.model_names (str list): define networks used in our training.
14
+ -- self.visual_names (str list): specify the images that you want to display and save.
15
+ -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
16
+
17
+ Now you can use the model class by specifying flag '--model dummy'.
18
+ See our template model class 'template_model.py' for more details.
19
+ """
20
+
21
+ import importlib
22
+ from src.face3d.models.base_model import BaseModel
23
+
24
+
25
+ def find_model_using_name(model_name):
26
+ """Import the module "models/[model_name]_model.py".
27
+
28
+ In the file, the class called DatasetNameModel() will
29
+ be instantiated. It has to be a subclass of BaseModel,
30
+ and it is case-insensitive.
31
+ """
32
+ model_filename = "face3d.models." + model_name + "_model"
33
+ modellib = importlib.import_module(model_filename)
34
+ model = None
35
+ target_model_name = model_name.replace('_', '') + 'model'
36
+ for name, cls in modellib.__dict__.items():
37
+ if name.lower() == target_model_name.lower() \
38
+ and issubclass(cls, BaseModel):
39
+ model = cls
40
+
41
+ if model is None:
42
+ print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
43
+ exit(0)
44
+
45
+ return model
46
+
47
+
48
+ def get_option_setter(model_name):
49
+ """Return the static method <modify_commandline_options> of the model class."""
50
+ model_class = find_model_using_name(model_name)
51
+ return model_class.modify_commandline_options
52
+
53
+
54
+ def create_model(opt):
55
+ """Create a model given the option.
56
+
57
+ This function warps the class CustomDatasetDataLoader.
58
+ This is the main interface between this package and 'train.py'/'test.py'
59
+
60
+ Example:
61
+ >>> from models import create_model
62
+ >>> model = create_model(opt)
63
+ """
64
+ model = find_model_using_name(opt.model)
65
+ instance = model(opt)
66
+ print("model [%s] was created" % type(instance).__name__)
67
+ return instance
Demo_TFR_Pirenderer/src/face3d/models/arcface_torch/README.md ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Distributed Arcface Training in Pytorch
2
+
3
+ This is a deep learning library that makes face recognition efficient, and effective, which can train tens of millions
4
+ identity on a single server.
5
+
6
+ ## Requirements
7
+
8
+ - Install [pytorch](http://pytorch.org) (torch>=1.6.0), our doc for [install.md](docs/install.md).
9
+ - `pip install -r requirements.txt`.
10
+ - Download the dataset
11
+ from [https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_](https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_)
12
+ .
13
+
14
+ ## How to Training
15
+
16
+ To train a model, run `train.py` with the path to the configs:
17
+
18
+ ### 1. Single node, 8 GPUs:
19
+
20
+ ```shell
21
+ python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r50
22
+ ```
23
+
24
+ ### 2. Multiple nodes, each node 8 GPUs:
25
+
26
+ Node 0:
27
+
28
+ ```shell
29
+ python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50
30
+ ```
31
+
32
+ Node 1:
33
+
34
+ ```shell
35
+ python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50
36
+ ```
37
+
38
+ ### 3.Training resnet2060 with 8 GPUs:
39
+
40
+ ```shell
41
+ python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r2060.py
42
+ ```
43
+
44
+ ## Model Zoo
45
+
46
+ - The models are available for non-commercial research purposes only.
47
+ - All models can be found in here.
48
+ - [Baidu Yun Pan](https://pan.baidu.com/s/1CL-l4zWqsI1oDuEEYVhj-g): e8pw
49
+ - [onedrive](https://1drv.ms/u/s!AswpsDO2toNKq0lWY69vN58GR6mw?e=p9Ov5d)
50
+
51
+ ### Performance on [**ICCV2021-MFR**](http://iccv21-mfr.com/)
52
+
53
+ ICCV2021-MFR testset consists of non-celebrities so we can ensure that it has very few overlap with public available face
54
+ recognition training set, such as MS1M and CASIA as they mostly collected from online celebrities.
55
+ As the result, we can evaluate the FAIR performance for different algorithms.
56
+
57
+ For **ICCV2021-MFR-ALL** set, TAR is measured on all-to-all 1:1 protocal, with FAR less than 0.000001(e-6). The
58
+ globalised multi-racial testset contains 242,143 identities and 1,624,305 images.
59
+
60
+ For **ICCV2021-MFR-MASK** set, TAR is measured on mask-to-nonmask 1:1 protocal, with FAR less than 0.0001(e-4).
61
+ Mask testset contains 6,964 identities, 6,964 masked images and 13,928 non-masked images.
62
+ There are totally 13,928 positive pairs and 96,983,824 negative pairs.
63
+
64
+ | Datasets | backbone | Training throughout | Size / MB | **ICCV2021-MFR-MASK** | **ICCV2021-MFR-ALL** |
65
+ | :---: | :--- | :--- | :--- |:--- |:--- |
66
+ | MS1MV3 | r18 | - | 91 | **47.85** | **68.33** |
67
+ | Glint360k | r18 | 8536 | 91 | **53.32** | **72.07** |
68
+ | MS1MV3 | r34 | - | 130 | **58.72** | **77.36** |
69
+ | Glint360k | r34 | 6344 | 130 | **65.10** | **83.02** |
70
+ | MS1MV3 | r50 | 5500 | 166 | **63.85** | **80.53** |
71
+ | Glint360k | r50 | 5136 | 166 | **70.23** | **87.08** |
72
+ | MS1MV3 | r100 | - | 248 | **69.09** | **84.31** |
73
+ | Glint360k | r100 | 3332 | 248 | **75.57** | **90.66** |
74
+ | MS1MV3 | mobilefacenet | 12185 | 7.8 | **41.52** | **65.26** |
75
+ | Glint360k | mobilefacenet | 11197 | 7.8 | **44.52** | **66.48** |
76
+
77
+ ### Performance on IJB-C and Verification Datasets
78
+
79
+ | Datasets | backbone | IJBC(1e-05) | IJBC(1e-04) | agedb30 | cfp_fp | lfw | log |
80
+ | :---: | :--- | :--- | :--- | :--- |:--- |:--- |:--- |
81
+ | MS1MV3 | r18 | 92.07 | 94.66 | 97.77 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r18_fp16/training.log)|
82
+ | MS1MV3 | r34 | 94.10 | 95.90 | 98.10 | 98.67 | 99.80 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r34_fp16/training.log)|
83
+ | MS1MV3 | r50 | 94.79 | 96.46 | 98.35 | 98.96 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r50_fp16/training.log)|
84
+ | MS1MV3 | r100 | 95.31 | 96.81 | 98.48 | 99.06 | 99.85 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r100_fp16/training.log)|
85
+ | MS1MV3 | **r2060**| 95.34 | 97.11 | 98.67 | 99.24 | 99.87 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r2060_fp16/training.log)|
86
+ | Glint360k |r18-0.1 | 93.16 | 95.33 | 97.72 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r18_fp16_0.1/training.log)|
87
+ | Glint360k |r34-0.1 | 95.16 | 96.56 | 98.33 | 98.78 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r34_fp16_0.1/training.log)|
88
+ | Glint360k |r50-0.1 | 95.61 | 96.97 | 98.38 | 99.20 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r50_fp16_0.1/training.log)|
89
+ | Glint360k |r100-0.1 | 95.88 | 97.32 | 98.48 | 99.29 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r100_fp16_0.1/training.log)|
90
+
91
+ [comment]: <> (More details see [model.md]&#40;docs/modelzoo.md&#41; in docs.)
92
+
93
+
94
+ ## [Speed Benchmark](docs/speed_benchmark.md)
95
+
96
+ **Arcface Torch** can train large-scale face recognition training set efficiently and quickly. When the number of
97
+ classes in training sets is greater than 300K and the training is sufficient, partial fc sampling strategy will get same
98
+ accuracy with several times faster training performance and smaller GPU memory.
99
+ Partial FC is a sparse variant of the model parallel architecture for large sacle face recognition. Partial FC use a
100
+ sparse softmax, where each batch dynamicly sample a subset of class centers for training. In each iteration, only a
101
+ sparse part of the parameters will be updated, which can reduce a lot of GPU memory and calculations. With Partial FC,
102
+ we can scale trainset of 29 millions identities, the largest to date. Partial FC also supports multi-machine distributed
103
+ training and mixed precision training.
104
+
105
+ ![Image text](https://github.com/anxiangsir/insightface_arcface_log/blob/master/partial_fc_v2.png)
106
+
107
+ More details see
108
+ [speed_benchmark.md](docs/speed_benchmark.md) in docs.
109
+
110
+ ### 1. Training speed of different parallel methods (samples / second), Tesla V100 32GB * 8. (Larger is better)
111
+
112
+ `-` means training failed because of gpu memory limitations.
113
+
114
+ | Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
115
+ | :--- | :--- | :--- | :--- |
116
+ |125000 | 4681 | 4824 | 5004 |
117
+ |1400000 | **1672** | 3043 | 4738 |
118
+ |5500000 | **-** | **1389** | 3975 |
119
+ |8000000 | **-** | **-** | 3565 |
120
+ |16000000 | **-** | **-** | 2679 |
121
+ |29000000 | **-** | **-** | **1855** |
122
+
123
+ ### 2. GPU memory cost of different parallel methods (MB per GPU), Tesla V100 32GB * 8. (Smaller is better)
124
+
125
+ | Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
126
+ | :--- | :--- | :--- | :--- |
127
+ |125000 | 7358 | 5306 | 4868 |
128
+ |1400000 | 32252 | 11178 | 6056 |
129
+ |5500000 | **-** | 32188 | 9854 |
130
+ |8000000 | **-** | **-** | 12310 |
131
+ |16000000 | **-** | **-** | 19950 |
132
+ |29000000 | **-** | **-** | 32324 |
133
+
134
+ ## Evaluation ICCV2021-MFR and IJB-C
135
+
136
+ More details see [eval.md](docs/eval.md) in docs.
137
+
138
+ ## Test
139
+
140
+ We tested many versions of PyTorch. Please create an issue if you are having trouble.
141
+
142
+ - [x] torch 1.6.0
143
+ - [x] torch 1.7.1
144
+ - [x] torch 1.8.0
145
+ - [x] torch 1.9.0
146
+
147
+ ## Citation
148
+
149
+ ```
150
+ @inproceedings{deng2019arcface,
151
+ title={Arcface: Additive angular margin loss for deep face recognition},
152
+ author={Deng, Jiankang and Guo, Jia and Xue, Niannan and Zafeiriou, Stefanos},
153
+ booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
154
+ pages={4690--4699},
155
+ year={2019}
156
+ }
157
+ @inproceedings{an2020partical_fc,
158
+ title={Partial FC: Training 10 Million Identities on a Single Machine},
159
+ author={An, Xiang and Zhu, Xuhan and Xiao, Yang and Wu, Lan and Zhang, Ming and Gao, Yuan and Qin, Bin and
160
+ Zhang, Debing and Fu Ying},
161
+ booktitle={Arxiv 2010.05222},
162
+ year={2020}
163
+ }
164
+ ```
Demo_TFR_Pirenderer/src/face3d/models/arcface_torch/backbones/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .iresnet import iresnet18, iresnet34, iresnet50, iresnet100, iresnet200
2
+ from .mobilefacenet import get_mbf
3
+
4
+
5
+ def get_model(name, **kwargs):
6
+ # resnet
7
+ if name == "r18":
8
+ return iresnet18(False, **kwargs)
9
+ elif name == "r34":
10
+ return iresnet34(False, **kwargs)
11
+ elif name == "r50":
12
+ return iresnet50(False, **kwargs)
13
+ elif name == "r100":
14
+ return iresnet100(False, **kwargs)
15
+ elif name == "r200":
16
+ return iresnet200(False, **kwargs)
17
+ elif name == "r2060":
18
+ from .iresnet2060 import iresnet2060
19
+ return iresnet2060(False, **kwargs)
20
+ elif name == "mbf":
21
+ fp16 = kwargs.get("fp16", False)
22
+ num_features = kwargs.get("num_features", 512)
23
+ return get_mbf(fp16=fp16, num_features=num_features)
24
+ else:
25
+ raise ValueError()
Demo_TFR_Pirenderer/src/face3d/models/arcface_torch/backbones/iresnet.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ __all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200']
5
+
6
+
7
+ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
8
+ """3x3 convolution with padding"""
9
+ return nn.Conv2d(in_planes,
10
+ out_planes,
11
+ kernel_size=3,
12
+ stride=stride,
13
+ padding=dilation,
14
+ groups=groups,
15
+ bias=False,
16
+ dilation=dilation)
17
+
18
+
19
+ def conv1x1(in_planes, out_planes, stride=1):
20
+ """1x1 convolution"""
21
+ return nn.Conv2d(in_planes,
22
+ out_planes,
23
+ kernel_size=1,
24
+ stride=stride,
25
+ bias=False)
26
+
27
+
28
+ class IBasicBlock(nn.Module):
29
+ expansion = 1
30
+ def __init__(self, inplanes, planes, stride=1, downsample=None,
31
+ groups=1, base_width=64, dilation=1):
32
+ super(IBasicBlock, self).__init__()
33
+ if groups != 1 or base_width != 64:
34
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
35
+ if dilation > 1:
36
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
37
+ self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,)
38
+ self.conv1 = conv3x3(inplanes, planes)
39
+ self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,)
40
+ self.prelu = nn.PReLU(planes)
41
+ self.conv2 = conv3x3(planes, planes, stride)
42
+ self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,)
43
+ self.downsample = downsample
44
+ self.stride = stride
45
+
46
+ def forward(self, x):
47
+ identity = x
48
+ out = self.bn1(x)
49
+ out = self.conv1(out)
50
+ out = self.bn2(out)
51
+ out = self.prelu(out)
52
+ out = self.conv2(out)
53
+ out = self.bn3(out)
54
+ if self.downsample is not None:
55
+ identity = self.downsample(x)
56
+ out += identity
57
+ return out
58
+
59
+
60
+ class IResNet(nn.Module):
61
+ fc_scale = 7 * 7
62
+ def __init__(self,
63
+ block, layers, dropout=0, num_features=512, zero_init_residual=False,
64
+ groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
65
+ super(IResNet, self).__init__()
66
+ self.fp16 = fp16
67
+ self.inplanes = 64
68
+ self.dilation = 1
69
+ if replace_stride_with_dilation is None:
70
+ replace_stride_with_dilation = [False, False, False]
71
+ if len(replace_stride_with_dilation) != 3:
72
+ raise ValueError("replace_stride_with_dilation should be None "
73
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
74
+ self.groups = groups
75
+ self.base_width = width_per_group
76
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
77
+ self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
78
+ self.prelu = nn.PReLU(self.inplanes)
79
+ self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
80
+ self.layer2 = self._make_layer(block,
81
+ 128,
82
+ layers[1],
83
+ stride=2,
84
+ dilate=replace_stride_with_dilation[0])
85
+ self.layer3 = self._make_layer(block,
86
+ 256,
87
+ layers[2],
88
+ stride=2,
89
+ dilate=replace_stride_with_dilation[1])
90
+ self.layer4 = self._make_layer(block,
91
+ 512,
92
+ layers[3],
93
+ stride=2,
94
+ dilate=replace_stride_with_dilation[2])
95
+ self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,)
96
+ self.dropout = nn.Dropout(p=dropout, inplace=True)
97
+ self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
98
+ self.features = nn.BatchNorm1d(num_features, eps=1e-05)
99
+ nn.init.constant_(self.features.weight, 1.0)
100
+ self.features.weight.requires_grad = False
101
+
102
+ for m in self.modules():
103
+ if isinstance(m, nn.Conv2d):
104
+ nn.init.normal_(m.weight, 0, 0.1)
105
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
106
+ nn.init.constant_(m.weight, 1)
107
+ nn.init.constant_(m.bias, 0)
108
+
109
+ if zero_init_residual:
110
+ for m in self.modules():
111
+ if isinstance(m, IBasicBlock):
112
+ nn.init.constant_(m.bn2.weight, 0)
113
+
114
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
115
+ downsample = None
116
+ previous_dilation = self.dilation
117
+ if dilate:
118
+ self.dilation *= stride
119
+ stride = 1
120
+ if stride != 1 or self.inplanes != planes * block.expansion:
121
+ downsample = nn.Sequential(
122
+ conv1x1(self.inplanes, planes * block.expansion, stride),
123
+ nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
124
+ )
125
+ layers = []
126
+ layers.append(
127
+ block(self.inplanes, planes, stride, downsample, self.groups,
128
+ self.base_width, previous_dilation))
129
+ self.inplanes = planes * block.expansion
130
+ for _ in range(1, blocks):
131
+ layers.append(
132
+ block(self.inplanes,
133
+ planes,
134
+ groups=self.groups,
135
+ base_width=self.base_width,
136
+ dilation=self.dilation))
137
+
138
+ return nn.Sequential(*layers)
139
+
140
+ def forward(self, x):
141
+ with torch.cuda.amp.autocast(self.fp16):
142
+ x = self.conv1(x)
143
+ x = self.bn1(x)
144
+ x = self.prelu(x)
145
+ x = self.layer1(x)
146
+ x = self.layer2(x)
147
+ x = self.layer3(x)
148
+ x = self.layer4(x)
149
+ x = self.bn2(x)
150
+ x = torch.flatten(x, 1)
151
+ x = self.dropout(x)
152
+ x = self.fc(x.float() if self.fp16 else x)
153
+ x = self.features(x)
154
+ return x
155
+
156
+
157
+ def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
158
+ model = IResNet(block, layers, **kwargs)
159
+ if pretrained:
160
+ raise ValueError()
161
+ return model
162
+
163
+
164
+ def iresnet18(pretrained=False, progress=True, **kwargs):
165
+ return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained,
166
+ progress, **kwargs)
167
+
168
+
169
+ def iresnet34(pretrained=False, progress=True, **kwargs):
170
+ return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained,
171
+ progress, **kwargs)
172
+
173
+
174
+ def iresnet50(pretrained=False, progress=True, **kwargs):
175
+ return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained,
176
+ progress, **kwargs)
177
+
178
+
179
+ def iresnet100(pretrained=False, progress=True, **kwargs):
180
+ return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained,
181
+ progress, **kwargs)
182
+
183
+
184
+ def iresnet200(pretrained=False, progress=True, **kwargs):
185
+ return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained,
186
+ progress, **kwargs)
187
+
Demo_TFR_Pirenderer/src/face3d/models/arcface_torch/backbones/iresnet2060.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ assert torch.__version__ >= "1.8.1"
5
+ from torch.utils.checkpoint import checkpoint_sequential
6
+
7
+ __all__ = ['iresnet2060']
8
+
9
+
10
+ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
11
+ """3x3 convolution with padding"""
12
+ return nn.Conv2d(in_planes,
13
+ out_planes,
14
+ kernel_size=3,
15
+ stride=stride,
16
+ padding=dilation,
17
+ groups=groups,
18
+ bias=False,
19
+ dilation=dilation)
20
+
21
+
22
+ def conv1x1(in_planes, out_planes, stride=1):
23
+ """1x1 convolution"""
24
+ return nn.Conv2d(in_planes,
25
+ out_planes,
26
+ kernel_size=1,
27
+ stride=stride,
28
+ bias=False)
29
+
30
+
31
+ class IBasicBlock(nn.Module):
32
+ expansion = 1
33
+
34
+ def __init__(self, inplanes, planes, stride=1, downsample=None,
35
+ groups=1, base_width=64, dilation=1):
36
+ super(IBasicBlock, self).__init__()
37
+ if groups != 1 or base_width != 64:
38
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
39
+ if dilation > 1:
40
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
41
+ self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05, )
42
+ self.conv1 = conv3x3(inplanes, planes)
43
+ self.bn2 = nn.BatchNorm2d(planes, eps=1e-05, )
44
+ self.prelu = nn.PReLU(planes)
45
+ self.conv2 = conv3x3(planes, planes, stride)
46
+ self.bn3 = nn.BatchNorm2d(planes, eps=1e-05, )
47
+ self.downsample = downsample
48
+ self.stride = stride
49
+
50
+ def forward(self, x):
51
+ identity = x
52
+ out = self.bn1(x)
53
+ out = self.conv1(out)
54
+ out = self.bn2(out)
55
+ out = self.prelu(out)
56
+ out = self.conv2(out)
57
+ out = self.bn3(out)
58
+ if self.downsample is not None:
59
+ identity = self.downsample(x)
60
+ out += identity
61
+ return out
62
+
63
+
64
+ class IResNet(nn.Module):
65
+ fc_scale = 7 * 7
66
+
67
+ def __init__(self,
68
+ block, layers, dropout=0, num_features=512, zero_init_residual=False,
69
+ groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
70
+ super(IResNet, self).__init__()
71
+ self.fp16 = fp16
72
+ self.inplanes = 64
73
+ self.dilation = 1
74
+ if replace_stride_with_dilation is None:
75
+ replace_stride_with_dilation = [False, False, False]
76
+ if len(replace_stride_with_dilation) != 3:
77
+ raise ValueError("replace_stride_with_dilation should be None "
78
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
79
+ self.groups = groups
80
+ self.base_width = width_per_group
81
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
82
+ self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
83
+ self.prelu = nn.PReLU(self.inplanes)
84
+ self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
85
+ self.layer2 = self._make_layer(block,
86
+ 128,
87
+ layers[1],
88
+ stride=2,
89
+ dilate=replace_stride_with_dilation[0])
90
+ self.layer3 = self._make_layer(block,
91
+ 256,
92
+ layers[2],
93
+ stride=2,
94
+ dilate=replace_stride_with_dilation[1])
95
+ self.layer4 = self._make_layer(block,
96
+ 512,
97
+ layers[3],
98
+ stride=2,
99
+ dilate=replace_stride_with_dilation[2])
100
+ self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05, )
101
+ self.dropout = nn.Dropout(p=dropout, inplace=True)
102
+ self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
103
+ self.features = nn.BatchNorm1d(num_features, eps=1e-05)
104
+ nn.init.constant_(self.features.weight, 1.0)
105
+ self.features.weight.requires_grad = False
106
+
107
+ for m in self.modules():
108
+ if isinstance(m, nn.Conv2d):
109
+ nn.init.normal_(m.weight, 0, 0.1)
110
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
111
+ nn.init.constant_(m.weight, 1)
112
+ nn.init.constant_(m.bias, 0)
113
+
114
+ if zero_init_residual:
115
+ for m in self.modules():
116
+ if isinstance(m, IBasicBlock):
117
+ nn.init.constant_(m.bn2.weight, 0)
118
+
119
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
120
+ downsample = None
121
+ previous_dilation = self.dilation
122
+ if dilate:
123
+ self.dilation *= stride
124
+ stride = 1
125
+ if stride != 1 or self.inplanes != planes * block.expansion:
126
+ downsample = nn.Sequential(
127
+ conv1x1(self.inplanes, planes * block.expansion, stride),
128
+ nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
129
+ )
130
+ layers = []
131
+ layers.append(
132
+ block(self.inplanes, planes, stride, downsample, self.groups,
133
+ self.base_width, previous_dilation))
134
+ self.inplanes = planes * block.expansion
135
+ for _ in range(1, blocks):
136
+ layers.append(
137
+ block(self.inplanes,
138
+ planes,
139
+ groups=self.groups,
140
+ base_width=self.base_width,
141
+ dilation=self.dilation))
142
+
143
+ return nn.Sequential(*layers)
144
+
145
+ def checkpoint(self, func, num_seg, x):
146
+ if self.training:
147
+ return checkpoint_sequential(func, num_seg, x)
148
+ else:
149
+ return func(x)
150
+
151
+ def forward(self, x):
152
+ with torch.cuda.amp.autocast(self.fp16):
153
+ x = self.conv1(x)
154
+ x = self.bn1(x)
155
+ x = self.prelu(x)
156
+ x = self.layer1(x)
157
+ x = self.checkpoint(self.layer2, 20, x)
158
+ x = self.checkpoint(self.layer3, 100, x)
159
+ x = self.layer4(x)
160
+ x = self.bn2(x)
161
+ x = torch.flatten(x, 1)
162
+ x = self.dropout(x)
163
+ x = self.fc(x.float() if self.fp16 else x)
164
+ x = self.features(x)
165
+ return x
166
+
167
+
168
+ def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
169
+ model = IResNet(block, layers, **kwargs)
170
+ if pretrained:
171
+ raise ValueError()
172
+ return model
173
+
174
+
175
+ def iresnet2060(pretrained=False, progress=True, **kwargs):
176
+ return _iresnet('iresnet2060', IBasicBlock, [3, 128, 1024 - 128, 3], pretrained, progress, **kwargs)
Demo_TFR_Pirenderer/src/face3d/models/arcface_torch/backbones/mobilefacenet.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Adapted from https://github.com/cavalleria/cavaface.pytorch/blob/master/backbone/mobilefacenet.py
3
+ Original author cavalleria
4
+ '''
5
+
6
+ import torch.nn as nn
7
+ from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Sequential, Module
8
+ import torch
9
+
10
+
11
+ class Flatten(Module):
12
+ def forward(self, x):
13
+ return x.view(x.size(0), -1)
14
+
15
+
16
+ class ConvBlock(Module):
17
+ def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
18
+ super(ConvBlock, self).__init__()
19
+ self.layers = nn.Sequential(
20
+ Conv2d(in_c, out_c, kernel, groups=groups, stride=stride, padding=padding, bias=False),
21
+ BatchNorm2d(num_features=out_c),
22
+ PReLU(num_parameters=out_c)
23
+ )
24
+
25
+ def forward(self, x):
26
+ return self.layers(x)
27
+
28
+
29
+ class LinearBlock(Module):
30
+ def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
31
+ super(LinearBlock, self).__init__()
32
+ self.layers = nn.Sequential(
33
+ Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False),
34
+ BatchNorm2d(num_features=out_c)
35
+ )
36
+
37
+ def forward(self, x):
38
+ return self.layers(x)
39
+
40
+
41
+ class DepthWise(Module):
42
+ def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1):
43
+ super(DepthWise, self).__init__()
44
+ self.residual = residual
45
+ self.layers = nn.Sequential(
46
+ ConvBlock(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)),
47
+ ConvBlock(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride),
48
+ LinearBlock(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
49
+ )
50
+
51
+ def forward(self, x):
52
+ short_cut = None
53
+ if self.residual:
54
+ short_cut = x
55
+ x = self.layers(x)
56
+ if self.residual:
57
+ output = short_cut + x
58
+ else:
59
+ output = x
60
+ return output
61
+
62
+
63
+ class Residual(Module):
64
+ def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)):
65
+ super(Residual, self).__init__()
66
+ modules = []
67
+ for _ in range(num_block):
68
+ modules.append(DepthWise(c, c, True, kernel, stride, padding, groups))
69
+ self.layers = Sequential(*modules)
70
+
71
+ def forward(self, x):
72
+ return self.layers(x)
73
+
74
+
75
+ class GDC(Module):
76
+ def __init__(self, embedding_size):
77
+ super(GDC, self).__init__()
78
+ self.layers = nn.Sequential(
79
+ LinearBlock(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0)),
80
+ Flatten(),
81
+ Linear(512, embedding_size, bias=False),
82
+ BatchNorm1d(embedding_size))
83
+
84
+ def forward(self, x):
85
+ return self.layers(x)
86
+
87
+
88
+ class MobileFaceNet(Module):
89
+ def __init__(self, fp16=False, num_features=512):
90
+ super(MobileFaceNet, self).__init__()
91
+ scale = 2
92
+ self.fp16 = fp16
93
+ self.layers = nn.Sequential(
94
+ ConvBlock(3, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1)),
95
+ ConvBlock(64 * scale, 64 * scale, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64),
96
+ DepthWise(64 * scale, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128),
97
+ Residual(64 * scale, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
98
+ DepthWise(64 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256),
99
+ Residual(128 * scale, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
100
+ DepthWise(128 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512),
101
+ Residual(128 * scale, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
102
+ )
103
+ self.conv_sep = ConvBlock(128 * scale, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0))
104
+ self.features = GDC(num_features)
105
+ self._initialize_weights()
106
+
107
+ def _initialize_weights(self):
108
+ for m in self.modules():
109
+ if isinstance(m, nn.Conv2d):
110
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
111
+ if m.bias is not None:
112
+ m.bias.data.zero_()
113
+ elif isinstance(m, nn.BatchNorm2d):
114
+ m.weight.data.fill_(1)
115
+ m.bias.data.zero_()
116
+ elif isinstance(m, nn.Linear):
117
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
118
+ if m.bias is not None:
119
+ m.bias.data.zero_()
120
+
121
+ def forward(self, x):
122
+ with torch.cuda.amp.autocast(self.fp16):
123
+ x = self.layers(x)
124
+ x = self.conv_sep(x.float() if self.fp16 else x)
125
+ x = self.features(x)
126
+ return x
127
+
128
+
129
+ def get_mbf(fp16, num_features):
130
+ return MobileFaceNet(fp16, num_features)
Demo_TFR_Pirenderer/src/face3d/models/arcface_torch/configs/3millions.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict as edict
2
+
3
+ # configs for test speed
4
+
5
+ config = edict()
6
+ config.loss = "arcface"
7
+ config.network = "r50"
8
+ config.resume = False
9
+ config.output = None
10
+ config.embedding_size = 512
11
+ config.sample_rate = 1.0
12
+ config.fp16 = True
13
+ config.momentum = 0.9
14
+ config.weight_decay = 5e-4
15
+ config.batch_size = 128
16
+ config.lr = 0.1 # batch size is 512
17
+
18
+ config.rec = "synthetic"
19
+ config.num_classes = 300 * 10000
20
+ config.num_epoch = 30
21
+ config.warmup_epoch = -1
22
+ config.decay_epoch = [10, 16, 22]
23
+ config.val_targets = []
Demo_TFR_Pirenderer/src/face3d/models/arcface_torch/configs/3millions_pfc.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict as edict
2
+
3
+ # configs for test speed
4
+
5
+ config = edict()
6
+ config.loss = "arcface"
7
+ config.network = "r50"
8
+ config.resume = False
9
+ config.output = None
10
+ config.embedding_size = 512
11
+ config.sample_rate = 0.1
12
+ config.fp16 = True
13
+ config.momentum = 0.9
14
+ config.weight_decay = 5e-4
15
+ config.batch_size = 128
16
+ config.lr = 0.1 # batch size is 512
17
+
18
+ config.rec = "synthetic"
19
+ config.num_classes = 300 * 10000
20
+ config.num_epoch = 30
21
+ config.warmup_epoch = -1
22
+ config.decay_epoch = [10, 16, 22]
23
+ config.val_targets = []
Demo_TFR_Pirenderer/src/face3d/models/arcface_torch/configs/__init__.py ADDED
File without changes
Demo_TFR_Pirenderer/src/face3d/models/arcface_torch/configs/base.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict as edict
2
+
3
+ # make training faster
4
+ # our RAM is 256G
5
+ # mount -t tmpfs -o size=140G tmpfs /train_tmp
6
+
7
+ config = edict()
8
+ config.loss = "arcface"
9
+ config.network = "r50"
10
+ config.resume = False
11
+ config.output = "ms1mv3_arcface_r50"
12
+
13
+ config.dataset = "ms1m-retinaface-t1"
14
+ config.embedding_size = 512
15
+ config.sample_rate = 1
16
+ config.fp16 = False
17
+ config.momentum = 0.9
18
+ config.weight_decay = 5e-4
19
+ config.batch_size = 128
20
+ config.lr = 0.1 # batch size is 512
21
+
22
+ if config.dataset == "emore":
23
+ config.rec = "/train_tmp/faces_emore"
24
+ config.num_classes = 85742
25
+ config.num_image = 5822653
26
+ config.num_epoch = 16
27
+ config.warmup_epoch = -1
28
+ config.decay_epoch = [8, 14, ]
29
+ config.val_targets = ["lfw", ]
30
+
31
+ elif config.dataset == "ms1m-retinaface-t1":
32
+ config.rec = "/train_tmp/ms1m-retinaface-t1"
33
+ config.num_classes = 93431
34
+ config.num_image = 5179510
35
+ config.num_epoch = 25
36
+ config.warmup_epoch = -1
37
+ config.decay_epoch = [11, 17, 22]
38
+ config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
39
+
40
+ elif config.dataset == "glint360k":
41
+ config.rec = "/train_tmp/glint360k"
42
+ config.num_classes = 360232
43
+ config.num_image = 17091657
44
+ config.num_epoch = 20
45
+ config.warmup_epoch = -1
46
+ config.decay_epoch = [8, 12, 15, 18]
47
+ config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
48
+
49
+ elif config.dataset == "webface":
50
+ config.rec = "/train_tmp/faces_webface_112x112"
51
+ config.num_classes = 10572
52
+ config.num_image = "forget"
53
+ config.num_epoch = 34
54
+ config.warmup_epoch = -1
55
+ config.decay_epoch = [20, 28, 32]
56
+ config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
Demo_TFR_Pirenderer/src/face3d/models/arcface_torch/configs/glint360k_mbf.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict as edict
2
+
3
+ # make training faster
4
+ # our RAM is 256G
5
+ # mount -t tmpfs -o size=140G tmpfs /train_tmp
6
+
7
+ config = edict()
8
+ config.loss = "cosface"
9
+ config.network = "mbf"
10
+ config.resume = False
11
+ config.output = None
12
+ config.embedding_size = 512
13
+ config.sample_rate = 0.1
14
+ config.fp16 = True
15
+ config.momentum = 0.9
16
+ config.weight_decay = 2e-4
17
+ config.batch_size = 128
18
+ config.lr = 0.1 # batch size is 512
19
+
20
+ config.rec = "/train_tmp/glint360k"
21
+ config.num_classes = 360232
22
+ config.num_image = 17091657
23
+ config.num_epoch = 20
24
+ config.warmup_epoch = -1
25
+ config.decay_epoch = [8, 12, 15, 18]
26
+ config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
Demo_TFR_Pirenderer/src/face3d/models/arcface_torch/configs/glint360k_r100.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict as edict
2
+
3
+ # make training faster
4
+ # our RAM is 256G
5
+ # mount -t tmpfs -o size=140G tmpfs /train_tmp
6
+
7
+ config = edict()
8
+ config.loss = "cosface"
9
+ config.network = "r100"
10
+ config.resume = False
11
+ config.output = None
12
+ config.embedding_size = 512
13
+ config.sample_rate = 1.0
14
+ config.fp16 = True
15
+ config.momentum = 0.9
16
+ config.weight_decay = 5e-4
17
+ config.batch_size = 128
18
+ config.lr = 0.1 # batch size is 512
19
+
20
+ config.rec = "/train_tmp/glint360k"
21
+ config.num_classes = 360232
22
+ config.num_image = 17091657
23
+ config.num_epoch = 20
24
+ config.warmup_epoch = -1
25
+ config.decay_epoch = [8, 12, 15, 18]
26
+ config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
Demo_TFR_Pirenderer/src/face3d/models/arcface_torch/configs/glint360k_r18.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict as edict
2
+
3
+ # make training faster
4
+ # our RAM is 256G
5
+ # mount -t tmpfs -o size=140G tmpfs /train_tmp
6
+
7
+ config = edict()
8
+ config.loss = "cosface"
9
+ config.network = "r18"
10
+ config.resume = False
11
+ config.output = None
12
+ config.embedding_size = 512
13
+ config.sample_rate = 1.0
14
+ config.fp16 = True
15
+ config.momentum = 0.9
16
+ config.weight_decay = 5e-4
17
+ config.batch_size = 128
18
+ config.lr = 0.1 # batch size is 512
19
+
20
+ config.rec = "/train_tmp/glint360k"
21
+ config.num_classes = 360232
22
+ config.num_image = 17091657
23
+ config.num_epoch = 20
24
+ config.warmup_epoch = -1
25
+ config.decay_epoch = [8, 12, 15, 18]
26
+ config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
Demo_TFR_Pirenderer/src/face3d/models/arcface_torch/configs/glint360k_r34.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict as edict
2
+
3
+ # make training faster
4
+ # our RAM is 256G
5
+ # mount -t tmpfs -o size=140G tmpfs /train_tmp
6
+
7
+ config = edict()
8
+ config.loss = "cosface"
9
+ config.network = "r34"
10
+ config.resume = False
11
+ config.output = None
12
+ config.embedding_size = 512
13
+ config.sample_rate = 1.0
14
+ config.fp16 = True
15
+ config.momentum = 0.9
16
+ config.weight_decay = 5e-4
17
+ config.batch_size = 128
18
+ config.lr = 0.1 # batch size is 512
19
+
20
+ config.rec = "/train_tmp/glint360k"
21
+ config.num_classes = 360232
22
+ config.num_image = 17091657
23
+ config.num_epoch = 20
24
+ config.warmup_epoch = -1
25
+ config.decay_epoch = [8, 12, 15, 18]
26
+ config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
Demo_TFR_Pirenderer/src/face3d/models/arcface_torch/configs/glint360k_r50.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict as edict
2
+
3
+ # make training faster
4
+ # our RAM is 256G
5
+ # mount -t tmpfs -o size=140G tmpfs /train_tmp
6
+
7
+ config = edict()
8
+ config.loss = "cosface"
9
+ config.network = "r50"
10
+ config.resume = False
11
+ config.output = None
12
+ config.embedding_size = 512
13
+ config.sample_rate = 1.0
14
+ config.fp16 = True
15
+ config.momentum = 0.9
16
+ config.weight_decay = 5e-4
17
+ config.batch_size = 128
18
+ config.lr = 0.1 # batch size is 512
19
+
20
+ config.rec = "/train_tmp/glint360k"
21
+ config.num_classes = 360232
22
+ config.num_image = 17091657
23
+ config.num_epoch = 20
24
+ config.warmup_epoch = -1
25
+ config.decay_epoch = [8, 12, 15, 18]
26
+ config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
Demo_TFR_Pirenderer/src/face3d/models/arcface_torch/configs/ms1mv3_mbf.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict as edict
2
+
3
+ # make training faster
4
+ # our RAM is 256G
5
+ # mount -t tmpfs -o size=140G tmpfs /train_tmp
6
+
7
+ config = edict()
8
+ config.loss = "arcface"
9
+ config.network = "mbf"
10
+ config.resume = False
11
+ config.output = None
12
+ config.embedding_size = 512
13
+ config.sample_rate = 1.0
14
+ config.fp16 = True
15
+ config.momentum = 0.9
16
+ config.weight_decay = 2e-4
17
+ config.batch_size = 128
18
+ config.lr = 0.1 # batch size is 512
19
+
20
+ config.rec = "/train_tmp/ms1m-retinaface-t1"
21
+ config.num_classes = 93431
22
+ config.num_image = 5179510
23
+ config.num_epoch = 30
24
+ config.warmup_epoch = -1
25
+ config.decay_epoch = [10, 20, 25]
26
+ config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
Demo_TFR_Pirenderer/src/face3d/models/arcface_torch/configs/ms1mv3_r18.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict as edict
2
+
3
+ # make training faster
4
+ # our RAM is 256G
5
+ # mount -t tmpfs -o size=140G tmpfs /train_tmp
6
+
7
+ config = edict()
8
+ config.loss = "arcface"
9
+ config.network = "r18"
10
+ config.resume = False
11
+ config.output = None
12
+ config.embedding_size = 512
13
+ config.sample_rate = 1.0
14
+ config.fp16 = True
15
+ config.momentum = 0.9
16
+ config.weight_decay = 5e-4
17
+ config.batch_size = 128
18
+ config.lr = 0.1 # batch size is 512
19
+
20
+ config.rec = "/train_tmp/ms1m-retinaface-t1"
21
+ config.num_classes = 93431
22
+ config.num_image = 5179510
23
+ config.num_epoch = 25
24
+ config.warmup_epoch = -1
25
+ config.decay_epoch = [10, 16, 22]
26
+ config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
Demo_TFR_Pirenderer/src/face3d/models/arcface_torch/configs/ms1mv3_r2060.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict as edict
2
+
3
+ # make training faster
4
+ # our RAM is 256G
5
+ # mount -t tmpfs -o size=140G tmpfs /train_tmp
6
+
7
+ config = edict()
8
+ config.loss = "arcface"
9
+ config.network = "r2060"
10
+ config.resume = False
11
+ config.output = None
12
+ config.embedding_size = 512
13
+ config.sample_rate = 1.0
14
+ config.fp16 = True
15
+ config.momentum = 0.9
16
+ config.weight_decay = 5e-4
17
+ config.batch_size = 64
18
+ config.lr = 0.1 # batch size is 512
19
+
20
+ config.rec = "/train_tmp/ms1m-retinaface-t1"
21
+ config.num_classes = 93431
22
+ config.num_image = 5179510
23
+ config.num_epoch = 25
24
+ config.warmup_epoch = -1
25
+ config.decay_epoch = [10, 16, 22]
26
+ config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
Demo_TFR_Pirenderer/src/face3d/models/arcface_torch/configs/ms1mv3_r34.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict as edict
2
+
3
+ # make training faster
4
+ # our RAM is 256G
5
+ # mount -t tmpfs -o size=140G tmpfs /train_tmp
6
+
7
+ config = edict()
8
+ config.loss = "arcface"
9
+ config.network = "r34"
10
+ config.resume = False
11
+ config.output = None
12
+ config.embedding_size = 512
13
+ config.sample_rate = 1.0
14
+ config.fp16 = True
15
+ config.momentum = 0.9
16
+ config.weight_decay = 5e-4
17
+ config.batch_size = 128
18
+ config.lr = 0.1 # batch size is 512
19
+
20
+ config.rec = "/train_tmp/ms1m-retinaface-t1"
21
+ config.num_classes = 93431
22
+ config.num_image = 5179510
23
+ config.num_epoch = 25
24
+ config.warmup_epoch = -1
25
+ config.decay_epoch = [10, 16, 22]
26
+ config.val_targets = ["lfw", "cfp_fp", "agedb_30"]