VictorMais commited on
Commit
e355904
·
verified ·
1 Parent(s): 9098160

Added the ZeroGPU processing

Browse files
Files changed (1) hide show
  1. webgui.py +250 -293
webgui.py CHANGED
@@ -3,11 +3,8 @@
3
  '''
4
  webui
5
  '''
6
- import spaces
7
- import os
8
 
9
- os.system('pip install scikit-image')
10
- os.system('pip install IPython')
11
  import random
12
  from datetime import datetime
13
  from pathlib import Path
@@ -29,22 +26,17 @@ from facenet_pytorch import MTCNN
29
  import argparse
30
 
31
  import gradio as gr
32
-
33
- import huggingface_hub
34
-
35
- import pickle
36
- from src.utils.draw_utils import FaceMeshVisualizer
37
- from src.utils.motion_utils import motion_sync
38
- from src.utils.mp_utils import LMKExtractor
39
-
40
 
41
  huggingface_hub.snapshot_download(
42
  repo_id='BadToBest/EchoMimic',
43
- local_dir='./pretrained_weights',
44
- local_dir_use_symlinks=False,
45
  )
46
 
47
- is_shared_ui = True if "fffiloni/EchoMimic" in os.environ['SPACE_ID'] else False
48
  available_property = False if is_shared_ui else True
49
  advanced_settings_label = "Advanced Configuration (only for duplicated spaces)" if is_shared_ui else "Advanced Configuration"
50
 
@@ -67,7 +59,7 @@ default_values = {
67
  ffmpeg_path = os.getenv('FFMPEG_PATH')
68
  if ffmpeg_path is None:
69
  print("please download ffmpeg-static and export to FFMPEG_PATH. \nFor example: export FFMPEG_PATH=/musetalk/ffmpeg-4.4-amd64-static")
70
- elif ffmpeg_path not in os.getenv('PATH'):
71
  print("add ffmpeg to path")
72
  os.environ["PATH"] = f"{ffmpeg_path}:{os.environ['PATH']}"
73
 
@@ -86,64 +78,91 @@ if not torch.cuda.is_available():
86
  inference_config_path = config.inference_config
87
  infer_config = OmegaConf.load(inference_config_path)
88
 
89
- ############# model_init started #############
90
- ## vae init
91
- vae = AutoencoderKL.from_pretrained(config.pretrained_vae_path).to("cuda", dtype=weight_dtype)
92
-
93
- ## reference net init
94
- reference_unet = UNet2DConditionModel.from_pretrained(
95
- config.pretrained_base_model_path,
96
- subfolder="unet",
97
- ).to(dtype=weight_dtype, device=device)
98
- reference_unet.load_state_dict(torch.load(config.reference_unet_path, map_location="cpu"))
99
-
100
- ## denoising net init
101
- if os.path.exists(config.motion_module_path):
102
- ### stage1 + stage2
103
- denoising_unet = EchoUNet3DConditionModel.from_pretrained_2d(
104
- config.pretrained_base_model_path,
105
- config.motion_module_path,
106
- subfolder="unet",
107
- unet_additional_kwargs=infer_config.unet_additional_kwargs,
108
- ).to(dtype=weight_dtype, device=device)
109
- else:
110
- ### only stage1
111
- denoising_unet = EchoUNet3DConditionModel.from_pretrained_2d(
112
  config.pretrained_base_model_path,
113
- "",
114
  subfolder="unet",
115
- unet_additional_kwargs={
116
- "use_motion_module": False,
117
- "unet_use_temporal_attention": False,
118
- "cross_attention_dim": infer_config.unet_additional_kwargs.cross_attention_dim
119
- }
120
  ).to(dtype=weight_dtype, device=device)
121
-
122
- denoising_unet.load_state_dict(torch.load(config.denoising_unet_path, map_location="cpu"), strict=False)
123
-
124
- ## face locator init
125
- face_locator = FaceLocator(320, conditioning_channels=1, block_out_channels=(16, 32, 96, 256)).to(dtype=weight_dtype, device="cuda")
126
- face_locator.load_state_dict(torch.load(config.face_locator_path, map_location='cpu'))
127
-
128
- ## load audio processor params
129
- audio_processor = load_audio_model(model_path=config.audio_model_path, device=device)
130
-
131
- ## load face detector params
132
- face_detector = MTCNN(image_size=320, margin=0, min_face_size=20, thresholds=[0.6, 0.7, 0.7], factor=0.709, post_process=True, device=device)
133
-
134
- ############# model_init finished #############
135
-
136
- sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
137
- scheduler = DDIMScheduler(**sched_kwargs)
138
-
139
- pipe = Audio2VideoPipeline(
140
- vae=vae,
141
- reference_unet=reference_unet,
142
- denoising_unet=denoising_unet,
143
- audio_guider=audio_processor,
144
- face_locator=face_locator,
145
- scheduler=scheduler,
146
- ).to("cuda", dtype=weight_dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
  def select_face(det_bboxes, probs):
149
  ## max face from faces that the prob is above 0.8
@@ -159,25 +178,58 @@ def select_face(det_bboxes, probs):
159
  sorted_bboxes = sorted(filtered_bboxes, key=lambda x:(x[3]-x[1]) * (x[2] - x[0]), reverse=True)
160
  return sorted_bboxes[0]
161
 
162
- lmk_extractor = LMKExtractor()
 
 
 
 
163
 
164
- def face_detection(uploaded_img, facemask_dilation_ratio, facecrop_dilation_ratio, width, height):
 
 
 
 
 
 
 
165
  face_img = cv2.imread(uploaded_img)
166
- if face_img is None:
167
- raise gr.Error("input image should be uploaded or selected.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  face_mask = np.zeros((face_img.shape[0], face_img.shape[1])).astype('uint8')
169
  det_bboxes, probs = face_detector.detect(face_img)
170
  select_bbox = select_face(det_bboxes, probs)
171
  if select_bbox is None:
 
172
  face_mask[:, :] = 255
 
 
 
173
  else:
 
174
  xyxy = select_bbox[:4]
175
  xyxy = np.round(xyxy).astype('int')
176
  rb, re, cb, ce = xyxy[1], xyxy[3], xyxy[0], xyxy[2]
177
  r_pad = int((re - rb) * facemask_dilation_ratio)
178
  c_pad = int((ce - cb) * facemask_dilation_ratio)
179
  face_mask[rb - r_pad : re + r_pad, cb - c_pad : ce + c_pad] = 255
180
-
 
181
  r_pad_crop = int((re - rb) * facecrop_dilation_ratio)
182
  c_pad_crop = int((ce - cb) * facecrop_dilation_ratio)
183
  crop_rect = [max(0, cb - c_pad_crop), max(0, rb - r_pad_crop), min(ce + c_pad_crop, face_img.shape[1]), min(re + r_pad_crop, face_img.shape[0])]
@@ -185,15 +237,10 @@ def face_detection(uploaded_img, facemask_dilation_ratio, facecrop_dilation_rati
185
  face_mask = crop_and_pad(face_mask, crop_rect)
186
  face_img = cv2.resize(face_img, (width, height))
187
  face_mask = cv2.resize(face_mask, (width, height))
188
-
189
- print('face detect done.')
190
- return face_img, face_mask
191
 
192
- @spaces.GPU(duration=300)
193
- def video_pipe(face_img, face_mask, uploaded_audio, width, height, length, context_frames, context_overlap, cfg, steps, sample_rate, fps, device):
194
- face_mask_tensor = torch.Tensor(face_mask).to(dtype=weight_dtype, device="cuda").unsqueeze(0).unsqueeze(0).unsqueeze(0) / 255.0
195
  ref_image_pil = Image.fromarray(face_img[:, :, [2, 1, 0]])
196
-
 
197
  video = pipe(
198
  ref_image_pil,
199
  uploaded_audio,
@@ -203,12 +250,12 @@ def video_pipe(face_img, face_mask, uploaded_audio, width, height, length, conte
203
  length,
204
  steps,
205
  cfg,
 
206
  audio_sample_rate=sample_rate,
207
  context_frames=context_frames,
208
  fps=fps,
209
  context_overlap=context_overlap
210
  ).videos
211
- print('video pipe done.')
212
 
213
  save_dir = Path("output/tmp")
214
  save_dir.mkdir(exist_ok=True, parents=True)
@@ -223,107 +270,27 @@ def video_pipe(face_img, face_mask, uploaded_audio, width, height, length, conte
223
 
224
  return final_output_path
225
 
226
- def process_video(uploaded_img, uploaded_audio, width, height, length, facemask_dilation_ratio, facecrop_dilation_ratio, context_frames, context_overlap, cfg, steps, sample_rate, fps, device):
227
- face_img, face_mask = face_detection(uploaded_img, facemask_dilation_ratio, facecrop_dilation_ratio, width, height)
228
- final_output_path = video_pipe(face_img, face_mask, uploaded_audio, width, height, length, context_frames, context_overlap, cfg, steps, sample_rate, fps, device)
229
- return final_output_path
230
-
231
-
232
- # @spaces.GPU
233
- # def process_video(uploaded_img, uploaded_audio, width, height, length, facemask_dilation_ratio, facecrop_dilation_ratio, context_frames, context_overlap, cfg, steps, sample_rate, fps, device):
234
- # #### face musk prepare
235
- # face_img = cv2.imread(uploaded_img)
236
- # if face_img is None:
237
- # raise gr.Error("input image should be uploaded or selected.")
238
- # face_mask = np.zeros((face_img.shape[0], face_img.shape[1])).astype('uint8')
239
- # det_bboxes, probs = face_detector.detect(face_img)
240
- # select_bbox = select_face(det_bboxes, probs)
241
- # if select_bbox is None:
242
- # face_mask[:, :] = 255
243
- # else:
244
- # xyxy = select_bbox[:4]
245
- # xyxy = np.round(xyxy).astype('int')
246
- # rb, re, cb, ce = xyxy[1], xyxy[3], xyxy[0], xyxy[2]
247
- # r_pad = int((re - rb) * facemask_dilation_ratio)
248
- # c_pad = int((ce - cb) * facemask_dilation_ratio)
249
- # face_mask[rb - r_pad : re + r_pad, cb - c_pad : ce + c_pad] = 255
250
-
251
- # #### face crop
252
- # r_pad_crop = int((re - rb) * facecrop_dilation_ratio)
253
- # c_pad_crop = int((ce - cb) * facecrop_dilation_ratio)
254
- # crop_rect = [max(0, cb - c_pad_crop), max(0, rb - r_pad_crop), min(ce + c_pad_crop, face_img.shape[1]), min(re + r_pad_crop, face_img.shape[0])]
255
- # face_img = crop_and_pad(face_img, crop_rect)
256
- # face_mask = crop_and_pad(face_mask, crop_rect)
257
- # face_img = cv2.resize(face_img, (width, height))
258
- # face_mask = cv2.resize(face_mask, (width, height))
259
- # print('face detect done.')
260
- # # ==================== face_locator =====================
261
- # '''
262
- # driver_video = "./assets/driven_videos/c.mp4"
263
-
264
- # input_frames_cv2 = [cv2.resize(center_crop_cv2(pil_to_cv2(i)), (512, 512)) for i in pils_from_video(driver_video)]
265
- # ref_det = lmk_extractor(face_img)
266
-
267
- # visualizer = FaceMeshVisualizer(draw_iris=False, draw_mouse=False)
268
 
269
- # pose_list = []
270
- # sequence_driver_det = []
271
- # try:
272
- # for frame in input_frames_cv2:
273
- # result = lmk_extractor(frame)
274
- # assert result is not None, "{}, bad video, face not detected".format(driver_video)
275
- # sequence_driver_det.append(result)
276
- # except:
277
- # print("face detection failed")
278
- # exit()
279
-
280
- # sequence_det_ms = motion_sync(sequence_driver_det, ref_det)
281
- # for p in sequence_det_ms:
282
- # tgt_musk = visualizer.draw_landmarks((width, height), p)
283
- # tgt_musk_pil = Image.fromarray(np.array(tgt_musk).astype(np.uint8)).convert('RGB')
284
- # pose_list.append(torch.Tensor(np.array(tgt_musk_pil)).to(dtype=weight_dtype, device="cuda").permute(2,0,1) / 255.0)
285
- # '''
286
- # # face_mask_tensor = torch.stack(pose_list, dim=1).unsqueeze(0)
287
- # face_mask_tensor = torch.Tensor(face_mask).to(dtype=weight_dtype, device="cuda").unsqueeze(0).unsqueeze(0).unsqueeze(0) / 255.0
288
-
289
- # ref_image_pil = Image.fromarray(face_img[:, :, [2, 1, 0]])
290
-
291
- # #del pose_list, sequence_det_ms, sequence_driver_det, input_frames_cv2
292
-
293
- # video = pipe(
294
- # ref_image_pil,
295
- # uploaded_audio,
296
- # face_mask_tensor,
297
- # width,
298
- # height,
299
- # length,
300
- # steps,
301
- # cfg,
302
- # #generator=generator,
303
- # audio_sample_rate=sample_rate,
304
- # context_frames=context_frames,
305
- # fps=fps,
306
- # context_overlap=context_overlap
307
- # ).videos
308
- # print('video pipe done.')
309
-
310
- # save_dir = Path("output/tmp")
311
- # save_dir.mkdir(exist_ok=True, parents=True)
312
- # output_video_path = save_dir / "output_video.mp4"
313
- # save_videos_grid(video, str(output_video_path), n_rows=1, fps=fps)
314
-
315
- # video_clip = VideoFileClip(str(output_video_path))
316
- # audio_clip = AudioFileClip(uploaded_audio)
317
- # final_output_path = save_dir / "output_video_with_audio.mp4"
318
- # video_clip = video_clip.set_audio(audio_clip)
319
- # video_clip.write_videofile(str(final_output_path), codec="libx264", audio_codec="aac")
320
-
321
- # return final_output_path
322
 
323
  with gr.Blocks() as demo:
324
  gr.Markdown('# EchoMimic')
325
  gr.Markdown('## Lifelike Audio-Driven Portrait Animations through Editable Landmark Conditioning')
326
- gr.Markdown('Inference time: from ~7mins/240frames to ~50s/240frames on V100 GPU')
327
  gr.HTML("""
328
  <div style="display:flex;column-gap:4px;">
329
  <a href='https://badtobest.github.io/echomimic.html'><img src='https://img.shields.io/badge/Project-Page-blue'></a>
@@ -331,12 +298,24 @@ with gr.Blocks() as demo:
331
  <a href='https://arxiv.org/abs/2407.08136'><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a>
332
  </div>
333
  """)
334
-
335
  with gr.Row():
336
- with gr.Column(min_width=250):
337
  uploaded_img = gr.Image(type="filepath", label="Reference Image")
338
- with gr.Column(min_width=250):
339
- uploaded_audio = gr.Audio(type="filepath", label="Input Audio")
 
 
 
 
 
 
 
 
 
 
 
 
 
340
  with gr.Accordion(label=advanced_settings_label, open=False):
341
  with gr.Row():
342
  width = gr.Slider(label="Width", minimum=128, maximum=1024, value=default_values["width"], interactive=available_property)
@@ -357,132 +336,110 @@ with gr.Blocks() as demo:
357
  sample_rate = gr.Slider(label="Sample Rate", minimum=8000, maximum=48000, step=1000, value=default_values["sample_rate"], interactive=available_property)
358
  fps = gr.Slider(label="FPS", minimum=1, maximum=60, step=1, value=default_values["fps"], interactive=available_property)
359
  device = gr.Radio(label="Device", choices=["cuda", "cpu"], value=default_values["device"], interactive=available_property)
360
-
361
- with gr.Column(min_width=250):
362
  generate_button = gr.Button("Generate Video")
 
363
  output_video = gr.Video()
364
- with gr.Row():
365
-
366
- gr.Examples(
367
- label = "Portrait examples",
368
- examples = [
369
- ['assets/test_imgs/a.png'],
370
- ['assets/test_imgs/b.png'],
371
- ['assets/test_imgs/c.png'],
372
- ['assets/test_imgs/d.png'],
373
- ['assets/test_imgs/e.png']
374
- ],
375
- inputs = [uploaded_img]
376
- )
377
- gr.Examples(
378
- label = "Audio examples",
379
- examples = [
380
- ['assets/test_audios/chunnuanhuakai.wav'],
381
- ['assets/test_audios/chunwang.wav'],
382
- ['assets/test_audios/echomimic_en_girl.wav'],
383
- ['assets/test_audios/echomimic_en.wav'],
384
- ['assets/test_audios/echomimic_girl.wav'],
385
- ['assets/test_audios/echomimic.wav'],
386
- ['assets/test_audios/jane.wav'],
387
- ['assets/test_audios/mei.wav'],
388
- ['assets/test_audios/walden.wav'],
389
- ['assets/test_audios/yun.wav'],
390
- ],
391
- inputs = [uploaded_audio]
392
- )
393
- # gr.HTML("""
394
- # <div style="display:flex;column-gap:4px;">
395
- # <a href="https://huggingface.co/spaces/fffiloni/EchoMimic?duplicate=true">
396
- # <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-xl.svg" alt="Duplicate this Space">
397
- # </a>
398
- # <a href="https://huggingface.co/fffiloni">
399
- # <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/follow-me-on-HF-xl-dark.svg" alt="Follow me on HF">
400
- # </a>
401
- # </div>
402
- # """)
 
 
403
 
404
- # def generate_video(uploaded_img, uploaded_audio, facemask_dilation_ratio=default_values["facemask_dilation_ratio"],
405
- # facecrop_dilation_ratio=default_values["facecrop_dilation_ratio"],
406
- # context_frames=default_values["context_frames"],
407
- # context_overlap=default_values["context_overlap"],
408
- # cfg=default_values["cfg"],
409
- # steps=default_values["steps"],
410
- # sample_rate=default_values["sample_rate"],
411
- # fps=default_values["fps"],
412
- # device=default_values["device"],
413
- # width=default_values["width"],
414
- # height=default_values["height"],
415
- # length=default_values["length"] ):
416
-
417
- # final_output_path = process_video(
418
- # uploaded_img, uploaded_audio, width, height, length, seed, facemask_dilation_ratio, facecrop_dilation_ratio, context_frames, context_overlap, cfg, steps, sample_rate, fps, device
419
- # )
420
- # output_video= final_output_path
421
- # return final_output_path
422
-
423
- # generate_button.click(
424
- # generate_video,
425
- # inputs=[
426
- # uploaded_img,
427
- # uploaded_audio,
428
- # # width,
429
- # # height,
430
- # # length,
431
- # # seed,
432
- # # facemask_dilation_ratio,
433
- # # facecrop_dilation_ratio,
434
- # # context_frames,
435
- # # context_overlap,
436
- # # cfg,
437
- # # steps,
438
- # # sample_rate,
439
- # # fps,
440
- # # device
441
- # ],
442
- # outputs=output_video,
443
- # show_api=False
444
- # )
445
- def generate_video(uploaded_img, uploaded_audio,
446
- facemask_dilation_ratio=default_values["facemask_dilation_ratio"],
447
- facecrop_dilation_ratio=default_values["facecrop_dilation_ratio"],
448
- context_frames=default_values["context_frames"],
449
- context_overlap=default_values["context_overlap"],
450
- cfg=default_values["cfg"],
451
- steps=default_values["steps"],
452
- sample_rate=default_values["sample_rate"],
453
- fps=default_values["fps"],
454
- device=default_values["device"],
455
- width=default_values["width"],
456
- height=default_values["height"],
457
- length=default_values["length"] ):
458
-
459
  final_output_path = process_video(
460
- uploaded_img,
461
- uploaded_audio, width, height,
462
- length, facemask_dilation_ratio,
463
- facecrop_dilation_ratio, context_frames,
464
- context_overlap, cfg, steps,
465
- sample_rate, fps, device
466
  )
467
- output_video = final_output_path
468
  return final_output_path
469
 
 
 
 
 
 
 
 
 
470
  generate_button.click(
471
  generate_video,
472
  inputs=[
473
  uploaded_img,
474
- uploaded_audio
 
 
 
 
 
 
 
 
 
 
 
 
 
475
  ],
476
  outputs=output_video,
477
- show_progress=True
478
  )
479
- parser = argparse.ArgumentParser(description='EchoMimic')
480
  parser.add_argument('--server_name', type=str, default='0.0.0.0', help='Server name')
481
  parser.add_argument('--server_port', type=int, default=7680, help='Server port')
482
  args = parser.parse_args()
483
 
484
- # demo.launch(server_name=args.server_name, server_port=args.server_port, inbrowser=True)
485
-
486
  if __name__ == '__main__':
487
- demo.queue(max_size=3).launch(show_api=False, show_error=True)
488
  #demo.launch(server_name=args.server_name, server_port=args.server_port, inbrowser=True)
 
3
  '''
4
  webui
5
  '''
 
 
6
 
7
+ import os
 
8
  import random
9
  from datetime import datetime
10
  from pathlib import Path
 
26
  import argparse
27
 
28
  import gradio as gr
29
+ from gradio_client import Client, handle_file
30
+ from pydub import AudioSegment
31
+ import huggingface_hub
32
+ import spaces # Import spaces module for ZeroGPU support
 
 
 
 
33
 
34
  huggingface_hub.snapshot_download(
35
  repo_id='BadToBest/EchoMimic',
36
+ local_dir='./pretrained_weights'
 
37
  )
38
 
39
+ is_shared_ui = True if "fffiloni/EchoMimic" in os.environ.get('SPACE_ID', '') else False
40
  available_property = False if is_shared_ui else True
41
  advanced_settings_label = "Advanced Configuration (only for duplicated spaces)" if is_shared_ui else "Advanced Configuration"
42
 
 
59
  ffmpeg_path = os.getenv('FFMPEG_PATH')
60
  if ffmpeg_path is None:
61
  print("please download ffmpeg-static and export to FFMPEG_PATH. \nFor example: export FFMPEG_PATH=/musetalk/ffmpeg-4.4-amd64-static")
62
+ elif ffmpeg_path not in os.getenv('PATH', ''):
63
  print("add ffmpeg to path")
64
  os.environ["PATH"] = f"{ffmpeg_path}:{os.environ['PATH']}"
65
 
 
78
  inference_config_path = config.inference_config
79
  infer_config = OmegaConf.load(inference_config_path)
80
 
81
+ # Model initialization is performed on-demand with ZeroGPU
82
+
83
+ # Function to initialize models when needed
84
+ @spaces.GPU
85
+ def initialize_models():
86
+ global vae, reference_unet, denoising_unet, face_locator, audio_processor, face_detector, pipe
87
+
88
+ ## vae init
89
+ vae = AutoencoderKL.from_pretrained(config.pretrained_vae_path).to(device, dtype=weight_dtype)
90
+
91
+ ## reference net init
92
+ reference_unet = UNet2DConditionModel.from_pretrained(
 
 
 
 
 
 
 
 
 
 
 
93
  config.pretrained_base_model_path,
 
94
  subfolder="unet",
 
 
 
 
 
95
  ).to(dtype=weight_dtype, device=device)
96
+ reference_unet.load_state_dict(torch.load(config.reference_unet_path, map_location="cpu"))
97
+
98
+ ## denoising net init
99
+ if os.path.exists(config.motion_module_path):
100
+ ### stage1 + stage2
101
+ denoising_unet = EchoUNet3DConditionModel.from_pretrained_2d(
102
+ config.pretrained_base_model_path,
103
+ config.motion_module_path,
104
+ subfolder="unet",
105
+ unet_additional_kwargs=infer_config.unet_additional_kwargs,
106
+ ).to(dtype=weight_dtype, device=device)
107
+ else:
108
+ ### only stage1
109
+ denoising_unet = EchoUNet3DConditionModel.from_pretrained_2d(
110
+ config.pretrained_base_model_path,
111
+ "",
112
+ subfolder="unet",
113
+ unet_additional_kwargs={
114
+ "use_motion_module": False,
115
+ "unet_use_temporal_attention": False,
116
+ "cross_attention_dim": infer_config.unet_additional_kwargs.cross_attention_dim
117
+ }
118
+ ).to(dtype=weight_dtype, device=device)
119
+
120
+ denoising_unet.load_state_dict(torch.load(config.denoising_unet_path, map_location="cpu"), strict=False)
121
+
122
+ ## face locator init
123
+ face_locator = FaceLocator(320, conditioning_channels=1, block_out_channels=(16, 32, 96, 256)).to(dtype=weight_dtype, device=device)
124
+ face_locator.load_state_dict(torch.load(config.face_locator_path))
125
+
126
+ ## load audio processor params
127
+ audio_processor = load_audio_model(model_path=config.audio_model_path, device=device)
128
+
129
+ ## load face detector params
130
+ face_detector = MTCNN(image_size=320, margin=0, min_face_size=20, thresholds=[0.6, 0.7, 0.7], factor=0.709, post_process=True, device=device)
131
+
132
+ sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
133
+ scheduler = DDIMScheduler(**sched_kwargs)
134
+
135
+ pipe = Audio2VideoPipeline(
136
+ vae=vae,
137
+ reference_unet=reference_unet,
138
+ denoising_unet=denoising_unet,
139
+ audio_guider=audio_processor,
140
+ face_locator=face_locator,
141
+ scheduler=scheduler,
142
+ ).to(device, dtype=weight_dtype)
143
+
144
+ # Global variables for models
145
+ vae = None
146
+ reference_unet = None
147
+ denoising_unet = None
148
+ face_locator = None
149
+ audio_processor = None
150
+ face_detector = None
151
+ pipe = None
152
+
153
+ def ensure_png(image_path):
154
+ # Load the image with Pillow
155
+ with Image.open(image_path) as img:
156
+ # Check if the image is already a PNG
157
+ if img.format != "PNG":
158
+ # Convert and save as PNG
159
+ png_path = os.path.splitext(image_path)[0] + ".png"
160
+ img.save(png_path, format="PNG")
161
+ print(f"Image converted to PNG and saved as {png_path}")
162
+ return png_path
163
+ else:
164
+ print("Image is already a PNG.")
165
+ return image_path
166
 
167
  def select_face(det_bboxes, probs):
168
  ## max face from faces that the prob is above 0.8
 
178
  sorted_bboxes = sorted(filtered_bboxes, key=lambda x:(x[3]-x[1]) * (x[2] - x[0]), reverse=True)
179
  return sorted_bboxes[0]
180
 
181
+ @spaces.GPU(duration=120) # Allow up to 2 minutes for video processing (maximum allowed)
182
+ def process_video(uploaded_img, uploaded_audio, width, height, length, seed, facemask_dilation_ratio, facecrop_dilation_ratio, context_frames, context_overlap, cfg, steps, sample_rate, fps, device):
183
+ # Ensure models are initialized
184
+ if vae is None:
185
+ initialize_models()
186
 
187
+ if seed is not None and seed > -1:
188
+ generator = torch.manual_seed(seed)
189
+ else:
190
+ generator = torch.manual_seed(random.randint(100, 1000000))
191
+
192
+ uploaded_img = ensure_png(uploaded_img)
193
+
194
+ #### face mask prepare
195
  face_img = cv2.imread(uploaded_img)
196
+
197
+ # Get the original dimensions
198
+ original_height, original_width = face_img.shape[:2]
199
+
200
+ # Set the new width to 512 pixels
201
+ new_width = 512
202
+
203
+ # Calculate the new height with the same aspect ratio
204
+ new_height = int(original_height * (new_width / original_width))
205
+
206
+ # Ensure both width and height are divisible by 8
207
+ new_width = (new_width // 8) * 8 # Force target width to be divisible by 8
208
+ new_height = (new_height // 8) * 8 # Floor the height to the nearest multiple of 8
209
+
210
+
211
+ # Resize the image to the calculated dimensions
212
+ face_img = cv2.resize(face_img, (new_width, new_height))
213
+
214
  face_mask = np.zeros((face_img.shape[0], face_img.shape[1])).astype('uint8')
215
  det_bboxes, probs = face_detector.detect(face_img)
216
  select_bbox = select_face(det_bboxes, probs)
217
  if select_bbox is None:
218
+ print("SELECT_BBOX IS NONE")
219
  face_mask[:, :] = 255
220
+ face_img = cv2.resize(face_img, (width, height))
221
+ face_mask = cv2.resize(face_mask, (width, height))
222
+ raise gr.Error("Face Detector could not detect a face in your image. Try with a 512 squared image where the face is clearly visible.")
223
  else:
224
+ print("SELECT_BBOX IS NOT NONE")
225
  xyxy = select_bbox[:4]
226
  xyxy = np.round(xyxy).astype('int')
227
  rb, re, cb, ce = xyxy[1], xyxy[3], xyxy[0], xyxy[2]
228
  r_pad = int((re - rb) * facemask_dilation_ratio)
229
  c_pad = int((ce - cb) * facemask_dilation_ratio)
230
  face_mask[rb - r_pad : re + r_pad, cb - c_pad : ce + c_pad] = 255
231
+
232
+ #### face crop
233
  r_pad_crop = int((re - rb) * facecrop_dilation_ratio)
234
  c_pad_crop = int((ce - cb) * facecrop_dilation_ratio)
235
  crop_rect = [max(0, cb - c_pad_crop), max(0, rb - r_pad_crop), min(ce + c_pad_crop, face_img.shape[1]), min(re + r_pad_crop, face_img.shape[0])]
 
237
  face_mask = crop_and_pad(face_mask, crop_rect)
238
  face_img = cv2.resize(face_img, (width, height))
239
  face_mask = cv2.resize(face_mask, (width, height))
 
 
 
240
 
 
 
 
241
  ref_image_pil = Image.fromarray(face_img[:, :, [2, 1, 0]])
242
+ face_mask_tensor = torch.Tensor(face_mask).to(dtype=weight_dtype, device=device).unsqueeze(0).unsqueeze(0).unsqueeze(0) / 255.0
243
+
244
  video = pipe(
245
  ref_image_pil,
246
  uploaded_audio,
 
250
  length,
251
  steps,
252
  cfg,
253
+ generator=generator,
254
  audio_sample_rate=sample_rate,
255
  context_frames=context_frames,
256
  fps=fps,
257
  context_overlap=context_overlap
258
  ).videos
 
259
 
260
  save_dir = Path("output/tmp")
261
  save_dir.mkdir(exist_ok=True, parents=True)
 
270
 
271
  return final_output_path
272
 
273
+ @spaces.GPU(duration=60) # Allow 1 minute for voice cloning
274
+ def get_maskGCT_TTS(prompt_audio_maskGCT, audio_to_clone):
275
+ try:
276
+ client = Client("amphion/maskgct")
277
+ except:
278
+ raise gr.Error(f"amphion/maskgct space's api might not be ready, please wait, or upload an audio instead.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
 
280
+ result = client.predict(
281
+ prompt_wav = handle_file(audio_to_clone),
282
+ target_text = prompt_audio_maskGCT,
283
+ target_len=-1,
284
+ n_timesteps=25,
285
+ api_name="/predict"
286
+ )
287
+ print(result)
288
+ return result, gr.update(value=result, visible=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
 
290
  with gr.Blocks() as demo:
291
  gr.Markdown('# EchoMimic')
292
  gr.Markdown('## Lifelike Audio-Driven Portrait Animations through Editable Landmark Conditioning')
293
+ gr.Markdown('Running on Spaces ZeroGPU: Dynamic GPU allocation for optimal resource usage')
294
  gr.HTML("""
295
  <div style="display:flex;column-gap:4px;">
296
  <a href='https://badtobest.github.io/echomimic.html'><img src='https://img.shields.io/badge/Project-Page-blue'></a>
 
298
  <a href='https://arxiv.org/abs/2407.08136'><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a>
299
  </div>
300
  """)
 
301
  with gr.Row():
302
+ with gr.Column():
303
  uploaded_img = gr.Image(type="filepath", label="Reference Image")
304
+ uploaded_audio = gr.Audio(type="filepath", label="Input Audio", format="wav")
305
+ preprocess_audio_file = gr.File(visible=False)
306
+ with gr.Accordion(label="Voice cloning with MaskGCT", open=False):
307
+ prompt_audio_maskGCT = gr.Textbox(
308
+ label = "Text to synthetize",
309
+ lines = 2,
310
+ max_lines = 2,
311
+ elem_id = "text-synth-maskGCT"
312
+ )
313
+ audio_to_clone_maskGCT = gr.Audio(
314
+ label = "Voice to clone",
315
+ type = "filepath",
316
+ elem_id = "audio-clone-elm-maskGCT"
317
+ )
318
+ gen_maskGCT_voice_btn = gr.Button("Generate voice clone (optional)")
319
  with gr.Accordion(label=advanced_settings_label, open=False):
320
  with gr.Row():
321
  width = gr.Slider(label="Width", minimum=128, maximum=1024, value=default_values["width"], interactive=available_property)
 
336
  sample_rate = gr.Slider(label="Sample Rate", minimum=8000, maximum=48000, step=1000, value=default_values["sample_rate"], interactive=available_property)
337
  fps = gr.Slider(label="FPS", minimum=1, maximum=60, step=1, value=default_values["fps"], interactive=available_property)
338
  device = gr.Radio(label="Device", choices=["cuda", "cpu"], value=default_values["device"], interactive=available_property)
 
 
339
  generate_button = gr.Button("Generate Video")
340
+ with gr.Column():
341
  output_video = gr.Video()
342
+ gr.Examples(
343
+ label = "Portrait examples",
344
+ examples = [
345
+ ['assets/test_imgs/a.png'],
346
+ ['assets/test_imgs/b.png'],
347
+ ['assets/test_imgs/c.png'],
348
+ ['assets/test_imgs/d.png'],
349
+ ['assets/test_imgs/e.png']
350
+ ],
351
+ inputs = [uploaded_img]
352
+ )
353
+ gr.Examples(
354
+ label = "Audio examples",
355
+ examples = [
356
+ ['assets/test_audios/chunnuanhuakai.wav'],
357
+ ['assets/test_audios/chunwang.wav'],
358
+ ['assets/test_audios/echomimic_en_girl.wav'],
359
+ ['assets/test_audios/echomimic_en.wav'],
360
+ ['assets/test_audios/echomimic_girl.wav'],
361
+ ['assets/test_audios/echomimic.wav'],
362
+ ['assets/test_audios/jane.wav'],
363
+ ['assets/test_audios/mei.wav'],
364
+ ['assets/test_audios/walden.wav'],
365
+ ['assets/test_audios/yun.wav'],
366
+ ],
367
+ inputs = [uploaded_audio]
368
+ )
369
+ gr.HTML("""
370
+ <div style="display:flex;column-gap:4px;">
371
+ <a href="https://huggingface.co/spaces/fffiloni/EchoMimic?duplicate=true">
372
+ <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-xl.svg" alt="Duplicate this Space">
373
+ </a>
374
+ <a href="https://huggingface.co/fffiloni">
375
+ <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/follow-me-on-HF-xl-dark.svg" alt="Follow me on HF">
376
+ </a>
377
+ </div>
378
+ """)
379
+
380
+ def trim_audio(file_path, output_path, max_duration=10):
381
+ # Load the audio file
382
+ audio = AudioSegment.from_wav(file_path)
383
 
384
+ # Convert max duration to milliseconds
385
+ max_duration_ms = max_duration * 1000
386
+
387
+ # Trim the audio if it's longer than max_duration
388
+ if len(audio) > max_duration_ms:
389
+ audio = audio[:max_duration_ms]
390
+
391
+ # Export the trimmed audio
392
+ audio.export(output_path, format="wav")
393
+ print(f"Audio trimmed and saved as {output_path}")
394
+ return output_path
395
+
396
+ def generate_video(uploaded_img, uploaded_audio, width, height, length, seed, facemask_dilation_ratio, facecrop_dilation_ratio, context_frames, context_overlap, cfg, steps, sample_rate, fps, device, progress=gr.Progress(track_tqdm=True)):
397
+ # First, check and trim audio if needed
398
+ if is_shared_ui:
399
+ gr.Info("Trimming audio to max 10 seconds. Duplicate the space for unlimited audio length.")
400
+ uploaded_audio = trim_audio(uploaded_audio, "trimmed_audio.wav")
401
+
402
+ # Process the video with ZeroGPU support
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
403
  final_output_path = process_video(
404
+ uploaded_img, uploaded_audio, width, height, length, seed, facemask_dilation_ratio, facecrop_dilation_ratio, context_frames, context_overlap, cfg, steps, sample_rate, fps, device
 
 
 
 
 
405
  )
 
406
  return final_output_path
407
 
408
+ gen_maskGCT_voice_btn.click(
409
+ fn = get_maskGCT_TTS,
410
+ inputs = [prompt_audio_maskGCT, audio_to_clone_maskGCT],
411
+ outputs = [uploaded_audio, preprocess_audio_file],
412
+ queue = False,
413
+ show_api = False
414
+ )
415
+
416
  generate_button.click(
417
  generate_video,
418
  inputs=[
419
  uploaded_img,
420
+ uploaded_audio,
421
+ width,
422
+ height,
423
+ length,
424
+ seed,
425
+ facemask_dilation_ratio,
426
+ facecrop_dilation_ratio,
427
+ context_frames,
428
+ context_overlap,
429
+ cfg,
430
+ steps,
431
+ sample_rate,
432
+ fps,
433
+ device
434
  ],
435
  outputs=output_video,
436
+ show_api=False
437
  )
438
+ parser = argparse.ArgumentParser(description='EchoMimic with ZeroGPU Support')
439
  parser.add_argument('--server_name', type=str, default='0.0.0.0', help='Server name')
440
  parser.add_argument('--server_port', type=int, default=7680, help='Server port')
441
  args = parser.parse_args()
442
 
 
 
443
  if __name__ == '__main__':
444
+ demo.queue(max_size=3).launch(show_api=False, show_error=True, ssr_mode=False)
445
  #demo.launch(server_name=args.server_name, server_port=args.server_port, inbrowser=True)