Goodnight7 commited on
Commit
e1f8fab
1 Parent(s): dc730ff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -0
app.py CHANGED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import torch
4
+ from diffusers import StableDiffusionInpaintPipeline
5
+ from PIL import Image
6
+ from segment_anything import SamPredictor, sam_model_registry
7
+
8
+
9
+ device= "cuda"
10
+ sam_checkpoint = "weights/sam_vit_b_01ec64.pth"
11
+ model_type = "vit_h"
12
+ sam = sam_model_registry[model_type](checkpoint= sam_checkpoint)
13
+ sam.to(device)
14
+
15
+ predictor= SamPredictor(sam)
16
+
17
+ pipe = StableDiffusionInpaintPipeline.from_pretrained(
18
+ "stabilityai/stable-diffusion-2-inpainting",
19
+ torch_dtype = torch.float16,
20
+ )
21
+
22
+
23
+ pipe = pipe.to(device)
24
+
25
+ selected_pixels = []
26
+ with gr.Blocks() as demo:
27
+ with gr.Row():
28
+ input_img= gr.Image(label= "Input")
29
+ mask_img = gr.Image(label = "Mask")
30
+ output_img= gr.Image(label = "Output")
31
+
32
+ with gr.Row():
33
+ prompt_text = gr.Textbox(lines=1, label= "Prompt")
34
+
35
+ with gr.Row():
36
+ submit = gr.Button("Submit")
37
+
38
+ def generate_mask(image, evt: gr.SelectData ):
39
+ selected_pixels.append(evt.index)
40
+ predictor.set(image)
41
+ input_points = np.array(selected_pixels)
42
+ input_label= np.ones(input_points.shape[0])
43
+ mask, _ , _ = predictor.predict(
44
+ point_coords= input_points,
45
+ point_label= input_label,
46
+ multimask_output = False,
47
+
48
+ )
49
+ #(1, szn sz) shape of mask
50
+ mask= Image.fromarray(mask[0 : , : ])
51
+
52
+
53
+ def inpaint(image, mask, prompt):
54
+ image = Image.fromarray(image)
55
+ mask = Image.fromarray(mask)
56
+
57
+ image= image.resize((512, 512))
58
+ image= image.resize((512, 512))
59
+ output = pipe (
60
+ prompt = prompt,
61
+ image= image,
62
+ mask_image= mask,
63
+
64
+ ).images[0]
65
+
66
+ return output
67
+
68
+ input_img.select(generate_mask, [input_img], [mask_img])
69
+
70
+ submit.click(inpaint, inputs= [input_img, mask_img, prompt_text],
71
+ outputs=[output_img],
72
+ )
73
+
74
+ if __name__ == "__main__":
75
+ demo.launch()
76
+
77
+