import gradio as gr import torch import imageio import imageio_ffmpeg import numpy as np import matplotlib.pyplot as plt import matplotlib.animation as animation from skimage.transform import resize import warnings import os from model import load_checkpoints from model import make_animation from skimage import img_as_ubyte from PIL import Image import time warnings.filterwarnings("ignore") device = torch.device('cuda:0') #device = torch.device('cpu') dataset_name = 'vox' # ['vox', 'taichi', 'ted', 'mgif'] source_image_path = './assets/source.png' driving_video_path = './assets/driving.mp4' output_video_path = './generated.mp4' config_path = './config/vox-256.yaml' checkpoint_path = './checkpoints/vox.pth.tar' predict_mode = 'relative' # ['standard', 'relative', 'avd'] find_best_frame = False # when use the relative mode to animate a face, use 'find_best_frame=True' can get better quality result pixel = 256 # for vox, taichi and mgif, the resolution is 256*256 if(dataset_name == 'ted'): # for ted, the resolution is 384*384 pixel = 384 if find_best_frame: #!pip install face_alignment pass def create_video(tt): source_image = imageio.imread(f"assets/img_{tt}.jpg") reader = imageio.get_reader(f"assets/ref_{tt}.mp4") source_image = resize(source_image, (pixel, pixel))[..., :3] fps = reader.get_meta_data()['fps'] driving_video = [] try: for im in reader: driving_video.append(im) except RuntimeError: pass reader.close() driving_video = [resize(frame, (pixel, pixel))[..., :3] for frame in driving_video] def display(source, driving, generated=None): fig = plt.figure(figsize=(8 + 4 * (generated is not None), 6)) ims = [] for i in range(len(driving)): cols = [source] cols.append(driving[i]) if generated is not None: cols.append(generated[i]) im = plt.imshow(np.concatenate(cols, axis=1), animated=True) plt.axis('off') ims.append([im]) ani = animation.ArtistAnimation(fig, ims, interval=50, repeat_delay=1000) plt.close() return ani #HTML(display(source_image, driving_video).to_html5_video()) inpainting, kp_detector, dense_motion_network, avd_network = load_checkpoints(config_path = config_path, checkpoint_path = checkpoint_path, device = device) if predict_mode=='relative' and find_best_frame: from model import find_best_frame as _find i = _find(source_image, driving_video, device.type=='cpu') print ("Best frame: " + str(i)) driving_forward = driving_video[i:] driving_backward = driving_video[:(i+1)][::-1] predictions_forward = make_animation(source_image, driving_forward, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = predict_mode) predictions_backward = make_animation(source_image, driving_backward, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = predict_mode) predictions = predictions_backward[::-1] + predictions_forward[1:] else: predictions = make_animation(source_image, driving_video, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = predict_mode) #save resulting video imageio.mimsave(f"./assets/output_{tt}.mp4", [img_as_ubyte(frame) for frame in predictions], fps=fps) def greet(img,video): tt=str(time.time()) os.replace(video, f"assets/ref_{tt}.mp4") img.save(f"assets/img_{tt}.jpg") create_video(tt) return f"./assets/output_{tt}.mp4" iface = gr.Interface(fn=greet, inputs=[gr.inputs.Image(type="pil"),gr.inputs.Video()], outputs=gr.inputs.Video()) iface.launch()