Inpaint-Maething / lama_inpaint.py
pg56714's picture
Update lama_inpaint.py
77c9ee7 verified
import os
import sys
import numpy as np
import torch
import yaml
import glob
import argparse
from omegaconf import OmegaConf
from pathlib import Path
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["VECLIB_MAXIMUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"
sys.path.insert(0, str(Path(__file__).resolve().parent / "lama"))
from saicinpainting.evaluation.utils import move_to_device
from saicinpainting.training.trainers import load_checkpoint
from saicinpainting.evaluation.data import pad_tensor_to_modulo
from saicinpainting.evaluation.refinement import refine_predict
from utils import load_img_to_array, save_array_to_img
@torch.no_grad()
def inpaint_img_with_lama(
img: np.ndarray, mask: np.ndarray, config_p: str, ckpt_p: str, mod=8, device="cuda"
):
assert len(mask.shape) == 2
if np.max(mask) == 1:
mask = mask * 255
img = torch.from_numpy(img).float().div(255.0)
mask = torch.from_numpy(mask).float()
predict_config = OmegaConf.load(config_p)
predict_config.model.path = ckpt_p
# device = torch.device(predict_config.device)
device = torch.device(device)
train_config_path = os.path.join(predict_config.model.path, "config.yaml")
with open(train_config_path, "r") as f:
train_config = OmegaConf.create(yaml.safe_load(f))
train_config.training_model.predict_only = True
train_config.visualizer.kind = "noop"
checkpoint_path = os.path.join(
predict_config.model.path, "models", predict_config.model.checkpoint
)
model = load_checkpoint(
train_config, checkpoint_path, strict=False, map_location=device
)
model.freeze()
model.to(device)
batch = {}
batch["image"] = img.permute(2, 0, 1).unsqueeze(0)
batch["mask"] = mask[None, None]
unpad_to_size = [batch["image"].shape[2], batch["image"].shape[3]]
batch["image"] = pad_tensor_to_modulo(batch["image"], mod)
batch["mask"] = pad_tensor_to_modulo(batch["mask"], mod)
# batch = move_to_device(batch, device)
# batch["mask"] = (batch["mask"] > 0) * 1
# batch = model(batch)
# cur_res = batch[predict_config.out_key][0].permute(1, 2, 0)
# cur_res = cur_res.detach().cpu().numpy()
if predict_config.get("refine", False):
batch["unpad_to_size"] = [torch.tensor([size]) for size in unpad_to_size]
cur_res = refine_predict(batch, model, **predict_config.refiner)
cur_res = cur_res[0].permute(1, 2, 0).detach().cpu().numpy()
else:
batch = move_to_device(batch, device)
batch["mask"] = (batch["mask"] > 0) * 1
batch = model(batch)
cur_res = batch[predict_config.out_key][0].permute(1, 2, 0)
cur_res = cur_res.detach().cpu().numpy()
if unpad_to_size is not None:
orig_height, orig_width = unpad_to_size
cur_res = cur_res[:orig_height, :orig_width]
# if unpad_to_size is not None:
# orig_height, orig_width = unpad_to_size
# cur_res = cur_res[:orig_height, :orig_width]
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
return cur_res
def build_lama_model(config_p: str, ckpt_p: str, device="cuda"):
predict_config = OmegaConf.load(config_p)
predict_config.model.path = ckpt_p
# device = torch.device(predict_config.device)
device = torch.device(device)
train_config_path = os.path.join(predict_config.model.path, "config.yaml")
with open(train_config_path, "r") as f:
train_config = OmegaConf.create(yaml.safe_load(f))
train_config.training_model.predict_only = True
train_config.visualizer.kind = "noop"
checkpoint_path = os.path.join(
predict_config.model.path, "models", predict_config.model.checkpoint
)
model = load_checkpoint(
train_config, checkpoint_path, strict=False, map_location=device
)
model.freeze()
model.to(device)
return model
@torch.no_grad()
def inpaint_img_with_builded_lama(
model, img: np.ndarray, mask: np.ndarray, config_p: str, mod=8, device="cuda"
):
assert len(mask.shape) == 2
if np.max(mask) == 1:
mask = mask * 255
img = torch.from_numpy(img).float().div(255.0)
mask = torch.from_numpy(mask).float()
predict_config = OmegaConf.load(config_p)
batch = {}
batch["image"] = img.permute(2, 0, 1).unsqueeze(0)
batch["mask"] = mask[None, None]
unpad_to_size = [batch["image"].shape[2], batch["image"].shape[3]]
batch["image"] = pad_tensor_to_modulo(batch["image"], mod)
batch["mask"] = pad_tensor_to_modulo(batch["mask"], mod)
batch = move_to_device(batch, device)
batch["mask"] = (batch["mask"] > 0) * 1
batch = model(batch)
cur_res = batch[predict_config.out_key][0].permute(1, 2, 0)
cur_res = cur_res.detach().cpu().numpy()
if unpad_to_size is not None:
orig_height, orig_width = unpad_to_size
cur_res = cur_res[:orig_height, :orig_width]
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
return cur_res
def setup_args(parser):
parser.add_argument(
"--input_img",
type=str,
required=True,
help="Path to a single input img",
)
parser.add_argument(
"--input_mask_glob",
type=str,
required=True,
help="Glob to input masks",
)
parser.add_argument(
"--output_dir",
type=str,
required=True,
help="Output path to the directory with results.",
)
parser.add_argument(
"--lama_config",
type=str,
default="./third_party/lama/configs/prediction/default.yaml",
help="The path to the config file of lama model. "
"Default: the config of big-lama",
)
parser.add_argument(
"--lama_ckpt",
type=str,
required=True,
help="The path to the lama checkpoint.",
)
if __name__ == "__main__":
"""Example usage:
python lama_inpaint.py \
--input_img FA_demo/FA1_dog.png \
--input_mask_glob "results/FA1_dog/mask*.png" \
--output_dir results \
--lama_config lama/configs/prediction/default.yaml \
--lama_ckpt big-lama
"""
parser = argparse.ArgumentParser()
setup_args(parser)
args = parser.parse_args(sys.argv[1:])
device = "cuda" if torch.cuda.is_available() else "cpu"
img_stem = Path(args.input_img).stem
mask_ps = sorted(glob.glob(args.input_mask_glob))
out_dir = Path(args.output_dir) / img_stem
out_dir.mkdir(parents=True, exist_ok=True)
img = load_img_to_array(args.input_img)
for mask_p in mask_ps:
mask = load_img_to_array(mask_p)
img_inpainted_p = out_dir / f"inpainted_with_{Path(mask_p).name}"
img_inpainted = inpaint_img_with_lama(
img, mask, args.lama_config, args.lama_ckpt, device=device
)
save_array_to_img(img_inpainted, img_inpainted_p)