Aastha commited on
Commit
f1bc325
β€’
1 Parent(s): 909e5f1

Ensemble Kosmos2

Browse files
Files changed (1) hide show
  1. app.py +375 -96
app.py CHANGED
@@ -1,105 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  from efficientnet_pytorch import EfficientNet
3
  from torchvision import transforms
4
  from PIL import Image
5
  import gradio as gr
6
  from super_gradients.training import models
7
- import cv2
8
- import numpy as np
9
 
10
- device = torch.device("cpu")
11
-
12
- # Load the YOLO-NAS model
13
- yolo_nas_l = models.get("yolo_nas_l", pretrained_weights="coco")
14
-
15
- def bounding_boxes_overlap(box1, box2):
16
- """Check if two bounding boxes overlap or touch."""
17
- x1, y1, x2, y2 = box1
18
- x3, y3, x4, y4 = box2
19
- return not (x3 > x2 or x4 < x1 or y3 > y2 or y4 < y1)
20
-
21
- def merge_boxes(box1, box2):
22
- """Return the encompassing bounding box of two boxes."""
23
- x1, y1, x2, y2 = box1
24
- x3, y3, x4, y4 = box2
25
- x = min(x1, x3)
26
- y = min(y1, y3)
27
- w = max(x2, x4)
28
- h = max(y2, y4)
29
- return (x, y, w, h)
30
-
31
- def save_merged_boxes(predictions, image_np):
32
- """Save merged bounding boxes as separate images."""
33
- processed_boxes = set()
34
- roi = None # Initialize roi to None
35
-
36
- for image_prediction in predictions:
37
- bboxes = image_prediction.prediction.bboxes_xyxy
38
- for box1 in bboxes:
39
- for box2 in bboxes:
40
- if np.array_equal(box1, box2):
41
- continue
42
- if bounding_boxes_overlap(box1, box2) and tuple(box1) not in processed_boxes and tuple(box2) not in processed_boxes:
43
- merged_box = merge_boxes(box1, box2)
44
- roi = image_np[int(merged_box[1]):int(merged_box[3]), int(merged_box[0]):int(merged_box[2])]
45
- processed_boxes.add(tuple(box1))
46
- processed_boxes.add(tuple(box2))
47
- break # Exit the inner loop once a match is found
48
- if roi is not None:
49
- break # Exit the outer loop once a match is found
50
-
51
- return roi
52
-
53
- # Load the EfficientNet model
54
- def load_model(model_path):
55
- model = torch.load(model_path)
56
- model = model.to(device)
57
- model.eval() # Set the model to evaluation mode
58
- return model
59
-
60
- # Perform inference on an image
61
- def predict_image(image, model):
62
- # First, get the ROI using YOLO-NAS
63
- image_np = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
64
- predictions = yolo_nas_l.predict(image_np, iou=0.3, conf=0.35)
65
- roi_new = save_merged_boxes(predictions, image_np)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
- if roi_new is None:
68
- roi_new = image_np # Use the original image if no ROI is found
69
-
70
- # Convert ROI back to PIL Image for EfficientNet
71
- roi_image = Image.fromarray(cv2.cvtColor(roi_new, cv2.COLOR_BGR2RGB))
72
-
73
- # Define the image transformations
74
- transform = transforms.Compose([
75
- transforms.Resize(256),
76
- transforms.CenterCrop(224),
77
- transforms.ToTensor(),
78
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
79
- ])
80
-
81
- # Convert PIL Image to Tensor
82
- roi_image_tensor = transform(roi_image).unsqueeze(0).to(device)
83
-
84
- with torch.no_grad():
85
- outputs = model(roi_image_tensor)
86
- _, predicted = outputs.max(1)
87
- prediction_text = 'Accident' if predicted.item() == 0 else 'No accident'
88
 
89
- return roi_image, prediction_text # Return both the roi_image and the prediction text
90
-
91
- # Load the EfficientNet model outside the function to avoid loading it multiple times
92
- model_path = 'vehicle.pt'
93
- model = load_model(model_path)
94
-
95
- # Gradio UI
96
- title = "Vehicle Collision Classification"
97
- description = "Upload an image to determine if it depicts a vehicle accident. Powered by EfficientNet."
98
- examples = [["roi_none.png"], ["test2.jpeg"]]
99
-
100
- gr.Interface(fn=lambda img: predict_image(img, model),
101
- inputs=gr.inputs.Image(type="pil"),
102
- outputs=[gr.outputs.Image(type="pil"), "text"], # Updated outputs to handle both image and text
103
- title=title,
104
- description=description,
105
- examples=examples).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import random
3
+ import numpy as np
4
+ import os
5
+ import requests
6
+ import torch
7
+ import torchvision.transforms as T
8
+ from PIL import Image
9
+ from transformers import AutoProcessor, AutoModelForVision2Seq
10
+ import cv2
11
+ import ast
12
  import torch
13
  from efficientnet_pytorch import EfficientNet
14
  from torchvision import transforms
15
  from PIL import Image
16
  import gradio as gr
17
  from super_gradients.training import models
 
 
18
 
19
+
20
+ class Kosmos2:
21
+ def __init__(self):
22
+ self.colors = [
23
+ (0, 255, 0),
24
+ (0, 0, 255),
25
+ (255, 255, 0),
26
+ (255, 0, 255),
27
+ (0, 255, 255),
28
+ (114, 128, 250),
29
+ (0, 165, 255),
30
+ (0, 128, 0),
31
+ (144, 238, 144),
32
+ (238, 238, 175),
33
+ (255, 191, 0),
34
+ (0, 128, 0),
35
+ (226, 43, 138),
36
+ (255, 0, 255),
37
+ (0, 215, 255),
38
+ (255, 0, 0),
39
+ ]
40
+
41
+ self.color_map = {
42
+ f"{color_id}": f"#{hex(color[2])[2:].zfill(2)}{hex(color[1])[2:].zfill(2)}{hex(color[0])[2:].zfill(2)}" for color_id, color in enumerate(self.colors)
43
+ }
44
+
45
+ self.ckpt = "ydshieh/kosmos-2-patch14-224"
46
+ self.model = AutoModelForVision2Seq.from_pretrained(self.ckpt, trust_remote_code=True).to("cuda")
47
+ self.processor = AutoProcessor.from_pretrained(self.ckpt, trust_remote_code=True)
48
+
49
+ def is_overlapping(self, rect1, rect2):
50
+ x1, y1, x2, y2 = rect1
51
+ x3, y3, x4, y4 = rect2
52
+ return not (x2 < x3 or x1 > x4 or y2 < y3 or y1 > y4)
53
+
54
+ def draw_entity_boxes_on_image(self, image, entities, show=False, save_path=None, entity_index=-1):
55
+ """_summary_
56
+ Args:
57
+ image (_type_): image or image path
58
+ collect_entity_location (_type_): _description_
59
+ """
60
+ if isinstance(image, Image.Image):
61
+ image_h = image.height
62
+ image_w = image.width
63
+ image = np.array(image)[:, :, [2, 1, 0]]
64
+ elif isinstance(image, str):
65
+ if os.path.exists(image):
66
+ pil_img = Image.open(image).convert("RGB")
67
+ image = np.array(pil_img)[:, :, [2, 1, 0]]
68
+ image_h = pil_img.height
69
+ image_w = pil_img.width
70
+ else:
71
+ raise ValueError(f"invaild image path, {image}")
72
+ elif isinstance(image, torch.Tensor):
73
+ # pdb.set_trace()
74
+ image_tensor = image.cpu()
75
+ reverse_norm_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073])[:, None, None]
76
+ reverse_norm_std = torch.tensor([0.26862954, 0.26130258, 0.27577711])[:, None, None]
77
+ image_tensor = image_tensor * reverse_norm_std + reverse_norm_mean
78
+ pil_img = T.ToPILImage()(image_tensor)
79
+ image_h = pil_img.height
80
+ image_w = pil_img.width
81
+ image = np.array(pil_img)[:, :, [2, 1, 0]]
82
+ else:
83
+ raise ValueError(f"invaild image format, {type(image)} for {image}")
84
+
85
+ if len(entities) == 0:
86
+ return image
87
+
88
+ indices = list(range(len(entities)))
89
+ if entity_index >= 0:
90
+ indices = [entity_index]
91
+
92
+ # Not to show too many bboxes
93
+ entities = entities[:len(self.color_map)]
94
+
95
+ new_image = image.copy()
96
+ previous_bboxes = []
97
+ # size of text
98
+ text_size = 1
99
+ # thickness of text
100
+ text_line = 1 # int(max(1 * min(image_h, image_w) / 512, 1))
101
+ box_line = 3
102
+ (c_width, text_height), _ = cv2.getTextSize("F", cv2.FONT_HERSHEY_COMPLEX, text_size, text_line)
103
+ base_height = int(text_height * 0.675)
104
+ text_offset_original = text_height - base_height
105
+ text_spaces = 3
106
+
107
+ # num_bboxes = sum(len(x[-1]) for x in entities)
108
+ used_colors = self.colors # random.sample(colors, k=num_bboxes)
109
+
110
+ color_id = -1
111
+ for entity_idx, (entity_name, (start, end), bboxes) in enumerate(entities):
112
+ color_id += 1
113
+ if entity_idx not in indices:
114
+ continue
115
+ for bbox_id, (x1_norm, y1_norm, x2_norm, y2_norm) in enumerate(bboxes):
116
+ # if start is None and bbox_id > 0:
117
+ # color_id += 1
118
+ orig_x1, orig_y1, orig_x2, orig_y2 = int(x1_norm * image_w), int(y1_norm * image_h), int(x2_norm * image_w), int(y2_norm * image_h)
119
+
120
+ # draw bbox
121
+ # random color
122
+ color = used_colors[color_id] # tuple(np.random.randint(0, 255, size=3).tolist())
123
+ new_image = cv2.rectangle(new_image, (orig_x1, orig_y1), (orig_x2, orig_y2), color, box_line)
124
+
125
+ l_o, r_o = box_line // 2 + box_line % 2, box_line // 2 + box_line % 2 + 1
126
+
127
+ x1 = orig_x1 - l_o
128
+ y1 = orig_y1 - l_o
129
+
130
+ if y1 < text_height + text_offset_original + 2 * text_spaces:
131
+ y1 = orig_y1 + r_o + text_height + text_offset_original + 2 * text_spaces
132
+ x1 = orig_x1 + r_o
133
+
134
+ # add text background
135
+ (text_width, text_height), _ = cv2.getTextSize(f" {entity_name}", cv2.FONT_HERSHEY_COMPLEX, text_size, text_line)
136
+ text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2 = x1, y1 - (text_height + text_offset_original + 2 * text_spaces), x1 + text_width, y1
137
+
138
+ for prev_bbox in previous_bboxes:
139
+ while self.is_overlapping((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox):
140
+ text_bg_y1 += (text_height + text_offset_original + 2 * text_spaces)
141
+ text_bg_y2 += (text_height + text_offset_original + 2 * text_spaces)
142
+ y1 += (text_height + text_offset_original + 2 * text_spaces)
143
+
144
+ if text_bg_y2 >= image_h:
145
+ text_bg_y1 = max(0, image_h - (text_height + text_offset_original + 2 * text_spaces))
146
+ text_bg_y2 = image_h
147
+ y1 = image_h
148
+ break
149
+
150
+ alpha = 0.5
151
+ for i in range(text_bg_y1, text_bg_y2):
152
+ for j in range(text_bg_x1, text_bg_x2):
153
+ if i < image_h and j < image_w:
154
+ if j < text_bg_x1 + 1.35 * c_width:
155
+ # original color
156
+ bg_color = color
157
+ else:
158
+ # white
159
+ bg_color = [255, 255, 255]
160
+ new_image[i, j] = (alpha * new_image[i, j] + (1 - alpha) * np.array(bg_color)).astype(np.uint8)
161
+
162
+ cv2.putText(
163
+ new_image, f" {entity_name}", (x1, y1 - text_offset_original - 1 * text_spaces), cv2.FONT_HERSHEY_COMPLEX, text_size, (0, 0, 0), text_line, cv2.LINE_AA
164
+ )
165
+ # previous_locations.append((x1, y1))
166
+ previous_bboxes.append((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2))
167
+
168
+ pil_image = Image.fromarray(new_image[:, :, [2, 1, 0]])
169
+ if save_path:
170
+ pil_image.save(save_path)
171
+ if show:
172
+ pil_image.show()
173
+
174
+ return pil_image
175
+
176
+ def generate_predictions(self, image_input, text_input):
177
+
178
+ # Save the image and load it again to match the original Kosmos-2 demo.
179
+ # (https://github.com/microsoft/unilm/blob/f4695ed0244a275201fff00bee495f76670fbe70/kosmos-2/demo/gradio_app.py#L345-L346)
180
+ user_image_path = "/tmp/user_input_test_image.jpg"
181
+ image_input.save(user_image_path)
182
+ # This might give different results from the original argument `image_input`
183
+ image_input = Image.open(user_image_path)
184
+
185
+ if text_input == "Brief":
186
+ text_input = "<grounding>An image of"
187
+ elif text_input == "Detailed":
188
+ text_input = "<grounding>Describe this image in detail:"
189
+ else:
190
+ text_input = f"<grounding>{text_input}"
191
+
192
+ inputs = self.processor(text=text_input, images=image_input, return_tensors="pt")
193
+
194
+ generated_ids = self.model.generate(
195
+ pixel_values=inputs["pixel_values"].to("cuda"),
196
+ input_ids=inputs["input_ids"][:, :-1].to("cuda"),
197
+ attention_mask=inputs["attention_mask"][:, :-1].to("cuda"),
198
+ img_features=None,
199
+ img_attn_mask=inputs["img_attn_mask"][:, :-1].to("cuda"),
200
+ use_cache=True,
201
+ max_new_tokens=128,
202
+ )
203
+ generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
204
+
205
+ # By default, the generated text is cleanup and the entities are extracted.
206
+ processed_text, entities = self.processor.post_process_generation(generated_text)
207
+
208
+ annotated_image = self.draw_entity_boxes_on_image(image_input, entities, show=False)
209
+
210
+ color_id = -1
211
+ entity_info = []
212
+ filtered_entities = []
213
+ for entity in entities:
214
+ entity_name, (start, end), bboxes = entity
215
+ if start == end:
216
+ # skip bounding bbox without a `phrase` associated
217
+ continue
218
+ color_id += 1
219
+ # for bbox_id, _ in enumerate(bboxes):
220
+ # if start is None and bbox_id > 0:
221
+ # color_id += 1
222
+ entity_info.append(((start, end), color_id))
223
+ filtered_entities.append(entity)
224
+
225
+ colored_text = []
226
+ prev_start = 0
227
+ end = 0
228
+ for idx, ((start, end), color_id) in enumerate(entity_info):
229
+ if start > prev_start:
230
+ colored_text.append((processed_text[prev_start:start], None))
231
+ colored_text.append((processed_text[start:end], f"{color_id}"))
232
+ prev_start = end
233
+
234
+ if end < len(processed_text):
235
+ colored_text.append((processed_text[end:len(processed_text)], None))
236
+
237
+ return annotated_image, colored_text, str(filtered_entities)
238
+
239
+ class VehiclePredictor:
240
+ def __init__(self, model_path):
241
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
242
+ self.yolo_nas_l = models.get("yolo_nas_l", pretrained_weights="coco")
243
+ self.classifier_model = torch.load(model_path)
244
+ self.classifier_model = self.classifier_model.to(self.device)
245
+ self.classifier_model.eval() # Set the model to evaluation mode
246
+
247
+ def bounding_boxes_overlap(self, box1, box2):
248
+ """Check if two bounding boxes overlap or touch."""
249
+ x1, y1, x2, y2 = box1
250
+ x3, y3, x4, y4 = box2
251
+ return not (x3 > x2 or x4 < x1 or y3 > y2 or y4 < y1)
252
 
253
+ def merge_boxes(self, box1, box2):
254
+ """Return the encompassing bounding box of two boxes."""
255
+ x1, y1, x2, y2 = box1
256
+ x3, y3, x4, y4 = box2
257
+ x = min(x1, x3)
258
+ y = min(y1, y3)
259
+ w = max(x2, x4)
260
+ h = max(y2, y4)
261
+ return (x, y, w, h)
 
 
 
 
 
 
 
 
 
 
 
 
262
 
263
+ def save_merged_boxes(self, predictions, image_np):
264
+ """Save merged bounding boxes as separate images."""
265
+ processed_boxes = set()
266
+ roi = None # Initialize roi to None
267
+
268
+ for image_prediction in predictions:
269
+ bboxes = image_prediction.prediction.bboxes_xyxy
270
+ for box1 in bboxes:
271
+ for box2 in bboxes:
272
+ if np.array_equal(box1, box2):
273
+ continue
274
+ if self.bounding_boxes_overlap(box1, box2) and tuple(box1) not in processed_boxes and tuple(box2) not in processed_boxes:
275
+ merged_box = self.merge_boxes(box1, box2)
276
+ roi = image_np[int(merged_box[1]):int(merged_box[3]), int(merged_box[0]):int(merged_box[2])]
277
+ processed_boxes.add(tuple(box1))
278
+ processed_boxes.add(tuple(box2))
279
+ break # Exit the inner loop once a match is found
280
+ if roi is not None:
281
+ break # Exit the outer loop once a match is found
282
+
283
+ return roi
284
+
285
+ # Perform inference on an image
286
+ def predict_image(self, image, model):
287
+ # First, get the ROI using YOLO-NAS
288
+ image_np = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
289
+ predictions = self.yolo_nas_l.predict(image_np, iou=0.3, conf=0.35)
290
+ roi_new = self.save_merged_boxes(predictions, image_np)
291
+
292
+ if roi_new is None:
293
+ roi_new = image_np # Use the original image if no ROI is found
294
+
295
+ # Convert ROI back to PIL Image for EfficientNet
296
+ roi_image = Image.fromarray(cv2.cvtColor(roi_new, cv2.COLOR_BGR2RGB))
297
+
298
+ # Define the image transformations
299
+ transform = transforms.Compose([
300
+ transforms.Resize(256),
301
+ transforms.CenterCrop(224),
302
+ transforms.ToTensor(),
303
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
304
+ ])
305
+
306
+ # Convert PIL Image to Tensor
307
+ roi_image_tensor = transform(roi_image).unsqueeze(0).to(self.device)
308
+
309
+ with torch.no_grad():
310
+ outputs = self.classifier_model(roi_image_tensor)
311
+ _, predicted = outputs.max(1)
312
+ prediction_text = 'Accident' if predicted.item() == 0 else 'No accident'
313
+
314
+ return roi_image, prediction_text # Return both the roi_image and the prediction text
315
+
316
+
317
+ def main():
318
+ kosmos2 = Kosmos2()
319
+ vehicle_predictor = VehiclePredictor('vehicle.pt')
320
+
321
+ with gr.Blocks(title="Advanced Vehicle Contextualization & Collision Prediction", theme=gr.themes.Base()).queue() as demo:
322
+ gr.Markdown(("""
323
+ # Models used -
324
+ Kosmos-2: Grounding Multimodal Large Language Models to the World
325
+ [[Paper]](https://arxiv.org/abs/2306.14824) [[Code]](https://github.com/microsoft/unilm/blob/master/kosmos-2)
326
+ YOLO-NAS [[Code]](https://github.com/Deci-AI/super-gradients/blob/master/YOLONAS.md)
327
+ EfficientNet-b0
328
+ """))
329
+ with gr.Row():
330
+ with gr.Column():
331
+ image_input = gr.Image(type="pil", label="Test Image")
332
+ text_input = gr.Radio(["Brief", "Detailed"], label="Description Type", value="Brief")
333
+ run_button = gr.Button(label="Run", visible=True)
334
+
335
+ with gr.Column():
336
+ image_output_kosmos = gr.Image(type="pil", label="Kosmos-2 Output Image")
337
+ text_output_kosmos = gr.HighlightedText(
338
+ label="Generated Description by Kosmos-2",
339
+ combine_adjacent=False,
340
+ show_legend=True,
341
+ ).style(color_map=kosmos2.color_map)
342
+
343
+ image_output_vehicle = gr.Image(type="pil", label="Collision Predictor Output Image", size=(112, 112))
344
+ text_output_vehicle = gr.Textbox(label="Collision Predictor Result")
345
+
346
+ # record which text span (label) is selected
347
+ selected = gr.Number(-1, show_label=False, placeholder="Selected", visible=False)
348
+
349
+ # record the current `entities`
350
+ entity_output = gr.Textbox(visible=False)
351
+
352
+ # get the current selected span label
353
+ def get_text_span_label(evt: gr.SelectData):
354
+ if evt.value[-1] is None:
355
+ return -1
356
+ return int(evt.value[-1])
357
+ # and set this information to `selected`
358
+ text_output_kosmos.select(get_text_span_label, None, selected)
359
+
360
+ # update output image when we change the span (enity) selection
361
+ def update_output_image(img_input, image_output, entities, idx):
362
+ entities = ast.literal_eval(entities)
363
+ updated_image = kosmos2.draw_entity_boxes_on_image(img_input, entities, entity_index=idx)
364
+ return updated_image
365
+ selected.change(update_output_image, [image_input, image_output_kosmos, entity_output, selected], [image_output_kosmos])
366
+
367
+ def combined_predictions(img, description_type):
368
+ # Kosmos2 predictions
369
+ kosmos_image, kosmos_text, entities = kosmos2.generate_predictions(img, description_type)
370
+
371
+ # VehiclePredictor predictions
372
+ vehicle_image, vehicle_text = vehicle_predictor.predict_image(img, vehicle_predictor.classifier_model)
373
+
374
+ return kosmos_image, kosmos_text, entities, vehicle_image, vehicle_text
375
+
376
+ run_button.click(fn=combined_predictions,
377
+ inputs=[image_input, text_input],
378
+ outputs=[image_output_kosmos, text_output_kosmos, entity_output, image_output_vehicle, text_output_vehicle],
379
+ show_progress=True, queue=True)
380
+
381
+ demo.launch(share=True)
382
+
383
+ if __name__ == "__main__":
384
+ main()