Spaces:
Running
Running
jhj0517
commited on
Commit
·
ab6505a
1
Parent(s):
cfa21c8
refactor download models script
Browse files- downloading_weights.py +14 -9
- musepose_inference.py +3 -0
- pose_align.py +3 -3
downloading_weights.py
CHANGED
@@ -26,8 +26,15 @@ def download_models(
|
|
26 |
os.makedirs(dir, exist_ok=True)
|
27 |
|
28 |
for url, path in tqdm(zip(urls, paths)):
|
29 |
-
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
config_urls = ['https://huggingface.co/lambdalabs/sd-image-variations-diffusers/resolve/main/unet/config.json',
|
33 |
'https://huggingface.co/lambdalabs/sd-image-variations-diffusers/resolve/main/image_encoder/config.json',
|
@@ -37,10 +44,8 @@ def download_models(
|
|
37 |
|
38 |
# saving config files
|
39 |
for url, path in tqdm(zip(config_urls, config_paths)):
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
correct_file_path = os.path.join("pretrained_weights", "dwpose", "yolox_l_8x8_300e_coco.pth")
|
46 |
-
os.rename(wrong_file_path, correct_file_path)
|
|
|
26 |
os.makedirs(dir, exist_ok=True)
|
27 |
|
28 |
for url, path in tqdm(zip(urls, paths)):
|
29 |
+
filename = os.path.basename(url)
|
30 |
+
if filename == "yolox_l_8x8_300e_coco_20211126_140236-d3bd2b23.pth":
|
31 |
+
filename = "yolox_l_8x8_300e_coco.pth"
|
32 |
+
|
33 |
+
full_file_path = os.path.join(model_dir, path, filename)
|
34 |
+
|
35 |
+
if not os.path.exists(full_file_path):
|
36 |
+
print(f"Model '{filename}' does not exists. Downloading to '{full_file_path}'..")
|
37 |
+
wget.download(url, full_file_path)
|
38 |
|
39 |
config_urls = ['https://huggingface.co/lambdalabs/sd-image-variations-diffusers/resolve/main/unet/config.json',
|
40 |
'https://huggingface.co/lambdalabs/sd-image-variations-diffusers/resolve/main/image_encoder/config.json',
|
|
|
44 |
|
45 |
# saving config files
|
46 |
for url, path in tqdm(zip(config_urls, config_paths)):
|
47 |
+
filename = os.path.basename(url)
|
48 |
+
full_file_path = os.path.join(model_dir, path, filename)
|
49 |
+
if not os.path.exists(full_file_path):
|
50 |
+
print(f"Model '{filename}' does not exists. Downloading to '{full_file_path}'..")
|
51 |
+
wget.download(url, full_file_path)
|
|
|
|
musepose_inference.py
CHANGED
@@ -18,6 +18,7 @@ from musepose.models.unet_2d_condition import UNet2DConditionModel
|
|
18 |
from musepose.models.unet_3d import UNet3DConditionModel
|
19 |
from musepose.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
|
20 |
from musepose.utils.util import get_fps, read_frames, save_videos_grid
|
|
|
21 |
|
22 |
|
23 |
class MusePoseInference:
|
@@ -42,6 +43,7 @@ class MusePoseInference:
|
|
42 |
self.pose_guider = None
|
43 |
self.image_enc = None
|
44 |
self.pipe = None
|
|
|
45 |
self.output_dir = output_dir
|
46 |
if not os.path.exists(self.output_dir):
|
47 |
os.makedirs(self.output_dir)
|
@@ -62,6 +64,7 @@ class MusePoseInference:
|
|
62 |
fps: int,
|
63 |
skip: int
|
64 |
):
|
|
|
65 |
print(f"Model Paths: {self.musepose_model_paths}\n{self.image_gen_model_paths}\n{self.inference_config_path}")
|
66 |
print(f"Input Image Path: {ref_image_path}")
|
67 |
print(f"Pose Video Path: {pose_video_path}")
|
|
|
18 |
from musepose.models.unet_3d import UNet3DConditionModel
|
19 |
from musepose.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
|
20 |
from musepose.utils.util import get_fps, read_frames, save_videos_grid
|
21 |
+
from downloading_weights import download_models
|
22 |
|
23 |
|
24 |
class MusePoseInference:
|
|
|
43 |
self.pose_guider = None
|
44 |
self.image_enc = None
|
45 |
self.pipe = None
|
46 |
+
self.model_dir = model_dir
|
47 |
self.output_dir = output_dir
|
48 |
if not os.path.exists(self.output_dir):
|
49 |
os.makedirs(self.output_dir)
|
|
|
64 |
fps: int,
|
65 |
skip: int
|
66 |
):
|
67 |
+
download_models(model_dir=self.model_dir)
|
68 |
print(f"Model Paths: {self.musepose_model_paths}\n{self.image_gen_model_paths}\n{self.inference_config_path}")
|
69 |
print(f"Input Image Path: {ref_image_path}")
|
70 |
print(f"Pose Video Path: {pose_video_path}")
|
pose_align.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
import numpy as np
|
2 |
-
import argparse
|
3 |
import torch
|
4 |
import copy
|
5 |
import cv2
|
@@ -7,10 +6,10 @@ import os
|
|
7 |
import moviepy.video.io.ImageSequenceClip
|
8 |
from datetime import datetime
|
9 |
import gc
|
10 |
-
from huggingface_hub import hf_hub_download
|
11 |
|
12 |
from pose.script.dwpose import DWposeDetector, draw_pose
|
13 |
from pose.script.util import size_calculate, warpAffine_kps
|
|
|
14 |
|
15 |
|
16 |
'''
|
@@ -32,6 +31,7 @@ class PoseAlignmentInference:
|
|
32 |
"pose_config": os.path.join("pose", "config", "dwpose-l_384x288.py"),
|
33 |
"det_config": os.path.join("pose", "config", "yolox_l_8xb8-300e_coco.py"),
|
34 |
}
|
|
|
35 |
self.output_dir = output_dir
|
36 |
if not os.path.exists(self.output_dir):
|
37 |
os.makedirs(self.output_dir)
|
@@ -45,7 +45,7 @@ class PoseAlignmentInference:
|
|
45 |
align_frame: int,
|
46 |
max_frame: int,
|
47 |
):
|
48 |
-
self.
|
49 |
dt_file_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
50 |
outfn=os.path.abspath(os.path.join(self.output_dir, f'{dt_file_name}_demo.mp4'))
|
51 |
outfn_align_pose_video=os.path.abspath(os.path.join(self.output_dir, f'{dt_file_name}.mp4'))
|
|
|
1 |
import numpy as np
|
|
|
2 |
import torch
|
3 |
import copy
|
4 |
import cv2
|
|
|
6 |
import moviepy.video.io.ImageSequenceClip
|
7 |
from datetime import datetime
|
8 |
import gc
|
|
|
9 |
|
10 |
from pose.script.dwpose import DWposeDetector, draw_pose
|
11 |
from pose.script.util import size_calculate, warpAffine_kps
|
12 |
+
from downloading_weights import download_models
|
13 |
|
14 |
|
15 |
'''
|
|
|
31 |
"pose_config": os.path.join("pose", "config", "dwpose-l_384x288.py"),
|
32 |
"det_config": os.path.join("pose", "config", "yolox_l_8xb8-300e_coco.py"),
|
33 |
}
|
34 |
+
self.model_dir = model_dir
|
35 |
self.output_dir = output_dir
|
36 |
if not os.path.exists(self.output_dir):
|
37 |
os.makedirs(self.output_dir)
|
|
|
45 |
align_frame: int,
|
46 |
max_frame: int,
|
47 |
):
|
48 |
+
download_models(model_dir=self.model_dir)
|
49 |
dt_file_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
50 |
outfn=os.path.abspath(os.path.join(self.output_dir, f'{dt_file_name}_demo.mp4'))
|
51 |
outfn_align_pose_video=os.path.abspath(os.path.join(self.output_dir, f'{dt_file_name}.mp4'))
|