File size: 4,761 Bytes
5c80958 53ef1bb 5c80958 f9a62da 5c80958 53ef1bb 5c80958 53ef1bb 5c80958 f9a62da 5c80958 53ef1bb 5c80958 53ef1bb 5c80958 53ef1bb 5c80958 53ef1bb 5c80958 53ef1bb 5c80958 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
"""VIP."""
import json
import re
import cv2
from tqdm import trange
import numpy as np
import vip
def make_prompt(description, top_n=3):
return f"""
INSTRUCTIONS:
You are tasked to locate an object, region, or point in space in the given annotated image according to a description.
The image is annoated with numbered circles.
Choose the top {top_n} circles that have the most overlap with and/or is closest to what the description is describing in the image.
You are a five-time world champion in this game.
Give a one sentence analysis of why you chose those points.
Provide your answer at the end in a valid JSON of this format:
{{"points": []}}
DESCRIPTION: {description}
IMAGE:
""".strip()
def extract_json(response, key):
json_part = re.search(r"\{.*\}", response, re.DOTALL)
parsed_json = {}
if json_part:
json_data = json_part.group()
# Parse the JSON data
parsed_json = json.loads(json_data)
else:
print("No JSON data found ******\n", response)
return parsed_json[key]
def vip_perform_selection(prompter, vlm, im, desc, arm_coord, samples, top_n):
"""Perform one selection pass given samples."""
image_circles_np = prompter.add_arrow_overlay_plt(
image=im, samples=samples, arm_xy=arm_coord
)
_, encoded_image_circles = cv2.imencode(".png", image_circles_np)
prompt_seq = [make_prompt(desc, top_n=top_n), encoded_image_circles]
response = vlm.query(prompt_seq)
try:
arrow_ids = extract_json(response, "points")
except Exception as e:
print(e)
arrow_ids = []
return arrow_ids, image_circles_np
def vip_runner(
vlm,
im,
desc,
style,
action_spec,
n_samples_init=25,
n_samples_opt=10,
n_iters=3,
n_parallel_trials=1,
):
"""VIP."""
prompter = vip.VisualIterativePrompter(
style, action_spec, vip.SupportedEmbodiments.HF_DEMO
)
output_ims = []
arm_coord = (int(im.shape[1] / 2), int(im.shape[0] / 2))
new_samples = []
center_mean = action_spec["loc"]
for i in range(n_parallel_trials):
center_mean = action_spec["loc"]
center_std = action_spec["scale"]
for itr in trange(n_iters):
if itr == 0:
style["num_samples"] = n_samples_init
else:
style["num_samples"] = n_samples_opt
samples = prompter.sample_actions(im, arm_coord, center_mean, center_std)
arrow_ids, image_circles_np = vip_perform_selection(
prompter, vlm, im, desc, arm_coord, samples, top_n=3
)
# plot sampled circles as red
selected_samples = []
for selected_id in arrow_ids:
sample = samples[selected_id]
sample.coord.color = (255, 0, 0)
selected_samples.append(sample)
image_circles_marked_np = prompter.add_arrow_overlay_plt(
image_circles_np, selected_samples, arm_coord
)
output_ims.append(image_circles_marked_np)
yield output_ims, f"Image generated for parallel sample {i+1}/{n_parallel_trials} iteration {itr+1}/{n_iters}. Still working..."
# if at last iteration, pick one answer out of the selected ones
if itr == n_iters - 1:
arrow_ids, _ = vip_perform_selection(
prompter, vlm, im, desc, arm_coord, selected_samples, top_n=1
)
selected_samples = []
for selected_id in arrow_ids:
sample = samples[selected_id]
sample.coord.color = (255, 0, 0)
selected_samples.append(sample)
image_circles_marked_np = prompter.add_arrow_overlay_plt(
im, selected_samples, arm_coord
)
output_ims.append(image_circles_marked_np)
new_samples += selected_samples
yield output_ims, f"Image generated for parallel sample {i+1}/{n_parallel_trials} last iteration. Still working..."
center_mean, center_std = prompter.fit(arrow_ids, samples)
if n_parallel_trials > 1:
# adjust sample label to avoid duplications
for sample_id in range(len(new_samples)):
new_samples[sample_id].label = str(sample_id)
arrow_ids, _ = vip_perform_selection(
prompter, vlm, im, desc, arm_coord, new_samples, top_n=1
)
selected_samples = []
for selected_id in arrow_ids:
sample = new_samples[selected_id]
sample.coord.color = (255, 0, 0)
selected_samples.append(sample)
image_circles_marked_np = prompter.add_arrow_overlay_plt(
im, selected_samples, arm_coord
)
output_ims.append(image_circles_marked_np)
center_mean, _ = prompter.fit(arrow_ids, new_samples)
if output_ims:
yield (
output_ims,
(
"Final selected coordinate:"
f" {np.round(prompter.action_to_coord(center_mean, im, arm_coord).xy, decimals=0)}"
),
)
return [], "Unable to understand query"
|