jhj0517 commited on
Commit
ab6505a
1 Parent(s): cfa21c8

refactor download models script

Browse files
Files changed (3) hide show
  1. downloading_weights.py +14 -9
  2. musepose_inference.py +3 -0
  3. pose_align.py +3 -3
downloading_weights.py CHANGED
@@ -26,8 +26,15 @@ def download_models(
26
  os.makedirs(dir, exist_ok=True)
27
 
28
  for url, path in tqdm(zip(urls, paths)):
29
- local_file_path = os.path.join("pretrained_weights", path)
30
- wget.download(url, local_file_path)
 
 
 
 
 
 
 
31
 
32
  config_urls = ['https://huggingface.co/lambdalabs/sd-image-variations-diffusers/resolve/main/unet/config.json',
33
  'https://huggingface.co/lambdalabs/sd-image-variations-diffusers/resolve/main/image_encoder/config.json',
@@ -37,10 +44,8 @@ def download_models(
37
 
38
  # saving config files
39
  for url, path in tqdm(zip(config_urls, config_paths)):
40
- local_file_path = os.path.join("pretrained_weights", path)
41
- wget.download(url, local_file_path)
42
-
43
- # renaming model name as given in readme
44
- wrong_file_path = os.path.join("pretrained_weights", "dwpose", "yolox_l_8x8_300e_coco_20211126_140236-d3bd2b23.pth")
45
- correct_file_path = os.path.join("pretrained_weights", "dwpose", "yolox_l_8x8_300e_coco.pth")
46
- os.rename(wrong_file_path, correct_file_path)
 
26
  os.makedirs(dir, exist_ok=True)
27
 
28
  for url, path in tqdm(zip(urls, paths)):
29
+ filename = os.path.basename(url)
30
+ if filename == "yolox_l_8x8_300e_coco_20211126_140236-d3bd2b23.pth":
31
+ filename = "yolox_l_8x8_300e_coco.pth"
32
+
33
+ full_file_path = os.path.join(model_dir, path, filename)
34
+
35
+ if not os.path.exists(full_file_path):
36
+ print(f"Model '{filename}' does not exists. Downloading to '{full_file_path}'..")
37
+ wget.download(url, full_file_path)
38
 
39
  config_urls = ['https://huggingface.co/lambdalabs/sd-image-variations-diffusers/resolve/main/unet/config.json',
40
  'https://huggingface.co/lambdalabs/sd-image-variations-diffusers/resolve/main/image_encoder/config.json',
 
44
 
45
  # saving config files
46
  for url, path in tqdm(zip(config_urls, config_paths)):
47
+ filename = os.path.basename(url)
48
+ full_file_path = os.path.join(model_dir, path, filename)
49
+ if not os.path.exists(full_file_path):
50
+ print(f"Model '{filename}' does not exists. Downloading to '{full_file_path}'..")
51
+ wget.download(url, full_file_path)
 
 
musepose_inference.py CHANGED
@@ -18,6 +18,7 @@ from musepose.models.unet_2d_condition import UNet2DConditionModel
18
  from musepose.models.unet_3d import UNet3DConditionModel
19
  from musepose.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
20
  from musepose.utils.util import get_fps, read_frames, save_videos_grid
 
21
 
22
 
23
  class MusePoseInference:
@@ -42,6 +43,7 @@ class MusePoseInference:
42
  self.pose_guider = None
43
  self.image_enc = None
44
  self.pipe = None
 
45
  self.output_dir = output_dir
46
  if not os.path.exists(self.output_dir):
47
  os.makedirs(self.output_dir)
@@ -62,6 +64,7 @@ class MusePoseInference:
62
  fps: int,
63
  skip: int
64
  ):
 
65
  print(f"Model Paths: {self.musepose_model_paths}\n{self.image_gen_model_paths}\n{self.inference_config_path}")
66
  print(f"Input Image Path: {ref_image_path}")
67
  print(f"Pose Video Path: {pose_video_path}")
 
18
  from musepose.models.unet_3d import UNet3DConditionModel
19
  from musepose.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
20
  from musepose.utils.util import get_fps, read_frames, save_videos_grid
21
+ from downloading_weights import download_models
22
 
23
 
24
  class MusePoseInference:
 
43
  self.pose_guider = None
44
  self.image_enc = None
45
  self.pipe = None
46
+ self.model_dir = model_dir
47
  self.output_dir = output_dir
48
  if not os.path.exists(self.output_dir):
49
  os.makedirs(self.output_dir)
 
64
  fps: int,
65
  skip: int
66
  ):
67
+ download_models(model_dir=self.model_dir)
68
  print(f"Model Paths: {self.musepose_model_paths}\n{self.image_gen_model_paths}\n{self.inference_config_path}")
69
  print(f"Input Image Path: {ref_image_path}")
70
  print(f"Pose Video Path: {pose_video_path}")
pose_align.py CHANGED
@@ -1,5 +1,4 @@
1
  import numpy as np
2
- import argparse
3
  import torch
4
  import copy
5
  import cv2
@@ -7,10 +6,10 @@ import os
7
  import moviepy.video.io.ImageSequenceClip
8
  from datetime import datetime
9
  import gc
10
- from huggingface_hub import hf_hub_download
11
 
12
  from pose.script.dwpose import DWposeDetector, draw_pose
13
  from pose.script.util import size_calculate, warpAffine_kps
 
14
 
15
 
16
  '''
@@ -32,6 +31,7 @@ class PoseAlignmentInference:
32
  "pose_config": os.path.join("pose", "config", "dwpose-l_384x288.py"),
33
  "det_config": os.path.join("pose", "config", "yolox_l_8xb8-300e_coco.py"),
34
  }
 
35
  self.output_dir = output_dir
36
  if not os.path.exists(self.output_dir):
37
  os.makedirs(self.output_dir)
@@ -45,7 +45,7 @@ class PoseAlignmentInference:
45
  align_frame: int,
46
  max_frame: int,
47
  ):
48
- self.download_models()
49
  dt_file_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
50
  outfn=os.path.abspath(os.path.join(self.output_dir, f'{dt_file_name}_demo.mp4'))
51
  outfn_align_pose_video=os.path.abspath(os.path.join(self.output_dir, f'{dt_file_name}.mp4'))
 
1
  import numpy as np
 
2
  import torch
3
  import copy
4
  import cv2
 
6
  import moviepy.video.io.ImageSequenceClip
7
  from datetime import datetime
8
  import gc
 
9
 
10
  from pose.script.dwpose import DWposeDetector, draw_pose
11
  from pose.script.util import size_calculate, warpAffine_kps
12
+ from downloading_weights import download_models
13
 
14
 
15
  '''
 
31
  "pose_config": os.path.join("pose", "config", "dwpose-l_384x288.py"),
32
  "det_config": os.path.join("pose", "config", "yolox_l_8xb8-300e_coco.py"),
33
  }
34
+ self.model_dir = model_dir
35
  self.output_dir = output_dir
36
  if not os.path.exists(self.output_dir):
37
  os.makedirs(self.output_dir)
 
45
  align_frame: int,
46
  max_frame: int,
47
  ):
48
+ download_models(model_dir=self.model_dir)
49
  dt_file_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
50
  outfn=os.path.abspath(os.path.join(self.output_dir, f'{dt_file_name}_demo.mp4'))
51
  outfn_align_pose_video=os.path.abspath(os.path.join(self.output_dir, f'{dt_file_name}.mp4'))