Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import imageio | |
import imageio_ffmpeg | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import matplotlib.animation as animation | |
from skimage.transform import resize | |
import warnings | |
import os | |
from model import load_checkpoints | |
from model import make_animation | |
from skimage import img_as_ubyte | |
from PIL import Image | |
import time | |
warnings.filterwarnings("ignore") | |
device = torch.device('cuda:0') | |
#device = torch.device('cpu') | |
dataset_name = 'vox' # ['vox', 'taichi', 'ted', 'mgif'] | |
source_image_path = './assets/source.png' | |
driving_video_path = './assets/driving.mp4' | |
output_video_path = './generated.mp4' | |
config_path = './config/vox-256.yaml' | |
checkpoint_path = './checkpoints/vox.pth.tar' | |
predict_mode = 'relative' # ['standard', 'relative', 'avd'] | |
find_best_frame = False # when use the relative mode to animate a face, use 'find_best_frame=True' can get better quality result | |
pixel = 256 # for vox, taichi and mgif, the resolution is 256*256 | |
if(dataset_name == 'ted'): # for ted, the resolution is 384*384 | |
pixel = 384 | |
if find_best_frame: | |
#!pip install face_alignment | |
pass | |
def create_video(tt): | |
source_image = imageio.imread(f"assets/img_{tt}.jpg") | |
reader = imageio.get_reader(f"assets/ref_{tt}.mp4") | |
source_image = resize(source_image, (pixel, pixel))[..., :3] | |
fps = reader.get_meta_data()['fps'] | |
driving_video = [] | |
try: | |
for im in reader: | |
driving_video.append(im) | |
except RuntimeError: | |
pass | |
reader.close() | |
driving_video = [resize(frame, (pixel, pixel))[..., :3] for frame in driving_video] | |
def display(source, driving, generated=None): | |
fig = plt.figure(figsize=(8 + 4 * (generated is not None), 6)) | |
ims = [] | |
for i in range(len(driving)): | |
cols = [source] | |
cols.append(driving[i]) | |
if generated is not None: | |
cols.append(generated[i]) | |
im = plt.imshow(np.concatenate(cols, axis=1), animated=True) | |
plt.axis('off') | |
ims.append([im]) | |
ani = animation.ArtistAnimation(fig, ims, interval=50, repeat_delay=1000) | |
plt.close() | |
return ani | |
#HTML(display(source_image, driving_video).to_html5_video()) | |
inpainting, kp_detector, dense_motion_network, avd_network = load_checkpoints(config_path = config_path, checkpoint_path = checkpoint_path, device = device) | |
if predict_mode=='relative' and find_best_frame: | |
from model import find_best_frame as _find | |
i = _find(source_image, driving_video, device.type=='cpu') | |
print ("Best frame: " + str(i)) | |
driving_forward = driving_video[i:] | |
driving_backward = driving_video[:(i+1)][::-1] | |
predictions_forward = make_animation(source_image, driving_forward, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = predict_mode) | |
predictions_backward = make_animation(source_image, driving_backward, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = predict_mode) | |
predictions = predictions_backward[::-1] + predictions_forward[1:] | |
else: | |
predictions = make_animation(source_image, driving_video, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = predict_mode) | |
#save resulting video | |
imageio.mimsave(f"./assets/output_{tt}.mp4", [img_as_ubyte(frame) for frame in predictions], fps=fps) | |
def greet(img,video): | |
tt=str(time.time()) | |
os.replace(video, f"assets/ref_{tt}.mp4") | |
img.save(f"assets/img_{tt}.jpg") | |
create_video(tt) | |
return f"./assets/output_{tt}.mp4" | |
iface = gr.Interface(fn=greet, inputs=[gr.inputs.Image(type="pil"),gr.inputs.Video()], outputs=gr.inputs.Video()) | |
iface.launch() |