jhj0517 commited on
Commit
6ecdb23
·
1 Parent(s): dffd539

Add progress during model loading

Browse files
modules/live_portrait/live_portrait_inferencer.py CHANGED
@@ -56,7 +56,8 @@ class LivePortraitInferencer:
56
  self.psi_list = None
57
  self.d_info = None
58
 
59
- def load_models(self):
 
60
  def filter_stitcher(checkpoint, prefix):
61
  filtered_checkpoint = {key.replace(prefix + "_module.", ""): value for key, value in checkpoint.items() if
62
  key.startswith(prefix)}
@@ -64,6 +65,8 @@ class LivePortraitInferencer:
64
 
65
  self.download_if_no_models()
66
 
 
 
67
  appearance_feat_config = self.model_config["appearance_feature_extractor_params"]
68
  self.appearance_feature_extractor = AppearanceFeatureExtractor(**appearance_feat_config).to(self.device)
69
  self.appearance_feature_extractor = self.load_safe_tensor(
@@ -71,6 +74,7 @@ class LivePortraitInferencer:
71
  os.path.join(self.model_dir, "appearance_feature_extractor.safetensors")
72
  )
73
 
 
74
  motion_ext_config = self.model_config["motion_extractor_params"]
75
  self.motion_extractor = MotionExtractor(**motion_ext_config).to(self.device)
76
  self.motion_extractor = self.load_safe_tensor(
@@ -78,6 +82,7 @@ class LivePortraitInferencer:
78
  os.path.join(self.model_dir, "motion_extractor.safetensors")
79
  )
80
 
 
81
  warping_module_config = self.model_config["warping_module_params"]
82
  self.warping_module = WarpingNetwork(**warping_module_config).to(self.device)
83
  self.warping_module = self.load_safe_tensor(
@@ -85,6 +90,7 @@ class LivePortraitInferencer:
85
  os.path.join(self.model_dir, "warping_module.safetensors")
86
  )
87
 
 
88
  spaded_decoder_config = self.model_config["spade_generator_params"]
89
  self.spade_generator = SPADEDecoder(**spaded_decoder_config).to(self.device)
90
  self.spade_generator = self.load_safe_tensor(
@@ -92,6 +98,7 @@ class LivePortraitInferencer:
92
  os.path.join(self.model_dir, "spade_generator.safetensors")
93
  )
94
 
 
95
  stitcher_config = self.model_config["stitching_retargeting_module_params"]
96
  self.stitching_retargeting_module = StitchingRetargetingNetwork(**stitcher_config.get('stitching'))
97
  stitcher_model_path = os.path.join(self.model_dir, "stitching_retargeting_module.safetensors")
 
56
  self.psi_list = None
57
  self.d_info = None
58
 
59
+ def load_models(self,
60
+ progress=gr.Progress()):
61
  def filter_stitcher(checkpoint, prefix):
62
  filtered_checkpoint = {key.replace(prefix + "_module.", ""): value for key, value in checkpoint.items() if
63
  key.startswith(prefix)}
 
65
 
66
  self.download_if_no_models()
67
 
68
+ total_models_num = 5
69
+ progress(0/total_models_num, desc="Loading Appearance Feature Extractor model...")
70
  appearance_feat_config = self.model_config["appearance_feature_extractor_params"]
71
  self.appearance_feature_extractor = AppearanceFeatureExtractor(**appearance_feat_config).to(self.device)
72
  self.appearance_feature_extractor = self.load_safe_tensor(
 
74
  os.path.join(self.model_dir, "appearance_feature_extractor.safetensors")
75
  )
76
 
77
+ progress(1/total_models_num, desc="Loading Motion Extractor model...")
78
  motion_ext_config = self.model_config["motion_extractor_params"]
79
  self.motion_extractor = MotionExtractor(**motion_ext_config).to(self.device)
80
  self.motion_extractor = self.load_safe_tensor(
 
82
  os.path.join(self.model_dir, "motion_extractor.safetensors")
83
  )
84
 
85
+ progress(2/total_models_num, desc="Loading Warping Module model...")
86
  warping_module_config = self.model_config["warping_module_params"]
87
  self.warping_module = WarpingNetwork(**warping_module_config).to(self.device)
88
  self.warping_module = self.load_safe_tensor(
 
90
  os.path.join(self.model_dir, "warping_module.safetensors")
91
  )
92
 
93
+ progress(3/total_models_num, desc="Loading Spade generator model...")
94
  spaded_decoder_config = self.model_config["spade_generator_params"]
95
  self.spade_generator = SPADEDecoder(**spaded_decoder_config).to(self.device)
96
  self.spade_generator = self.load_safe_tensor(
 
98
  os.path.join(self.model_dir, "spade_generator.safetensors")
99
  )
100
 
101
+ progress(4/total_models_num, desc="Loading Stitcher model...")
102
  stitcher_config = self.model_config["stitching_retargeting_module_params"]
103
  self.stitching_retargeting_module = StitchingRetargetingNetwork(**stitcher_config.get('stitching'))
104
  stitcher_model_path = os.path.join(self.model_dir, "stitching_retargeting_module.safetensors")