countgd / datasets /odvg.py
nikigoli's picture
Upload folder using huggingface_hub
a277bb8 verified
from torchvision.datasets.vision import VisionDataset
import os.path
from typing import Callable, Optional
import json
from PIL import Image
import torch
import random
import os, sys
sys.path.append(os.path.dirname(sys.path[0]))
import datasets.transforms as T
class ODVGDataset(VisionDataset):
"""
Args:
root (string): Root directory where images are downloaded to.
anno (string): Path to json annotation file.
label_map_anno (string): Path to json label mapping file. Only for Object Detection
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.PILToTensor``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
transforms (callable, optional): A function/transform that takes input sample and its target as entry
and returns a transformed version.
"""
def __init__(
self,
root: str,
anno: str,
label_map_anno: str = None,
max_labels: int = 80,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
transforms: Optional[Callable] = None,
) -> None:
super().__init__(root, transforms, transform, target_transform)
self.root = root
self.dataset_mode = "OD" if label_map_anno else "VG"
self.max_labels = max_labels
if self.dataset_mode == "OD":
self.load_label_map(label_map_anno)
self._load_metas(anno)
self.get_dataset_info()
def load_label_map(self, label_map_anno):
with open(label_map_anno, "r") as file:
self.label_map = json.load(file)
self.label_index = set(self.label_map.keys())
def _load_metas(self, anno):
with open(anno, "r") as f:
self.metas = [json.loads(line) for line in f]
def get_dataset_info(self):
print(f" == total images: {len(self)}")
if self.dataset_mode == "OD":
print(f" == total labels: {len(self.label_map)}")
def __getitem__(self, index: int):
meta = self.metas[index]
rel_path = meta["filename"]
abs_path = os.path.join(self.root, rel_path)
if not os.path.exists(abs_path):
raise FileNotFoundError(f"{abs_path} not found.")
image = Image.open(abs_path).convert("RGB")
exemplars = torch.tensor(meta["exemplars"], dtype=torch.int64)
w, h = image.size
if self.dataset_mode == "OD":
anno = meta["detection"]
instances = [obj for obj in anno["instances"]]
boxes = [obj["bbox"] for obj in instances]
# generate vg_labels
# pos bbox labels
ori_classes = [str(obj["label"]) for obj in instances]
pos_labels = set(ori_classes)
# neg bbox labels
neg_labels = self.label_index.difference(pos_labels)
vg_labels = list(pos_labels)
num_to_add = min(len(neg_labels), self.max_labels - len(pos_labels))
if num_to_add > 0:
vg_labels.extend(random.sample(neg_labels, num_to_add))
# shuffle
for i in range(len(vg_labels) - 1, 0, -1):
j = random.randint(0, i)
vg_labels[i], vg_labels[j] = vg_labels[j], vg_labels[i]
caption_list = [self.label_map[lb] for lb in vg_labels]
caption_dict = {item: index for index, item in enumerate(caption_list)}
caption = " . ".join(caption_list) + " ."
classes = [
caption_dict[self.label_map[str(obj["label"])]] for obj in instances
]
boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
classes = torch.tensor(classes, dtype=torch.int64)
elif self.dataset_mode == "VG":
anno = meta["grounding"]
instances = [obj for obj in anno["regions"]]
boxes = [obj["bbox"] for obj in instances]
caption_list = [obj["phrase"] for obj in instances]
c = list(zip(boxes, caption_list))
random.shuffle(c)
boxes[:], caption_list[:] = zip(*c)
uni_caption_list = list(set(caption_list))
label_map = {}
for idx in range(len(uni_caption_list)):
label_map[uni_caption_list[idx]] = idx
classes = [label_map[cap] for cap in caption_list]
caption = " . ".join(uni_caption_list) + " ."
boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
classes = torch.tensor(classes, dtype=torch.int64)
caption_list = uni_caption_list
target = {}
target["size"] = torch.as_tensor([int(h), int(w)])
target["cap_list"] = caption_list
target["caption"] = caption
target["boxes"] = boxes
target["labels"] = classes
target["exemplars"] = exemplars
target["labels_uncropped"] = torch.clone(classes)
# size, cap_list, caption, bboxes, labels
if self.transforms is not None:
image, target = self.transforms(image, target)
# Check that transforms does not change the identity of target['labels'].
if len(target["labels"]) > 0:
assert target["labels"][0] == target["labels_uncropped"][0]
print(
"Asserted that transforms does not change the identity of target['labels']."
)
return image, target
def __len__(self) -> int:
return len(self.metas)
def make_coco_transforms(image_set, fix_size=False, strong_aug=False, args=None):
normalize = T.Compose(
[T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]
)
# config the params for data aug
scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800]
max_size = 1333
scales2_resize = [400, 500, 600]
scales2_crop = [384, 600]
# update args from config files
scales = getattr(args, "data_aug_scales", scales)
max_size = getattr(args, "data_aug_max_size", max_size)
scales2_resize = getattr(args, "data_aug_scales2_resize", scales2_resize)
scales2_crop = getattr(args, "data_aug_scales2_crop", scales2_crop)
# resize them
data_aug_scale_overlap = getattr(args, "data_aug_scale_overlap", None)
if data_aug_scale_overlap is not None and data_aug_scale_overlap > 0:
data_aug_scale_overlap = float(data_aug_scale_overlap)
scales = [int(i * data_aug_scale_overlap) for i in scales]
max_size = int(max_size * data_aug_scale_overlap)
scales2_resize = [int(i * data_aug_scale_overlap) for i in scales2_resize]
scales2_crop = [int(i * data_aug_scale_overlap) for i in scales2_crop]
# datadict_for_print = {
# 'scales': scales,
# 'max_size': max_size,
# 'scales2_resize': scales2_resize,
# 'scales2_crop': scales2_crop
# }
# print("data_aug_params:", json.dumps(datadict_for_print, indent=2))
if image_set == "train":
if fix_size:
return T.Compose(
[
T.RandomHorizontalFlip(),
T.RandomResize([(max_size, max(scales))]),
normalize,
]
)
if strong_aug:
import datasets.sltransform as SLT
return T.Compose(
[
T.RandomHorizontalFlip(),
T.RandomSelect(
T.RandomResize(scales, max_size=max_size),
T.Compose(
[
T.RandomResize(scales2_resize),
T.RandomSizeCrop(*scales2_crop),
T.RandomResize(scales, max_size=max_size),
]
),
),
SLT.RandomSelectMulti(
[
SLT.RandomCrop(),
SLT.LightingNoise(),
SLT.AdjustBrightness(2),
SLT.AdjustContrast(2),
]
),
normalize,
]
)
return T.Compose(
[
T.RandomHorizontalFlip(),
T.RandomSelect(
T.RandomResize(scales, max_size=max_size),
T.Compose(
[
T.RandomResize(scales2_resize),
T.RandomSizeCrop(*scales2_crop),
T.RandomResize(scales, max_size=max_size),
]
),
),
normalize,
]
)
if image_set in ["val", "eval_debug", "train_reg", "test"]:
if os.environ.get("GFLOPS_DEBUG_SHILONG", False) == "INFO":
print("Under debug mode for flops calculation only!!!!!!!!!!!!!!!!")
return T.Compose(
[
T.ResizeDebug((1280, 800)),
normalize,
]
)
return T.Compose(
[
T.RandomResize([max(scales)], max_size=max_size),
normalize,
]
)
raise ValueError(f"unknown {image_set}")
def build_odvg(image_set, args, datasetinfo):
img_folder = datasetinfo["root"]
ann_file = datasetinfo["anno"]
label_map = datasetinfo["label_map"] if "label_map" in datasetinfo else None
try:
strong_aug = args.strong_aug
except:
strong_aug = False
print(img_folder, ann_file, label_map)
dataset = ODVGDataset(
img_folder,
ann_file,
label_map,
max_labels=args.max_labels,
transforms=make_coco_transforms(
image_set, fix_size=args.fix_size, strong_aug=strong_aug, args=args
),
)
return dataset
if __name__ == "__main__":
dataset_vg = ODVGDataset(
"path/GRIT-20M/data/",
"path/GRIT-20M/anno/grit_odvg_10k.jsonl",
)
print(len(dataset_vg))
data = dataset_vg[random.randint(0, 100)]
print(data)
dataset_od = ODVGDataset(
"pathl/V3Det/",
"path/V3Det/annotations/v3det_2023_v1_all_odvg.jsonl",
"path/V3Det/annotations/v3det_label_map.json",
)
print(len(dataset_od))
data = dataset_od[random.randint(0, 100)]
print(data)