jhj0517 commited on
Commit
3b34b71
·
1 Parent(s): 3f75a4f

Refactor model loading

Browse files
modules/live_portrait/live_portrait_inferencer.py CHANGED
@@ -58,11 +58,6 @@ class LivePortraitInferencer:
58
 
59
  def load_models(self,
60
  progress=gr.Progress()):
61
- def filter_stitcher(checkpoint, prefix):
62
- filtered_checkpoint = {key.replace(prefix + "_module.", ""): value for key, value in checkpoint.items() if
63
- key.startswith(prefix)}
64
- return filtered_checkpoint
65
-
66
  self.download_if_no_models()
67
 
68
  total_models_num = 5
@@ -100,11 +95,12 @@ class LivePortraitInferencer:
100
 
101
  progress(4/total_models_num, desc="Loading Stitcher model...")
102
  stitcher_config = self.model_config["stitching_retargeting_module_params"]
103
- self.stitching_retargeting_module = StitchingRetargetingNetwork(**stitcher_config.get('stitching'))
104
- stitcher_model_path = os.path.join(self.model_dir, "stitching_retargeting_module.safetensors")
105
- ckpt = safetensors.torch.load_file(stitcher_model_path)
106
- self.stitching_retargeting_module.load_state_dict(filter_stitcher(ckpt, 'retarget_shoulder'))
107
- self.stitching_retargeting_module.to(self.device).eval()
 
108
  self.stitching_retargeting_module = {"stitching": self.stitching_retargeting_module}
109
 
110
  if self.pipeline is None:
@@ -350,8 +346,16 @@ class LivePortraitInferencer:
350
  download_model(model_path, model_url)
351
 
352
  @staticmethod
353
- def load_safe_tensor(model, file_path):
354
- model.load_state_dict(safetensors.torch.load_file(file_path))
 
 
 
 
 
 
 
 
355
  model.eval()
356
  return model
357
 
 
58
 
59
  def load_models(self,
60
  progress=gr.Progress()):
 
 
 
 
 
61
  self.download_if_no_models()
62
 
63
  total_models_num = 5
 
95
 
96
  progress(4/total_models_num, desc="Loading Stitcher model...")
97
  stitcher_config = self.model_config["stitching_retargeting_module_params"]
98
+ self.stitching_retargeting_module = StitchingRetargetingNetwork(**stitcher_config.get('stitching')).to(self.device)
99
+ self.stitching_retargeting_module = self.load_safe_tensor(
100
+ self.stitching_retargeting_module,
101
+ os.path.join(self.model_dir, "stitching_retargeting_module.safetensors"),
102
+ True
103
+ )
104
  self.stitching_retargeting_module = {"stitching": self.stitching_retargeting_module}
105
 
106
  if self.pipeline is None:
 
346
  download_model(model_path, model_url)
347
 
348
  @staticmethod
349
+ def load_safe_tensor(model, file_path, is_stitcher=False):
350
+ def filter_stitcher(checkpoint, prefix):
351
+ filtered_checkpoint = {key.replace(prefix + "_module.", ""): value for key, value in checkpoint.items() if
352
+ key.startswith(prefix)}
353
+ return filtered_checkpoint
354
+
355
+ if is_stitcher:
356
+ model.load_state_dict(filter_stitcher(safetensors.torch.load_file(file_path), 'retarget_shoulder'))
357
+ else:
358
+ model.load_state_dict(safetensors.torch.load_file(file_path))
359
  model.eval()
360
  return model
361