nervn / playground /RePaint /repaint.py
mart9992's picture
m
b793f0c
raw
history blame
1.36 kB
from io import BytesIO
import torch
import PIL
import requests
from diffusers import RePaintPipeline, RePaintScheduler
def download_image(url):
response = requests.get(url)
return PIL.Image.open(BytesIO(response.content)).convert("RGB")
img_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/repaint/celeba_hq_256.png"
mask_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/repaint/mask_256.png"
# Load the original image and the mask as PIL images
original_image = download_image(img_url).resize((256, 256))
mask_image = download_image(mask_url).resize((256, 256))
# Load the RePaint scheduler and pipeline based on a pretrained DDPM model
DEVICE = "cuda:1"
CACHE_DIR = "/comp_robot/rentianhe/weights/diffusers/"
scheduler = RePaintScheduler.from_pretrained("google/ddpm-ema-celebahq-256", cache_dir=CACHE_DIR)
pipe = RePaintPipeline.from_pretrained("google/ddpm-ema-celebahq-256", scheduler=scheduler, cache_dir=CACHE_DIR)
pipe = pipe.to(DEVICE)
generator = torch.Generator(device=DEVICE).manual_seed(0)
output = pipe(
image=original_image,
mask_image=mask_image,
num_inference_steps=250,
eta=0.0,
jump_length=10,
jump_n_sample=10,
generator=generator,
)
inpainted_image = output.images[0]
inpainted_image.save("./repaint_demo.jpg")