LivePortrait / stf_utils.py
yerang's picture
Update stf_utils.py
f55c77d verified
raw
history blame
7.06 kB
import torch
import os
from concurrent.futures import ThreadPoolExecutor
from pydub import AudioSegment
import cv2
from pathlib import Path
import subprocess
from pathlib import Path
import av
import imageio
import numpy as np
from rich.progress import track
from tqdm import tqdm
import stf_alternative
import spaces
def exec_cmd(cmd):
subprocess.run(
cmd, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT
)
def images2video(images, wfp, **kwargs):
fps = kwargs.get("fps", 24)
video_format = kwargs.get("format", "mp4") # default is mp4 format
codec = kwargs.get("codec", "libx264") # default is libx264 encoding
quality = kwargs.get("quality") # video quality
pixelformat = kwargs.get("pixelformat", "yuv420p") # video pixel format
image_mode = kwargs.get("image_mode", "rgb")
macro_block_size = kwargs.get("macro_block_size", 2)
ffmpeg_params = ["-crf", str(kwargs.get("crf", 18))]
writer = imageio.get_writer(
wfp,
fps=fps,
format=video_format,
codec=codec,
quality=quality,
ffmpeg_params=ffmpeg_params,
pixelformat=pixelformat,
macro_block_size=macro_block_size,
)
n = len(images)
for i in track(range(n), description="writing", transient=True):
if image_mode.lower() == "bgr":
writer.append_data(images[i][..., ::-1])
else:
writer.append_data(images[i])
writer.close()
# print(f':smiley: Dump to {wfp}\n', style="bold green")
print(f"Dump to {wfp}\n")
def merge_audio_video(video_fp, audio_fp, wfp):
if osp.exists(video_fp) and osp.exists(audio_fp):
cmd = f"ffmpeg -i {video_fp} -i {audio_fp} -c:v copy -c:a aac {wfp} -y"
exec_cmd(cmd)
print(f"merge {video_fp} and {audio_fp} to {wfp}")
else:
print(f"video_fp: {video_fp} or audio_fp: {audio_fp} not exists!")
class STFPipeline:
def __init__(
self,
stf_path: str = "/home/user/app/stf/",
template_video_path: str = "templates/front_one_piece_dress_nodded_cut.webm",
config_path: str = "front_config.json",
checkpoint_path: str = "089.pth",
root_path: str = "works",
wavlm_path: str = "microsoft/wavlm-large",
device: str = "cuda:0"
):
self.device = device
self.stf_path = stf_path
self.config_path = os.path.join(stf_path, config_path)
self.checkpoint_path = os.path.join(stf_path, checkpoint_path)
self.work_root_path = os.path.join(stf_path, root_path)
self.wavlm_path = wavlm_path
self.template_video_path = template_video_path
# ๋น„๋™๊ธฐ์ ์œผ๋กœ ๋ชจ๋ธ ๋กœ๋”ฉ
self.model = self.load_model()
self.template = self.create_template()
@spaces.GPU(duration=120)
def load_model(self):
"""๋ชจ๋ธ์„ ์ƒ์„ฑํ•˜๊ณ  GPU์— ํ• ๋‹น."""
model = stf_alternative.create_model(
config_path=self.config_path,
checkpoint_path=self.checkpoint_path,
work_root_path=self.work_root_path,
device=self.device,
wavlm_path=self.wavlm_path
)
return model
@spaces.GPU(duration=120)
def create_template(self):
"""ํ…œํ”Œ๋ฆฟ ์ƒ์„ฑ."""
template = stf_alternative.Template(
model=self.model,
config_path=self.config_path,
template_video_path=self.template_video_path
)
return template
def execute(self, audio: str) -> str:
"""์˜ค๋””์˜ค๋ฅผ ์ž…๋ ฅ ๋ฐ›์•„ ๋น„๋””์˜ค๋ฅผ ์ƒ์„ฑ."""
# ํด๋” ์ƒ์„ฑ
Path("dubbing").mkdir(exist_ok=True)
save_path = os.path.join("dubbing", Path(audio).stem + "--lip.mp4")
reader = iter(self.template._get_reader(num_skip_frames=0))
audio_segment = AudioSegment.from_file(audio)
results = []
# ๋น„๋™๊ธฐ ํ”„๋ ˆ์ž„ ์ƒ์„ฑ
with ThreadPoolExecutor(max_workers=4) as executor:
try:
gen_infer = self.template.gen_infer_concurrent(
executor, audio_segment, 0
)
for idx, (it, _) in enumerate(gen_infer):
frame = next(reader)
composed = self.template.compose(idx, frame, it)
results.append(it["pred"])
except StopIteration:
pass
self.images_to_video(results, save_path)
return save_path
@staticmethod
def images_to_video(images, output_path, fps=24):
"""์ด๋ฏธ์ง€ ๋ฐฐ์—ด์„ ๋น„๋””์˜ค๋กœ ๋ณ€ํ™˜."""
writer = imageio.get_writer(output_path, fps=fps, format="mp4", codec="libx264")
for i in track(range(len(images)), description="๋น„๋””์˜ค ์ƒ์„ฑ ์ค‘"):
writer.append_data(images[i])
writer.close()
print(f"๋น„๋””์˜ค ์ €์žฅ ์™„๋ฃŒ: {output_path}")
# class STFPipeline:
# def __init__(self,
# stf_path: str = "/home/user/app/stf/",
# device: str = "cuda:0",
# template_video_path: str = "templates/front_one_piece_dress_nodded_cut.webm",
# config_path: str = "front_config.json",
# checkpoint_path: str = "089.pth",
# root_path: str = "works"
# ):
# config_path = os.path.join(stf_path, config_path)
# checkpoint_path = os.path.join(stf_path, checkpoint_path)
# work_root_path = os.path.join(stf_path, root_path)
# model = stf_alternative.create_model(
# config_path=config_path,
# checkpoint_path=checkpoint_path,
# work_root_path=work_root_path,
# device=device,
# wavlm_path="microsoft/wavlm-large",
# )
# self.template = stf_alternative.Template(
# model=model,
# config_path=config_path,
# template_video_path=template_video_path,
# )
# def execute(self, audio: str):
# Path("dubbing").mkdir(exist_ok=True)
# save_path = os.path.join("dubbing", Path(audio).stem+"--lip.mp4")
# reader = iter(self.template._get_reader(num_skip_frames=0))
# audio_segment = AudioSegment.from_file(audio)
# pivot = 0
# results = []
# with ThreadPoolExecutor(4) as p:
# try:
# gen_infer = self.template.gen_infer_concurrent(
# p,
# audio_segment,
# pivot,
# )
# for idx, (it, chunk) in enumerate(gen_infer, pivot):
# frame = next(reader)
# composed = self.template.compose(idx, frame, it)
# frame_name = f"{idx}".zfill(5)+".jpg"
# results.append(it['pred'])
# pivot = idx + 1
# except StopIteration as e:
# pass
# images2video(results, save_path)
# return save_path