SAMSD / app.py
Goodnight7's picture
Update app.py
e1f8fab verified
raw
history blame
1.98 kB
import gradio as gr
import numpy as np
import torch
from diffusers import StableDiffusionInpaintPipeline
from PIL import Image
from segment_anything import SamPredictor, sam_model_registry
device= "cuda"
sam_checkpoint = "weights/sam_vit_b_01ec64.pth"
model_type = "vit_h"
sam = sam_model_registry[model_type](checkpoint= sam_checkpoint)
sam.to(device)
predictor= SamPredictor(sam)
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-inpainting",
torch_dtype = torch.float16,
)
pipe = pipe.to(device)
selected_pixels = []
with gr.Blocks() as demo:
with gr.Row():
input_img= gr.Image(label= "Input")
mask_img = gr.Image(label = "Mask")
output_img= gr.Image(label = "Output")
with gr.Row():
prompt_text = gr.Textbox(lines=1, label= "Prompt")
with gr.Row():
submit = gr.Button("Submit")
def generate_mask(image, evt: gr.SelectData ):
selected_pixels.append(evt.index)
predictor.set(image)
input_points = np.array(selected_pixels)
input_label= np.ones(input_points.shape[0])
mask, _ , _ = predictor.predict(
point_coords= input_points,
point_label= input_label,
multimask_output = False,
)
#(1, szn sz) shape of mask
mask= Image.fromarray(mask[0 : , : ])
def inpaint(image, mask, prompt):
image = Image.fromarray(image)
mask = Image.fromarray(mask)
image= image.resize((512, 512))
image= image.resize((512, 512))
output = pipe (
prompt = prompt,
image= image,
mask_image= mask,
).images[0]
return output
input_img.select(generate_mask, [input_img], [mask_img])
submit.click(inpaint, inputs= [input_img, mask_img, prompt_text],
outputs=[output_img],
)
if __name__ == "__main__":
demo.launch()