Spaces:
Runtime error
Runtime error
File size: 4,557 Bytes
fc16538 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
import json
import numpy as np
from camviz import BBox3D
from camviz import Camera as CameraCV
from camviz import Draw
from vidar.geometry.camera import Camera
from vidar.geometry.pose import Pose
from vidar.utils.data import make_batch, fold_batch, modrem
from vidar.utils.flip import flip_batch
from vidar.utils.viz import viz_depth, viz_optical_flow, viz_semantic
def change_key(dic, c, n):
steps = sorted(dic.keys())
return steps[(steps.index(c) + n) % len(steps)]
def display_sample(data, flip=False):
tasks = ['rgb', 'depth', 'fwd_optical_flow', 'bwd_optical_flow','semantic']
cam_colors = ['red', 'blu', 'gre', 'yel', 'mag', 'cya'] * 100
data = make_batch(data)
if flip:
data = flip_batch(data)
data = fold_batch(data)
rgb = data['rgb']
intrinsics = data['intrinsics']
depth = data['depth']
pose = data['pose']
pose = Pose.from_dict(pose, to_global=True)
cam = Camera.from_dict(intrinsics, rgb, pose)
num_cams = rgb[0].shape[0]
wh = rgb[0].shape[-2:][::-1]
keys = [key for key in tasks if key in data.keys()]
points = {}
for key, val in cam.items():
points[key] = cam[key].reconstruct_depth_map(
depth[key], to_world=True).reshape(num_cams, 3, -1).permute(0, 2, 1)
draw = Draw((wh[0] * 4, wh[1] * 3), width=2100)
draw.add2DimageGrid('cam', (0.0, 0.0, 0.5, 1.0), n=(3, 2), res=wh)
draw.add3Dworld('wld', (0.5, 0.0, 1.0, 1.0), pose=cam[0].Tcw.T[0])
draw.addTexture('cam', n=num_cams)
draw.addBuffer3f('lidar', 1000000, n=num_cams)
draw.addBuffer3f('color', 1000000, n=num_cams)
with_bbox3d = 'bbox3d' in data
if with_bbox3d:
bbox3d_corners = [[BBox3D(b) for b in bb] for bb in data['bbox3d']['corners']]
with_pointcache = 'pointcache' in data
if with_pointcache:
pointcache = np.concatenate([np.concatenate(pp, 0) for pp in data['pointcache']['points']], 0)
draw.addBufferf('pointcache', pointcache[:, :3])
camcv = []
for i in range(num_cams):
camcv.append({key: CameraCV.from_vidar(val, i) for key, val in cam.items()})
t, k = 0, 0
key = keys[k]
change = True
color = True
while draw.input():
if draw.SPACE:
color = not color
change = True
if draw.RIGHT:
change = True
k = (k + 1) % len(keys)
while t not in data[keys[k]].keys():
k = (k + 1) % len(keys)
key = keys[k]
if draw.LEFT:
change = True
k = (k - 1) % len(keys)
while t not in data[keys[k]].keys():
k = (k - 1) % len(keys)
key = keys[k]
if draw.UP:
change = True
t = change_key(data[key], t, 1)
while t not in data[keys[k]].keys():
t = change_key(data[key], t, 1)
if draw.DOWN:
change = True
t = change_key(data[key], t, -1)
while t not in data[keys[k]].keys():
t = change_key(data[key], t, -1)
if change:
change = False
for i in range(num_cams):
img = data[key][t][i]
if key == 'depth':
img = viz_depth(img, filter_zeros=True)
elif key in ['fwd_optical_flow', 'bwd_optical_flow']:
img = viz_optical_flow(img)
elif key == 'semantic':
ontology = json.load(open('vidar/datasets/ontologies/%s.json' % data['tag'][0]))
img = viz_semantic(img, ontology)
draw.updTexture('cam%d' % i, img)
draw.updBufferf('lidar%d' % i, points[t][i])
draw.updBufferf('color%d' % i, data['rgb'][t][i])
draw.clear()
for i in range(num_cams):
draw['cam%d%d' % modrem(i, 2)].image('cam%d' % i)
draw['wld'].size(1).color(cam_colors[i]).points('lidar%d' % i, ('color%d' % i) if color else None)
for cam_key, cam_val in camcv[i].items():
clr = cam_colors[i] if cam_key == t else 'gra'
tex = 'cam%d' % i if cam_key == t else None
draw['wld'].object(cam_val, color=clr, tex=tex)
if with_bbox3d:
[[draw['wld'].object(b) for b in bb] for bb in bbox3d_corners]
if with_pointcache:
draw['wld'].color('whi').points('pointcache')
draw.update(30)
|