Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from PIL import Image, ImageDraw
|
5 |
+
import requests
|
6 |
+
from transformers import SamModel, SamProcessor
|
7 |
+
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
|
8 |
+
import cv2
|
9 |
+
from typing import List
|
10 |
+
|
11 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
12 |
+
|
13 |
+
#Load clipseg Model
|
14 |
+
clip_processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
|
15 |
+
clip_model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined").to(device)
|
16 |
+
|
17 |
+
# Load SAM model and processor
|
18 |
+
model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
|
19 |
+
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
|
20 |
+
|
21 |
+
cache_data = None
|
22 |
+
|
23 |
+
# Prompts to segment damaged area and car
|
24 |
+
prompts = ['damaged', 'car']
|
25 |
+
damage_threshold = 0.4
|
26 |
+
vehicle_threshold = 0.5
|
27 |
+
|
28 |
+
def bbox_normalization(bbox, width, height):
|
29 |
+
height_coeff = height/352
|
30 |
+
width_coeff = width/352
|
31 |
+
normalized_bbox = [int(bbox[0]*width_coeff), int(bbox[1]*height_coeff),
|
32 |
+
int(bbox[2]*width_coeff), int(bbox[3]*height_coeff)]
|
33 |
+
return normalized_bbox
|
34 |
+
|
35 |
+
def bbox_area(bbox):
|
36 |
+
area = (bbox[2]-bbox[0])*(bbox[3]-bbox[1])
|
37 |
+
return area
|
38 |
+
|
39 |
+
def segment_to_bbox(segment_indexs):
|
40 |
+
x_points = []
|
41 |
+
y_points = []
|
42 |
+
for y, list_val in enumerate(segment_indexs):
|
43 |
+
for x, val in enumerate(list_val):
|
44 |
+
if val == 1:
|
45 |
+
x_points.append(x)
|
46 |
+
y_points.append(y)
|
47 |
+
return [np.min(x_points), np.min(y_points), np.max(x_points), np.max(y_points)]
|
48 |
+
|
49 |
+
def clipseg_prediction(image):
|
50 |
+
inputs = processor(text=prompts, images=[image] * len(prompts), padding="max_length", return_tensors="pt")
|
51 |
+
# predict
|
52 |
+
with torch.no_grad():
|
53 |
+
outputs = model(**inputs)
|
54 |
+
preds = outputs.logits.unsqueeze(1)
|
55 |
+
# Setting threshold and classify the image contains vehicle or not
|
56 |
+
flat_preds = torch.sigmoid(preds.squeeze()).reshape((preds.shape[0], -1))
|
57 |
+
|
58 |
+
# Initialize a dummy "unlabeled" mask with the threshold
|
59 |
+
flat_damage_preds_with_treshold = torch.full((2, flat_preds.shape[-1]), damage_threshold)
|
60 |
+
flat_vehicle_preds_with_treshold = torch.full((2, flat_preds.shape[-1]), vehicle_threshold)
|
61 |
+
flat_damage_preds_with_treshold[1:2,:] = flat_preds[0] # damage
|
62 |
+
flat_vehicle_preds_with_treshold[1:2,:] = flat_preds[1] # vehicle
|
63 |
+
|
64 |
+
# Get the top mask index for each pixel
|
65 |
+
damage_inds = torch.topk(flat_damage_preds_with_treshold, 1, dim=0).indices.reshape((preds.shape[-2], preds.shape[-1]))
|
66 |
+
vehicle_inds = torch.topk(flat_vehicle_preds_with_treshold, 1, dim=0).indices.reshape((preds.shape[-2], preds.shape[-1]))
|
67 |
+
|
68 |
+
# bbox creation
|
69 |
+
damage_bbox = segment_to_bbox(damage_inds)
|
70 |
+
vehicle_bbox = segment_to_bbox(vehicle_inds)
|
71 |
+
|
72 |
+
# Vehicle checking
|
73 |
+
if bbox_area(vehicle_bbox) > bbox_area(damage_bbox):
|
74 |
+
return True, bbox_normalization(damage_bbox)
|
75 |
+
else:
|
76 |
+
return False, []
|
77 |
+
|
78 |
+
|
79 |
+
@torch.no_grad()
|
80 |
+
def foward_pass(image_input: np.ndarray, points: List[List[int]]) -> np.ndarray:
|
81 |
+
global cache_data
|
82 |
+
image_input = Image.fromarray(image_input)
|
83 |
+
inputs = processor(image_input, input_points=points, return_tensors="pt").to(device)
|
84 |
+
if not cache_data or not torch.equal(inputs['pixel_values'],cache_data[0]):
|
85 |
+
embedding = model.get_image_embeddings(inputs["pixel_values"])
|
86 |
+
pixels = inputs["pixel_values"]
|
87 |
+
cache_data = [pixels, embedding]
|
88 |
+
del inputs["pixel_values"]
|
89 |
+
|
90 |
+
outputs = model.forward(image_embeddings=cache_data[1], **inputs)
|
91 |
+
masks = processor.image_processor.post_process_masks(
|
92 |
+
outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
|
93 |
+
)
|
94 |
+
masks = masks[0].squeeze(0).numpy().transpose(1, 2, 0)
|
95 |
+
|
96 |
+
return masks
|
97 |
+
|
98 |
+
def main_func(inputs):
|
99 |
+
|
100 |
+
image_input = inputs['image']
|
101 |
+
classification, points = clipseg_prediction(image_input)
|
102 |
+
if classification:
|
103 |
+
masks = foward_pass(image_input, points)
|
104 |
+
|
105 |
+
image_input = Image.fromarray(image_input)
|
106 |
+
|
107 |
+
final_mask = masks[0]
|
108 |
+
mask_colors = np.zeros((final_mask.shape[0], final_mask.shape[1], 3), dtype=np.uint8)
|
109 |
+
mask_colors[final_mask, :] = np.array([[128, 0, 0]])
|
110 |
+
return Image.fromarray((mask_colors * 0.6 + image_input * 0.4).astype('uint8'), 'RGB')
|
111 |
+
else:
|
112 |
+
return Image.fromarray(image_input)
|
113 |
+
|
114 |
+
return pred_masks
|
115 |
+
|
116 |
+
def reset_data():
|
117 |
+
global cache_data
|
118 |
+
cache_data = None
|
119 |
+
|
120 |
+
with gr.Blocks() as demo:
|
121 |
+
gr.Markdown("# Demo to run Vehicle damage detection")
|
122 |
+
gr.Markdown("""This app uses the SAM model and clipseg model to get a vehicle damage area from image.""")
|
123 |
+
with gr.Row():
|
124 |
+
image_input = gr.Image()
|
125 |
+
image_output = gr.Image()
|
126 |
+
|
127 |
+
image_button = gr.Button("Segment Image", variant='primary')
|
128 |
+
|
129 |
+
image_button.click(main_func, inputs=image_input, outputs=image_output)
|
130 |
+
image_input.upload(reset_data)
|
131 |
+
|
132 |
+
demo.launch()
|