import os import sys import glob import argparse import torch import numpy as np import PIL.Image as Image from pathlib import Path from diffusers import StableDiffusionInpaintPipeline from utils.mask_processing import crop_for_filling_pre, crop_for_filling_post from utils.crop_for_replacing import recover_size, resize_and_pad from utils import load_img_to_array, save_array_to_img def fill_img_with_sd( img: np.ndarray, mask: np.ndarray, text_prompt: str, device="cuda" ): pipe = StableDiffusionInpaintPipeline.from_pretrained( "stabilityai/stable-diffusion-2-inpainting", torch_dtype=torch.float32, ).to(device) img_crop, mask_crop = crop_for_filling_pre(img, mask) img_crop_filled = pipe( prompt=text_prompt, image=Image.fromarray(img_crop), mask_image=Image.fromarray(mask_crop), ).images[0] img_filled = crop_for_filling_post(img, mask, np.array(img_crop_filled)) return img_filled def replace_img_with_sd( img: np.ndarray, mask: np.ndarray, text_prompt: str, step: int = 50, device="cuda" ): pipe = StableDiffusionInpaintPipeline.from_pretrained( "stabilityai/stable-diffusion-2-inpainting", torch_dtype=torch.float32, ).to(device) img_padded, mask_padded, padding_factors = resize_and_pad(img, mask) img_padded = pipe( prompt=text_prompt, image=Image.fromarray(img_padded), mask_image=Image.fromarray(255 - mask_padded), num_inference_steps=step, ).images[0] height, width, _ = img.shape img_resized, mask_resized = recover_size( np.array(img_padded), mask_padded, (height, width), padding_factors ) mask_resized = np.expand_dims(mask_resized, -1) / 255 img_resized = img_resized * (1 - mask_resized) + img * mask_resized return img_resized def setup_args(parser): parser.add_argument( "--input_img", type=str, required=True, help="Path to a single input img", ) parser.add_argument( "--text_prompt", type=str, required=True, help="Text prompt", ) parser.add_argument( "--input_mask_glob", type=str, required=True, help="Glob to input masks", ) parser.add_argument( "--output_dir", type=str, required=True, help="Output path to the directory with results.", ) parser.add_argument( "--seed", type=int, help="Specify seed for reproducibility.", ) parser.add_argument( "--deterministic", action="store_true", help="Use deterministic algorithms for reproducibility.", ) if __name__ == "__main__": """Example usage: python lama_inpaint.py \ --input_img FA_demo/FA1_dog.png \ --input_mask_glob "results/FA1_dog/mask*.png" \ --text_prompt "a teddy bear on a bench" \ --output_dir results """ parser = argparse.ArgumentParser() setup_args(parser) args = parser.parse_args(sys.argv[1:]) device = "cuda" if torch.cuda.is_available() else "cpu" if args.deterministic: os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" torch.use_deterministic_algorithms(True) img_stem = Path(args.input_img).stem mask_ps = sorted(glob.glob(args.input_mask_glob)) out_dir = Path(args.output_dir) / img_stem out_dir.mkdir(parents=True, exist_ok=True) img = load_img_to_array(args.input_img) for mask_p in mask_ps: if args.seed is not None: torch.manual_seed(args.seed) mask = load_img_to_array(mask_p) img_filled_p = out_dir / f"filled_with_{Path(mask_p).name}" img_filled = fill_img_with_sd(img, mask, args.text_prompt, device=device) save_array_to_img(img_filled, img_filled_p)