diff --git a/videoretalking/third_part/face3d/checkpoints/model_name/test_opt.txt b/videoretalking/third_part/face3d/checkpoints/model_name/test_opt.txt new file mode 100644 index 0000000000000000000000000000000000000000..08f477c61d0a7e5506d21dad3ea62e049f545400 --- /dev/null +++ b/videoretalking/third_part/face3d/checkpoints/model_name/test_opt.txt @@ -0,0 +1,34 @@ +----------------- Options --------------- + add_image: True + bfm_folder: BFM + bfm_model: BFM_model_front.mat + camera_d: 10.0 + center: 112.0 + checkpoints_dir: ./checkpoints + dataset_mode: None + ddp_port: 12355 + display_per_batch: True + epoch: 20 [default: latest] + eval_batch_nums: inf + focal: 1015.0 + gpu_ids: 0 + inference_batch_size: 8 + init_path: checkpoints/init_model/resnet50-0676ba61.pth + input_dir: demo_video [default: None] + isTrain: False [default: None] + keypoint_dir: demo_cctv [default: None] + model: facerecon + name: model_name [default: face_recon] + net_recon: resnet50 + output_dir: demo_cctv [default: mp4] + phase: test + save_split_files: False + suffix: + use_ddp: False [default: True] + use_last_fc: False + verbose: False + vis_batch_nums: 1 + world_size: 1 + z_far: 15.0 + z_near: 5.0 +----------------- End ------------------- diff --git a/videoretalking/third_part/face3d/coeff_detector.py b/videoretalking/third_part/face3d/coeff_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..e4617cb5be3a532509ea1116aa8d0a127c2d9b03 --- /dev/null +++ b/videoretalking/third_part/face3d/coeff_detector.py @@ -0,0 +1,118 @@ +import os +import glob +import numpy as np +from os import makedirs, name +from PIL import Image +from tqdm import tqdm + +import torch +import torch.nn as nn + +from face3d.options.inference_options import InferenceOptions +from face3d.models import create_model +from face3d.util.preprocess import align_img +from face3d.util.load_mats import load_lm3d +from face3d.extract_kp_videos import KeypointExtractor + + +class CoeffDetector(nn.Module): + def __init__(self, opt): + super().__init__() + + self.model = create_model(opt) + self.model.setup(opt) + self.model.device = 'cuda' + self.model.parallelize() + self.model.eval() + + self.lm3d_std = load_lm3d(opt.bfm_folder) + + def forward(self, img, lm): + + img, trans_params = self.image_transform(img, lm) + + data_input = { + 'imgs': img[None], + } + self.model.set_input(data_input) + self.model.test() + pred_coeff = {key:self.model.pred_coeffs_dict[key].cpu().numpy() for key in self.model.pred_coeffs_dict} + pred_coeff = np.concatenate([ + pred_coeff['id'], + pred_coeff['exp'], + pred_coeff['tex'], + pred_coeff['angle'], + pred_coeff['gamma'], + pred_coeff['trans'], + trans_params[None], + ], 1) + + return {'coeff_3dmm':pred_coeff, + 'crop_img': Image.fromarray((img.cpu().permute(1, 2, 0).numpy()*255).astype(np.uint8))} + + def image_transform(self, images, lm): + """ + param: + images: -- PIL image + lm: -- numpy array + """ + W,H = images.size + if np.mean(lm) == -1: + lm = (self.lm3d_std[:, :2]+1)/2. + lm = np.concatenate( + [lm[:, :1]*W, lm[:, 1:2]*H], 1 + ) + else: + lm[:, -1] = H - 1 - lm[:, -1] + + trans_params, img, lm, _ = align_img(images, lm, self.lm3d_std) + img = torch.tensor(np.array(img)/255., dtype=torch.float32).permute(2, 0, 1) + trans_params = np.array([float(item) for item in np.hsplit(trans_params, 5)]) + trans_params = torch.tensor(trans_params.astype(np.float32)) + return img, trans_params + +def get_data_path(root, keypoint_root): + filenames = list() + keypoint_filenames = list() + + IMAGE_EXTENSIONS_LOWERCASE = {'jpg', 'png', 'jpeg', 'webp'} + IMAGE_EXTENSIONS = IMAGE_EXTENSIONS_LOWERCASE.union({f.upper() for f in IMAGE_EXTENSIONS_LOWERCASE}) + extensions = IMAGE_EXTENSIONS + + for ext in extensions: + filenames += glob.glob(f'{root}/*.{ext}', recursive=True) + filenames = sorted(filenames) + for filename in filenames: + name = os.path.splitext(os.path.basename(filename))[0] + keypoint_filenames.append( + os.path.join(keypoint_root, name + '.txt') + ) + return filenames, keypoint_filenames + + +if __name__ == "__main__": + opt = InferenceOptions().parse() + coeff_detector = CoeffDetector(opt) + kp_extractor = KeypointExtractor() + image_names, keypoint_names = get_data_path(opt.input_dir, opt.keypoint_dir) + makedirs(opt.keypoint_dir, exist_ok=True) + makedirs(opt.output_dir, exist_ok=True) + + for image_name, keypoint_name in tqdm(zip(image_names, keypoint_names)): + image = Image.open(image_name) + if not os.path.isfile(keypoint_name): + lm = kp_extractor.extract_keypoint(image, keypoint_name) + else: + lm = np.loadtxt(keypoint_name).astype(np.float32) + lm = lm.reshape([-1, 2]) + predicted = coeff_detector(image, lm) + name = os.path.splitext(os.path.basename(image_name))[0] + np.savetxt( + "{}/{}_3dmm_coeff.txt".format(opt.output_dir, name), + predicted['coeff_3dmm'].reshape(-1)) + + + + + + \ No newline at end of file diff --git a/videoretalking/third_part/face3d/data/__init__.py b/videoretalking/third_part/face3d/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..be2378c5877af8e749db18d8a67a382f3eb0912b --- /dev/null +++ b/videoretalking/third_part/face3d/data/__init__.py @@ -0,0 +1,116 @@ +"""This package includes all the modules related to data loading and preprocessing + + To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset. + You need to implement four functions: + -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). + -- <__len__>: return the size of dataset. + -- <__getitem__>: get a data point from data loader. + -- : (optionally) add dataset-specific options and set default options. + +Now you can use the dataset class by specifying flag '--dataset_mode dummy'. +See our template dataset class 'template_dataset.py' for more details. +""" +import numpy as np +import importlib +import torch.utils.data +from face3d.data.base_dataset import BaseDataset + + +def find_dataset_using_name(dataset_name): + """Import the module "data/[dataset_name]_dataset.py". + + In the file, the class called DatasetNameDataset() will + be instantiated. It has to be a subclass of BaseDataset, + and it is case-insensitive. + """ + dataset_filename = "data." + dataset_name + "_dataset" + datasetlib = importlib.import_module(dataset_filename) + + dataset = None + target_dataset_name = dataset_name.replace('_', '') + 'dataset' + for name, cls in datasetlib.__dict__.items(): + if name.lower() == target_dataset_name.lower() \ + and issubclass(cls, BaseDataset): + dataset = cls + + if dataset is None: + raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) + + return dataset + + +def get_option_setter(dataset_name): + """Return the static method of the dataset class.""" + dataset_class = find_dataset_using_name(dataset_name) + return dataset_class.modify_commandline_options + + +def create_dataset(opt, rank=0): + """Create a dataset given the option. + + This function wraps the class CustomDatasetDataLoader. + This is the main interface between this package and 'train.py'/'test.py' + + Example: + >>> from data import create_dataset + >>> dataset = create_dataset(opt) + """ + data_loader = CustomDatasetDataLoader(opt, rank=rank) + dataset = data_loader.load_data() + return dataset + +class CustomDatasetDataLoader(): + """Wrapper class of Dataset class that performs multi-threaded data loading""" + + def __init__(self, opt, rank=0): + """Initialize this class + + Step 1: create a dataset instance given the name [dataset_mode] + Step 2: create a multi-threaded data loader. + """ + self.opt = opt + dataset_class = find_dataset_using_name(opt.dataset_mode) + self.dataset = dataset_class(opt) + self.sampler = None + print("rank %d %s dataset [%s] was created" % (rank, self.dataset.name, type(self.dataset).__name__)) + if opt.use_ddp and opt.isTrain: + world_size = opt.world_size + self.sampler = torch.utils.data.distributed.DistributedSampler( + self.dataset, + num_replicas=world_size, + rank=rank, + shuffle=not opt.serial_batches + ) + self.dataloader = torch.utils.data.DataLoader( + self.dataset, + sampler=self.sampler, + num_workers=int(opt.num_threads / world_size), + batch_size=int(opt.batch_size / world_size), + drop_last=True) + else: + self.dataloader = torch.utils.data.DataLoader( + self.dataset, + batch_size=opt.batch_size, + shuffle=(not opt.serial_batches) and opt.isTrain, + num_workers=int(opt.num_threads), + drop_last=True + ) + + def set_epoch(self, epoch): + self.dataset.current_epoch = epoch + if self.sampler is not None: + self.sampler.set_epoch(epoch) + + def load_data(self): + return self + + def __len__(self): + """Return the number of data in the dataset""" + return min(len(self.dataset), self.opt.max_dataset_size) + + def __iter__(self): + """Return a batch of data""" + for i, data in enumerate(self.dataloader): + if i * self.opt.batch_size >= self.opt.max_dataset_size: + break + yield data diff --git a/videoretalking/third_part/face3d/data/base_dataset.py b/videoretalking/third_part/face3d/data/base_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..34a7ea5024206e6e58c2f404ac6a1bf0987f5fd4 --- /dev/null +++ b/videoretalking/third_part/face3d/data/base_dataset.py @@ -0,0 +1,125 @@ +"""This module implements an abstract base class (ABC) 'BaseDataset' for datasets. + +It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses. +""" +import random +import numpy as np +import torch.utils.data as data +from PIL import Image +import torchvision.transforms as transforms +from abc import ABC, abstractmethod + + +class BaseDataset(data.Dataset, ABC): + """This class is an abstract base class (ABC) for datasets. + + To create a subclass, you need to implement the following four functions: + -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). + -- <__len__>: return the size of dataset. + -- <__getitem__>: get a data point. + -- : (optionally) add dataset-specific options and set default options. + """ + + def __init__(self, opt): + """Initialize the class; save the options in the class + + Parameters: + opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + self.opt = opt + # self.root = opt.dataroot + self.current_epoch = 0 + + @staticmethod + def modify_commandline_options(parser, is_train): + """Add new dataset-specific options, and rewrite default values for existing options. + + Parameters: + parser -- original option parser + is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. + + Returns: + the modified parser. + """ + return parser + + @abstractmethod + def __len__(self): + """Return the total number of images in the dataset.""" + return 0 + + @abstractmethod + def __getitem__(self, index): + """Return a data point and its metadata information. + + Parameters: + index - - a random integer for data indexing + + Returns: + a dictionary of data with their names. It ususally contains the data itself and its metadata information. + """ + pass + + +def get_transform(grayscale=False): + transform_list = [] + if grayscale: + transform_list.append(transforms.Grayscale(1)) + transform_list += [transforms.ToTensor()] + return transforms.Compose(transform_list) + +def get_affine_mat(opt, size): + shift_x, shift_y, scale, rot_angle, flip = 0., 0., 1., 0., False + w, h = size + + if 'shift' in opt.preprocess: + shift_pixs = int(opt.shift_pixs) + shift_x = random.randint(-shift_pixs, shift_pixs) + shift_y = random.randint(-shift_pixs, shift_pixs) + if 'scale' in opt.preprocess: + scale = 1 + opt.scale_delta * (2 * random.random() - 1) + if 'rot' in opt.preprocess: + rot_angle = opt.rot_angle * (2 * random.random() - 1) + rot_rad = -rot_angle * np.pi/180 + if 'flip' in opt.preprocess: + flip = random.random() > 0.5 + + shift_to_origin = np.array([1, 0, -w//2, 0, 1, -h//2, 0, 0, 1]).reshape([3, 3]) + flip_mat = np.array([-1 if flip else 1, 0, 0, 0, 1, 0, 0, 0, 1]).reshape([3, 3]) + shift_mat = np.array([1, 0, shift_x, 0, 1, shift_y, 0, 0, 1]).reshape([3, 3]) + rot_mat = np.array([np.cos(rot_rad), np.sin(rot_rad), 0, -np.sin(rot_rad), np.cos(rot_rad), 0, 0, 0, 1]).reshape([3, 3]) + scale_mat = np.array([scale, 0, 0, 0, scale, 0, 0, 0, 1]).reshape([3, 3]) + shift_to_center = np.array([1, 0, w//2, 0, 1, h//2, 0, 0, 1]).reshape([3, 3]) + + affine = shift_to_center @ scale_mat @ rot_mat @ shift_mat @ flip_mat @ shift_to_origin + affine_inv = np.linalg.inv(affine) + return affine, affine_inv, flip + +def apply_img_affine(img, affine_inv, method=Image.BICUBIC): + return img.transform(img.size, Image.AFFINE, data=affine_inv.flatten()[:6], resample=Image.BICUBIC) + +def apply_lm_affine(landmark, affine, flip, size): + _, h = size + lm = landmark.copy() + lm[:, 1] = h - 1 - lm[:, 1] + lm = np.concatenate((lm, np.ones([lm.shape[0], 1])), -1) + lm = lm @ np.transpose(affine) + lm[:, :2] = lm[:, :2] / lm[:, 2:] + lm = lm[:, :2] + lm[:, 1] = h - 1 - lm[:, 1] + if flip: + lm_ = lm.copy() + lm_[:17] = lm[16::-1] + lm_[17:22] = lm[26:21:-1] + lm_[22:27] = lm[21:16:-1] + lm_[31:36] = lm[35:30:-1] + lm_[36:40] = lm[45:41:-1] + lm_[40:42] = lm[47:45:-1] + lm_[42:46] = lm[39:35:-1] + lm_[46:48] = lm[41:39:-1] + lm_[48:55] = lm[54:47:-1] + lm_[55:60] = lm[59:54:-1] + lm_[60:65] = lm[64:59:-1] + lm_[65:68] = lm[67:64:-1] + lm = lm_ + return lm diff --git a/videoretalking/third_part/face3d/data/flist_dataset.py b/videoretalking/third_part/face3d/data/flist_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..63b49caa8020f8e9aedb73a839b7112320cad68a --- /dev/null +++ b/videoretalking/third_part/face3d/data/flist_dataset.py @@ -0,0 +1,125 @@ +"""This script defines the custom dataset for Deep3DFaceRecon_pytorch +""" + +import os.path +from data.base_dataset import BaseDataset, get_transform, get_affine_mat, apply_img_affine, apply_lm_affine +from data.image_folder import make_dataset +from PIL import Image +import random +import util.util as util +import numpy as np +import json +import torch +from scipy.io import loadmat, savemat +import pickle +from util.preprocess import align_img, estimate_norm +from util.load_mats import load_lm3d + + +def default_flist_reader(flist): + """ + flist format: impath label\nimpath label\n ...(same to caffe's filelist) + """ + imlist = [] + with open(flist, 'r') as rf: + for line in rf.readlines(): + impath = line.strip() + imlist.append(impath) + + return imlist + +def jason_flist_reader(flist): + with open(flist, 'r') as fp: + info = json.load(fp) + return info + +def parse_label(label): + return torch.tensor(np.array(label).astype(np.float32)) + + +class FlistDataset(BaseDataset): + """ + It requires one directories to host training images '/path/to/data/train' + You can train the model with the dataset flag '--dataroot /path/to/data'. + """ + + def __init__(self, opt): + """Initialize this dataset class. + + Parameters: + opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + BaseDataset.__init__(self, opt) + + self.lm3d_std = load_lm3d(opt.bfm_folder) + + msk_names = default_flist_reader(opt.flist) + self.msk_paths = [os.path.join(opt.data_root, i) for i in msk_names] + + self.size = len(self.msk_paths) + self.opt = opt + + self.name = 'train' if opt.isTrain else 'val' + if '_' in opt.flist: + self.name += '_' + opt.flist.split(os.sep)[-1].split('_')[0] + + + def __getitem__(self, index): + """Return a data point and its metadata information. + + Parameters: + index (int) -- a random integer for data indexing + + Returns a dictionary that contains A, B, A_paths and B_paths + img (tensor) -- an image in the input domain + msk (tensor) -- its corresponding attention mask + lm (tensor) -- its corresponding 3d landmarks + im_paths (str) -- image paths + aug_flag (bool) -- a flag used to tell whether its raw or augmented + """ + msk_path = self.msk_paths[index % self.size] # make sure index is within then range + img_path = msk_path.replace('mask/', '') + lm_path = '.'.join(msk_path.replace('mask', 'landmarks').split('.')[:-1]) + '.txt' + + raw_img = Image.open(img_path).convert('RGB') + raw_msk = Image.open(msk_path).convert('RGB') + raw_lm = np.loadtxt(lm_path).astype(np.float32) + + _, img, lm, msk = align_img(raw_img, raw_lm, self.lm3d_std, raw_msk) + + aug_flag = self.opt.use_aug and self.opt.isTrain + if aug_flag: + img, lm, msk = self._augmentation(img, lm, self.opt, msk) + + _, H = img.size + M = estimate_norm(lm, H) + transform = get_transform() + img_tensor = transform(img) + msk_tensor = transform(msk)[:1, ...] + lm_tensor = parse_label(lm) + M_tensor = parse_label(M) + + + return {'imgs': img_tensor, + 'lms': lm_tensor, + 'msks': msk_tensor, + 'M': M_tensor, + 'im_paths': img_path, + 'aug_flag': aug_flag, + 'dataset': self.name} + + def _augmentation(self, img, lm, opt, msk=None): + affine, affine_inv, flip = get_affine_mat(opt, img.size) + img = apply_img_affine(img, affine_inv) + lm = apply_lm_affine(lm, affine, flip, img.size) + if msk is not None: + msk = apply_img_affine(msk, affine_inv, method=Image.BILINEAR) + return img, lm, msk + + + + + def __len__(self): + """Return the total number of images in the dataset. + """ + return self.size diff --git a/videoretalking/third_part/face3d/data/image_folder.py b/videoretalking/third_part/face3d/data/image_folder.py new file mode 100644 index 0000000000000000000000000000000000000000..07ef069029b0db1fc40b9b5f9a6f52a48c1cd162 --- /dev/null +++ b/videoretalking/third_part/face3d/data/image_folder.py @@ -0,0 +1,66 @@ +"""A modified image folder class + +We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py) +so that this class can load images from both current directory and its subdirectories. +""" +import numpy as np +import torch.utils.data as data + +from PIL import Image +import os +import os.path + +IMG_EXTENSIONS = [ + '.jpg', '.JPG', '.jpeg', '.JPEG', + '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', + '.tif', '.TIF', '.tiff', '.TIFF', +] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + + +def make_dataset(dir, max_dataset_size=float("inf")): + images = [] + assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir + + for root, _, fnames in sorted(os.walk(dir, followlinks=True)): + for fname in fnames: + if is_image_file(fname): + path = os.path.join(root, fname) + images.append(path) + return images[:min(max_dataset_size, len(images))] + + +def default_loader(path): + return Image.open(path).convert('RGB') + + +class ImageFolder(data.Dataset): + + def __init__(self, root, transform=None, return_paths=False, + loader=default_loader): + imgs = make_dataset(root) + if len(imgs) == 0: + raise(RuntimeError("Found 0 images in: " + root + "\n" + "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) + + self.root = root + self.imgs = imgs + self.transform = transform + self.return_paths = return_paths + self.loader = loader + + def __getitem__(self, index): + path = self.imgs[index] + img = self.loader(path) + if self.transform is not None: + img = self.transform(img) + if self.return_paths: + return img, path + else: + return img + + def __len__(self): + return len(self.imgs) diff --git a/videoretalking/third_part/face3d/data/template_dataset.py b/videoretalking/third_part/face3d/data/template_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..693b6b09085ad424e53f26e0938b61eea30ed644 --- /dev/null +++ b/videoretalking/third_part/face3d/data/template_dataset.py @@ -0,0 +1,75 @@ +"""Dataset class template + +This module provides a template for users to implement custom datasets. +You can specify '--dataset_mode template' to use this dataset. +The class name should be consistent with both the filename and its dataset_mode option. +The filename should be _dataset.py +The class name should be Dataset.py +You need to implement the following functions: + -- : Add dataset-specific options and rewrite default values for existing options. + -- <__init__>: Initialize this dataset class. + -- <__getitem__>: Return a data point and its metadata information. + -- <__len__>: Return the number of images. +""" +from data.base_dataset import BaseDataset, get_transform +# from data.image_folder import make_dataset +# from PIL import Image + + +class TemplateDataset(BaseDataset): + """A template dataset class for you to implement custom datasets.""" + @staticmethod + def modify_commandline_options(parser, is_train): + """Add new dataset-specific options, and rewrite default values for existing options. + + Parameters: + parser -- original option parser + is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. + + Returns: + the modified parser. + """ + parser.add_argument('--new_dataset_option', type=float, default=1.0, help='new dataset option') + parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0) # specify dataset-specific default values + return parser + + def __init__(self, opt): + """Initialize this dataset class. + + Parameters: + opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions + + A few things can be done here. + - save the options (have been done in BaseDataset) + - get image paths and meta information of the dataset. + - define the image transformation. + """ + # save the option and dataset root + BaseDataset.__init__(self, opt) + # get the image paths of your dataset; + self.image_paths = [] # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root + # define the default transform function. You can use ; You can also define your custom transform function + self.transform = get_transform(opt) + + def __getitem__(self, index): + """Return a data point and its metadata information. + + Parameters: + index -- a random integer for data indexing + + Returns: + a dictionary of data with their names. It usually contains the data itself and its metadata information. + + Step 1: get a random image path: e.g., path = self.image_paths[index] + Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB'). + Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image) + Step 4: return a data point as a dictionary. + """ + path = 'temp' # needs to be a string + data_A = None # needs to be a tensor + data_B = None # needs to be a tensor + return {'data_A': data_A, 'data_B': data_B, 'path': path} + + def __len__(self): + """Return the total number of images.""" + return len(self.image_paths) diff --git a/videoretalking/third_part/face3d/data_preparation.py b/videoretalking/third_part/face3d/data_preparation.py new file mode 100644 index 0000000000000000000000000000000000000000..7d0a9cfa609cdc3631d82d1ec696c381238b3296 --- /dev/null +++ b/videoretalking/third_part/face3d/data_preparation.py @@ -0,0 +1,45 @@ +"""This script is the data preparation script for Deep3DFaceRecon_pytorch +""" + +import os +import numpy as np +import argparse +from util.detect_lm68 import detect_68p,load_lm_graph +from util.skin_mask import get_skin_mask +from util.generate_list import check_list, write_list +import warnings +warnings.filterwarnings("ignore") + +parser = argparse.ArgumentParser() +parser.add_argument('--data_root', type=str, default='datasets', help='root directory for training data') +parser.add_argument('--img_folder', nargs="+", required=True, help='folders of training images') +parser.add_argument('--mode', type=str, default='train', help='train or val') +opt = parser.parse_args() + +os.environ['CUDA_VISIBLE_DEVICES'] = '0' + +def data_prepare(folder_list,mode): + + lm_sess,input_op,output_op = load_lm_graph('./checkpoints/lm_model/68lm_detector.pb') # load a tensorflow version 68-landmark detector + + for img_folder in folder_list: + detect_68p(img_folder,lm_sess,input_op,output_op) # detect landmarks for images + get_skin_mask(img_folder) # generate skin attention mask for images + + # create files that record path to all training data + msks_list = [] + for img_folder in folder_list: + path = os.path.join(img_folder, 'mask') + msks_list += ['/'.join([img_folder, 'mask', i]) for i in sorted(os.listdir(path)) if 'jpg' in i or + 'png' in i or 'jpeg' in i or 'PNG' in i] + + imgs_list = [i.replace('mask/', '') for i in msks_list] + lms_list = [i.replace('mask', 'landmarks') for i in msks_list] + lms_list = ['.'.join(i.split('.')[:-1]) + '.txt' for i in lms_list] + + lms_list_final, imgs_list_final, msks_list_final = check_list(lms_list, imgs_list, msks_list) # check if the path is valid + write_list(lms_list_final, imgs_list_final, msks_list_final, mode=mode) # save files + +if __name__ == '__main__': + print('Datasets:',opt.img_folder) + data_prepare([os.path.join(opt.data_root,folder) for folder in opt.img_folder],opt.mode) diff --git a/videoretalking/third_part/face3d/extract_kp_videos.py b/videoretalking/third_part/face3d/extract_kp_videos.py new file mode 100644 index 0000000000000000000000000000000000000000..7b62192469fb708a508bbebec439b7fddc36558b --- /dev/null +++ b/videoretalking/third_part/face3d/extract_kp_videos.py @@ -0,0 +1,109 @@ +import os +import cv2 +import time +import glob +import argparse +import face_alignment +import numpy as np +from PIL import Image +import torch +from tqdm import tqdm +from itertools import cycle + +from torch.multiprocessing import Pool, Process, set_start_method + +class KeypointExtractor(): + def __init__(self): + device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.detector = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, device=device) + + def extract_keypoint(self, images, name=None, info=True): + if isinstance(images, list): + keypoints = [] + if info: + i_range = tqdm(images,desc='landmark Det:') + else: + i_range = images + + for image in i_range: + current_kp = self.extract_keypoint(image) + if np.mean(current_kp) == -1 and keypoints: + keypoints.append(keypoints[-1]) + else: + keypoints.append(current_kp[None]) + + keypoints = np.concatenate(keypoints, 0) + np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1)) + return keypoints + else: + while True: + try: + keypoints = self.detector.get_landmarks_from_image(np.array(images))[0] + break + except RuntimeError as e: + if str(e).startswith('CUDA'): + print("Warning: out of memory, sleep for 1s") + time.sleep(1) + else: + print(e) + break + except TypeError: + print('No face detected in this image') + shape = [68, 2] + keypoints = -1. * np.ones(shape) + break + if name is not None: + np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1)) + return keypoints + +def read_video(filename): + frames = [] + cap = cv2.VideoCapture(filename) + while cap.isOpened(): + ret, frame = cap.read() + if ret: + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frame = Image.fromarray(frame) + frames.append(frame) + else: + break + cap.release() + return frames + +def run(data): + filename, opt, device = data + os.environ['CUDA_VISIBLE_DEVICES'] = device + kp_extractor = KeypointExtractor() + images = read_video(filename) + name = filename.split('/')[-2:] + os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True) + kp_extractor.extract_keypoint( + images, + name=os.path.join(opt.output_dir, name[-2], name[-1]) + ) + +if __name__ == '__main__': + set_start_method('spawn') + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument('--input_dir', type=str, help='the folder of the input files') + parser.add_argument('--output_dir', type=str, help='the folder of the output files') + parser.add_argument('--device_ids', type=str, default='0,1') + parser.add_argument('--workers', type=int, default=4) + + opt = parser.parse_args() + filenames = list() + VIDEO_EXTENSIONS_LOWERCASE = {'mp4'} + VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE}) + extensions = VIDEO_EXTENSIONS + + for ext in extensions: + os.listdir(f'{opt.input_dir}') + print(f'{opt.input_dir}/*.{ext}') + filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}')) + print('Total number of videos:', len(filenames)) + pool = Pool(opt.workers) + args_list = cycle([opt]) + device_ids = opt.device_ids.split(",") + device_ids = cycle(device_ids) + for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))): + None diff --git a/videoretalking/third_part/face3d/face_recon_videos.py b/videoretalking/third_part/face3d/face_recon_videos.py new file mode 100644 index 0000000000000000000000000000000000000000..7dc49331ee51b765592d1d751b05bde15ff02c5b --- /dev/null +++ b/videoretalking/third_part/face3d/face_recon_videos.py @@ -0,0 +1,157 @@ +import os +import cv2 +import glob +import numpy as np +from PIL import Image +from tqdm import tqdm +from scipy.io import savemat + +import torch + +from models import create_model +from options.inference_options import InferenceOptions +from util.preprocess import align_img +from util.load_mats import load_lm3d +from util.util import mkdirs, tensor2im, save_image + + +def get_data_path(root, keypoint_root): + filenames = list() + keypoint_filenames = list() + + VIDEO_EXTENSIONS_LOWERCASE = {'mp4'} + VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE}) + extensions = VIDEO_EXTENSIONS + + for ext in extensions: + filenames += glob.glob(f'{root}/**/*.{ext}', recursive=True) + filenames = sorted(filenames) + keypoint_filenames = sorted(glob.glob(f'{keypoint_root}/**/*.txt', recursive=True)) + assert len(filenames) == len(keypoint_filenames) + + return filenames, keypoint_filenames + +class VideoPathDataset(torch.utils.data.Dataset): + def __init__(self, filenames, txt_filenames, bfm_folder): + self.filenames = filenames + self.txt_filenames = txt_filenames + self.lm3d_std = load_lm3d(bfm_folder) + + def __len__(self): + return len(self.filenames) + + def __getitem__(self, index): + filename = self.filenames[index] + txt_filename = self.txt_filenames[index] + frames = self.read_video(filename) + lm = np.loadtxt(txt_filename).astype(np.float32) + lm = lm.reshape([len(frames), -1, 2]) + out_images, out_trans_params = list(), list() + for i in range(len(frames)): + out_img, _, out_trans_param \ + = self.image_transform(frames[i], lm[i]) + out_images.append(out_img[None]) + out_trans_params.append(out_trans_param[None]) + return { + 'imgs': torch.cat(out_images, 0), + 'trans_param':torch.cat(out_trans_params, 0), + 'filename': filename + } + + def read_video(self, filename): + frames = list() + cap = cv2.VideoCapture(filename) + while cap.isOpened(): + ret, frame = cap.read() + if ret: + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frame = Image.fromarray(frame) + frames.append(frame) + else: + break + cap.release() + return frames + + def image_transform(self, images, lm): + W,H = images.size + if np.mean(lm) == -1: + lm = (self.lm3d_std[:, :2]+1)/2. + lm = np.concatenate( + [lm[:, :1]*W, lm[:, 1:2]*H], 1 + ) + else: + lm[:, -1] = H - 1 - lm[:, -1] + + trans_params, img, lm, _ = align_img(images, lm, self.lm3d_std) + img = torch.tensor(np.array(img)/255., dtype=torch.float32).permute(2, 0, 1) + lm = torch.tensor(lm) + trans_params = np.array([float(item) for item in np.hsplit(trans_params, 5)]) + trans_params = torch.tensor(trans_params.astype(np.float32)) + return img, lm, trans_params + +def main(opt, model): + # import torch.multiprocessing + # torch.multiprocessing.set_sharing_strategy('file_system') + filenames, keypoint_filenames = get_data_path(opt.input_dir, opt.keypoint_dir) + dataset = VideoPathDataset(filenames, keypoint_filenames, opt.bfm_folder) + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=1, # can noly set to one here! + shuffle=False, + drop_last=False, + num_workers=0, + ) + batch_size = opt.inference_batch_size + for data in tqdm(dataloader): + num_batch = data['imgs'][0].shape[0] // batch_size + 1 + pred_coeffs = list() + for index in range(num_batch): + data_input = { + 'imgs': data['imgs'][0,index*batch_size:(index+1)*batch_size], + } + model.set_input(data_input) + model.test() + pred_coeff = {key:model.pred_coeffs_dict[key].cpu().numpy() for key in model.pred_coeffs_dict} + pred_coeff = np.concatenate([ + pred_coeff['id'], + pred_coeff['exp'], + pred_coeff['tex'], + pred_coeff['angle'], + pred_coeff['gamma'], + pred_coeff['trans']], 1) + pred_coeffs.append(pred_coeff) + visuals = model.get_current_visuals() # get image results + if False: # debug + for name in visuals: + images = visuals[name] + for i in range(images.shape[0]): + image_numpy = tensor2im(images[i]) + save_image( + image_numpy, + os.path.join( + opt.output_dir, + os.path.basename(data['filename'][0])+str(i).zfill(5)+'.jpg') + ) + exit() + + pred_coeffs = np.concatenate(pred_coeffs, 0) + pred_trans_params = data['trans_param'][0].cpu().numpy() + name = data['filename'][0].split('/')[-2:] + name[-1] = os.path.splitext(name[-1])[0] + '.mat' + os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True) + savemat( + os.path.join(opt.output_dir, name[-2], name[-1]), + {'coeff':pred_coeffs, 'transform_params':pred_trans_params} + ) + +if __name__ == '__main__': + opt = InferenceOptions().parse() # get test options + model = create_model(opt) + model.setup(opt) + model.device = 'cuda:0' + model.parallelize() + model.eval() + + main(opt, model) + + diff --git a/videoretalking/third_part/face3d/models/__init__.py b/videoretalking/third_part/face3d/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..11dec40a863670ee332c29d0b99e77515ebdf56c --- /dev/null +++ b/videoretalking/third_part/face3d/models/__init__.py @@ -0,0 +1,67 @@ +"""This package contains modules related to objective functions, optimizations, and network architectures. + +To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. +You need to implement the following five functions: + -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). + -- : unpack data from dataset and apply preprocessing. + -- : produce intermediate results. + -- : calculate loss, gradients, and update network weights. + -- : (optionally) add model-specific options and set default options. + +In the function <__init__>, you need to define four lists: + -- self.loss_names (str list): specify the training losses that you want to plot and save. + -- self.model_names (str list): define networks used in our training. + -- self.visual_names (str list): specify the images that you want to display and save. + -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage. + +Now you can use the model class by specifying flag '--model dummy'. +See our template model class 'template_model.py' for more details. +""" + +import importlib +from face3d.models.base_model import BaseModel + + +def find_model_using_name(model_name): + """Import the module "models/[model_name]_model.py". + + In the file, the class called DatasetNameModel() will + be instantiated. It has to be a subclass of BaseModel, + and it is case-insensitive. + """ + model_filename = "face3d.models." + model_name + "_model" + modellib = importlib.import_module(model_filename) + model = None + target_model_name = model_name.replace('_', '') + 'model' + for name, cls in modellib.__dict__.items(): + if name.lower() == target_model_name.lower() \ + and issubclass(cls, BaseModel): + model = cls + + if model is None: + print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) + exit(0) + + return model + + +def get_option_setter(model_name): + """Return the static method of the model class.""" + model_class = find_model_using_name(model_name) + return model_class.modify_commandline_options + + +def create_model(opt): + """Create a model given the option. + + This function warps the class CustomDatasetDataLoader. + This is the main interface between this package and 'train.py'/'test.py' + + Example: + >>> from models import create_model + >>> model = create_model(opt) + """ + model = find_model_using_name(opt.model) + instance = model(opt) + print("model [%s] was created" % type(instance).__name__) + return instance diff --git a/videoretalking/third_part/face3d/models/arcface_torch/README.md b/videoretalking/third_part/face3d/models/arcface_torch/README.md new file mode 100644 index 0000000000000000000000000000000000000000..cc7f1d45f2f5e4b752c42dc81d3e2879c1459c6e --- /dev/null +++ b/videoretalking/third_part/face3d/models/arcface_torch/README.md @@ -0,0 +1,164 @@ +# Distributed Arcface Training in Pytorch + +This is a deep learning library that makes face recognition efficient, and effective, which can train tens of millions +identity on a single server. + +## Requirements + +- Install [pytorch](http://pytorch.org) (torch>=1.6.0), our doc for [install.md](docs/install.md). +- `pip install -r requirements.txt`. +- Download the dataset + from [https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_](https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_) + . + +## How to Training + +To train a model, run `train.py` with the path to the configs: + +### 1. Single node, 8 GPUs: + +```shell +python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r50 +``` + +### 2. Multiple nodes, each node 8 GPUs: + +Node 0: + +```shell +python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50 +``` + +Node 1: + +```shell +python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50 +``` + +### 3.Training resnet2060 with 8 GPUs: + +```shell +python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r2060.py +``` + +## Model Zoo + +- The models are available for non-commercial research purposes only. +- All models can be found in here. +- [Baidu Yun Pan](https://pan.baidu.com/s/1CL-l4zWqsI1oDuEEYVhj-g): e8pw +- [onedrive](https://1drv.ms/u/s!AswpsDO2toNKq0lWY69vN58GR6mw?e=p9Ov5d) + +### Performance on [**ICCV2021-MFR**](http://iccv21-mfr.com/) + +ICCV2021-MFR testset consists of non-celebrities so we can ensure that it has very few overlap with public available face +recognition training set, such as MS1M and CASIA as they mostly collected from online celebrities. +As the result, we can evaluate the FAIR performance for different algorithms. + +For **ICCV2021-MFR-ALL** set, TAR is measured on all-to-all 1:1 protocal, with FAR less than 0.000001(e-6). The +globalised multi-racial testset contains 242,143 identities and 1,624,305 images. + +For **ICCV2021-MFR-MASK** set, TAR is measured on mask-to-nonmask 1:1 protocal, with FAR less than 0.0001(e-4). +Mask testset contains 6,964 identities, 6,964 masked images and 13,928 non-masked images. +There are totally 13,928 positive pairs and 96,983,824 negative pairs. + +| Datasets | backbone | Training throughout | Size / MB | **ICCV2021-MFR-MASK** | **ICCV2021-MFR-ALL** | +| :---: | :--- | :--- | :--- |:--- |:--- | +| MS1MV3 | r18 | - | 91 | **47.85** | **68.33** | +| Glint360k | r18 | 8536 | 91 | **53.32** | **72.07** | +| MS1MV3 | r34 | - | 130 | **58.72** | **77.36** | +| Glint360k | r34 | 6344 | 130 | **65.10** | **83.02** | +| MS1MV3 | r50 | 5500 | 166 | **63.85** | **80.53** | +| Glint360k | r50 | 5136 | 166 | **70.23** | **87.08** | +| MS1MV3 | r100 | - | 248 | **69.09** | **84.31** | +| Glint360k | r100 | 3332 | 248 | **75.57** | **90.66** | +| MS1MV3 | mobilefacenet | 12185 | 7.8 | **41.52** | **65.26** | +| Glint360k | mobilefacenet | 11197 | 7.8 | **44.52** | **66.48** | + +### Performance on IJB-C and Verification Datasets + +| Datasets | backbone | IJBC(1e-05) | IJBC(1e-04) | agedb30 | cfp_fp | lfw | log | +| :---: | :--- | :--- | :--- | :--- |:--- |:--- |:--- | +| MS1MV3 | r18 | 92.07 | 94.66 | 97.77 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r18_fp16/training.log)| +| MS1MV3 | r34 | 94.10 | 95.90 | 98.10 | 98.67 | 99.80 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r34_fp16/training.log)| +| MS1MV3 | r50 | 94.79 | 96.46 | 98.35 | 98.96 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r50_fp16/training.log)| +| MS1MV3 | r100 | 95.31 | 96.81 | 98.48 | 99.06 | 99.85 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r100_fp16/training.log)| +| MS1MV3 | **r2060**| 95.34 | 97.11 | 98.67 | 99.24 | 99.87 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r2060_fp16/training.log)| +| Glint360k |r18-0.1 | 93.16 | 95.33 | 97.72 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r18_fp16_0.1/training.log)| +| Glint360k |r34-0.1 | 95.16 | 96.56 | 98.33 | 98.78 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r34_fp16_0.1/training.log)| +| Glint360k |r50-0.1 | 95.61 | 96.97 | 98.38 | 99.20 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r50_fp16_0.1/training.log)| +| Glint360k |r100-0.1 | 95.88 | 97.32 | 98.48 | 99.29 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r100_fp16_0.1/training.log)| + +[comment]: <> (More details see [model.md](docs/modelzoo.md) in docs.) + + +## [Speed Benchmark](docs/speed_benchmark.md) + +**Arcface Torch** can train large-scale face recognition training set efficiently and quickly. When the number of +classes in training sets is greater than 300K and the training is sufficient, partial fc sampling strategy will get same +accuracy with several times faster training performance and smaller GPU memory. +Partial FC is a sparse variant of the model parallel architecture for large sacle face recognition. Partial FC use a +sparse softmax, where each batch dynamicly sample a subset of class centers for training. In each iteration, only a +sparse part of the parameters will be updated, which can reduce a lot of GPU memory and calculations. With Partial FC, +we can scale trainset of 29 millions identities, the largest to date. Partial FC also supports multi-machine distributed +training and mixed precision training. + +![Image text](https://github.com/anxiangsir/insightface_arcface_log/blob/master/partial_fc_v2.png) + +More details see +[speed_benchmark.md](docs/speed_benchmark.md) in docs. + +### 1. Training speed of different parallel methods (samples / second), Tesla V100 32GB * 8. (Larger is better) + +`-` means training failed because of gpu memory limitations. + +| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 | +| :--- | :--- | :--- | :--- | +|125000 | 4681 | 4824 | 5004 | +|1400000 | **1672** | 3043 | 4738 | +|5500000 | **-** | **1389** | 3975 | +|8000000 | **-** | **-** | 3565 | +|16000000 | **-** | **-** | 2679 | +|29000000 | **-** | **-** | **1855** | + +### 2. GPU memory cost of different parallel methods (MB per GPU), Tesla V100 32GB * 8. (Smaller is better) + +| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 | +| :--- | :--- | :--- | :--- | +|125000 | 7358 | 5306 | 4868 | +|1400000 | 32252 | 11178 | 6056 | +|5500000 | **-** | 32188 | 9854 | +|8000000 | **-** | **-** | 12310 | +|16000000 | **-** | **-** | 19950 | +|29000000 | **-** | **-** | 32324 | + +## Evaluation ICCV2021-MFR and IJB-C + +More details see [eval.md](docs/eval.md) in docs. + +## Test + +We tested many versions of PyTorch. Please create an issue if you are having trouble. + +- [x] torch 1.6.0 +- [x] torch 1.7.1 +- [x] torch 1.8.0 +- [x] torch 1.9.0 + +## Citation + +``` +@inproceedings{deng2019arcface, + title={Arcface: Additive angular margin loss for deep face recognition}, + author={Deng, Jiankang and Guo, Jia and Xue, Niannan and Zafeiriou, Stefanos}, + booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, + pages={4690--4699}, + year={2019} +} +@inproceedings{an2020partical_fc, + title={Partial FC: Training 10 Million Identities on a Single Machine}, + author={An, Xiang and Zhu, Xuhan and Xiao, Yang and Wu, Lan and Zhang, Ming and Gao, Yuan and Qin, Bin and + Zhang, Debing and Fu Ying}, + booktitle={Arxiv 2010.05222}, + year={2020} +} +``` diff --git a/videoretalking/third_part/face3d/models/arcface_torch/backbones/__init__.py b/videoretalking/third_part/face3d/models/arcface_torch/backbones/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5650187b4fdea84c5a23e0445440901690ab682a --- /dev/null +++ b/videoretalking/third_part/face3d/models/arcface_torch/backbones/__init__.py @@ -0,0 +1,25 @@ +from .iresnet import iresnet18, iresnet34, iresnet50, iresnet100, iresnet200 +from .mobilefacenet import get_mbf + + +def get_model(name, **kwargs): + # resnet + if name == "r18": + return iresnet18(False, **kwargs) + elif name == "r34": + return iresnet34(False, **kwargs) + elif name == "r50": + return iresnet50(False, **kwargs) + elif name == "r100": + return iresnet100(False, **kwargs) + elif name == "r200": + return iresnet200(False, **kwargs) + elif name == "r2060": + from .iresnet2060 import iresnet2060 + return iresnet2060(False, **kwargs) + elif name == "mbf": + fp16 = kwargs.get("fp16", False) + num_features = kwargs.get("num_features", 512) + return get_mbf(fp16=fp16, num_features=num_features) + else: + raise ValueError() \ No newline at end of file diff --git a/videoretalking/third_part/face3d/models/arcface_torch/backbones/iresnet.py b/videoretalking/third_part/face3d/models/arcface_torch/backbones/iresnet.py new file mode 100644 index 0000000000000000000000000000000000000000..d29f5f2bfbd444273717c4bc8aa20ba7edd08f80 --- /dev/null +++ b/videoretalking/third_part/face3d/models/arcface_torch/backbones/iresnet.py @@ -0,0 +1,187 @@ +import torch +from torch import nn + +__all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200'] + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, + out_planes, + kernel_size=1, + stride=stride, + bias=False) + + +class IBasicBlock(nn.Module): + expansion = 1 + def __init__(self, inplanes, planes, stride=1, downsample=None, + groups=1, base_width=64, dilation=1): + super(IBasicBlock, self).__init__() + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,) + self.conv1 = conv3x3(inplanes, planes) + self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,) + self.prelu = nn.PReLU(planes) + self.conv2 = conv3x3(planes, planes, stride) + self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + out = self.bn1(x) + out = self.conv1(out) + out = self.bn2(out) + out = self.prelu(out) + out = self.conv2(out) + out = self.bn3(out) + if self.downsample is not None: + identity = self.downsample(x) + out += identity + return out + + +class IResNet(nn.Module): + fc_scale = 7 * 7 + def __init__(self, + block, layers, dropout=0, num_features=512, zero_init_residual=False, + groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False): + super(IResNet, self).__init__() + self.fp16 = fp16 + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05) + self.prelu = nn.PReLU(self.inplanes) + self.layer1 = self._make_layer(block, 64, layers[0], stride=2) + self.layer2 = self._make_layer(block, + 128, + layers[1], + stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, + 256, + layers[2], + stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, + 512, + layers[3], + stride=2, + dilate=replace_stride_with_dilation[2]) + self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,) + self.dropout = nn.Dropout(p=dropout, inplace=True) + self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features) + self.features = nn.BatchNorm1d(num_features, eps=1e-05) + nn.init.constant_(self.features.weight, 1.0) + self.features.weight.requires_grad = False + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.normal_(m.weight, 0, 0.1) + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + if zero_init_residual: + for m in self.modules(): + if isinstance(m, IBasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ), + ) + layers = [] + layers.append( + block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block(self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation)) + + return nn.Sequential(*layers) + + def forward(self, x): + with torch.cuda.amp.autocast(self.fp16): + x = self.conv1(x) + x = self.bn1(x) + x = self.prelu(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.bn2(x) + x = torch.flatten(x, 1) + x = self.dropout(x) + x = self.fc(x.float() if self.fp16 else x) + x = self.features(x) + return x + + +def _iresnet(arch, block, layers, pretrained, progress, **kwargs): + model = IResNet(block, layers, **kwargs) + if pretrained: + raise ValueError() + return model + + +def iresnet18(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained, + progress, **kwargs) + + +def iresnet34(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained, + progress, **kwargs) + + +def iresnet50(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained, + progress, **kwargs) + + +def iresnet100(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained, + progress, **kwargs) + + +def iresnet200(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained, + progress, **kwargs) + diff --git a/videoretalking/third_part/face3d/models/arcface_torch/backbones/iresnet2060.py b/videoretalking/third_part/face3d/models/arcface_torch/backbones/iresnet2060.py new file mode 100644 index 0000000000000000000000000000000000000000..39bb4335716b653bd5924e20d616d825ef48339f --- /dev/null +++ b/videoretalking/third_part/face3d/models/arcface_torch/backbones/iresnet2060.py @@ -0,0 +1,176 @@ +import torch +from torch import nn + +assert torch.__version__ >= "1.8.1" +from torch.utils.checkpoint import checkpoint_sequential + +__all__ = ['iresnet2060'] + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, + out_planes, + kernel_size=1, + stride=stride, + bias=False) + + +class IBasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, + groups=1, base_width=64, dilation=1): + super(IBasicBlock, self).__init__() + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05, ) + self.conv1 = conv3x3(inplanes, planes) + self.bn2 = nn.BatchNorm2d(planes, eps=1e-05, ) + self.prelu = nn.PReLU(planes) + self.conv2 = conv3x3(planes, planes, stride) + self.bn3 = nn.BatchNorm2d(planes, eps=1e-05, ) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + out = self.bn1(x) + out = self.conv1(out) + out = self.bn2(out) + out = self.prelu(out) + out = self.conv2(out) + out = self.bn3(out) + if self.downsample is not None: + identity = self.downsample(x) + out += identity + return out + + +class IResNet(nn.Module): + fc_scale = 7 * 7 + + def __init__(self, + block, layers, dropout=0, num_features=512, zero_init_residual=False, + groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False): + super(IResNet, self).__init__() + self.fp16 = fp16 + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05) + self.prelu = nn.PReLU(self.inplanes) + self.layer1 = self._make_layer(block, 64, layers[0], stride=2) + self.layer2 = self._make_layer(block, + 128, + layers[1], + stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, + 256, + layers[2], + stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, + 512, + layers[3], + stride=2, + dilate=replace_stride_with_dilation[2]) + self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05, ) + self.dropout = nn.Dropout(p=dropout, inplace=True) + self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features) + self.features = nn.BatchNorm1d(num_features, eps=1e-05) + nn.init.constant_(self.features.weight, 1.0) + self.features.weight.requires_grad = False + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.normal_(m.weight, 0, 0.1) + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + if zero_init_residual: + for m in self.modules(): + if isinstance(m, IBasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ), + ) + layers = [] + layers.append( + block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block(self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation)) + + return nn.Sequential(*layers) + + def checkpoint(self, func, num_seg, x): + if self.training: + return checkpoint_sequential(func, num_seg, x) + else: + return func(x) + + def forward(self, x): + with torch.cuda.amp.autocast(self.fp16): + x = self.conv1(x) + x = self.bn1(x) + x = self.prelu(x) + x = self.layer1(x) + x = self.checkpoint(self.layer2, 20, x) + x = self.checkpoint(self.layer3, 100, x) + x = self.layer4(x) + x = self.bn2(x) + x = torch.flatten(x, 1) + x = self.dropout(x) + x = self.fc(x.float() if self.fp16 else x) + x = self.features(x) + return x + + +def _iresnet(arch, block, layers, pretrained, progress, **kwargs): + model = IResNet(block, layers, **kwargs) + if pretrained: + raise ValueError() + return model + + +def iresnet2060(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet2060', IBasicBlock, [3, 128, 1024 - 128, 3], pretrained, progress, **kwargs) diff --git a/videoretalking/third_part/face3d/models/arcface_torch/backbones/mobilefacenet.py b/videoretalking/third_part/face3d/models/arcface_torch/backbones/mobilefacenet.py new file mode 100644 index 0000000000000000000000000000000000000000..c02c6c1e4fa6a6ddf09f5b01dec96971427cb110 --- /dev/null +++ b/videoretalking/third_part/face3d/models/arcface_torch/backbones/mobilefacenet.py @@ -0,0 +1,130 @@ +''' +Adapted from https://github.com/cavalleria/cavaface.pytorch/blob/master/backbone/mobilefacenet.py +Original author cavalleria +''' + +import torch.nn as nn +from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Sequential, Module +import torch + + +class Flatten(Module): + def forward(self, x): + return x.view(x.size(0), -1) + + +class ConvBlock(Module): + def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1): + super(ConvBlock, self).__init__() + self.layers = nn.Sequential( + Conv2d(in_c, out_c, kernel, groups=groups, stride=stride, padding=padding, bias=False), + BatchNorm2d(num_features=out_c), + PReLU(num_parameters=out_c) + ) + + def forward(self, x): + return self.layers(x) + + +class LinearBlock(Module): + def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1): + super(LinearBlock, self).__init__() + self.layers = nn.Sequential( + Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False), + BatchNorm2d(num_features=out_c) + ) + + def forward(self, x): + return self.layers(x) + + +class DepthWise(Module): + def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1): + super(DepthWise, self).__init__() + self.residual = residual + self.layers = nn.Sequential( + ConvBlock(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)), + ConvBlock(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride), + LinearBlock(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1)) + ) + + def forward(self, x): + short_cut = None + if self.residual: + short_cut = x + x = self.layers(x) + if self.residual: + output = short_cut + x + else: + output = x + return output + + +class Residual(Module): + def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)): + super(Residual, self).__init__() + modules = [] + for _ in range(num_block): + modules.append(DepthWise(c, c, True, kernel, stride, padding, groups)) + self.layers = Sequential(*modules) + + def forward(self, x): + return self.layers(x) + + +class GDC(Module): + def __init__(self, embedding_size): + super(GDC, self).__init__() + self.layers = nn.Sequential( + LinearBlock(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0)), + Flatten(), + Linear(512, embedding_size, bias=False), + BatchNorm1d(embedding_size)) + + def forward(self, x): + return self.layers(x) + + +class MobileFaceNet(Module): + def __init__(self, fp16=False, num_features=512): + super(MobileFaceNet, self).__init__() + scale = 2 + self.fp16 = fp16 + self.layers = nn.Sequential( + ConvBlock(3, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1)), + ConvBlock(64 * scale, 64 * scale, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64), + DepthWise(64 * scale, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128), + Residual(64 * scale, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)), + DepthWise(64 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256), + Residual(128 * scale, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)), + DepthWise(128 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512), + Residual(128 * scale, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)), + ) + self.conv_sep = ConvBlock(128 * scale, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0)) + self.features = GDC(num_features) + self._initialize_weights() + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x): + with torch.cuda.amp.autocast(self.fp16): + x = self.layers(x) + x = self.conv_sep(x.float() if self.fp16 else x) + x = self.features(x) + return x + + +def get_mbf(fp16, num_features): + return MobileFaceNet(fp16, num_features) \ No newline at end of file diff --git a/videoretalking/third_part/face3d/models/arcface_torch/configs/3millions.py b/videoretalking/third_part/face3d/models/arcface_torch/configs/3millions.py new file mode 100644 index 0000000000000000000000000000000000000000..3bee7cb4236e8b842a1bd1e8c26de7a11df0bf43 --- /dev/null +++ b/videoretalking/third_part/face3d/models/arcface_torch/configs/3millions.py @@ -0,0 +1,23 @@ +from easydict import EasyDict as edict + +# configs for test speed + +config = edict() +config.loss = "arcface" +config.network = "r50" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.1 # batch size is 512 + +config.rec = "synthetic" +config.num_classes = 300 * 10000 +config.num_epoch = 30 +config.warmup_epoch = -1 +config.decay_epoch = [10, 16, 22] +config.val_targets = [] diff --git a/videoretalking/third_part/face3d/models/arcface_torch/configs/3millions_pfc.py b/videoretalking/third_part/face3d/models/arcface_torch/configs/3millions_pfc.py new file mode 100644 index 0000000000000000000000000000000000000000..bf7df5f04e2509e5dcc14adebbb9302a18f03f2b --- /dev/null +++ b/videoretalking/third_part/face3d/models/arcface_torch/configs/3millions_pfc.py @@ -0,0 +1,23 @@ +from easydict import EasyDict as edict + +# configs for test speed + +config = edict() +config.loss = "arcface" +config.network = "r50" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 0.1 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.1 # batch size is 512 + +config.rec = "synthetic" +config.num_classes = 300 * 10000 +config.num_epoch = 30 +config.warmup_epoch = -1 +config.decay_epoch = [10, 16, 22] +config.val_targets = [] diff --git a/videoretalking/third_part/face3d/models/arcface_torch/configs/__init__.py b/videoretalking/third_part/face3d/models/arcface_torch/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/videoretalking/third_part/face3d/models/arcface_torch/configs/base.py b/videoretalking/third_part/face3d/models/arcface_torch/configs/base.py new file mode 100644 index 0000000000000000000000000000000000000000..f98c62fed44afde276dcbacecd9da0a8f474963c --- /dev/null +++ b/videoretalking/third_part/face3d/models/arcface_torch/configs/base.py @@ -0,0 +1,56 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.loss = "arcface" +config.network = "r50" +config.resume = False +config.output = "ms1mv3_arcface_r50" + +config.dataset = "ms1m-retinaface-t1" +config.embedding_size = 512 +config.sample_rate = 1 +config.fp16 = False +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.1 # batch size is 512 + +if config.dataset == "emore": + config.rec = "/train_tmp/faces_emore" + config.num_classes = 85742 + config.num_image = 5822653 + config.num_epoch = 16 + config.warmup_epoch = -1 + config.decay_epoch = [8, 14, ] + config.val_targets = ["lfw", ] + +elif config.dataset == "ms1m-retinaface-t1": + config.rec = "/train_tmp/ms1m-retinaface-t1" + config.num_classes = 93431 + config.num_image = 5179510 + config.num_epoch = 25 + config.warmup_epoch = -1 + config.decay_epoch = [11, 17, 22] + config.val_targets = ["lfw", "cfp_fp", "agedb_30"] + +elif config.dataset == "glint360k": + config.rec = "/train_tmp/glint360k" + config.num_classes = 360232 + config.num_image = 17091657 + config.num_epoch = 20 + config.warmup_epoch = -1 + config.decay_epoch = [8, 12, 15, 18] + config.val_targets = ["lfw", "cfp_fp", "agedb_30"] + +elif config.dataset == "webface": + config.rec = "/train_tmp/faces_webface_112x112" + config.num_classes = 10572 + config.num_image = "forget" + config.num_epoch = 34 + config.warmup_epoch = -1 + config.decay_epoch = [20, 28, 32] + config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/videoretalking/third_part/face3d/models/arcface_torch/configs/glint360k_mbf.py b/videoretalking/third_part/face3d/models/arcface_torch/configs/glint360k_mbf.py new file mode 100644 index 0000000000000000000000000000000000000000..44ee5e8d96249d57196df43418f6fda4ab339877 --- /dev/null +++ b/videoretalking/third_part/face3d/models/arcface_torch/configs/glint360k_mbf.py @@ -0,0 +1,26 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.loss = "cosface" +config.network = "mbf" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 0.1 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 2e-4 +config.batch_size = 128 +config.lr = 0.1 # batch size is 512 + +config.rec = "/train_tmp/glint360k" +config.num_classes = 360232 +config.num_image = 17091657 +config.num_epoch = 20 +config.warmup_epoch = -1 +config.decay_epoch = [8, 12, 15, 18] +config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/videoretalking/third_part/face3d/models/arcface_torch/configs/glint360k_r100.py b/videoretalking/third_part/face3d/models/arcface_torch/configs/glint360k_r100.py new file mode 100644 index 0000000000000000000000000000000000000000..f8f8ef745c0efb9d5ea67409edc8c904def8a9d9 --- /dev/null +++ b/videoretalking/third_part/face3d/models/arcface_torch/configs/glint360k_r100.py @@ -0,0 +1,26 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.loss = "cosface" +config.network = "r100" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.1 # batch size is 512 + +config.rec = "/train_tmp/glint360k" +config.num_classes = 360232 +config.num_image = 17091657 +config.num_epoch = 20 +config.warmup_epoch = -1 +config.decay_epoch = [8, 12, 15, 18] +config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/videoretalking/third_part/face3d/models/arcface_torch/configs/glint360k_r18.py b/videoretalking/third_part/face3d/models/arcface_torch/configs/glint360k_r18.py new file mode 100644 index 0000000000000000000000000000000000000000..473b59a954fffcaddca132fb6e0f32cbe70c70f4 --- /dev/null +++ b/videoretalking/third_part/face3d/models/arcface_torch/configs/glint360k_r18.py @@ -0,0 +1,26 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.loss = "cosface" +config.network = "r18" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.1 # batch size is 512 + +config.rec = "/train_tmp/glint360k" +config.num_classes = 360232 +config.num_image = 17091657 +config.num_epoch = 20 +config.warmup_epoch = -1 +config.decay_epoch = [8, 12, 15, 18] +config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/videoretalking/third_part/face3d/models/arcface_torch/configs/glint360k_r34.py b/videoretalking/third_part/face3d/models/arcface_torch/configs/glint360k_r34.py new file mode 100644 index 0000000000000000000000000000000000000000..d9c22ff0c82cc98bbbe81c9a1c26c9b3fc186105 --- /dev/null +++ b/videoretalking/third_part/face3d/models/arcface_torch/configs/glint360k_r34.py @@ -0,0 +1,26 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.loss = "cosface" +config.network = "r34" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.1 # batch size is 512 + +config.rec = "/train_tmp/glint360k" +config.num_classes = 360232 +config.num_image = 17091657 +config.num_epoch = 20 +config.warmup_epoch = -1 +config.decay_epoch = [8, 12, 15, 18] +config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/videoretalking/third_part/face3d/models/arcface_torch/configs/glint360k_r50.py b/videoretalking/third_part/face3d/models/arcface_torch/configs/glint360k_r50.py new file mode 100644 index 0000000000000000000000000000000000000000..8ecbfda06730e3842e7b347db366e82f0714912f --- /dev/null +++ b/videoretalking/third_part/face3d/models/arcface_torch/configs/glint360k_r50.py @@ -0,0 +1,26 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.loss = "cosface" +config.network = "r50" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.1 # batch size is 512 + +config.rec = "/train_tmp/glint360k" +config.num_classes = 360232 +config.num_image = 17091657 +config.num_epoch = 20 +config.warmup_epoch = -1 +config.decay_epoch = [8, 12, 15, 18] +config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/videoretalking/third_part/face3d/models/arcface_torch/configs/ms1mv3_mbf.py b/videoretalking/third_part/face3d/models/arcface_torch/configs/ms1mv3_mbf.py new file mode 100644 index 0000000000000000000000000000000000000000..47c87a99867db55c7f689574c331c14cda23ea96 --- /dev/null +++ b/videoretalking/third_part/face3d/models/arcface_torch/configs/ms1mv3_mbf.py @@ -0,0 +1,26 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.loss = "arcface" +config.network = "mbf" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 2e-4 +config.batch_size = 128 +config.lr = 0.1 # batch size is 512 + +config.rec = "/train_tmp/ms1m-retinaface-t1" +config.num_classes = 93431 +config.num_image = 5179510 +config.num_epoch = 30 +config.warmup_epoch = -1 +config.decay_epoch = [10, 20, 25] +config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/videoretalking/third_part/face3d/models/arcface_torch/configs/ms1mv3_r18.py b/videoretalking/third_part/face3d/models/arcface_torch/configs/ms1mv3_r18.py new file mode 100644 index 0000000000000000000000000000000000000000..1aeb851b05ea22e01da87b3d387812f0253989f8 --- /dev/null +++ b/videoretalking/third_part/face3d/models/arcface_torch/configs/ms1mv3_r18.py @@ -0,0 +1,26 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.loss = "arcface" +config.network = "r18" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.1 # batch size is 512 + +config.rec = "/train_tmp/ms1m-retinaface-t1" +config.num_classes = 93431 +config.num_image = 5179510 +config.num_epoch = 25 +config.warmup_epoch = -1 +config.decay_epoch = [10, 16, 22] +config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/videoretalking/third_part/face3d/models/arcface_torch/configs/ms1mv3_r2060.py b/videoretalking/third_part/face3d/models/arcface_torch/configs/ms1mv3_r2060.py new file mode 100644 index 0000000000000000000000000000000000000000..8693e67080dac7e7b84da08a62df326c7b12d465 --- /dev/null +++ b/videoretalking/third_part/face3d/models/arcface_torch/configs/ms1mv3_r2060.py @@ -0,0 +1,26 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.loss = "arcface" +config.network = "r2060" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 64 +config.lr = 0.1 # batch size is 512 + +config.rec = "/train_tmp/ms1m-retinaface-t1" +config.num_classes = 93431 +config.num_image = 5179510 +config.num_epoch = 25 +config.warmup_epoch = -1 +config.decay_epoch = [10, 16, 22] +config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/videoretalking/third_part/face3d/models/arcface_torch/configs/ms1mv3_r34.py b/videoretalking/third_part/face3d/models/arcface_torch/configs/ms1mv3_r34.py new file mode 100644 index 0000000000000000000000000000000000000000..52bff483db179045c0e3acc8e2975477182b0756 --- /dev/null +++ b/videoretalking/third_part/face3d/models/arcface_torch/configs/ms1mv3_r34.py @@ -0,0 +1,26 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.loss = "arcface" +config.network = "r34" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.1 # batch size is 512 + +config.rec = "/train_tmp/ms1m-retinaface-t1" +config.num_classes = 93431 +config.num_image = 5179510 +config.num_epoch = 25 +config.warmup_epoch = -1 +config.decay_epoch = [10, 16, 22] +config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/videoretalking/third_part/face3d/models/arcface_torch/configs/ms1mv3_r50.py b/videoretalking/third_part/face3d/models/arcface_torch/configs/ms1mv3_r50.py new file mode 100644 index 0000000000000000000000000000000000000000..de81ffdd84edd6fcea7fcb4d3594db031b9e4e26 --- /dev/null +++ b/videoretalking/third_part/face3d/models/arcface_torch/configs/ms1mv3_r50.py @@ -0,0 +1,26 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.loss = "arcface" +config.network = "r50" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.1 # batch size is 512 + +config.rec = "/train_tmp/ms1m-retinaface-t1" +config.num_classes = 93431 +config.num_image = 5179510 +config.num_epoch = 25 +config.warmup_epoch = -1 +config.decay_epoch = [10, 16, 22] +config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/videoretalking/third_part/face3d/models/arcface_torch/configs/speed.py b/videoretalking/third_part/face3d/models/arcface_torch/configs/speed.py new file mode 100644 index 0000000000000000000000000000000000000000..c172f9d44d39b534f2253630471e91cf78e6fba7 --- /dev/null +++ b/videoretalking/third_part/face3d/models/arcface_torch/configs/speed.py @@ -0,0 +1,23 @@ +from easydict import EasyDict as edict + +# configs for test speed + +config = edict() +config.loss = "arcface" +config.network = "r50" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.1 # batch size is 512 + +config.rec = "synthetic" +config.num_classes = 100 * 10000 +config.num_epoch = 30 +config.warmup_epoch = -1 +config.decay_epoch = [10, 16, 22] +config.val_targets = [] diff --git a/videoretalking/third_part/face3d/models/arcface_torch/dataset.py b/videoretalking/third_part/face3d/models/arcface_torch/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..8bead250243237c650fa3138f6aa172d4f98535f --- /dev/null +++ b/videoretalking/third_part/face3d/models/arcface_torch/dataset.py @@ -0,0 +1,124 @@ +import numbers +import os +import queue as Queue +import threading + +import mxnet as mx +import numpy as np +import torch +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms + + +class BackgroundGenerator(threading.Thread): + def __init__(self, generator, local_rank, max_prefetch=6): + super(BackgroundGenerator, self).__init__() + self.queue = Queue.Queue(max_prefetch) + self.generator = generator + self.local_rank = local_rank + self.daemon = True + self.start() + + def run(self): + torch.cuda.set_device(self.local_rank) + for item in self.generator: + self.queue.put(item) + self.queue.put(None) + + def next(self): + next_item = self.queue.get() + if next_item is None: + raise StopIteration + return next_item + + def __next__(self): + return self.next() + + def __iter__(self): + return self + + +class DataLoaderX(DataLoader): + + def __init__(self, local_rank, **kwargs): + super(DataLoaderX, self).__init__(**kwargs) + self.stream = torch.cuda.Stream(local_rank) + self.local_rank = local_rank + + def __iter__(self): + self.iter = super(DataLoaderX, self).__iter__() + self.iter = BackgroundGenerator(self.iter, self.local_rank) + self.preload() + return self + + def preload(self): + self.batch = next(self.iter, None) + if self.batch is None: + return None + with torch.cuda.stream(self.stream): + for k in range(len(self.batch)): + self.batch[k] = self.batch[k].to(device=self.local_rank, non_blocking=True) + + def __next__(self): + torch.cuda.current_stream().wait_stream(self.stream) + batch = self.batch + if batch is None: + raise StopIteration + self.preload() + return batch + + +class MXFaceDataset(Dataset): + def __init__(self, root_dir, local_rank): + super(MXFaceDataset, self).__init__() + self.transform = transforms.Compose( + [transforms.ToPILImage(), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + self.root_dir = root_dir + self.local_rank = local_rank + path_imgrec = os.path.join(root_dir, 'train.rec') + path_imgidx = os.path.join(root_dir, 'train.idx') + self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r') + s = self.imgrec.read_idx(0) + header, _ = mx.recordio.unpack(s) + if header.flag > 0: + self.header0 = (int(header.label[0]), int(header.label[1])) + self.imgidx = np.array(range(1, int(header.label[0]))) + else: + self.imgidx = np.array(list(self.imgrec.keys)) + + def __getitem__(self, index): + idx = self.imgidx[index] + s = self.imgrec.read_idx(idx) + header, img = mx.recordio.unpack(s) + label = header.label + if not isinstance(label, numbers.Number): + label = label[0] + label = torch.tensor(label, dtype=torch.long) + sample = mx.image.imdecode(img).asnumpy() + if self.transform is not None: + sample = self.transform(sample) + return sample, label + + def __len__(self): + return len(self.imgidx) + + +class SyntheticDataset(Dataset): + def __init__(self, local_rank): + super(SyntheticDataset, self).__init__() + img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32) + img = np.transpose(img, (2, 0, 1)) + img = torch.from_numpy(img).squeeze(0).float() + img = ((img / 255) - 0.5) / 0.5 + self.img = img + self.label = 1 + + def __getitem__(self, index): + return self.img, self.label + + def __len__(self): + return 1000000 diff --git a/videoretalking/third_part/face3d/models/arcface_torch/docs/eval.md b/videoretalking/third_part/face3d/models/arcface_torch/docs/eval.md new file mode 100644 index 0000000000000000000000000000000000000000..4d29c855fc6e4245ed264216c1f96ab2efc57248 --- /dev/null +++ b/videoretalking/third_part/face3d/models/arcface_torch/docs/eval.md @@ -0,0 +1,31 @@ +## Eval on ICCV2021-MFR + +coming soon. + + +## Eval IJBC +You can eval ijbc with pytorch or onnx. + + +1. Eval IJBC With Onnx +```shell +CUDA_VISIBLE_DEVICES=0 python onnx_ijbc.py --model-root ms1mv3_arcface_r50 --image-path IJB_release/IJBC --result-dir ms1mv3_arcface_r50 +``` + +2. Eval IJBC With Pytorch +```shell +CUDA_VISIBLE_DEVICES=0,1 python eval_ijbc.py \ +--model-prefix ms1mv3_arcface_r50/backbone.pth \ +--image-path IJB_release/IJBC \ +--result-dir ms1mv3_arcface_r50 \ +--batch-size 128 \ +--job ms1mv3_arcface_r50 \ +--target IJBC \ +--network iresnet50 +``` + +## Inference + +```shell +python inference.py --weight ms1mv3_arcface_r50/backbone.pth --network r50 +``` diff --git a/videoretalking/third_part/face3d/models/arcface_torch/docs/install.md b/videoretalking/third_part/face3d/models/arcface_torch/docs/install.md new file mode 100644 index 0000000000000000000000000000000000000000..b1b770a0d93dac1f160185b5bbf4da2f414f21f6 --- /dev/null +++ b/videoretalking/third_part/face3d/models/arcface_torch/docs/install.md @@ -0,0 +1,51 @@ +## v1.8.0 +### Linux and Windows +```shell +# CUDA 11.0 +pip --default-timeout=100 install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html + +# CUDA 10.2 +pip --default-timeout=100 install torch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 + +# CPU only +pip --default-timeout=100 install torch==1.8.0+cpu torchvision==0.9.0+cpu torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html + +``` + + +## v1.7.1 +### Linux and Windows +```shell +# CUDA 11.0 +pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html + +# CUDA 10.2 +pip install torch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 + +# CUDA 10.1 +pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html + +# CUDA 9.2 +pip install torch==1.7.1+cu92 torchvision==0.8.2+cu92 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html + +# CPU only +pip install torch==1.7.1+cpu torchvision==0.8.2+cpu torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html +``` + + +## v1.6.0 + +### Linux and Windows +```shell +# CUDA 10.2 +pip install torch==1.6.0 torchvision==0.7.0 + +# CUDA 10.1 +pip install torch==1.6.0+cu101 torchvision==0.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html + +# CUDA 9.2 +pip install torch==1.6.0+cu92 torchvision==0.7.0+cu92 -f https://download.pytorch.org/whl/torch_stable.html + +# CPU only +pip install torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html +``` \ No newline at end of file diff --git a/videoretalking/third_part/face3d/models/arcface_torch/docs/modelzoo.md b/videoretalking/third_part/face3d/models/arcface_torch/docs/modelzoo.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/videoretalking/third_part/face3d/models/arcface_torch/docs/speed_benchmark.md b/videoretalking/third_part/face3d/models/arcface_torch/docs/speed_benchmark.md new file mode 100644 index 0000000000000000000000000000000000000000..d54904587df4e13784dc68d5709b4d7d97490890 --- /dev/null +++ b/videoretalking/third_part/face3d/models/arcface_torch/docs/speed_benchmark.md @@ -0,0 +1,93 @@ +## Test Training Speed + +- Test Commands + +You need to use the following two commands to test the Partial FC training performance. +The number of identites is **3 millions** (synthetic data), turn mixed precision training on, backbone is resnet50, +batch size is 1024. +```shell +# Model Parallel +python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/3millions +# Partial FC 0.1 +python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/3millions_pfc +``` + +- GPU Memory + +``` +# (Model Parallel) gpustat -i +[0] Tesla V100-SXM2-32GB | 64'C, 94 % | 30338 / 32510 MB +[1] Tesla V100-SXM2-32GB | 60'C, 99 % | 28876 / 32510 MB +[2] Tesla V100-SXM2-32GB | 60'C, 99 % | 28872 / 32510 MB +[3] Tesla V100-SXM2-32GB | 69'C, 99 % | 28872 / 32510 MB +[4] Tesla V100-SXM2-32GB | 66'C, 99 % | 28888 / 32510 MB +[5] Tesla V100-SXM2-32GB | 60'C, 99 % | 28932 / 32510 MB +[6] Tesla V100-SXM2-32GB | 68'C, 100 % | 28916 / 32510 MB +[7] Tesla V100-SXM2-32GB | 65'C, 99 % | 28860 / 32510 MB + +# (Partial FC 0.1) gpustat -i +[0] Tesla V100-SXM2-32GB | 60'C, 95 % | 10488 / 32510 MB │······················· +[1] Tesla V100-SXM2-32GB | 60'C, 97 % | 10344 / 32510 MB │······················· +[2] Tesla V100-SXM2-32GB | 61'C, 95 % | 10340 / 32510 MB │······················· +[3] Tesla V100-SXM2-32GB | 66'C, 95 % | 10340 / 32510 MB │······················· +[4] Tesla V100-SXM2-32GB | 65'C, 94 % | 10356 / 32510 MB │······················· +[5] Tesla V100-SXM2-32GB | 61'C, 95 % | 10400 / 32510 MB │······················· +[6] Tesla V100-SXM2-32GB | 68'C, 96 % | 10384 / 32510 MB │······················· +[7] Tesla V100-SXM2-32GB | 64'C, 95 % | 10328 / 32510 MB │······················· +``` + +- Training Speed + +```python +# (Model Parallel) trainging.log +Training: Speed 2271.33 samples/sec Loss 1.1624 LearningRate 0.2000 Epoch: 0 Global Step: 100 +Training: Speed 2269.94 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 150 +Training: Speed 2272.67 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 200 +Training: Speed 2266.55 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 250 +Training: Speed 2272.54 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 300 + +# (Partial FC 0.1) trainging.log +Training: Speed 5299.56 samples/sec Loss 1.0965 LearningRate 0.2000 Epoch: 0 Global Step: 100 +Training: Speed 5296.37 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 150 +Training: Speed 5304.37 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 200 +Training: Speed 5274.43 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 250 +Training: Speed 5300.10 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 300 +``` + +In this test case, Partial FC 0.1 only use1 1/3 of the GPU memory of the model parallel, +and the training speed is 2.5 times faster than the model parallel. + + +## Speed Benchmark + +1. Training speed of different parallel methods (samples/second), Tesla V100 32GB * 8. (Larger is better) + +| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 | +| :--- | :--- | :--- | :--- | +|125000 | 4681 | 4824 | 5004 | +|250000 | 4047 | 4521 | 4976 | +|500000 | 3087 | 4013 | 4900 | +|1000000 | 2090 | 3449 | 4803 | +|1400000 | 1672 | 3043 | 4738 | +|2000000 | - | 2593 | 4626 | +|4000000 | - | 1748 | 4208 | +|5500000 | - | 1389 | 3975 | +|8000000 | - | - | 3565 | +|16000000 | - | - | 2679 | +|29000000 | - | - | 1855 | + +2. GPU memory cost of different parallel methods (GB per GPU), Tesla V100 32GB * 8. (Smaller is better) + +| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 | +| :--- | :--- | :--- | :--- | +|125000 | 7358 | 5306 | 4868 | +|250000 | 9940 | 5826 | 5004 | +|500000 | 14220 | 7114 | 5202 | +|1000000 | 23708 | 9966 | 5620 | +|1400000 | 32252 | 11178 | 6056 | +|2000000 | - | 13978 | 6472 | +|4000000 | - | 23238 | 8284 | +|5500000 | - | 32188 | 9854 | +|8000000 | - | - | 12310 | +|16000000 | - | - | 19950 | +|29000000 | - | - | 32324 | diff --git a/videoretalking/third_part/face3d/models/arcface_torch/eval/__init__.py b/videoretalking/third_part/face3d/models/arcface_torch/eval/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/videoretalking/third_part/face3d/models/arcface_torch/eval/verification.py b/videoretalking/third_part/face3d/models/arcface_torch/eval/verification.py new file mode 100644 index 0000000000000000000000000000000000000000..5b1f5618184effae64895847af1a65d43d2e4418 --- /dev/null +++ b/videoretalking/third_part/face3d/models/arcface_torch/eval/verification.py @@ -0,0 +1,407 @@ +"""Helper for evaluation on the Labeled Faces in the Wild dataset +""" + +# MIT License +# +# Copyright (c) 2016 David Sandberg +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +import datetime +import os +import pickle + +import mxnet as mx +import numpy as np +import sklearn +import torch +from mxnet import ndarray as nd +from scipy import interpolate +from sklearn.decomposition import PCA +from sklearn.model_selection import KFold + + +class LFold: + def __init__(self, n_splits=2, shuffle=False): + self.n_splits = n_splits + if self.n_splits > 1: + self.k_fold = KFold(n_splits=n_splits, shuffle=shuffle) + + def split(self, indices): + if self.n_splits > 1: + return self.k_fold.split(indices) + else: + return [(indices, indices)] + + +def calculate_roc(thresholds, + embeddings1, + embeddings2, + actual_issame, + nrof_folds=10, + pca=0): + assert (embeddings1.shape[0] == embeddings2.shape[0]) + assert (embeddings1.shape[1] == embeddings2.shape[1]) + nrof_pairs = min(len(actual_issame), embeddings1.shape[0]) + nrof_thresholds = len(thresholds) + k_fold = LFold(n_splits=nrof_folds, shuffle=False) + + tprs = np.zeros((nrof_folds, nrof_thresholds)) + fprs = np.zeros((nrof_folds, nrof_thresholds)) + accuracy = np.zeros((nrof_folds)) + indices = np.arange(nrof_pairs) + + if pca == 0: + diff = np.subtract(embeddings1, embeddings2) + dist = np.sum(np.square(diff), 1) + + for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)): + if pca > 0: + print('doing pca on', fold_idx) + embed1_train = embeddings1[train_set] + embed2_train = embeddings2[train_set] + _embed_train = np.concatenate((embed1_train, embed2_train), axis=0) + pca_model = PCA(n_components=pca) + pca_model.fit(_embed_train) + embed1 = pca_model.transform(embeddings1) + embed2 = pca_model.transform(embeddings2) + embed1 = sklearn.preprocessing.normalize(embed1) + embed2 = sklearn.preprocessing.normalize(embed2) + diff = np.subtract(embed1, embed2) + dist = np.sum(np.square(diff), 1) + + # Find the best threshold for the fold + acc_train = np.zeros((nrof_thresholds)) + for threshold_idx, threshold in enumerate(thresholds): + _, _, acc_train[threshold_idx] = calculate_accuracy( + threshold, dist[train_set], actual_issame[train_set]) + best_threshold_index = np.argmax(acc_train) + for threshold_idx, threshold in enumerate(thresholds): + tprs[fold_idx, threshold_idx], fprs[fold_idx, threshold_idx], _ = calculate_accuracy( + threshold, dist[test_set], + actual_issame[test_set]) + _, _, accuracy[fold_idx] = calculate_accuracy( + thresholds[best_threshold_index], dist[test_set], + actual_issame[test_set]) + + tpr = np.mean(tprs, 0) + fpr = np.mean(fprs, 0) + return tpr, fpr, accuracy + + +def calculate_accuracy(threshold, dist, actual_issame): + predict_issame = np.less(dist, threshold) + tp = np.sum(np.logical_and(predict_issame, actual_issame)) + fp = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame))) + tn = np.sum( + np.logical_and(np.logical_not(predict_issame), + np.logical_not(actual_issame))) + fn = np.sum(np.logical_and(np.logical_not(predict_issame), actual_issame)) + + tpr = 0 if (tp + fn == 0) else float(tp) / float(tp + fn) + fpr = 0 if (fp + tn == 0) else float(fp) / float(fp + tn) + acc = float(tp + tn) / dist.size + return tpr, fpr, acc + + +def calculate_val(thresholds, + embeddings1, + embeddings2, + actual_issame, + far_target, + nrof_folds=10): + assert (embeddings1.shape[0] == embeddings2.shape[0]) + assert (embeddings1.shape[1] == embeddings2.shape[1]) + nrof_pairs = min(len(actual_issame), embeddings1.shape[0]) + nrof_thresholds = len(thresholds) + k_fold = LFold(n_splits=nrof_folds, shuffle=False) + + val = np.zeros(nrof_folds) + far = np.zeros(nrof_folds) + + diff = np.subtract(embeddings1, embeddings2) + dist = np.sum(np.square(diff), 1) + indices = np.arange(nrof_pairs) + + for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)): + + # Find the threshold that gives FAR = far_target + far_train = np.zeros(nrof_thresholds) + for threshold_idx, threshold in enumerate(thresholds): + _, far_train[threshold_idx] = calculate_val_far( + threshold, dist[train_set], actual_issame[train_set]) + if np.max(far_train) >= far_target: + f = interpolate.interp1d(far_train, thresholds, kind='slinear') + threshold = f(far_target) + else: + threshold = 0.0 + + val[fold_idx], far[fold_idx] = calculate_val_far( + threshold, dist[test_set], actual_issame[test_set]) + + val_mean = np.mean(val) + far_mean = np.mean(far) + val_std = np.std(val) + return val_mean, val_std, far_mean + + +def calculate_val_far(threshold, dist, actual_issame): + predict_issame = np.less(dist, threshold) + true_accept = np.sum(np.logical_and(predict_issame, actual_issame)) + false_accept = np.sum( + np.logical_and(predict_issame, np.logical_not(actual_issame))) + n_same = np.sum(actual_issame) + n_diff = np.sum(np.logical_not(actual_issame)) + # print(true_accept, false_accept) + # print(n_same, n_diff) + val = float(true_accept) / float(n_same) + far = float(false_accept) / float(n_diff) + return val, far + + +def evaluate(embeddings, actual_issame, nrof_folds=10, pca=0): + # Calculate evaluation metrics + thresholds = np.arange(0, 4, 0.01) + embeddings1 = embeddings[0::2] + embeddings2 = embeddings[1::2] + tpr, fpr, accuracy = calculate_roc(thresholds, + embeddings1, + embeddings2, + np.asarray(actual_issame), + nrof_folds=nrof_folds, + pca=pca) + thresholds = np.arange(0, 4, 0.001) + val, val_std, far = calculate_val(thresholds, + embeddings1, + embeddings2, + np.asarray(actual_issame), + 1e-3, + nrof_folds=nrof_folds) + return tpr, fpr, accuracy, val, val_std, far + +@torch.no_grad() +def load_bin(path, image_size): + try: + with open(path, 'rb') as f: + bins, issame_list = pickle.load(f) # py2 + except UnicodeDecodeError as e: + with open(path, 'rb') as f: + bins, issame_list = pickle.load(f, encoding='bytes') # py3 + data_list = [] + for flip in [0, 1]: + data = torch.empty((len(issame_list) * 2, 3, image_size[0], image_size[1])) + data_list.append(data) + for idx in range(len(issame_list) * 2): + _bin = bins[idx] + img = mx.image.imdecode(_bin) + if img.shape[1] != image_size[0]: + img = mx.image.resize_short(img, image_size[0]) + img = nd.transpose(img, axes=(2, 0, 1)) + for flip in [0, 1]: + if flip == 1: + img = mx.ndarray.flip(data=img, axis=2) + data_list[flip][idx][:] = torch.from_numpy(img.asnumpy()) + if idx % 1000 == 0: + print('loading bin', idx) + print(data_list[0].shape) + return data_list, issame_list + +@torch.no_grad() +def test(data_set, backbone, batch_size, nfolds=10): + print('testing verification..') + data_list = data_set[0] + issame_list = data_set[1] + embeddings_list = [] + time_consumed = 0.0 + for i in range(len(data_list)): + data = data_list[i] + embeddings = None + ba = 0 + while ba < data.shape[0]: + bb = min(ba + batch_size, data.shape[0]) + count = bb - ba + _data = data[bb - batch_size: bb] + time0 = datetime.datetime.now() + img = ((_data / 255) - 0.5) / 0.5 + net_out: torch.Tensor = backbone(img) + _embeddings = net_out.detach().cpu().numpy() + time_now = datetime.datetime.now() + diff = time_now - time0 + time_consumed += diff.total_seconds() + if embeddings is None: + embeddings = np.zeros((data.shape[0], _embeddings.shape[1])) + embeddings[ba:bb, :] = _embeddings[(batch_size - count):, :] + ba = bb + embeddings_list.append(embeddings) + + _xnorm = 0.0 + _xnorm_cnt = 0 + for embed in embeddings_list: + for i in range(embed.shape[0]): + _em = embed[i] + _norm = np.linalg.norm(_em) + _xnorm += _norm + _xnorm_cnt += 1 + _xnorm /= _xnorm_cnt + + acc1 = 0.0 + std1 = 0.0 + embeddings = embeddings_list[0] + embeddings_list[1] + embeddings = sklearn.preprocessing.normalize(embeddings) + print(embeddings.shape) + print('infer time', time_consumed) + _, _, accuracy, val, val_std, far = evaluate(embeddings, issame_list, nrof_folds=nfolds) + acc2, std2 = np.mean(accuracy), np.std(accuracy) + return acc1, std1, acc2, std2, _xnorm, embeddings_list + + +def dumpR(data_set, + backbone, + batch_size, + name='', + data_extra=None, + label_shape=None): + print('dump verification embedding..') + data_list = data_set[0] + issame_list = data_set[1] + embeddings_list = [] + time_consumed = 0.0 + for i in range(len(data_list)): + data = data_list[i] + embeddings = None + ba = 0 + while ba < data.shape[0]: + bb = min(ba + batch_size, data.shape[0]) + count = bb - ba + + _data = nd.slice_axis(data, axis=0, begin=bb - batch_size, end=bb) + time0 = datetime.datetime.now() + if data_extra is None: + db = mx.io.DataBatch(data=(_data,), label=(_label,)) + else: + db = mx.io.DataBatch(data=(_data, _data_extra), + label=(_label,)) + model.forward(db, is_train=False) + net_out = model.get_outputs() + _embeddings = net_out[0].asnumpy() + time_now = datetime.datetime.now() + diff = time_now - time0 + time_consumed += diff.total_seconds() + if embeddings is None: + embeddings = np.zeros((data.shape[0], _embeddings.shape[1])) + embeddings[ba:bb, :] = _embeddings[(batch_size - count):, :] + ba = bb + embeddings_list.append(embeddings) + embeddings = embeddings_list[0] + embeddings_list[1] + embeddings = sklearn.preprocessing.normalize(embeddings) + actual_issame = np.asarray(issame_list) + outname = os.path.join('temp.bin') + with open(outname, 'wb') as f: + pickle.dump((embeddings, issame_list), + f, + protocol=pickle.HIGHEST_PROTOCOL) + + +# if __name__ == '__main__': +# +# parser = argparse.ArgumentParser(description='do verification') +# # general +# parser.add_argument('--data-dir', default='', help='') +# parser.add_argument('--model', +# default='../model/softmax,50', +# help='path to load model.') +# parser.add_argument('--target', +# default='lfw,cfp_ff,cfp_fp,agedb_30', +# help='test targets.') +# parser.add_argument('--gpu', default=0, type=int, help='gpu id') +# parser.add_argument('--batch-size', default=32, type=int, help='') +# parser.add_argument('--max', default='', type=str, help='') +# parser.add_argument('--mode', default=0, type=int, help='') +# parser.add_argument('--nfolds', default=10, type=int, help='') +# args = parser.parse_args() +# image_size = [112, 112] +# print('image_size', image_size) +# ctx = mx.gpu(args.gpu) +# nets = [] +# vec = args.model.split(',') +# prefix = args.model.split(',')[0] +# epochs = [] +# if len(vec) == 1: +# pdir = os.path.dirname(prefix) +# for fname in os.listdir(pdir): +# if not fname.endswith('.params'): +# continue +# _file = os.path.join(pdir, fname) +# if _file.startswith(prefix): +# epoch = int(fname.split('.')[0].split('-')[1]) +# epochs.append(epoch) +# epochs = sorted(epochs, reverse=True) +# if len(args.max) > 0: +# _max = [int(x) for x in args.max.split(',')] +# assert len(_max) == 2 +# if len(epochs) > _max[1]: +# epochs = epochs[_max[0]:_max[1]] +# +# else: +# epochs = [int(x) for x in vec[1].split('|')] +# print('model number', len(epochs)) +# time0 = datetime.datetime.now() +# for epoch in epochs: +# print('loading', prefix, epoch) +# sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch) +# # arg_params, aux_params = ch_dev(arg_params, aux_params, ctx) +# all_layers = sym.get_internals() +# sym = all_layers['fc1_output'] +# model = mx.mod.Module(symbol=sym, context=ctx, label_names=None) +# # model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,))]) +# model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], +# image_size[1]))]) +# model.set_params(arg_params, aux_params) +# nets.append(model) +# time_now = datetime.datetime.now() +# diff = time_now - time0 +# print('model loading time', diff.total_seconds()) +# +# ver_list = [] +# ver_name_list = [] +# for name in args.target.split(','): +# path = os.path.join(args.data_dir, name + ".bin") +# if os.path.exists(path): +# print('loading.. ', name) +# data_set = load_bin(path, image_size) +# ver_list.append(data_set) +# ver_name_list.append(name) +# +# if args.mode == 0: +# for i in range(len(ver_list)): +# results = [] +# for model in nets: +# acc1, std1, acc2, std2, xnorm, embeddings_list = test( +# ver_list[i], model, args.batch_size, args.nfolds) +# print('[%s]XNorm: %f' % (ver_name_list[i], xnorm)) +# print('[%s]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], acc1, std1)) +# print('[%s]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], acc2, std2)) +# results.append(acc2) +# print('Max of [%s] is %1.5f' % (ver_name_list[i], np.max(results))) +# elif args.mode == 1: +# raise ValueError +# else: +# model = nets[0] +# dumpR(ver_list[0], model, args.batch_size, args.target) diff --git a/videoretalking/third_part/face3d/models/arcface_torch/eval_ijbc.py b/videoretalking/third_part/face3d/models/arcface_torch/eval_ijbc.py new file mode 100644 index 0000000000000000000000000000000000000000..64844c4723a88b4b160d2fee9a7b626b987981d9 --- /dev/null +++ b/videoretalking/third_part/face3d/models/arcface_torch/eval_ijbc.py @@ -0,0 +1,483 @@ +# coding: utf-8 + +import os +import pickle + +import matplotlib +import pandas as pd + +matplotlib.use('Agg') +import matplotlib.pyplot as plt +import timeit +import sklearn +import argparse +import cv2 +import numpy as np +import torch +from skimage import transform as trans +from backbones import get_model +from sklearn.metrics import roc_curve, auc + +from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap +from prettytable import PrettyTable +from pathlib import Path + +import sys +import warnings + +sys.path.insert(0, "../") +warnings.filterwarnings("ignore") + +parser = argparse.ArgumentParser(description='do ijb test') +# general +parser.add_argument('--model-prefix', default='', help='path to load model.') +parser.add_argument('--image-path', default='', type=str, help='') +parser.add_argument('--result-dir', default='.', type=str, help='') +parser.add_argument('--batch-size', default=128, type=int, help='') +parser.add_argument('--network', default='iresnet50', type=str, help='') +parser.add_argument('--job', default='insightface', type=str, help='job name') +parser.add_argument('--target', default='IJBC', type=str, help='target, set to IJBC or IJBB') +args = parser.parse_args() + +target = args.target +model_path = args.model_prefix +image_path = args.image_path +result_dir = args.result_dir +gpu_id = None +use_norm_score = True # if Ture, TestMode(N1) +use_detector_score = True # if Ture, TestMode(D1) +use_flip_test = True # if Ture, TestMode(F1) +job = args.job +batch_size = args.batch_size + + +class Embedding(object): + def __init__(self, prefix, data_shape, batch_size=1): + image_size = (112, 112) + self.image_size = image_size + weight = torch.load(prefix) + resnet = get_model(args.network, dropout=0, fp16=False).cuda() + resnet.load_state_dict(weight) + model = torch.nn.DataParallel(resnet) + self.model = model + self.model.eval() + src = np.array([ + [30.2946, 51.6963], + [65.5318, 51.5014], + [48.0252, 71.7366], + [33.5493, 92.3655], + [62.7299, 92.2041]], dtype=np.float32) + src[:, 0] += 8.0 + self.src = src + self.batch_size = batch_size + self.data_shape = data_shape + + def get(self, rimg, landmark): + + assert landmark.shape[0] == 68 or landmark.shape[0] == 5 + assert landmark.shape[1] == 2 + if landmark.shape[0] == 68: + landmark5 = np.zeros((5, 2), dtype=np.float32) + landmark5[0] = (landmark[36] + landmark[39]) / 2 + landmark5[1] = (landmark[42] + landmark[45]) / 2 + landmark5[2] = landmark[30] + landmark5[3] = landmark[48] + landmark5[4] = landmark[54] + else: + landmark5 = landmark + tform = trans.SimilarityTransform() + tform.estimate(landmark5, self.src) + M = tform.params[0:2, :] + img = cv2.warpAffine(rimg, + M, (self.image_size[1], self.image_size[0]), + borderValue=0.0) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img_flip = np.fliplr(img) + img = np.transpose(img, (2, 0, 1)) # 3*112*112, RGB + img_flip = np.transpose(img_flip, (2, 0, 1)) + input_blob = np.zeros((2, 3, self.image_size[1], self.image_size[0]), dtype=np.uint8) + input_blob[0] = img + input_blob[1] = img_flip + return input_blob + + @torch.no_grad() + def forward_db(self, batch_data): + imgs = torch.Tensor(batch_data).cuda() + imgs.div_(255).sub_(0.5).div_(0.5) + feat = self.model(imgs) + feat = feat.reshape([self.batch_size, 2 * feat.shape[1]]) + return feat.cpu().numpy() + + +# 将一个list尽量均分成n份,限制len(list)==n,份数大于原list内元素个数则分配空list[] +def divideIntoNstrand(listTemp, n): + twoList = [[] for i in range(n)] + for i, e in enumerate(listTemp): + twoList[i % n].append(e) + return twoList + + +def read_template_media_list(path): + # ijb_meta = np.loadtxt(path, dtype=str) + ijb_meta = pd.read_csv(path, sep=' ', header=None).values + templates = ijb_meta[:, 1].astype(np.int) + medias = ijb_meta[:, 2].astype(np.int) + return templates, medias + + +# In[ ]: + + +def read_template_pair_list(path): + # pairs = np.loadtxt(path, dtype=str) + pairs = pd.read_csv(path, sep=' ', header=None).values + # print(pairs.shape) + # print(pairs[:, 0].astype(np.int)) + t1 = pairs[:, 0].astype(np.int) + t2 = pairs[:, 1].astype(np.int) + label = pairs[:, 2].astype(np.int) + return t1, t2, label + + +# In[ ]: + + +def read_image_feature(path): + with open(path, 'rb') as fid: + img_feats = pickle.load(fid) + return img_feats + + +# In[ ]: + + +def get_image_feature(img_path, files_list, model_path, epoch, gpu_id): + batch_size = args.batch_size + data_shape = (3, 112, 112) + + files = files_list + print('files:', len(files)) + rare_size = len(files) % batch_size + faceness_scores = [] + batch = 0 + img_feats = np.empty((len(files), 1024), dtype=np.float32) + + batch_data = np.empty((2 * batch_size, 3, 112, 112)) + embedding = Embedding(model_path, data_shape, batch_size) + for img_index, each_line in enumerate(files[:len(files) - rare_size]): + name_lmk_score = each_line.strip().split(' ') + img_name = os.path.join(img_path, name_lmk_score[0]) + img = cv2.imread(img_name) + lmk = np.array([float(x) for x in name_lmk_score[1:-1]], + dtype=np.float32) + lmk = lmk.reshape((5, 2)) + input_blob = embedding.get(img, lmk) + + batch_data[2 * (img_index - batch * batch_size)][:] = input_blob[0] + batch_data[2 * (img_index - batch * batch_size) + 1][:] = input_blob[1] + if (img_index + 1) % batch_size == 0: + print('batch', batch) + img_feats[batch * batch_size:batch * batch_size + + batch_size][:] = embedding.forward_db(batch_data) + batch += 1 + faceness_scores.append(name_lmk_score[-1]) + + batch_data = np.empty((2 * rare_size, 3, 112, 112)) + embedding = Embedding(model_path, data_shape, rare_size) + for img_index, each_line in enumerate(files[len(files) - rare_size:]): + name_lmk_score = each_line.strip().split(' ') + img_name = os.path.join(img_path, name_lmk_score[0]) + img = cv2.imread(img_name) + lmk = np.array([float(x) for x in name_lmk_score[1:-1]], + dtype=np.float32) + lmk = lmk.reshape((5, 2)) + input_blob = embedding.get(img, lmk) + batch_data[2 * img_index][:] = input_blob[0] + batch_data[2 * img_index + 1][:] = input_blob[1] + if (img_index + 1) % rare_size == 0: + print('batch', batch) + img_feats[len(files) - + rare_size:][:] = embedding.forward_db(batch_data) + batch += 1 + faceness_scores.append(name_lmk_score[-1]) + faceness_scores = np.array(faceness_scores).astype(np.float32) + # img_feats = np.ones( (len(files), 1024), dtype=np.float32) * 0.01 + # faceness_scores = np.ones( (len(files), ), dtype=np.float32 ) + return img_feats, faceness_scores + + +# In[ ]: + + +def image2template_feature(img_feats=None, templates=None, medias=None): + # ========================================================== + # 1. face image feature l2 normalization. img_feats:[number_image x feats_dim] + # 2. compute media feature. + # 3. compute template feature. + # ========================================================== + unique_templates = np.unique(templates) + template_feats = np.zeros((len(unique_templates), img_feats.shape[1])) + + for count_template, uqt in enumerate(unique_templates): + + (ind_t,) = np.where(templates == uqt) + face_norm_feats = img_feats[ind_t] + face_medias = medias[ind_t] + unique_medias, unique_media_counts = np.unique(face_medias, + return_counts=True) + media_norm_feats = [] + for u, ct in zip(unique_medias, unique_media_counts): + (ind_m,) = np.where(face_medias == u) + if ct == 1: + media_norm_feats += [face_norm_feats[ind_m]] + else: # image features from the same video will be aggregated into one feature + media_norm_feats += [ + np.mean(face_norm_feats[ind_m], axis=0, keepdims=True) + ] + media_norm_feats = np.array(media_norm_feats) + # media_norm_feats = media_norm_feats / np.sqrt(np.sum(media_norm_feats ** 2, -1, keepdims=True)) + template_feats[count_template] = np.sum(media_norm_feats, axis=0) + if count_template % 2000 == 0: + print('Finish Calculating {} template features.'.format( + count_template)) + # template_norm_feats = template_feats / np.sqrt(np.sum(template_feats ** 2, -1, keepdims=True)) + template_norm_feats = sklearn.preprocessing.normalize(template_feats) + # print(template_norm_feats.shape) + return template_norm_feats, unique_templates + + +# In[ ]: + + +def verification(template_norm_feats=None, + unique_templates=None, + p1=None, + p2=None): + # ========================================================== + # Compute set-to-set Similarity Score. + # ========================================================== + template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int) + for count_template, uqt in enumerate(unique_templates): + template2id[uqt] = count_template + + score = np.zeros((len(p1),)) # save cosine distance between pairs + + total_pairs = np.array(range(len(p1))) + batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation + sublists = [ + total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize) + ] + total_sublists = len(sublists) + for c, s in enumerate(sublists): + feat1 = template_norm_feats[template2id[p1[s]]] + feat2 = template_norm_feats[template2id[p2[s]]] + similarity_score = np.sum(feat1 * feat2, -1) + score[s] = similarity_score.flatten() + if c % 10 == 0: + print('Finish {}/{} pairs.'.format(c, total_sublists)) + return score + + +# In[ ]: +def verification2(template_norm_feats=None, + unique_templates=None, + p1=None, + p2=None): + template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int) + for count_template, uqt in enumerate(unique_templates): + template2id[uqt] = count_template + score = np.zeros((len(p1),)) # save cosine distance between pairs + total_pairs = np.array(range(len(p1))) + batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation + sublists = [ + total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize) + ] + total_sublists = len(sublists) + for c, s in enumerate(sublists): + feat1 = template_norm_feats[template2id[p1[s]]] + feat2 = template_norm_feats[template2id[p2[s]]] + similarity_score = np.sum(feat1 * feat2, -1) + score[s] = similarity_score.flatten() + if c % 10 == 0: + print('Finish {}/{} pairs.'.format(c, total_sublists)) + return score + + +def read_score(path): + with open(path, 'rb') as fid: + img_feats = pickle.load(fid) + return img_feats + + +# # Step1: Load Meta Data + +# In[ ]: + +assert target == 'IJBC' or target == 'IJBB' + +# ============================================================= +# load image and template relationships for template feature embedding +# tid --> template id, mid --> media id +# format: +# image_name tid mid +# ============================================================= +start = timeit.default_timer() +templates, medias = read_template_media_list( + os.path.join('%s/meta' % image_path, + '%s_face_tid_mid.txt' % target.lower())) +stop = timeit.default_timer() +print('Time: %.2f s. ' % (stop - start)) + +# In[ ]: + +# ============================================================= +# load template pairs for template-to-template verification +# tid : template id, label : 1/0 +# format: +# tid_1 tid_2 label +# ============================================================= +start = timeit.default_timer() +p1, p2, label = read_template_pair_list( + os.path.join('%s/meta' % image_path, + '%s_template_pair_label.txt' % target.lower())) +stop = timeit.default_timer() +print('Time: %.2f s. ' % (stop - start)) + +# # Step 2: Get Image Features + +# In[ ]: + +# ============================================================= +# load image features +# format: +# img_feats: [image_num x feats_dim] (227630, 512) +# ============================================================= +start = timeit.default_timer() +img_path = '%s/loose_crop' % image_path +img_list_path = '%s/meta/%s_name_5pts_score.txt' % (image_path, target.lower()) +img_list = open(img_list_path) +files = img_list.readlines() +# files_list = divideIntoNstrand(files, rank_size) +files_list = files + +# img_feats +# for i in range(rank_size): +img_feats, faceness_scores = get_image_feature(img_path, files_list, + model_path, 0, gpu_id) +stop = timeit.default_timer() +print('Time: %.2f s. ' % (stop - start)) +print('Feature Shape: ({} , {}) .'.format(img_feats.shape[0], + img_feats.shape[1])) + +# # Step3: Get Template Features + +# In[ ]: + +# ============================================================= +# compute template features from image features. +# ============================================================= +start = timeit.default_timer() +# ========================================================== +# Norm feature before aggregation into template feature? +# Feature norm from embedding network and faceness score are able to decrease weights for noise samples (not face). +# ========================================================== +# 1. FaceScore (Feature Norm) +# 2. FaceScore (Detector) + +if use_flip_test: + # concat --- F1 + # img_input_feats = img_feats + # add --- F2 + img_input_feats = img_feats[:, 0:img_feats.shape[1] // + 2] + img_feats[:, img_feats.shape[1] // 2:] +else: + img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2] + +if use_norm_score: + img_input_feats = img_input_feats +else: + # normalise features to remove norm information + img_input_feats = img_input_feats / np.sqrt( + np.sum(img_input_feats ** 2, -1, keepdims=True)) + +if use_detector_score: + print(img_input_feats.shape, faceness_scores.shape) + img_input_feats = img_input_feats * faceness_scores[:, np.newaxis] +else: + img_input_feats = img_input_feats + +template_norm_feats, unique_templates = image2template_feature( + img_input_feats, templates, medias) +stop = timeit.default_timer() +print('Time: %.2f s. ' % (stop - start)) + +# # Step 4: Get Template Similarity Scores + +# In[ ]: + +# ============================================================= +# compute verification scores between template pairs. +# ============================================================= +start = timeit.default_timer() +score = verification(template_norm_feats, unique_templates, p1, p2) +stop = timeit.default_timer() +print('Time: %.2f s. ' % (stop - start)) + +# In[ ]: +save_path = os.path.join(result_dir, args.job) +# save_path = result_dir + '/%s_result' % target + +if not os.path.exists(save_path): + os.makedirs(save_path) + +score_save_file = os.path.join(save_path, "%s.npy" % target.lower()) +np.save(score_save_file, score) + +# # Step 5: Get ROC Curves and TPR@FPR Table + +# In[ ]: + +files = [score_save_file] +methods = [] +scores = [] +for file in files: + methods.append(Path(file).stem) + scores.append(np.load(file)) + +methods = np.array(methods) +scores = dict(zip(methods, scores)) +colours = dict( + zip(methods, sample_colours_from_colourmap(methods.shape[0], 'Set2'))) +x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1] +tpr_fpr_table = PrettyTable(['Methods'] + [str(x) for x in x_labels]) +fig = plt.figure() +for method in methods: + fpr, tpr, _ = roc_curve(label, scores[method]) + roc_auc = auc(fpr, tpr) + fpr = np.flipud(fpr) + tpr = np.flipud(tpr) # select largest tpr at same fpr + plt.plot(fpr, + tpr, + color=colours[method], + lw=1, + label=('[%s (AUC = %0.4f %%)]' % + (method.split('-')[-1], roc_auc * 100))) + tpr_fpr_row = [] + tpr_fpr_row.append("%s-%s" % (method, target)) + for fpr_iter in np.arange(len(x_labels)): + _, min_index = min( + list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr))))) + tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100)) + tpr_fpr_table.add_row(tpr_fpr_row) +plt.xlim([10 ** -6, 0.1]) +plt.ylim([0.3, 1.0]) +plt.grid(linestyle='--', linewidth=1) +plt.xticks(x_labels) +plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True)) +plt.xscale('log') +plt.xlabel('False Positive Rate') +plt.ylabel('True Positive Rate') +plt.title('ROC on IJB') +plt.legend(loc="lower right") +fig.savefig(os.path.join(save_path, '%s.pdf' % target.lower())) +print(tpr_fpr_table) diff --git a/videoretalking/third_part/face3d/models/arcface_torch/inference.py b/videoretalking/third_part/face3d/models/arcface_torch/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..1929d4abb640d040398dda57b491b9bd96deac9d --- /dev/null +++ b/videoretalking/third_part/face3d/models/arcface_torch/inference.py @@ -0,0 +1,35 @@ +import argparse + +import cv2 +import numpy as np +import torch + +from backbones import get_model + + +@torch.no_grad() +def inference(weight, name, img): + if img is None: + img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.uint8) + else: + img = cv2.imread(img) + img = cv2.resize(img, (112, 112)) + + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = np.transpose(img, (2, 0, 1)) + img = torch.from_numpy(img).unsqueeze(0).float() + img.div_(255).sub_(0.5).div_(0.5) + net = get_model(name, fp16=False) + net.load_state_dict(torch.load(weight)) + net.eval() + feat = net(img).numpy() + print(feat) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='PyTorch ArcFace Training') + parser.add_argument('--network', type=str, default='r50', help='backbone network') + parser.add_argument('--weight', type=str, default='') + parser.add_argument('--img', type=str, default=None) + args = parser.parse_args() + inference(args.weight, args.network, args.img) diff --git a/videoretalking/third_part/face3d/models/arcface_torch/losses.py b/videoretalking/third_part/face3d/models/arcface_torch/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..7bfdd8c6b7f6b0d465928f19c554e62340e5ad7b --- /dev/null +++ b/videoretalking/third_part/face3d/models/arcface_torch/losses.py @@ -0,0 +1,42 @@ +import torch +from torch import nn + + +def get_loss(name): + if name == "cosface": + return CosFace() + elif name == "arcface": + return ArcFace() + else: + raise ValueError() + + +class CosFace(nn.Module): + def __init__(self, s=64.0, m=0.40): + super(CosFace, self).__init__() + self.s = s + self.m = m + + def forward(self, cosine, label): + index = torch.where(label != -1)[0] + m_hot = torch.zeros(index.size()[0], cosine.size()[1], device=cosine.device) + m_hot.scatter_(1, label[index, None], self.m) + cosine[index] -= m_hot + ret = cosine * self.s + return ret + + +class ArcFace(nn.Module): + def __init__(self, s=64.0, m=0.5): + super(ArcFace, self).__init__() + self.s = s + self.m = m + + def forward(self, cosine: torch.Tensor, label): + index = torch.where(label != -1)[0] + m_hot = torch.zeros(index.size()[0], cosine.size()[1], device=cosine.device) + m_hot.scatter_(1, label[index, None], self.m) + cosine.acos_() + cosine[index] += m_hot + cosine.cos_().mul_(self.s) + return cosine diff --git a/videoretalking/third_part/face3d/models/arcface_torch/onnx_helper.py b/videoretalking/third_part/face3d/models/arcface_torch/onnx_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..4a01a46621dc0ea695bd903de5d1e212d424c860 --- /dev/null +++ b/videoretalking/third_part/face3d/models/arcface_torch/onnx_helper.py @@ -0,0 +1,250 @@ +from __future__ import division +import datetime +import os +import os.path as osp +import glob +import numpy as np +import cv2 +import sys +import onnxruntime +import onnx +import argparse +from onnx import numpy_helper +from insightface.data import get_image + +class ArcFaceORT: + def __init__(self, model_path, cpu=False): + self.model_path = model_path + # providers = None will use available provider, for onnxruntime-gpu it will be "CUDAExecutionProvider" + self.providers = ['CPUExecutionProvider'] if cpu else None + + #input_size is (w,h), return error message, return None if success + def check(self, track='cfat', test_img = None): + #default is cfat + max_model_size_mb=1024 + max_feat_dim=512 + max_time_cost=15 + if track.startswith('ms1m'): + max_model_size_mb=1024 + max_feat_dim=512 + max_time_cost=10 + elif track.startswith('glint'): + max_model_size_mb=1024 + max_feat_dim=1024 + max_time_cost=20 + elif track.startswith('cfat'): + max_model_size_mb = 1024 + max_feat_dim = 512 + max_time_cost = 15 + elif track.startswith('unconstrained'): + max_model_size_mb=1024 + max_feat_dim=1024 + max_time_cost=30 + else: + return "track not found" + + if not os.path.exists(self.model_path): + return "model_path not exists" + if not os.path.isdir(self.model_path): + return "model_path should be directory" + onnx_files = [] + for _file in os.listdir(self.model_path): + if _file.endswith('.onnx'): + onnx_files.append(osp.join(self.model_path, _file)) + if len(onnx_files)==0: + return "do not have onnx files" + self.model_file = sorted(onnx_files)[-1] + print('use onnx-model:', self.model_file) + try: + session = onnxruntime.InferenceSession(self.model_file, providers=self.providers) + except: + return "load onnx failed" + input_cfg = session.get_inputs()[0] + input_shape = input_cfg.shape + print('input-shape:', input_shape) + if len(input_shape)!=4: + return "length of input_shape should be 4" + if not isinstance(input_shape[0], str): + #return "input_shape[0] should be str to support batch-inference" + print('reset input-shape[0] to None') + model = onnx.load(self.model_file) + model.graph.input[0].type.tensor_type.shape.dim[0].dim_param = 'None' + new_model_file = osp.join(self.model_path, 'zzzzrefined.onnx') + onnx.save(model, new_model_file) + self.model_file = new_model_file + print('use new onnx-model:', self.model_file) + try: + session = onnxruntime.InferenceSession(self.model_file, providers=self.providers) + except: + return "load onnx failed" + input_cfg = session.get_inputs()[0] + input_shape = input_cfg.shape + print('new-input-shape:', input_shape) + + self.image_size = tuple(input_shape[2:4][::-1]) + #print('image_size:', self.image_size) + input_name = input_cfg.name + outputs = session.get_outputs() + output_names = [] + for o in outputs: + output_names.append(o.name) + #print(o.name, o.shape) + if len(output_names)!=1: + return "number of output nodes should be 1" + self.session = session + self.input_name = input_name + self.output_names = output_names + #print(self.output_names) + model = onnx.load(self.model_file) + graph = model.graph + if len(graph.node)<8: + return "too small onnx graph" + + input_size = (112,112) + self.crop = None + if track=='cfat': + crop_file = osp.join(self.model_path, 'crop.txt') + if osp.exists(crop_file): + lines = open(crop_file,'r').readlines() + if len(lines)!=6: + return "crop.txt should contain 6 lines" + lines = [int(x) for x in lines] + self.crop = lines[:4] + input_size = tuple(lines[4:6]) + if input_size!=self.image_size: + return "input-size is inconsistant with onnx model input, %s vs %s"%(input_size, self.image_size) + + self.model_size_mb = os.path.getsize(self.model_file) / float(1024*1024) + if self.model_size_mb > max_model_size_mb: + return "max model size exceed, given %.3f-MB"%self.model_size_mb + + input_mean = None + input_std = None + if track=='cfat': + pn_file = osp.join(self.model_path, 'pixel_norm.txt') + if osp.exists(pn_file): + lines = open(pn_file,'r').readlines() + if len(lines)!=2: + return "pixel_norm.txt should contain 2 lines" + input_mean = float(lines[0]) + input_std = float(lines[1]) + if input_mean is not None or input_std is not None: + if input_mean is None or input_std is None: + return "please set input_mean and input_std simultaneously" + else: + find_sub = False + find_mul = False + for nid, node in enumerate(graph.node[:8]): + print(nid, node.name) + if node.name.startswith('Sub') or node.name.startswith('_minus'): + find_sub = True + if node.name.startswith('Mul') or node.name.startswith('_mul') or node.name.startswith('Div'): + find_mul = True + if find_sub and find_mul: + print("find sub and mul") + #mxnet arcface model + input_mean = 0.0 + input_std = 1.0 + else: + input_mean = 127.5 + input_std = 127.5 + self.input_mean = input_mean + self.input_std = input_std + for initn in graph.initializer: + weight_array = numpy_helper.to_array(initn) + dt = weight_array.dtype + if dt.itemsize<4: + return 'invalid weight type - (%s:%s)' % (initn.name, dt.name) + if test_img is None: + test_img = get_image('Tom_Hanks_54745') + test_img = cv2.resize(test_img, self.image_size) + else: + test_img = cv2.resize(test_img, self.image_size) + feat, cost = self.benchmark(test_img) + batch_result = self.check_batch(test_img) + batch_result_sum = float(np.sum(batch_result)) + if batch_result_sum in [float('inf'), -float('inf')] or batch_result_sum != batch_result_sum: + print(batch_result) + print(batch_result_sum) + return "batch result output contains NaN!" + + if len(feat.shape) < 2: + return "the shape of the feature must be two, but get {}".format(str(feat.shape)) + + if feat.shape[1] > max_feat_dim: + return "max feat dim exceed, given %d"%feat.shape[1] + self.feat_dim = feat.shape[1] + cost_ms = cost*1000 + if cost_ms>max_time_cost: + return "max time cost exceed, given %.4f"%cost_ms + self.cost_ms = cost_ms + print('check stat:, model-size-mb: %.4f, feat-dim: %d, time-cost-ms: %.4f, input-mean: %.3f, input-std: %.3f'%(self.model_size_mb, self.feat_dim, self.cost_ms, self.input_mean, self.input_std)) + return None + + def check_batch(self, img): + if not isinstance(img, list): + imgs = [img, ] * 32 + if self.crop is not None: + nimgs = [] + for img in imgs: + nimg = img[self.crop[1]:self.crop[3], self.crop[0]:self.crop[2], :] + if nimg.shape[0] != self.image_size[1] or nimg.shape[1] != self.image_size[0]: + nimg = cv2.resize(nimg, self.image_size) + nimgs.append(nimg) + imgs = nimgs + blob = cv2.dnn.blobFromImages( + images=imgs, scalefactor=1.0 / self.input_std, size=self.image_size, + mean=(self.input_mean, self.input_mean, self.input_mean), swapRB=True) + net_out = self.session.run(self.output_names, {self.input_name: blob})[0] + return net_out + + + def meta_info(self): + return {'model-size-mb':self.model_size_mb, 'feature-dim':self.feat_dim, 'infer': self.cost_ms} + + + def forward(self, imgs): + if not isinstance(imgs, list): + imgs = [imgs] + input_size = self.image_size + if self.crop is not None: + nimgs = [] + for img in imgs: + nimg = img[self.crop[1]:self.crop[3],self.crop[0]:self.crop[2],:] + if nimg.shape[0]!=input_size[1] or nimg.shape[1]!=input_size[0]: + nimg = cv2.resize(nimg, input_size) + nimgs.append(nimg) + imgs = nimgs + blob = cv2.dnn.blobFromImages(imgs, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True) + net_out = self.session.run(self.output_names, {self.input_name : blob})[0] + return net_out + + def benchmark(self, img): + input_size = self.image_size + if self.crop is not None: + nimg = img[self.crop[1]:self.crop[3],self.crop[0]:self.crop[2],:] + if nimg.shape[0]!=input_size[1] or nimg.shape[1]!=input_size[0]: + nimg = cv2.resize(nimg, input_size) + img = nimg + blob = cv2.dnn.blobFromImage(img, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True) + costs = [] + for _ in range(50): + ta = datetime.datetime.now() + net_out = self.session.run(self.output_names, {self.input_name : blob})[0] + tb = datetime.datetime.now() + cost = (tb-ta).total_seconds() + costs.append(cost) + costs = sorted(costs) + cost = costs[5] + return net_out, cost + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='') + # general + parser.add_argument('workdir', help='submitted work dir', type=str) + parser.add_argument('--track', help='track name, for different challenge', type=str, default='cfat') + args = parser.parse_args() + handler = ArcFaceORT(args.workdir) + err = handler.check(args.track) + print('err:', err) diff --git a/videoretalking/third_part/face3d/models/arcface_torch/onnx_ijbc.py b/videoretalking/third_part/face3d/models/arcface_torch/onnx_ijbc.py new file mode 100644 index 0000000000000000000000000000000000000000..aa96b96745e23d4d6642d99f71456c10af5e4e4e --- /dev/null +++ b/videoretalking/third_part/face3d/models/arcface_torch/onnx_ijbc.py @@ -0,0 +1,267 @@ +import argparse +import os +import pickle +import timeit + +import cv2 +import mxnet as mx +import numpy as np +import pandas as pd +import prettytable +import skimage.transform +from sklearn.metrics import roc_curve +from sklearn.preprocessing import normalize + +from onnx_helper import ArcFaceORT + +SRC = np.array( + [ + [30.2946, 51.6963], + [65.5318, 51.5014], + [48.0252, 71.7366], + [33.5493, 92.3655], + [62.7299, 92.2041]] + , dtype=np.float32) +SRC[:, 0] += 8.0 + + +class AlignedDataSet(mx.gluon.data.Dataset): + def __init__(self, root, lines, align=True): + self.lines = lines + self.root = root + self.align = align + + def __len__(self): + return len(self.lines) + + def __getitem__(self, idx): + each_line = self.lines[idx] + name_lmk_score = each_line.strip().split(' ') + name = os.path.join(self.root, name_lmk_score[0]) + img = cv2.cvtColor(cv2.imread(name), cv2.COLOR_BGR2RGB) + landmark5 = np.array([float(x) for x in name_lmk_score[1:-1]], dtype=np.float32).reshape((5, 2)) + st = skimage.transform.SimilarityTransform() + st.estimate(landmark5, SRC) + img = cv2.warpAffine(img, st.params[0:2, :], (112, 112), borderValue=0.0) + img_1 = np.expand_dims(img, 0) + img_2 = np.expand_dims(np.fliplr(img), 0) + output = np.concatenate((img_1, img_2), axis=0).astype(np.float32) + output = np.transpose(output, (0, 3, 1, 2)) + output = mx.nd.array(output) + return output + + +def extract(model_root, dataset): + model = ArcFaceORT(model_path=model_root) + model.check() + feat_mat = np.zeros(shape=(len(dataset), 2 * model.feat_dim)) + + def batchify_fn(data): + return mx.nd.concat(*data, dim=0) + + data_loader = mx.gluon.data.DataLoader( + dataset, 128, last_batch='keep', num_workers=4, + thread_pool=True, prefetch=16, batchify_fn=batchify_fn) + num_iter = 0 + for batch in data_loader: + batch = batch.asnumpy() + batch = (batch - model.input_mean) / model.input_std + feat = model.session.run(model.output_names, {model.input_name: batch})[0] + feat = np.reshape(feat, (-1, model.feat_dim * 2)) + feat_mat[128 * num_iter: 128 * num_iter + feat.shape[0], :] = feat + num_iter += 1 + if num_iter % 50 == 0: + print(num_iter) + return feat_mat + + +def read_template_media_list(path): + ijb_meta = pd.read_csv(path, sep=' ', header=None).values + templates = ijb_meta[:, 1].astype(np.int) + medias = ijb_meta[:, 2].astype(np.int) + return templates, medias + + +def read_template_pair_list(path): + pairs = pd.read_csv(path, sep=' ', header=None).values + t1 = pairs[:, 0].astype(np.int) + t2 = pairs[:, 1].astype(np.int) + label = pairs[:, 2].astype(np.int) + return t1, t2, label + + +def read_image_feature(path): + with open(path, 'rb') as fid: + img_feats = pickle.load(fid) + return img_feats + + +def image2template_feature(img_feats=None, + templates=None, + medias=None): + unique_templates = np.unique(templates) + template_feats = np.zeros((len(unique_templates), img_feats.shape[1])) + for count_template, uqt in enumerate(unique_templates): + (ind_t,) = np.where(templates == uqt) + face_norm_feats = img_feats[ind_t] + face_medias = medias[ind_t] + unique_medias, unique_media_counts = np.unique(face_medias, return_counts=True) + media_norm_feats = [] + for u, ct in zip(unique_medias, unique_media_counts): + (ind_m,) = np.where(face_medias == u) + if ct == 1: + media_norm_feats += [face_norm_feats[ind_m]] + else: # image features from the same video will be aggregated into one feature + media_norm_feats += [np.mean(face_norm_feats[ind_m], axis=0, keepdims=True), ] + media_norm_feats = np.array(media_norm_feats) + template_feats[count_template] = np.sum(media_norm_feats, axis=0) + if count_template % 2000 == 0: + print('Finish Calculating {} template features.'.format( + count_template)) + template_norm_feats = normalize(template_feats) + return template_norm_feats, unique_templates + + +def verification(template_norm_feats=None, + unique_templates=None, + p1=None, + p2=None): + template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int) + for count_template, uqt in enumerate(unique_templates): + template2id[uqt] = count_template + score = np.zeros((len(p1),)) + total_pairs = np.array(range(len(p1))) + batchsize = 100000 + sublists = [total_pairs[i: i + batchsize] for i in range(0, len(p1), batchsize)] + total_sublists = len(sublists) + for c, s in enumerate(sublists): + feat1 = template_norm_feats[template2id[p1[s]]] + feat2 = template_norm_feats[template2id[p2[s]]] + similarity_score = np.sum(feat1 * feat2, -1) + score[s] = similarity_score.flatten() + if c % 10 == 0: + print('Finish {}/{} pairs.'.format(c, total_sublists)) + return score + + +def verification2(template_norm_feats=None, + unique_templates=None, + p1=None, + p2=None): + template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int) + for count_template, uqt in enumerate(unique_templates): + template2id[uqt] = count_template + score = np.zeros((len(p1),)) # save cosine distance between pairs + total_pairs = np.array(range(len(p1))) + batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation + sublists = [total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize)] + total_sublists = len(sublists) + for c, s in enumerate(sublists): + feat1 = template_norm_feats[template2id[p1[s]]] + feat2 = template_norm_feats[template2id[p2[s]]] + similarity_score = np.sum(feat1 * feat2, -1) + score[s] = similarity_score.flatten() + if c % 10 == 0: + print('Finish {}/{} pairs.'.format(c, total_sublists)) + return score + + +def main(args): + use_norm_score = True # if Ture, TestMode(N1) + use_detector_score = True # if Ture, TestMode(D1) + use_flip_test = True # if Ture, TestMode(F1) + assert args.target == 'IJBC' or args.target == 'IJBB' + + start = timeit.default_timer() + templates, medias = read_template_media_list( + os.path.join('%s/meta' % args.image_path, '%s_face_tid_mid.txt' % args.target.lower())) + stop = timeit.default_timer() + print('Time: %.2f s. ' % (stop - start)) + + start = timeit.default_timer() + p1, p2, label = read_template_pair_list( + os.path.join('%s/meta' % args.image_path, + '%s_template_pair_label.txt' % args.target.lower())) + stop = timeit.default_timer() + print('Time: %.2f s. ' % (stop - start)) + + start = timeit.default_timer() + img_path = '%s/loose_crop' % args.image_path + img_list_path = '%s/meta/%s_name_5pts_score.txt' % (args.image_path, args.target.lower()) + img_list = open(img_list_path) + files = img_list.readlines() + dataset = AlignedDataSet(root=img_path, lines=files, align=True) + img_feats = extract(args.model_root, dataset) + + faceness_scores = [] + for each_line in files: + name_lmk_score = each_line.split() + faceness_scores.append(name_lmk_score[-1]) + faceness_scores = np.array(faceness_scores).astype(np.float32) + stop = timeit.default_timer() + print('Time: %.2f s. ' % (stop - start)) + print('Feature Shape: ({} , {}) .'.format(img_feats.shape[0], img_feats.shape[1])) + start = timeit.default_timer() + + if use_flip_test: + img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2] + img_feats[:, img_feats.shape[1] // 2:] + else: + img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2] + + if use_norm_score: + img_input_feats = img_input_feats + else: + img_input_feats = img_input_feats / np.sqrt(np.sum(img_input_feats ** 2, -1, keepdims=True)) + + if use_detector_score: + print(img_input_feats.shape, faceness_scores.shape) + img_input_feats = img_input_feats * faceness_scores[:, np.newaxis] + else: + img_input_feats = img_input_feats + + template_norm_feats, unique_templates = image2template_feature( + img_input_feats, templates, medias) + stop = timeit.default_timer() + print('Time: %.2f s. ' % (stop - start)) + + start = timeit.default_timer() + score = verification(template_norm_feats, unique_templates, p1, p2) + stop = timeit.default_timer() + print('Time: %.2f s. ' % (stop - start)) + save_path = os.path.join(args.result_dir, "{}_result".format(args.target)) + if not os.path.exists(save_path): + os.makedirs(save_path) + score_save_file = os.path.join(save_path, "{}.npy".format(args.model_root)) + np.save(score_save_file, score) + files = [score_save_file] + methods = [] + scores = [] + for file in files: + methods.append(os.path.basename(file)) + scores.append(np.load(file)) + methods = np.array(methods) + scores = dict(zip(methods, scores)) + x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1] + tpr_fpr_table = prettytable.PrettyTable(['Methods'] + [str(x) for x in x_labels]) + for method in methods: + fpr, tpr, _ = roc_curve(label, scores[method]) + fpr = np.flipud(fpr) + tpr = np.flipud(tpr) + tpr_fpr_row = [] + tpr_fpr_row.append("%s-%s" % (method, args.target)) + for fpr_iter in np.arange(len(x_labels)): + _, min_index = min( + list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr))))) + tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100)) + tpr_fpr_table.add_row(tpr_fpr_row) + print(tpr_fpr_table) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='do ijb test') + # general + parser.add_argument('--model-root', default='', help='path to load model.') + parser.add_argument('--image-path', default='', type=str, help='') + parser.add_argument('--result-dir', default='.', type=str, help='') + parser.add_argument('--target', default='IJBC', type=str, help='target, set to IJBC or IJBB') + main(parser.parse_args()) diff --git a/videoretalking/third_part/face3d/models/arcface_torch/partial_fc.py b/videoretalking/third_part/face3d/models/arcface_torch/partial_fc.py new file mode 100644 index 0000000000000000000000000000000000000000..e0286dd437319c920ecb61f4eb3a32333dcf49eb --- /dev/null +++ b/videoretalking/third_part/face3d/models/arcface_torch/partial_fc.py @@ -0,0 +1,222 @@ +import logging +import os + +import torch +import torch.distributed as dist +from torch.nn import Module +from torch.nn.functional import normalize, linear +from torch.nn.parameter import Parameter + + +class PartialFC(Module): + """ + Author: {Xiang An, Yang Xiao, XuHan Zhu} in DeepGlint, + Partial FC: Training 10 Million Identities on a Single Machine + See the original paper: + https://arxiv.org/abs/2010.05222 + """ + + @torch.no_grad() + def __init__(self, rank, local_rank, world_size, batch_size, resume, + margin_softmax, num_classes, sample_rate=1.0, embedding_size=512, prefix="./"): + """ + rank: int + Unique process(GPU) ID from 0 to world_size - 1. + local_rank: int + Unique process(GPU) ID within the server from 0 to 7. + world_size: int + Number of GPU. + batch_size: int + Batch size on current rank(GPU). + resume: bool + Select whether to restore the weight of softmax. + margin_softmax: callable + A function of margin softmax, eg: cosface, arcface. + num_classes: int + The number of class center storage in current rank(CPU/GPU), usually is total_classes // world_size, + required. + sample_rate: float + The partial fc sampling rate, when the number of classes increases to more than 2 millions, Sampling + can greatly speed up training, and reduce a lot of GPU memory, default is 1.0. + embedding_size: int + The feature dimension, default is 512. + prefix: str + Path for save checkpoint, default is './'. + """ + super(PartialFC, self).__init__() + # + self.num_classes: int = num_classes + self.rank: int = rank + self.local_rank: int = local_rank + self.device: torch.device = torch.device("cuda:{}".format(self.local_rank)) + self.world_size: int = world_size + self.batch_size: int = batch_size + self.margin_softmax: callable = margin_softmax + self.sample_rate: float = sample_rate + self.embedding_size: int = embedding_size + self.prefix: str = prefix + self.num_local: int = num_classes // world_size + int(rank < num_classes % world_size) + self.class_start: int = num_classes // world_size * rank + min(rank, num_classes % world_size) + self.num_sample: int = int(self.sample_rate * self.num_local) + + self.weight_name = os.path.join(self.prefix, "rank_{}_softmax_weight.pt".format(self.rank)) + self.weight_mom_name = os.path.join(self.prefix, "rank_{}_softmax_weight_mom.pt".format(self.rank)) + + if resume: + try: + self.weight: torch.Tensor = torch.load(self.weight_name) + self.weight_mom: torch.Tensor = torch.load(self.weight_mom_name) + if self.weight.shape[0] != self.num_local or self.weight_mom.shape[0] != self.num_local: + raise IndexError + logging.info("softmax weight resume successfully!") + logging.info("softmax weight mom resume successfully!") + except (FileNotFoundError, KeyError, IndexError): + self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device) + self.weight_mom: torch.Tensor = torch.zeros_like(self.weight) + logging.info("softmax weight init!") + logging.info("softmax weight mom init!") + else: + self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device) + self.weight_mom: torch.Tensor = torch.zeros_like(self.weight) + logging.info("softmax weight init successfully!") + logging.info("softmax weight mom init successfully!") + self.stream: torch.cuda.Stream = torch.cuda.Stream(local_rank) + + self.index = None + if int(self.sample_rate) == 1: + self.update = lambda: 0 + self.sub_weight = Parameter(self.weight) + self.sub_weight_mom = self.weight_mom + else: + self.sub_weight = Parameter(torch.empty((0, 0)).cuda(local_rank)) + + def save_params(self): + """ Save softmax weight for each rank on prefix + """ + torch.save(self.weight.data, self.weight_name) + torch.save(self.weight_mom, self.weight_mom_name) + + @torch.no_grad() + def sample(self, total_label): + """ + Sample all positive class centers in each rank, and random select neg class centers to filling a fixed + `num_sample`. + + total_label: tensor + Label after all gather, which cross all GPUs. + """ + index_positive = (self.class_start <= total_label) & (total_label < self.class_start + self.num_local) + total_label[~index_positive] = -1 + total_label[index_positive] -= self.class_start + if int(self.sample_rate) != 1: + positive = torch.unique(total_label[index_positive], sorted=True) + if self.num_sample - positive.size(0) >= 0: + perm = torch.rand(size=[self.num_local], device=self.device) + perm[positive] = 2.0 + index = torch.topk(perm, k=self.num_sample)[1] + index = index.sort()[0] + else: + index = positive + self.index = index + total_label[index_positive] = torch.searchsorted(index, total_label[index_positive]) + self.sub_weight = Parameter(self.weight[index]) + self.sub_weight_mom = self.weight_mom[index] + + def forward(self, total_features, norm_weight): + """ Partial fc forward, `logits = X * sample(W)` + """ + torch.cuda.current_stream().wait_stream(self.stream) + logits = linear(total_features, norm_weight) + return logits + + @torch.no_grad() + def update(self): + """ Set updated weight and weight_mom to memory bank. + """ + self.weight_mom[self.index] = self.sub_weight_mom + self.weight[self.index] = self.sub_weight + + def prepare(self, label, optimizer): + """ + get sampled class centers for cal softmax. + + label: tensor + Label tensor on each rank. + optimizer: opt + Optimizer for partial fc, which need to get weight mom. + """ + with torch.cuda.stream(self.stream): + total_label = torch.zeros( + size=[self.batch_size * self.world_size], device=self.device, dtype=torch.long) + dist.all_gather(list(total_label.chunk(self.world_size, dim=0)), label) + self.sample(total_label) + optimizer.state.pop(optimizer.param_groups[-1]['params'][0], None) + optimizer.param_groups[-1]['params'][0] = self.sub_weight + optimizer.state[self.sub_weight]['momentum_buffer'] = self.sub_weight_mom + norm_weight = normalize(self.sub_weight) + return total_label, norm_weight + + def forward_backward(self, label, features, optimizer): + """ + Partial fc forward and backward with model parallel + + label: tensor + Label tensor on each rank(GPU) + features: tensor + Features tensor on each rank(GPU) + optimizer: optimizer + Optimizer for partial fc + + Returns: + -------- + x_grad: tensor + The gradient of features. + loss_v: tensor + Loss value for cross entropy. + """ + total_label, norm_weight = self.prepare(label, optimizer) + total_features = torch.zeros( + size=[self.batch_size * self.world_size, self.embedding_size], device=self.device) + dist.all_gather(list(total_features.chunk(self.world_size, dim=0)), features.data) + total_features.requires_grad = True + + logits = self.forward(total_features, norm_weight) + logits = self.margin_softmax(logits, total_label) + + with torch.no_grad(): + max_fc = torch.max(logits, dim=1, keepdim=True)[0] + dist.all_reduce(max_fc, dist.ReduceOp.MAX) + + # calculate exp(logits) and all-reduce + logits_exp = torch.exp(logits - max_fc) + logits_sum_exp = logits_exp.sum(dim=1, keepdims=True) + dist.all_reduce(logits_sum_exp, dist.ReduceOp.SUM) + + # calculate prob + logits_exp.div_(logits_sum_exp) + + # get one-hot + grad = logits_exp + index = torch.where(total_label != -1)[0] + one_hot = torch.zeros(size=[index.size()[0], grad.size()[1]], device=grad.device) + one_hot.scatter_(1, total_label[index, None], 1) + + # calculate loss + loss = torch.zeros(grad.size()[0], 1, device=grad.device) + loss[index] = grad[index].gather(1, total_label[index, None]) + dist.all_reduce(loss, dist.ReduceOp.SUM) + loss_v = loss.clamp_min_(1e-30).log_().mean() * (-1) + + # calculate grad + grad[index] -= one_hot + grad.div_(self.batch_size * self.world_size) + + logits.backward(grad) + if total_features.grad is not None: + total_features.grad.detach_() + x_grad: torch.Tensor = torch.zeros_like(features, requires_grad=True) + # feature gradient all-reduce + dist.reduce_scatter(x_grad, list(total_features.grad.chunk(self.world_size, dim=0))) + x_grad = x_grad * self.world_size + # backward backbone + return x_grad, loss_v diff --git a/videoretalking/third_part/face3d/models/arcface_torch/requirement.txt b/videoretalking/third_part/face3d/models/arcface_torch/requirement.txt new file mode 100644 index 0000000000000000000000000000000000000000..99aef673e30b99cbe56ce82a564c1df9df24ba21 --- /dev/null +++ b/videoretalking/third_part/face3d/models/arcface_torch/requirement.txt @@ -0,0 +1,5 @@ +tensorboard +easydict +mxnet +onnx +sklearn diff --git a/videoretalking/third_part/face3d/models/arcface_torch/run.sh b/videoretalking/third_part/face3d/models/arcface_torch/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..67b25fd63ef3921733d81d5be844aacc5a5c84ed --- /dev/null +++ b/videoretalking/third_part/face3d/models/arcface_torch/run.sh @@ -0,0 +1,2 @@ +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r50 +ps -ef | grep "train" | grep -v grep | awk '{print "kill -9 "$2}' | sh diff --git a/videoretalking/third_part/face3d/models/arcface_torch/torch2onnx.py b/videoretalking/third_part/face3d/models/arcface_torch/torch2onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..458660df7cc7f9a567aaf492c45f232e776a9ef0 --- /dev/null +++ b/videoretalking/third_part/face3d/models/arcface_torch/torch2onnx.py @@ -0,0 +1,59 @@ +import numpy as np +import onnx +import torch + + +def convert_onnx(net, path_module, output, opset=11, simplify=False): + assert isinstance(net, torch.nn.Module) + img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32) + img = img.astype(np.float) + img = (img / 255. - 0.5) / 0.5 # torch style norm + img = img.transpose((2, 0, 1)) + img = torch.from_numpy(img).unsqueeze(0).float() + + weight = torch.load(path_module) + net.load_state_dict(weight) + net.eval() + torch.onnx.export(net, img, output, keep_initializers_as_inputs=False, verbose=False, opset_version=opset) + model = onnx.load(output) + graph = model.graph + graph.input[0].type.tensor_type.shape.dim[0].dim_param = 'None' + if simplify: + from onnxsim import simplify + model, check = simplify(model) + assert check, "Simplified ONNX model could not be validated" + onnx.save(model, output) + + +if __name__ == '__main__': + import os + import argparse + from backbones import get_model + + parser = argparse.ArgumentParser(description='ArcFace PyTorch to onnx') + parser.add_argument('input', type=str, help='input backbone.pth file or path') + parser.add_argument('--output', type=str, default=None, help='output onnx path') + parser.add_argument('--network', type=str, default=None, help='backbone network') + parser.add_argument('--simplify', type=bool, default=False, help='onnx simplify') + args = parser.parse_args() + input_file = args.input + if os.path.isdir(input_file): + input_file = os.path.join(input_file, "backbone.pth") + assert os.path.exists(input_file) + model_name = os.path.basename(os.path.dirname(input_file)).lower() + params = model_name.split("_") + if len(params) >= 3 and params[1] in ('arcface', 'cosface'): + if args.network is None: + args.network = params[2] + assert args.network is not None + print(args) + backbone_onnx = get_model(args.network, dropout=0) + + output_path = args.output + if output_path is None: + output_path = os.path.join(os.path.dirname(__file__), 'onnx') + if not os.path.exists(output_path): + os.makedirs(output_path) + assert os.path.isdir(output_path) + output_file = os.path.join(output_path, "%s.onnx" % model_name) + convert_onnx(backbone_onnx, input_file, output_file, simplify=args.simplify) diff --git a/videoretalking/third_part/face3d/models/arcface_torch/train.py b/videoretalking/third_part/face3d/models/arcface_torch/train.py new file mode 100644 index 0000000000000000000000000000000000000000..0c5491de9af8fc7a2f3d0648c53b89584864f20e --- /dev/null +++ b/videoretalking/third_part/face3d/models/arcface_torch/train.py @@ -0,0 +1,141 @@ +import argparse +import logging +import os + +import torch +import torch.distributed as dist +import torch.nn.functional as F +import torch.utils.data.distributed +from torch.nn.utils import clip_grad_norm_ + +import losses +from backbones import get_model +from dataset import MXFaceDataset, SyntheticDataset, DataLoaderX +from partial_fc import PartialFC +from utils.utils_amp import MaxClipGradScaler +from utils.utils_callbacks import CallBackVerification, CallBackLogging, CallBackModelCheckpoint +from utils.utils_config import get_config +from utils.utils_logging import AverageMeter, init_logging + + +def main(args): + cfg = get_config(args.config) + try: + world_size = int(os.environ['WORLD_SIZE']) + rank = int(os.environ['RANK']) + dist.init_process_group('nccl') + except KeyError: + world_size = 1 + rank = 0 + dist.init_process_group(backend='nccl', init_method="tcp://127.0.0.1:12584", rank=rank, world_size=world_size) + + local_rank = args.local_rank + torch.cuda.set_device(local_rank) + os.makedirs(cfg.output, exist_ok=True) + init_logging(rank, cfg.output) + + if cfg.rec == "synthetic": + train_set = SyntheticDataset(local_rank=local_rank) + else: + train_set = MXFaceDataset(root_dir=cfg.rec, local_rank=local_rank) + + train_sampler = torch.utils.data.distributed.DistributedSampler(train_set, shuffle=True) + train_loader = DataLoaderX( + local_rank=local_rank, dataset=train_set, batch_size=cfg.batch_size, + sampler=train_sampler, num_workers=2, pin_memory=True, drop_last=True) + backbone = get_model(cfg.network, dropout=0.0, fp16=cfg.fp16, num_features=cfg.embedding_size).to(local_rank) + + if cfg.resume: + try: + backbone_pth = os.path.join(cfg.output, "backbone.pth") + backbone.load_state_dict(torch.load(backbone_pth, map_location=torch.device(local_rank))) + if rank == 0: + logging.info("backbone resume successfully!") + except (FileNotFoundError, KeyError, IndexError, RuntimeError): + if rank == 0: + logging.info("resume fail, backbone init successfully!") + + backbone = torch.nn.parallel.DistributedDataParallel( + module=backbone, broadcast_buffers=False, device_ids=[local_rank]) + backbone.train() + margin_softmax = losses.get_loss(cfg.loss) + module_partial_fc = PartialFC( + rank=rank, local_rank=local_rank, world_size=world_size, resume=cfg.resume, + batch_size=cfg.batch_size, margin_softmax=margin_softmax, num_classes=cfg.num_classes, + sample_rate=cfg.sample_rate, embedding_size=cfg.embedding_size, prefix=cfg.output) + + opt_backbone = torch.optim.SGD( + params=[{'params': backbone.parameters()}], + lr=cfg.lr / 512 * cfg.batch_size * world_size, + momentum=0.9, weight_decay=cfg.weight_decay) + opt_pfc = torch.optim.SGD( + params=[{'params': module_partial_fc.parameters()}], + lr=cfg.lr / 512 * cfg.batch_size * world_size, + momentum=0.9, weight_decay=cfg.weight_decay) + + num_image = len(train_set) + total_batch_size = cfg.batch_size * world_size + cfg.warmup_step = num_image // total_batch_size * cfg.warmup_epoch + cfg.total_step = num_image // total_batch_size * cfg.num_epoch + + def lr_step_func(current_step): + cfg.decay_step = [x * num_image // total_batch_size for x in cfg.decay_epoch] + if current_step < cfg.warmup_step: + return current_step / cfg.warmup_step + else: + return 0.1 ** len([m for m in cfg.decay_step if m <= current_step]) + + scheduler_backbone = torch.optim.lr_scheduler.LambdaLR( + optimizer=opt_backbone, lr_lambda=lr_step_func) + scheduler_pfc = torch.optim.lr_scheduler.LambdaLR( + optimizer=opt_pfc, lr_lambda=lr_step_func) + + for key, value in cfg.items(): + num_space = 25 - len(key) + logging.info(": " + key + " " * num_space + str(value)) + + val_target = cfg.val_targets + callback_verification = CallBackVerification(2000, rank, val_target, cfg.rec) + callback_logging = CallBackLogging(50, rank, cfg.total_step, cfg.batch_size, world_size, None) + callback_checkpoint = CallBackModelCheckpoint(rank, cfg.output) + + loss = AverageMeter() + start_epoch = 0 + global_step = 0 + grad_amp = MaxClipGradScaler(cfg.batch_size, 128 * cfg.batch_size, growth_interval=100) if cfg.fp16 else None + for epoch in range(start_epoch, cfg.num_epoch): + train_sampler.set_epoch(epoch) + for step, (img, label) in enumerate(train_loader): + global_step += 1 + features = F.normalize(backbone(img)) + x_grad, loss_v = module_partial_fc.forward_backward(label, features, opt_pfc) + if cfg.fp16: + features.backward(grad_amp.scale(x_grad)) + grad_amp.unscale_(opt_backbone) + clip_grad_norm_(backbone.parameters(), max_norm=5, norm_type=2) + grad_amp.step(opt_backbone) + grad_amp.update() + else: + features.backward(x_grad) + clip_grad_norm_(backbone.parameters(), max_norm=5, norm_type=2) + opt_backbone.step() + + opt_pfc.step() + module_partial_fc.update() + opt_backbone.zero_grad() + opt_pfc.zero_grad() + loss.update(loss_v, 1) + callback_logging(global_step, loss, epoch, cfg.fp16, scheduler_backbone.get_last_lr()[0], grad_amp) + callback_verification(global_step, backbone) + scheduler_backbone.step() + scheduler_pfc.step() + callback_checkpoint(global_step, backbone, module_partial_fc) + dist.destroy_process_group() + + +if __name__ == "__main__": + torch.backends.cudnn.benchmark = True + parser = argparse.ArgumentParser(description='PyTorch ArcFace Training') + parser.add_argument('config', type=str, help='py config file') + parser.add_argument('--local_rank', type=int, default=0, help='local_rank') + main(parser.parse_args()) diff --git a/videoretalking/third_part/face3d/models/arcface_torch/utils/__init__.py b/videoretalking/third_part/face3d/models/arcface_torch/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/videoretalking/third_part/face3d/models/arcface_torch/utils/plot.py b/videoretalking/third_part/face3d/models/arcface_torch/utils/plot.py new file mode 100644 index 0000000000000000000000000000000000000000..4fce6cc0ae526d5aebc8e7a1550300ceae3a2034 --- /dev/null +++ b/videoretalking/third_part/face3d/models/arcface_torch/utils/plot.py @@ -0,0 +1,72 @@ +# coding: utf-8 + +import os +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap +from prettytable import PrettyTable +from sklearn.metrics import roc_curve, auc + +image_path = "/data/anxiang/IJB_release/IJBC" +files = [ + "./ms1mv3_arcface_r100/ms1mv3_arcface_r100/ijbc.npy" +] + + +def read_template_pair_list(path): + pairs = pd.read_csv(path, sep=' ', header=None).values + t1 = pairs[:, 0].astype(np.int) + t2 = pairs[:, 1].astype(np.int) + label = pairs[:, 2].astype(np.int) + return t1, t2, label + + +p1, p2, label = read_template_pair_list( + os.path.join('%s/meta' % image_path, + '%s_template_pair_label.txt' % 'ijbc')) + +methods = [] +scores = [] +for file in files: + methods.append(file.split('/')[-2]) + scores.append(np.load(file)) + +methods = np.array(methods) +scores = dict(zip(methods, scores)) +colours = dict( + zip(methods, sample_colours_from_colourmap(methods.shape[0], 'Set2'))) +x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1] +tpr_fpr_table = PrettyTable(['Methods'] + [str(x) for x in x_labels]) +fig = plt.figure() +for method in methods: + fpr, tpr, _ = roc_curve(label, scores[method]) + roc_auc = auc(fpr, tpr) + fpr = np.flipud(fpr) + tpr = np.flipud(tpr) # select largest tpr at same fpr + plt.plot(fpr, + tpr, + color=colours[method], + lw=1, + label=('[%s (AUC = %0.4f %%)]' % + (method.split('-')[-1], roc_auc * 100))) + tpr_fpr_row = [] + tpr_fpr_row.append("%s-%s" % (method, "IJBC")) + for fpr_iter in np.arange(len(x_labels)): + _, min_index = min( + list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr))))) + tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100)) + tpr_fpr_table.add_row(tpr_fpr_row) +plt.xlim([10 ** -6, 0.1]) +plt.ylim([0.3, 1.0]) +plt.grid(linestyle='--', linewidth=1) +plt.xticks(x_labels) +plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True)) +plt.xscale('log') +plt.xlabel('False Positive Rate') +plt.ylabel('True Positive Rate') +plt.title('ROC on IJB') +plt.legend(loc="lower right") +print(tpr_fpr_table) diff --git a/videoretalking/third_part/face3d/models/arcface_torch/utils/utils_amp.py b/videoretalking/third_part/face3d/models/arcface_torch/utils/utils_amp.py new file mode 100644 index 0000000000000000000000000000000000000000..a6d5bcbb540ff8b04535e71c0057e124338df5bd --- /dev/null +++ b/videoretalking/third_part/face3d/models/arcface_torch/utils/utils_amp.py @@ -0,0 +1,88 @@ +from typing import Dict, List + +import torch + +if torch.__version__ < '1.9': + Iterable = torch._six.container_abcs.Iterable +else: + import collections + + Iterable = collections.abc.Iterable +from torch.cuda.amp import GradScaler + + +class _MultiDeviceReplicator(object): + """ + Lazily serves copies of a tensor to requested devices. Copies are cached per-device. + """ + + def __init__(self, master_tensor: torch.Tensor) -> None: + assert master_tensor.is_cuda + self.master = master_tensor + self._per_device_tensors: Dict[torch.device, torch.Tensor] = {} + + def get(self, device) -> torch.Tensor: + retval = self._per_device_tensors.get(device, None) + if retval is None: + retval = self.master.to(device=device, non_blocking=True, copy=True) + self._per_device_tensors[device] = retval + return retval + + +class MaxClipGradScaler(GradScaler): + def __init__(self, init_scale, max_scale: float, growth_interval=100): + GradScaler.__init__(self, init_scale=init_scale, growth_interval=growth_interval) + self.max_scale = max_scale + + def scale_clip(self): + if self.get_scale() == self.max_scale: + self.set_growth_factor(1) + elif self.get_scale() < self.max_scale: + self.set_growth_factor(2) + elif self.get_scale() > self.max_scale: + self._scale.fill_(self.max_scale) + self.set_growth_factor(1) + + def scale(self, outputs): + """ + Multiplies ('scales') a tensor or list of tensors by the scale factor. + + Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned + unmodified. + + Arguments: + outputs (Tensor or iterable of Tensors): Outputs to scale. + """ + if not self._enabled: + return outputs + self.scale_clip() + # Short-circuit for the common case. + if isinstance(outputs, torch.Tensor): + assert outputs.is_cuda + if self._scale is None: + self._lazy_init_scale_growth_tracker(outputs.device) + assert self._scale is not None + return outputs * self._scale.to(device=outputs.device, non_blocking=True) + + # Invoke the more complex machinery only if we're treating multiple outputs. + stash: List[_MultiDeviceReplicator] = [] # holds a reference that can be overwritten by apply_scale + + def apply_scale(val): + if isinstance(val, torch.Tensor): + assert val.is_cuda + if len(stash) == 0: + if self._scale is None: + self._lazy_init_scale_growth_tracker(val.device) + assert self._scale is not None + stash.append(_MultiDeviceReplicator(self._scale)) + return val * stash[0].get(val.device) + elif isinstance(val, Iterable): + iterable = map(apply_scale, val) + if isinstance(val, list) or isinstance(val, tuple): + return type(val)(iterable) + else: + return iterable + else: + raise ValueError("outputs must be a Tensor or an iterable of Tensors") + + return apply_scale(outputs) diff --git a/videoretalking/third_part/face3d/models/arcface_torch/utils/utils_callbacks.py b/videoretalking/third_part/face3d/models/arcface_torch/utils/utils_callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..748923b36358bd118efa0532a6f512b6ca96ff34 --- /dev/null +++ b/videoretalking/third_part/face3d/models/arcface_torch/utils/utils_callbacks.py @@ -0,0 +1,117 @@ +import logging +import os +import time +from typing import List + +import torch + +from eval import verification +from utils.utils_logging import AverageMeter + + +class CallBackVerification(object): + def __init__(self, frequent, rank, val_targets, rec_prefix, image_size=(112, 112)): + self.frequent: int = frequent + self.rank: int = rank + self.highest_acc: float = 0.0 + self.highest_acc_list: List[float] = [0.0] * len(val_targets) + self.ver_list: List[object] = [] + self.ver_name_list: List[str] = [] + if self.rank is 0: + self.init_dataset(val_targets=val_targets, data_dir=rec_prefix, image_size=image_size) + + def ver_test(self, backbone: torch.nn.Module, global_step: int): + results = [] + for i in range(len(self.ver_list)): + acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test( + self.ver_list[i], backbone, 10, 10) + logging.info('[%s][%d]XNorm: %f' % (self.ver_name_list[i], global_step, xnorm)) + logging.info('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (self.ver_name_list[i], global_step, acc2, std2)) + if acc2 > self.highest_acc_list[i]: + self.highest_acc_list[i] = acc2 + logging.info( + '[%s][%d]Accuracy-Highest: %1.5f' % (self.ver_name_list[i], global_step, self.highest_acc_list[i])) + results.append(acc2) + + def init_dataset(self, val_targets, data_dir, image_size): + for name in val_targets: + path = os.path.join(data_dir, name + ".bin") + if os.path.exists(path): + data_set = verification.load_bin(path, image_size) + self.ver_list.append(data_set) + self.ver_name_list.append(name) + + def __call__(self, num_update, backbone: torch.nn.Module): + if self.rank is 0 and num_update > 0 and num_update % self.frequent == 0: + backbone.eval() + self.ver_test(backbone, num_update) + backbone.train() + + +class CallBackLogging(object): + def __init__(self, frequent, rank, total_step, batch_size, world_size, writer=None): + self.frequent: int = frequent + self.rank: int = rank + self.time_start = time.time() + self.total_step: int = total_step + self.batch_size: int = batch_size + self.world_size: int = world_size + self.writer = writer + + self.init = False + self.tic = 0 + + def __call__(self, + global_step: int, + loss: AverageMeter, + epoch: int, + fp16: bool, + learning_rate: float, + grad_scaler: torch.cuda.amp.GradScaler): + if self.rank == 0 and global_step > 0 and global_step % self.frequent == 0: + if self.init: + try: + speed: float = self.frequent * self.batch_size / (time.time() - self.tic) + speed_total = speed * self.world_size + except ZeroDivisionError: + speed_total = float('inf') + + time_now = (time.time() - self.time_start) / 3600 + time_total = time_now / ((global_step + 1) / self.total_step) + time_for_end = time_total - time_now + if self.writer is not None: + self.writer.add_scalar('time_for_end', time_for_end, global_step) + self.writer.add_scalar('learning_rate', learning_rate, global_step) + self.writer.add_scalar('loss', loss.avg, global_step) + if fp16: + msg = "Speed %.2f samples/sec Loss %.4f LearningRate %.4f Epoch: %d Global Step: %d " \ + "Fp16 Grad Scale: %2.f Required: %1.f hours" % ( + speed_total, loss.avg, learning_rate, epoch, global_step, + grad_scaler.get_scale(), time_for_end + ) + else: + msg = "Speed %.2f samples/sec Loss %.4f LearningRate %.4f Epoch: %d Global Step: %d " \ + "Required: %1.f hours" % ( + speed_total, loss.avg, learning_rate, epoch, global_step, time_for_end + ) + logging.info(msg) + loss.reset() + self.tic = time.time() + else: + self.init = True + self.tic = time.time() + + +class CallBackModelCheckpoint(object): + def __init__(self, rank, output="./"): + self.rank: int = rank + self.output: str = output + + def __call__(self, global_step, backbone, partial_fc, ): + if global_step > 100 and self.rank == 0: + path_module = os.path.join(self.output, "backbone.pth") + torch.save(backbone.module.state_dict(), path_module) + logging.info("Pytorch Model Saved in '{}'".format(path_module)) + + if global_step > 100 and partial_fc is not None: + partial_fc.save_params() diff --git a/videoretalking/third_part/face3d/models/arcface_torch/utils/utils_config.py b/videoretalking/third_part/face3d/models/arcface_torch/utils/utils_config.py new file mode 100644 index 0000000000000000000000000000000000000000..b60a1e5a2e860ce5511a2d3863c8b57a4df292d7 --- /dev/null +++ b/videoretalking/third_part/face3d/models/arcface_torch/utils/utils_config.py @@ -0,0 +1,16 @@ +import importlib +import os.path as osp + + +def get_config(config_file): + assert config_file.startswith('configs/'), 'config file setting must start with configs/' + temp_config_name = osp.basename(config_file) + temp_module_name = osp.splitext(temp_config_name)[0] + config = importlib.import_module("configs.base") + cfg = config.config + config = importlib.import_module("configs.%s" % temp_module_name) + job_cfg = config.config + cfg.update(job_cfg) + if cfg.output is None: + cfg.output = osp.join('work_dirs', temp_module_name) + return cfg \ No newline at end of file diff --git a/videoretalking/third_part/face3d/models/arcface_torch/utils/utils_logging.py b/videoretalking/third_part/face3d/models/arcface_torch/utils/utils_logging.py new file mode 100644 index 0000000000000000000000000000000000000000..f2b43b851c9e06230abd94c73a1f64cfa1b6f3ac --- /dev/null +++ b/videoretalking/third_part/face3d/models/arcface_torch/utils/utils_logging.py @@ -0,0 +1,41 @@ +import logging +import os +import sys + + +class AverageMeter(object): + """Computes and stores the average and current value + """ + + def __init__(self): + self.val = None + self.avg = None + self.sum = None + self.count = None + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def init_logging(rank, models_root): + if rank == 0: + log_root = logging.getLogger() + log_root.setLevel(logging.INFO) + formatter = logging.Formatter("Training: %(asctime)s-%(message)s") + handler_file = logging.FileHandler(os.path.join(models_root, "training.log")) + handler_stream = logging.StreamHandler(sys.stdout) + handler_file.setFormatter(formatter) + handler_stream.setFormatter(formatter) + log_root.addHandler(handler_file) + log_root.addHandler(handler_stream) + log_root.info('rank_id: %d' % rank) diff --git a/videoretalking/third_part/face3d/models/arcface_torch/utils/utils_os.py b/videoretalking/third_part/face3d/models/arcface_torch/utils/utils_os.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/videoretalking/third_part/face3d/models/base_model.py b/videoretalking/third_part/face3d/models/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..7580bb5329b31532462a3b98250cd6b976bc952d --- /dev/null +++ b/videoretalking/third_part/face3d/models/base_model.py @@ -0,0 +1,316 @@ +"""This script defines the base network model for Deep3DFaceRecon_pytorch +""" + +import os +import numpy as np +import torch +from collections import OrderedDict +from abc import ABC, abstractmethod +from . import networks + + +class BaseModel(ABC): + """This class is an abstract base class (ABC) for models. + To create a subclass, you need to implement the following five functions: + -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). + -- : unpack data from dataset and apply preprocessing. + -- : produce intermediate results. + -- : calculate losses, gradients, and update network weights. + -- : (optionally) add model-specific options and set default options. + """ + + def __init__(self, opt): + """Initialize the BaseModel class. + + Parameters: + opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions + + When creating your custom class, you need to implement your own initialization. + In this fucntion, you should first call + Then, you need to define four lists: + -- self.loss_names (str list): specify the training losses that you want to plot and save. + -- self.model_names (str list): specify the images that you want to display and save. + -- self.visual_names (str list): define networks used in our training. + -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. + """ + self.opt = opt + self.isTrain = opt.isTrain + self.device = torch.device('cpu') + self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir + self.loss_names = [] + self.model_names = [] + self.visual_names = [] + self.parallel_names = [] + self.optimizers = [] + self.image_paths = [] + self.metric = 0 # used for learning rate policy 'plateau' + + @staticmethod + def dict_grad_hook_factory(add_func=lambda x: x): + saved_dict = dict() + + def hook_gen(name): + def grad_hook(grad): + saved_vals = add_func(grad) + saved_dict[name] = saved_vals + return grad_hook + return hook_gen, saved_dict + + @staticmethod + def modify_commandline_options(parser, is_train): + """Add new model-specific options, and rewrite default values for existing options. + + Parameters: + parser -- original option parser + is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. + + Returns: + the modified parser. + """ + return parser + + @abstractmethod + def set_input(self, input): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + + Parameters: + input (dict): includes the data itself and its metadata information. + """ + pass + + @abstractmethod + def forward(self): + """Run forward pass; called by both functions and .""" + pass + + @abstractmethod + def optimize_parameters(self): + """Calculate losses, gradients, and update network weights; called in every training iteration""" + pass + + def setup(self, opt): + """Load and print networks; create schedulers + + Parameters: + opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + if self.isTrain: + self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] + + if not self.isTrain or opt.continue_train: + load_suffix = opt.epoch + self.load_networks(load_suffix) + + + # self.print_networks(opt.verbose) + + def parallelize(self, convert_sync_batchnorm=True): + if not self.opt.use_ddp: + for name in self.parallel_names: + if isinstance(name, str): + module = getattr(self, name) + setattr(self, name, module.to(self.device)) + else: + for name in self.model_names: + if isinstance(name, str): + module = getattr(self, name) + if convert_sync_batchnorm: + module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module) + setattr(self, name, torch.nn.parallel.DistributedDataParallel(module.to(self.device), + device_ids=[self.device.index], + find_unused_parameters=True, broadcast_buffers=True)) + + # DistributedDataParallel is not needed when a module doesn't have any parameter that requires a gradient. + for name in self.parallel_names: + if isinstance(name, str) and name not in self.model_names: + module = getattr(self, name) + setattr(self, name, module.to(self.device)) + + # put state_dict of optimizer to gpu device + if self.opt.phase != 'test': + if self.opt.continue_train: + for optim in self.optimizers: + for state in optim.state.values(): + for k, v in state.items(): + if isinstance(v, torch.Tensor): + state[k] = v.to(self.device) + + def data_dependent_initialize(self, data): + pass + + def train(self): + """Make models train mode""" + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, name) + net.train() + + def eval(self): + """Make models eval mode""" + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, name) + net.eval() + + def test(self): + """Forward function used in test time. + + This function wraps function in no_grad() so we don't save intermediate steps for backprop + It also calls to produce additional visualization results + """ + with torch.no_grad(): + self.forward() + self.compute_visuals() + + def compute_visuals(self): + """Calculate additional output images for visdom and HTML visualization""" + pass + + def get_image_paths(self, name='A'): + """ Return image paths that are used to load current data""" + return self.image_paths if name =='A' else self.image_paths_B + + def update_learning_rate(self): + """Update learning rates for all the networks; called at the end of every epoch""" + for scheduler in self.schedulers: + if self.opt.lr_policy == 'plateau': + scheduler.step(self.metric) + else: + scheduler.step() + + lr = self.optimizers[0].param_groups[0]['lr'] + print('learning rate = %.7f' % lr) + + def get_current_visuals(self): + """Return visualization images. train.py will display these images with visdom, and save the images to a HTML""" + visual_ret = OrderedDict() + for name in self.visual_names: + if isinstance(name, str): + visual_ret[name] = getattr(self, name)[:, :3, ...] + return visual_ret + + def get_current_losses(self): + """Return traning losses / errors. train.py will print out these errors on console, and save them to a file""" + errors_ret = OrderedDict() + for name in self.loss_names: + if isinstance(name, str): + errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number + return errors_ret + + def save_networks(self, epoch): + """Save all the networks to the disk. + + Parameters: + epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) + """ + if not os.path.isdir(self.save_dir): + os.makedirs(self.save_dir) + + save_filename = 'epoch_%s.pth' % (epoch) + save_path = os.path.join(self.save_dir, save_filename) + + save_dict = {} + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, name) + if isinstance(net, torch.nn.DataParallel) or isinstance(net, + torch.nn.parallel.DistributedDataParallel): + net = net.module + save_dict[name] = net.state_dict() + + + for i, optim in enumerate(self.optimizers): + save_dict['opt_%02d'%i] = optim.state_dict() + + for i, sched in enumerate(self.schedulers): + save_dict['sched_%02d'%i] = sched.state_dict() + + torch.save(save_dict, save_path) + + def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): + """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)""" + key = keys[i] + if i + 1 == len(keys): # at the end, pointing to a parameter/buffer + if module.__class__.__name__.startswith('InstanceNorm') and \ + (key == 'running_mean' or key == 'running_var'): + if getattr(module, key) is None: + state_dict.pop('.'.join(keys)) + if module.__class__.__name__.startswith('InstanceNorm') and \ + (key == 'num_batches_tracked'): + state_dict.pop('.'.join(keys)) + else: + self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) + + def load_networks(self, epoch): + """Load all the networks from the disk. + + Parameters: + epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) + """ + if self.opt.isTrain and self.opt.pretrained_name is not None: + load_dir = os.path.join(self.opt.checkpoints_dir, self.opt.pretrained_name) + else: + load_dir = self.save_dir + load_filename = 'epoch_%s.pth' % (epoch) + load_path = os.path.join(load_dir, load_filename) + state_dict = torch.load(load_path, map_location=self.device) + print('loading the model from %s' % load_path) + + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, name) + if isinstance(net, torch.nn.DataParallel): + net = net.module + net.load_state_dict(state_dict[name]) + + if self.opt.phase != 'test': + if self.opt.continue_train: + print('loading the optim from %s' % load_path) + for i, optim in enumerate(self.optimizers): + optim.load_state_dict(state_dict['opt_%02d'%i]) + + try: + print('loading the sched from %s' % load_path) + for i, sched in enumerate(self.schedulers): + sched.load_state_dict(state_dict['sched_%02d'%i]) + except: + print('Failed to load schedulers, set schedulers according to epoch count manually') + for i, sched in enumerate(self.schedulers): + sched.last_epoch = self.opt.epoch_count - 1 + + + + + def print_networks(self, verbose): + """Print the total number of parameters in the network and (if verbose) network architecture + + Parameters: + verbose (bool) -- if verbose: print the network architecture + """ + print('---------- Networks initialized -------------') + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, name) + num_params = 0 + for param in net.parameters(): + num_params += param.numel() + if verbose: + print(net) + print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) + print('-----------------------------------------------') + + def set_requires_grad(self, nets, requires_grad=False): + """Set requies_grad=Fasle for all the networks to avoid unnecessary computations + Parameters: + nets (network list) -- a list of networks + requires_grad (bool) -- whether the networks require gradients or not + """ + if not isinstance(nets, list): + nets = [nets] + for net in nets: + if net is not None: + for param in net.parameters(): + param.requires_grad = requires_grad + + def generate_visuals_for_evaluation(self, data, mode): + return {} diff --git a/videoretalking/third_part/face3d/models/bfm.py b/videoretalking/third_part/face3d/models/bfm.py new file mode 100644 index 0000000000000000000000000000000000000000..74d47fe20766447e4363acad2220a5d87f943b52 --- /dev/null +++ b/videoretalking/third_part/face3d/models/bfm.py @@ -0,0 +1,303 @@ +"""This script defines the parametric 3d face model for Deep3DFaceRecon_pytorch +""" + +import numpy as np +import torch +import torch.nn.functional as F +from scipy.io import loadmat +from face3d.util.load_mats import transferBFM09 +import os + +def perspective_projection(focal, center): + # return p.T (N, 3) @ (3, 3) + return np.array([ + focal, 0, center, + 0, focal, center, + 0, 0, 1 + ]).reshape([3, 3]).astype(np.float32).transpose() + +class SH: + def __init__(self): + self.a = [np.pi, 2 * np.pi / np.sqrt(3.), 2 * np.pi / np.sqrt(8.)] + self.c = [1/np.sqrt(4 * np.pi), np.sqrt(3.) / np.sqrt(4 * np.pi), 3 * np.sqrt(5.) / np.sqrt(12 * np.pi)] + + + +class ParametricFaceModel: + def __init__(self, + bfm_folder='./BFM', + recenter=True, + camera_distance=10., + init_lit=np.array([ + 0.8, 0, 0, 0, 0, 0, 0, 0, 0 + ]), + focal=1015., + center=112., + is_train=True, + default_name='BFM_model_front.mat'): + + if not os.path.isfile(os.path.join(bfm_folder, default_name)): + transferBFM09(bfm_folder) + model = loadmat(os.path.join(bfm_folder, default_name)) + # mean face shape. [3*N,1] + self.mean_shape = model['meanshape'].astype(np.float32) + # identity basis. [3*N,80] + self.id_base = model['idBase'].astype(np.float32) + # expression basis. [3*N,64] + self.exp_base = model['exBase'].astype(np.float32) + # mean face texture. [3*N,1] (0-255) + self.mean_tex = model['meantex'].astype(np.float32) + # texture basis. [3*N,80] + self.tex_base = model['texBase'].astype(np.float32) + # face indices for each vertex that lies in. starts from 0. [N,8] + self.point_buf = model['point_buf'].astype(np.int64) - 1 + # vertex indices for each face. starts from 0. [F,3] + self.face_buf = model['tri'].astype(np.int64) - 1 + # vertex indices for 68 landmarks. starts from 0. [68,1] + self.keypoints = np.squeeze(model['keypoints']).astype(np.int64) - 1 + + if is_train: + # vertex indices for small face region to compute photometric error. starts from 0. + self.front_mask = np.squeeze(model['frontmask2_idx']).astype(np.int64) - 1 + # vertex indices for each face from small face region. starts from 0. [f,3] + self.front_face_buf = model['tri_mask2'].astype(np.int64) - 1 + # vertex indices for pre-defined skin region to compute reflectance loss + self.skin_mask = np.squeeze(model['skinmask']) + + if recenter: + mean_shape = self.mean_shape.reshape([-1, 3]) + mean_shape = mean_shape - np.mean(mean_shape, axis=0, keepdims=True) + self.mean_shape = mean_shape.reshape([-1, 1]) + + self.persc_proj = perspective_projection(focal, center) + self.device = 'cpu' + self.camera_distance = camera_distance + self.SH = SH() + self.init_lit = init_lit.reshape([1, 1, -1]).astype(np.float32) + + + def to(self, device): + self.device = device + for key, value in self.__dict__.items(): + if type(value).__module__ == np.__name__: + setattr(self, key, torch.tensor(value).to(device)) + + + def compute_shape(self, id_coeff, exp_coeff): + """ + Return: + face_shape -- torch.tensor, size (B, N, 3) + + Parameters: + id_coeff -- torch.tensor, size (B, 80), identity coeffs + exp_coeff -- torch.tensor, size (B, 64), expression coeffs + """ + batch_size = id_coeff.shape[0] + id_part = torch.einsum('ij,aj->ai', self.id_base, id_coeff) + exp_part = torch.einsum('ij,aj->ai', self.exp_base, exp_coeff) + face_shape = id_part + exp_part + self.mean_shape.reshape([1, -1]) + return face_shape.reshape([batch_size, -1, 3]) + + + def compute_texture(self, tex_coeff, normalize=True): + """ + Return: + face_texture -- torch.tensor, size (B, N, 3), in RGB order, range (0, 1.) + + Parameters: + tex_coeff -- torch.tensor, size (B, 80) + """ + batch_size = tex_coeff.shape[0] + face_texture = torch.einsum('ij,aj->ai', self.tex_base, tex_coeff) + self.mean_tex + if normalize: + face_texture = face_texture / 255. + return face_texture.reshape([batch_size, -1, 3]) + + + def compute_norm(self, face_shape): + """ + Return: + vertex_norm -- torch.tensor, size (B, N, 3) + + Parameters: + face_shape -- torch.tensor, size (B, N, 3) + """ + + v1 = face_shape[:, self.face_buf[:, 0]] + v2 = face_shape[:, self.face_buf[:, 1]] + v3 = face_shape[:, self.face_buf[:, 2]] + e1 = v1 - v2 + e2 = v2 - v3 + face_norm = torch.cross(e1, e2, dim=-1) + face_norm = F.normalize(face_norm, dim=-1, p=2) + face_norm = torch.cat([face_norm, torch.zeros(face_norm.shape[0], 1, 3).to(self.device)], dim=1) + + vertex_norm = torch.sum(face_norm[:, self.point_buf], dim=2) + vertex_norm = F.normalize(vertex_norm, dim=-1, p=2) + return vertex_norm + + + def compute_color(self, face_texture, face_norm, gamma): + """ + Return: + face_color -- torch.tensor, size (B, N, 3), range (0, 1.) + + Parameters: + face_texture -- torch.tensor, size (B, N, 3), from texture model, range (0, 1.) + face_norm -- torch.tensor, size (B, N, 3), rotated face normal + gamma -- torch.tensor, size (B, 27), SH coeffs + """ + batch_size = gamma.shape[0] + v_num = face_texture.shape[1] + a, c = self.SH.a, self.SH.c + gamma = gamma.reshape([batch_size, 3, 9]) + gamma = gamma + self.init_lit + gamma = gamma.permute(0, 2, 1) + Y = torch.cat([ + a[0] * c[0] * torch.ones_like(face_norm[..., :1]).to(self.device), + -a[1] * c[1] * face_norm[..., 1:2], + a[1] * c[1] * face_norm[..., 2:], + -a[1] * c[1] * face_norm[..., :1], + a[2] * c[2] * face_norm[..., :1] * face_norm[..., 1:2], + -a[2] * c[2] * face_norm[..., 1:2] * face_norm[..., 2:], + 0.5 * a[2] * c[2] / np.sqrt(3.) * (3 * face_norm[..., 2:] ** 2 - 1), + -a[2] * c[2] * face_norm[..., :1] * face_norm[..., 2:], + 0.5 * a[2] * c[2] * (face_norm[..., :1] ** 2 - face_norm[..., 1:2] ** 2) + ], dim=-1) + r = Y @ gamma[..., :1] + g = Y @ gamma[..., 1:2] + b = Y @ gamma[..., 2:] + face_color = torch.cat([r, g, b], dim=-1) * face_texture + return face_color + + + def compute_rotation(self, angles): + """ + Return: + rot -- torch.tensor, size (B, 3, 3) pts @ trans_mat + + Parameters: + angles -- torch.tensor, size (B, 3), radian + """ + + batch_size = angles.shape[0] + ones = torch.ones([batch_size, 1]).to(self.device) + zeros = torch.zeros([batch_size, 1]).to(self.device) + x, y, z = angles[:, :1], angles[:, 1:2], angles[:, 2:], + + rot_x = torch.cat([ + ones, zeros, zeros, + zeros, torch.cos(x), -torch.sin(x), + zeros, torch.sin(x), torch.cos(x) + ], dim=1).reshape([batch_size, 3, 3]) + + rot_y = torch.cat([ + torch.cos(y), zeros, torch.sin(y), + zeros, ones, zeros, + -torch.sin(y), zeros, torch.cos(y) + ], dim=1).reshape([batch_size, 3, 3]) + + rot_z = torch.cat([ + torch.cos(z), -torch.sin(z), zeros, + torch.sin(z), torch.cos(z), zeros, + zeros, zeros, ones + ], dim=1).reshape([batch_size, 3, 3]) + + rot = rot_z @ rot_y @ rot_x + return rot.permute(0, 2, 1) + + + def to_camera(self, face_shape): + face_shape[..., -1] = self.camera_distance - face_shape[..., -1] + return face_shape + + def to_image(self, face_shape): + """ + Return: + face_proj -- torch.tensor, size (B, N, 2), y direction is opposite to v direction + + Parameters: + face_shape -- torch.tensor, size (B, N, 3) + """ + # to image_plane + face_proj = face_shape @ self.persc_proj + face_proj = face_proj[..., :2] / face_proj[..., 2:] + + return face_proj + + + def transform(self, face_shape, rot, trans): + """ + Return: + face_shape -- torch.tensor, size (B, N, 3) pts @ rot + trans + + Parameters: + face_shape -- torch.tensor, size (B, N, 3) + rot -- torch.tensor, size (B, 3, 3) + trans -- torch.tensor, size (B, 3) + """ + return face_shape @ rot + trans.unsqueeze(1) + + + def get_landmarks(self, face_proj): + """ + Return: + face_lms -- torch.tensor, size (B, 68, 2) + + Parameters: + face_proj -- torch.tensor, size (B, N, 2) + """ + return face_proj[:, self.keypoints] + + def split_coeff(self, coeffs): + """ + Return: + coeffs_dict -- a dict of torch.tensors + + Parameters: + coeffs -- torch.tensor, size (B, 256) + """ + id_coeffs = coeffs[:, :80] + exp_coeffs = coeffs[:, 80: 144] + tex_coeffs = coeffs[:, 144: 224] + angles = coeffs[:, 224: 227] + gammas = coeffs[:, 227: 254] + translations = coeffs[:, 254:] + return { + 'id': id_coeffs, + 'exp': exp_coeffs, + 'tex': tex_coeffs, + 'angle': angles, + 'gamma': gammas, + 'trans': translations + } + def compute_for_render(self, coeffs): + """ + Return: + face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate + face_color -- torch.tensor, size (B, N, 3), in RGB order + landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction + Parameters: + coeffs -- torch.tensor, size (B, 257) + """ + coef_dict = self.split_coeff(coeffs) + face_shape = self.compute_shape(coef_dict['id'], coef_dict['exp']) + rotation = self.compute_rotation(coef_dict['angle']) + + + face_shape_transformed = self.transform(face_shape, rotation, coef_dict['trans']) + face_vertex = self.to_camera(face_shape_transformed) + + face_proj = self.to_image(face_vertex) + landmark = self.get_landmarks(face_proj) + + face_texture = self.compute_texture(coef_dict['tex']) + face_norm = self.compute_norm(face_shape) + face_norm_roted = face_norm @ rotation + face_color = self.compute_color(face_texture, face_norm_roted, coef_dict['gamma']) + + return face_vertex, face_texture, face_color, landmark + + +if __name__ == '__main__': + transferBFM09() \ No newline at end of file diff --git a/videoretalking/third_part/face3d/models/facerecon_model.py b/videoretalking/third_part/face3d/models/facerecon_model.py new file mode 100644 index 0000000000000000000000000000000000000000..585025565a53eccab887ae7c77f06d225c72d647 --- /dev/null +++ b/videoretalking/third_part/face3d/models/facerecon_model.py @@ -0,0 +1,227 @@ +"""This script defines the face reconstruction model for Deep3DFaceRecon_pytorch +""" + +import numpy as np +import torch +from face3d.models.base_model import BaseModel +from face3d.models import networks +from face3d.models.bfm import ParametricFaceModel +from face3d.models.losses import perceptual_loss, photo_loss, reg_loss, reflectance_loss, landmark_loss +from face3d.util import util +from face3d.util.nvdiffrast import MeshRenderer +from face3d.util.preprocess import estimate_norm_torch + +import trimesh +from scipy.io import savemat + +class FaceReconModel(BaseModel): + + @staticmethod + def modify_commandline_options(parser, is_train=True): + """ Configures options specific for CUT model + """ + # net structure and parameters + parser.add_argument('--net_recon', type=str, default='resnet50', choices=['resnet18', 'resnet34', 'resnet50'], help='network structure') + parser.add_argument('--init_path', type=str, default='checkpoints/init_model/resnet50-0676ba61.pth') + parser.add_argument('--use_last_fc', type=util.str2bool, nargs='?', const=True, default=False, help='zero initialize the last fc') + parser.add_argument('--bfm_folder', type=str, default='BFM') + parser.add_argument('--bfm_model', type=str, default='BFM_model_front.mat', help='bfm model') + + # renderer parameters + parser.add_argument('--focal', type=float, default=1015.) + parser.add_argument('--center', type=float, default=112.) + parser.add_argument('--camera_d', type=float, default=10.) + parser.add_argument('--z_near', type=float, default=5.) + parser.add_argument('--z_far', type=float, default=15.) + + if is_train: + # training parameters + parser.add_argument('--net_recog', type=str, default='r50', choices=['r18', 'r43', 'r50'], help='face recog network structure') + parser.add_argument('--net_recog_path', type=str, default='checkpoints/recog_model/ms1mv3_arcface_r50_fp16/backbone.pth') + parser.add_argument('--use_crop_face', type=util.str2bool, nargs='?', const=True, default=False, help='use crop mask for photo loss') + parser.add_argument('--use_predef_M', type=util.str2bool, nargs='?', const=True, default=False, help='use predefined M for predicted face') + + + # augmentation parameters + parser.add_argument('--shift_pixs', type=float, default=10., help='shift pixels') + parser.add_argument('--scale_delta', type=float, default=0.1, help='delta scale factor') + parser.add_argument('--rot_angle', type=float, default=10., help='rot angles, degree') + + # loss weights + parser.add_argument('--w_feat', type=float, default=0.2, help='weight for feat loss') + parser.add_argument('--w_color', type=float, default=1.92, help='weight for loss loss') + parser.add_argument('--w_reg', type=float, default=3.0e-4, help='weight for reg loss') + parser.add_argument('--w_id', type=float, default=1.0, help='weight for id_reg loss') + parser.add_argument('--w_exp', type=float, default=0.8, help='weight for exp_reg loss') + parser.add_argument('--w_tex', type=float, default=1.7e-2, help='weight for tex_reg loss') + parser.add_argument('--w_gamma', type=float, default=10.0, help='weight for gamma loss') + parser.add_argument('--w_lm', type=float, default=1.6e-3, help='weight for lm loss') + parser.add_argument('--w_reflc', type=float, default=5.0, help='weight for reflc loss') + + + + opt, _ = parser.parse_known_args() + parser.set_defaults( + focal=1015., center=112., camera_d=10., use_last_fc=False, z_near=5., z_far=15. + ) + if is_train: + parser.set_defaults( + use_crop_face=True, use_predef_M=False + ) + return parser + + def __init__(self, opt): + """Initialize this model class. + + Parameters: + opt -- training/test options + + A few things can be done here. + - (required) call the initialization function of BaseModel + - define loss function, visualization images, model names, and optimizers + """ + BaseModel.__init__(self, opt) # call the initialization method of BaseModel + + self.visual_names = ['output_vis'] + self.model_names = ['net_recon'] + self.parallel_names = self.model_names + ['renderer'] + + self.net_recon = networks.define_net_recon( + net_recon=opt.net_recon, use_last_fc=opt.use_last_fc, init_path=opt.init_path + ) + + self.facemodel = ParametricFaceModel( + bfm_folder=opt.bfm_folder, camera_distance=opt.camera_d, focal=opt.focal, center=opt.center, + is_train=self.isTrain, default_name=opt.bfm_model + ) + + fov = 2 * np.arctan(opt.center / opt.focal) * 180 / np.pi + self.renderer = MeshRenderer( + rasterize_fov=fov, znear=opt.z_near, zfar=opt.z_far, rasterize_size=int(2 * opt.center) + ) + + if self.isTrain: + self.loss_names = ['all', 'feat', 'color', 'lm', 'reg', 'gamma', 'reflc'] + + self.net_recog = networks.define_net_recog( + net_recog=opt.net_recog, pretrained_path=opt.net_recog_path + ) + # loss func name: (compute_%s_loss) % loss_name + self.compute_feat_loss = perceptual_loss + self.comupte_color_loss = photo_loss + self.compute_lm_loss = landmark_loss + self.compute_reg_loss = reg_loss + self.compute_reflc_loss = reflectance_loss + + self.optimizer = torch.optim.Adam(self.net_recon.parameters(), lr=opt.lr) + self.optimizers = [self.optimizer] + self.parallel_names += ['net_recog'] + # Our program will automatically call to define schedulers, load networks, and print networks + + def set_input(self, input): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + + Parameters: + input: a dictionary that contains the data itself and its metadata information. + """ + self.input_img = input['imgs'].to(self.device) + self.atten_mask = input['msks'].to(self.device) if 'msks' in input else None + self.gt_lm = input['lms'].to(self.device) if 'lms' in input else None + self.trans_m = input['M'].to(self.device) if 'M' in input else None + self.image_paths = input['im_paths'] if 'im_paths' in input else None + + def forward(self): + output_coeff = self.net_recon(self.input_img) + self.facemodel.to(self.device) + self.pred_vertex, self.pred_tex, self.pred_color, self.pred_lm = \ + self.facemodel.compute_for_render(output_coeff) + self.pred_mask, _, self.pred_face = self.renderer( + self.pred_vertex, self.facemodel.face_buf, feat=self.pred_color) + + self.pred_coeffs_dict = self.facemodel.split_coeff(output_coeff) + + + def compute_losses(self): + """Calculate losses, gradients, and update network weights; called in every training iteration""" + + assert self.net_recog.training == False + trans_m = self.trans_m + if not self.opt.use_predef_M: + trans_m = estimate_norm_torch(self.pred_lm, self.input_img.shape[-2]) + + pred_feat = self.net_recog(self.pred_face, trans_m) + gt_feat = self.net_recog(self.input_img, self.trans_m) + self.loss_feat = self.opt.w_feat * self.compute_feat_loss(pred_feat, gt_feat) + + face_mask = self.pred_mask + if self.opt.use_crop_face: + face_mask, _, _ = self.renderer(self.pred_vertex, self.facemodel.front_face_buf) + + face_mask = face_mask.detach() + self.loss_color = self.opt.w_color * self.comupte_color_loss( + self.pred_face, self.input_img, self.atten_mask * face_mask) + + loss_reg, loss_gamma = self.compute_reg_loss(self.pred_coeffs_dict, self.opt) + self.loss_reg = self.opt.w_reg * loss_reg + self.loss_gamma = self.opt.w_gamma * loss_gamma + + self.loss_lm = self.opt.w_lm * self.compute_lm_loss(self.pred_lm, self.gt_lm) + + self.loss_reflc = self.opt.w_reflc * self.compute_reflc_loss(self.pred_tex, self.facemodel.skin_mask) + + self.loss_all = self.loss_feat + self.loss_color + self.loss_reg + self.loss_gamma \ + + self.loss_lm + self.loss_reflc + + + def optimize_parameters(self, isTrain=True): + self.forward() + self.compute_losses() + """Update network weights; it will be called in every training iteration.""" + if isTrain: + self.optimizer.zero_grad() + self.loss_all.backward() + self.optimizer.step() + + def compute_visuals(self): + with torch.no_grad(): + input_img_numpy = 255. * self.input_img.detach().cpu().permute(0, 2, 3, 1).numpy() + output_vis = self.pred_face * self.pred_mask + (1 - self.pred_mask) * self.input_img + output_vis_numpy_raw = 255. * output_vis.detach().cpu().permute(0, 2, 3, 1).numpy() + + if self.gt_lm is not None: + gt_lm_numpy = self.gt_lm.cpu().numpy() + pred_lm_numpy = self.pred_lm.detach().cpu().numpy() + output_vis_numpy = util.draw_landmarks(output_vis_numpy_raw, gt_lm_numpy, 'b') + output_vis_numpy = util.draw_landmarks(output_vis_numpy, pred_lm_numpy, 'r') + + output_vis_numpy = np.concatenate((input_img_numpy, + output_vis_numpy_raw, output_vis_numpy), axis=-2) + else: + output_vis_numpy = np.concatenate((input_img_numpy, + output_vis_numpy_raw), axis=-2) + + self.output_vis = torch.tensor( + output_vis_numpy / 255., dtype=torch.float32 + ).permute(0, 3, 1, 2).to(self.device) + + def save_mesh(self, name): + + recon_shape = self.pred_vertex # get reconstructed shape + recon_shape[..., -1] = 10 - recon_shape[..., -1] # from camera space to world space + recon_shape = recon_shape.cpu().numpy()[0] + recon_color = self.pred_color + recon_color = recon_color.cpu().numpy()[0] + tri = self.facemodel.face_buf.cpu().numpy() + mesh = trimesh.Trimesh(vertices=recon_shape, faces=tri, vertex_colors=np.clip(255. * recon_color, 0, 255).astype(np.uint8)) + mesh.export(name) + + def save_coeff(self,name): + + pred_coeffs = {key:self.pred_coeffs_dict[key].cpu().numpy() for key in self.pred_coeffs_dict} + pred_lm = self.pred_lm.cpu().numpy() + pred_lm = np.stack([pred_lm[:,:,0],self.input_img.shape[2]-1-pred_lm[:,:,1]],axis=2) # transfer to image coordinate + pred_coeffs['lm68'] = pred_lm + savemat(name,pred_coeffs) + + + diff --git a/videoretalking/third_part/face3d/models/losses.py b/videoretalking/third_part/face3d/models/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..43ee47f0ca601822f40f9772001f71a2c1aba71d --- /dev/null +++ b/videoretalking/third_part/face3d/models/losses.py @@ -0,0 +1,113 @@ +import numpy as np +import torch +import torch.nn as nn +from kornia.geometry import warp_affine +import torch.nn.functional as F + +def resize_n_crop(image, M, dsize=112): + # image: (b, c, h, w) + # M : (b, 2, 3) + return warp_affine(image, M, dsize=(dsize, dsize)) + +### perceptual level loss +class PerceptualLoss(nn.Module): + def __init__(self, recog_net, input_size=112): + super(PerceptualLoss, self).__init__() + self.recog_net = recog_net + self.preprocess = lambda x: 2 * x - 1 + self.input_size=input_size + def forward(imageA, imageB, M): + """ + 1 - cosine distance + Parameters: + imageA --torch.tensor (B, 3, H, W), range (0, 1) , RGB order + imageB --same as imageA + """ + + imageA = self.preprocess(resize_n_crop(imageA, M, self.input_size)) + imageB = self.preprocess(resize_n_crop(imageB, M, self.input_size)) + + # freeze bn + self.recog_net.eval() + + id_featureA = F.normalize(self.recog_net(imageA), dim=-1, p=2) + id_featureB = F.normalize(self.recog_net(imageB), dim=-1, p=2) + cosine_d = torch.sum(id_featureA * id_featureB, dim=-1) + # assert torch.sum((cosine_d > 1).float()) == 0 + return torch.sum(1 - cosine_d) / cosine_d.shape[0] + +def perceptual_loss(id_featureA, id_featureB): + cosine_d = torch.sum(id_featureA * id_featureB, dim=-1) + # assert torch.sum((cosine_d > 1).float()) == 0 + return torch.sum(1 - cosine_d) / cosine_d.shape[0] + +### image level loss +def photo_loss(imageA, imageB, mask, eps=1e-6): + """ + l2 norm (with sqrt, to ensure backward stabililty, use eps, otherwise Nan may occur) + Parameters: + imageA --torch.tensor (B, 3, H, W), range (0, 1), RGB order + imageB --same as imageA + """ + loss = torch.sqrt(eps + torch.sum((imageA - imageB) ** 2, dim=1, keepdims=True)) * mask + loss = torch.sum(loss) / torch.max(torch.sum(mask), torch.tensor(1.0).to(mask.device)) + return loss + +def landmark_loss(predict_lm, gt_lm, weight=None): + """ + weighted mse loss + Parameters: + predict_lm --torch.tensor (B, 68, 2) + gt_lm --torch.tensor (B, 68, 2) + weight --numpy.array (1, 68) + """ + if not weight: + weight = np.ones([68]) + weight[28:31] = 20 + weight[-8:] = 20 + weight = np.expand_dims(weight, 0) + weight = torch.tensor(weight).to(predict_lm.device) + loss = torch.sum((predict_lm - gt_lm)**2, dim=-1) * weight + loss = torch.sum(loss) / (predict_lm.shape[0] * predict_lm.shape[1]) + return loss + + +### regulization +def reg_loss(coeffs_dict, opt=None): + """ + l2 norm without the sqrt, from yu's implementation (mse) + tf.nn.l2_loss https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss + Parameters: + coeffs_dict -- a dict of torch.tensors , keys: id, exp, tex, angle, gamma, trans + + """ + # coefficient regularization to ensure plausible 3d faces + if opt: + w_id, w_exp, w_tex = opt.w_id, opt.w_exp, opt.w_tex + else: + w_id, w_exp, w_tex = 1, 1, 1, 1 + creg_loss = w_id * torch.sum(coeffs_dict['id'] ** 2) + \ + w_exp * torch.sum(coeffs_dict['exp'] ** 2) + \ + w_tex * torch.sum(coeffs_dict['tex'] ** 2) + creg_loss = creg_loss / coeffs_dict['id'].shape[0] + + # gamma regularization to ensure a nearly-monochromatic light + gamma = coeffs_dict['gamma'].reshape([-1, 3, 9]) + gamma_mean = torch.mean(gamma, dim=1, keepdims=True) + gamma_loss = torch.mean((gamma - gamma_mean) ** 2) + + return creg_loss, gamma_loss + +def reflectance_loss(texture, mask): + """ + minimize texture variance (mse), albedo regularization to ensure an uniform skin albedo + Parameters: + texture --torch.tensor, (B, N, 3) + mask --torch.tensor, (N), 1 or 0 + + """ + mask = mask.reshape([1, mask.shape[0], 1]) + texture_mean = torch.sum(mask * texture, dim=1, keepdims=True) / torch.sum(mask) + loss = torch.sum(((texture - texture_mean) * mask)**2) / (texture.shape[0] * torch.sum(mask)) + return loss + diff --git a/videoretalking/third_part/face3d/models/networks.py b/videoretalking/third_part/face3d/models/networks.py new file mode 100644 index 0000000000000000000000000000000000000000..30c0250f500cd2326abb6e0a19692029db30a653 --- /dev/null +++ b/videoretalking/third_part/face3d/models/networks.py @@ -0,0 +1,521 @@ +"""This script defines deep neural networks for Deep3DFaceRecon_pytorch +""" + +import os +import numpy as np +import torch.nn.functional as F +from torch.nn import init +import functools +from torch.optim import lr_scheduler +import torch +from torch import Tensor +import torch.nn as nn +try: + from torch.hub import load_state_dict_from_url +except ImportError: + from torch.utils.model_zoo import load_url as load_state_dict_from_url +from typing import Type, Any, Callable, Union, List, Optional +from .arcface_torch.backbones import get_model +from kornia.geometry import warp_affine + +def resize_n_crop(image, M, dsize=112): + # image: (b, c, h, w) + # M : (b, 2, 3) + return warp_affine(image, M, dsize=(dsize, dsize)) + +def filter_state_dict(state_dict, remove_name='fc'): + new_state_dict = {} + for key in state_dict: + if remove_name in key: + continue + new_state_dict[key] = state_dict[key] + return new_state_dict + +def get_scheduler(optimizer, opt): + """Return a learning rate scheduler + + Parameters: + optimizer -- the optimizer of the network + opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.  + opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine + + For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. + See https://pytorch.org/docs/stable/optim.html for more details. + """ + if opt.lr_policy == 'linear': + def lambda_rule(epoch): + lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs + 1) + return lr_l + scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) + elif opt.lr_policy == 'step': + scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_epochs, gamma=0.2) + elif opt.lr_policy == 'plateau': + scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) + elif opt.lr_policy == 'cosine': + scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0) + else: + return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) + return scheduler + + +def define_net_recon(net_recon, use_last_fc=False, init_path=None): + return ReconNetWrapper(net_recon, use_last_fc=use_last_fc, init_path=init_path) + +def define_net_recog(net_recog, pretrained_path=None): + net = RecogNetWrapper(net_recog=net_recog, pretrained_path=pretrained_path) + net.eval() + return net + +class ReconNetWrapper(nn.Module): + fc_dim=257 + def __init__(self, net_recon, use_last_fc=False, init_path=None): + super(ReconNetWrapper, self).__init__() + self.use_last_fc = use_last_fc + if net_recon not in func_dict: + return NotImplementedError('network [%s] is not implemented', net_recon) + func, last_dim = func_dict[net_recon] + backbone = func(use_last_fc=use_last_fc, num_classes=self.fc_dim) + if init_path and os.path.isfile(init_path): + state_dict = filter_state_dict(torch.load(init_path, map_location='cpu')) + backbone.load_state_dict(state_dict) + print("loading init net_recon %s from %s" %(net_recon, init_path)) + self.backbone = backbone + if not use_last_fc: + self.final_layers = nn.ModuleList([ + conv1x1(last_dim, 80, bias=True), # id layer + conv1x1(last_dim, 64, bias=True), # exp layer + conv1x1(last_dim, 80, bias=True), # tex layer + conv1x1(last_dim, 3, bias=True), # angle layer + conv1x1(last_dim, 27, bias=True), # gamma layer + conv1x1(last_dim, 2, bias=True), # tx, ty + conv1x1(last_dim, 1, bias=True) # tz + ]) + for m in self.final_layers: + nn.init.constant_(m.weight, 0.) + nn.init.constant_(m.bias, 0.) + + def forward(self, x): + x = self.backbone(x) + if not self.use_last_fc: + output = [] + for layer in self.final_layers: + output.append(layer(x)) + x = torch.flatten(torch.cat(output, dim=1), 1) + return x + + +class RecogNetWrapper(nn.Module): + def __init__(self, net_recog, pretrained_path=None, input_size=112): + super(RecogNetWrapper, self).__init__() + net = get_model(name=net_recog, fp16=False) + if pretrained_path: + state_dict = torch.load(pretrained_path, map_location='cpu') + net.load_state_dict(state_dict) + print("loading pretrained net_recog %s from %s" %(net_recog, pretrained_path)) + for param in net.parameters(): + param.requires_grad = False + self.net = net + self.preprocess = lambda x: 2 * x - 1 + self.input_size=input_size + + def forward(self, image, M): + image = self.preprocess(resize_n_crop(image, M, self.input_size)) + id_feature = F.normalize(self.net(image), dim=-1, p=2) + return id_feature + + +# adapted from https://github.com/pytorch/vision/edit/master/torchvision/models/resnet.py +__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', + 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', + 'wide_resnet50_2', 'wide_resnet101_2'] + + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth', + 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', + 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', + 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', + 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', +} + + +def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + + +def conv1x1(in_planes: int, out_planes: int, stride: int = 1, bias: bool = False) -> nn.Conv2d: + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias) + + +class BasicBlock(nn.Module): + expansion: int = 1 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) + # while original implementation places the stride at the first 1x1 convolution(self.conv1) + # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. + # This variant is also known as ResNet V1.5 and improves accuracy according to + # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. + + expansion: int = 4 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__( + self, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + num_classes: int = 1000, + zero_init_residual: bool = False, + use_last_fc: bool = False, + groups: int = 1, + width_per_group: int = 64, + replace_stride_with_dilation: Optional[List[bool]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.use_last_fc = use_last_fc + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, + dilate=replace_stride_with_dilation[2]) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + + if self.use_last_fc: + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] + + def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, + stride: int = 1, dilate: bool = False) -> nn.Sequential: + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def _forward_impl(self, x: Tensor) -> Tensor: + # See note [TorchScript super()] + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + if self.use_last_fc: + x = torch.flatten(x, 1) + x = self.fc(x) + return x + + def forward(self, x: Tensor) -> Tensor: + return self._forward_impl(x) + + +def _resnet( + arch: str, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + pretrained: bool, + progress: bool, + **kwargs: Any +) -> ResNet: + model = ResNet(block, layers, **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls[arch], + progress=progress) + model.load_state_dict(state_dict) + return model + + +def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-18 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, + **kwargs) + + +def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-34 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, + **kwargs) + + +def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-50 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, + **kwargs) + + +def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-101 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, + **kwargs) + + +def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-152 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, + **kwargs) + + +def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNeXt-50 32x4d model from + `"Aggregated Residual Transformation for Deep Neural Networks" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['groups'] = 32 + kwargs['width_per_group'] = 4 + return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], + pretrained, progress, **kwargs) + + +def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNeXt-101 32x8d model from + `"Aggregated Residual Transformation for Deep Neural Networks" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['groups'] = 32 + kwargs['width_per_group'] = 8 + return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], + pretrained, progress, **kwargs) + + +def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""Wide ResNet-50-2 model from + `"Wide Residual Networks" `_. + + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['width_per_group'] = 64 * 2 + return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], + pretrained, progress, **kwargs) + + +def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""Wide ResNet-101-2 model from + `"Wide Residual Networks" `_. + + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['width_per_group'] = 64 * 2 + return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], + pretrained, progress, **kwargs) + + +func_dict = { + 'resnet18': (resnet18, 512), + 'resnet50': (resnet50, 2048) +} diff --git a/videoretalking/third_part/face3d/models/template_model.py b/videoretalking/third_part/face3d/models/template_model.py new file mode 100644 index 0000000000000000000000000000000000000000..791b0b3c7eea62b9b5e7fbf046f81f64d6aa72c3 --- /dev/null +++ b/videoretalking/third_part/face3d/models/template_model.py @@ -0,0 +1,100 @@ +"""Model class template + +This module provides a template for users to implement custom models. +You can specify '--model template' to use this model. +The class name should be consistent with both the filename and its model option. +The filename should be _dataset.py +The class name should be Dataset.py +It implements a simple image-to-image translation baseline based on regression loss. +Given input-output pairs (data_A, data_B), it learns a network netG that can minimize the following L1 loss: + min_ ||netG(data_A) - data_B||_1 +You need to implement the following functions: + : Add model-specific options and rewrite default values for existing options. + <__init__>: Initialize this model class. + : Unpack input data and perform data pre-processing. + : Run forward pass. This will be called by both and . + : Update network weights; it will be called in every training iteration. +""" +import numpy as np +import torch +from .base_model import BaseModel +from . import networks + + +class TemplateModel(BaseModel): + @staticmethod + def modify_commandline_options(parser, is_train=True): + """Add new model-specific options and rewrite default values for existing options. + + Parameters: + parser -- the option parser + is_train -- if it is training phase or test phase. You can use this flag to add training-specific or test-specific options. + + Returns: + the modified parser. + """ + parser.set_defaults(dataset_mode='aligned') # You can rewrite default values for this model. For example, this model usually uses aligned dataset as its dataset. + if is_train: + parser.add_argument('--lambda_regression', type=float, default=1.0, help='weight for the regression loss') # You can define new arguments for this model. + + return parser + + def __init__(self, opt): + """Initialize this model class. + + Parameters: + opt -- training/test options + + A few things can be done here. + - (required) call the initialization function of BaseModel + - define loss function, visualization images, model names, and optimizers + """ + BaseModel.__init__(self, opt) # call the initialization method of BaseModel + # specify the training losses you want to print out. The program will call base_model.get_current_losses to plot the losses to the console and save them to the disk. + self.loss_names = ['loss_G'] + # specify the images you want to save and display. The program will call base_model.get_current_visuals to save and display these images. + self.visual_names = ['data_A', 'data_B', 'output'] + # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks to save and load networks. + # you can use opt.isTrain to specify different behaviors for training and test. For example, some networks will not be used during test, and you don't need to load them. + self.model_names = ['G'] + # define networks; you can use opt.isTrain to specify different behaviors for training and test. + self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, gpu_ids=self.gpu_ids) + if self.isTrain: # only defined during training time + # define your loss functions. You can use losses provided by torch.nn such as torch.nn.L1Loss. + # We also provide a GANLoss class "networks.GANLoss". self.criterionGAN = networks.GANLoss().to(self.device) + self.criterionLoss = torch.nn.L1Loss() + # define and initialize optimizers. You can define one optimizer for each network. + # If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. + self.optimizer = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + self.optimizers = [self.optimizer] + + # Our program will automatically call to define schedulers, load networks, and print networks + + def set_input(self, input): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + + Parameters: + input: a dictionary that contains the data itself and its metadata information. + """ + AtoB = self.opt.direction == 'AtoB' # use to swap data_A and data_B + self.data_A = input['A' if AtoB else 'B'].to(self.device) # get image data A + self.data_B = input['B' if AtoB else 'A'].to(self.device) # get image data B + self.image_paths = input['A_paths' if AtoB else 'B_paths'] # get image paths + + def forward(self): + """Run forward pass. This will be called by both functions and .""" + self.output = self.netG(self.data_A) # generate output image given the input data_A + + def backward(self): + """Calculate losses, gradients, and update network weights; called in every training iteration""" + # calculate the intermediate results if necessary; here self.output has been computed during function + # calculate loss given the input and intermediate results + self.loss_G = self.criterionLoss(self.output, self.data_B) * self.opt.lambda_regression + self.loss_G.backward() # calculate gradients of network G w.r.t. loss_G + + def optimize_parameters(self): + """Update network weights; it will be called in every training iteration.""" + self.forward() # first call forward to calculate intermediate results + self.optimizer.zero_grad() # clear network G's existing gradients + self.backward() # calculate gradients for network G + self.optimizer.step() # update gradients for network G diff --git a/videoretalking/third_part/face3d/options/base_options.py b/videoretalking/third_part/face3d/options/base_options.py new file mode 100644 index 0000000000000000000000000000000000000000..afb0f0e6e09255c6be56c38d9116eab6e00dd6dd --- /dev/null +++ b/videoretalking/third_part/face3d/options/base_options.py @@ -0,0 +1,169 @@ +"""This script contains base options for Deep3DFaceRecon_pytorch +""" + +import argparse +import os +from util import util +import numpy as np +import torch +import face3d.models as models +import face3d.data as data + + +class BaseOptions(): + """This class defines options used during both training and test time. + + It also implements several helper functions such as parsing, printing, and saving the options. + It also gathers additional options defined in functions in both dataset class and model class. + """ + + def __init__(self, cmd_line=None): + """Reset the class; indicates the class hasn't been initialized""" + self.initialized = False + self.cmd_line = None + if cmd_line is not None: + self.cmd_line = cmd_line.split() + + def initialize(self, parser): + """Define the common options that are used in both training and test.""" + # basic parameters + parser.add_argument('--name', type=str, default='face_recon', help='name of the experiment. It decides where to store samples and models') + parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') + parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') + parser.add_argument('--vis_batch_nums', type=float, default=1, help='batch nums of images for visulization') + parser.add_argument('--eval_batch_nums', type=float, default=float('inf'), help='batch nums of images for evaluation') + parser.add_argument('--use_ddp', type=util.str2bool, nargs='?', const=True, default=True, help='whether use distributed data parallel') + parser.add_argument('--ddp_port', type=str, default='12355', help='ddp port') + parser.add_argument('--display_per_batch', type=util.str2bool, nargs='?', const=True, default=True, help='whether use batch to show losses') + parser.add_argument('--add_image', type=util.str2bool, nargs='?', const=True, default=True, help='whether add image to tensorboard') + parser.add_argument('--world_size', type=int, default=1, help='batch nums of images for evaluation') + + # model parameters + parser.add_argument('--model', type=str, default='facerecon', help='chooses which model to use.') + + # additional parameters + parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') + parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') + parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}') + + self.initialized = True + return parser + + def gather_options(self): + """Initialize our parser with basic options(only once). + Add additional model-specific and dataset-specific options. + These options are defined in the function + in model and dataset classes. + """ + if not self.initialized: # check if it has been initialized + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser = self.initialize(parser) + + # get the basic options + if self.cmd_line is None: + opt, _ = parser.parse_known_args() + else: + opt, _ = parser.parse_known_args(self.cmd_line) + + # set cuda visible devices + os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_ids + + # modify model-related parser options + model_name = opt.model + model_option_setter = models.get_option_setter(model_name) + parser = model_option_setter(parser, self.isTrain) + if self.cmd_line is None: + opt, _ = parser.parse_known_args() # parse again with new defaults + else: + opt, _ = parser.parse_known_args(self.cmd_line) # parse again with new defaults + + # modify dataset-related parser options + if opt.dataset_mode: + dataset_name = opt.dataset_mode + dataset_option_setter = data.get_option_setter(dataset_name) + parser = dataset_option_setter(parser, self.isTrain) + + # save and return the parser + self.parser = parser + if self.cmd_line is None: + return parser.parse_args() + else: + return parser.parse_args(self.cmd_line) + + def print_options(self, opt): + """Print and save options + + It will print both current options and default values(if different). + It will save options into a text file / [checkpoints_dir] / opt.txt + """ + message = '' + message += '----------------- Options ---------------\n' + for k, v in sorted(vars(opt).items()): + comment = '' + default = self.parser.get_default(k) + if v != default: + comment = '\t[default: %s]' % str(default) + message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) + message += '----------------- End -------------------' + print(message) + + # save to the disk + expr_dir = os.path.join(opt.checkpoints_dir, opt.name) + util.mkdirs(expr_dir) + file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase)) + try: + with open(file_name, 'wt') as opt_file: + opt_file.write(message) + opt_file.write('\n') + except PermissionError as error: + print("permission error {}".format(error)) + pass + + def parse(self): + """Parse our options, create checkpoints directory suffix, and set up gpu device.""" + opt = self.gather_options() + opt.isTrain = self.isTrain # train or test + + # process opt.suffix + if opt.suffix: + suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' + opt.name = opt.name + suffix + + + # set gpu ids + str_ids = opt.gpu_ids.split(',') + gpu_ids = [] + for str_id in str_ids: + id = int(str_id) + if id >= 0: + gpu_ids.append(id) + opt.world_size = len(gpu_ids) + # if len(opt.gpu_ids) > 0: + # torch.cuda.set_device(gpu_ids[0]) + if opt.world_size == 1: + opt.use_ddp = False + + if opt.phase != 'test': + # set continue_train automatically + if opt.pretrained_name is None: + model_dir = os.path.join(opt.checkpoints_dir, opt.name) + else: + model_dir = os.path.join(opt.checkpoints_dir, opt.pretrained_name) + if os.path.isdir(model_dir): + model_pths = [i for i in os.listdir(model_dir) if i.endswith('pth')] + if os.path.isdir(model_dir) and len(model_pths) != 0: + opt.continue_train= True + + # update the latest epoch count + if opt.continue_train: + if opt.epoch == 'latest': + epoch_counts = [int(i.split('.')[0].split('_')[-1]) for i in model_pths if 'latest' not in i] + if len(epoch_counts) != 0: + opt.epoch_count = max(epoch_counts) + 1 + else: + opt.epoch_count = int(opt.epoch) + 1 + + + self.print_options(opt) + self.opt = opt + return self.opt diff --git a/videoretalking/third_part/face3d/options/inference_options.py b/videoretalking/third_part/face3d/options/inference_options.py new file mode 100644 index 0000000000000000000000000000000000000000..80b9466776e120e0fe3d164217df5071c2114cef --- /dev/null +++ b/videoretalking/third_part/face3d/options/inference_options.py @@ -0,0 +1,23 @@ +from face3d.options.base_options import BaseOptions + + +class InferenceOptions(BaseOptions): + """This class includes test options. + + It also includes shared options defined in BaseOptions. + """ + + def initialize(self, parser): + parser = BaseOptions.initialize(self, parser) # define shared options + parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') + parser.add_argument('--dataset_mode', type=str, default=None, help='chooses how datasets are loaded. [None | flist]') + + parser.add_argument('--input_dir', type=str, help='the folder of the input files') + parser.add_argument('--keypoint_dir', type=str, help='the folder of the keypoint files') + parser.add_argument('--output_dir', type=str, default='mp4', help='the output dir to save the extracted coefficients') + parser.add_argument('--save_split_files', action='store_true', help='save split files or not') + parser.add_argument('--inference_batch_size', type=int, default=8) + + # Dropout and Batchnorm has different behavior during training and test. + self.isTrain = False + return parser diff --git a/videoretalking/third_part/face3d/options/test_options.py b/videoretalking/third_part/face3d/options/test_options.py new file mode 100644 index 0000000000000000000000000000000000000000..f81c0c6eee0549e6fa8762dc4fc4b8573b887fe4 --- /dev/null +++ b/videoretalking/third_part/face3d/options/test_options.py @@ -0,0 +1,21 @@ +"""This script contains the test options for Deep3DFaceRecon_pytorch +""" + +from .base_options import BaseOptions + + +class TestOptions(BaseOptions): + """This class includes test options. + + It also includes shared options defined in BaseOptions. + """ + + def initialize(self, parser): + parser = BaseOptions.initialize(self, parser) # define shared options + parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') + parser.add_argument('--dataset_mode', type=str, default=None, help='chooses how datasets are loaded. [None | flist]') + parser.add_argument('--img_folder', type=str, default='examples', help='folder for test images.') + + # Dropout and Batchnorm has different behavior during training and test. + self.isTrain = False + return parser diff --git a/videoretalking/third_part/face3d/options/train_options.py b/videoretalking/third_part/face3d/options/train_options.py new file mode 100644 index 0000000000000000000000000000000000000000..1100b0e35cc8ef563f41f6b8219510edbef53233 --- /dev/null +++ b/videoretalking/third_part/face3d/options/train_options.py @@ -0,0 +1,53 @@ +"""This script contains the training options for Deep3DFaceRecon_pytorch +""" + +from .base_options import BaseOptions +from util import util + +class TrainOptions(BaseOptions): + """This class includes training options. + + It also includes shared options defined in BaseOptions. + """ + + def initialize(self, parser): + parser = BaseOptions.initialize(self, parser) + # dataset parameters + # for train + parser.add_argument('--data_root', type=str, default='./', help='dataset root') + parser.add_argument('--flist', type=str, default='datalist/train/masks.txt', help='list of mask names of training set') + parser.add_argument('--batch_size', type=int, default=32) + parser.add_argument('--dataset_mode', type=str, default='flist', help='chooses how datasets are loaded. [None | flist]') + parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') + parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data') + parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') + parser.add_argument('--preprocess', type=str, default='shift_scale_rot_flip', help='scaling and cropping of images at load time [shift_scale_rot_flip | shift_scale | shift | shift_rot_flip ]') + parser.add_argument('--use_aug', type=util.str2bool, nargs='?', const=True, default=True, help='whether use data augmentation') + + # for val + parser.add_argument('--flist_val', type=str, default='datalist/val/masks.txt', help='list of mask names of val set') + parser.add_argument('--batch_size_val', type=int, default=32) + + + # visualization parameters + parser.add_argument('--display_freq', type=int, default=1000, help='frequency of showing training results on screen') + parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') + + # network saving and loading parameters + parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') + parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs') + parser.add_argument('--evaluation_freq', type=int, default=5000, help='evaluation freq') + parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration') + parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') + parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') + parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') + parser.add_argument('--pretrained_name', type=str, default=None, help='resume training from another checkpoint') + + # training parameters + parser.add_argument('--n_epochs', type=int, default=20, help='number of epochs with the initial learning rate') + parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate for adam') + parser.add_argument('--lr_policy', type=str, default='step', help='learning rate policy. [linear | step | plateau | cosine]') + parser.add_argument('--lr_decay_epochs', type=int, default=10, help='multiply by a gamma every lr_decay_epochs epoches') + + self.isTrain = True + return parser diff --git a/videoretalking/third_part/face3d/util/BBRegressorParam_r.mat b/videoretalking/third_part/face3d/util/BBRegressorParam_r.mat new file mode 100644 index 0000000000000000000000000000000000000000..1430a94ed2ab570a09f9d980d3585e8aaa933084 Binary files /dev/null and b/videoretalking/third_part/face3d/util/BBRegressorParam_r.mat differ diff --git a/videoretalking/third_part/face3d/util/__init__.py b/videoretalking/third_part/face3d/util/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6433dfd5d5b80b7b6f5ca4218a7725e853c17843 --- /dev/null +++ b/videoretalking/third_part/face3d/util/__init__.py @@ -0,0 +1,2 @@ +"""This package includes a miscellaneous collection of useful helper functions.""" +from face3d.util import * diff --git a/videoretalking/third_part/face3d/util/detect_lm68.py b/videoretalking/third_part/face3d/util/detect_lm68.py new file mode 100644 index 0000000000000000000000000000000000000000..8a2cfd22b342de5c872ff07fc1c2a9920c2985b7 --- /dev/null +++ b/videoretalking/third_part/face3d/util/detect_lm68.py @@ -0,0 +1,106 @@ +import os +import cv2 +import numpy as np +from scipy.io import loadmat +import tensorflow as tf +from util.preprocess import align_for_lm +from shutil import move + +mean_face = np.loadtxt('util/test_mean_face.txt') +mean_face = mean_face.reshape([68, 2]) + +def save_label(labels, save_path): + np.savetxt(save_path, labels) + +def draw_landmarks(img, landmark, save_name): + landmark = landmark + lm_img = np.zeros([img.shape[0], img.shape[1], 3]) + lm_img[:] = img.astype(np.float32) + landmark = np.round(landmark).astype(np.int32) + + for i in range(len(landmark)): + for j in range(-1, 1): + for k in range(-1, 1): + if img.shape[0] - 1 - landmark[i, 1]+j > 0 and \ + img.shape[0] - 1 - landmark[i, 1]+j < img.shape[0] and \ + landmark[i, 0]+k > 0 and \ + landmark[i, 0]+k < img.shape[1]: + lm_img[img.shape[0] - 1 - landmark[i, 1]+j, landmark[i, 0]+k, + :] = np.array([0, 0, 255]) + lm_img = lm_img.astype(np.uint8) + + cv2.imwrite(save_name, lm_img) + + +def load_data(img_name, txt_name): + return cv2.imread(img_name), np.loadtxt(txt_name) + +# create tensorflow graph for landmark detector +def load_lm_graph(graph_filename): + with tf.gfile.GFile(graph_filename, 'rb') as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + + with tf.Graph().as_default() as graph: + tf.import_graph_def(graph_def, name='net') + img_224 = graph.get_tensor_by_name('net/input_imgs:0') + output_lm = graph.get_tensor_by_name('net/lm:0') + lm_sess = tf.Session(graph=graph) + + return lm_sess,img_224,output_lm + +# landmark detection +def detect_68p(img_path,sess,input_op,output_op): + print('detecting landmarks......') + names = [i for i in sorted(os.listdir( + img_path)) if 'jpg' in i or 'png' in i or 'jpeg' in i or 'PNG' in i] + vis_path = os.path.join(img_path, 'vis') + remove_path = os.path.join(img_path, 'remove') + save_path = os.path.join(img_path, 'landmarks') + if not os.path.isdir(vis_path): + os.makedirs(vis_path) + if not os.path.isdir(remove_path): + os.makedirs(remove_path) + if not os.path.isdir(save_path): + os.makedirs(save_path) + + for i in range(0, len(names)): + name = names[i] + print('%05d' % (i), ' ', name) + full_image_name = os.path.join(img_path, name) + txt_name = '.'.join(name.split('.')[:-1]) + '.txt' + full_txt_name = os.path.join(img_path, 'detections', txt_name) # 5 facial landmark path for each image + + # if an image does not have detected 5 facial landmarks, remove it from the training list + if not os.path.isfile(full_txt_name): + move(full_image_name, os.path.join(remove_path, name)) + continue + + # load data + img, five_points = load_data(full_image_name, full_txt_name) + input_img, scale, bbox = align_for_lm(img, five_points) # align for 68 landmark detection + + # if the alignment fails, remove corresponding image from the training list + if scale == 0: + move(full_txt_name, os.path.join( + remove_path, txt_name)) + move(full_image_name, os.path.join(remove_path, name)) + continue + + # detect landmarks + input_img = np.reshape( + input_img, [1, 224, 224, 3]).astype(np.float32) + landmark = sess.run( + output_op, feed_dict={input_op: input_img}) + + # transform back to original image coordinate + landmark = landmark.reshape([68, 2]) + mean_face + landmark[:, 1] = 223 - landmark[:, 1] + landmark = landmark / scale + landmark[:, 0] = landmark[:, 0] + bbox[0] + landmark[:, 1] = landmark[:, 1] + bbox[1] + landmark[:, 1] = img.shape[0] - 1 - landmark[:, 1] + + if i % 100 == 0: + draw_landmarks(img, landmark, os.path.join(vis_path, name)) + save_label(landmark, os.path.join(save_path, txt_name)) diff --git a/videoretalking/third_part/face3d/util/generate_list.py b/videoretalking/third_part/face3d/util/generate_list.py new file mode 100644 index 0000000000000000000000000000000000000000..ebe93fcc5c61fbc79f4cd004a8d1bdd10ece16eb --- /dev/null +++ b/videoretalking/third_part/face3d/util/generate_list.py @@ -0,0 +1,34 @@ +"""This script is to generate training list files for Deep3DFaceRecon_pytorch +""" + +import os + +# save path to training data +def write_list(lms_list, imgs_list, msks_list, mode='train',save_folder='datalist', save_name=''): + save_path = os.path.join(save_folder, mode) + if not os.path.isdir(save_path): + os.makedirs(save_path) + with open(os.path.join(save_path, save_name + 'landmarks.txt'), 'w') as fd: + fd.writelines([i + '\n' for i in lms_list]) + + with open(os.path.join(save_path, save_name + 'images.txt'), 'w') as fd: + fd.writelines([i + '\n' for i in imgs_list]) + + with open(os.path.join(save_path, save_name + 'masks.txt'), 'w') as fd: + fd.writelines([i + '\n' for i in msks_list]) + +# check if the path is valid +def check_list(rlms_list, rimgs_list, rmsks_list): + lms_list, imgs_list, msks_list = [], [], [] + for i in range(len(rlms_list)): + flag = 'false' + lm_path = rlms_list[i] + im_path = rimgs_list[i] + msk_path = rmsks_list[i] + if os.path.isfile(lm_path) and os.path.isfile(im_path) and os.path.isfile(msk_path): + flag = 'true' + lms_list.append(rlms_list[i]) + imgs_list.append(rimgs_list[i]) + msks_list.append(rmsks_list[i]) + print(i, rlms_list[i], flag) + return lms_list, imgs_list, msks_list diff --git a/videoretalking/third_part/face3d/util/html.py b/videoretalking/third_part/face3d/util/html.py new file mode 100644 index 0000000000000000000000000000000000000000..318c9ecea5fb72433ac90f8b72ad863f3337ed9c --- /dev/null +++ b/videoretalking/third_part/face3d/util/html.py @@ -0,0 +1,86 @@ +import dominate +from dominate.tags import meta, h3, table, tr, td, p, a, img, br +import os + + +class HTML: + """This HTML class allows us to save images and write texts into a single HTML file. + + It consists of functions such as (add a text header to the HTML file), + (add a row of images to the HTML file), and (save the HTML to the disk). + It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API. + """ + + def __init__(self, web_dir, title, refresh=0): + """Initialize the HTML classes + + Parameters: + web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0: + with self.doc.head: + meta(http_equiv="refresh", content=str(refresh)) + + def get_image_dir(self): + """Return the directory that stores images""" + return self.img_dir + + def add_header(self, text): + """Insert a header to the HTML file + + Parameters: + text (str) -- the header text + """ + with self.doc: + h3(text) + + def add_images(self, ims, txts, links, width=400): + """add images to the HTML file + + Parameters: + ims (str list) -- a list of image paths + txts (str list) -- a list of image names shown on the website + links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page + """ + self.t = table(border=1, style="table-layout: fixed;") # Insert a table + self.doc.add(self.t) + with self.t: + with tr(): + for im, txt, link in zip(ims, txts, links): + with td(style="word-wrap: break-word;", halign="center", valign="top"): + with p(): + with a(href=os.path.join('images', link)): + img(style="width:%dpx" % width, src=os.path.join('images', im)) + br() + p(txt) + + def save(self): + """save the current content to the HTML file""" + html_file = '%s/index.html' % self.web_dir + f = open(html_file, 'wt') + f.write(self.doc.render()) + f.close() + + +if __name__ == '__main__': # we show an example usage here. + html = HTML('web/', 'test_html') + html.add_header('hello world') + + ims, txts, links = [], [], [] + for n in range(4): + ims.append('image_%d.png' % n) + txts.append('text_%d' % n) + links.append('image_%d.png' % n) + html.add_images(ims, txts, links) + html.save() diff --git a/videoretalking/third_part/face3d/util/load_mats.py b/videoretalking/third_part/face3d/util/load_mats.py new file mode 100644 index 0000000000000000000000000000000000000000..e42da746a2f64b4faeed9793033feaa0ab05af64 --- /dev/null +++ b/videoretalking/third_part/face3d/util/load_mats.py @@ -0,0 +1,120 @@ +"""This script is to load 3D face model for Deep3DFaceRecon_pytorch +""" + +import numpy as np +from PIL import Image +from scipy.io import loadmat, savemat +from array import array +import os.path as osp + +# load expression basis +def LoadExpBasis(bfm_folder='BFM'): + n_vertex = 53215 + Expbin = open(osp.join(bfm_folder, 'Exp_Pca.bin'), 'rb') + exp_dim = array('i') + exp_dim.fromfile(Expbin, 1) + expMU = array('f') + expPC = array('f') + expMU.fromfile(Expbin, 3*n_vertex) + expPC.fromfile(Expbin, 3*exp_dim[0]*n_vertex) + Expbin.close() + + expPC = np.array(expPC) + expPC = np.reshape(expPC, [exp_dim[0], -1]) + expPC = np.transpose(expPC) + + expEV = np.loadtxt(osp.join(bfm_folder, 'std_exp.txt')) + + return expPC, expEV + + +# transfer original BFM09 to our face model +def transferBFM09(bfm_folder='BFM'): + print('Transfer BFM09 to BFM_model_front......') + original_BFM = loadmat(osp.join(bfm_folder, '01_MorphableModel.mat')) + shapePC = original_BFM['shapePC'] # shape basis + shapeEV = original_BFM['shapeEV'] # corresponding eigen value + shapeMU = original_BFM['shapeMU'] # mean face + texPC = original_BFM['texPC'] # texture basis + texEV = original_BFM['texEV'] # eigen value + texMU = original_BFM['texMU'] # mean texture + + expPC, expEV = LoadExpBasis() + + # transfer BFM09 to our face model + + idBase = shapePC*np.reshape(shapeEV, [-1, 199]) + idBase = idBase/1e5 # unify the scale to decimeter + idBase = idBase[:, :80] # use only first 80 basis + + exBase = expPC*np.reshape(expEV, [-1, 79]) + exBase = exBase/1e5 # unify the scale to decimeter + exBase = exBase[:, :64] # use only first 64 basis + + texBase = texPC*np.reshape(texEV, [-1, 199]) + texBase = texBase[:, :80] # use only first 80 basis + + # our face model is cropped along face landmarks and contains only 35709 vertex. + # original BFM09 contains 53490 vertex, and expression basis provided by Guo et al. contains 53215 vertex. + # thus we select corresponding vertex to get our face model. + + index_exp = loadmat(osp.join(bfm_folder, 'BFM_front_idx.mat')) + index_exp = index_exp['idx'].astype(np.int32) - 1 # starts from 0 (to 53215) + + index_shape = loadmat(osp.join(bfm_folder, 'BFM_exp_idx.mat')) + index_shape = index_shape['trimIndex'].astype( + np.int32) - 1 # starts from 0 (to 53490) + index_shape = index_shape[index_exp] + + idBase = np.reshape(idBase, [-1, 3, 80]) + idBase = idBase[index_shape, :, :] + idBase = np.reshape(idBase, [-1, 80]) + + texBase = np.reshape(texBase, [-1, 3, 80]) + texBase = texBase[index_shape, :, :] + texBase = np.reshape(texBase, [-1, 80]) + + exBase = np.reshape(exBase, [-1, 3, 64]) + exBase = exBase[index_exp, :, :] + exBase = np.reshape(exBase, [-1, 64]) + + meanshape = np.reshape(shapeMU, [-1, 3])/1e5 + meanshape = meanshape[index_shape, :] + meanshape = np.reshape(meanshape, [1, -1]) + + meantex = np.reshape(texMU, [-1, 3]) + meantex = meantex[index_shape, :] + meantex = np.reshape(meantex, [1, -1]) + + # other info contains triangles, region used for computing photometric loss, + # region used for skin texture regularization, and 68 landmarks index etc. + other_info = loadmat(osp.join(bfm_folder, 'facemodel_info.mat')) + frontmask2_idx = other_info['frontmask2_idx'] + skinmask = other_info['skinmask'] + keypoints = other_info['keypoints'] + point_buf = other_info['point_buf'] + tri = other_info['tri'] + tri_mask2 = other_info['tri_mask2'] + + # save our face model + savemat(osp.join(bfm_folder, 'BFM_model_front.mat'), {'meanshape': meanshape, 'meantex': meantex, 'idBase': idBase, 'exBase': exBase, 'texBase': texBase, + 'tri': tri, 'point_buf': point_buf, 'tri_mask2': tri_mask2, 'keypoints': keypoints, 'frontmask2_idx': frontmask2_idx, 'skinmask': skinmask}) + + +# load landmarks for standard face, which is used for image preprocessing +def load_lm3d(bfm_folder): + + Lm3D = loadmat(osp.join(bfm_folder, 'similarity_Lm3D_all.mat')) + Lm3D = Lm3D['lm'] + + # calculate 5 facial landmarks using 68 landmarks + lm_idx = np.array([31, 37, 40, 43, 46, 49, 55]) - 1 + Lm3D = np.stack([Lm3D[lm_idx[0], :], np.mean(Lm3D[lm_idx[[1, 2]], :], 0), np.mean( + Lm3D[lm_idx[[3, 4]], :], 0), Lm3D[lm_idx[5], :], Lm3D[lm_idx[6], :]], axis=0) + Lm3D = Lm3D[[1, 2, 0, 3, 4], :] + + return Lm3D + + +if __name__ == '__main__': + transferBFM09() \ No newline at end of file diff --git a/videoretalking/third_part/face3d/util/nvdiffrast.py b/videoretalking/third_part/face3d/util/nvdiffrast.py new file mode 100644 index 0000000000000000000000000000000000000000..a6d82af0c33df2181293a5ae3e87020e2b55ac3d --- /dev/null +++ b/videoretalking/third_part/face3d/util/nvdiffrast.py @@ -0,0 +1,89 @@ +"""This script is the differentiable renderer for Deep3DFaceRecon_pytorch + Attention, antialiasing step is missing in current version. +""" + +import torch +import torch.nn.functional as F +import kornia +from kornia.geometry.camera import pixel2cam +import numpy as np +from typing import List +import nvdiffrast.torch as dr +from scipy.io import loadmat +from torch import nn + +def ndc_projection(x=0.1, n=1.0, f=50.0): + return np.array([[n/x, 0, 0, 0], + [ 0, n/-x, 0, 0], + [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)], + [ 0, 0, -1, 0]]).astype(np.float32) + +class MeshRenderer(nn.Module): + def __init__(self, + rasterize_fov, + znear=0.1, + zfar=10, + rasterize_size=224): + super(MeshRenderer, self).__init__() + + x = np.tan(np.deg2rad(rasterize_fov * 0.5)) * znear + self.ndc_proj = torch.tensor(ndc_projection(x=x, n=znear, f=zfar)).matmul( + torch.diag(torch.tensor([1., -1, -1, 1]))) + self.rasterize_size = rasterize_size + self.glctx = None + + def forward(self, vertex, tri, feat=None): + """ + Return: + mask -- torch.tensor, size (B, 1, H, W) + depth -- torch.tensor, size (B, 1, H, W) + features(optional) -- torch.tensor, size (B, C, H, W) if feat is not None + + Parameters: + vertex -- torch.tensor, size (B, N, 3) + tri -- torch.tensor, size (B, M, 3) or (M, 3), triangles + feat(optional) -- torch.tensor, size (B, C), features + """ + device = vertex.device + rsize = int(self.rasterize_size) + ndc_proj = self.ndc_proj.to(device) + # trans to homogeneous coordinates of 3d vertices, the direction of y is the same as v + if vertex.shape[-1] == 3: + vertex = torch.cat([vertex, torch.ones([*vertex.shape[:2], 1]).to(device)], dim=-1) + vertex[..., 1] = -vertex[..., 1] + + + vertex_ndc = vertex @ ndc_proj.t() + if self.glctx is None: + self.glctx = dr.RasterizeGLContext(device=device) + print("create glctx on device cuda:%d"%device.index) + + ranges = None + if isinstance(tri, List) or len(tri.shape) == 3: + vum = vertex_ndc.shape[1] + fnum = torch.tensor([f.shape[0] for f in tri]).unsqueeze(1).to(device) + fstartidx = torch.cumsum(fnum, dim=0) - fnum + ranges = torch.cat([fstartidx, fnum], axis=1).type(torch.int32).cpu() + for i in range(tri.shape[0]): + tri[i] = tri[i] + i*vum + vertex_ndc = torch.cat(vertex_ndc, dim=0) + tri = torch.cat(tri, dim=0) + + # for range_mode vetex: [B*N, 4], tri: [B*M, 3], for instance_mode vetex: [B, N, 4], tri: [M, 3] + tri = tri.type(torch.int32).contiguous() + rast_out, _ = dr.rasterize(self.glctx, vertex_ndc.contiguous(), tri, resolution=[rsize, rsize], ranges=ranges) + + depth, _ = dr.interpolate(vertex.reshape([-1,4])[...,2].unsqueeze(1).contiguous(), rast_out, tri) + depth = depth.permute(0, 3, 1, 2) + mask = (rast_out[..., 3] > 0).float().unsqueeze(1) + depth = mask * depth + + + image = None + if feat is not None: + image, _ = dr.interpolate(feat, rast_out, tri) + image = image.permute(0, 3, 1, 2) + image = mask * image + + return mask, depth, image + diff --git a/videoretalking/third_part/face3d/util/preprocess.py b/videoretalking/third_part/face3d/util/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..ea92da58593115a8e5f72a3cf98b02250d195d80 --- /dev/null +++ b/videoretalking/third_part/face3d/util/preprocess.py @@ -0,0 +1,230 @@ +""" +This script contains the image preprocessing code for Deep3DFaceRecon_pytorch +""" + +import numpy as np +from scipy.io import loadmat +from PIL import Image +import cv2 +import os +from skimage import transform as trans +import torch +import warnings +warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning) +warnings.filterwarnings("ignore", category=FutureWarning) + + +# calculating least square problem for image alignment +def POS(xp, x): + npts = xp.shape[1] + + A = np.zeros([2*npts, 8]) + + A[0:2*npts-1:2, 0:3] = x.transpose() + A[0:2*npts-1:2, 3] = 1 + + A[1:2*npts:2, 4:7] = x.transpose() + A[1:2*npts:2, 7] = 1 + + b = np.reshape(xp.transpose(), [2*npts, 1]) + + k, _, _, _ = np.linalg.lstsq(A, b) + + R1 = k[0:3] + R2 = k[4:7] + sTx = k[3] + sTy = k[7] + s = (np.linalg.norm(R1) + np.linalg.norm(R2))/2 + t = np.stack([sTx, sTy], axis=0) + + return t, s + +# bounding box for 68 landmark detection +def BBRegression(points, params): + + w1 = params['W1'] + b1 = params['B1'] + w2 = params['W2'] + b2 = params['B2'] + data = points.copy() + data = data.reshape([5, 2]) + data_mean = np.mean(data, axis=0) + x_mean = data_mean[0] + y_mean = data_mean[1] + data[:, 0] = data[:, 0] - x_mean + data[:, 1] = data[:, 1] - y_mean + + rms = np.sqrt(np.sum(data ** 2)/5) + data = data / rms + data = data.reshape([1, 10]) + data = np.transpose(data) + inputs = np.matmul(w1, data) + b1 + inputs = 2 / (1 + np.exp(-2 * inputs)) - 1 + inputs = np.matmul(w2, inputs) + b2 + inputs = np.transpose(inputs) + x = inputs[:, 0] * rms + x_mean + y = inputs[:, 1] * rms + y_mean + w = 224/inputs[:, 2] * rms + rects = [x, y, w, w] + return np.array(rects).reshape([4]) + +# utils for landmark detection +def img_padding(img, box): + success = True + bbox = box.copy() + res = np.zeros([2*img.shape[0], 2*img.shape[1], 3]) + res[img.shape[0] // 2: img.shape[0] + img.shape[0] // + 2, img.shape[1] // 2: img.shape[1] + img.shape[1]//2] = img + + bbox[0] = bbox[0] + img.shape[1] // 2 + bbox[1] = bbox[1] + img.shape[0] // 2 + if bbox[0] < 0 or bbox[1] < 0: + success = False + return res, bbox, success + +# utils for landmark detection +def crop(img, bbox): + padded_img, padded_bbox, flag = img_padding(img, bbox) + if flag: + crop_img = padded_img[padded_bbox[1]: padded_bbox[1] + + padded_bbox[3], padded_bbox[0]: padded_bbox[0] + padded_bbox[2]] + crop_img = cv2.resize(crop_img.astype(np.uint8), + (224, 224), interpolation=cv2.INTER_CUBIC) + scale = 224 / padded_bbox[3] + return crop_img, scale + else: + return padded_img, 0 + +# utils for landmark detection +def scale_trans(img, lm, t, s): + imgw = img.shape[1] + imgh = img.shape[0] + M_s = np.array([[1, 0, -t[0] + imgw//2 + 0.5], [0, 1, -imgh//2 + t[1]]], + dtype=np.float32) + img = cv2.warpAffine(img, M_s, (imgw, imgh)) + w = int(imgw / s * 100) + h = int(imgh / s * 100) + img = cv2.resize(img, (w, h)) + lm = np.stack([lm[:, 0] - t[0] + imgw // 2, lm[:, 1] - + t[1] + imgh // 2], axis=1) / s * 100 + + left = w//2 - 112 + up = h//2 - 112 + bbox = [left, up, 224, 224] + cropped_img, scale2 = crop(img, bbox) + assert(scale2!=0) + t1 = np.array([bbox[0], bbox[1]]) + + # back to raw img s * crop + s * t1 + t2 + t1 = np.array([w//2 - 112, h//2 - 112]) + scale = s / 100 + t2 = np.array([t[0] - imgw/2, t[1] - imgh / 2]) + inv = (scale/scale2, scale * t1 + t2.reshape([2])) + return cropped_img, inv + +# utils for landmark detection +def align_for_lm(img, five_points): + five_points = np.array(five_points).reshape([1, 10]) + params = loadmat('util/BBRegressorParam_r.mat') + bbox = BBRegression(five_points, params) + assert(bbox[2] != 0) + bbox = np.round(bbox).astype(np.int32) + crop_img, scale = crop(img, bbox) + return crop_img, scale, bbox + + +# resize and crop images for face reconstruction +def resize_n_crop_img(img, lm, t, s, target_size=224., mask=None): + w0, h0 = img.size + w = (w0*s).astype(np.int32) + h = (h0*s).astype(np.int32) + left = (w/2 - target_size/2 + float((t[0] - w0/2)*s)).astype(np.int32) + right = left + target_size + up = (h/2 - target_size/2 + float((h0/2 - t[1])*s)).astype(np.int32) + below = up + target_size + + img = img.resize((w, h), resample=Image.BICUBIC) + img = img.crop((left, up, right, below)) + + if mask is not None: + mask = mask.resize((w, h), resample=Image.BICUBIC) + mask = mask.crop((left, up, right, below)) + + lm = np.stack([lm[:, 0] - t[0] + w0/2, lm[:, 1] - + t[1] + h0/2], axis=1)*s + lm = lm - np.reshape( + np.array([(w/2 - target_size/2), (h/2-target_size/2)]), [1, 2]) + + return img, lm, mask + +# utils for face reconstruction +def extract_5p(lm): + lm_idx = np.array([31, 37, 40, 43, 46, 49, 55]) - 1 + lm5p = np.stack([lm[lm_idx[0], :], np.mean(lm[lm_idx[[1, 2]], :], 0), np.mean( + lm[lm_idx[[3, 4]], :], 0), lm[lm_idx[5], :], lm[lm_idx[6], :]], axis=0) + lm5p = lm5p[[1, 2, 0, 3, 4], :] + return lm5p + +# utils for face reconstruction +def align_img(img, lm, lm3D, mask=None, target_size=224., rescale_factor=102.): + """ + Return: + transparams --numpy.array (raw_W, raw_H, scale, tx, ty) + img_new --PIL.Image (target_size, target_size, 3) + lm_new --numpy.array (68, 2), y direction is opposite to v direction + mask_new --PIL.Image (target_size, target_size) + + Parameters: + img --PIL.Image (raw_H, raw_W, 3) + lm --numpy.array (68, 2), y direction is opposite to v direction + lm3D --numpy.array (5, 3) + mask --PIL.Image (raw_H, raw_W, 3) + """ + + w0, h0 = img.size + if lm.shape[0] != 5: + lm5p = extract_5p(lm) + else: + lm5p = lm + + # calculate translation and scale factors using 5 facial landmarks and standard landmarks of a 3D face + t, s = POS(lm5p.transpose(), lm3D.transpose()) + s = rescale_factor/s + + # processing the image + img_new, lm_new, mask_new = resize_n_crop_img(img, lm, t, s, target_size=target_size, mask=mask) + trans_params = np.array([w0, h0, s, t[0], t[1]]) + + return trans_params, img_new, lm_new, mask_new + +# utils for face recognition model +def estimate_norm(lm_68p, H): + # from https://github.com/deepinsight/insightface/blob/c61d3cd208a603dfa4a338bd743b320ce3e94730/recognition/common/face_align.py#L68 + """ + Return: + trans_m --numpy.array (2, 3) + Parameters: + lm --numpy.array (68, 2), y direction is opposite to v direction + H --int/float , image height + """ + lm = extract_5p(lm_68p) + lm[:, -1] = H - 1 - lm[:, -1] + tform = trans.SimilarityTransform() + src = np.array( + [[38.2946, 51.6963], [73.5318, 51.5014], [56.0252, 71.7366], + [41.5493, 92.3655], [70.7299, 92.2041]], + dtype=np.float32) + tform.estimate(lm, src) + M = tform.params + if np.linalg.det(M) == 0: + M = np.eye(3) + + return M[0:2, :] + +def estimate_norm_torch(lm_68p, H): + lm_68p_ = lm_68p.detach().cpu().numpy() + M = [] + for i in range(lm_68p_.shape[0]): + M.append(estimate_norm(lm_68p_[i], H)) + M = torch.tensor(np.array(M), dtype=torch.float32).to(lm_68p.device) + return M diff --git a/videoretalking/third_part/face3d/util/skin_mask.py b/videoretalking/third_part/face3d/util/skin_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..ed764759038f77b35d45448b344d4347498ca427 --- /dev/null +++ b/videoretalking/third_part/face3d/util/skin_mask.py @@ -0,0 +1,125 @@ +"""This script is to generate skin attention mask for Deep3DFaceRecon_pytorch +""" + +import math +import numpy as np +import os +import cv2 + +class GMM: + def __init__(self, dim, num, w, mu, cov, cov_det, cov_inv): + self.dim = dim # feature dimension + self.num = num # number of Gaussian components + self.w = w # weights of Gaussian components (a list of scalars) + self.mu= mu # mean of Gaussian components (a list of 1xdim vectors) + self.cov = cov # covariance matrix of Gaussian components (a list of dimxdim matrices) + self.cov_det = cov_det # pre-computed determinet of covariance matrices (a list of scalars) + self.cov_inv = cov_inv # pre-computed inverse covariance matrices (a list of dimxdim matrices) + + self.factor = [0]*num + for i in range(self.num): + self.factor[i] = (2*math.pi)**(self.dim/2) * self.cov_det[i]**0.5 + + def likelihood(self, data): + assert(data.shape[1] == self.dim) + N = data.shape[0] + lh = np.zeros(N) + + for i in range(self.num): + data_ = data - self.mu[i] + + tmp = np.matmul(data_,self.cov_inv[i]) * data_ + tmp = np.sum(tmp,axis=1) + power = -0.5 * tmp + + p = np.array([math.exp(power[j]) for j in range(N)]) + p = p/self.factor[i] + lh += p*self.w[i] + + return lh + + +def _rgb2ycbcr(rgb): + m = np.array([[65.481, 128.553, 24.966], + [-37.797, -74.203, 112], + [112, -93.786, -18.214]]) + shape = rgb.shape + rgb = rgb.reshape((shape[0] * shape[1], 3)) + ycbcr = np.dot(rgb, m.transpose() / 255.) + ycbcr[:, 0] += 16. + ycbcr[:, 1:] += 128. + return ycbcr.reshape(shape) + + +def _bgr2ycbcr(bgr): + rgb = bgr[..., ::-1] + return _rgb2ycbcr(rgb) + + +gmm_skin_w = [0.24063933, 0.16365987, 0.26034665, 0.33535415] +gmm_skin_mu = [np.array([113.71862, 103.39613, 164.08226]), + np.array([150.19858, 105.18467, 155.51428]), + np.array([183.92976, 107.62468, 152.71820]), + np.array([114.90524, 113.59782, 151.38217])] +gmm_skin_cov_det = [5692842.5, 5851930.5, 2329131., 1585971.] +gmm_skin_cov_inv = [np.array([[0.0019472069, 0.0020450759, -0.00060243998],[0.0020450759, 0.017700525, 0.0051420014],[-0.00060243998, 0.0051420014, 0.0081308950]]), + np.array([[0.0027110141, 0.0011036990, 0.0023122299],[0.0011036990, 0.010707724, 0.010742856],[0.0023122299, 0.010742856, 0.017481629]]), + np.array([[0.0048026871, 0.00022935172, 0.0077668377],[0.00022935172, 0.011729696, 0.0081661865],[0.0077668377, 0.0081661865, 0.025374353]]), + np.array([[0.0011989699, 0.0022453172, -0.0010748957],[0.0022453172, 0.047758564, 0.020332102],[-0.0010748957, 0.020332102, 0.024502251]])] + +gmm_skin = GMM(3, 4, gmm_skin_w, gmm_skin_mu, [], gmm_skin_cov_det, gmm_skin_cov_inv) + +gmm_nonskin_w = [0.12791070, 0.31130761, 0.34245777, 0.21832393] +gmm_nonskin_mu = [np.array([99.200851, 112.07533, 140.20602]), + np.array([110.91392, 125.52969, 130.19237]), + np.array([129.75864, 129.96107, 126.96808]), + np.array([112.29587, 128.85121, 129.05431])] +gmm_nonskin_cov_det = [458703648., 6466488., 90611376., 133097.63] +gmm_nonskin_cov_inv = [np.array([[0.00085371657, 0.00071197288, 0.00023958916],[0.00071197288, 0.0025935620, 0.00076557708],[0.00023958916, 0.00076557708, 0.0015042332]]), + np.array([[0.00024650150, 0.00045542428, 0.00015019422],[0.00045542428, 0.026412144, 0.018419769],[0.00015019422, 0.018419769, 0.037497383]]), + np.array([[0.00037054974, 0.00038146760, 0.00040408765],[0.00038146760, 0.0085505722, 0.0079136286],[0.00040408765, 0.0079136286, 0.010982352]]), + np.array([[0.00013709733, 0.00051228428, 0.00012777430],[0.00051228428, 0.28237113, 0.10528370],[0.00012777430, 0.10528370, 0.23468947]])] + +gmm_nonskin = GMM(3, 4, gmm_nonskin_w, gmm_nonskin_mu, [], gmm_nonskin_cov_det, gmm_nonskin_cov_inv) + +prior_skin = 0.8 +prior_nonskin = 1 - prior_skin + + +# calculate skin attention mask +def skinmask(imbgr): + im = _bgr2ycbcr(imbgr) + + data = im.reshape((-1,3)) + + lh_skin = gmm_skin.likelihood(data) + lh_nonskin = gmm_nonskin.likelihood(data) + + tmp1 = prior_skin * lh_skin + tmp2 = prior_nonskin * lh_nonskin + post_skin = tmp1 / (tmp1+tmp2) # posterior probability + + post_skin = post_skin.reshape((im.shape[0],im.shape[1])) + + post_skin = np.round(post_skin*255) + post_skin = post_skin.astype(np.uint8) + post_skin = np.tile(np.expand_dims(post_skin,2),[1,1,3]) # reshape to H*W*3 + + return post_skin + + +def get_skin_mask(img_path): + print('generating skin masks......') + names = [i for i in sorted(os.listdir( + img_path)) if 'jpg' in i or 'png' in i or 'jpeg' in i or 'PNG' in i] + save_path = os.path.join(img_path, 'mask') + if not os.path.isdir(save_path): + os.makedirs(save_path) + + for i in range(0, len(names)): + name = names[i] + print('%05d' % (i), ' ', name) + full_image_name = os.path.join(img_path, name) + img = cv2.imread(full_image_name).astype(np.float32) + skin_img = skinmask(img) + cv2.imwrite(os.path.join(save_path, name), skin_img.astype(np.uint8)) diff --git a/videoretalking/third_part/face3d/util/test_mean_face.txt b/videoretalking/third_part/face3d/util/test_mean_face.txt new file mode 100644 index 0000000000000000000000000000000000000000..1637648acf5a61cbc71b317c845414bb16d0150c --- /dev/null +++ b/videoretalking/third_part/face3d/util/test_mean_face.txt @@ -0,0 +1,136 @@ +-5.228591537475585938e+01 +2.078247070312500000e-01 +-5.064269638061523438e+01 +-1.315765380859375000e+01 +-4.952939224243164062e+01 +-2.592591094970703125e+01 +-4.793047332763671875e+01 +-3.832135772705078125e+01 +-4.512159729003906250e+01 +-5.059623336791992188e+01 +-3.917720794677734375e+01 +-6.043736648559570312e+01 +-2.929953765869140625e+01 +-6.861183166503906250e+01 +-1.719801330566406250e+01 +-7.572736358642578125e+01 +-1.961936950683593750e+00 +-7.862001037597656250e+01 +1.467941284179687500e+01 +-7.607844543457031250e+01 +2.744073486328125000e+01 +-6.915261840820312500e+01 +3.855677795410156250e+01 +-5.950350570678710938e+01 +4.478240966796875000e+01 +-4.867547225952148438e+01 +4.714337158203125000e+01 +-3.800830078125000000e+01 +4.940315246582031250e+01 +-2.496297454833984375e+01 +5.117234802246093750e+01 +-1.241538238525390625e+01 +5.190507507324218750e+01 +8.244247436523437500e-01 +-4.150688934326171875e+01 +2.386329650878906250e+01 +-3.570307159423828125e+01 +3.017010498046875000e+01 +-2.790358734130859375e+01 +3.212951660156250000e+01 +-1.941773223876953125e+01 +3.156523132324218750e+01 +-1.138106536865234375e+01 +2.841992187500000000e+01 +5.993263244628906250e+00 +2.895182800292968750e+01 +1.343590545654296875e+01 +3.189880371093750000e+01 +2.203153991699218750e+01 +3.302221679687500000e+01 +2.992478942871093750e+01 +3.099150085449218750e+01 +3.628388977050781250e+01 +2.765748596191406250e+01 +-1.933914184570312500e+00 +1.405374145507812500e+01 +-2.153038024902343750e+00 +5.772636413574218750e+00 +-2.270050048828125000e+00 +-2.121643066406250000e+00 +-2.218330383300781250e+00 +-1.068978118896484375e+01 +-1.187252044677734375e+01 +-1.997912597656250000e+01 +-6.879402160644531250e+00 +-2.143579864501953125e+01 +-1.227821350097656250e+00 +-2.193494415283203125e+01 +4.623237609863281250e+00 +-2.152721405029296875e+01 +9.721397399902343750e+00 +-1.953671264648437500e+01 +-3.648714447021484375e+01 +9.811126708984375000e+00 +-3.130242919921875000e+01 +1.422447967529296875e+01 +-2.212834930419921875e+01 +1.493019866943359375e+01 +-1.500880432128906250e+01 +1.073588562011718750e+01 +-2.095037078857421875e+01 +9.054298400878906250e+00 +-3.050099182128906250e+01 +8.704177856445312500e+00 +1.173237609863281250e+01 +1.054329681396484375e+01 +1.856353759765625000e+01 +1.535009765625000000e+01 +2.893331909179687500e+01 +1.451992797851562500e+01 +3.452944946289062500e+01 +1.065280151367187500e+01 +2.875990295410156250e+01 +8.654792785644531250e+00 +1.942100524902343750e+01 +9.422447204589843750e+00 +-2.204488372802734375e+01 +-3.983994293212890625e+01 +-1.324458312988281250e+01 +-3.467377471923828125e+01 +-6.749649047851562500e+00 +-3.092894744873046875e+01 +-9.183349609375000000e-01 +-3.196458435058593750e+01 +4.220649719238281250e+00 +-3.090406036376953125e+01 +1.089889526367187500e+01 +-3.497008514404296875e+01 +1.874589538574218750e+01 +-4.065438079833984375e+01 +1.124106597900390625e+01 +-4.438417816162109375e+01 +5.181709289550781250e+00 +-4.649170684814453125e+01 +-1.158607482910156250e+00 +-4.680406951904296875e+01 +-7.918922424316406250e+00 +-4.671575164794921875e+01 +-1.452505493164062500e+01 +-4.416526031494140625e+01 +-2.005007171630859375e+01 +-3.997841644287109375e+01 +-1.054919433593750000e+01 +-3.849683380126953125e+01 +-1.051826477050781250e+00 +-3.794863128662109375e+01 +6.412681579589843750e+00 +-3.804645538330078125e+01 +1.627674865722656250e+01 +-4.039697265625000000e+01 +6.373878479003906250e+00 +-4.087213897705078125e+01 +-8.551712036132812500e-01 +-4.157129669189453125e+01 +-1.014953613281250000e+01 +-4.128469085693359375e+01 diff --git a/videoretalking/third_part/face3d/util/util.py b/videoretalking/third_part/face3d/util/util.py new file mode 100644 index 0000000000000000000000000000000000000000..05f510f12e047666066a620e10376d4e2f910804 --- /dev/null +++ b/videoretalking/third_part/face3d/util/util.py @@ -0,0 +1,208 @@ +"""This script contains basic utilities for Deep3DFaceRecon_pytorch +""" +from __future__ import print_function +import numpy as np +import torch +from PIL import Image +import os +import importlib +import argparse +from argparse import Namespace +import torchvision + + +def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + +def copyconf(default_opt, **kwargs): + conf = Namespace(**vars(default_opt)) + for key in kwargs: + setattr(conf, key, kwargs[key]) + return conf + +def genvalconf(train_opt, **kwargs): + conf = Namespace(**vars(train_opt)) + attr_dict = train_opt.__dict__ + for key, value in attr_dict.items(): + if 'val' in key and key.split('_')[0] in attr_dict: + setattr(conf, key.split('_')[0], value) + + for key in kwargs: + setattr(conf, key, kwargs[key]) + + return conf + +def find_class_in_module(target_cls_name, module): + target_cls_name = target_cls_name.replace('_', '').lower() + clslib = importlib.import_module(module) + cls = None + for name, clsobj in clslib.__dict__.items(): + if name.lower() == target_cls_name: + cls = clsobj + + assert cls is not None, "In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name) + + return cls + + +def tensor2im(input_image, imtype=np.uint8): + """"Converts a Tensor array into a numpy image array. + + Parameters: + input_image (tensor) -- the input image tensor array, range(0, 1) + imtype (type) -- the desired type of the converted numpy array + """ + if not isinstance(input_image, np.ndarray): + if isinstance(input_image, torch.Tensor): # get the data from a variable + image_tensor = input_image.data + else: + return input_image + image_numpy = image_tensor.clamp(0.0, 1.0).cpu().float().numpy() # convert it into a numpy array + if image_numpy.shape[0] == 1: # grayscale to RGB + image_numpy = np.tile(image_numpy, (3, 1, 1)) + image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 # post-processing: transpose and scaling + else: # if it is a numpy array, do nothing + image_numpy = input_image + return image_numpy.astype(imtype) + + +def diagnose_network(net, name='network'): + """Calculate and print the mean of average absolute(gradients) + + Parameters: + net (torch network) -- Torch network + name (str) -- the name of the network + """ + mean = 0.0 + count = 0 + for param in net.parameters(): + if param.grad is not None: + mean += torch.mean(torch.abs(param.grad.data)) + count += 1 + if count > 0: + mean = mean / count + print(name) + print(mean) + + +def save_image(image_numpy, image_path, aspect_ratio=1.0): + """Save a numpy image to the disk + + Parameters: + image_numpy (numpy array) -- input numpy array + image_path (str) -- the path of the image + """ + + image_pil = Image.fromarray(image_numpy) + h, w, _ = image_numpy.shape + + if aspect_ratio is None: + pass + elif aspect_ratio > 1.0: + image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC) + elif aspect_ratio < 1.0: + image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC) + image_pil.save(image_path) + + +def print_numpy(x, val=True, shp=False): + """Print the mean, min, max, median, std, and size of a numpy array + + Parameters: + val (bool) -- if print the values of the numpy array + shp (bool) -- if print the shape of the numpy array + """ + x = x.astype(np.float64) + if shp: + print('shape,', x.shape) + if val: + x = x.flatten() + print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( + np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) + + +def mkdirs(paths): + """create empty directories if they don't exist + + Parameters: + paths (str list) -- a list of directory paths + """ + if isinstance(paths, list) and not isinstance(paths, str): + for path in paths: + mkdir(path) + else: + mkdir(paths) + + +def mkdir(path): + """create a single empty directory if it didn't exist + + Parameters: + path (str) -- a single directory path + """ + if not os.path.exists(path): + os.makedirs(path) + + +def correct_resize_label(t, size): + device = t.device + t = t.detach().cpu() + resized = [] + for i in range(t.size(0)): + one_t = t[i, :1] + one_np = np.transpose(one_t.numpy().astype(np.uint8), (1, 2, 0)) + one_np = one_np[:, :, 0] + one_image = Image.fromarray(one_np).resize(size, Image.NEAREST) + resized_t = torch.from_numpy(np.array(one_image)).long() + resized.append(resized_t) + return torch.stack(resized, dim=0).to(device) + + +def correct_resize(t, size, mode=Image.BICUBIC): + device = t.device + t = t.detach().cpu() + resized = [] + for i in range(t.size(0)): + one_t = t[i:i + 1] + one_image = Image.fromarray(tensor2im(one_t)).resize(size, Image.BICUBIC) + resized_t = torchvision.transforms.functional.to_tensor(one_image) * 2 - 1.0 + resized.append(resized_t) + return torch.stack(resized, dim=0).to(device) + +def draw_landmarks(img, landmark, color='r', step=2): + """ + Return: + img -- numpy.array, (B, H, W, 3) img with landmark, RGB order, range (0, 255) + + + Parameters: + img -- numpy.array, (B, H, W, 3), RGB order, range (0, 255) + landmark -- numpy.array, (B, 68, 2), y direction is opposite to v direction + color -- str, 'r' or 'b' (red or blue) + """ + if color =='r': + c = np.array([255., 0, 0]) + else: + c = np.array([0, 0, 255.]) + + _, H, W, _ = img.shape + img, landmark = img.copy(), landmark.copy() + landmark[..., 1] = H - 1 - landmark[..., 1] + landmark = np.round(landmark).astype(np.int32) + for i in range(landmark.shape[1]): + x, y = landmark[:, i, 0], landmark[:, i, 1] + for j in range(-step, step): + for k in range(-step, step): + u = np.clip(x + j, 0, W - 1) + v = np.clip(y + k, 0, H - 1) + for m in range(landmark.shape[0]): + img[m, v[m], u[m]] = c + return img diff --git a/videoretalking/third_part/face3d/util/visualizer.py b/videoretalking/third_part/face3d/util/visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..5c7f2a93781e62381e057e14063d55240a1de227 --- /dev/null +++ b/videoretalking/third_part/face3d/util/visualizer.py @@ -0,0 +1,227 @@ +"""This script defines the visualizer for Deep3DFaceRecon_pytorch +""" + +import numpy as np +import os +import sys +import ntpath +import time +from . import util, html +from subprocess import Popen, PIPE +from torch.utils.tensorboard import SummaryWriter + +def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256): + """Save images to the disk. + + Parameters: + webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details) + visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs + image_path (str) -- the string is used to create image paths + aspect_ratio (float) -- the aspect ratio of saved images + width (int) -- the images will be resized to width x width + + This function will save images stored in 'visuals' to the HTML file specified by 'webpage'. + """ + image_dir = webpage.get_image_dir() + short_path = ntpath.basename(image_path[0]) + name = os.path.splitext(short_path)[0] + + webpage.add_header(name) + ims, txts, links = [], [], [] + + for label, im_data in visuals.items(): + im = util.tensor2im(im_data) + image_name = '%s/%s.png' % (label, name) + os.makedirs(os.path.join(image_dir, label), exist_ok=True) + save_path = os.path.join(image_dir, image_name) + util.save_image(im, save_path, aspect_ratio=aspect_ratio) + ims.append(image_name) + txts.append(label) + links.append(image_name) + webpage.add_images(ims, txts, links, width=width) + + +class Visualizer(): + """This class includes several functions that can display/save images and print/save logging information. + + It uses a Python library tensprboardX for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images. + """ + + def __init__(self, opt): + """Initialize the Visualizer class + + Parameters: + opt -- stores all the experiment flags; needs to be a subclass of BaseOptions + Step 1: Cache the training/test options + Step 2: create a tensorboard writer + Step 3: create an HTML object for saving HTML filters + Step 4: create a logging file to store training losses + """ + self.opt = opt # cache the option + self.use_html = opt.isTrain and not opt.no_html + self.writer = SummaryWriter(os.path.join(opt.checkpoints_dir, 'logs', opt.name)) + self.win_size = opt.display_winsize + self.name = opt.name + self.saved = False + if self.use_html: # create an HTML object at /web/; images will be saved under /web/images/ + self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') + self.img_dir = os.path.join(self.web_dir, 'images') + print('create web directory %s...' % self.web_dir) + util.mkdirs([self.web_dir, self.img_dir]) + # create a logging file to store training losses + self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') + with open(self.log_name, "a") as log_file: + now = time.strftime("%c") + log_file.write('================ Training Loss (%s) ================\n' % now) + + def reset(self): + """Reset the self.saved status""" + self.saved = False + + + def display_current_results(self, visuals, total_iters, epoch, save_result): + """Display current results on tensorboad; save current results to an HTML file. + + Parameters: + visuals (OrderedDict) - - dictionary of images to display or save + total_iters (int) -- total iterations + epoch (int) - - the current epoch + save_result (bool) - - if save the current results to an HTML file + """ + for label, image in visuals.items(): + self.writer.add_image(label, util.tensor2im(image), total_iters, dataformats='HWC') + + if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved. + self.saved = True + # save images to the disk + for label, image in visuals.items(): + image_numpy = util.tensor2im(image) + img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) + util.save_image(image_numpy, img_path) + + # update website + webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=0) + for n in range(epoch, 0, -1): + webpage.add_header('epoch [%d]' % n) + ims, txts, links = [], [], [] + + for label, image_numpy in visuals.items(): + image_numpy = util.tensor2im(image) + img_path = 'epoch%.3d_%s.png' % (n, label) + ims.append(img_path) + txts.append(label) + links.append(img_path) + webpage.add_images(ims, txts, links, width=self.win_size) + webpage.save() + + def plot_current_losses(self, total_iters, losses): + # G_loss_collection = {} + # D_loss_collection = {} + # for name, value in losses.items(): + # if 'G' in name or 'NCE' in name or 'idt' in name: + # G_loss_collection[name] = value + # else: + # D_loss_collection[name] = value + # self.writer.add_scalars('G_collec', G_loss_collection, total_iters) + # self.writer.add_scalars('D_collec', D_loss_collection, total_iters) + for name, value in losses.items(): + self.writer.add_scalar(name, value, total_iters) + + # losses: same format as |losses| of plot_current_losses + def print_current_losses(self, epoch, iters, losses, t_comp, t_data): + """print current losses on console; also save the losses to the disk + + Parameters: + epoch (int) -- current epoch + iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch) + losses (OrderedDict) -- training losses stored in the format of (name, float) pairs + t_comp (float) -- computational time per data point (normalized by batch_size) + t_data (float) -- data loading time per data point (normalized by batch_size) + """ + message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data) + for k, v in losses.items(): + message += '%s: %.3f ' % (k, v) + + print(message) # print the message + with open(self.log_name, "a") as log_file: + log_file.write('%s\n' % message) # save the message + + +class MyVisualizer: + def __init__(self, opt): + """Initialize the Visualizer class + + Parameters: + opt -- stores all the experiment flags; needs to be a subclass of BaseOptions + Step 1: Cache the training/test options + Step 2: create a tensorboard writer + Step 3: create an HTML object for saving HTML filters + Step 4: create a logging file to store training losses + """ + self.opt = opt # cache the option + self.name = opt.name + self.img_dir = os.path.join(opt.checkpoints_dir, opt.name, 'results') + + if opt.phase != 'test': + self.writer = SummaryWriter(os.path.join(opt.checkpoints_dir, opt.name, 'logs')) + # create a logging file to store training losses + self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') + with open(self.log_name, "a") as log_file: + now = time.strftime("%c") + log_file.write('================ Training Loss (%s) ================\n' % now) + + + def display_current_results(self, visuals, total_iters, epoch, dataset='train', save_results=False, count=0, name=None, + add_image=True): + """Display current results on tensorboad; save current results to an HTML file. + + Parameters: + visuals (OrderedDict) - - dictionary of images to display or save + total_iters (int) -- total iterations + epoch (int) - - the current epoch + dataset (str) - - 'train' or 'val' or 'test' + """ + # if (not add_image) and (not save_results): return + + for label, image in visuals.items(): + for i in range(image.shape[0]): + image_numpy = util.tensor2im(image[i]) + if add_image: + self.writer.add_image(label + '%s_%02d'%(dataset, i + count), + image_numpy, total_iters, dataformats='HWC') + + if save_results: + save_path = os.path.join(self.img_dir, dataset, 'epoch_%s_%06d'%(epoch, total_iters)) + if not os.path.isdir(save_path): + os.makedirs(save_path) + + if name is not None: + img_path = os.path.join(save_path, '%s.png' % name) + else: + img_path = os.path.join(save_path, '%s_%03d.png' % (label, i + count)) + util.save_image(image_numpy, img_path) + + + def plot_current_losses(self, total_iters, losses, dataset='train'): + for name, value in losses.items(): + self.writer.add_scalar(name + '/%s'%dataset, value, total_iters) + + # losses: same format as |losses| of plot_current_losses + def print_current_losses(self, epoch, iters, losses, t_comp, t_data, dataset='train'): + """print current losses on console; also save the losses to the disk + + Parameters: + epoch (int) -- current epoch + iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch) + losses (OrderedDict) -- training losses stored in the format of (name, float) pairs + t_comp (float) -- computational time per data point (normalized by batch_size) + t_data (float) -- data loading time per data point (normalized by batch_size) + """ + message = '(dataset: %s, epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % ( + dataset, epoch, iters, t_comp, t_data) + for k, v in losses.items(): + message += '%s: %.3f ' % (k, v) + + print(message) # print the message + with open(self.log_name, "a") as log_file: + log_file.write('%s\n' % message) # save the message