Xin Liu
test
2492d81 unverified
raw
history blame contribute delete
No virus
2.93 kB
import os
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
from logger import Logger, Visualizer
import numpy as np
import imageio
def reconstruction(config, inpainting_network, kp_detector, bg_predictor, dense_motion_network, checkpoint, log_dir, dataset):
png_dir = os.path.join(log_dir, 'reconstruction/png')
log_dir = os.path.join(log_dir, 'reconstruction')
if checkpoint is not None:
Logger.load_cpk(checkpoint, inpainting_network=inpainting_network, kp_detector=kp_detector,
bg_predictor=bg_predictor, dense_motion_network=dense_motion_network)
else:
raise AttributeError("Checkpoint should be specified for mode='reconstruction'.")
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
if not os.path.exists(log_dir):
os.makedirs(log_dir)
if not os.path.exists(png_dir):
os.makedirs(png_dir)
loss_list = []
inpainting_network.eval()
kp_detector.eval()
dense_motion_network.eval()
if bg_predictor:
bg_predictor.eval()
for it, x in tqdm(enumerate(dataloader)):
with torch.no_grad():
predictions = []
visualizations = []
if torch.cuda.is_available():
x['video'] = x['video'].cuda()
kp_source = kp_detector(x['video'][:, :, 0])
for frame_idx in range(x['video'].shape[2]):
source = x['video'][:, :, 0]
driving = x['video'][:, :, frame_idx]
kp_driving = kp_detector(driving)
bg_params = None
if bg_predictor:
bg_params = bg_predictor(source, driving)
dense_motion = dense_motion_network(source_image=source, kp_driving=kp_driving,
kp_source=kp_source, bg_param = bg_params,
dropout_flag = False)
out = inpainting_network(source, dense_motion)
out['kp_source'] = kp_source
out['kp_driving'] = kp_driving
predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
visualization = Visualizer(**config['visualizer_params']).visualize(source=source,
driving=driving, out=out)
visualizations.append(visualization)
loss = torch.abs(out['prediction'] - driving).mean().cpu().numpy()
loss_list.append(loss)
# print(np.mean(loss_list))
predictions = np.concatenate(predictions, axis=1)
imageio.imsave(os.path.join(png_dir, x['name'][0] + '.png'), (255 * predictions).astype(np.uint8))
print("Reconstruction loss: %s" % np.mean(loss_list))