pivot-demo / vip_runner.py
pivot-iterative-visual-optimization's picture
Upload 5 files
f9a62da verified
raw
history blame contribute delete
No virus
4.76 kB
"""VIP."""
import json
import re
import cv2
from tqdm import trange
import numpy as np
import vip
def make_prompt(description, top_n=3):
return f"""
INSTRUCTIONS:
You are tasked to locate an object, region, or point in space in the given annotated image according to a description.
The image is annoated with numbered circles.
Choose the top {top_n} circles that have the most overlap with and/or is closest to what the description is describing in the image.
You are a five-time world champion in this game.
Give a one sentence analysis of why you chose those points.
Provide your answer at the end in a valid JSON of this format:
{{"points": []}}
DESCRIPTION: {description}
IMAGE:
""".strip()
def extract_json(response, key):
json_part = re.search(r"\{.*\}", response, re.DOTALL)
parsed_json = {}
if json_part:
json_data = json_part.group()
# Parse the JSON data
parsed_json = json.loads(json_data)
else:
print("No JSON data found ******\n", response)
return parsed_json[key]
def vip_perform_selection(prompter, vlm, im, desc, arm_coord, samples, top_n):
"""Perform one selection pass given samples."""
image_circles_np = prompter.add_arrow_overlay_plt(
image=im, samples=samples, arm_xy=arm_coord
)
_, encoded_image_circles = cv2.imencode(".png", image_circles_np)
prompt_seq = [make_prompt(desc, top_n=top_n), encoded_image_circles]
response = vlm.query(prompt_seq)
try:
arrow_ids = extract_json(response, "points")
except Exception as e:
print(e)
arrow_ids = []
return arrow_ids, image_circles_np
def vip_runner(
vlm,
im,
desc,
style,
action_spec,
n_samples_init=25,
n_samples_opt=10,
n_iters=3,
n_parallel_trials=1,
):
"""VIP."""
prompter = vip.VisualIterativePrompter(
style, action_spec, vip.SupportedEmbodiments.HF_DEMO
)
output_ims = []
arm_coord = (int(im.shape[1] / 2), int(im.shape[0] / 2))
new_samples = []
center_mean = action_spec["loc"]
for i in range(n_parallel_trials):
center_mean = action_spec["loc"]
center_std = action_spec["scale"]
for itr in trange(n_iters):
if itr == 0:
style["num_samples"] = n_samples_init
else:
style["num_samples"] = n_samples_opt
samples = prompter.sample_actions(im, arm_coord, center_mean, center_std)
arrow_ids, image_circles_np = vip_perform_selection(
prompter, vlm, im, desc, arm_coord, samples, top_n=3
)
# plot sampled circles as red
selected_samples = []
for selected_id in arrow_ids:
sample = samples[selected_id]
sample.coord.color = (255, 0, 0)
selected_samples.append(sample)
image_circles_marked_np = prompter.add_arrow_overlay_plt(
image_circles_np, selected_samples, arm_coord
)
output_ims.append(image_circles_marked_np)
yield output_ims, f"Image generated for parallel sample {i+1}/{n_parallel_trials} iteration {itr+1}/{n_iters}. Still working..."
# if at last iteration, pick one answer out of the selected ones
if itr == n_iters - 1:
arrow_ids, _ = vip_perform_selection(
prompter, vlm, im, desc, arm_coord, selected_samples, top_n=1
)
selected_samples = []
for selected_id in arrow_ids:
sample = samples[selected_id]
sample.coord.color = (255, 0, 0)
selected_samples.append(sample)
image_circles_marked_np = prompter.add_arrow_overlay_plt(
im, selected_samples, arm_coord
)
output_ims.append(image_circles_marked_np)
new_samples += selected_samples
yield output_ims, f"Image generated for parallel sample {i+1}/{n_parallel_trials} last iteration. Still working..."
center_mean, center_std = prompter.fit(arrow_ids, samples)
if n_parallel_trials > 1:
# adjust sample label to avoid duplications
for sample_id in range(len(new_samples)):
new_samples[sample_id].label = str(sample_id)
arrow_ids, _ = vip_perform_selection(
prompter, vlm, im, desc, arm_coord, new_samples, top_n=1
)
selected_samples = []
for selected_id in arrow_ids:
sample = new_samples[selected_id]
sample.coord.color = (255, 0, 0)
selected_samples.append(sample)
image_circles_marked_np = prompter.add_arrow_overlay_plt(
im, selected_samples, arm_coord
)
output_ims.append(image_circles_marked_np)
center_mean, _ = prompter.fit(arrow_ids, new_samples)
if output_ims:
yield (
output_ims,
(
"Final selected coordinate:"
f" {np.round(prompter.action_to_coord(center_mean, im, arm_coord).xy, decimals=0)}"
),
)
return [], "Unable to understand query"