import os import sys import numpy as np import torch import yaml import glob import argparse from omegaconf import OmegaConf from pathlib import Path os.environ["OMP_NUM_THREADS"] = "1" os.environ["OPENBLAS_NUM_THREADS"] = "1" os.environ["MKL_NUM_THREADS"] = "1" os.environ["VECLIB_MAXIMUM_THREADS"] = "1" os.environ["NUMEXPR_NUM_THREADS"] = "1" sys.path.insert(0, str(Path(__file__).resolve().parent / "lama")) from saicinpainting.evaluation.utils import move_to_device from saicinpainting.training.trainers import load_checkpoint from saicinpainting.evaluation.data import pad_tensor_to_modulo from saicinpainting.evaluation.refinement import refine_predict from utils import load_img_to_array, save_array_to_img @torch.no_grad() def inpaint_img_with_lama( img: np.ndarray, mask: np.ndarray, config_p: str, ckpt_p: str, mod=8, device="cuda" ): assert len(mask.shape) == 2 if np.max(mask) == 1: mask = mask * 255 img = torch.from_numpy(img).float().div(255.0) mask = torch.from_numpy(mask).float() predict_config = OmegaConf.load(config_p) predict_config.model.path = ckpt_p # device = torch.device(predict_config.device) device = torch.device(device) train_config_path = os.path.join(predict_config.model.path, "config.yaml") with open(train_config_path, "r") as f: train_config = OmegaConf.create(yaml.safe_load(f)) train_config.training_model.predict_only = True train_config.visualizer.kind = "noop" checkpoint_path = os.path.join( predict_config.model.path, "models", predict_config.model.checkpoint ) model = load_checkpoint( train_config, checkpoint_path, strict=False, map_location=device ) model.freeze() model.to(device) batch = {} batch["image"] = img.permute(2, 0, 1).unsqueeze(0) batch["mask"] = mask[None, None] unpad_to_size = [batch["image"].shape[2], batch["image"].shape[3]] batch["image"] = pad_tensor_to_modulo(batch["image"], mod) batch["mask"] = pad_tensor_to_modulo(batch["mask"], mod) # batch = move_to_device(batch, device) # batch["mask"] = (batch["mask"] > 0) * 1 # batch = model(batch) # cur_res = batch[predict_config.out_key][0].permute(1, 2, 0) # cur_res = cur_res.detach().cpu().numpy() if predict_config.get("refine", False): batch["unpad_to_size"] = [torch.tensor([size]) for size in unpad_to_size] cur_res = refine_predict(batch, model, **predict_config.refiner) cur_res = cur_res[0].permute(1, 2, 0).detach().cpu().numpy() else: batch = move_to_device(batch, device) batch["mask"] = (batch["mask"] > 0) * 1 batch = model(batch) cur_res = batch[predict_config.out_key][0].permute(1, 2, 0) cur_res = cur_res.detach().cpu().numpy() if unpad_to_size is not None: orig_height, orig_width = unpad_to_size cur_res = cur_res[:orig_height, :orig_width] # if unpad_to_size is not None: # orig_height, orig_width = unpad_to_size # cur_res = cur_res[:orig_height, :orig_width] cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8") return cur_res def build_lama_model(config_p: str, ckpt_p: str, device="cuda"): predict_config = OmegaConf.load(config_p) predict_config.model.path = ckpt_p # device = torch.device(predict_config.device) device = torch.device(device) train_config_path = os.path.join(predict_config.model.path, "config.yaml") with open(train_config_path, "r") as f: train_config = OmegaConf.create(yaml.safe_load(f)) train_config.training_model.predict_only = True train_config.visualizer.kind = "noop" checkpoint_path = os.path.join( predict_config.model.path, "models", predict_config.model.checkpoint ) model = load_checkpoint( train_config, checkpoint_path, strict=False, map_location=device ) model.freeze() model.to(device) return model @torch.no_grad() def inpaint_img_with_builded_lama( model, img: np.ndarray, mask: np.ndarray, config_p: str, mod=8, device="cuda" ): assert len(mask.shape) == 2 if np.max(mask) == 1: mask = mask * 255 img = torch.from_numpy(img).float().div(255.0) mask = torch.from_numpy(mask).float() predict_config = OmegaConf.load(config_p) batch = {} batch["image"] = img.permute(2, 0, 1).unsqueeze(0) batch["mask"] = mask[None, None] unpad_to_size = [batch["image"].shape[2], batch["image"].shape[3]] batch["image"] = pad_tensor_to_modulo(batch["image"], mod) batch["mask"] = pad_tensor_to_modulo(batch["mask"], mod) batch = move_to_device(batch, device) batch["mask"] = (batch["mask"] > 0) * 1 batch = model(batch) cur_res = batch[predict_config.out_key][0].permute(1, 2, 0) cur_res = cur_res.detach().cpu().numpy() if unpad_to_size is not None: orig_height, orig_width = unpad_to_size cur_res = cur_res[:orig_height, :orig_width] cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8") return cur_res def setup_args(parser): parser.add_argument( "--input_img", type=str, required=True, help="Path to a single input img", ) parser.add_argument( "--input_mask_glob", type=str, required=True, help="Glob to input masks", ) parser.add_argument( "--output_dir", type=str, required=True, help="Output path to the directory with results.", ) parser.add_argument( "--lama_config", type=str, default="./third_party/lama/configs/prediction/default.yaml", help="The path to the config file of lama model. " "Default: the config of big-lama", ) parser.add_argument( "--lama_ckpt", type=str, required=True, help="The path to the lama checkpoint.", ) if __name__ == "__main__": """Example usage: python lama_inpaint.py \ --input_img FA_demo/FA1_dog.png \ --input_mask_glob "results/FA1_dog/mask*.png" \ --output_dir results \ --lama_config lama/configs/prediction/default.yaml \ --lama_ckpt big-lama """ parser = argparse.ArgumentParser() setup_args(parser) args = parser.parse_args(sys.argv[1:]) device = "cuda" if torch.cuda.is_available() else "cpu" img_stem = Path(args.input_img).stem mask_ps = sorted(glob.glob(args.input_mask_glob)) out_dir = Path(args.output_dir) / img_stem out_dir.mkdir(parents=True, exist_ok=True) img = load_img_to_array(args.input_img) for mask_p in mask_ps: mask = load_img_to_array(mask_p) img_inpainted_p = out_dir / f"inpainted_with_{Path(mask_p).name}" img_inpainted = inpaint_img_with_lama( img, mask, args.lama_config, args.lama_ckpt, device=device ) save_array_to_img(img_inpainted, img_inpainted_p)