jhj0517 commited on
Commit
3c0f460
·
1 Parent(s): e7f3a8b

refactor output dir and models dir

Browse files
Files changed (4) hide show
  1. app.py +18 -5
  2. downloading_weights.py +1 -1
  3. musepose_inference.py +11 -9
  4. pose_align.py +6 -4
app.py CHANGED
@@ -1,15 +1,23 @@
1
  import gradio as gr
 
2
  import os
3
- from huggingface_hub import hf_hub_download
4
 
5
  from musepose_inference import MusePoseInference
6
  from pose_align import PoseAlignmentInference
 
7
 
8
 
9
  class App:
10
- def __init__(self):
11
- self.pose_alignment_infer = PoseAlignmentInference()
12
- self.musepose_infer = MusePoseInference()
 
 
 
 
 
 
 
13
 
14
  def musepose_demo(self):
15
  with gr.Blocks() as demo:
@@ -82,5 +90,10 @@ class App:
82
 
83
 
84
  if __name__ == "__main__":
85
- app = App()
 
 
 
 
 
86
  app.launch()
 
1
  import gradio as gr
2
+ import argparse
3
  import os
 
4
 
5
  from musepose_inference import MusePoseInference
6
  from pose_align import PoseAlignmentInference
7
+ from downloading_weights import download_models
8
 
9
 
10
  class App:
11
+ def __init__(self, args):
12
+ self.pose_alignment_infer = PoseAlignmentInference(
13
+ model_dir=args.model_dir,
14
+ output_dir=args.output_dir
15
+ )
16
+ self.musepose_infer = MusePoseInference(
17
+ model_dir=args.model_dir,
18
+ output_dir=args.output_dir
19
+ )
20
+ download_models(args.model_dir)
21
 
22
  def musepose_demo(self):
23
  with gr.Blocks() as demo:
 
90
 
91
 
92
  if __name__ == "__main__":
93
+ parser = argparse.ArgumentParser()
94
+ parser.add_argument('--model_dir', type=str, default=os.path.join("pretrained_weights"), help='Pretrained models directory for MusePose')
95
+ parser.add_argument('--output_dir', type=str, default=os.path.join("assets", "videos"), help='Output directory for the result')
96
+ args = parser.parse_args()
97
+
98
+ app = App(args=args)
99
  app.launch()
downloading_weights.py CHANGED
@@ -4,7 +4,7 @@ from tqdm import tqdm
4
 
5
 
6
  def download_models(
7
- models_dir:str = os.makedirs('pretrained_weights', exist_ok=True)
8
  ):
9
  os.makedirs(models_dir, exist_ok=True)
10
 
 
4
 
5
 
6
  def download_models(
7
+ models_dir: str = os.makedirs('pretrained_weights', exist_ok=True)
8
  ):
9
  os.makedirs(models_dir, exist_ok=True)
10
 
musepose_inference.py CHANGED
@@ -21,17 +21,19 @@ from musepose.utils.util import get_fps, read_frames, save_videos_grid
21
 
22
 
23
  class MusePoseInference:
24
- def __init__(self):
 
 
25
  self.image_gen_model_paths = {
26
- "pretrained_base_model": os.path.join("pretrained_weights", "sd-image-variations-diffusers"),
27
- "pretrained_vae": os.path.join("pretrained_weights", "sd-vae-ft-mse"),
28
- "image_encoder": os.path.join("pretrained_weights", "image_encoder"),
29
  }
30
  self.musepose_model_paths = {
31
- "denoising_unet": os.path.join("pretrained_weights", "MusePose", "denoising_unet.pth"),
32
- "reference_unet": os.path.join("pretrained_weights", "MusePose", "reference_unet.pth"),
33
- "pose_guider": os.path.join("pretrained_weights", "MusePose", "pose_guider.pth"),
34
- "motion_module": os.path.join("pretrained_weights", "MusePose", "motion_module.pth"),
35
  }
36
  self.inference_config_path = os.path.join("configs", "inference_v2.yaml")
37
  self.vae = None
@@ -40,7 +42,7 @@ class MusePoseInference:
40
  self.pose_guider = None
41
  self.image_enc = None
42
  self.pipe = None
43
- self.output_dir = os.path.join("assets", "videos")
44
  if not os.path.exists(self.output_dir):
45
  os.makedirs(self.output_dir)
46
 
 
21
 
22
 
23
  class MusePoseInference:
24
+ def __init__(self,
25
+ model_dir,
26
+ output_dir):
27
  self.image_gen_model_paths = {
28
+ "pretrained_base_model": os.path.join(model_dir, "sd-image-variations-diffusers"),
29
+ "pretrained_vae": os.path.join(model_dir, "sd-vae-ft-mse"),
30
+ "image_encoder": os.path.join(model_dir, "image_encoder"),
31
  }
32
  self.musepose_model_paths = {
33
+ "denoising_unet": os.path.join(model_dir, "MusePose", "denoising_unet.pth"),
34
+ "reference_unet": os.path.join(model_dir, "MusePose", "reference_unet.pth"),
35
+ "pose_guider": os.path.join(model_dir, "MusePose", "pose_guider.pth"),
36
+ "motion_module": os.path.join(model_dir, "MusePose", "motion_module.pth"),
37
  }
38
  self.inference_config_path = os.path.join("configs", "inference_v2.yaml")
39
  self.vae = None
 
42
  self.pose_guider = None
43
  self.image_enc = None
44
  self.pipe = None
45
+ self.output_dir = output_dir
46
  if not os.path.exists(self.output_dir):
47
  os.makedirs(self.output_dir)
48
 
pose_align.py CHANGED
@@ -20,17 +20,19 @@ from pose.script.util import size_calculate, warpAffine_kps
20
  scales: scale parameters
21
  '''
22
  class PoseAlignmentInference:
23
- def __init__(self):
 
 
24
  self.detector = None
25
  self.model_paths = {
26
- "det_ckpt": os.path.join("pretrained_weights", "dwpose", "yolox_l_8x8_300e_coco.pth"),
27
- "pose_ckpt": os.path.join("pretrained_weights", "dwpose", "dw-ll_ucoco_384.pth")
28
  }
29
  self.config_paths = {
30
  "pose_config": os.path.join("pose", "config", "dwpose-l_384x288.py"),
31
  "det_config": os.path.join("pose", "config", "yolox_l_8xb8-300e_coco.py"),
32
  }
33
- self.output_dir = os.path.join("assets", "videos")
34
  if not os.path.exists(self.output_dir):
35
  os.makedirs(self.output_dir)
36
 
 
20
  scales: scale parameters
21
  '''
22
  class PoseAlignmentInference:
23
+ def __init__(self,
24
+ model_dir,
25
+ output_dir):
26
  self.detector = None
27
  self.model_paths = {
28
+ "det_ckpt": os.path.join(model_dir, "dwpose", "yolox_l_8x8_300e_coco.pth"),
29
+ "pose_ckpt": os.path.join(model_dir, "dwpose", "dw-ll_ucoco_384.pth")
30
  }
31
  self.config_paths = {
32
  "pose_config": os.path.join("pose", "config", "dwpose-l_384x288.py"),
33
  "det_config": os.path.join("pose", "config", "yolox_l_8xb8-300e_coco.py"),
34
  }
35
+ self.output_dir = output_dir
36
  if not os.path.exists(self.output_dir):
37
  os.makedirs(self.output_dir)
38