Spaces:
Running
on
Zero
Running
on
Zero
# coding: utf-8 | |
""" | |
Benchmark the inference speed of each module in LivePortrait. | |
TODO: heavy GPT style, need to refactor | |
""" | |
import yaml | |
import torch | |
import time | |
import numpy as np | |
from src.utils.helper import load_model, concat_feat | |
from src.config.inference_config import InferenceConfig | |
def initialize_inputs(batch_size=1): | |
""" | |
Generate random input tensors and move them to GPU | |
""" | |
feature_3d = torch.randn(batch_size, 32, 16, 64, 64).cuda().half() | |
kp_source = torch.randn(batch_size, 21, 3).cuda().half() | |
kp_driving = torch.randn(batch_size, 21, 3).cuda().half() | |
source_image = torch.randn(batch_size, 3, 256, 256).cuda().half() | |
generator_input = torch.randn(batch_size, 256, 64, 64).cuda().half() | |
eye_close_ratio = torch.randn(batch_size, 3).cuda().half() | |
lip_close_ratio = torch.randn(batch_size, 2).cuda().half() | |
feat_stitching = concat_feat(kp_source, kp_driving).half() | |
feat_eye = concat_feat(kp_source, eye_close_ratio).half() | |
feat_lip = concat_feat(kp_source, lip_close_ratio).half() | |
inputs = { | |
'feature_3d': feature_3d, | |
'kp_source': kp_source, | |
'kp_driving': kp_driving, | |
'source_image': source_image, | |
'generator_input': generator_input, | |
'feat_stitching': feat_stitching, | |
'feat_eye': feat_eye, | |
'feat_lip': feat_lip | |
} | |
return inputs | |
def load_and_compile_models(cfg, model_config): | |
""" | |
Load and compile models for inference | |
""" | |
appearance_feature_extractor = load_model(cfg.checkpoint_F, model_config, cfg.device_id, 'appearance_feature_extractor') | |
motion_extractor = load_model(cfg.checkpoint_M, model_config, cfg.device_id, 'motion_extractor') | |
warping_module = load_model(cfg.checkpoint_W, model_config, cfg.device_id, 'warping_module') | |
spade_generator = load_model(cfg.checkpoint_G, model_config, cfg.device_id, 'spade_generator') | |
stitching_retargeting_module = load_model(cfg.checkpoint_S, model_config, cfg.device_id, 'stitching_retargeting_module') | |
models_with_params = [ | |
('Appearance Feature Extractor', appearance_feature_extractor), | |
('Motion Extractor', motion_extractor), | |
('Warping Network', warping_module), | |
('SPADE Decoder', spade_generator) | |
] | |
compiled_models = {} | |
for name, model in models_with_params: | |
model = model.half() | |
model = torch.compile(model, mode='max-autotune') # Optimize for inference | |
model.eval() # Switch to evaluation mode | |
compiled_models[name] = model | |
retargeting_models = ['stitching', 'eye', 'lip'] | |
for retarget in retargeting_models: | |
module = stitching_retargeting_module[retarget].half() | |
module = torch.compile(module, mode='max-autotune') # Optimize for inference | |
module.eval() # Switch to evaluation mode | |
stitching_retargeting_module[retarget] = module | |
return compiled_models, stitching_retargeting_module | |
def warm_up_models(compiled_models, stitching_retargeting_module, inputs): | |
""" | |
Warm up models to prepare them for benchmarking | |
""" | |
print("Warm up start!") | |
with torch.no_grad(): | |
for _ in range(10): | |
compiled_models['Appearance Feature Extractor'](inputs['source_image']) | |
compiled_models['Motion Extractor'](inputs['source_image']) | |
compiled_models['Warping Network'](inputs['feature_3d'], inputs['kp_driving'], inputs['kp_source']) | |
compiled_models['SPADE Decoder'](inputs['generator_input']) # Adjust input as required | |
stitching_retargeting_module['stitching'](inputs['feat_stitching']) | |
stitching_retargeting_module['eye'](inputs['feat_eye']) | |
stitching_retargeting_module['lip'](inputs['feat_lip']) | |
print("Warm up end!") | |
def measure_inference_times(compiled_models, stitching_retargeting_module, inputs): | |
""" | |
Measure inference times for each model | |
""" | |
times = {name: [] for name in compiled_models.keys()} | |
times['Retargeting Models'] = [] | |
overall_times = [] | |
with torch.no_grad(): | |
for _ in range(100): | |
torch.cuda.synchronize() | |
overall_start = time.time() | |
start = time.time() | |
compiled_models['Appearance Feature Extractor'](inputs['source_image']) | |
torch.cuda.synchronize() | |
times['Appearance Feature Extractor'].append(time.time() - start) | |
start = time.time() | |
compiled_models['Motion Extractor'](inputs['source_image']) | |
torch.cuda.synchronize() | |
times['Motion Extractor'].append(time.time() - start) | |
start = time.time() | |
compiled_models['Warping Network'](inputs['feature_3d'], inputs['kp_driving'], inputs['kp_source']) | |
torch.cuda.synchronize() | |
times['Warping Network'].append(time.time() - start) | |
start = time.time() | |
compiled_models['SPADE Decoder'](inputs['generator_input']) # Adjust input as required | |
torch.cuda.synchronize() | |
times['SPADE Decoder'].append(time.time() - start) | |
start = time.time() | |
stitching_retargeting_module['stitching'](inputs['feat_stitching']) | |
stitching_retargeting_module['eye'](inputs['feat_eye']) | |
stitching_retargeting_module['lip'](inputs['feat_lip']) | |
torch.cuda.synchronize() | |
times['Retargeting Models'].append(time.time() - start) | |
overall_times.append(time.time() - overall_start) | |
return times, overall_times | |
def print_benchmark_results(compiled_models, stitching_retargeting_module, retargeting_models, times, overall_times): | |
""" | |
Print benchmark results with average and standard deviation of inference times | |
""" | |
average_times = {name: np.mean(times[name]) * 1000 for name in times.keys()} | |
std_times = {name: np.std(times[name]) * 1000 for name in times.keys()} | |
for name, model in compiled_models.items(): | |
num_params = sum(p.numel() for p in model.parameters()) | |
num_params_in_millions = num_params / 1e6 | |
print(f"Number of parameters for {name}: {num_params_in_millions:.2f} M") | |
for index, retarget in enumerate(retargeting_models): | |
num_params = sum(p.numel() for p in stitching_retargeting_module[retarget].parameters()) | |
num_params_in_millions = num_params / 1e6 | |
print(f"Number of parameters for part_{index} in Stitching and Retargeting Modules: {num_params_in_millions:.2f} M") | |
for name, avg_time in average_times.items(): | |
std_time = std_times[name] | |
print(f"Average inference time for {name} over 100 runs: {avg_time:.2f} ms (std: {std_time:.2f} ms)") | |
def main(): | |
""" | |
Main function to benchmark speed and model parameters | |
""" | |
# Sample input tensors | |
inputs = initialize_inputs() | |
# Load configuration | |
cfg = InferenceConfig(device_id=0) | |
model_config_path = cfg.models_config | |
with open(model_config_path, 'r') as file: | |
model_config = yaml.safe_load(file) | |
# Load and compile models | |
compiled_models, stitching_retargeting_module = load_and_compile_models(cfg, model_config) | |
# Warm up models | |
warm_up_models(compiled_models, stitching_retargeting_module, inputs) | |
# Measure inference times | |
times, overall_times = measure_inference_times(compiled_models, stitching_retargeting_module, inputs) | |
# Print benchmark results | |
print_benchmark_results(compiled_models, stitching_retargeting_module, ['stitching', 'eye', 'lip'], times, overall_times) | |
if __name__ == "__main__": | |
main() | |