Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,606 Bytes
a891a57 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
# coding: utf-8
"""
Benchmark the inference speed of each module in LivePortrait.
TODO: heavy GPT style, need to refactor
"""
import yaml
import torch
import time
import numpy as np
from src.utils.helper import load_model, concat_feat
from src.config.inference_config import InferenceConfig
def initialize_inputs(batch_size=1):
"""
Generate random input tensors and move them to GPU
"""
feature_3d = torch.randn(batch_size, 32, 16, 64, 64).cuda().half()
kp_source = torch.randn(batch_size, 21, 3).cuda().half()
kp_driving = torch.randn(batch_size, 21, 3).cuda().half()
source_image = torch.randn(batch_size, 3, 256, 256).cuda().half()
generator_input = torch.randn(batch_size, 256, 64, 64).cuda().half()
eye_close_ratio = torch.randn(batch_size, 3).cuda().half()
lip_close_ratio = torch.randn(batch_size, 2).cuda().half()
feat_stitching = concat_feat(kp_source, kp_driving).half()
feat_eye = concat_feat(kp_source, eye_close_ratio).half()
feat_lip = concat_feat(kp_source, lip_close_ratio).half()
inputs = {
'feature_3d': feature_3d,
'kp_source': kp_source,
'kp_driving': kp_driving,
'source_image': source_image,
'generator_input': generator_input,
'feat_stitching': feat_stitching,
'feat_eye': feat_eye,
'feat_lip': feat_lip
}
return inputs
def load_and_compile_models(cfg, model_config):
"""
Load and compile models for inference
"""
appearance_feature_extractor = load_model(cfg.checkpoint_F, model_config, cfg.device_id, 'appearance_feature_extractor')
motion_extractor = load_model(cfg.checkpoint_M, model_config, cfg.device_id, 'motion_extractor')
warping_module = load_model(cfg.checkpoint_W, model_config, cfg.device_id, 'warping_module')
spade_generator = load_model(cfg.checkpoint_G, model_config, cfg.device_id, 'spade_generator')
stitching_retargeting_module = load_model(cfg.checkpoint_S, model_config, cfg.device_id, 'stitching_retargeting_module')
models_with_params = [
('Appearance Feature Extractor', appearance_feature_extractor),
('Motion Extractor', motion_extractor),
('Warping Network', warping_module),
('SPADE Decoder', spade_generator)
]
compiled_models = {}
for name, model in models_with_params:
model = model.half()
model = torch.compile(model, mode='max-autotune') # Optimize for inference
model.eval() # Switch to evaluation mode
compiled_models[name] = model
retargeting_models = ['stitching', 'eye', 'lip']
for retarget in retargeting_models:
module = stitching_retargeting_module[retarget].half()
module = torch.compile(module, mode='max-autotune') # Optimize for inference
module.eval() # Switch to evaluation mode
stitching_retargeting_module[retarget] = module
return compiled_models, stitching_retargeting_module
def warm_up_models(compiled_models, stitching_retargeting_module, inputs):
"""
Warm up models to prepare them for benchmarking
"""
print("Warm up start!")
with torch.no_grad():
for _ in range(10):
compiled_models['Appearance Feature Extractor'](inputs['source_image'])
compiled_models['Motion Extractor'](inputs['source_image'])
compiled_models['Warping Network'](inputs['feature_3d'], inputs['kp_driving'], inputs['kp_source'])
compiled_models['SPADE Decoder'](inputs['generator_input']) # Adjust input as required
stitching_retargeting_module['stitching'](inputs['feat_stitching'])
stitching_retargeting_module['eye'](inputs['feat_eye'])
stitching_retargeting_module['lip'](inputs['feat_lip'])
print("Warm up end!")
def measure_inference_times(compiled_models, stitching_retargeting_module, inputs):
"""
Measure inference times for each model
"""
times = {name: [] for name in compiled_models.keys()}
times['Retargeting Models'] = []
overall_times = []
with torch.no_grad():
for _ in range(100):
torch.cuda.synchronize()
overall_start = time.time()
start = time.time()
compiled_models['Appearance Feature Extractor'](inputs['source_image'])
torch.cuda.synchronize()
times['Appearance Feature Extractor'].append(time.time() - start)
start = time.time()
compiled_models['Motion Extractor'](inputs['source_image'])
torch.cuda.synchronize()
times['Motion Extractor'].append(time.time() - start)
start = time.time()
compiled_models['Warping Network'](inputs['feature_3d'], inputs['kp_driving'], inputs['kp_source'])
torch.cuda.synchronize()
times['Warping Network'].append(time.time() - start)
start = time.time()
compiled_models['SPADE Decoder'](inputs['generator_input']) # Adjust input as required
torch.cuda.synchronize()
times['SPADE Decoder'].append(time.time() - start)
start = time.time()
stitching_retargeting_module['stitching'](inputs['feat_stitching'])
stitching_retargeting_module['eye'](inputs['feat_eye'])
stitching_retargeting_module['lip'](inputs['feat_lip'])
torch.cuda.synchronize()
times['Retargeting Models'].append(time.time() - start)
overall_times.append(time.time() - overall_start)
return times, overall_times
def print_benchmark_results(compiled_models, stitching_retargeting_module, retargeting_models, times, overall_times):
"""
Print benchmark results with average and standard deviation of inference times
"""
average_times = {name: np.mean(times[name]) * 1000 for name in times.keys()}
std_times = {name: np.std(times[name]) * 1000 for name in times.keys()}
for name, model in compiled_models.items():
num_params = sum(p.numel() for p in model.parameters())
num_params_in_millions = num_params / 1e6
print(f"Number of parameters for {name}: {num_params_in_millions:.2f} M")
for index, retarget in enumerate(retargeting_models):
num_params = sum(p.numel() for p in stitching_retargeting_module[retarget].parameters())
num_params_in_millions = num_params / 1e6
print(f"Number of parameters for part_{index} in Stitching and Retargeting Modules: {num_params_in_millions:.2f} M")
for name, avg_time in average_times.items():
std_time = std_times[name]
print(f"Average inference time for {name} over 100 runs: {avg_time:.2f} ms (std: {std_time:.2f} ms)")
def main():
"""
Main function to benchmark speed and model parameters
"""
# Sample input tensors
inputs = initialize_inputs()
# Load configuration
cfg = InferenceConfig(device_id=0)
model_config_path = cfg.models_config
with open(model_config_path, 'r') as file:
model_config = yaml.safe_load(file)
# Load and compile models
compiled_models, stitching_retargeting_module = load_and_compile_models(cfg, model_config)
# Warm up models
warm_up_models(compiled_models, stitching_retargeting_module, inputs)
# Measure inference times
times, overall_times = measure_inference_times(compiled_models, stitching_retargeting_module, inputs)
# Print benchmark results
print_benchmark_results(compiled_models, stitching_retargeting_module, ['stitching', 'eye', 'lip'], times, overall_times)
if __name__ == "__main__":
main()
|