Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import re | |
import warnings | |
from typing import Tuple, Union | |
import torch | |
from torch import Tensor | |
from mmdet.registry import MODELS | |
from mmdet.structures import SampleList | |
from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig | |
from .single_stage import SingleStageDetector | |
def find_noun_phrases(caption: str) -> list: | |
"""Find noun phrases in a caption using nltk. | |
Args: | |
caption (str): The caption to analyze. | |
Returns: | |
list: List of noun phrases found in the caption. | |
Examples: | |
>>> caption = 'There is two cat and a remote in the picture' | |
>>> find_noun_phrases(caption) # ['cat', 'a remote', 'the picture'] | |
""" | |
try: | |
import nltk | |
nltk.download('punkt') | |
nltk.download('averaged_perceptron_tagger') | |
except ImportError: | |
raise RuntimeError('nltk is not installed, please install it by: ' | |
'pip install nltk.') | |
caption = caption.lower() | |
tokens = nltk.word_tokenize(caption) | |
pos_tags = nltk.pos_tag(tokens) | |
grammar = 'NP: {<DT>?<JJ.*>*<NN.*>+}' | |
cp = nltk.RegexpParser(grammar) | |
result = cp.parse(pos_tags) | |
noun_phrases = [] | |
for subtree in result.subtrees(): | |
if subtree.label() == 'NP': | |
noun_phrases.append(' '.join(t[0] for t in subtree.leaves())) | |
return noun_phrases | |
def remove_punctuation(text: str) -> str: | |
"""Remove punctuation from a text. | |
Args: | |
text (str): The input text. | |
Returns: | |
str: The text with punctuation removed. | |
""" | |
punctuation = [ | |
'|', ':', ';', '@', '(', ')', '[', ']', '{', '}', '^', '\'', '\"', '’', | |
'`', '?', '$', '%', '#', '!', '&', '*', '+', ',', '.' | |
] | |
for p in punctuation: | |
text = text.replace(p, '') | |
return text.strip() | |
def run_ner(caption: str) -> Tuple[list, list]: | |
"""Run NER on a caption and return the tokens and noun phrases. | |
Args: | |
caption (str): The input caption. | |
Returns: | |
Tuple[List, List]: A tuple containing the tokens and noun phrases. | |
- tokens_positive (List): A list of token positions. | |
- noun_phrases (List): A list of noun phrases. | |
""" | |
noun_phrases = find_noun_phrases(caption) | |
noun_phrases = [remove_punctuation(phrase) for phrase in noun_phrases] | |
noun_phrases = [phrase for phrase in noun_phrases if phrase != ''] | |
relevant_phrases = noun_phrases | |
labels = noun_phrases | |
tokens_positive = [] | |
for entity, label in zip(relevant_phrases, labels): | |
try: | |
# search all occurrences and mark them as different entities | |
# TODO: Not Robust | |
for m in re.finditer(entity, caption.lower()): | |
tokens_positive.append([[m.start(), m.end()]]) | |
except Exception: | |
print('noun entities:', noun_phrases) | |
print('entity:', entity) | |
print('caption:', caption.lower()) | |
return tokens_positive, noun_phrases | |
def create_positive_map(tokenized, | |
tokens_positive: list, | |
max_num_entities: int = 256) -> Tensor: | |
"""construct a map such that positive_map[i,j] = True | |
if box i is associated to token j | |
Args: | |
tokenized: The tokenized input. | |
tokens_positive (list): A list of token ranges | |
associated with positive boxes. | |
max_num_entities (int, optional): The maximum number of entities. | |
Defaults to 256. | |
Returns: | |
torch.Tensor: The positive map. | |
Raises: | |
Exception: If an error occurs during token-to-char mapping. | |
""" | |
positive_map = torch.zeros((len(tokens_positive), max_num_entities), | |
dtype=torch.float) | |
for j, tok_list in enumerate(tokens_positive): | |
for (beg, end) in tok_list: | |
try: | |
beg_pos = tokenized.char_to_token(beg) | |
end_pos = tokenized.char_to_token(end - 1) | |
except Exception as e: | |
print('beg:', beg, 'end:', end) | |
print('token_positive:', tokens_positive) | |
raise e | |
if beg_pos is None: | |
try: | |
beg_pos = tokenized.char_to_token(beg + 1) | |
if beg_pos is None: | |
beg_pos = tokenized.char_to_token(beg + 2) | |
except Exception: | |
beg_pos = None | |
if end_pos is None: | |
try: | |
end_pos = tokenized.char_to_token(end - 2) | |
if end_pos is None: | |
end_pos = tokenized.char_to_token(end - 3) | |
except Exception: | |
end_pos = None | |
if beg_pos is None or end_pos is None: | |
continue | |
assert beg_pos is not None and end_pos is not None | |
positive_map[j, beg_pos:end_pos + 1].fill_(1) | |
return positive_map / (positive_map.sum(-1)[:, None] + 1e-6) | |
def create_positive_map_label_to_token(positive_map: Tensor, | |
plus: int = 0) -> dict: | |
"""Create a dictionary mapping the label to the token. | |
Args: | |
positive_map (Tensor): The positive map tensor. | |
plus (int, optional): Value added to the label for indexing. | |
Defaults to 0. | |
Returns: | |
dict: The dictionary mapping the label to the token. | |
""" | |
positive_map_label_to_token = {} | |
for i in range(len(positive_map)): | |
positive_map_label_to_token[i + plus] = torch.nonzero( | |
positive_map[i], as_tuple=True)[0].tolist() | |
return positive_map_label_to_token | |
class GLIP(SingleStageDetector): | |
"""Implementation of `GLIP <https://arxiv.org/abs/2112.03857>`_ | |
Args: | |
backbone (:obj:`ConfigDict` or dict): The backbone config. | |
neck (:obj:`ConfigDict` or dict): The neck config. | |
bbox_head (:obj:`ConfigDict` or dict): The bbox head config. | |
language_model (:obj:`ConfigDict` or dict): The language model config. | |
train_cfg (:obj:`ConfigDict` or dict, optional): The training config | |
of GLIP. Defaults to None. | |
test_cfg (:obj:`ConfigDict` or dict, optional): The testing config | |
of GLIP. Defaults to None. | |
data_preprocessor (:obj:`ConfigDict` or dict, optional): Config of | |
:class:`DetDataPreprocessor` to process the input data. | |
Defaults to None. | |
init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or | |
list[dict], optional): Initialization config dict. | |
Defaults to None. | |
""" | |
def __init__(self, | |
backbone: ConfigType, | |
neck: ConfigType, | |
bbox_head: ConfigType, | |
language_model: ConfigType, | |
train_cfg: OptConfigType = None, | |
test_cfg: OptConfigType = None, | |
data_preprocessor: OptConfigType = None, | |
init_cfg: OptMultiConfig = None) -> None: | |
super().__init__( | |
backbone=backbone, | |
neck=neck, | |
bbox_head=bbox_head, | |
train_cfg=train_cfg, | |
test_cfg=test_cfg, | |
data_preprocessor=data_preprocessor, | |
init_cfg=init_cfg) | |
self.language_model = MODELS.build(language_model) | |
self._special_tokens = '. ' | |
def get_tokens_and_prompts( | |
self, | |
original_caption: Union[str, list, tuple], | |
custom_entities: bool = False) -> Tuple[dict, str, list, list]: | |
"""Get the tokens positive and prompts for the caption.""" | |
if isinstance(original_caption, (list, tuple)) or custom_entities: | |
if custom_entities and isinstance(original_caption, str): | |
original_caption = original_caption.strip(self._special_tokens) | |
original_caption = original_caption.split(self._special_tokens) | |
original_caption = list( | |
filter(lambda x: len(x) > 0, original_caption)) | |
caption_string = '' | |
tokens_positive = [] | |
for idx, word in enumerate(original_caption): | |
tokens_positive.append( | |
[[len(caption_string), | |
len(caption_string) + len(word)]]) | |
caption_string += word | |
if idx != len(original_caption) - 1: | |
caption_string += self._special_tokens | |
tokenized = self.language_model.tokenizer([caption_string], | |
return_tensors='pt') | |
entities = original_caption | |
else: | |
original_caption = original_caption.strip(self._special_tokens) | |
tokenized = self.language_model.tokenizer([original_caption], | |
return_tensors='pt') | |
tokens_positive, noun_phrases = run_ner(original_caption) | |
entities = noun_phrases | |
caption_string = original_caption | |
return tokenized, caption_string, tokens_positive, entities | |
def get_positive_map(self, tokenized, tokens_positive): | |
positive_map = create_positive_map(tokenized, tokens_positive) | |
positive_map_label_to_token = create_positive_map_label_to_token( | |
positive_map, plus=1) | |
return positive_map_label_to_token, positive_map | |
def get_tokens_positive_and_prompts( | |
self, | |
original_caption: Union[str, list, tuple], | |
custom_entities: bool = False) -> Tuple[dict, str, Tensor, list]: | |
tokenized, caption_string, tokens_positive, entities = \ | |
self.get_tokens_and_prompts( | |
original_caption, custom_entities) | |
positive_map_label_to_token, positive_map = self.get_positive_map( | |
tokenized, tokens_positive) | |
return positive_map_label_to_token, caption_string, \ | |
positive_map, entities | |
def loss(self, batch_inputs: Tensor, | |
batch_data_samples: SampleList) -> Union[dict, list]: | |
# TODO: Only open vocabulary tasks are supported for training now. | |
text_prompts = [ | |
data_samples.text for data_samples in batch_data_samples | |
] | |
gt_labels = [ | |
data_samples.gt_instances.labels | |
for data_samples in batch_data_samples | |
] | |
new_text_prompts = [] | |
positive_maps = [] | |
if len(set(text_prompts)) == 1: | |
# All the text prompts are the same, | |
# so there is no need to calculate them multiple times. | |
tokenized, caption_string, tokens_positive, _ = \ | |
self.get_tokens_and_prompts( | |
text_prompts[0], True) | |
new_text_prompts = [caption_string] * len(batch_inputs) | |
for gt_label in gt_labels: | |
new_tokens_positive = [ | |
tokens_positive[label] for label in gt_label | |
] | |
_, positive_map = self.get_positive_map( | |
tokenized, new_tokens_positive) | |
positive_maps.append(positive_map) | |
else: | |
for text_prompt, gt_label in zip(text_prompts, gt_labels): | |
tokenized, caption_string, tokens_positive, _ = \ | |
self.get_tokens_and_prompts( | |
text_prompt, True) | |
new_tokens_positive = [ | |
tokens_positive[label] for label in gt_label | |
] | |
_, positive_map = self.get_positive_map( | |
tokenized, new_tokens_positive) | |
positive_maps.append(positive_map) | |
new_text_prompts.append(caption_string) | |
language_dict_features = self.language_model(new_text_prompts) | |
for i, data_samples in enumerate(batch_data_samples): | |
# .bool().float() is very important | |
positive_map = positive_maps[i].to( | |
batch_inputs.device).bool().float() | |
data_samples.gt_instances.positive_maps = positive_map | |
visual_features = self.extract_feat(batch_inputs) | |
losses = self.bbox_head.loss(visual_features, language_dict_features, | |
batch_data_samples) | |
return losses | |
def predict(self, | |
batch_inputs: Tensor, | |
batch_data_samples: SampleList, | |
rescale: bool = True) -> SampleList: | |
"""Predict results from a batch of inputs and data samples with post- | |
processing. | |
Args: | |
batch_inputs (Tensor): Inputs with shape (N, C, H, W). | |
batch_data_samples (List[:obj:`DetDataSample`]): The Data | |
Samples. It usually includes information such as | |
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. | |
rescale (bool): Whether to rescale the results. | |
Defaults to True. | |
Returns: | |
list[:obj:`DetDataSample`]: Detection results of the | |
input images. Each DetDataSample usually contain | |
'pred_instances'. And the ``pred_instances`` usually | |
contains following keys. | |
- scores (Tensor): Classification scores, has a shape | |
(num_instance, ) | |
- labels (Tensor): Labels of bboxes, has a shape | |
(num_instances, ). | |
- label_names (List[str]): Label names of bboxes. | |
- bboxes (Tensor): Has a shape (num_instances, 4), | |
the last dimension 4 arrange as (x1, y1, x2, y2). | |
""" | |
text_prompts = [ | |
data_samples.text for data_samples in batch_data_samples | |
] | |
if 'custom_entities' in batch_data_samples[0]: | |
# Assuming that the `custom_entities` flag | |
# inside a batch is always the same. For single image inference | |
custom_entities = batch_data_samples[0].custom_entities | |
else: | |
custom_entities = False | |
if len(set(text_prompts)) == 1: | |
# All the text prompts are the same, | |
# so there is no need to calculate them multiple times. | |
_positive_maps_and_prompts = [ | |
self.get_tokens_positive_and_prompts(text_prompts[0], | |
custom_entities) | |
] * len(batch_inputs) | |
else: | |
_positive_maps_and_prompts = [ | |
self.get_tokens_positive_and_prompts(text_prompt, | |
custom_entities) | |
for text_prompt in text_prompts | |
] | |
token_positive_maps, text_prompts, _, entities = zip( | |
*_positive_maps_and_prompts) | |
language_dict_features = self.language_model(list(text_prompts)) | |
for i, data_samples in enumerate(batch_data_samples): | |
data_samples.token_positive_map = token_positive_maps[i] | |
visual_features = self.extract_feat(batch_inputs) | |
results_list = self.bbox_head.predict( | |
visual_features, | |
language_dict_features, | |
batch_data_samples, | |
rescale=rescale) | |
for data_sample, pred_instances, entity in zip(batch_data_samples, | |
results_list, entities): | |
if len(pred_instances) > 0: | |
label_names = [] | |
for labels in pred_instances.labels: | |
if labels >= len(entity): | |
warnings.warn( | |
'The unexpected output indicates an issue with ' | |
'named entity recognition. You can try ' | |
'setting custom_entities=True and running ' | |
'again to see if it helps.') | |
label_names.append('unobject') | |
else: | |
label_names.append(entity[labels]) | |
# for visualization | |
pred_instances.label_names = label_names | |
data_sample.pred_instances = pred_instances | |
return batch_data_samples | |