jhj0517
commited on
Commit
·
6ecdb23
1
Parent(s):
dffd539
Add progress during model loading
Browse files
modules/live_portrait/live_portrait_inferencer.py
CHANGED
@@ -56,7 +56,8 @@ class LivePortraitInferencer:
|
|
56 |
self.psi_list = None
|
57 |
self.d_info = None
|
58 |
|
59 |
-
def load_models(self
|
|
|
60 |
def filter_stitcher(checkpoint, prefix):
|
61 |
filtered_checkpoint = {key.replace(prefix + "_module.", ""): value for key, value in checkpoint.items() if
|
62 |
key.startswith(prefix)}
|
@@ -64,6 +65,8 @@ class LivePortraitInferencer:
|
|
64 |
|
65 |
self.download_if_no_models()
|
66 |
|
|
|
|
|
67 |
appearance_feat_config = self.model_config["appearance_feature_extractor_params"]
|
68 |
self.appearance_feature_extractor = AppearanceFeatureExtractor(**appearance_feat_config).to(self.device)
|
69 |
self.appearance_feature_extractor = self.load_safe_tensor(
|
@@ -71,6 +74,7 @@ class LivePortraitInferencer:
|
|
71 |
os.path.join(self.model_dir, "appearance_feature_extractor.safetensors")
|
72 |
)
|
73 |
|
|
|
74 |
motion_ext_config = self.model_config["motion_extractor_params"]
|
75 |
self.motion_extractor = MotionExtractor(**motion_ext_config).to(self.device)
|
76 |
self.motion_extractor = self.load_safe_tensor(
|
@@ -78,6 +82,7 @@ class LivePortraitInferencer:
|
|
78 |
os.path.join(self.model_dir, "motion_extractor.safetensors")
|
79 |
)
|
80 |
|
|
|
81 |
warping_module_config = self.model_config["warping_module_params"]
|
82 |
self.warping_module = WarpingNetwork(**warping_module_config).to(self.device)
|
83 |
self.warping_module = self.load_safe_tensor(
|
@@ -85,6 +90,7 @@ class LivePortraitInferencer:
|
|
85 |
os.path.join(self.model_dir, "warping_module.safetensors")
|
86 |
)
|
87 |
|
|
|
88 |
spaded_decoder_config = self.model_config["spade_generator_params"]
|
89 |
self.spade_generator = SPADEDecoder(**spaded_decoder_config).to(self.device)
|
90 |
self.spade_generator = self.load_safe_tensor(
|
@@ -92,6 +98,7 @@ class LivePortraitInferencer:
|
|
92 |
os.path.join(self.model_dir, "spade_generator.safetensors")
|
93 |
)
|
94 |
|
|
|
95 |
stitcher_config = self.model_config["stitching_retargeting_module_params"]
|
96 |
self.stitching_retargeting_module = StitchingRetargetingNetwork(**stitcher_config.get('stitching'))
|
97 |
stitcher_model_path = os.path.join(self.model_dir, "stitching_retargeting_module.safetensors")
|
|
|
56 |
self.psi_list = None
|
57 |
self.d_info = None
|
58 |
|
59 |
+
def load_models(self,
|
60 |
+
progress=gr.Progress()):
|
61 |
def filter_stitcher(checkpoint, prefix):
|
62 |
filtered_checkpoint = {key.replace(prefix + "_module.", ""): value for key, value in checkpoint.items() if
|
63 |
key.startswith(prefix)}
|
|
|
65 |
|
66 |
self.download_if_no_models()
|
67 |
|
68 |
+
total_models_num = 5
|
69 |
+
progress(0/total_models_num, desc="Loading Appearance Feature Extractor model...")
|
70 |
appearance_feat_config = self.model_config["appearance_feature_extractor_params"]
|
71 |
self.appearance_feature_extractor = AppearanceFeatureExtractor(**appearance_feat_config).to(self.device)
|
72 |
self.appearance_feature_extractor = self.load_safe_tensor(
|
|
|
74 |
os.path.join(self.model_dir, "appearance_feature_extractor.safetensors")
|
75 |
)
|
76 |
|
77 |
+
progress(1/total_models_num, desc="Loading Motion Extractor model...")
|
78 |
motion_ext_config = self.model_config["motion_extractor_params"]
|
79 |
self.motion_extractor = MotionExtractor(**motion_ext_config).to(self.device)
|
80 |
self.motion_extractor = self.load_safe_tensor(
|
|
|
82 |
os.path.join(self.model_dir, "motion_extractor.safetensors")
|
83 |
)
|
84 |
|
85 |
+
progress(2/total_models_num, desc="Loading Warping Module model...")
|
86 |
warping_module_config = self.model_config["warping_module_params"]
|
87 |
self.warping_module = WarpingNetwork(**warping_module_config).to(self.device)
|
88 |
self.warping_module = self.load_safe_tensor(
|
|
|
90 |
os.path.join(self.model_dir, "warping_module.safetensors")
|
91 |
)
|
92 |
|
93 |
+
progress(3/total_models_num, desc="Loading Spade generator model...")
|
94 |
spaded_decoder_config = self.model_config["spade_generator_params"]
|
95 |
self.spade_generator = SPADEDecoder(**spaded_decoder_config).to(self.device)
|
96 |
self.spade_generator = self.load_safe_tensor(
|
|
|
98 |
os.path.join(self.model_dir, "spade_generator.safetensors")
|
99 |
)
|
100 |
|
101 |
+
progress(4/total_models_num, desc="Loading Stitcher model...")
|
102 |
stitcher_config = self.model_config["stitching_retargeting_module_params"]
|
103 |
self.stitching_retargeting_module = StitchingRetargetingNetwork(**stitcher_config.get('stitching'))
|
104 |
stitcher_model_path = os.path.join(self.model_dir, "stitching_retargeting_module.safetensors")
|