Demo / visualize_data.py
DHPR's picture
Upload 3 files
f704fbe
raw
history blame
5.28 kB
# %%
import os
import json
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from pprint import pprint
from omegaconf import OmegaConf
from PIL import Image, ImageDraw
import streamlit as st
# %%
os.environ['ROOT'] = '/mnt/Documents/traffic_var_server/visualization'
print("os.environ['ROOT'] :",os.environ['ROOT'])
# %%
class ImageRetriever:
def __init__(self, root_path, anno_path):
self.root_path = Path(root_path)
self.anno = json.load(open(anno_path))
def key2img_path(self, key):
file_paths = [
self.root_path / f"var_images/{key}.jpg",
self.root_path / f"var_images/{key}.png",
self.root_path / f"images/{key}.jpg",
self.root_path / f"img/train/{key.split('_')[0]}/{key}.png",
self.root_path / f"img/val/{key.split('_')[0]}/{key}.png",
self.root_path / f"img/test/{key.split('_')[0]}/{key}.png",
self.root_path / f"img/{key}.png",
self.root_path / f"img/{key}.jpg",
self.root_path / f"{key}.png",
self.root_path / f"{key}.jpg",
]
print("file_paths!!!!!!!!", file_paths)
for file_path in file_paths:
if file_path.exists():
return file_path
def key2img(self, key, draw_bbox=True):
file_path = self.key2img_path(key)
print("file_path!!@@@@", key, file_path)
image = Image.open(file_path)
if draw_bbox:
meta = self.anno[key]['details'][-1]
bboxes = [meta['bounding_box'].get(str(box_idx + 1), None) for box_idx in range(3)]
image = self.hide_region(image, bboxes)
return image
def hide_region(self, image, bboxes):
self.hide_true_bbox = 2
image = image.convert('RGBA')
if self.hide_true_bbox == 1: # hide mode
draw = ImageDraw.Draw(image, 'RGBA')
if self.hide_true_bbox in [2, 5, 7, 8, 9]: #highlight mode
overlay = Image.new('RGBA', image.size, '#00000000')
draw = ImageDraw.Draw(overlay, 'RGBA')
if self.hide_true_bbox == 3 or self.hide_true_bbox == 6: #blackout mode or position only mode
overlay = Image.new('RGBA', image.size, '#7B7575ff')
draw = ImageDraw.Draw(overlay, 'RGBA')
color_fill_list = ['#ff05cd3c', '#00F1E83c', '#F2D4003c'] # Green, Blue, Yellow?
for idx, bbox in enumerate(bboxes):
if bbox == None:
continue
color_fill = color_fill_list[idx]
x, y = bbox['left'], bbox['top']
if self.hide_true_bbox == 1: # hide mode
draw.rectangle([(x, y), (x + bbox['width'], y + bbox['height'])], fill='#7B7575')
elif self.hide_true_bbox in [2, 5, 7, 8, 9]: # highlight mode
draw.rectangle([(x, y), (x + bbox['width'], y + bbox['height'])], fill=color_fill, outline='#05ff37ff',
width=3) # Fill with Pink 60% ##00F1E8
elif self.hide_true_bbox == 3: # blackout mode
draw.rectangle([(x, y), (x + bbox['width'], y + bbox['height'])], fill='#00000000')
elif self.hide_true_bbox == 6: # position only mode
draw.rectangle([(x, y), (x + bbox['width'], y + bbox['height'])], fill=color_fill)
if self.hide_true_bbox in [2, 3, 5, 6, 7, 8, 9]:
image = Image.alpha_composite(image, overlay)
return image
def retrive_data(file_index, mode='direct'):
split = 'val'
mode = mode.lower()
main_dataset = json.load(open(os.path.join(os.environ['ROOT'], f"data/anno_{split}_{mode}.json")))
temp_data = main_dataset[list(main_dataset.keys())[file_index]]['details'][-1]
message_dict = {}
message_dict['img'] = temp_data['img']
message_dict['plausible_speed'] = temp_data['plausible_speed']
message_dict['bounding_box'] = temp_data['bounding_box']
try:
message_dict['hazard'] = temp_data['hazard']
except:
message_dict['hazard'] = temp_data['rationale']
message_dict['Entity #1'] = temp_data['Entity #1']
message_dict['Entity #2'] = temp_data['Entity #2']
message_dict['Entity #3'] = temp_data['Entity #3']
img_retriever = ImageRetriever(
root_path=os.path.join(os.environ['ROOT'], ''),
anno_path=os.path.join(os.environ['ROOT'], f'data/anno_{split}_{mode}.json'),
)
img = img_retriever.key2img(list(main_dataset.keys())[file_index])
img = img.resize((img.width // 2, img.height // 2))
return img, message_dict
# %%
if __name__ == '__main__':
st.title("DHPR: Driving Hazard Prediction and Reasoning")
st.subheader("Data Visualization")
option = st.selectbox(
'Select the hazard type',
('Direct', 'Indirect'))
st.write('You selected:', option)
image_index = st.slider('Please Select The Image Index', 0, 999, 0)
st.write("You select the data index of ", image_index," for visualization from the validation set")
img, message_dict = retrive_data(image_index, option)
st.write("---")
st.image(img)
st.subheader("Annotation Details")
st.json(message_dict)
st.write('---')
# !streamlit run visualize_data.py --server.fileWatcherType none
# %%