adirik's picture
update repo
f4c3c2b
import argparse
import torch
import numpy as np
import sys
import os
import dlib
sys.path.append(".")
sys.path.append("..")
from configs import data_configs, paths_config
from datasets.inference_dataset import InferenceDataset
from torch.utils.data import DataLoader
from utils.model_utils import setup_model
from utils.common import tensor2im
from utils.alignment import align_face
from PIL import Image
def main(args):
net, opts = setup_model(args.ckpt, device)
is_cars = 'cars_' in opts.dataset_type
generator = net.decoder
generator.eval()
args, data_loader = setup_data_loader(args, opts)
# Check if latents exist
latents_file_path = os.path.join(args.save_dir, 'latents.pt')
if os.path.exists(latents_file_path):
latent_codes = torch.load(latents_file_path).to(device)
else:
latent_codes = get_all_latents(net, data_loader, args.n_sample, is_cars=is_cars)
torch.save(latent_codes, latents_file_path)
if not args.latents_only:
generate_inversions(args, generator, latent_codes, is_cars=is_cars)
def setup_data_loader(args, opts):
dataset_args = data_configs.DATASETS[opts.dataset_type]
transforms_dict = dataset_args['transforms'](opts).get_transforms()
images_path = args.images_dir if args.images_dir is not None else dataset_args['test_source_root']
print(f"images path: {images_path}")
align_function = None
if args.align:
align_function = run_alignment
test_dataset = InferenceDataset(root=images_path,
transform=transforms_dict['transform_test'],
preprocess=align_function,
opts=opts)
data_loader = DataLoader(test_dataset,
batch_size=args.batch,
shuffle=False,
num_workers=2,
drop_last=True)
print(f'dataset length: {len(test_dataset)}')
if args.n_sample is None:
args.n_sample = len(test_dataset)
return args, data_loader
def get_latents(net, x, is_cars=False):
codes = net.encoder(x)
if net.opts.start_from_latent_avg:
if codes.ndim == 2:
codes = codes + net.latent_avg.repeat(codes.shape[0], 1, 1)[:, 0, :]
else:
codes = codes + net.latent_avg.repeat(codes.shape[0], 1, 1)
if codes.shape[1] == 18 and is_cars:
codes = codes[:, :16, :]
return codes
def get_all_latents(net, data_loader, n_images=None, is_cars=False):
all_latents = []
i = 0
with torch.no_grad():
for batch in data_loader:
if n_images is not None and i > n_images:
break
x = batch
inputs = x.to(device).float()
latents = get_latents(net, inputs, is_cars)
all_latents.append(latents)
i += len(latents)
return torch.cat(all_latents)
def save_image(img, save_dir, idx):
result = tensor2im(img)
im_save_path = os.path.join(save_dir, f"{idx:05d}.jpg")
Image.fromarray(np.array(result)).save(im_save_path)
@torch.no_grad()
def generate_inversions(args, g, latent_codes, is_cars):
print('Saving inversion images')
inversions_directory_path = os.path.join(args.save_dir, 'inversions')
os.makedirs(inversions_directory_path, exist_ok=True)
for i in range(args.n_sample):
imgs, _ = g([latent_codes[i].unsqueeze(0)], input_is_latent=True, randomize_noise=False, return_latents=True)
if is_cars:
imgs = imgs[:, :, 64:448, :]
save_image(imgs[0], inversions_directory_path, i + 1)
def run_alignment(image_path):
predictor = dlib.shape_predictor(paths_config.model_paths['shape_predictor'])
aligned_image = align_face(filepath=image_path, predictor=predictor)
print("Aligned image has shape: {}".format(aligned_image.size))
return aligned_image
if __name__ == "__main__":
device = "cuda"
parser = argparse.ArgumentParser(description="Inference")
parser.add_argument("--images_dir", type=str, default=None,
help="The directory of the images to be inverted")
parser.add_argument("--save_dir", type=str, default=None,
help="The directory to save the latent codes and inversion images. (default: images_dir")
parser.add_argument("--batch", type=int, default=1, help="batch size for the generator")
parser.add_argument("--n_sample", type=int, default=None, help="number of the samples to infer.")
parser.add_argument("--latents_only", action="store_true", help="infer only the latent codes of the directory")
parser.add_argument("--align", action="store_true", help="align face images before inference")
parser.add_argument("ckpt", metavar="CHECKPOINT", help="path to generator checkpoint")
args = parser.parse_args()
main(args)