Spaces:
Running
Running
import os | |
import sys | |
import numpy as np | |
import torch | |
import yaml | |
import glob | |
import argparse | |
from omegaconf import OmegaConf | |
from pathlib import Path | |
os.environ["OMP_NUM_THREADS"] = "1" | |
os.environ["OPENBLAS_NUM_THREADS"] = "1" | |
os.environ["MKL_NUM_THREADS"] = "1" | |
os.environ["VECLIB_MAXIMUM_THREADS"] = "1" | |
os.environ["NUMEXPR_NUM_THREADS"] = "1" | |
sys.path.insert(0, str(Path(__file__).resolve().parent / "lama")) | |
from saicinpainting.evaluation.utils import move_to_device | |
from saicinpainting.training.trainers import load_checkpoint | |
from saicinpainting.evaluation.data import pad_tensor_to_modulo | |
from saicinpainting.evaluation.refinement import refine_predict | |
from utils import load_img_to_array, save_array_to_img | |
def inpaint_img_with_lama( | |
img: np.ndarray, mask: np.ndarray, config_p: str, ckpt_p: str, mod=8, device="cuda" | |
): | |
assert len(mask.shape) == 2 | |
if np.max(mask) == 1: | |
mask = mask * 255 | |
img = torch.from_numpy(img).float().div(255.0) | |
mask = torch.from_numpy(mask).float() | |
predict_config = OmegaConf.load(config_p) | |
predict_config.model.path = ckpt_p | |
# device = torch.device(predict_config.device) | |
device = torch.device(device) | |
train_config_path = os.path.join(predict_config.model.path, "config.yaml") | |
with open(train_config_path, "r") as f: | |
train_config = OmegaConf.create(yaml.safe_load(f)) | |
train_config.training_model.predict_only = True | |
train_config.visualizer.kind = "noop" | |
checkpoint_path = os.path.join( | |
predict_config.model.path, "models", predict_config.model.checkpoint | |
) | |
model = load_checkpoint( | |
train_config, checkpoint_path, strict=False, map_location=device | |
) | |
model.freeze() | |
model.to(device) | |
batch = {} | |
batch["image"] = img.permute(2, 0, 1).unsqueeze(0) | |
batch["mask"] = mask[None, None] | |
unpad_to_size = [batch["image"].shape[2], batch["image"].shape[3]] | |
batch["image"] = pad_tensor_to_modulo(batch["image"], mod) | |
batch["mask"] = pad_tensor_to_modulo(batch["mask"], mod) | |
# batch = move_to_device(batch, device) | |
# batch["mask"] = (batch["mask"] > 0) * 1 | |
# batch = model(batch) | |
# cur_res = batch[predict_config.out_key][0].permute(1, 2, 0) | |
# cur_res = cur_res.detach().cpu().numpy() | |
if predict_config.get("refine", False): | |
batch["unpad_to_size"] = [torch.tensor([size]) for size in unpad_to_size] | |
cur_res = refine_predict(batch, model, **predict_config.refiner) | |
cur_res = cur_res[0].permute(1, 2, 0).detach().cpu().numpy() | |
else: | |
batch = move_to_device(batch, device) | |
batch["mask"] = (batch["mask"] > 0) * 1 | |
batch = model(batch) | |
cur_res = batch[predict_config.out_key][0].permute(1, 2, 0) | |
cur_res = cur_res.detach().cpu().numpy() | |
if unpad_to_size is not None: | |
orig_height, orig_width = unpad_to_size | |
cur_res = cur_res[:orig_height, :orig_width] | |
# if unpad_to_size is not None: | |
# orig_height, orig_width = unpad_to_size | |
# cur_res = cur_res[:orig_height, :orig_width] | |
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8") | |
return cur_res | |
def build_lama_model(config_p: str, ckpt_p: str, device="cuda"): | |
predict_config = OmegaConf.load(config_p) | |
predict_config.model.path = ckpt_p | |
# device = torch.device(predict_config.device) | |
device = torch.device(device) | |
train_config_path = os.path.join(predict_config.model.path, "config.yaml") | |
with open(train_config_path, "r") as f: | |
train_config = OmegaConf.create(yaml.safe_load(f)) | |
train_config.training_model.predict_only = True | |
train_config.visualizer.kind = "noop" | |
checkpoint_path = os.path.join( | |
predict_config.model.path, "models", predict_config.model.checkpoint | |
) | |
model = load_checkpoint( | |
train_config, checkpoint_path, strict=False, map_location=device | |
) | |
model.freeze() | |
model.to(device) | |
return model | |
def inpaint_img_with_builded_lama( | |
model, img: np.ndarray, mask: np.ndarray, config_p: str, mod=8, device="cuda" | |
): | |
assert len(mask.shape) == 2 | |
if np.max(mask) == 1: | |
mask = mask * 255 | |
img = torch.from_numpy(img).float().div(255.0) | |
mask = torch.from_numpy(mask).float() | |
predict_config = OmegaConf.load(config_p) | |
batch = {} | |
batch["image"] = img.permute(2, 0, 1).unsqueeze(0) | |
batch["mask"] = mask[None, None] | |
unpad_to_size = [batch["image"].shape[2], batch["image"].shape[3]] | |
batch["image"] = pad_tensor_to_modulo(batch["image"], mod) | |
batch["mask"] = pad_tensor_to_modulo(batch["mask"], mod) | |
batch = move_to_device(batch, device) | |
batch["mask"] = (batch["mask"] > 0) * 1 | |
batch = model(batch) | |
cur_res = batch[predict_config.out_key][0].permute(1, 2, 0) | |
cur_res = cur_res.detach().cpu().numpy() | |
if unpad_to_size is not None: | |
orig_height, orig_width = unpad_to_size | |
cur_res = cur_res[:orig_height, :orig_width] | |
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8") | |
return cur_res | |
def setup_args(parser): | |
parser.add_argument( | |
"--input_img", | |
type=str, | |
required=True, | |
help="Path to a single input img", | |
) | |
parser.add_argument( | |
"--input_mask_glob", | |
type=str, | |
required=True, | |
help="Glob to input masks", | |
) | |
parser.add_argument( | |
"--output_dir", | |
type=str, | |
required=True, | |
help="Output path to the directory with results.", | |
) | |
parser.add_argument( | |
"--lama_config", | |
type=str, | |
default="./third_party/lama/configs/prediction/default.yaml", | |
help="The path to the config file of lama model. " | |
"Default: the config of big-lama", | |
) | |
parser.add_argument( | |
"--lama_ckpt", | |
type=str, | |
required=True, | |
help="The path to the lama checkpoint.", | |
) | |
if __name__ == "__main__": | |
"""Example usage: | |
python lama_inpaint.py \ | |
--input_img FA_demo/FA1_dog.png \ | |
--input_mask_glob "results/FA1_dog/mask*.png" \ | |
--output_dir results \ | |
--lama_config lama/configs/prediction/default.yaml \ | |
--lama_ckpt big-lama | |
""" | |
parser = argparse.ArgumentParser() | |
setup_args(parser) | |
args = parser.parse_args(sys.argv[1:]) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
img_stem = Path(args.input_img).stem | |
mask_ps = sorted(glob.glob(args.input_mask_glob)) | |
out_dir = Path(args.output_dir) / img_stem | |
out_dir.mkdir(parents=True, exist_ok=True) | |
img = load_img_to_array(args.input_img) | |
for mask_p in mask_ps: | |
mask = load_img_to_array(mask_p) | |
img_inpainted_p = out_dir / f"inpainted_with_{Path(mask_p).name}" | |
img_inpainted = inpaint_img_with_lama( | |
img, mask, args.lama_config, args.lama_ckpt, device=device | |
) | |
save_array_to_img(img_inpainted, img_inpainted_p) | |