Spaces:
Runtime error
Runtime error
import logging | |
import os | |
import os.path | |
import math | |
from PIL import Image, ImageDraw | |
import random | |
import numpy as np | |
import torch | |
import torchvision | |
import torch.utils.data as data | |
from pycocotools import mask as coco_mask | |
from maskrcnn_benchmark.structures.bounding_box import BoxList | |
from maskrcnn_benchmark.structures.segmentation_mask import SegmentationMask | |
from maskrcnn_benchmark.data.datasets.coco import has_valid_annotation | |
from .od_to_grounding import convert_od_to_grounding_simple, check_for_positive_overflow, sanity_check_target_after_processing, convert_object_detection_to_grounding_optimized_for_od | |
import pdb | |
import json | |
class CocoGrounding(torchvision.datasets.CocoDetection): | |
def __init__(self, | |
img_folder, | |
ann_file, | |
transforms, | |
return_masks, | |
return_tokens, | |
is_train=False, | |
tokenizer=None, | |
disable_shuffle=False, | |
add_detection_prompt=False, | |
one_hot=False, | |
disable_clip_to_image=False, | |
no_minus_one_for_one_hot=False, | |
separation_tokens=" ", | |
few_shot=0, | |
no_mask_for_od=False, | |
override_category=None, | |
use_caption_prompt=False, | |
caption_prompt=None, | |
max_query_len=256, | |
special_safeguard_for_coco_grounding=False, | |
random_sample_negative=-1, | |
**kwargs | |
): | |
super(CocoGrounding, self).__init__(img_folder, ann_file) | |
self.ids = sorted(self.ids) | |
ids = [] | |
for img_id in self.ids: | |
if isinstance(img_id, str): | |
ann_ids = self.coco.getAnnIds(imgIds=[img_id], iscrowd=None) | |
else: | |
ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None) | |
anno = self.coco.loadAnns(ann_ids) | |
if has_valid_annotation(anno): | |
ids.append(img_id) | |
self.ids = ids | |
if few_shot: | |
ids = [] | |
# cats_freq = [few_shot]*len(self.coco.cats.keys()) | |
cats_freq = [few_shot]*max(list(self.coco.cats.keys())) | |
for img_id in self.ids: | |
if isinstance(img_id, str): | |
ann_ids = self.coco.getAnnIds(imgIds=[img_id], iscrowd=None) | |
else: | |
ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None) | |
anno = self.coco.loadAnns(ann_ids) | |
cat = set([ann['category_id'] for ann in anno]) #set/tuple corresponde to instance/image level | |
is_needed = sum([cats_freq[c-1]>0 for c in cat]) | |
if is_needed: | |
ids.append(img_id) | |
for c in cat: | |
cats_freq[c-1] -= 1 | |
# print(cat, cats_freq) | |
self.ids = ids | |
self.json_category_id_to_contiguous_id = { | |
v: i + 1 for i, v in enumerate(self.coco.getCatIds()) | |
} | |
self.contiguous_category_id_to_json_id = { | |
v: k for k, v in self.json_category_id_to_contiguous_id.items() | |
} | |
if override_category is not None: | |
self.coco.dataset["categories"] = override_category | |
self.use_caption_prompt = use_caption_prompt | |
self.caption_prompt = caption_prompt | |
self.special_safeguard_for_coco_grounding = special_safeguard_for_coco_grounding | |
self.random_sample_negative = random_sample_negative | |
self.ind_to_class = self.categories(no_background=False) | |
self.id_to_img_map = {k: v for k, v in enumerate(self.ids)} | |
self._transforms = transforms | |
self.max_query_len = max_query_len | |
self.prepare = ConvertCocoPolysToMask(False, return_tokens, tokenizer=tokenizer, max_query_len=max_query_len) | |
self.tokenizer = tokenizer | |
self.is_train = is_train | |
self.ind_to_class = self.categories(no_background=False) | |
self.disable_shuffle = disable_shuffle | |
self.add_detection_prompt = add_detection_prompt | |
self.one_hot = one_hot | |
self.no_minus_one_for_one_hot = no_minus_one_for_one_hot | |
self.disable_clip_to_image = disable_clip_to_image | |
self.separation_tokens = separation_tokens | |
self.no_mask_for_od = no_mask_for_od | |
self.return_masks = return_masks | |
def categories(self, no_background=True): | |
categories = self.coco.dataset["categories"] | |
label_list = {} | |
for index, i in enumerate(categories): | |
# assert(index + 1 == i["id"]) | |
if not no_background or (i["name"] != "__background__" and i['id'] != 0): | |
label_list[self.json_category_id_to_contiguous_id[i["id"]]] = i["name"] | |
return label_list | |
def get_box_mask(self, rect, img_size, mode="poly"): | |
assert mode=="poly", "Only support poly mask right now!" | |
x1, y1, x2, y2 = rect[0], rect[1], rect[2], rect[3] | |
return [[x1, y1, x1, y2, x2, y2, x2, y1]] | |
def __getitem__(self, idx): | |
img, tgt = super(CocoGrounding, self).__getitem__(idx) | |
image_id = self.ids[idx] | |
tgt = [obj for obj in tgt if obj["iscrowd"] == 0] | |
boxes = [obj["bbox"] for obj in tgt] | |
boxes = torch.as_tensor(boxes).reshape(-1, 4) # guard against no boxes | |
target = BoxList(boxes, img.size, mode="xywh").convert("xyxy") | |
classes = [obj["category_id"] for obj in tgt] | |
classes = [self.json_category_id_to_contiguous_id[c] for c in classes] | |
classes = torch.tensor(classes) | |
target.add_field("labels", classes) | |
if self.return_masks: | |
masks = [] | |
is_box_mask = [] | |
for obj, bbox in zip(tgt, target.bbox): | |
if "segmentation" in obj: | |
masks.append(obj["segmentation"]) | |
is_box_mask.append(0) | |
else: | |
masks.append(self.get_box_mask(bbox, img.size, mode="poly")) | |
is_box_mask.append(1) | |
masks = SegmentationMask(masks, img.size, mode="poly") | |
is_box_mask = torch.tensor(is_box_mask) | |
target.add_field("masks", masks) | |
target.add_field("is_box_mask", is_box_mask) | |
if not self.disable_clip_to_image: | |
target = target.clip_to_image(remove_empty=True) | |
if self.special_safeguard_for_coco_grounding: | |
# Intended for LVIS | |
assert(not self.use_caption_prompt) | |
original_box_num = len(target) | |
target, positive_caption_length = check_for_positive_overflow(target, self.ind_to_class, self.tokenizer, self.max_query_len-2) # leave some space for the special tokens | |
if len(target) < original_box_num: | |
print("WARNING: removed {} boxes due to positive caption overflow".format(original_box_num - len(target))) | |
annotations, caption, greenlight_span_for_masked_lm_objective, label_to_positions = convert_object_detection_to_grounding_optimized_for_od( | |
target=target, | |
image_id=image_id, | |
ind_to_class=self.ind_to_class, | |
disable_shuffle=self.disable_shuffle, | |
add_detection_prompt=False, | |
add_detection_prompt_advanced=False, | |
random_sample_negative=self.random_sample_negative, | |
control_probabilities=(0.0, 0.0, 1.0, 0.0), # always try to add a lot of negatives | |
restricted_negative_list=None, | |
separation_tokens=self.separation_tokens, | |
max_num_labels=-1, | |
positive_caption_length=positive_caption_length, | |
tokenizer=self.tokenizer, | |
max_seq_length=self.max_query_len-2 | |
) | |
else: | |
# Intended for COCO / ODinW | |
annotations, caption, greenlight_span_for_masked_lm_objective = convert_od_to_grounding_simple( | |
target=target, | |
image_id=image_id, | |
ind_to_class=self.ind_to_class, | |
disable_shuffle=self.disable_shuffle, | |
add_detection_prompt=self.add_detection_prompt, | |
separation_tokens=self.separation_tokens, | |
caption_prompt=self.caption_prompt if self.use_caption_prompt else None, | |
) | |
anno = {"image_id": image_id, "annotations": annotations, "caption": caption} | |
anno["greenlight_span_for_masked_lm_objective"] = greenlight_span_for_masked_lm_objective | |
if self.no_mask_for_od: | |
anno["greenlight_span_for_masked_lm_objective"].append((-1, -1, -1)) | |
img, anno = self.prepare(img, anno, box_format="xyxy") | |
# for equivalence check | |
if self.one_hot: | |
logging.info("using one hot for equivalence check.") | |
one_hot_map = torch.zeros_like(anno["positive_map"], dtype=torch.float) | |
text_mask = torch.zeros(anno["positive_map"].shape[1], dtype=torch.int64) | |
# create one hot mapping | |
for ii, cls in enumerate(classes): | |
if self.no_minus_one_for_one_hot: | |
one_hot_map[ii, cls] = 1.0 | |
else: | |
one_hot_map[ii, cls - 1] = 1.0 | |
if self.no_minus_one_for_one_hot: | |
text_mask[:] = 1 | |
else: | |
text_mask[:len(self.ind_to_class)] = 1 | |
anno["positive_map"] = one_hot_map | |
anno["text_mask"] = text_mask | |
if self._transforms is not None: | |
img, target = self._transforms(img, target) | |
# add additional property | |
for ann in anno: | |
target.add_field(ann, anno[ann]) | |
sanity_check_target_after_processing(target) | |
return img, target, idx | |
def get_img_info(self, index): | |
img_id = self.id_to_img_map[index] | |
img_data = self.coco.imgs[img_id] | |
return img_data | |
class ModulatedDataset(torchvision.datasets.CocoDetection): | |
def __init__(self, | |
img_folder, | |
ann_file, | |
transforms, | |
return_masks, | |
return_tokens, | |
is_train=False, | |
tokenizer=None, | |
disable_clip_to_image=False, | |
no_mask_for_gold=False, | |
max_query_len=256, | |
**kwargs): | |
super(ModulatedDataset, self).__init__(img_folder, ann_file) | |
self.ids = sorted(self.ids) | |
ids = [] | |
for img_id in self.ids: | |
if isinstance(img_id, str): | |
ann_ids = self.coco.getAnnIds(imgIds=[img_id], iscrowd=None) | |
else: | |
ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None) | |
anno = self.coco.loadAnns(ann_ids) | |
if has_valid_annotation(anno): | |
ids.append(img_id) | |
self.ids = ids | |
self.id_to_img_map = {k: v for k, v in enumerate(self.ids)} | |
self._transforms = transforms | |
self.max_query_len = max_query_len | |
self.prepare = ConvertCocoPolysToMask(return_masks, return_tokens, tokenizer=tokenizer, max_query_len=max_query_len) | |
self.is_train = is_train | |
self.disable_clip_to_image = disable_clip_to_image | |
self.no_mask_for_gold = no_mask_for_gold | |
def __getitem__(self, idx): | |
img, target = super(ModulatedDataset, self).__getitem__(idx) | |
image_id = self.ids[idx] | |
coco_img = self.coco.loadImgs(image_id)[0] | |
caption = coco_img["caption"] | |
dataset_name = coco_img["dataset_name"] if "dataset_name" in coco_img else None | |
anno = {"image_id": image_id, "annotations": target, "caption": caption} | |
# This dataset is used for Flickr & Mixed, so the sequence is maskable | |
anno["greenlight_span_for_masked_lm_objective"] = [(0, len(caption))] | |
if self.no_mask_for_gold: | |
anno["greenlight_span_for_masked_lm_objective"].append((-1, -1, -1)) | |
img, anno = self.prepare(img, anno) | |
# convert to BoxList (bboxes, labels) | |
boxes = torch.as_tensor(anno["boxes"]).reshape(-1, 4) # guard against no boxes | |
target = BoxList(boxes, img.size, mode="xyxy") | |
classes = anno["labels"] | |
target.add_field("labels", classes) | |
if self.prepare.return_masks: | |
target.add_field("masks", anno.pop("masks")) | |
target.add_field("is_box_mask", anno.pop("is_box_mask")) | |
if not self.disable_clip_to_image: | |
num_boxes = len(target.bbox) | |
target = target.clip_to_image(remove_empty=True) | |
assert num_boxes == len(target.bbox), "Box got removed in MixedDataset!!!" | |
# Check if bboxes are correct | |
# draw = ImageDraw.Draw(img) | |
# boxes = target.bbox | |
# for box in boxes: | |
# draw.rectangle([box[0], box[1], box[2], box[3]]) | |
# img.save('OUTPUT/images/{}.jpg'.format(idx)) | |
if self._transforms is not None: | |
img, target = self._transforms(img, target) | |
# add additional property | |
for ann in anno: | |
target.add_field(ann, anno[ann]) | |
target.add_field("dataset_name", dataset_name) | |
for extra_key in ["sentence_id", "original_img_id", "original_id", "task_id"]: | |
if extra_key in coco_img: | |
target.add_field(extra_key, coco_img[extra_key]) | |
if "tokens_positive_eval" in coco_img and not self.is_train: | |
tokenized = self.prepare.tokenizer(caption, return_tensors="pt") | |
target.add_field("positive_map_eval", create_positive_map(tokenized, coco_img["tokens_positive_eval"])) | |
target.add_field("nb_eval", len(target.get_field("positive_map_eval"))) | |
sanity_check_target_after_processing(target) | |
return img, target, idx | |
def get_img_info(self, index): | |
img_id = self.id_to_img_map[index] | |
img_data = self.coco.imgs[img_id] | |
return img_data | |
class CocoDetection(data.Dataset): | |
"""`MS Coco Detection <http://mscoco.org/dataset/#detections-challenge2016>`_ Dataset. | |
Args: | |
root (string): Root directory where images are downloaded to. | |
annFile (string): Path to json annotation file. | |
transform (callable, optional): A function/transform that takes in an PIL image | |
and returns a transformed version. E.g, ``transforms.ToTensor`` | |
target_transform (callable, optional): A function/transform that takes in the | |
target and transforms it. | |
""" | |
def __init__(self, root, annFile, transform=None, target_transform=None): | |
from pycocotools.coco import COCO | |
self.root = root | |
self.coco = COCO(annFile) | |
self.ids = list(self.coco.imgs.keys()) | |
self.transform = transform | |
self.target_transform = target_transform | |
def __getitem__(self, index, return_meta=False): | |
""" | |
Args: | |
index (int): Index | |
Returns: | |
tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``. | |
""" | |
coco = self.coco | |
img_id = self.ids[index] | |
if isinstance(img_id, str): | |
img_id = [img_id] | |
ann_ids = coco.getAnnIds(imgIds=img_id) | |
target = coco.loadAnns(ann_ids) | |
meta = coco.loadImgs(img_id)[0] | |
path = meta['file_name'] | |
img = pil_loader(os.path.join(self.root, path)) | |
if self.transform is not None: | |
img = self.transform(img) | |
if self.target_transform is not None: | |
target = self.target_transform(target) | |
if return_meta: | |
return img, target, meta | |
else: | |
return img, target | |
def __len__(self): | |
return len(self.ids) | |
def __repr__(self): | |
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' | |
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) | |
fmt_str += ' Root Location: {}\n'.format(self.root) | |
tmp = ' Transforms (if any): ' | |
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) | |
tmp = ' Target Transforms (if any): ' | |
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) | |
return fmt_str | |
class ConvertCocoPolysToMask(object): | |
def __init__(self, return_masks=False, return_tokens=False, tokenizer=None, max_query_len=256): | |
self.return_masks = return_masks | |
self.return_tokens = return_tokens | |
self.tokenizer = tokenizer | |
self.max_query_len = max_query_len | |
def get_box_mask(self, rect, img_size, mode="poly"): | |
assert mode=="poly", "Only support poly mask right now!" | |
x1, y1, x2, y2 = rect[0], rect[1], rect[2], rect[3] | |
return [[x1, y1, x1, y2, x2, y2, x2, y1]] | |
def __call__(self, image, target, ignore_box_screen=False, box_format="xywh"): | |
w, h = image.size | |
image_id = target["image_id"] | |
image_id = torch.tensor([image_id]) | |
anno = target["annotations"] | |
caption = target["caption"] if "caption" in target else None | |
label_to_positions = target.get("label_to_positions", {}) | |
greenlight_span_for_masked_lm_objective = target.get("greenlight_span_for_masked_lm_objective", None) | |
anno = [obj for obj in anno if "iscrowd" not in obj or obj["iscrowd"] == 0] | |
boxes = [obj["bbox"] for obj in anno] | |
# guard against no boxes via resizing | |
boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) | |
if box_format == "xywh": | |
boxes[:, 2:] += boxes[:, :2] - 1 # TO_REMOVE = 1 | |
boxes[:, 0::2].clamp_(min=0, max=w-1) # TO_REMOVE = 1 | |
boxes[:, 1::2].clamp_(min=0, max=h-1) # TO_REMOVE = 1 | |
classes = [obj["category_id"] for obj in anno] | |
classes = torch.tensor(classes, dtype=torch.int64) | |
if self.return_masks: | |
masks = [] | |
is_box_mask = [] | |
for obj, bbox in zip(anno, boxes): | |
if "segmentation" in obj: | |
masks.append(obj["segmentation"]) | |
is_box_mask.append(0) | |
else: | |
masks.append(self.get_box_mask(bbox, image.size, mode='poly')) | |
is_box_mask.append(1) | |
masks = SegmentationMask(masks, image.size, mode='poly') | |
is_box_mask = torch.tensor(is_box_mask) | |
keypoints = None | |
if anno and "keypoints" in anno[0]: | |
keypoints = [obj["keypoints"] for obj in anno] | |
keypoints = torch.as_tensor(keypoints, dtype=torch.float32) | |
num_keypoints = keypoints.shape[0] | |
if num_keypoints: | |
keypoints = keypoints.view(num_keypoints, -1, 3) | |
isfinal = None | |
if anno and "isfinal" in anno[0]: | |
isfinal = torch.as_tensor([obj["isfinal"] for obj in anno], dtype=torch.float) | |
tokens_positive = [] if self.return_tokens else None | |
if self.return_tokens and anno and "tokens" in anno[0]: | |
tokens_positive = [obj["tokens"] for obj in anno] | |
elif self.return_tokens and anno and "tokens_positive" in anno[0]: | |
tokens_positive = [obj["tokens_positive"] for obj in anno] | |
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) | |
boxes = boxes[keep] | |
classes = classes[keep] | |
if self.return_masks: | |
masks = masks[keep] | |
is_box_mask = is_box_mask[keep] | |
if keypoints is not None: | |
keypoints = keypoints[keep] | |
target = {} | |
target["boxes"] = boxes | |
target["labels"] = classes | |
if caption is not None: | |
target["caption"] = caption | |
if self.return_masks: | |
target["masks"] = masks | |
target["is_box_mask"] = is_box_mask | |
target["image_id"] = image_id | |
if keypoints is not None: | |
target["keypoints"] = keypoints | |
if tokens_positive is not None: | |
target["tokens_positive"] = [] | |
for i, k in enumerate(keep): | |
if k or ignore_box_screen: | |
target["tokens_positive"].append(tokens_positive[i]) | |
if isfinal is not None: | |
target["isfinal"] = isfinal | |
# for conversion to coco api | |
area = torch.tensor([obj["area"] for obj in anno]) | |
iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno]) | |
target["area"] = area[keep] | |
target["iscrowd"] = iscrowd[keep] | |
target["orig_size"] = torch.as_tensor([int(h), int(w)]) | |
target["size"] = torch.as_tensor([int(h), int(w)]) | |
if self.return_tokens and self.tokenizer is not None: | |
if not ignore_box_screen: | |
assert len(target["boxes"]) == len(target["tokens_positive"]) | |
tokenized = self.tokenizer(caption, return_tensors="pt", | |
max_length=self.max_query_len, | |
truncation=True) | |
target["positive_map"] = create_positive_map(tokenized, target["tokens_positive"]) | |
target['greenlight_map'] = create_greenlight_map(greenlight_span_for_masked_lm_objective,tokenized) | |
target["positive_map_for_od_labels"] = create_positive_map_for_od_labels(tokenized, label_to_positions) | |
original_od_label = [] | |
for obj in anno: | |
original_od_label.append( | |
obj.get("original_od_label", -10)) # NOTE: The padding value has to be not the same as -1 or -100 | |
target["original_od_label"] = torch.as_tensor(original_od_label) | |
return image, target | |
def create_greenlight_map(tok_list, tokenized): | |
# An example tok_list: | |
# [(0, 5), (10, 13), (-1, -1, -1)] | |
# The last one is a special indicator.. | |
greenlight_map = torch.zeros(256, dtype=torch.float) | |
for item in tok_list: | |
if len(item) != 2: | |
assert(len(item) == 3) | |
# Make everything unmakable | |
greenlight_map[:] = -1 | |
break | |
beg, end = item | |
beg_pos = tokenized.char_to_token(beg) | |
end_pos = tokenized.char_to_token(end - 1) | |
if beg_pos is None: | |
try: | |
beg_pos = tokenized.char_to_token(beg + 1) | |
if beg_pos is None: | |
beg_pos = tokenized.char_to_token(beg + 2) | |
except: | |
beg_pos = None | |
if end_pos is None: | |
try: | |
end_pos = tokenized.char_to_token(end - 2) | |
if end_pos is None: | |
end_pos = tokenized.char_to_token(end - 3) | |
except: | |
end_pos = None | |
if beg_pos is None or end_pos is None: | |
continue | |
assert beg_pos is not None and end_pos is not None | |
greenlight_map[beg_pos: end_pos + 1].fill_(1) | |
return greenlight_map | |
def create_positive_map_for_od_labels(tokenized, label_to_positions): | |
"""construct a map such that positive_map[i] = j, where j is the object detection label of the token i""" | |
""" | |
{3: [1: 5)} | |
256 : -1 3 3 3 3 -1 .. 8 8 .. | |
the woman in the garden | |
-1 -1 -1 -1 -1 | |
""" | |
positive_map = torch.ones(256, dtype=torch.float) * -1 # -1 means no match | |
keys = list(label_to_positions.keys()) | |
for j, key in enumerate(keys): | |
tok_list = label_to_positions[key] | |
# one label only mapps to one location | |
beg, end = tok_list | |
beg_pos = tokenized.char_to_token(beg) | |
end_pos = tokenized.char_to_token(end - 1) | |
if beg_pos is None: | |
try: | |
beg_pos = tokenized.char_to_token(beg + 1) | |
if beg_pos is None: | |
beg_pos = tokenized.char_to_token(beg + 2) | |
except: | |
beg_pos = None | |
if end_pos is None: | |
try: | |
end_pos = tokenized.char_to_token(end - 2) | |
if end_pos is None: | |
end_pos = tokenized.char_to_token(end - 3) | |
except: | |
end_pos = None | |
if beg_pos is None or end_pos is None: | |
continue | |
assert beg_pos is not None and end_pos is not None | |
positive_map[beg_pos: end_pos + 1].fill_(key) | |
return positive_map | |
def convert_coco_poly_to_mask(segmentations, height, width): | |
masks = [] | |
for polygons in segmentations: | |
rles = coco_mask.frPyObjects(polygons, height, width) | |
mask = coco_mask.decode(rles) | |
if len(mask.shape) < 3: | |
mask = mask[..., None] | |
mask = torch.as_tensor(mask, dtype=torch.uint8) | |
mask = mask.any(dim=2) | |
masks.append(mask) | |
if masks: | |
masks = torch.stack(masks, dim=0) | |
else: | |
masks = torch.zeros((0, height, width), dtype=torch.uint8) | |
return masks | |
def create_positive_map(tokenized, tokens_positive): | |
"""construct a map such that positive_map[i,j] = True iff box i is associated to token j""" | |
positive_map = torch.zeros((len(tokens_positive), 256), dtype=torch.float) | |
for j, tok_list in enumerate(tokens_positive): | |
for (beg, end) in tok_list: | |
beg_pos = tokenized.char_to_token(beg) | |
end_pos = tokenized.char_to_token(end - 1) | |
if beg_pos is None: | |
try: | |
beg_pos = tokenized.char_to_token(beg + 1) | |
if beg_pos is None: | |
beg_pos = tokenized.char_to_token(beg + 2) | |
except: | |
beg_pos = None | |
if end_pos is None: | |
try: | |
end_pos = tokenized.char_to_token(end - 2) | |
if end_pos is None: | |
end_pos = tokenized.char_to_token(end - 3) | |
except: | |
end_pos = None | |
if beg_pos is None or end_pos is None: | |
continue | |
assert beg_pos is not None and end_pos is not None | |
positive_map[j, beg_pos: end_pos + 1].fill_(1) | |
return positive_map / (positive_map.sum(-1)[:, None] + 1e-6) | |
def pil_loader(path, retry=5): | |
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) | |
ri = 0 | |
while ri < retry: | |
try: | |
with open(path, 'rb') as f: | |
img = Image.open(f) | |
return img.convert('RGB') | |
except: | |
ri += 1 | |