Spaces:
Build error
Build error
import os | |
from tqdm import tqdm | |
import torch | |
from torch.utils.data import DataLoader | |
from frames_dataset import PairedDataset | |
from logger import Logger, Visualizer | |
import imageio | |
from scipy.spatial import ConvexHull | |
import numpy as np | |
from sync_batchnorm import DataParallelWithCallback | |
def normalize_kp(kp_source, kp_driving, kp_driving_initial, adapt_movement_scale=False, | |
use_relative_movement=False, use_relative_jacobian=False): | |
if adapt_movement_scale: | |
source_area = ConvexHull(kp_source['value'][0].data.cpu().numpy()).volume | |
driving_area = ConvexHull(kp_driving_initial['value'][0].data.cpu().numpy()).volume | |
adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area) | |
else: | |
adapt_movement_scale = 1 | |
kp_new = {k: v for k, v in kp_driving.items()} | |
if use_relative_movement: | |
kp_value_diff = (kp_driving['value'] - kp_driving_initial['value']) | |
kp_value_diff *= adapt_movement_scale | |
kp_new['value'] = kp_value_diff + kp_source['value'] | |
if use_relative_jacobian: | |
jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian'])) | |
kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source['jacobian']) | |
return kp_new | |
def animate(config, generator, kp_detector, checkpoint, log_dir, dataset): | |
log_dir = os.path.join(log_dir, 'animation') | |
png_dir = os.path.join(log_dir, 'png') | |
animate_params = config['animate_params'] | |
dataset = PairedDataset(initial_dataset=dataset, number_of_pairs=animate_params['num_pairs']) | |
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1) | |
if checkpoint is not None: | |
Logger.load_cpk(checkpoint, generator=generator, kp_detector=kp_detector) | |
else: | |
raise AttributeError("Checkpoint should be specified for mode='animate'.") | |
if not os.path.exists(log_dir): | |
os.makedirs(log_dir) | |
if not os.path.exists(png_dir): | |
os.makedirs(png_dir) | |
if torch.cuda.is_available(): | |
generator = DataParallelWithCallback(generator) | |
kp_detector = DataParallelWithCallback(kp_detector) | |
generator.eval() | |
kp_detector.eval() | |
for it, x in tqdm(enumerate(dataloader)): | |
with torch.no_grad(): | |
predictions = [] | |
visualizations = [] | |
driving_video = x['driving_video'] | |
source_frame = x['source_video'][:, :, 0, :, :] | |
kp_source = kp_detector(source_frame) | |
kp_driving_initial = kp_detector(driving_video[:, :, 0]) | |
for frame_idx in range(driving_video.shape[2]): | |
driving_frame = driving_video[:, :, frame_idx] | |
kp_driving = kp_detector(driving_frame) | |
kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving, | |
kp_driving_initial=kp_driving_initial, **animate_params['normalization_params']) | |
out = generator(source_frame, kp_source=kp_source, kp_driving=kp_norm) | |
out['kp_driving'] = kp_driving | |
out['kp_source'] = kp_source | |
out['kp_norm'] = kp_norm | |
del out['sparse_deformed'] | |
predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0]) | |
visualization = Visualizer(**config['visualizer_params']).visualize(source=source_frame, | |
driving=driving_frame, out=out) | |
visualization = visualization | |
visualizations.append(visualization) | |
predictions = np.concatenate(predictions, axis=1) | |
result_name = "-".join([x['driving_name'][0], x['source_name'][0]]) | |
imageio.imsave(os.path.join(png_dir, result_name + '.png'), (255 * predictions).astype(np.uint8)) | |
image_name = result_name + animate_params['format'] | |
imageio.mimsave(os.path.join(log_dir, image_name), visualizations) | |