|
|
|
import os |
|
import json |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from pathlib import Path |
|
from pprint import pprint |
|
from omegaconf import OmegaConf |
|
from PIL import Image, ImageDraw |
|
import streamlit as st |
|
import random |
|
|
|
os.environ['ROOT'] = os.path.dirname(os.path.realpath(__file__)) |
|
|
|
|
|
|
|
def get_list_folder(PATH): |
|
return [name for name in os.listdir(PATH) if os.path.isdir(os.path.join(PATH, name))] |
|
|
|
def get_file_only(PATH): |
|
return [f for f in os.listdir(PATH) if os.path.isfile(os.path.join(PATH, f))] |
|
|
|
|
|
class ImageRetriever: |
|
|
|
def __init__(self, root_path, anno_path): |
|
self.root_path = Path(root_path) |
|
self.anno_path = Path(anno_path) |
|
|
|
def key2img_path(self, key): |
|
file_paths = [ |
|
self.root_path / f"var_images/{key}.jpg", |
|
self.root_path / f"var_images/{key}.png", |
|
self.root_path / f"images/{key}.jpg", |
|
self.root_path / f"img/train/{key.split('_')[0]}/{key}.png", |
|
self.root_path / f"img/val/{key.split('_')[0]}/{key}.png", |
|
self.root_path / f"img/test/{key.split('_')[0]}/{key}.png", |
|
self.root_path / f"img/{key}.png", |
|
self.root_path / f"img/{key}.jpg", |
|
self.root_path / f"{key}.png", |
|
self.root_path / f"{key}.jpg", |
|
] |
|
for file_path in file_paths: |
|
if file_path.exists(): |
|
return file_path |
|
|
|
|
|
def key2img(self, key, temp_data, draw_bbox=True): |
|
file_path = self.key2img_path(key) |
|
|
|
image = Image.open(file_path) |
|
|
|
if draw_bbox: |
|
bboxes = [temp_data['bounding_box'].get(str(box_idx + 1), None) for box_idx in range(3)] |
|
image = self.hide_region(image, bboxes) |
|
return image |
|
|
|
def hide_region(self, image, bboxes): |
|
self.hide_true_bbox = 2 |
|
|
|
image = image.convert('RGBA') |
|
|
|
if self.hide_true_bbox == 1: |
|
draw = ImageDraw.Draw(image, 'RGBA') |
|
|
|
if self.hide_true_bbox in [2, 5, 7, 8, 9]: |
|
overlay = Image.new('RGBA', image.size, '#00000000') |
|
draw = ImageDraw.Draw(overlay, 'RGBA') |
|
|
|
if self.hide_true_bbox == 3 or self.hide_true_bbox == 6: |
|
overlay = Image.new('RGBA', image.size, '#7B7575ff') |
|
draw = ImageDraw.Draw(overlay, 'RGBA') |
|
|
|
color_fill_list = ['#ff05cd3c', '#00F1E83c', '#F2D4003c'] |
|
|
|
for idx, bbox in enumerate(bboxes): |
|
if bbox == None: |
|
continue |
|
|
|
color_fill = color_fill_list[idx] |
|
x, y = bbox['left'], bbox['top'] |
|
|
|
if self.hide_true_bbox == 1: |
|
draw.rectangle([(x, y), (x + bbox['width'], y + bbox['height'])], fill='#7B7575') |
|
elif self.hide_true_bbox in [2, 5, 7, 8, 9]: |
|
draw.rectangle([(x, y), (x + bbox['width'], y + bbox['height'])], fill=color_fill, outline='#05ff37ff', |
|
width=3) |
|
elif self.hide_true_bbox == 3: |
|
draw.rectangle([(x, y), (x + bbox['width'], y + bbox['height'])], fill='#00000000') |
|
elif self.hide_true_bbox == 6: |
|
draw.rectangle([(x, y), (x + bbox['width'], y + bbox['height'])], fill=color_fill) |
|
|
|
if self.hide_true_bbox in [2, 3, 5, 6, 7, 8, 9]: |
|
image = Image.alpha_composite(image, overlay) |
|
return image |
|
|
|
def retrive_data(temp_data, img_key, mode='direct'): |
|
|
|
|
|
|
|
message_dict = {} |
|
|
|
message_dict['img'] = temp_data['img'] |
|
message_dict['plausible_speed'] = temp_data['plausible_speed'] |
|
message_dict['bounding_box'] = temp_data['bounding_box'] |
|
try: |
|
message_dict['hazard'] = temp_data['hazard'] |
|
except: |
|
message_dict['hazard'] = temp_data['rationale'] |
|
message_dict['Entity #1'] = temp_data['Entity #1'] |
|
message_dict['Entity #2'] = temp_data['Entity #2'] |
|
message_dict['Entity #3'] = temp_data['Entity #3'] |
|
|
|
img_retriever = ImageRetriever( |
|
root_path=os.path.join(os.environ['ROOT'], ''), |
|
anno_path=os.path.join(os.environ['ROOT'], f'data/anno_{split}_{mode}.json'), |
|
) |
|
img = img_retriever.key2img(img_key, temp_data) |
|
img = img.resize((img.width // 2, img.height // 2)) |
|
|
|
return img, message_dict |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
st.title("DHPR: Driving Hazard Prediction and Reasoning") |
|
|
|
img_path = os.path.join(os.environ['ROOT'], 'img/') |
|
img_path_list = get_file_only(img_path) |
|
|
|
split = 'val' |
|
rand_index = 0 |
|
main_direct_dataset = json.load(open(os.path.join(os.environ['ROOT'], f"data/anno_{'val'}_{'direct'}.json"))) |
|
main_indirect_dataset = json.load(open(os.path.join(os.environ['ROOT'], f"data/anno_{'val'}_{'indirect'}.json"))) |
|
|
|
if st.button('Random Data Sample'): |
|
rand_index = random.randint(0, len(get_file_only(img_path))) |
|
else: |
|
pass |
|
|
|
st.subheader("Data Visualization") |
|
|
|
img_key = img_path_list[rand_index].split('.')[0] |
|
|
|
if img_key in main_direct_dataset.keys(): |
|
temp_data = main_direct_dataset[img_key]['details'][-1] |
|
elif img_key in main_indirect_dataset.keys(): |
|
temp_data = main_indirect_dataset[img_key]['details'][-1] |
|
else: |
|
pass |
|
|
|
img, message_dict = retrive_data(temp_data, img_key) |
|
|
|
st.write("---") |
|
|
|
st.image(img) |
|
st.subheader("Annotation Details") |
|
st.json(message_dict) |
|
st.write('---') |
|
|
|
|
|
|
|
|
|
|
|
|