jhj0517
commited on
Commit
·
3b34b71
1
Parent(s):
3f75a4f
Refactor model loading
Browse files
modules/live_portrait/live_portrait_inferencer.py
CHANGED
@@ -58,11 +58,6 @@ class LivePortraitInferencer:
|
|
58 |
|
59 |
def load_models(self,
|
60 |
progress=gr.Progress()):
|
61 |
-
def filter_stitcher(checkpoint, prefix):
|
62 |
-
filtered_checkpoint = {key.replace(prefix + "_module.", ""): value for key, value in checkpoint.items() if
|
63 |
-
key.startswith(prefix)}
|
64 |
-
return filtered_checkpoint
|
65 |
-
|
66 |
self.download_if_no_models()
|
67 |
|
68 |
total_models_num = 5
|
@@ -100,11 +95,12 @@ class LivePortraitInferencer:
|
|
100 |
|
101 |
progress(4/total_models_num, desc="Loading Stitcher model...")
|
102 |
stitcher_config = self.model_config["stitching_retargeting_module_params"]
|
103 |
-
self.stitching_retargeting_module = StitchingRetargetingNetwork(**stitcher_config.get('stitching'))
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
|
|
108 |
self.stitching_retargeting_module = {"stitching": self.stitching_retargeting_module}
|
109 |
|
110 |
if self.pipeline is None:
|
@@ -350,8 +346,16 @@ class LivePortraitInferencer:
|
|
350 |
download_model(model_path, model_url)
|
351 |
|
352 |
@staticmethod
|
353 |
-
def load_safe_tensor(model, file_path):
|
354 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
355 |
model.eval()
|
356 |
return model
|
357 |
|
|
|
58 |
|
59 |
def load_models(self,
|
60 |
progress=gr.Progress()):
|
|
|
|
|
|
|
|
|
|
|
61 |
self.download_if_no_models()
|
62 |
|
63 |
total_models_num = 5
|
|
|
95 |
|
96 |
progress(4/total_models_num, desc="Loading Stitcher model...")
|
97 |
stitcher_config = self.model_config["stitching_retargeting_module_params"]
|
98 |
+
self.stitching_retargeting_module = StitchingRetargetingNetwork(**stitcher_config.get('stitching')).to(self.device)
|
99 |
+
self.stitching_retargeting_module = self.load_safe_tensor(
|
100 |
+
self.stitching_retargeting_module,
|
101 |
+
os.path.join(self.model_dir, "stitching_retargeting_module.safetensors"),
|
102 |
+
True
|
103 |
+
)
|
104 |
self.stitching_retargeting_module = {"stitching": self.stitching_retargeting_module}
|
105 |
|
106 |
if self.pipeline is None:
|
|
|
346 |
download_model(model_path, model_url)
|
347 |
|
348 |
@staticmethod
|
349 |
+
def load_safe_tensor(model, file_path, is_stitcher=False):
|
350 |
+
def filter_stitcher(checkpoint, prefix):
|
351 |
+
filtered_checkpoint = {key.replace(prefix + "_module.", ""): value for key, value in checkpoint.items() if
|
352 |
+
key.startswith(prefix)}
|
353 |
+
return filtered_checkpoint
|
354 |
+
|
355 |
+
if is_stitcher:
|
356 |
+
model.load_state_dict(filter_stitcher(safetensors.torch.load_file(file_path), 'retarget_shoulder'))
|
357 |
+
else:
|
358 |
+
model.load_state_dict(safetensors.torch.load_file(file_path))
|
359 |
model.eval()
|
360 |
return model
|
361 |
|