|
|
|
|
|
import requests |
|
import torch |
|
import numpy as np |
|
from PIL import Image |
|
from io import BytesIO |
|
from diffusers import DiffusionPipeline |
|
|
|
from segment_anything import sam_model_registry, SamPredictor |
|
|
|
|
|
""" |
|
Step 1: Download and preprocess example demo images |
|
""" |
|
def download_image(url): |
|
response = requests.get(url) |
|
return Image.open(BytesIO(response.content)).convert("RGB") |
|
|
|
|
|
img_url = "https://github.com/IDEA-Research/detrex-storage/blob/main/assets/grounded_sam/paint_by_example/input_image.png?raw=true" |
|
|
|
|
|
example_url = "https://github.com/IDEA-Research/detrex-storage/blob/main/assets/grounded_sam/paint_by_example/labrador_example.jpg?raw=true" |
|
|
|
init_image = download_image(img_url).resize((512, 512)) |
|
example_image = download_image(example_url).resize((512, 512)) |
|
|
|
|
|
""" |
|
Step 2: Initialize SAM and PaintByExample models |
|
""" |
|
|
|
DEVICE = "cuda:1" |
|
|
|
|
|
SAM_ENCODER_VERSION = "vit_h" |
|
SAM_CHECKPOINT_PATH = "/comp_robot/rentianhe/code/Grounded-Segment-Anything/sam_vit_h_4b8939.pth" |
|
sam = sam_model_registry[SAM_ENCODER_VERSION](checkpoint=SAM_CHECKPOINT_PATH).to(device=DEVICE) |
|
sam_predictor = SamPredictor(sam) |
|
sam_predictor.set_image(np.array(init_image)) |
|
|
|
|
|
CACHE_DIR = "/comp_robot/rentianhe/weights/diffusers/" |
|
pipe = DiffusionPipeline.from_pretrained( |
|
"Fantasy-Studio/Paint-by-Example", |
|
torch_dtype=torch.float16, |
|
cache_dir=CACHE_DIR, |
|
) |
|
pipe = pipe.to(DEVICE) |
|
|
|
|
|
""" |
|
Step 3: Get masks with SAM by prompt (box or point) and inpaint the mask region by example image. |
|
""" |
|
|
|
input_point = np.array([[350, 256]]) |
|
input_label = np.array([1]) |
|
|
|
masks, _, _ = sam_predictor.predict( |
|
point_coords=input_point, |
|
point_labels=input_label, |
|
multimask_output=False |
|
) |
|
mask = masks[0] |
|
mask_pil = Image.fromarray(mask) |
|
|
|
mask_pil.save("./mask.jpg") |
|
|
|
image = pipe( |
|
image=init_image, |
|
mask_image=mask_pil, |
|
example_image=example_image, |
|
num_inference_steps=500, |
|
guidance_scale=9.0 |
|
).images[0] |
|
|
|
image.save("./paint_by_example_demo.jpg") |
|
|