vierundvi / playground /PaintByExample /sam_paint_by_example.py
mart9992's picture
m
2cd560a
raw
history blame
2.38 kB
# !pip install diffusers transformers
import requests
import torch
import numpy as np
from PIL import Image
from io import BytesIO
from diffusers import DiffusionPipeline
from segment_anything import sam_model_registry, SamPredictor
"""
Step 1: Download and preprocess example demo images
"""
def download_image(url):
response = requests.get(url)
return Image.open(BytesIO(response.content)).convert("RGB")
img_url = "https://github.com/IDEA-Research/detrex-storage/blob/main/assets/grounded_sam/paint_by_example/input_image.png?raw=true"
# example_url = "https://github.com/IDEA-Research/detrex-storage/blob/main/assets/grounded_sam/paint_by_example/pomeranian_example.jpg?raw=True"
# example_url = "https://github.com/IDEA-Research/detrex-storage/blob/main/assets/grounded_sam/paint_by_example/example_image.jpg?raw=true"
example_url = "https://github.com/IDEA-Research/detrex-storage/blob/main/assets/grounded_sam/paint_by_example/labrador_example.jpg?raw=true"
init_image = download_image(img_url).resize((512, 512))
example_image = download_image(example_url).resize((512, 512))
"""
Step 2: Initialize SAM and PaintByExample models
"""
DEVICE = "cuda:1"
# SAM
SAM_ENCODER_VERSION = "vit_h"
SAM_CHECKPOINT_PATH = "/comp_robot/rentianhe/code/Grounded-Segment-Anything/sam_vit_h_4b8939.pth"
sam = sam_model_registry[SAM_ENCODER_VERSION](checkpoint=SAM_CHECKPOINT_PATH).to(device=DEVICE)
sam_predictor = SamPredictor(sam)
sam_predictor.set_image(np.array(init_image))
# PaintByExample Pipeline
CACHE_DIR = "/comp_robot/rentianhe/weights/diffusers/"
pipe = DiffusionPipeline.from_pretrained(
"Fantasy-Studio/Paint-by-Example",
torch_dtype=torch.float16,
cache_dir=CACHE_DIR,
)
pipe = pipe.to(DEVICE)
"""
Step 3: Get masks with SAM by prompt (box or point) and inpaint the mask region by example image.
"""
input_point = np.array([[350, 256]])
input_label = np.array([1]) # positive label
masks, _, _ = sam_predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=False
)
mask = masks[0] # [1, 512, 512] to [512, 512] np.ndarray
mask_pil = Image.fromarray(mask)
mask_pil.save("./mask.jpg")
image = pipe(
image=init_image,
mask_image=mask_pil,
example_image=example_image,
num_inference_steps=500,
guidance_scale=9.0
).images[0]
image.save("./paint_by_example_demo.jpg")