Abijith commited on
Commit
05a002a
1 Parent(s): c02b12d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -0
app.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import torch
4
+ from PIL import Image, ImageDraw
5
+ import requests
6
+ from transformers import SamModel, SamProcessor
7
+ from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
8
+ import cv2
9
+ from typing import List
10
+
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+
13
+ #Load clipseg Model
14
+ clip_processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
15
+ clip_model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined").to(device)
16
+
17
+ # Load SAM model and processor
18
+ model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
19
+ processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
20
+
21
+ cache_data = None
22
+
23
+ # Prompts to segment damaged area and car
24
+ prompts = ['damaged', 'car']
25
+ damage_threshold = 0.4
26
+ vehicle_threshold = 0.5
27
+
28
+ def bbox_normalization(bbox, width, height):
29
+ height_coeff = height/352
30
+ width_coeff = width/352
31
+ normalized_bbox = [int(bbox[0]*width_coeff), int(bbox[1]*height_coeff),
32
+ int(bbox[2]*width_coeff), int(bbox[3]*height_coeff)]
33
+ return normalized_bbox
34
+
35
+ def bbox_area(bbox):
36
+ area = (bbox[2]-bbox[0])*(bbox[3]-bbox[1])
37
+ return area
38
+
39
+ def segment_to_bbox(segment_indexs):
40
+ x_points = []
41
+ y_points = []
42
+ for y, list_val in enumerate(segment_indexs):
43
+ for x, val in enumerate(list_val):
44
+ if val == 1:
45
+ x_points.append(x)
46
+ y_points.append(y)
47
+ return [np.min(x_points), np.min(y_points), np.max(x_points), np.max(y_points)]
48
+
49
+ def clipseg_prediction(image):
50
+ inputs = processor(text=prompts, images=[image] * len(prompts), padding="max_length", return_tensors="pt")
51
+ # predict
52
+ with torch.no_grad():
53
+ outputs = model(**inputs)
54
+ preds = outputs.logits.unsqueeze(1)
55
+ # Setting threshold and classify the image contains vehicle or not
56
+ flat_preds = torch.sigmoid(preds.squeeze()).reshape((preds.shape[0], -1))
57
+
58
+ # Initialize a dummy "unlabeled" mask with the threshold
59
+ flat_damage_preds_with_treshold = torch.full((2, flat_preds.shape[-1]), damage_threshold)
60
+ flat_vehicle_preds_with_treshold = torch.full((2, flat_preds.shape[-1]), vehicle_threshold)
61
+ flat_damage_preds_with_treshold[1:2,:] = flat_preds[0] # damage
62
+ flat_vehicle_preds_with_treshold[1:2,:] = flat_preds[1] # vehicle
63
+
64
+ # Get the top mask index for each pixel
65
+ damage_inds = torch.topk(flat_damage_preds_with_treshold, 1, dim=0).indices.reshape((preds.shape[-2], preds.shape[-1]))
66
+ vehicle_inds = torch.topk(flat_vehicle_preds_with_treshold, 1, dim=0).indices.reshape((preds.shape[-2], preds.shape[-1]))
67
+
68
+ # bbox creation
69
+ damage_bbox = segment_to_bbox(damage_inds)
70
+ vehicle_bbox = segment_to_bbox(vehicle_inds)
71
+
72
+ # Vehicle checking
73
+ if bbox_area(vehicle_bbox) > bbox_area(damage_bbox):
74
+ return True, bbox_normalization(damage_bbox)
75
+ else:
76
+ return False, []
77
+
78
+
79
+ @torch.no_grad()
80
+ def foward_pass(image_input: np.ndarray, points: List[List[int]]) -> np.ndarray:
81
+ global cache_data
82
+ image_input = Image.fromarray(image_input)
83
+ inputs = processor(image_input, input_points=points, return_tensors="pt").to(device)
84
+ if not cache_data or not torch.equal(inputs['pixel_values'],cache_data[0]):
85
+ embedding = model.get_image_embeddings(inputs["pixel_values"])
86
+ pixels = inputs["pixel_values"]
87
+ cache_data = [pixels, embedding]
88
+ del inputs["pixel_values"]
89
+
90
+ outputs = model.forward(image_embeddings=cache_data[1], **inputs)
91
+ masks = processor.image_processor.post_process_masks(
92
+ outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
93
+ )
94
+ masks = masks[0].squeeze(0).numpy().transpose(1, 2, 0)
95
+
96
+ return masks
97
+
98
+ def main_func(inputs):
99
+
100
+ image_input = inputs['image']
101
+ classification, points = clipseg_prediction(image_input)
102
+ if classification:
103
+ masks = foward_pass(image_input, points)
104
+
105
+ image_input = Image.fromarray(image_input)
106
+
107
+ final_mask = masks[0]
108
+ mask_colors = np.zeros((final_mask.shape[0], final_mask.shape[1], 3), dtype=np.uint8)
109
+ mask_colors[final_mask, :] = np.array([[128, 0, 0]])
110
+ return Image.fromarray((mask_colors * 0.6 + image_input * 0.4).astype('uint8'), 'RGB')
111
+ else:
112
+ return Image.fromarray(image_input)
113
+
114
+ return pred_masks
115
+
116
+ def reset_data():
117
+ global cache_data
118
+ cache_data = None
119
+
120
+ with gr.Blocks() as demo:
121
+ gr.Markdown("# Demo to run Vehicle damage detection")
122
+ gr.Markdown("""This app uses the SAM model and clipseg model to get a vehicle damage area from image.""")
123
+ with gr.Row():
124
+ image_input = gr.Image()
125
+ image_output = gr.Image()
126
+
127
+ image_button = gr.Button("Segment Image", variant='primary')
128
+
129
+ image_button.click(main_func, inputs=image_input, outputs=image_output)
130
+ image_input.upload(reset_data)
131
+
132
+ demo.launch()