|
from __future__ import annotations |
|
|
|
import json |
|
import pathlib |
|
import re |
|
from typing import Tuple |
|
from typing import Union, List |
|
|
|
import numpy as np |
|
import torch |
|
from PIL import Image |
|
from dateutil import parser as dateparser |
|
from torchvision import transforms |
|
from torchvision.ops import box_iou |
|
from word2number import w2n |
|
|
|
from vision_processes import forward |
|
|
|
|
|
def load_json(path: str): |
|
if isinstance(path, str): |
|
path = pathlib.Path(path) |
|
if path.suffix != '.json': |
|
path = path.with_suffix('.json') |
|
with open(path, 'r') as f: |
|
data = json.load(f) |
|
return data |
|
|
|
|
|
class ImagePatch: |
|
"""A Python class containing a crop of an image centered around a particular object, as well as relevant |
|
information. |
|
Attributes |
|
---------- |
|
cropped_image : array_like |
|
An array-like of the cropped image taken from the original image. |
|
left : int |
|
An int describing the position of the left border of the crop's bounding box in the original image. |
|
lower : int |
|
An int describing the position of the bottom border of the crop's bounding box in the original image. |
|
right : int |
|
An int describing the position of the right border of the crop's bounding box in the original image. |
|
upper : int |
|
An int describing the position of the top border of the crop's bounding box in the original image. |
|
|
|
Methods |
|
------- |
|
find(object_name: str)->List[ImagePatch] |
|
Returns a list of new ImagePatch objects containing crops of the image centered around any objects found in the |
|
image matching the object_name. |
|
exists(object_name: str)->bool |
|
Returns True if the object specified by object_name is found in the image, and False otherwise. |
|
verify_property(property: str)->bool |
|
Returns True if the property is met, and False otherwise. |
|
best_text_match(option_list: List[str], prefix: str)->str |
|
Returns the string that best matches the image. |
|
simple_query(question: str=None)->str |
|
Returns the answer to a basic question asked about the image. If no question is provided, returns the answer |
|
to "What is this?". |
|
compute_depth()->float |
|
Returns the median depth of the image crop. |
|
crop(left: int, lower: int, right: int, upper: int)->ImagePatch |
|
Returns a new ImagePatch object containing a crop of the image at the given coordinates. |
|
""" |
|
|
|
def __init__(self, image: Union[Image.Image, torch.Tensor, np.ndarray], left: int = None, lower: int = None, |
|
right: int = None, upper: int = None, parent_left=0, parent_lower=0, queues=None, |
|
parent_img_patch=None): |
|
"""Initializes an ImagePatch object by cropping the image at the given coordinates and stores the coordinates as |
|
attributes. If no coordinates are provided, the image is left unmodified, and the coordinates are set to the |
|
dimensions of the image. |
|
|
|
Parameters |
|
------- |
|
image : array_like |
|
An array-like of the original image. |
|
left : int |
|
An int describing the position of the left border of the crop's bounding box in the original image. |
|
lower : int |
|
An int describing the position of the bottom border of the crop's bounding box in the original image. |
|
right : int |
|
An int describing the position of the right border of the crop's bounding box in the original image. |
|
upper : int |
|
An int describing the position of the top border of the crop's bounding box in the original image. |
|
|
|
""" |
|
|
|
if isinstance(image, Image.Image): |
|
image = transforms.ToTensor()(image) |
|
elif isinstance(image, np.ndarray): |
|
image = torch.tensor(image).permute(1, 2, 0) |
|
elif isinstance(image, torch.Tensor) and image.dtype == torch.uint8: |
|
image = image / 255 |
|
|
|
if left is None and right is None and upper is None and lower is None: |
|
self.cropped_image = image |
|
self.left = 0 |
|
self.lower = 0 |
|
self.right = image.shape[2] |
|
self.upper = image.shape[1] |
|
else: |
|
self.cropped_image = image[:, image.shape[1] - upper:image.shape[1] - lower, left:right] |
|
self.left = left + parent_left |
|
self.upper = upper + parent_lower |
|
self.right = right + parent_left |
|
self.lower = lower + parent_lower |
|
|
|
self.height = self.cropped_image.shape[1] |
|
self.width = self.cropped_image.shape[2] |
|
|
|
self.cache = {} |
|
self.queues = (None, None) if queues is None else queues |
|
|
|
self.parent_img_patch = parent_img_patch |
|
|
|
self.horizontal_center = (self.left + self.right) / 2 |
|
self.vertical_center = (self.lower + self.upper) / 2 |
|
|
|
if self.cropped_image.shape[1] == 0 or self.cropped_image.shape[2] == 0: |
|
raise Exception("ImagePatch has no area") |
|
|
|
self.possible_options = load_json('./useful_lists/possible_options.json') |
|
|
|
def forward(self, model_name, *args, **kwargs): |
|
return forward(model_name, *args, **kwargs) |
|
|
|
|
|
@property |
|
def original_image(self): |
|
if self.parent_img_patch is None: |
|
return self.cropped_image |
|
else: |
|
return self.parent_img_patch.original_image |
|
|
|
def find(self, object_name: str, confidence_threshold: float = None, return_confidence: bool = False) -> List: |
|
"""Returns a list of ImagePatch objects matching object_name contained in the crop if any are found. |
|
Otherwise, returns an empty list. |
|
Parameters |
|
---------- |
|
object_name : str |
|
the name of the object to be found |
|
|
|
Returns |
|
------- |
|
List[ImagePatch] |
|
a list of ImagePatch objects matching object_name contained in the crop |
|
""" |
|
if confidence_threshold is not None: |
|
confidence_threshold = float(confidence_threshold) |
|
|
|
if object_name in ["object", "objects"]: |
|
all_object_coordinates, all_object_scores = self.forward('maskrcnn', self.cropped_image, |
|
confidence_threshold=confidence_threshold) |
|
all_object_coordinates = all_object_coordinates[0] |
|
all_object_scores = all_object_scores[0] |
|
else: |
|
if object_name == 'person': |
|
object_name = 'people' |
|
|
|
all_object_coordinates, all_object_scores = self.forward('glip', self.cropped_image, object_name, |
|
confidence_threshold=confidence_threshold) |
|
if len(all_object_coordinates) == 0: |
|
return [] |
|
|
|
threshold = 0.0 |
|
if threshold > 0: |
|
area_im = self.width * self.height |
|
all_areas = torch.tensor([(coord[2] - coord[0]) * (coord[3] - coord[1]) / area_im |
|
for coord in all_object_coordinates]) |
|
mask = all_areas > threshold |
|
|
|
|
|
all_object_coordinates = all_object_coordinates[mask] |
|
all_object_scores = all_object_scores[mask] |
|
|
|
boxes = [self.crop(*coordinates) for coordinates in all_object_coordinates] |
|
if return_confidence: |
|
return [(box, float(score)) for box, score in zip(boxes, all_object_scores.reshape(-1))] |
|
else: |
|
return boxes |
|
|
|
def exists(self, object_name) -> bool: |
|
"""Returns True if the object specified by object_name is found in the image, and False otherwise. |
|
Parameters |
|
------- |
|
object_name : str |
|
A string describing the name of the object to be found in the image. |
|
""" |
|
if object_name.isdigit() or object_name.lower().startswith("number"): |
|
object_name = object_name.lower().replace("number", "").strip() |
|
|
|
object_name = w2n.word_to_num(object_name) |
|
answer = self.simple_query("What number is written in the image (in digits)?") |
|
return w2n.word_to_num(answer) == object_name |
|
|
|
patches = self.find(object_name) |
|
|
|
filtered_patches = [] |
|
for patch in patches: |
|
if "yes" in patch.simple_query(f"Is this a {object_name}?"): |
|
filtered_patches.append(patch) |
|
return len(filtered_patches) > 0 |
|
|
|
def _score(self, category: str, negative_categories=None, model='clip') -> float: |
|
""" |
|
Returns a binary score for the similarity between the image and the category. |
|
The negative categories are used to compare to (score is relative to the scores of the negative categories). |
|
""" |
|
if model == 'clip': |
|
res = self.forward('clip', self.cropped_image, category, task='score', |
|
negative_categories=negative_categories) |
|
elif model == 'tcl': |
|
res = self.forward('tcl', self.cropped_image, category, task='score') |
|
else: |
|
task = 'binary_score' if negative_categories is not None else 'score' |
|
res = self.forward('xvlm', self.cropped_image, category, task=task, negative_categories=negative_categories) |
|
res = res.item() |
|
|
|
return res |
|
|
|
def _detect(self, category: str, thresh, negative_categories=None, model='clip') -> Tuple[bool, float]: |
|
score = self._score(category, negative_categories, model) |
|
return score > thresh, float(score) |
|
|
|
def verify_property(self, object_name: str, attribute: str, return_confidence: bool = False): |
|
"""Returns True if the object possesses the property, and False otherwise. |
|
Differs from 'exists' in that it presupposes the existence of the object specified by object_name, instead |
|
checking whether the object possesses the property. |
|
Parameters |
|
------- |
|
object_name : str |
|
A string describing the name of the object to be found in the image. |
|
attribute : str |
|
A string describing the property to be checked. |
|
""" |
|
name = f"{attribute} {object_name}" |
|
model = "xvlm" |
|
negative_categories = [f"{att} {object_name}" for att in self.possible_options['attributes']] |
|
|
|
|
|
|
|
|
|
|
|
|
|
ret, score = self._detect(name, negative_categories=negative_categories, thresh=0.6, model='xvlm') |
|
|
|
if return_confidence: |
|
return ret, score |
|
else: |
|
return ret |
|
|
|
def best_text_match(self, option_list: list[str] = None, prefix: str = None) -> str: |
|
"""Returns the string that best matches the image. |
|
Parameters |
|
------- |
|
option_list : str |
|
A list with the names of the different options |
|
prefix : str |
|
A string with the prefixes to append to the options |
|
""" |
|
option_list_to_use = option_list |
|
if prefix is not None: |
|
option_list_to_use = [prefix + " " + option for option in option_list] |
|
|
|
model_name = "xvlm" |
|
image = self.cropped_image |
|
text = option_list_to_use |
|
if model_name in ('clip', 'tcl'): |
|
selected = self.forward(model_name, image, text, task='classify') |
|
elif model_name == 'xvlm': |
|
res = self.forward(model_name, image, text, task='score') |
|
res = res.argmax().item() |
|
selected = res |
|
else: |
|
raise NotImplementedError |
|
|
|
return option_list[selected] |
|
|
|
def simple_query(self, question: str, return_confidence: bool = False): |
|
"""Returns the answer to a basic question asked about the image. If no question is provided, returns the answer |
|
to "What is this?". The questions are about basic perception, and are not meant to be used for complex reasoning |
|
or external knowledge. |
|
Parameters |
|
------- |
|
question : str |
|
A string describing the question to be asked. |
|
""" |
|
text, score = self.forward('blip', self.cropped_image, question, task='qa') |
|
if return_confidence: |
|
return text, score |
|
else: |
|
return text |
|
|
|
def compute_depth(self): |
|
"""Returns the median depth of the image crop |
|
Parameters |
|
---------- |
|
Returns |
|
------- |
|
float |
|
the median depth of the image crop |
|
""" |
|
original_image = self.original_image |
|
depth_map = self.forward('depth', original_image) |
|
depth_map = depth_map[original_image.shape[1] - self.upper:original_image.shape[1] - self.lower, |
|
self.left:self.right] |
|
return depth_map.median() |
|
|
|
def crop(self, left: int, lower: int, right: int, upper: int) -> ImagePatch: |
|
"""Returns a new ImagePatch containing a crop of the original image at the given coordinates. |
|
Parameters |
|
---------- |
|
left : int |
|
the position of the left border of the crop's bounding box in the original image |
|
lower : int |
|
the position of the bottom border of the crop's bounding box in the original image |
|
right : int |
|
the position of the right border of the crop's bounding box in the original image |
|
upper : int |
|
the position of the top border of the crop's bounding box in the original image |
|
|
|
Returns |
|
------- |
|
ImagePatch |
|
a new ImagePatch containing a crop of the original image at the given coordinates |
|
""" |
|
|
|
left = int(left) |
|
lower = int(lower) |
|
right = int(right) |
|
upper = int(upper) |
|
|
|
if True: |
|
left = max(0, left - 10) |
|
lower = max(0, lower - 10) |
|
right = min(self.width, right + 10) |
|
upper = min(self.height, upper + 10) |
|
|
|
return ImagePatch(self.cropped_image, left, lower, right, upper, self.left, self.lower, queues=self.queues, |
|
parent_img_patch=self) |
|
|
|
def overlaps_with(self, left, lower, right, upper): |
|
"""Returns True if a crop with the given coordinates overlaps with this one, |
|
else False. |
|
Parameters |
|
---------- |
|
left : int |
|
the left border of the crop to be checked |
|
lower : int |
|
the lower border of the crop to be checked |
|
right : int |
|
the right border of the crop to be checked |
|
upper : int |
|
the upper border of the crop to be checked |
|
|
|
Returns |
|
------- |
|
bool |
|
True if a crop with the given coordinates overlaps with this one, else False |
|
""" |
|
return self.left <= right and self.right >= left and self.lower <= upper and self.upper >= lower |
|
|
|
def llm_query(self, question: str, long_answer: bool = True) -> str: |
|
return llm_query(question, None, long_answer) |
|
|
|
|
|
|
|
|
|
def __repr__(self): |
|
return "ImagePatch(left={}, right={}, upper={}, lower={}, height={}, width={}, horizontal_center={}, vertical_center={})".format( |
|
self.left, self.right, self.upper, self.lower, self.height, self.width, |
|
self.horizontal_center, self.vertical_center, |
|
) |
|
|
|
|
|
|
|
def best_image_match(list_patches: list[ImagePatch], content: List[str], return_index: bool = False) -> \ |
|
Union[ImagePatch, None]: |
|
"""Returns the patch most likely to contain the content. |
|
Parameters |
|
---------- |
|
list_patches : List[ImagePatch] |
|
content : List[str] |
|
the object of interest |
|
return_index : bool |
|
if True, returns the index of the patch most likely to contain the object |
|
|
|
Returns |
|
------- |
|
int |
|
Patch most likely to contain the object |
|
""" |
|
if len(list_patches) == 0: |
|
return None |
|
|
|
model = "xvlm" |
|
|
|
scores = [] |
|
for cont in content: |
|
if model == 'clip': |
|
res = list_patches[0].forward(model, [p.cropped_image for p in list_patches], cont, task='compare', |
|
return_scores=True) |
|
else: |
|
res = list_patches[0].forward(model, [p.cropped_image for p in list_patches], cont, task='score') |
|
scores.append(res) |
|
scores = torch.stack(scores).mean(dim=0) |
|
scores = scores.argmax().item() |
|
|
|
if return_index: |
|
return scores |
|
return list_patches[scores] |
|
|
|
|
|
def distance(patch_a: Union[ImagePatch, float], patch_b: Union[ImagePatch, float]) -> float: |
|
""" |
|
Returns the distance between the edges of two ImagePatches, or between two floats. |
|
If the patches overlap, it returns a negative distance corresponding to the negative intersection over union. |
|
""" |
|
|
|
if isinstance(patch_a, ImagePatch) and isinstance(patch_b, ImagePatch): |
|
a_min = np.array([patch_a.left, patch_a.lower]) |
|
a_max = np.array([patch_a.right, patch_a.upper]) |
|
b_min = np.array([patch_b.left, patch_b.lower]) |
|
b_max = np.array([patch_b.right, patch_b.upper]) |
|
|
|
u = np.maximum(0, a_min - b_max) |
|
v = np.maximum(0, b_min - a_max) |
|
|
|
dist = np.sqrt((u ** 2).sum() + (v ** 2).sum()) |
|
|
|
if dist == 0: |
|
box_a = torch.tensor([patch_a.left, patch_a.lower, patch_a.right, patch_a.upper])[None] |
|
box_b = torch.tensor([patch_b.left, patch_b.lower, patch_b.right, patch_b.upper])[None] |
|
dist = - box_iou(box_a, box_b).item() |
|
|
|
else: |
|
dist = abs(patch_a - patch_b) |
|
|
|
return dist |
|
|
|
|
|
def bool_to_yesno(bool_answer: bool) -> str: |
|
"""Returns a yes/no answer to a question based on the boolean value of bool_answer. |
|
Parameters |
|
---------- |
|
bool_answer : bool |
|
a boolean value |
|
|
|
Returns |
|
------- |
|
str |
|
a yes/no answer to a question based on the boolean value of bool_answer |
|
""" |
|
return "yes" if bool_answer else "no" |
|
|
|
|
|
def llm_query(query, context=None, long_answer=True, queues=None): |
|
"""Answers a text question using GPT-3. The input question is always a formatted string with a variable in it. |
|
|
|
Parameters |
|
---------- |
|
query: str |
|
the text question to ask. Must not contain any reference to 'the image' or 'the photo', etc. |
|
""" |
|
if long_answer: |
|
return forward(model_name='gpt3_general', prompt=query, queues=queues) |
|
else: |
|
return forward(model_name='gpt3_qa', prompt=[query, context], queues=queues) |
|
|
|
|
|
def process_guesses(prompt, guess1=None, guess2=None, queues=None): |
|
return forward(model_name='gpt3_guess', prompt=[prompt, guess1, guess2], queues=queues) |
|
|
|
|
|
def coerce_to_numeric(string, no_string=False): |
|
""" |
|
This function takes a string as input and returns a numeric value after removing any non-numeric characters. |
|
If the input string contains a range (e.g. "10-15"), it returns the first value in the range. |
|
# TODO: Cases like '25to26' return 2526, which is not correct. |
|
""" |
|
if any(month in string.lower() for month in ['january', 'february', 'march', 'april', 'may', 'june', 'july', |
|
'august', 'september', 'october', 'november', 'december']): |
|
try: |
|
return dateparser.parse(string).timestamp().year |
|
except: |
|
pass |
|
|
|
try: |
|
|
|
numeric = w2n.word_to_num(string) |
|
return numeric |
|
except ValueError: |
|
pass |
|
|
|
|
|
string_re = re.sub("[^0-9\.\-]", "", string) |
|
|
|
if string_re.startswith('-'): |
|
string_re = '&' + string_re[1:] |
|
|
|
|
|
if "-" in string_re: |
|
|
|
parts = string_re.split("-") |
|
return coerce_to_numeric(parts[0].replace('&', '-')) |
|
else: |
|
string_re = string_re.replace('&', '-') |
|
|
|
try: |
|
|
|
if "." in string_re: |
|
numeric = float(string_re) |
|
else: |
|
numeric = int(string_re) |
|
except: |
|
if no_string: |
|
raise ValueError |
|
|
|
return string |
|
return numeric |
|
|