Spaces:
Runtime error
Runtime error
import gradio as gr | |
import numpy as np | |
import torch | |
from diffusers import StableDiffusionInpaintPipeline | |
from PIL import Image | |
from segment_anything import SamPredictor, sam_model_registry | |
device= "cuda" | |
sam_checkpoint = "weights/sam_vit_b_01ec64.pth" | |
model_type = "vit_h" | |
sam = sam_model_registry[model_type](checkpoint= sam_checkpoint) | |
sam.to(device) | |
predictor= SamPredictor(sam) | |
pipe = StableDiffusionInpaintPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-2-inpainting", | |
torch_dtype = torch.float16, | |
) | |
pipe = pipe.to(device) | |
selected_pixels = [] | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
input_img= gr.Image(label= "Input") | |
mask_img = gr.Image(label = "Mask") | |
output_img= gr.Image(label = "Output") | |
with gr.Row(): | |
prompt_text = gr.Textbox(lines=1, label= "Prompt") | |
with gr.Row(): | |
submit = gr.Button("Submit") | |
def generate_mask(image, evt: gr.SelectData ): | |
selected_pixels.append(evt.index) | |
predictor.set(image) | |
input_points = np.array(selected_pixels) | |
input_label= np.ones(input_points.shape[0]) | |
mask, _ , _ = predictor.predict( | |
point_coords= input_points, | |
point_label= input_label, | |
multimask_output = False, | |
) | |
#(1, szn sz) shape of mask | |
mask= Image.fromarray(mask[0 : , : ]) | |
def inpaint(image, mask, prompt): | |
image = Image.fromarray(image) | |
mask = Image.fromarray(mask) | |
image= image.resize((512, 512)) | |
image= image.resize((512, 512)) | |
output = pipe ( | |
prompt = prompt, | |
image= image, | |
mask_image= mask, | |
).images[0] | |
return output | |
input_img.select(generate_mask, [input_img], [mask_img]) | |
submit.click(inpaint, inputs= [input_img, mask_img, prompt_text], | |
outputs=[output_img], | |
) | |
if __name__ == "__main__": | |
demo.launch() | |