File size: 4,761 Bytes
5c80958
 
 
 
 
 
 
53ef1bb
5c80958
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f9a62da
5c80958
 
 
 
 
 
 
53ef1bb
 
 
 
 
5c80958
 
 
 
 
 
 
 
 
 
 
 
53ef1bb
5c80958
 
 
 
f9a62da
5c80958
 
 
 
 
53ef1bb
 
 
5c80958
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53ef1bb
5c80958
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53ef1bb
 
5c80958
 
53ef1bb
5c80958
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53ef1bb
 
 
 
 
 
 
 
5c80958
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
"""VIP."""

import json
import re

import cv2
from tqdm import trange
import numpy as np
import vip


def make_prompt(description, top_n=3):
  return f"""
INSTRUCTIONS:
You are tasked to locate an object, region, or point in space in the given annotated image according to a description.
The image is annoated with numbered circles.
Choose the top {top_n} circles that have the most overlap with and/or is closest to what the description is describing in the image.
You are a five-time world champion in this game. 
Give a one sentence analysis of why you chose those points.
Provide your answer at the end in a valid JSON of this format:

{{"points": []}}

DESCRIPTION: {description}
IMAGE:
""".strip()


def extract_json(response, key):
  json_part = re.search(r"\{.*\}", response, re.DOTALL)
  parsed_json = {}
  if json_part:
    json_data = json_part.group()
    # Parse the JSON data
    parsed_json = json.loads(json_data)
  else:
    print("No JSON data found ******\n", response)
  return parsed_json[key]


def vip_perform_selection(prompter, vlm, im, desc, arm_coord, samples, top_n):
  """Perform one selection pass given samples."""
  image_circles_np = prompter.add_arrow_overlay_plt(
      image=im, samples=samples, arm_xy=arm_coord
  )

  _, encoded_image_circles = cv2.imencode(".png", image_circles_np)

  prompt_seq = [make_prompt(desc, top_n=top_n), encoded_image_circles]
  response = vlm.query(prompt_seq)

  try:
    arrow_ids = extract_json(response, "points")
  except Exception as e:
    print(e)
    arrow_ids = []
  return arrow_ids, image_circles_np


def vip_runner(
    vlm,
    im,
    desc,
    style,
    action_spec,
    n_samples_init=25,
    n_samples_opt=10,
    n_iters=3,
    n_parallel_trials=1,
):
  """VIP."""

  prompter = vip.VisualIterativePrompter(
      style, action_spec, vip.SupportedEmbodiments.HF_DEMO
  )

  output_ims = []
  arm_coord = (int(im.shape[1] / 2), int(im.shape[0] / 2))

  new_samples = []
  center_mean = action_spec["loc"]
  for i in range(n_parallel_trials):
    center_mean = action_spec["loc"]
    center_std = action_spec["scale"]
    for itr in trange(n_iters):
      if itr == 0:
        style["num_samples"] = n_samples_init
      else:
        style["num_samples"] = n_samples_opt
      samples = prompter.sample_actions(im, arm_coord, center_mean, center_std)
      arrow_ids, image_circles_np = vip_perform_selection(
          prompter, vlm, im, desc, arm_coord, samples, top_n=3
      )

      # plot sampled circles as red
      selected_samples = []
      for selected_id in arrow_ids:
        sample = samples[selected_id]
        sample.coord.color = (255, 0, 0)
        selected_samples.append(sample)
      image_circles_marked_np = prompter.add_arrow_overlay_plt(
          image_circles_np, selected_samples, arm_coord
      )
      output_ims.append(image_circles_marked_np)
      yield output_ims, f"Image generated for parallel sample {i+1}/{n_parallel_trials} iteration {itr+1}/{n_iters}. Still working..."

      # if at last iteration, pick one answer out of the selected ones
      if itr == n_iters - 1:
        arrow_ids, _ = vip_perform_selection(
            prompter, vlm, im, desc, arm_coord, selected_samples, top_n=1
        )

        selected_samples = []
        for selected_id in arrow_ids:
          sample = samples[selected_id]
          sample.coord.color = (255, 0, 0)
          selected_samples.append(sample)
        image_circles_marked_np = prompter.add_arrow_overlay_plt(
            im, selected_samples, arm_coord
        )
        output_ims.append(image_circles_marked_np)
        new_samples += selected_samples
        yield output_ims, f"Image generated for parallel sample {i+1}/{n_parallel_trials} last iteration. Still working..."
      center_mean, center_std = prompter.fit(arrow_ids, samples)

  if n_parallel_trials > 1:
    # adjust sample label to avoid duplications
    for sample_id in range(len(new_samples)):
      new_samples[sample_id].label = str(sample_id)
    arrow_ids, _ = vip_perform_selection(
        prompter, vlm, im, desc, arm_coord, new_samples, top_n=1
    )

    selected_samples = []
    for selected_id in arrow_ids:
      sample = new_samples[selected_id]
      sample.coord.color = (255, 0, 0)
      selected_samples.append(sample)
    image_circles_marked_np = prompter.add_arrow_overlay_plt(
        im, selected_samples, arm_coord
    )
    output_ims.append(image_circles_marked_np)
    center_mean, _ = prompter.fit(arrow_ids, new_samples)

  if output_ims:
    yield (
        output_ims,
        (
            "Final selected coordinate:"
            f" {np.round(prompter.action_to_coord(center_mean, im, arm_coord).xy, decimals=0)}"
        ),
    )
  return [], "Unable to understand query"