zejunyang commited on
Commit
2de857a
1 Parent(s): bf4c058
Files changed (4) hide show
  1. app.py +50 -46
  2. src/audio2vid.py +73 -67
  3. src/utils/crop_face_single.py +45 -0
  4. src/vid2vid.py +69 -67
app.py CHANGED
@@ -17,68 +17,72 @@ with gr.Blocks() as demo:
17
  gr.Markdown(description)
18
 
19
  with gr.Tab("Audio2video"):
20
- with gr.Column():
21
- with gr.Row():
22
- a2v_input_audio = gr.Audio(sources=["upload", "microphone"], type="filepath", editable=True, label="Input audio", interactive=True)
23
- # with gr.Column():
24
- # a2v_ref_img = gr.Image(label="Upload reference image", sources="upload")
25
- # a2v_img_trans_real_botton = gr.Button("Translate to realistic style")
26
- a2v_ref_img = gr.Image(label="Upload reference image", sources="upload")
27
- a2v_headpose_video = gr.Video(label="Option: upload head pose reference video", sources="upload")
28
 
29
- with gr.Row():
30
- a2v_size_slider = gr.Slider(minimum=256, maximum=1024, step=8, value=512, label="Video size (-W & -H)")
31
- a2v_step_slider = gr.Slider(minimum=5, maximum=50, value=25, label="Steps (--steps)")
32
-
33
- with gr.Row():
34
- a2v_length = gr.Number(value=150, label="Length (-L) (Set 0 to automatically calculate video length.)")
35
- a2v_seed = gr.Number(value=42, label="Seed (--seed)")
36
-
37
- a2v_botton = gr.Button("Generate", variant="primary")
38
  a2v_output_video = gr.PlayableVideo(label="Result", interactive=False)
 
 
 
 
 
 
 
 
 
39
 
40
 
41
  with gr.Tab("Video2video"):
42
- with gr.Column():
43
- with gr.Row():
44
- # with gr.Column():
45
- # v2v_ref_img = gr.Image(label="Upload reference image", sources="upload")
46
- # v2v_img_trans_real_botton = gr.Button("Translate to realistic style")
47
- v2v_ref_img = gr.Image(label="Upload reference image", sources="upload")
48
- v2v_source_video = gr.Video(label="Upload source video", sources="upload")
49
-
50
- with gr.Row():
51
- v2v_size_slider = gr.Slider(minimum=256, maximum=1024, step=8, value=512, label="Video size (-W & -H)")
52
- v2v_step_slider = gr.Slider(minimum=5, maximum=50, value=25, label="Steps (--steps)")
53
-
54
- with gr.Row():
55
- v2v_length = gr.Number(value=150, label="Length (-L) (Set 0 to automatically calculate video length.)")
56
- v2v_seed = gr.Number(value=42, label="Seed (--seed)")
57
-
58
- v2v_botton = gr.Button("Generate", variant="primary")
59
  v2v_output_video = gr.PlayableVideo(label="Result", interactive=False)
60
 
 
 
 
 
 
 
 
 
 
61
  a2v_botton.click(
62
  fn=audio2video,
63
  inputs=[a2v_input_audio, a2v_ref_img, a2v_headpose_video,
64
  a2v_size_slider, a2v_step_slider, a2v_length, a2v_seed],
65
- outputs=[a2v_output_video]
66
  )
67
- # a2v_img_trans_real_botton.click(
68
- # fn=sd_img2real,
69
- # inputs=[a2v_ref_img],
70
- # outputs=[a2v_ref_img]
71
- # )
72
  v2v_botton.click(
73
  fn=video2video,
74
  inputs=[v2v_ref_img, v2v_source_video,
75
  v2v_size_slider, v2v_step_slider, v2v_length, v2v_seed],
76
- outputs=[v2v_output_video]
77
  )
78
- # v2v_img_trans_real_botton.click(
79
- # fn=sd_img2real,
80
- # inputs=[v2v_ref_img],
81
- # outputs=[v2v_ref_img]
82
- # )
83
 
84
  demo.launch()
 
17
  gr.Markdown(description)
18
 
19
  with gr.Tab("Audio2video"):
20
+ with gr.Row():
21
+ with gr.Column():
22
+ with gr.Row():
23
+ a2v_input_audio = gr.Audio(sources=["upload", "microphone"], type="filepath", editable=True, label="Input audio", interactive=True)
24
+ a2v_ref_img = gr.Image(label="Upload reference image", sources="upload")
25
+ a2v_headpose_video = gr.Video(label="Option: upload head pose reference video", sources="upload")
 
 
26
 
27
+ with gr.Row():
28
+ a2v_size_slider = gr.Slider(minimum=256, maximum=1024, step=8, value=512, label="Video size (-W & -H)")
29
+ a2v_step_slider = gr.Slider(minimum=5, maximum=50, step=1, value=20, label="Steps (--steps)")
30
+
31
+ with gr.Row():
32
+ a2v_length = gr.Slider(minimum=0, maximum=300, step=1, value=150, label="Length (-L) (Set 0 to automatically calculate video length.)")
33
+ a2v_seed = gr.Number(value=42, label="Seed (--seed)")
34
+
35
+ a2v_botton = gr.Button("Generate", variant="primary")
36
  a2v_output_video = gr.PlayableVideo(label="Result", interactive=False)
37
+
38
+ gr.Examples(
39
+ examples=[
40
+ ["configs/inference/audio/lyl.wav", "configs/inference/ref_images/Aragaki.png", None],
41
+ ["configs/inference/audio/lyl.wav", "configs/inference/ref_images/solo.png", None],
42
+ ["configs/inference/audio/lyl.wav", "configs/inference/ref_images/lyl.png", "configs/inference/head_pose_temp/pose_ref_video.mp4"],
43
+ ],
44
+ inputs=[a2v_input_audio, a2v_ref_img, a2v_headpose_video],
45
+ )
46
 
47
 
48
  with gr.Tab("Video2video"):
49
+ with gr.Row():
50
+ with gr.Column():
51
+ with gr.Row():
52
+ v2v_ref_img = gr.Image(label="Upload reference image", sources="upload")
53
+ v2v_source_video = gr.Video(label="Upload source video", sources="upload")
54
+
55
+ with gr.Row():
56
+ v2v_size_slider = gr.Slider(minimum=256, maximum=1024, step=8, value=512, label="Video size (-W & -H)")
57
+ v2v_step_slider = gr.Slider(minimum=5, maximum=50, step=1, value=20, label="Steps (--steps)")
58
+
59
+ with gr.Row():
60
+ v2v_length = gr.Slider(minimum=0, maximum=300, step=1, value=150, label="Length (-L) (Set 0 to automatically calculate video length.)")
61
+ v2v_seed = gr.Number(value=42, label="Seed (--seed)")
62
+
63
+ v2v_botton = gr.Button("Generate", variant="primary")
 
 
64
  v2v_output_video = gr.PlayableVideo(label="Result", interactive=False)
65
 
66
+ gr.Examples(
67
+ examples=[
68
+ ["configs/inference/ref_images/Aragaki.png", "configs/inference/video/Aragaki_song.mp4"],
69
+ ["configs/inference/ref_images/solo.png", "configs/inference/video/Aragaki_song.mp4"],
70
+ ["configs/inference/ref_images/lyl.png", "configs/inference/head_pose_temp/pose_ref_video.mp4"],
71
+ ],
72
+ inputs=[v2v_ref_img, v2v_source_video, a2v_headpose_video],
73
+ )
74
+
75
  a2v_botton.click(
76
  fn=audio2video,
77
  inputs=[a2v_input_audio, a2v_ref_img, a2v_headpose_video,
78
  a2v_size_slider, a2v_step_slider, a2v_length, a2v_seed],
79
+ outputs=[a2v_output_video, a2v_ref_img]
80
  )
 
 
 
 
 
81
  v2v_botton.click(
82
  fn=video2video,
83
  inputs=[v2v_ref_img, v2v_source_video,
84
  v2v_size_slider, v2v_step_slider, v2v_length, v2v_seed],
85
+ outputs=[v2v_output_video, v2v_ref_img]
86
  )
 
 
 
 
 
87
 
88
  demo.launch()
src/audio2vid.py CHANGED
@@ -9,25 +9,27 @@ import spaces
9
  from scipy.spatial.transform import Rotation as R
10
  from scipy.interpolate import interp1d
11
 
12
- from diffusers import AutoencoderKL, DDIMScheduler
13
- from einops import repeat
14
  from omegaconf import OmegaConf
15
  from PIL import Image
16
  from torchvision import transforms
17
- from transformers import CLIPVisionModelWithProjection
18
 
19
 
20
- from src.models.pose_guider import PoseGuider
21
- from src.models.unet_2d_condition import UNet2DConditionModel
22
- from src.models.unet_3d import UNet3DConditionModel
23
- from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
24
  from src.utils.util import save_videos_grid
25
 
26
- from src.audio_models.model import Audio2MeshModel
27
  from src.utils.audio_util import prepare_audio_feature
28
- from src.utils.mp_utils import LMKExtractor
29
- from src.utils.draw_util import FaceMeshVisualizer
30
  from src.utils.pose_util import project_points
 
 
31
 
32
 
33
  def matrix_to_euler_and_translation(matrix):
@@ -49,7 +51,7 @@ def smooth_pose_seq(pose_seq, window_size=5):
49
  return smoothed_pose_seq
50
 
51
  def get_headpose_temp(input_video):
52
- lmk_extractor = LMKExtractor()
53
  cap = cv2.VideoCapture(input_video)
54
 
55
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
@@ -98,70 +100,70 @@ def audio2video(input_audio, ref_img, headpose_video=None, size=512, steps=25, l
98
 
99
  config = OmegaConf.load('./configs/prompts/animation_audio.yaml')
100
 
101
- if config.weight_dtype == "fp16":
102
- weight_dtype = torch.float16
103
- else:
104
- weight_dtype = torch.float32
105
 
106
  audio_infer_config = OmegaConf.load(config.audio_inference_config)
107
- # prepare model
108
- a2m_model = Audio2MeshModel(audio_infer_config['a2m_model'])
109
- a2m_model.load_state_dict(torch.load(audio_infer_config['pretrained_model']['a2m_ckpt']), strict=False)
110
- a2m_model.cuda().eval()
111
 
112
- vae = AutoencoderKL.from_pretrained(
113
- config.pretrained_vae_path,
114
- ).to("cuda", dtype=weight_dtype)
115
 
116
- reference_unet = UNet2DConditionModel.from_pretrained(
117
- config.pretrained_base_model_path,
118
- subfolder="unet",
119
- ).to(dtype=weight_dtype, device="cuda")
120
 
121
- inference_config_path = config.inference_config
122
- infer_config = OmegaConf.load(inference_config_path)
123
- denoising_unet = UNet3DConditionModel.from_pretrained_2d(
124
- config.pretrained_base_model_path,
125
- config.motion_module_path,
126
- subfolder="unet",
127
- unet_additional_kwargs=infer_config.unet_additional_kwargs,
128
- ).to(dtype=weight_dtype, device="cuda")
129
 
130
 
131
- pose_guider = PoseGuider(noise_latent_channels=320, use_ca=True).to(device="cuda", dtype=weight_dtype) # not use cross attention
132
 
133
- image_enc = CLIPVisionModelWithProjection.from_pretrained(
134
- config.image_encoder_path
135
- ).to(dtype=weight_dtype, device="cuda")
136
 
137
- sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
138
- scheduler = DDIMScheduler(**sched_kwargs)
139
 
140
  generator = torch.manual_seed(seed)
141
 
142
  width, height = size, size
143
 
144
- # load pretrained weights
145
- denoising_unet.load_state_dict(
146
- torch.load(config.denoising_unet_path, map_location="cpu"),
147
- strict=False,
148
- )
149
- reference_unet.load_state_dict(
150
- torch.load(config.reference_unet_path, map_location="cpu"),
151
- )
152
- pose_guider.load_state_dict(
153
- torch.load(config.pose_guider_path, map_location="cpu"),
154
- )
155
-
156
- pipe = Pose2VideoPipeline(
157
- vae=vae,
158
- image_encoder=image_enc,
159
- reference_unet=reference_unet,
160
- denoising_unet=denoising_unet,
161
- pose_guider=pose_guider,
162
- scheduler=scheduler,
163
- )
164
- pipe = pipe.to("cuda", dtype=weight_dtype)
165
 
166
  date_str = datetime.now().strftime("%Y%m%d")
167
  time_str = datetime.now().strftime("%H%M")
@@ -170,17 +172,20 @@ def audio2video(input_audio, ref_img, headpose_video=None, size=512, steps=25, l
170
  save_dir = Path(f"output/{date_str}/{save_dir_name}")
171
  save_dir.mkdir(exist_ok=True, parents=True)
172
 
173
- lmk_extractor = LMKExtractor()
174
- vis = FaceMeshVisualizer(forehead_edge=False)
175
 
176
  ref_image_np = cv2.cvtColor(ref_img, cv2.COLOR_RGB2BGR)
177
- # TODO: 人脸检测+裁剪
 
 
 
178
  ref_image_np = cv2.resize(ref_image_np, (size, size))
179
  ref_image_pil = Image.fromarray(cv2.cvtColor(ref_image_np, cv2.COLOR_BGR2RGB))
180
 
181
  face_result = lmk_extractor(ref_image_np)
182
  if face_result is None:
183
- return None
184
 
185
  lmks = face_result['lmks'].astype(np.float32)
186
  ref_pose = vis.draw_landmarks((ref_image_np.shape[1], ref_image_np.shape[0]), lmks, normed=True)
@@ -217,6 +222,7 @@ def audio2video(input_audio, ref_img, headpose_video=None, size=512, steps=25, l
217
  [transforms.Resize((height, width)), transforms.ToTensor()]
218
  )
219
  args_L = len(pose_images) if length==0 or length > len(pose_images) else length
 
220
  for pose_image_np in pose_images[: args_L]:
221
  pose_image_pil = Image.fromarray(cv2.cvtColor(pose_image_np, cv2.COLOR_BGR2RGB))
222
  pose_tensor_list.append(pose_transform(pose_image_pil))
@@ -249,7 +255,7 @@ def audio2video(input_audio, ref_img, headpose_video=None, size=512, steps=25, l
249
 
250
  stream = ffmpeg.input(save_path)
251
  audio = ffmpeg.input(input_audio)
252
- ffmpeg.output(stream.video, audio.audio, save_path.replace('_noaudio.mp4', '.mp4'), vcodec='copy', acodec='aac').run()
253
  os.remove(save_path)
254
 
255
- return save_path.replace('_noaudio.mp4', '.mp4')
 
9
  from scipy.spatial.transform import Rotation as R
10
  from scipy.interpolate import interp1d
11
 
12
+ # from diffusers import AutoencoderKL, DDIMScheduler
13
+ # from einops import repeat
14
  from omegaconf import OmegaConf
15
  from PIL import Image
16
  from torchvision import transforms
17
+ # from transformers import CLIPVisionModelWithProjection
18
 
19
 
20
+ # from src.models.pose_guider import PoseGuider
21
+ # from src.models.unet_2d_condition import UNet2DConditionModel
22
+ # from src.models.unet_3d import UNet3DConditionModel
23
+ # from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
24
  from src.utils.util import save_videos_grid
25
 
26
+ # from src.audio_models.model import Audio2MeshModel
27
  from src.utils.audio_util import prepare_audio_feature
28
+ # from src.utils.mp_utils import LMKExtractor
29
+ # from src.utils.draw_util import FaceMeshVisualizer
30
  from src.utils.pose_util import project_points
31
+ from src.utils.crop_face_single import crop_face
32
+ from src.create_modules import lmk_extractor, vis, a2m_model, pipe
33
 
34
 
35
  def matrix_to_euler_and_translation(matrix):
 
51
  return smoothed_pose_seq
52
 
53
  def get_headpose_temp(input_video):
54
+ # lmk_extractor = LMKExtractor()
55
  cap = cv2.VideoCapture(input_video)
56
 
57
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
 
100
 
101
  config = OmegaConf.load('./configs/prompts/animation_audio.yaml')
102
 
103
+ # if config.weight_dtype == "fp16":
104
+ # weight_dtype = torch.float16
105
+ # else:
106
+ # weight_dtype = torch.float32
107
 
108
  audio_infer_config = OmegaConf.load(config.audio_inference_config)
109
+ # # prepare model
110
+ # a2m_model = Audio2MeshModel(audio_infer_config['a2m_model'])
111
+ # a2m_model.load_state_dict(torch.load(audio_infer_config['pretrained_model']['a2m_ckpt']), strict=False)
112
+ # a2m_model.cuda().eval()
113
 
114
+ # vae = AutoencoderKL.from_pretrained(
115
+ # config.pretrained_vae_path,
116
+ # ).to("cuda", dtype=weight_dtype)
117
 
118
+ # reference_unet = UNet2DConditionModel.from_pretrained(
119
+ # config.pretrained_base_model_path,
120
+ # subfolder="unet",
121
+ # ).to(dtype=weight_dtype, device="cuda")
122
 
123
+ # inference_config_path = config.inference_config
124
+ # infer_config = OmegaConf.load(inference_config_path)
125
+ # denoising_unet = UNet3DConditionModel.from_pretrained_2d(
126
+ # config.pretrained_base_model_path,
127
+ # config.motion_module_path,
128
+ # subfolder="unet",
129
+ # unet_additional_kwargs=infer_config.unet_additional_kwargs,
130
+ # ).to(dtype=weight_dtype, device="cuda")
131
 
132
 
133
+ # pose_guider = PoseGuider(noise_latent_channels=320, use_ca=True).to(device="cuda", dtype=weight_dtype) # not use cross attention
134
 
135
+ # image_enc = CLIPVisionModelWithProjection.from_pretrained(
136
+ # config.image_encoder_path
137
+ # ).to(dtype=weight_dtype, device="cuda")
138
 
139
+ # sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
140
+ # scheduler = DDIMScheduler(**sched_kwargs)
141
 
142
  generator = torch.manual_seed(seed)
143
 
144
  width, height = size, size
145
 
146
+ # # load pretrained weights
147
+ # denoising_unet.load_state_dict(
148
+ # torch.load(config.denoising_unet_path, map_location="cpu"),
149
+ # strict=False,
150
+ # )
151
+ # reference_unet.load_state_dict(
152
+ # torch.load(config.reference_unet_path, map_location="cpu"),
153
+ # )
154
+ # pose_guider.load_state_dict(
155
+ # torch.load(config.pose_guider_path, map_location="cpu"),
156
+ # )
157
+
158
+ # pipe = Pose2VideoPipeline(
159
+ # vae=vae,
160
+ # image_encoder=image_enc,
161
+ # reference_unet=reference_unet,
162
+ # denoising_unet=denoising_unet,
163
+ # pose_guider=pose_guider,
164
+ # scheduler=scheduler,
165
+ # )
166
+ # pipe = pipe.to("cuda", dtype=weight_dtype)
167
 
168
  date_str = datetime.now().strftime("%Y%m%d")
169
  time_str = datetime.now().strftime("%H%M")
 
172
  save_dir = Path(f"output/{date_str}/{save_dir_name}")
173
  save_dir.mkdir(exist_ok=True, parents=True)
174
 
175
+ # lmk_extractor = LMKExtractor()
176
+ # vis = FaceMeshVisualizer(forehead_edge=False)
177
 
178
  ref_image_np = cv2.cvtColor(ref_img, cv2.COLOR_RGB2BGR)
179
+ ref_image_np = crop_face(ref_image_np, lmk_extractor)
180
+ if ref_image_np is None:
181
+ return None, Image.fromarray(ref_img)
182
+
183
  ref_image_np = cv2.resize(ref_image_np, (size, size))
184
  ref_image_pil = Image.fromarray(cv2.cvtColor(ref_image_np, cv2.COLOR_BGR2RGB))
185
 
186
  face_result = lmk_extractor(ref_image_np)
187
  if face_result is None:
188
+ return None, ref_image_pil
189
 
190
  lmks = face_result['lmks'].astype(np.float32)
191
  ref_pose = vis.draw_landmarks((ref_image_np.shape[1], ref_image_np.shape[0]), lmks, normed=True)
 
222
  [transforms.Resize((height, width)), transforms.ToTensor()]
223
  )
224
  args_L = len(pose_images) if length==0 or length > len(pose_images) else length
225
+ args_L = min(args_L, 300)
226
  for pose_image_np in pose_images[: args_L]:
227
  pose_image_pil = Image.fromarray(cv2.cvtColor(pose_image_np, cv2.COLOR_BGR2RGB))
228
  pose_tensor_list.append(pose_transform(pose_image_pil))
 
255
 
256
  stream = ffmpeg.input(save_path)
257
  audio = ffmpeg.input(input_audio)
258
+ ffmpeg.output(stream.video, audio.audio, save_path.replace('_noaudio.mp4', '.mp4'), vcodec='copy', acodec='aac', shortest=None).run()
259
  os.remove(save_path)
260
 
261
+ return save_path.replace('_noaudio.mp4', '.mp4'), ref_image_pil
src/utils/crop_face_single.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+
4
+
5
+ def crop_face(img, lmk_extractor, expand=1.5):
6
+ result = lmk_extractor(img) # cv2 BGR
7
+
8
+ if result is None:
9
+ return None
10
+
11
+ H, W, _ = img.shape
12
+ lmks = result['lmks']
13
+ lmks[:, 0] *= W
14
+ lmks[:, 1] *= H
15
+
16
+ x_min = np.min(lmks[:, 0])
17
+ x_max = np.max(lmks[:, 0])
18
+ y_min = np.min(lmks[:, 1])
19
+ y_max = np.max(lmks[:, 1])
20
+
21
+ width = x_max - x_min
22
+ height = y_max - y_min
23
+
24
+ center_x = x_min + width / 2
25
+ center_y = y_min + height / 2
26
+
27
+ width *= expand
28
+ height *= expand
29
+
30
+ size = max(width, height)
31
+
32
+ x_min = int(center_x - size / 2)
33
+ x_max = int(center_x + size / 2)
34
+ y_min = int(center_y - size / 2)
35
+ y_max = int(center_y + size / 2)
36
+
37
+ top = max(0, -y_min)
38
+ bottom = max(0, y_max - img.shape[0])
39
+ left = max(0, -x_min)
40
+ right = max(0, x_max - img.shape[1])
41
+ img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=0)
42
+
43
+ cropped_img = img[y_min + top:y_max + top, x_min + left:x_max + left]
44
+
45
+ return cropped_img
src/vid2vid.py CHANGED
@@ -1,4 +1,3 @@
1
- import argparse
2
  import os
3
  import shutil
4
  import ffmpeg
@@ -8,88 +7,89 @@ import numpy as np
8
  import cv2
9
  import torch
10
  import spaces
11
- from diffusers import AutoencoderKL, DDIMScheduler
12
- from einops import repeat
13
- from omegaconf import OmegaConf
14
  from PIL import Image
15
  from torchvision import transforms
16
- from transformers import CLIPVisionModelWithProjection
17
 
18
- from src.models.pose_guider import PoseGuider
19
- from src.models.unet_2d_condition import UNet2DConditionModel
20
- from src.models.unet_3d import UNet3DConditionModel
21
- from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
22
  from src.utils.util import get_fps, read_frames, save_videos_grid
23
 
24
- from src.utils.mp_utils import LMKExtractor
25
- from src.utils.draw_util import FaceMeshVisualizer
26
  from src.utils.pose_util import project_points_with_trans, matrix_to_euler_and_translation, euler_and_translation_to_matrix
27
  from src.audio2vid import smooth_pose_seq
28
-
 
29
 
30
  @spaces.GPU
31
  def video2video(ref_img, source_video, size=512, steps=25, length=150, seed=42):
32
  cfg = 3.5
33
 
34
- config = OmegaConf.load('./configs/prompts/animation_facereenac.yaml')
35
 
36
- if config.weight_dtype == "fp16":
37
- weight_dtype = torch.float16
38
- else:
39
- weight_dtype = torch.float32
40
 
41
- vae = AutoencoderKL.from_pretrained(
42
- config.pretrained_vae_path,
43
- ).to("cuda", dtype=weight_dtype)
44
 
45
- reference_unet = UNet2DConditionModel.from_pretrained(
46
- config.pretrained_base_model_path,
47
- subfolder="unet",
48
- ).to(dtype=weight_dtype, device="cuda")
49
 
50
- inference_config_path = config.inference_config
51
- infer_config = OmegaConf.load(inference_config_path)
52
- denoising_unet = UNet3DConditionModel.from_pretrained_2d(
53
- config.pretrained_base_model_path,
54
- config.motion_module_path,
55
- subfolder="unet",
56
- unet_additional_kwargs=infer_config.unet_additional_kwargs,
57
- ).to(dtype=weight_dtype, device="cuda")
58
 
59
- pose_guider = PoseGuider(noise_latent_channels=320, use_ca=True).to(device="cuda", dtype=weight_dtype) # not use cross attention
60
 
61
- image_enc = CLIPVisionModelWithProjection.from_pretrained(
62
- config.image_encoder_path
63
- ).to(dtype=weight_dtype, device="cuda")
64
 
65
- sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
66
- scheduler = DDIMScheduler(**sched_kwargs)
67
 
68
  generator = torch.manual_seed(seed)
69
 
70
  width, height = size, size
71
 
72
- # load pretrained weights
73
- denoising_unet.load_state_dict(
74
- torch.load(config.denoising_unet_path, map_location="cpu"),
75
- strict=False,
76
- )
77
- reference_unet.load_state_dict(
78
- torch.load(config.reference_unet_path, map_location="cpu"),
79
- )
80
- pose_guider.load_state_dict(
81
- torch.load(config.pose_guider_path, map_location="cpu"),
82
- )
83
-
84
- pipe = Pose2VideoPipeline(
85
- vae=vae,
86
- image_encoder=image_enc,
87
- reference_unet=reference_unet,
88
- denoising_unet=denoising_unet,
89
- pose_guider=pose_guider,
90
- scheduler=scheduler,
91
- )
92
- pipe = pipe.to("cuda", dtype=weight_dtype)
93
 
94
  date_str = datetime.now().strftime("%Y%m%d")
95
  time_str = datetime.now().strftime("%H%M")
@@ -99,24 +99,25 @@ def video2video(ref_img, source_video, size=512, steps=25, length=150, seed=42):
99
  save_dir.mkdir(exist_ok=True, parents=True)
100
 
101
 
102
- lmk_extractor = LMKExtractor()
103
- vis = FaceMeshVisualizer(forehead_edge=False)
104
 
105
 
106
 
107
  ref_image_np = cv2.cvtColor(ref_img, cv2.COLOR_RGB2BGR)
108
- # TODO: 人脸检测+裁剪
 
 
 
109
  ref_image_np = cv2.resize(ref_image_np, (size, size))
110
  ref_image_pil = Image.fromarray(cv2.cvtColor(ref_image_np, cv2.COLOR_BGR2RGB))
111
 
112
  face_result = lmk_extractor(ref_image_np)
113
  if face_result is None:
114
- return None
115
 
116
  lmks = face_result['lmks'].astype(np.float32)
117
  ref_pose = vis.draw_landmarks((ref_image_np.shape[1], ref_image_np.shape[0]), lmks, normed=True)
118
-
119
-
120
 
121
  source_images = read_frames(source_video)
122
  src_fps = get_fps(source_video)
@@ -134,6 +135,7 @@ def video2video(ref_img, source_video, size=512, steps=25, length=150, seed=42):
134
  bs_list = []
135
  src_tensor_list = []
136
  args_L = len(source_images) if length==0 or length*step > len(source_images) else length*step
 
137
  for src_image_pil in source_images[: args_L: step]:
138
  src_tensor_list.append(pose_transform(src_image_pil))
139
  src_img_np = cv2.cvtColor(np.array(src_image_pil), cv2.COLOR_RGB2BGR)
@@ -209,7 +211,7 @@ def video2video(ref_img, source_video, size=512, steps=25, length=150, seed=42):
209
  # merge audio and video
210
  stream = ffmpeg.input(save_path)
211
  audio = ffmpeg.input(audio_output)
212
- ffmpeg.output(stream.video, audio.audio, save_path.replace('_noaudio.mp4', '.mp4'), vcodec='copy', acodec='aac').run()
213
 
214
  os.remove(save_path)
215
  os.remove(audio_output)
@@ -219,4 +221,4 @@ def video2video(ref_img, source_video, size=512, steps=25, length=150, seed=42):
219
  save_path.replace('_noaudio.mp4', '.mp4')
220
  )
221
 
222
- return save_path.replace('_noaudio.mp4', '.mp4')
 
 
1
  import os
2
  import shutil
3
  import ffmpeg
 
7
  import cv2
8
  import torch
9
  import spaces
10
+ # from diffusers import AutoencoderKL, DDIMScheduler
11
+ # from einops import repeat
12
+ # from omegaconf import OmegaConf
13
  from PIL import Image
14
  from torchvision import transforms
15
+ # from transformers import CLIPVisionModelWithProjection
16
 
17
+ # from src.models.pose_guider import PoseGuider
18
+ # from src.models.unet_2d_condition import UNet2DConditionModel
19
+ # from src.models.unet_3d import UNet3DConditionModel
20
+ # from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
21
  from src.utils.util import get_fps, read_frames, save_videos_grid
22
 
23
+ # from src.utils.mp_utils import LMKExtractor
24
+ # from src.utils.draw_util import FaceMeshVisualizer
25
  from src.utils.pose_util import project_points_with_trans, matrix_to_euler_and_translation, euler_and_translation_to_matrix
26
  from src.audio2vid import smooth_pose_seq
27
+ from src.utils.crop_face_single import crop_face
28
+ from src.create_modules import lmk_extractor, vis, pipe
29
 
30
  @spaces.GPU
31
  def video2video(ref_img, source_video, size=512, steps=25, length=150, seed=42):
32
  cfg = 3.5
33
 
34
+ # config = OmegaConf.load('./configs/prompts/animation_facereenac.yaml')
35
 
36
+ # if config.weight_dtype == "fp16":
37
+ # weight_dtype = torch.float16
38
+ # else:
39
+ # weight_dtype = torch.float32
40
 
41
+ # vae = AutoencoderKL.from_pretrained(
42
+ # config.pretrained_vae_path,
43
+ # ).to("cuda", dtype=weight_dtype)
44
 
45
+ # reference_unet = UNet2DConditionModel.from_pretrained(
46
+ # config.pretrained_base_model_path,
47
+ # subfolder="unet",
48
+ # ).to(dtype=weight_dtype, device="cuda")
49
 
50
+ # inference_config_path = config.inference_config
51
+ # infer_config = OmegaConf.load(inference_config_path)
52
+ # denoising_unet = UNet3DConditionModel.from_pretrained_2d(
53
+ # config.pretrained_base_model_path,
54
+ # config.motion_module_path,
55
+ # subfolder="unet",
56
+ # unet_additional_kwargs=infer_config.unet_additional_kwargs,
57
+ # ).to(dtype=weight_dtype, device="cuda")
58
 
59
+ # pose_guider = PoseGuider(noise_latent_channels=320, use_ca=True).to(device="cuda", dtype=weight_dtype) # not use cross attention
60
 
61
+ # image_enc = CLIPVisionModelWithProjection.from_pretrained(
62
+ # config.image_encoder_path
63
+ # ).to(dtype=weight_dtype, device="cuda")
64
 
65
+ # sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
66
+ # scheduler = DDIMScheduler(**sched_kwargs)
67
 
68
  generator = torch.manual_seed(seed)
69
 
70
  width, height = size, size
71
 
72
+ # # load pretrained weights
73
+ # denoising_unet.load_state_dict(
74
+ # torch.load(config.denoising_unet_path, map_location="cpu"),
75
+ # strict=False,
76
+ # )
77
+ # reference_unet.load_state_dict(
78
+ # torch.load(config.reference_unet_path, map_location="cpu"),
79
+ # )
80
+ # pose_guider.load_state_dict(
81
+ # torch.load(config.pose_guider_path, map_location="cpu"),
82
+ # )
83
+
84
+ # pipe = Pose2VideoPipeline(
85
+ # vae=vae,
86
+ # image_encoder=image_enc,
87
+ # reference_unet=reference_unet,
88
+ # denoising_unet=denoising_unet,
89
+ # pose_guider=pose_guider,
90
+ # scheduler=scheduler,
91
+ # )
92
+ # pipe = pipe.to("cuda", dtype=weight_dtype)
93
 
94
  date_str = datetime.now().strftime("%Y%m%d")
95
  time_str = datetime.now().strftime("%H%M")
 
99
  save_dir.mkdir(exist_ok=True, parents=True)
100
 
101
 
102
+ # lmk_extractor = LMKExtractor()
103
+ # vis = FaceMeshVisualizer(forehead_edge=False)
104
 
105
 
106
 
107
  ref_image_np = cv2.cvtColor(ref_img, cv2.COLOR_RGB2BGR)
108
+ ref_image_np = crop_face(ref_image_np, lmk_extractor)
109
+ if ref_image_np is None:
110
+ return None, Image.fromarray(ref_img)
111
+
112
  ref_image_np = cv2.resize(ref_image_np, (size, size))
113
  ref_image_pil = Image.fromarray(cv2.cvtColor(ref_image_np, cv2.COLOR_BGR2RGB))
114
 
115
  face_result = lmk_extractor(ref_image_np)
116
  if face_result is None:
117
+ return None, ref_image_pil
118
 
119
  lmks = face_result['lmks'].astype(np.float32)
120
  ref_pose = vis.draw_landmarks((ref_image_np.shape[1], ref_image_np.shape[0]), lmks, normed=True)
 
 
121
 
122
  source_images = read_frames(source_video)
123
  src_fps = get_fps(source_video)
 
135
  bs_list = []
136
  src_tensor_list = []
137
  args_L = len(source_images) if length==0 or length*step > len(source_images) else length*step
138
+ args_L = min(args_L, 300*step)
139
  for src_image_pil in source_images[: args_L: step]:
140
  src_tensor_list.append(pose_transform(src_image_pil))
141
  src_img_np = cv2.cvtColor(np.array(src_image_pil), cv2.COLOR_RGB2BGR)
 
211
  # merge audio and video
212
  stream = ffmpeg.input(save_path)
213
  audio = ffmpeg.input(audio_output)
214
+ ffmpeg.output(stream.video, audio.audio, save_path.replace('_noaudio.mp4', '.mp4'), vcodec='copy', acodec='aac', shortest=None).run()
215
 
216
  os.remove(save_path)
217
  os.remove(audio_output)
 
221
  save_path.replace('_noaudio.mp4', '.mp4')
222
  )
223
 
224
+ return save_path.replace('_noaudio.mp4', '.mp4'), ref_image_pil