|
from typing import List, Union |
|
|
|
from vision_functions import find_in_image, simple_qa, verify_property, best_text_match, compute_depth |
|
|
|
|
|
def bool_to_yesno(bool_answer: bool) -> str: |
|
return if bool_answer else |
|
|
|
|
|
class ImagePatch: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
What is this? |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, image, left: int = None, lower: int = None, right: int = None, upper: int = None): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if left is None and right is None and upper is None and lower is None: |
|
self.cropped_image = image |
|
self.left = 0 |
|
self.lower = 0 |
|
self.right = image.shape[2] # width |
|
self.upper = image.shape[1] # height |
|
else: |
|
self.cropped_image = image[:, lower:upper, left:right] |
|
self.left = left |
|
self.upper = upper |
|
self.right = right |
|
self.lower = lower |
|
|
|
self.width = self.cropped_image.shape[2] |
|
self.height = self.cropped_image.shape[1] |
|
|
|
self.horizontal_center = (self.left + self.right) / 2 |
|
self.vertical_center = (self.lower + self.upper) / 2 |
|
|
|
def find(self, object_name: str) -> List[]: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
foo |
|
|
|
|
|
return find_in_image(self.cropped_image, object_name) |
|
|
|
def simple_query(self, question: str = None) -> str: |
|
What is this? |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
animal |
|
|
|
animaleating |
|
What kind of animal is eating? |
|
|
|
Which kind of animal is not eating? |
|
|
|
|
|
|
|
|
|
|
|
What is in front of the horse? |
|
|
|
return simple_qa(self.cropped_image, question) |
|
|
|
def exists(self, object_name: str) -> bool: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cake |
|
gummy bear |
|
|
|
|
|
return len(self.find(object_name)) > 0 |
|
|
|
def verify_property(self, object_name: str, property: str) -> bool: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
letters |
|
|
|
|
|
|
|
Do the letters have blue color? |
|
lettersblue |
|
|
|
return verify_property(self.cropped_image, object_name, property) |
|
|
|
def compute_depth(self): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bar |
|
|
|
|
|
|
|
depth_map = compute_depth(self.cropped_image) |
|
return depth_map.median() |
|
|
|
def best_text_match(self, option_list: List[str]) -> str: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cap |
|
|
|
|
|
|
|
Is the cap gold or white? |
|
goldwhite |
|
|
|
return best_text_match(self.cropped_image, option_list) |
|
|
|
def crop(self, left: int, lower: int, right: int, upper: int) -> : |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return ImagePatch(self.cropped_image, left, lower, right, upper) |
|
|
|
|
|
def best_image_match(list_patches: List[ImagePatch], content: List[str], return_index=False) -> Union[ImagePatch, int]: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return best_image_match(list_patches, content, return_index) |
|
|
|
|
|
def distance(patch_a: ImagePatch, patch_b: ImagePatch) -> float: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return distance(patch_a, patch_b) |
|
|
|
|
|
# Examples of using ImagePatch |
|
|
|
|
|
# Given an image: What toy is wearing a shirt? |
|
def execute_command(image) -> str: |
|
# not a relational verb so go step by step |
|
image_patch = ImagePatch(image) |
|
toy_patches = image_patch.find() |
|
# Question assumes only one toy patch |
|
if len(toy_patches) == 0: |
|
# If no toy is found, query the image directly |
|
return image_patch.simple_query() |
|
for toy_patch in toy_patches: |
|
is_wearing_shirt = (toy_patch.simple_query() == ) |
|
if is_wearing_shirt: |
|
return toy_patch.simple_query( |
|
) # crop would include the shirt so keep it in the query |
|
# If no toy is wearing a shirt, pick the first toy |
|
return toy_patches[0].simple_query() |
|
|
|
|
|
# Given an image: Who is the man staring at? |
|
def execute_command(image) -> str: |
|
# asks for the predicate of a relational verb (staring at), so ask directly |
|
image_patch = ImagePatch(image) |
|
return image_patch.simple_query() |
|
|
|
|
|
# Given an image: Find more visible chair. |
|
def execute_command(image) -> ImagePatch: |
|
# Return the chair |
|
image_patch = ImagePatch(image) |
|
# Remember: return the chair |
|
return image_patch.find()[0] |
|
|
|
|
|
# Given an image: Find lamp on the bottom. |
|
def execute_command(image) -> ImagePatch: |
|
# Return the lamp |
|
image_patch = ImagePatch(image) |
|
lamp_patches = image_patch.find() |
|
lamp_patches.sort(key=lambda lamp: lamp.vertical_center) |
|
# Remember: return the lamp |
|
return lamp_patches[0] # Return the bottommost lamp |
|
|
|
|
|
# Given a list of images: Does the pole that is near a building that is near a green sign and the pole that is near bushes that are near a green sign have the same material? |
|
def execute_command(image_list) -> str: |
|
material_1 = None |
|
material_2 = None |
|
for image in image_list: |
|
image = ImagePatch(image) |
|
# find the building |
|
building_patches = image.find() |
|
for building_patch in building_patches: |
|
poles = building_patch.find() |
|
signs = building_patch.find() |
|
greensigns = [sign for sign in signs if sign.verify_property('sign', 'green')] |
|
if len(poles) > 0 and len(greensigns) > 0: |
|
material_1 = poles[0].simple_query() |
|
# find the bush |
|
bushes_patches = image.find() |
|
for bushes_patch in bushes_patches: |
|
poles = bushes_patch.find() |
|
signs = bushes_patch.find() |
|
greensigns = [sign for sign in signs if sign.verify_property('sign', 'green')] |
|
if len(poles) > 0 and len(greensigns) > 0: |
|
material_2 = poles[0].simple_query() |
|
return bool_to_yesno(material_1 == material_2) |
|
|
|
|
|
# Given an image: Find middle kid. |
|
def execute_command(image) -> ImagePatch: |
|
# Return the kid |
|
image_patch = ImagePatch(image) |
|
kid_patches = image_patch.find() |
|
if len(kid_patches) == 0: |
|
kid_patches = [image_patch] |
|
kid_patches.sort(key=lambda kid: kid.horizontal_center) |
|
# Remember: return the kid |
|
return kid_patches[len(kid_patches) // 2] # Return the middle kid |
|
|
|
|
|
# Given an image: Is that blanket to the right of a pillow? |
|
def execute_command(image) -> str: |
|
image_patch = ImagePatch(image) |
|
blanket_patches = image_patch.find() |
|
# Question assumes only one blanket patch |
|
if len(blanket_patches) == 0: |
|
# If no blanket is found, query the image directly |
|
return image_patch.simple_query() |
|
for blanket_patch in blanket_patches: |
|
pillow_patches = image_patch.find() |
|
for pillow_patch in pillow_patches: |
|
if pillow_patch.horizontal_center > blanket_patch.horizontal_center: |
|
return |
|
return |
|
|
|
|
|
# Given an image: How many people are there? |
|
def execute_command(image) -> str: |
|
image_patch = ImagePatch(image) |
|
person_patches = image_patch.find() |
|
return str(len(person_patches)) |
|
|
|
|
|
# Given a list of images: Is the man that is wearing dark pants driving?. |
|
def execute_command(image_list) -> str: |
|
for image in image_list: |
|
image = ImagePatch(image) |
|
man_patches = image.find() |
|
for man_patch in man_patches: |
|
pants = man_patch.find() |
|
if len(pants) == 0: |
|
continue |
|
if pants[0].verify_property(, ): |
|
return man_patch.simple_query() |
|
return ImagePatch(image_list[0]).simple_query() |
|
|
|
|
|
# Given an image: Is there a backpack to the right of the man? |
|
def execute_command(image) -> str: |
|
image_patch = ImagePatch(image) |
|
man_patches = image_patch.find() |
|
# Question assumes one man patch |
|
if len(man_patches) == 0: |
|
# If no man is found, query the image directly |
|
return image_patch.simple_query() |
|
man_patch = man_patches[0] |
|
backpack_patches = image_patch.find() |
|
# Question assumes one backpack patch |
|
if len(backpack_patches) == 0: |
|
return |
|
for backpack_patch in backpack_patches: |
|
if backpack_patch.horizontal_center > man_patch.horizontal_center: |
|
return |
|
return |
|
|
|
|
|
# Given a list of images: What is the pizza with red tomato on it on? |
|
def execute_command(image_list) -> str: |
|
for image in image_list: |
|
image = ImagePatch(image) |
|
pizza_patches = image.find() |
|
for pizza_patch in pizza_patches: |
|
tomato_patches = pizza_patch.find() |
|
has_red_tomato = False |
|
for tomato_patch in tomato_patches: |
|
if tomato_patch.verify_property(, ): |
|
has_red_tomato = True |
|
if has_red_tomato: |
|
return pizza_patch.simple_query() |
|
return ImagePatch(image_list[0]).simple_query() |
|
|
|
|
|
# Given an image: Find chair to the right near the couch. |
|
def execute_command(image) -> ImagePatch: |
|
# Return the chair |
|
image_patch = ImagePatch(image) |
|
chair_patches = image_patch.find() |
|
if len(chair_patches) == 0: |
|
chair_patches = [image_patch] |
|
elif len(chair_patches) == 1: |
|
return chair_patches[0] |
|
chair_patches_right = [c for c in chair_patches if c.horizontal_center > image_patch.horizontal_center] |
|
couch_patches = image_patch.find() |
|
if len(couch_patches) == 0: |
|
couch_patches = [image_patch] |
|
couch_patch = couch_patches[0] |
|
chair_patches_right.sort(key=lambda c: distance(c, couch_patch)) |
|
chair_patch = chair_patches_right[0] |
|
# Remember: return the chair |
|
return chair_patch |
|
|
|
|
|
# Given an image: Are there bagels or lemons? |
|
def execute_command(image) -> str: |
|
image_patch = ImagePatch(image) |
|
is_bagel = image_patch.exists() |
|
is_lemon = image_patch.exists() |
|
return bool_to_yesno(is_bagel or is_lemon) |
|
|
|
|
|
# Given an image: In which part is the bread, the bottom or the top? |
|
def execute_command(image) -> str: |
|
image_patch = ImagePatch(image) |
|
bread_patches = image_patch.find() |
|
# Question assumes only one bread patch |
|
if len(bread_patches) == 0: |
|
# If no bread is found, query the image directly |
|
return image_patch.simple_query() |
|
if bread_patches[0].vertical_center < image_patch.vertical_center: |
|
return |
|
else: |
|
return |
|
|
|
|
|
# Given an image: Find foo to bottom left. |
|
def execute_command(image) -> ImagePatch: |
|
# Return the foo |
|
image_patch = ImagePatch(image) |
|
foo_patches = image_patch.find() |
|
lowermost_coordinate = min([patch.vertical_center for patch in foo_patches]) |
|
foo_patches_bottom = [patch for patch in foo_patches if patch.vertical_center - lowermost_coordinate < 100] |
|
if len(foo_patches_bottom) == 0: |
|
foo_patches_bottom = foo_patches |
|
elif len(foo_patches_bottom) == 1: |
|
return foo_patches_bottom[0] |
|
foo_patches_bottom.sort(key=lambda foo: foo.horizontal_center) |
|
foo_patch = foo_patches_bottom[0] |
|
# Remember: return the foo |
|
return foo_patch |
|
|
|
|
|
# Given an image: Find number 17. |
|
def execute_command(image) -> ImagePatch: |
|
# Return the person |
|
image_patch = ImagePatch(image) |
|
person_patches = image_patch.find() |
|
for patch in person_patches: |
|
if patch.exists(): |
|
return patch |
|
# Remember: return the person |
|
return person_patches[0] |
|
|
|
|
|
# Given a list of images: Is the statement true? There is at least 1 image with a brown dog that is near a bicycle and is wearing a collar. |
|
def execute_command(image_list) -> str: |
|
for image in image_list: |
|
image = ImagePatch(image) |
|
dog_patches = image.find() |
|
for dog in dog_patches: |
|
near_bicycle = dog.simple_query() |
|
wearing_collar = dog.simple_query() |
|
if near_bicycle == and wearing_collar == : |
|
return 'yes' |
|
return 'no' |
|
|
|
|
|
# Given an image: Find dog to the left of the post who is closest to girl wearing a shirt with text that says . |
|
def execute_command(image) -> ImagePatch: |
|
# Return the dog |
|
image_patch = ImagePatch(image) |
|
shirt_patches = image_patch.find() |
|
if len(shirt_patches) == 0: |
|
shirt_patches = [image_patch] |
|
shirt_patch = best_image_match(list_patches=shirt_patches, content=[]) |
|
post_patches = image_patch.find() |
|
post_patches.sort(key=lambda post: distance(post, shirt_patch)) |
|
post_patch = post_patches[0] |
|
dog_patches = image_patch.find() |
|
dogs_left_patch = [dog for dog in dog_patches if dog.left < post_patch.left] |
|
if len(dogs_left_patch) == 0: |
|
dogs_left_patch = dog_patches |
|
dogs_left_patch.sort(key=lambda dog: distance(dog, post_patch)) |
|
dog_patch = dogs_left_patch[0] |
|
# Remember: return the dog |
|
return dog_patch |
|
|
|
|
|
# Given an image: Find balloon on the right and second from the bottom. |
|
def execute_command(image) -> ImagePatch: |
|
# Return the balloon |
|
image_patch = ImagePatch(image) |
|
balloon_patches = image_patch.find() |
|
if len(balloon_patches) == 0: |
|
balloon_patches = [image_patch] |
|
elif len(balloon_patches) == 1: |
|
return balloon_patches[0] |
|
leftmost_coordinate = min([patch.horizontal_center for patch in balloon_patches]) |
|
balloon_patches_right = [patch for patch in balloon_patches if patch.horizontal_center - leftmost_coordinate < 100] |
|
if len(balloon_patches_right) == 0: |
|
balloon_patches_right = balloon_patches |
|
balloon_patches_right.sort(key=lambda p: p.vertical_center) |
|
balloon_patch = balloon_patches_right[1] |
|
# Remember: return the balloon |
|
return balloon_patch |
|
|
|
|
|
# Given an image: Find girl in white next to man in left. |
|
def execute_command(image) -> ImagePatch: |
|
# Return the girl |
|
image_patch = ImagePatch(image) |
|
girl_patches = image_patch.find() |
|
girl_in_white_patches = [g for g in girl_patches if g.verify_property(, )] |
|
if len(girl_in_white_patches) == 0: |
|
girl_in_white_patches = girl_patches |
|
man_patches = image_patch.find() |
|
man_patches.sort(key=lambda man: man.horizontal_center) |
|
leftmost_man = man_patches[0] # First from the left |
|
girl_in_white_patches.sort(key=lambda girl: distance(girl, leftmost_man)) |
|
girl_patch = girl_in_white_patches[0] |
|
# Remember: return the girl |
|
return girl_patch |
|
|
|
|
|
# Given a list of images: Is the statement true? There is 1 table that is in front of woman that is wearing jacket. |
|
def execute_command(image_list) -> str: |
|
for image in image_list: |
|
image = ImagePatch(image) |
|
woman_patches = image.find() |
|
for woman in woman_patches: |
|
if woman.simple_query() == : |
|
tables = woman.find() |
|
return bool_to_yesno(len(tables) == 1) |
|
return 'no' |
|
|
|
|
|
# Given an image: Find top left. |
|
def execute_command(image) -> ImagePatch: |
|
# Return the person |
|
image_patch = ImagePatch(image) |
|
# Figure out what thing the caption is referring to. We need a subject for every caption |
|
persons = image_patch.find() |
|
top_all_objects = max([obj.vertical_center for obj in persons]) |
|
# Select objects that are close to the top |
|
# We do this because the caption is asking first about vertical and then about horizontal |
|
persons_top = [p for p in persons if top_all_objects - p.vertical_center < 100] |
|
if len(persons_top) == 0: |
|
persons_top = persons |
|
# And after that, obtain the leftmost object among them |
|
persons_top.sort(key=lambda obj: obj.horizontal_center) |
|
person_leftmost = persons_top[0] |
|
# Remember: return the person |
|
return person_leftmost |
|
|
|
|
|
# Given an image: What type of weather do you see in the photograph? |
|
def execute_command(image) -> str: |
|
image_patch = ImagePatch(image) |
|
return image_patch.simple_query() |
|
|
|
|
|
# Given an image: How many orange life vests can be seen? |
|
def execute_command(image) -> str: |
|
image_patch = ImagePatch(image) |
|
life_vest_patches = image_patch.find() |
|
orange_life_vest_patches = [] |
|
for life_vest_patch in life_vest_patches: |
|
if life_vest_patch.verify_property('life vest', 'orange'): |
|
orange_life_vest_patches.append(life_vest_patch) |
|
return str(len(orange_life_vest_patches)) |
|
|
|
|
|
# Given an image: What is behind the pole? |
|
def execute_command(image) -> str: |
|
image_patch = ImagePatch(image) |
|
# contains a relation (around, next to, on, near, on top of, in front of, behind, etc), so ask directly |
|
return image_patch.simple_query() |
|
|
|
|
|
# Given an image: Find second to top flower. |
|
def execute_command(image) -> ImagePatch: |
|
# Return the flower |
|
image_patch = ImagePatch(image) |
|
flower_patches = image_patch.find() |
|
flower_patches.sort(key=lambda flower: flower.vertical_center) |
|
flower_patch = flower_patches[-2] |
|
# Remember: return the flower |
|
return flower_patch |
|
|
|
|
|
# Given an image: Find back. |
|
def execute_command(image) -> ImagePatch: |
|
# Return the person |
|
image_patch = ImagePatch(image) |
|
person_patches = image_patch.find() |
|
person_patches.sort(key=lambda person: person.compute_depth()) |
|
person_patch = person_patches[-1] |
|
# Remember: return the person |
|
return person_patch |
|
|
|
|
|
# Given an image: Find chair at the front. |
|
def execute_command(image) -> ImagePatch: |
|
# Return the chair |
|
image_patch = ImagePatch(image) |
|
chair_patches = image_patch.find() |
|
chair_patches.sort(key=lambda chair: chair.compute_depth()) |
|
chair_patch = chair_patches[0] |
|
# Remember: return the chair |
|
return chair_patch |
|
|
|
|
|
# Given an image: Find white and yellow pants. |
|
def execute_command(image) -> ImagePatch: |
|
# Return the person |
|
image_patch = ImagePatch(image) |
|
# Clothing always requires returning the person |
|
person_patches = image_patch.find() |
|
person_patch = best_image_match(person_patches, [, ]) |
|
# Remember: return the person |
|
return person_patch |
|
|
|
|
|
# Given an image: Find cow facing the camera. |
|
def execute_command(image) -> ImagePatch: |
|
# Return the cow |
|
image_patch = ImagePatch(image) |
|
cow_patches = image_patch.find() |
|
if len(cow_patches) == 0: |
|
cow_patches = [image_patch] |
|
cow_patch = best_image_match(list_patches=cow_patches, content=[]) |
|
# Remember: return the cow |
|
return cow_patch |
|
|
|
|
|
# Given a list of images: Is the statement true? There is 1 image that contains exactly 3 blue papers. |
|
def execute_command(image_list) -> str: |
|
image_cnt = 0 |
|
for image in image_list: |
|
image = ImagePatch(image) |
|
paper_patches = image.find() |
|
blue_paper_patches = [] |
|
for paper in paper_patches: |
|
if paper.verify_property(, ): |
|
blue_paper_patches.append(paper) |
|
if len(blue_paper_patches) == 3: |
|
image_cnt += 1 |
|
return bool_to_yesno(image_cnt == 1) |
|
|
|
|
|
# Given an image: Find black car just under stop sign. |
|
def execute_command(image) -> ImagePatch: |
|
# Return the car |
|
image_patch = ImagePatch(image) |
|
stop_sign_patches = image_patch.find() |
|
if len(stop_sign_patches) == 0: |
|
stop_sign_patches = [image_patch] |
|
stop_sign_patch = stop_sign_patches[0] |
|
car_patches = image_patch.find() |
|
car_under_stop = [] |
|
for car in car_patches: |
|
if car.upper < stop_sign_patch.upper: |
|
car_under_stop.append(car) |
|
# Find car that is closest to the stop sign |
|
car_under_stop.sort(key=lambda car: car.vertical_center - stop_sign_patch.vertical_center) |
|
# Remember: return the car |
|
return car_under_stop[0] |
|
|
|
|
|
# Given a list of images: Is there either a standing man that is holding a cell phone or a sitting man that is holding a cell phone? |
|
def execute_command(image_list) -> str: |
|
for image in image_list: |
|
image = ImagePatch(image) |
|
man_patches = image.find() |
|
for man in man_patches: |
|
holding_cell_phone = man.simple_query() |
|
if holding_cell_phone == : |
|
if man.simple_query() == : |
|
return 'yes' |
|
if man.simple_query() == : |
|
return 'yes' |
|
return 'no' |
|
|
|
|
|
# Given a list of images: How many people are running while looking at their cell phone? |
|
def execute_command(image) -> str: |
|
image_patch = ImagePatch(image) |
|
people_patches = image_patch.find() |
|
# Question assumes only one person patch |
|
if len(people_patches) == 0: |
|
# If no people are found, query the image directly |
|
return image_patch.simple_query() |
|
people_count = 0 |
|
for person_patch in people_patches: |
|
# Verify two conditions: (1) running (2) looking at cell phone |
|
if person_patch.simple_query() == : |
|
if person_patch.simple_query() == : |
|
people_count += 1 |
|
return str(people_count) |
|
|
|
|
|
# Given a list of images: Does the car that is on a highway and the car that is on a street have the same color? |
|
def execute_command(image_list) -> str: |
|
color_1 = None |
|
color_2 = None |
|
for image in image_list: |
|
image = ImagePatch(image) |
|
car_patches = image.find() |
|
for car_patch in car_patches: |
|
if car_patch.simple_query() == : |
|
color_1 = car_patch.simple_query() |
|
elif car_patch.simple_query() == : |
|
color_2 = car_patch.simple_query() |
|
return bool_to_yesno(color_1 == color_2) |
|
|
|
|
|
# Given a list of images: Is the statement true? There are 3 magazine that are on table. |
|
def execute_command(image_list) -> str: |
|
count = 0 |
|
for image in image_list: |
|
image = ImagePatch(image) |
|
magazine_patches = image.find() |
|
for magazine_patch in magazine_patches: |
|
on_table = magazine_patch.simple_query() |
|
if on_table == : |
|
count += 1 |
|
return bool_to_yesno(count == 3) |
|
|
|
|
|
# INSERT_QUERY_HERE |