ahmetyaylalioglu commited on
Commit
9f8e4ba
1 Parent(s): 2b8eec6

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +61 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import numpy as np
4
+ from transformers import SamModel, SamProcessor
5
+ from diffusers import AutoPipelineForInpainting
6
+ import torch
7
+
8
+ # Model setup
9
+ device = "cuda"
10
+ model_name = "facebook/sam-vit-huge"
11
+ model = SamModel.from_pretrained(model_name).to(device)
12
+ processor = SamProcessor.from_pretrained(model_name)
13
+
14
+ def mask_to_rgb(mask):
15
+ bg_transparent = np.zeros(mask.shape + (4,), dtype=np.uint8)
16
+ bg_transparent[mask == 1] = [0, 255, 0, 127]
17
+ return bg_transparent
18
+
19
+ def get_processed_inputs(image, points):
20
+ input_points = [[list(map(int, point.split(',')))] for point in points.split('|') if point]
21
+ inputs = processor(image, input_points, return_tensors="pt").to(device)
22
+ with torch.no_grad():
23
+ outputs = model(**inputs)
24
+ masks = processor.image_processor.post_process_masks(
25
+ outputs.pred_masks.cpu(),
26
+ inputs["original_sizes"].cpu(),
27
+ inputs["reshaped_input_sizes"].cpu()
28
+ )
29
+ best_mask = masks[0][0][outputs.iou_scores.argmax()]
30
+ return ~best_mask.cpu().numpy()
31
+
32
+ def inpaint(raw_image, input_mask, prompt, negative_prompt=None, seed=74294536, cfgs=7):
33
+ mask_image = Image.fromarray(input_mask)
34
+ rand_gen = torch.manual_seed(seed)
35
+ pipeline = AutoPipelineForInpainting.from_pretrained(
36
+ "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", torch_dtype=torch.float16
37
+ )
38
+ pipeline.enable_model_cpu_offload()
39
+ image = pipeline(
40
+ prompt=prompt,
41
+ image=raw_image,
42
+ mask_image=mask_image,
43
+ guidance_scale=cfgs,
44
+ negative_prompt=negative_prompt,
45
+ generator=rand_gen
46
+ ).images[0]
47
+ return image
48
+
49
+ # Gradio Interface with Click Events
50
+ def gradio_interface(image, points):
51
+ raw_image = Image.fromarray(image).convert("RGB").resize((512, 512))
52
+ mask = get_processed_inputs(raw_image, points)
53
+ processed_image = inpaint(raw_image, mask, "a car driving on Mars. Studio lights, 1970s", "artifacts, low quality, distortion")
54
+ return processed_image, mask_to_rgb(mask)
55
+
56
+ iface = gr.Interface(
57
+ fn=gradio_interface,
58
+ inputs=["image", gr.Image(shape=(512, 512), image_mode='RGB', source="canvas", tool="sketch")],
59
+ outputs=["image", "image"]
60
+ )
61
+ iface.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ transformers
4
+ diffusers
5
+ numpy
6
+ Pillow