|
"""Visual Iterative Prompting functions. |
|
|
|
Code to implement visual iterative prompting, an approach for querying VLMs. |
|
""" |
|
|
|
import copy |
|
import dataclasses |
|
import enum |
|
import io |
|
from typing import Optional, Tuple |
|
import cv2 |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import scipy.stats |
|
import vip_utils |
|
|
|
|
|
@enum.unique |
|
class SupportedEmbodiments(str, enum.Enum): |
|
"""Embodiments supported by VIP.""" |
|
|
|
HF_DEMO = 'hf_demo' |
|
|
|
|
|
@dataclasses.dataclass() |
|
class Coordinate: |
|
"""Coordinate with necessary information for visualizing annotation.""" |
|
|
|
|
|
xy: Tuple[int, int] |
|
|
|
color: Optional[float] = None |
|
radius: Optional[int] = None |
|
|
|
|
|
@dataclasses.dataclass() |
|
class Sample: |
|
"""Single Sample mapping actions to Coordinates.""" |
|
|
|
|
|
action: np.ndarray |
|
|
|
coord: Coordinate |
|
|
|
text_coord: Coordinate |
|
|
|
label: str |
|
|
|
|
|
class VisualIterativePrompter: |
|
"""Visual Iterative Prompting class.""" |
|
|
|
def __init__(self, style, action_spec, embodiment): |
|
self.embodiment = embodiment |
|
self.style = style |
|
self.action_spec = action_spec |
|
self.fig_scale_size = None |
|
|
|
|
|
|
|
def action_to_coord(self, action, image, arm_xy, do_project=False): |
|
"""Converts candidate action to image coordinate.""" |
|
return self.navigation_action_to_coord( |
|
action=action, image=image, center_xy=arm_xy, do_project=do_project |
|
) |
|
|
|
def navigation_action_to_coord( |
|
self, action, image, center_xy, do_project=False |
|
): |
|
"""Converts a ZXY or XY action to an image coordinate. |
|
|
|
Conversion is done based on style['focal_offset'] and action_spec['scale']. |
|
|
|
Args: |
|
action: z, y, x action in robot action space |
|
image: image |
|
center_xy: x, y in image space |
|
do_project: whether or not to project actions sampled outside the image to |
|
the edge of the image |
|
|
|
Returns: |
|
Dict coordinate with image x, y, arrow color, and circle radius. |
|
""" |
|
if self.action_spec['scale'][0] == 0: |
|
norm_action = [ |
|
(action[d] - self.action_spec['loc'][d]) |
|
/ (2 * self.action_spec['scale'][d]) |
|
for d in range(1, 3) |
|
] |
|
norm_action_y, norm_action_x = norm_action |
|
norm_action_z = 0 |
|
else: |
|
norm_action = [ |
|
(action[d] - self.action_spec['loc'][d]) |
|
/ (2 * self.action_spec['scale'][d]) |
|
for d in range(3) |
|
] |
|
norm_action_z, norm_action_y, norm_action_x = norm_action |
|
focal_length = np.max([ |
|
0.2, |
|
self.style['focal_offset'] |
|
/ (self.style['focal_offset'] + norm_action_z), |
|
]) |
|
image_x = center_xy[0] - ( |
|
self.action_spec['action_to_coord'] * norm_action_x * focal_length |
|
) |
|
image_y = center_xy[1] - ( |
|
self.action_spec['action_to_coord'] * norm_action_y * focal_length |
|
) |
|
if ( |
|
vip_utils.coord_outside_image( |
|
Coordinate(xy=(image_x, image_y)), image, self.style['radius'] |
|
) |
|
and do_project |
|
): |
|
|
|
height, width, _ = image.shape |
|
max_x = ( |
|
width - center_xy[0] - 2 * self.style['radius'] |
|
if norm_action_x < 0 |
|
else center_xy[0] - 2 * self.style['radius'] |
|
) |
|
max_y = ( |
|
height - center_xy[1] - 2 * self.style['radius'] |
|
if norm_action_y < 0 |
|
else center_xy[1] - 2 * self.style['radius'] |
|
) |
|
rescale_ratio = min( |
|
np.abs([ |
|
max_x / (self.action_spec['action_to_coord'] * norm_action_x), |
|
max_y / (self.action_spec['action_to_coord'] * norm_action_y), |
|
]) |
|
) |
|
image_x = ( |
|
center_xy[0] |
|
- self.action_spec['action_to_coord'] * norm_action_x * rescale_ratio |
|
) |
|
image_y = ( |
|
center_xy[1] |
|
- self.action_spec['action_to_coord'] * norm_action_y * rescale_ratio |
|
) |
|
|
|
return Coordinate( |
|
xy=(int(image_x), int(image_y)), |
|
color=0.1 * self.style['rgb_scale'], |
|
radius=int(self.style['radius']), |
|
) |
|
|
|
def sample_actions( |
|
self, image, arm_xy, loc, scale, true_action=None, max_itrs=1000 |
|
): |
|
"""Sample actions from distribution. |
|
|
|
Args: |
|
image: image |
|
arm_xy: x, y in image space of arm |
|
loc: action distribution mean to sample from |
|
scale: action distribution variance to sample from |
|
true_action: action taken in demonstration if available |
|
max_itrs: number of tries to get a valid sample |
|
|
|
Returns: |
|
samples: Samples with associated actions, coords, text_coords, labels. |
|
""" |
|
image = copy.deepcopy(image) |
|
|
|
samples = [] |
|
actions = [] |
|
coords = [] |
|
text_coords = [] |
|
labels = [] |
|
|
|
|
|
true_label = None |
|
if true_action is not None: |
|
actions.append(true_action) |
|
coord = self.action_to_coord(true_action, image, arm_xy) |
|
coords.append(coord) |
|
text_coords.append( |
|
vip_utils.coord_to_text_coord(coords[-1], arm_xy, coord.radius) |
|
) |
|
true_label = np.random.randint(self.style['num_samples']) |
|
|
|
labels.append(str(true_label)) |
|
|
|
|
|
for i in range(self.style['num_samples']): |
|
if i == true_label: |
|
continue |
|
itrs = 0 |
|
|
|
|
|
action = np.clip( |
|
np.random.normal(loc, scale), |
|
self.action_spec['min'], |
|
self.action_spec['max'], |
|
) |
|
|
|
|
|
coord = self.action_to_coord(action, image, arm_xy) |
|
|
|
|
|
adjusted_scale = np.array(scale) |
|
while ( |
|
vip_utils.is_invalid_coord( |
|
coord, coords, self.style['radius'] * 1.5, image |
|
) |
|
or vip_utils.coord_outside_image(coord, image, self.style['radius']) |
|
) and itrs < max_itrs: |
|
action = np.clip( |
|
np.random.normal(loc, adjusted_scale), |
|
self.action_spec['min'], |
|
self.action_spec['max'], |
|
) |
|
coord = self.action_to_coord(action, image, arm_xy) |
|
itrs += 1 |
|
|
|
adjusted_scale *= 1.1 |
|
if itrs == max_itrs: |
|
|
|
|
|
coord = self.action_to_coord(action, image, arm_xy, do_project=True) |
|
|
|
|
|
radius = coord.radius |
|
text_coord = Coordinate( |
|
xy=vip_utils.coord_to_text_coord(coord, arm_xy, radius) |
|
) |
|
|
|
actions.append(action) |
|
coords.append(coord) |
|
text_coords.append(text_coord) |
|
labels.append(str(i)) |
|
|
|
for i in range(len(actions)): |
|
sample = Sample( |
|
action=actions[i], |
|
coord=coords[i], |
|
text_coord=text_coords[i], |
|
label=str(i), |
|
) |
|
samples.append(sample) |
|
return samples |
|
|
|
def add_arrow_overlay_plt(self, image, samples, arm_xy): |
|
"""Add arrows and circles to the image. |
|
|
|
Args: |
|
image: image |
|
samples: Samples to visualize. |
|
arm_xy: x, y image coordinates for EEF center. |
|
log_image: Boolean for whether to save to CNS. |
|
|
|
Returns: |
|
image: image with visual prompts. |
|
""" |
|
|
|
overlay = image.copy() |
|
(original_image_height, original_image_width, _) = image.shape |
|
|
|
white = ( |
|
self.style['rgb_scale'], |
|
self.style['rgb_scale'], |
|
self.style['rgb_scale'], |
|
) |
|
|
|
|
|
for sample in samples: |
|
color = sample.coord.color |
|
cv2.arrowedLine( |
|
overlay, arm_xy, sample.coord.xy, color, self.style['thickness'] |
|
) |
|
image = cv2.addWeighted( |
|
overlay, |
|
self.style['arrow_alpha'], |
|
image, |
|
1 - self.style['arrow_alpha'], |
|
0, |
|
) |
|
|
|
overlay = image.copy() |
|
|
|
for sample in samples: |
|
color = sample.coord.color |
|
radius = sample.coord.radius |
|
cv2.circle( |
|
overlay, |
|
sample.text_coord.xy, |
|
radius, |
|
color, |
|
self.style['thickness'] + 1, |
|
) |
|
cv2.circle(overlay, sample.text_coord.xy, radius, white, -1) |
|
image = cv2.addWeighted( |
|
overlay, |
|
self.style['circle_alpha'], |
|
image, |
|
1 - self.style['circle_alpha'], |
|
0, |
|
) |
|
|
|
dpi = plt.rcParams['figure.dpi'] |
|
if self.fig_scale_size is None: |
|
|
|
fig_size = (original_image_width / dpi, original_image_height / dpi) |
|
plt.subplots(1, figsize=fig_size) |
|
plt.imshow(image, cmap='binary') |
|
plt.axis('off') |
|
fig = plt.gcf() |
|
fig.tight_layout(pad=0) |
|
buf = io.BytesIO() |
|
plt.savefig(buf, format='png') |
|
plt.close() |
|
buf.seek(0) |
|
test_image = cv2.imdecode( |
|
np.frombuffer(buf.getvalue(), dtype=np.uint8), cv2.IMREAD_COLOR |
|
) |
|
self.fig_scale_size = original_image_width / test_image.shape[1] |
|
|
|
|
|
fig_size = ( |
|
self.fig_scale_size * original_image_width / dpi, |
|
self.fig_scale_size * original_image_height / dpi, |
|
) |
|
plt.subplots(1, figsize=fig_size) |
|
plt.imshow(image, cmap='binary') |
|
for sample in samples: |
|
plt.text( |
|
sample.text_coord.xy[0], |
|
sample.text_coord.xy[1], |
|
sample.label, |
|
ha='center', |
|
va='center', |
|
color='k', |
|
fontsize=self.style['fontsize'], |
|
) |
|
|
|
|
|
plt.axis('off') |
|
fig = plt.gcf() |
|
fig.tight_layout(pad=0) |
|
buf = io.BytesIO() |
|
plt.savefig(buf, format='png') |
|
plt.close() |
|
image = cv2.imdecode( |
|
np.frombuffer(buf.getvalue(), dtype=np.uint8), cv2.IMREAD_COLOR |
|
) |
|
|
|
image = cv2.resize(image, (original_image_width, original_image_height)) |
|
image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) |
|
|
|
return image |
|
|
|
def fit(self, values, samples): |
|
"""Fit a loc and scale to selected actions. |
|
|
|
Args: |
|
values: list of selected labels |
|
samples: list of all Samples |
|
|
|
Returns: |
|
loc: mean of selected distribution |
|
scale: variance of selected distribution |
|
""" |
|
actions = [sample.action for sample in samples] |
|
labels = [sample.label for sample in samples] |
|
|
|
if not values: |
|
print('GPT failed to return integer arrows') |
|
loc = self.action_spec['loc'] |
|
scale = self.action_spec['scale'] |
|
elif len(values) == 1: |
|
index = np.where([label == str(values[-1]) for label in labels])[0][0] |
|
action = actions[index] |
|
print('action', action) |
|
loc = action |
|
scale = self.action_spec['min_scale'] |
|
else: |
|
selected_actions = [] |
|
for value in values: |
|
idx = np.where([label == str(value) for label in labels])[0][0] |
|
selected_actions.append(actions[idx]) |
|
print('selected_actions', selected_actions) |
|
|
|
loc_scale = [ |
|
scipy.stats.norm.fit([action[d] for action in selected_actions]) |
|
for d in range(3) |
|
] |
|
loc = [loc_scale[d][0] for d in range(3)] |
|
scale = np.clip( |
|
[loc_scale[d][1] for d in range(3)], |
|
self.action_spec['min_scale'], |
|
None, |
|
) |
|
print('loc', loc, '\nscale', scale) |
|
|
|
return loc, scale |
|
|