|
''' consistent initial noise for video generation''' |
|
import cv2 |
|
import os |
|
from os.path import join, basename, dirname, splitext |
|
import shutil |
|
import argparse |
|
import numpy as np |
|
import random |
|
import torch, torchvision |
|
import subprocess |
|
from audio import audio |
|
import face_detection |
|
from tqdm import tqdm |
|
|
|
from guided_diffusion import dist_util, logger |
|
from guided_diffusion.resample import create_named_schedule_sampler |
|
from guided_diffusion.script_util import ( |
|
tfg_model_and_diffusion_defaults, |
|
tfg_create_model_and_diffusion, |
|
args_to_dict, |
|
add_dict_to_argparser, |
|
) |
|
|
|
from guided_diffusion.tfg_data_util import ( |
|
tfg_process_batch, |
|
) |
|
|
|
def get_frame_id(frame): |
|
return int(basename(frame).split('.')[0]) |
|
|
|
def crop_audio_window(spec, start_frame, args ): |
|
if type(start_frame) == int: |
|
start_frame_num = start_frame |
|
else: |
|
start_frame_num = get_frame_id(start_frame) |
|
start_idx = int(args.mel_steps_per_sec * (start_frame_num / float(args.video_fps))) |
|
end_idx = start_idx + args.syncnet_mel_step_size |
|
return spec[start_idx : end_idx, :] |
|
|
|
def load_all_indiv_mels(path, args): |
|
in_path = path |
|
out_dir = join(args.sample_path, "temp", basename(in_path).replace(".mp4", "")) |
|
os.makedirs(out_dir, exist_ok= True) |
|
out_path = join(out_dir, "audio.wav") |
|
command2 = 'ffmpeg -loglevel error -y -i {} -strict -2 {}'.format(in_path, out_path) |
|
subprocess.call(command2, shell=True) |
|
wav = audio.load_wav(out_path, args.sample_rate) |
|
orig_mel = audio.melspectrogram(wav).T |
|
|
|
all_indiv_mels = [] |
|
|
|
i=1 |
|
while True: |
|
m = crop_audio_window(orig_mel.copy(), max(i - args.syncnet_T//2,0), args) |
|
if (m.shape[0] != args.syncnet_mel_step_size): |
|
break |
|
all_indiv_mels.append(m.T) |
|
i+=1 |
|
|
|
|
|
shutil.rmtree(join(args.sample_path, "temp")) |
|
|
|
return all_indiv_mels, wav |
|
|
|
def load_video_frames(path, args): |
|
in_path = path |
|
out_dir = join(args.sample_path, "temp", basename(in_path).replace(".mp4", ""), "image") |
|
os.makedirs(out_dir, exist_ok= True) |
|
|
|
|
|
command = "ffmpeg -loglevel error -y -i {} -vf fps={} -q:v 2 -qmin 1 {}/%05d.jpg".format(in_path, args.video_fps, out_dir) |
|
subprocess.call(command, shell=True) |
|
|
|
video_frames=[] |
|
for i, img_name in enumerate(sorted(os.listdir(out_dir))): |
|
img_path=join(out_dir, img_name) |
|
img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB) |
|
video_frames.append(img) |
|
|
|
|
|
shutil.rmtree(join(args.sample_path, "temp")) |
|
|
|
|
|
return video_frames |
|
|
|
|
|
def get_smoothened_boxes(boxes, T): |
|
for i in range(len(boxes)): |
|
if i + T > len(boxes): |
|
window = boxes[len(boxes) - T:] |
|
else: |
|
window = boxes[i : i + T] |
|
boxes[i] = np.mean(window, axis=0) |
|
return boxes |
|
|
|
def my_voxceleb2_crop(img): |
|
return img[:-int(img.shape[0]*2.36/8) , int(img.shape[1]*1.8/8): -int(img.shape[1]*1.8/8)] |
|
|
|
def my_voxceleb2_crop_bboxs(img): |
|
return 0,img.shape[0]-int(img.shape[0]*2.36/8), int(img.shape[1]*1.8/8), img.shape[1]-int(img.shape[1]*1.8/8) |
|
|
|
def face_detect(images, detector, args, resize=False): |
|
batch_size = args.face_det_batch_size |
|
|
|
while 1: |
|
predictions = [] |
|
try: |
|
for i in range(0, len(images), batch_size): |
|
predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size]))) |
|
except RuntimeError: |
|
if batch_size == 1: |
|
raise RuntimeError('Image too big to run face detection on GPU') |
|
batch_size //= 2 |
|
args.face_det_batch_size = batch_size |
|
print('Recovering from OOM error; New batch size: {}'.format(batch_size)) |
|
continue |
|
break |
|
|
|
results = [] |
|
if type(args.pads) == str : |
|
args.pads = [int(x) for x in args.pads.split(",")] |
|
pady1, pady2, padx1, padx2 = args.pads |
|
for rect, image in zip(predictions, images): |
|
if rect is None: |
|
raise ValueError('Face not detected!') |
|
|
|
y1 = max(0, rect[1] - pady1) |
|
y2 = min(image.shape[0], rect[3] + pady2) |
|
x1 = max(0, rect[0] - padx1) |
|
x2 = min(image.shape[1], rect[2] + padx2) |
|
|
|
results.append([x1, y1, x2, y2]) |
|
|
|
boxes = get_smoothened_boxes(np.array(results), T=5) |
|
|
|
if resize: |
|
if args.is_voxceleb2: |
|
results = [[cv2.resize(my_voxceleb2_crop(image),(args.image_size, args.image_size)), my_voxceleb2_crop_bboxs(image), True] for image, (x1, y1, x2, y2) in zip(images, boxes)] |
|
else: |
|
results = [[cv2.resize(image[y1: y2, x1:x2],(args.image_size, args.image_size)), (y1, y2, x1, x2), True] for image, (x1, y1, x2, y2) in zip(images, boxes)] |
|
else: |
|
results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2), True] for image, (x1, y1, x2, y2) in zip(images, boxes)] |
|
return results |
|
|
|
def normalise(tensor): |
|
""" [-1,1]->[0,1]""" |
|
return ((tensor+1)*0.5).clamp(0,1) |
|
|
|
def normalise2(tensor): |
|
""" [0,1]->[-1,1]""" |
|
return (tensor*2-1).clamp(-1,1) |
|
|
|
|
|
def sample_batch(batch, model, diffusion, args): |
|
B, F, C, H, W = batch[f'image'].shape |
|
sample_shape = (B*F, C, H, W) |
|
|
|
|
|
|
|
init_noise = None |
|
if args.sampling_seed: |
|
|
|
state = torch.get_rng_state() |
|
torch.manual_seed(args.sampling_seed) |
|
torch.cuda.manual_seed_all(args.sampling_seed) |
|
init_noise = torch.randn((1,C,H,W)) |
|
|
|
init_noise = init_noise.repeat(B*F,1,1,1) |
|
torch.set_rng_state(state) |
|
|
|
|
|
img_batch, model_kwargs = tfg_process_batch(batch, args.face_hide_percentage, |
|
use_ref=args.use_ref, |
|
use_audio=args.use_audio, |
|
|
|
noise=init_noise) |
|
|
|
|
|
img_batch = img_batch.to(dist_util.dev()) |
|
model_kwargs = {k: v.to(dist_util.dev()) for k,v in model_kwargs.items()} |
|
init_noise = init_noise.to(dist_util.dev()) if init_noise is not None else None |
|
|
|
sample_fn = ( |
|
diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop |
|
) |
|
sample = sample_fn( |
|
model, |
|
sample_shape, |
|
clip_denoised=args.clip_denoised, |
|
model_kwargs=model_kwargs, |
|
noise = init_noise |
|
) |
|
return sample, img_batch, model_kwargs |
|
|
|
|
|
def generate(video_path, audio_path, model, diffusion, detector, args, out_path=None, save_orig=True): |
|
video_frames = load_video_frames(video_path, args) |
|
try: |
|
face_det_results = face_detect(video_frames.copy(), detector, args, resize=True) |
|
except Exception as e: |
|
print("Error:", e, video_path, audio_path) |
|
import traceback |
|
print(traceback.format_exc()) |
|
wrong_all_indiv_mels, wrong_audio_wavform = load_all_indiv_mels(audio_path, args) |
|
|
|
min_frames = min(len(video_frames), len(wrong_all_indiv_mels)) |
|
video_frames = video_frames[:min_frames] |
|
face_det_results = face_det_results[:min_frames] |
|
face_bboxes = [face_det_results[i][1] for i in range(min_frames)] |
|
face_frames = torch.FloatTensor(np.transpose(np.asarray([face_det_results[i][0] for i in range(min_frames)], dtype=np.float32)/255.,(0,3,1,2))) |
|
wrong_all_indiv_mels = torch.FloatTensor(np.asarray(wrong_all_indiv_mels[:min_frames])).unsqueeze(1) |
|
|
|
if save_orig: |
|
if out_path is None: |
|
out_path_orig = os.path.join(args.sample_path, splitext(basename(video_path))[0]+"_"+ splitext(basename(audio_path))[0]+"_orig.mp4") |
|
else: |
|
out_path_orig = out_path.replace(".mp4", "_orig.mp4") |
|
torchvision.io.write_video( |
|
out_path_orig, |
|
video_array=torch.from_numpy(np.array(video_frames)), fps = args.video_fps, video_codec='libx264', |
|
audio_array=torch.from_numpy(wrong_audio_wavform).unsqueeze(0), audio_fps=args.sample_rate, audio_codec='aac' |
|
) |
|
|
|
if args.sampling_ref_type=='gt': |
|
ref_frames = face_frames.clone() |
|
elif args.sampling_ref_type=='first_frame': |
|
ref_frames = face_frames[0:1].repeat(len(face_frames),1,1,1) |
|
elif args.sampling_ref_type=='random': |
|
rand_idx = random.Random(args.sampling_seed).randint(0, len(face_frames)-1) |
|
ref_frames = face_frames[rand_idx:rand_idx+1].repeat(len(face_frames),1,1,1) |
|
|
|
if args.sampling_input_type=='first_frame': |
|
face_frames = face_frames[0:1].repeat(len(face_frames),1,1,1) |
|
video_frames = np.array(video_frames[0:1]*len(video_frames)) |
|
face_bboxes = np.array(face_bboxes[0:1]*len(face_bboxes)) |
|
|
|
|
|
generated_video_frames = [] |
|
b_s = args.sampling_batch_size |
|
for i in range(0,min_frames, b_s*args.nframes): |
|
video_frames_batch = video_frames[i:i+b_s*args.nframes] |
|
face_bboxes_batch = face_bboxes[i:i+b_s*args.nframes] |
|
|
|
try: |
|
img_batch = face_frames[i:i+b_s*args.nframes] |
|
img_batch = img_batch.reshape(-1, args.nframes, img_batch.size(-3), img_batch.size(-2), img_batch.size(-1)) |
|
ref_batch = ref_frames[i:i+b_s*args.nframes] |
|
ref_batch = ref_batch.reshape(-1, args.nframes, ref_batch.size(-3), ref_batch.size(-2), ref_batch.size(-1)) |
|
wrong_indiv_mel_batch = wrong_all_indiv_mels[i:i+b_s*args.nframes] |
|
wrong_indiv_mel_batch = wrong_indiv_mel_batch.reshape(-1, args.nframes, wrong_indiv_mel_batch.size(-3),wrong_indiv_mel_batch.size(-2),wrong_indiv_mel_batch.size(-1)) |
|
except: |
|
|
|
|
|
img_batch = face_frames[i:i+b_s*args.nframes] |
|
img_batch = img_batch.reshape(-1, 1, img_batch.size(-3), img_batch.size(-2), img_batch.size(-1)) |
|
ref_batch = ref_frames[i:i+b_s*args.nframes] |
|
ref_batch = ref_batch.reshape(-1, 1, ref_batch.size(-3), ref_batch.size(-2), ref_batch.size(-1)) |
|
wrong_indiv_mel_batch = wrong_all_indiv_mels[i:i+b_s*args.nframes] |
|
wrong_indiv_mel_batch = wrong_indiv_mel_batch.reshape(-1, 1, wrong_indiv_mel_batch.size(-3),wrong_indiv_mel_batch.size(-2),wrong_indiv_mel_batch.size(-1)) |
|
|
|
|
|
batch = {"image":img_batch, |
|
"ref_img":ref_batch, |
|
"indiv_mels":wrong_indiv_mel_batch} |
|
|
|
sample, img_batch, model_kwargs = sample_batch(batch, model, diffusion, args) |
|
mask = model_kwargs['mask'] |
|
recon_batch = sample * mask + (1. -mask)*img_batch |
|
recon_batch = (normalise(recon_batch)*255).cpu().numpy().transpose(0,2,3,1) |
|
|
|
for g,v,b in zip(recon_batch, video_frames_batch, face_bboxes_batch): |
|
y1, y2, x1, x2 = b |
|
g = cv2.resize(g.astype(np.uint8), (x2 - x1, y2 - y1)) |
|
v[y1:y2, x1:x2] = g |
|
generated_video_frames.append(v) |
|
|
|
|
|
|
|
print(wrong_audio_wavform.shape, np.array(generated_video_frames).shape) |
|
min_time = len(generated_video_frames)/args.video_fps |
|
wrong_audio_wavform = wrong_audio_wavform[:int(min_time*args.sample_rate)] |
|
print(wrong_audio_wavform.shape, np.array(generated_video_frames).shape) |
|
if out_path is None: |
|
out_path = os.path.join(args.sample_path, splitext(basename(video_path))[0]+"_"+ splitext(basename(audio_path))[0]+".mp4") |
|
torchvision.io.write_video( |
|
out_path, |
|
video_array=torch.from_numpy(np.array(generated_video_frames)), fps = args.video_fps, video_codec='libx264', |
|
audio_array=torch.from_numpy(wrong_audio_wavform).unsqueeze(0), audio_fps=args.sample_rate, audio_codec='aac' |
|
) |
|
|
|
|
|
|
|
|
|
|
|
def generate_from_filelist(test_video_dir, filelist, model, diffusion, detector, args): |
|
video_names = [] |
|
audio_names = [] |
|
with open(filelist, "r") as f: |
|
lines = f.readlines() |
|
for line in tqdm(lines): |
|
try: |
|
audio_name, video_name = line.strip().split() |
|
audio_path = join(test_video_dir, audio_name+'.mp4') |
|
video_path = join(test_video_dir, video_name+'.mp4') |
|
out_path = join(args.sample_path,audio_name.replace('/','.')+"_"+video_name.replace('/','.')+".mp4") |
|
generate(video_path, audio_path, model, diffusion, detector, args, out_path=out_path ,save_orig=args.save_orig) |
|
except Exception as e: |
|
print("Error:", e, video_path, audio_path) |
|
import traceback |
|
print(traceback.format_exc()) |
|
|
|
|
|
|
|
def main(): |
|
args = create_argparser().parse_args() |
|
dist_util.setup_dist() |
|
logger.configure(dir=args.sample_path, format_strs=["stdout", "log"]) |
|
|
|
logger.log("creating model...") |
|
model, diffusion = tfg_create_model_and_diffusion( |
|
**args_to_dict(args, tfg_model_and_diffusion_defaults().keys()) |
|
) |
|
print("Model Loaded") |
|
model.load_state_dict( |
|
dist_util.load_state_dict(args.model_path, map_location='cpu') |
|
) |
|
model.to(dist_util.dev()) |
|
if args.use_fp16: |
|
model.convert_to_fp16() |
|
model.eval() |
|
|
|
detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D, flip_input=False, device='cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
if args.generate_from_filelist: |
|
generate_from_filelist(args.test_video_dir, args.filelist, model, diffusion, detector, args) |
|
else: |
|
generate(args.video_path, args.audio_path, model, diffusion, detector, args, out_path=args.out_path, save_orig=args.save_orig) |
|
|
|
|
|
def create_argparser(): |
|
defaults = dict( |
|
|
|
generate_from_filelist = False, |
|
video_path = "", |
|
audio_path = "", |
|
out_path = None, |
|
save_orig = True, |
|
|
|
|
|
test_video_dir = "test_videos", |
|
filelist = "test_filelist.txt", |
|
|
|
|
|
use_fp16 = True, |
|
|
|
face_hide_percentage=0.5, |
|
use_ref=False, |
|
use_audio=False, |
|
audio_as_style=False, |
|
audio_as_style_encoder_mlp=False, |
|
|
|
|
|
nframes=1, |
|
nrefer=0, |
|
image_size=128, |
|
syncnet_T = 5, |
|
syncnet_mel_step_size = 16, |
|
audio_frames_per_video = 16, |
|
audio_dim=80, |
|
is_voxceleb2=True, |
|
|
|
video_fps=25, |
|
sample_rate=16000, |
|
mel_steps_per_sec=80., |
|
|
|
|
|
clip_denoised=True, |
|
sampling_batch_size=2, |
|
use_ddim=False, |
|
model_path="", |
|
sample_path="d2l_gen", |
|
sample_partition="", |
|
sampling_seed=None, |
|
sampling_use_gt_for_ref=False, |
|
sampling_ref_type='gt', |
|
sampling_input_type='gt', |
|
|
|
|
|
face_det_batch_size=64, |
|
pads = "0,0,0,0" |
|
) |
|
defaults.update(tfg_model_and_diffusion_defaults()) |
|
parser = argparse.ArgumentParser() |
|
add_dict_to_argparser(parser, defaults) |
|
return parser |
|
|
|
|
|
if __name__=="__main__": |
|
main() |