Inpaint-Anything / stable_diffusion_inpaint.py
RysonFeng
Add source code
cdb26a4
raw history blame
No virus
3.81 kB
import os
import sys
import glob
import argparse
import torch
import numpy as np
import PIL.Image as Image
from pathlib import Path
from diffusers import StableDiffusionInpaintPipeline
from utils.mask_processing import crop_for_filling_pre, crop_for_filling_post
from utils.crop_for_replacing import recover_size, resize_and_pad
from utils import load_img_to_array, save_array_to_img
def fill_img_with_sd(
img: np.ndarray,
mask: np.ndarray,
text_prompt: str,
device="cuda"
):
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-inpainting",
torch_dtype=torch.float32,
).to(device)
img_crop, mask_crop = crop_for_filling_pre(img, mask)
img_crop_filled = pipe(
prompt=text_prompt,
image=Image.fromarray(img_crop),
mask_image=Image.fromarray(mask_crop)
).images[0]
img_filled = crop_for_filling_post(img, mask, np.array(img_crop_filled))
return img_filled
def replace_img_with_sd(
img: np.ndarray,
mask: np.ndarray,
text_prompt: str,
step: int = 50,
device="cuda"
):
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-inpainting",
torch_dtype=torch.float32,
).to(device)
img_padded, mask_padded, padding_factors = resize_and_pad(img, mask)
img_padded = pipe(
prompt=text_prompt,
image=Image.fromarray(img_padded),
mask_image=Image.fromarray(255 - mask_padded),
num_inference_steps=step,
).images[0]
height, width, _ = img.shape
img_resized, mask_resized = recover_size(
np.array(img_padded), mask_padded, (height, width), padding_factors)
mask_resized = np.expand_dims(mask_resized, -1) / 255
img_resized = img_resized * (1-mask_resized) + img * mask_resized
return img_resized
def setup_args(parser):
parser.add_argument(
"--input_img", type=str, required=True,
help="Path to a single input img",
)
parser.add_argument(
"--text_prompt", type=str, required=True,
help="Text prompt",
)
parser.add_argument(
"--input_mask_glob", type=str, required=True,
help="Glob to input masks",
)
parser.add_argument(
"--output_dir", type=str, required=True,
help="Output path to the directory with results.",
)
parser.add_argument(
"--seed", type=int,
help="Specify seed for reproducibility.",
)
parser.add_argument(
"--deterministic", action="store_true",
help="Use deterministic algorithms for reproducibility.",
)
if __name__ == "__main__":
"""Example usage:
python lama_inpaint.py \
--input_img FA_demo/FA1_dog.png \
--input_mask_glob "results/FA1_dog/mask*.png" \
--text_prompt "a teddy bear on a bench" \
--output_dir results
"""
parser = argparse.ArgumentParser()
setup_args(parser)
args = parser.parse_args(sys.argv[1:])
device = "cuda" if torch.cuda.is_available() else "cpu"
if args.deterministic:
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
torch.use_deterministic_algorithms(True)
img_stem = Path(args.input_img).stem
mask_ps = sorted(glob.glob(args.input_mask_glob))
out_dir = Path(args.output_dir) / img_stem
out_dir.mkdir(parents=True, exist_ok=True)
img = load_img_to_array(args.input_img)
for mask_p in mask_ps:
if args.seed is not None:
torch.manual_seed(args.seed)
mask = load_img_to_array(mask_p)
img_filled_p = out_dir / f"filled_with_{Path(mask_p).name}"
img_filled = fill_img_with_sd(
img, mask, args.text_prompt, device=device)
save_array_to_img(img_filled, img_filled_p)