|
|
|
import os |
|
import json |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from pathlib import Path |
|
from pprint import pprint |
|
from omegaconf import OmegaConf |
|
from PIL import Image, ImageDraw |
|
import streamlit as st |
|
|
|
|
|
os.environ['ROOT'] = '/mnt/Documents/traffic_var_server/visualization' |
|
print("os.environ['ROOT'] :",os.environ['ROOT']) |
|
|
|
class ImageRetriever: |
|
|
|
def __init__(self, root_path, anno_path): |
|
self.root_path = Path(root_path) |
|
self.anno = json.load(open(anno_path)) |
|
|
|
def key2img_path(self, key): |
|
file_paths = [ |
|
self.root_path / f"var_images/{key}.jpg", |
|
self.root_path / f"var_images/{key}.png", |
|
self.root_path / f"images/{key}.jpg", |
|
self.root_path / f"img/train/{key.split('_')[0]}/{key}.png", |
|
self.root_path / f"img/val/{key.split('_')[0]}/{key}.png", |
|
self.root_path / f"img/test/{key.split('_')[0]}/{key}.png", |
|
self.root_path / f"img/{key}.png", |
|
self.root_path / f"img/{key}.jpg", |
|
self.root_path / f"{key}.png", |
|
self.root_path / f"{key}.jpg", |
|
] |
|
print("file_paths!!!!!!!!", file_paths) |
|
for file_path in file_paths: |
|
if file_path.exists(): |
|
return file_path |
|
|
|
|
|
def key2img(self, key, draw_bbox=True): |
|
file_path = self.key2img_path(key) |
|
|
|
print("file_path!!@@@@", key, file_path) |
|
|
|
image = Image.open(file_path) |
|
if draw_bbox: |
|
meta = self.anno[key]['details'][-1] |
|
bboxes = [meta['bounding_box'].get(str(box_idx + 1), None) for box_idx in range(3)] |
|
image = self.hide_region(image, bboxes) |
|
return image |
|
|
|
def hide_region(self, image, bboxes): |
|
self.hide_true_bbox = 2 |
|
|
|
image = image.convert('RGBA') |
|
|
|
if self.hide_true_bbox == 1: |
|
draw = ImageDraw.Draw(image, 'RGBA') |
|
|
|
if self.hide_true_bbox in [2, 5, 7, 8, 9]: |
|
overlay = Image.new('RGBA', image.size, '#00000000') |
|
draw = ImageDraw.Draw(overlay, 'RGBA') |
|
|
|
if self.hide_true_bbox == 3 or self.hide_true_bbox == 6: |
|
overlay = Image.new('RGBA', image.size, '#7B7575ff') |
|
draw = ImageDraw.Draw(overlay, 'RGBA') |
|
|
|
color_fill_list = ['#ff05cd3c', '#00F1E83c', '#F2D4003c'] |
|
|
|
for idx, bbox in enumerate(bboxes): |
|
if bbox == None: |
|
continue |
|
|
|
color_fill = color_fill_list[idx] |
|
x, y = bbox['left'], bbox['top'] |
|
|
|
if self.hide_true_bbox == 1: |
|
draw.rectangle([(x, y), (x + bbox['width'], y + bbox['height'])], fill='#7B7575') |
|
elif self.hide_true_bbox in [2, 5, 7, 8, 9]: |
|
draw.rectangle([(x, y), (x + bbox['width'], y + bbox['height'])], fill=color_fill, outline='#05ff37ff', |
|
width=3) |
|
elif self.hide_true_bbox == 3: |
|
draw.rectangle([(x, y), (x + bbox['width'], y + bbox['height'])], fill='#00000000') |
|
elif self.hide_true_bbox == 6: |
|
draw.rectangle([(x, y), (x + bbox['width'], y + bbox['height'])], fill=color_fill) |
|
|
|
if self.hide_true_bbox in [2, 3, 5, 6, 7, 8, 9]: |
|
image = Image.alpha_composite(image, overlay) |
|
return image |
|
|
|
def retrive_data(file_index, mode='direct'): |
|
split = 'val' |
|
mode = mode.lower() |
|
main_dataset = json.load(open(os.path.join(os.environ['ROOT'], f"data/anno_{split}_{mode}.json"))) |
|
temp_data = main_dataset[list(main_dataset.keys())[file_index]]['details'][-1] |
|
|
|
message_dict = {} |
|
|
|
message_dict['img'] = temp_data['img'] |
|
message_dict['plausible_speed'] = temp_data['plausible_speed'] |
|
message_dict['bounding_box'] = temp_data['bounding_box'] |
|
try: |
|
message_dict['hazard'] = temp_data['hazard'] |
|
except: |
|
message_dict['hazard'] = temp_data['rationale'] |
|
message_dict['Entity #1'] = temp_data['Entity #1'] |
|
message_dict['Entity #2'] = temp_data['Entity #2'] |
|
message_dict['Entity #3'] = temp_data['Entity #3'] |
|
|
|
img_retriever = ImageRetriever( |
|
root_path=os.path.join(os.environ['ROOT'], ''), |
|
anno_path=os.path.join(os.environ['ROOT'], f'data/anno_{split}_{mode}.json'), |
|
) |
|
img = img_retriever.key2img(list(main_dataset.keys())[file_index]) |
|
img = img.resize((img.width // 2, img.height // 2)) |
|
|
|
return img, message_dict |
|
|
|
|
|
if __name__ == '__main__': |
|
st.title("DHPR: Driving Hazard Prediction and Reasoning") |
|
st.subheader("Data Visualization") |
|
|
|
option = st.selectbox( |
|
'Select the hazard type', |
|
('Direct', 'Indirect')) |
|
|
|
st.write('You selected:', option) |
|
|
|
image_index = st.slider('Please Select The Image Index', 0, 999, 0) |
|
st.write("You select the data index of ", image_index," for visualization from the validation set") |
|
|
|
img, message_dict = retrive_data(image_index, option) |
|
|
|
st.write("---") |
|
|
|
st.image(img) |
|
st.subheader("Annotation Details") |
|
st.json(message_dict) |
|
st.write('---') |
|
|
|
|
|
|
|
|
|
|
|
|