Spaces:
Running
on
Zero
Running
on
Zero
File size: 16,334 Bytes
28c256d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import re
import warnings
from typing import Tuple, Union
import torch
from torch import Tensor
from mmdet.registry import MODELS
from mmdet.structures import SampleList
from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig
from .single_stage import SingleStageDetector
def find_noun_phrases(caption: str) -> list:
"""Find noun phrases in a caption using nltk.
Args:
caption (str): The caption to analyze.
Returns:
list: List of noun phrases found in the caption.
Examples:
>>> caption = 'There is two cat and a remote in the picture'
>>> find_noun_phrases(caption) # ['cat', 'a remote', 'the picture']
"""
try:
import nltk
nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')
except ImportError:
raise RuntimeError('nltk is not installed, please install it by: '
'pip install nltk.')
caption = caption.lower()
tokens = nltk.word_tokenize(caption)
pos_tags = nltk.pos_tag(tokens)
grammar = 'NP: {<DT>?<JJ.*>*<NN.*>+}'
cp = nltk.RegexpParser(grammar)
result = cp.parse(pos_tags)
noun_phrases = []
for subtree in result.subtrees():
if subtree.label() == 'NP':
noun_phrases.append(' '.join(t[0] for t in subtree.leaves()))
return noun_phrases
def remove_punctuation(text: str) -> str:
"""Remove punctuation from a text.
Args:
text (str): The input text.
Returns:
str: The text with punctuation removed.
"""
punctuation = [
'|', ':', ';', '@', '(', ')', '[', ']', '{', '}', '^', '\'', '\"', '’',
'`', '?', '$', '%', '#', '!', '&', '*', '+', ',', '.'
]
for p in punctuation:
text = text.replace(p, '')
return text.strip()
def run_ner(caption: str) -> Tuple[list, list]:
"""Run NER on a caption and return the tokens and noun phrases.
Args:
caption (str): The input caption.
Returns:
Tuple[List, List]: A tuple containing the tokens and noun phrases.
- tokens_positive (List): A list of token positions.
- noun_phrases (List): A list of noun phrases.
"""
noun_phrases = find_noun_phrases(caption)
noun_phrases = [remove_punctuation(phrase) for phrase in noun_phrases]
noun_phrases = [phrase for phrase in noun_phrases if phrase != '']
relevant_phrases = noun_phrases
labels = noun_phrases
tokens_positive = []
for entity, label in zip(relevant_phrases, labels):
try:
# search all occurrences and mark them as different entities
# TODO: Not Robust
for m in re.finditer(entity, caption.lower()):
tokens_positive.append([[m.start(), m.end()]])
except Exception:
print('noun entities:', noun_phrases)
print('entity:', entity)
print('caption:', caption.lower())
return tokens_positive, noun_phrases
def create_positive_map(tokenized,
tokens_positive: list,
max_num_entities: int = 256) -> Tensor:
"""construct a map such that positive_map[i,j] = True
if box i is associated to token j
Args:
tokenized: The tokenized input.
tokens_positive (list): A list of token ranges
associated with positive boxes.
max_num_entities (int, optional): The maximum number of entities.
Defaults to 256.
Returns:
torch.Tensor: The positive map.
Raises:
Exception: If an error occurs during token-to-char mapping.
"""
positive_map = torch.zeros((len(tokens_positive), max_num_entities),
dtype=torch.float)
for j, tok_list in enumerate(tokens_positive):
for (beg, end) in tok_list:
try:
beg_pos = tokenized.char_to_token(beg)
end_pos = tokenized.char_to_token(end - 1)
except Exception as e:
print('beg:', beg, 'end:', end)
print('token_positive:', tokens_positive)
raise e
if beg_pos is None:
try:
beg_pos = tokenized.char_to_token(beg + 1)
if beg_pos is None:
beg_pos = tokenized.char_to_token(beg + 2)
except Exception:
beg_pos = None
if end_pos is None:
try:
end_pos = tokenized.char_to_token(end - 2)
if end_pos is None:
end_pos = tokenized.char_to_token(end - 3)
except Exception:
end_pos = None
if beg_pos is None or end_pos is None:
continue
assert beg_pos is not None and end_pos is not None
positive_map[j, beg_pos:end_pos + 1].fill_(1)
return positive_map / (positive_map.sum(-1)[:, None] + 1e-6)
def create_positive_map_label_to_token(positive_map: Tensor,
plus: int = 0) -> dict:
"""Create a dictionary mapping the label to the token.
Args:
positive_map (Tensor): The positive map tensor.
plus (int, optional): Value added to the label for indexing.
Defaults to 0.
Returns:
dict: The dictionary mapping the label to the token.
"""
positive_map_label_to_token = {}
for i in range(len(positive_map)):
positive_map_label_to_token[i + plus] = torch.nonzero(
positive_map[i], as_tuple=True)[0].tolist()
return positive_map_label_to_token
@MODELS.register_module()
class GLIP(SingleStageDetector):
"""Implementation of `GLIP <https://arxiv.org/abs/2112.03857>`_
Args:
backbone (:obj:`ConfigDict` or dict): The backbone config.
neck (:obj:`ConfigDict` or dict): The neck config.
bbox_head (:obj:`ConfigDict` or dict): The bbox head config.
language_model (:obj:`ConfigDict` or dict): The language model config.
train_cfg (:obj:`ConfigDict` or dict, optional): The training config
of GLIP. Defaults to None.
test_cfg (:obj:`ConfigDict` or dict, optional): The testing config
of GLIP. Defaults to None.
data_preprocessor (:obj:`ConfigDict` or dict, optional): Config of
:class:`DetDataPreprocessor` to process the input data.
Defaults to None.
init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
list[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
backbone: ConfigType,
neck: ConfigType,
bbox_head: ConfigType,
language_model: ConfigType,
train_cfg: OptConfigType = None,
test_cfg: OptConfigType = None,
data_preprocessor: OptConfigType = None,
init_cfg: OptMultiConfig = None) -> None:
super().__init__(
backbone=backbone,
neck=neck,
bbox_head=bbox_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
data_preprocessor=data_preprocessor,
init_cfg=init_cfg)
self.language_model = MODELS.build(language_model)
self._special_tokens = '. '
def get_tokens_and_prompts(
self,
original_caption: Union[str, list, tuple],
custom_entities: bool = False) -> Tuple[dict, str, list, list]:
"""Get the tokens positive and prompts for the caption."""
if isinstance(original_caption, (list, tuple)) or custom_entities:
if custom_entities and isinstance(original_caption, str):
original_caption = original_caption.strip(self._special_tokens)
original_caption = original_caption.split(self._special_tokens)
original_caption = list(
filter(lambda x: len(x) > 0, original_caption))
caption_string = ''
tokens_positive = []
for idx, word in enumerate(original_caption):
tokens_positive.append(
[[len(caption_string),
len(caption_string) + len(word)]])
caption_string += word
if idx != len(original_caption) - 1:
caption_string += self._special_tokens
tokenized = self.language_model.tokenizer([caption_string],
return_tensors='pt')
entities = original_caption
else:
original_caption = original_caption.strip(self._special_tokens)
tokenized = self.language_model.tokenizer([original_caption],
return_tensors='pt')
tokens_positive, noun_phrases = run_ner(original_caption)
entities = noun_phrases
caption_string = original_caption
return tokenized, caption_string, tokens_positive, entities
def get_positive_map(self, tokenized, tokens_positive):
positive_map = create_positive_map(tokenized, tokens_positive)
positive_map_label_to_token = create_positive_map_label_to_token(
positive_map, plus=1)
return positive_map_label_to_token, positive_map
def get_tokens_positive_and_prompts(
self,
original_caption: Union[str, list, tuple],
custom_entities: bool = False) -> Tuple[dict, str, Tensor, list]:
tokenized, caption_string, tokens_positive, entities = \
self.get_tokens_and_prompts(
original_caption, custom_entities)
positive_map_label_to_token, positive_map = self.get_positive_map(
tokenized, tokens_positive)
return positive_map_label_to_token, caption_string, \
positive_map, entities
def loss(self, batch_inputs: Tensor,
batch_data_samples: SampleList) -> Union[dict, list]:
# TODO: Only open vocabulary tasks are supported for training now.
text_prompts = [
data_samples.text for data_samples in batch_data_samples
]
gt_labels = [
data_samples.gt_instances.labels
for data_samples in batch_data_samples
]
new_text_prompts = []
positive_maps = []
if len(set(text_prompts)) == 1:
# All the text prompts are the same,
# so there is no need to calculate them multiple times.
tokenized, caption_string, tokens_positive, _ = \
self.get_tokens_and_prompts(
text_prompts[0], True)
new_text_prompts = [caption_string] * len(batch_inputs)
for gt_label in gt_labels:
new_tokens_positive = [
tokens_positive[label] for label in gt_label
]
_, positive_map = self.get_positive_map(
tokenized, new_tokens_positive)
positive_maps.append(positive_map)
else:
for text_prompt, gt_label in zip(text_prompts, gt_labels):
tokenized, caption_string, tokens_positive, _ = \
self.get_tokens_and_prompts(
text_prompt, True)
new_tokens_positive = [
tokens_positive[label] for label in gt_label
]
_, positive_map = self.get_positive_map(
tokenized, new_tokens_positive)
positive_maps.append(positive_map)
new_text_prompts.append(caption_string)
language_dict_features = self.language_model(new_text_prompts)
for i, data_samples in enumerate(batch_data_samples):
# .bool().float() is very important
positive_map = positive_maps[i].to(
batch_inputs.device).bool().float()
data_samples.gt_instances.positive_maps = positive_map
visual_features = self.extract_feat(batch_inputs)
losses = self.bbox_head.loss(visual_features, language_dict_features,
batch_data_samples)
return losses
def predict(self,
batch_inputs: Tensor,
batch_data_samples: SampleList,
rescale: bool = True) -> SampleList:
"""Predict results from a batch of inputs and data samples with post-
processing.
Args:
batch_inputs (Tensor): Inputs with shape (N, C, H, W).
batch_data_samples (List[:obj:`DetDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
rescale (bool): Whether to rescale the results.
Defaults to True.
Returns:
list[:obj:`DetDataSample`]: Detection results of the
input images. Each DetDataSample usually contain
'pred_instances'. And the ``pred_instances`` usually
contains following keys.
- scores (Tensor): Classification scores, has a shape
(num_instance, )
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- label_names (List[str]): Label names of bboxes.
- bboxes (Tensor): Has a shape (num_instances, 4),
the last dimension 4 arrange as (x1, y1, x2, y2).
"""
text_prompts = [
data_samples.text for data_samples in batch_data_samples
]
if 'custom_entities' in batch_data_samples[0]:
# Assuming that the `custom_entities` flag
# inside a batch is always the same. For single image inference
custom_entities = batch_data_samples[0].custom_entities
else:
custom_entities = False
if len(set(text_prompts)) == 1:
# All the text prompts are the same,
# so there is no need to calculate them multiple times.
_positive_maps_and_prompts = [
self.get_tokens_positive_and_prompts(text_prompts[0],
custom_entities)
] * len(batch_inputs)
else:
_positive_maps_and_prompts = [
self.get_tokens_positive_and_prompts(text_prompt,
custom_entities)
for text_prompt in text_prompts
]
token_positive_maps, text_prompts, _, entities = zip(
*_positive_maps_and_prompts)
language_dict_features = self.language_model(list(text_prompts))
for i, data_samples in enumerate(batch_data_samples):
data_samples.token_positive_map = token_positive_maps[i]
visual_features = self.extract_feat(batch_inputs)
results_list = self.bbox_head.predict(
visual_features,
language_dict_features,
batch_data_samples,
rescale=rescale)
for data_sample, pred_instances, entity in zip(batch_data_samples,
results_list, entities):
if len(pred_instances) > 0:
label_names = []
for labels in pred_instances.labels:
if labels >= len(entity):
warnings.warn(
'The unexpected output indicates an issue with '
'named entity recognition. You can try '
'setting custom_entities=True and running '
'again to see if it helps.')
label_names.append('unobject')
else:
label_names.append(entity[labels])
# for visualization
pred_instances.label_names = label_names
data_sample.pred_instances = pred_instances
return batch_data_samples
|