DeticChatGPT / tools /get_imagenet_21k_full_tar_json.py
taesiri's picture
Duplicate from akhaliq/Detic
6e14436
raw
history blame
2.66 kB
# Copyright (c) Facebook, Inc. and its affiliates.
import argparse
import json
import numpy as np
import pickle
import io
import gzip
import sys
import time
from nltk.corpus import wordnet
from tqdm import tqdm
import operator
import torch
sys.path.insert(0, 'third_party/CenterNet2/projects/CenterNet2/')
sys.path.insert(0, 'third_party/Deformable-DETR')
from detic.data.tar_dataset import DiskTarDataset, _TarDataset
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--imagenet_dir", default='datasets/imagenet/ImageNet-21k/')
parser.add_argument("--tarfile_path", default='datasets/imagenet/metadata-22k/tar_files.npy')
parser.add_argument("--tar_index_dir", default='datasets/imagenet/metadata-22k/tarindex_npy')
parser.add_argument("--out_path", default='datasets/imagenet/annotations/imagenet-22k_image_info.json')
parser.add_argument("--workers", default=16, type=int)
args = parser.parse_args()
start_time = time.time()
print('Building dataset')
dataset = DiskTarDataset(args.tarfile_path, args.tar_index_dir)
end_time = time.time()
print(f"Took {end_time-start_time} seconds to make the dataset.")
print(f"Have {len(dataset)} samples.")
print('dataset', dataset)
tar_files = np.load(args.tarfile_path)
categories = []
for i, tar_file in enumerate(tar_files):
wnid = tar_file[-13:-4]
synset = wordnet.synset_from_pos_and_offset('n', int(wnid[1:]))
synonyms = [x.name() for x in synset.lemmas()]
category = {
'id': i + 1,
'synset': synset.name(),
'name': synonyms[0],
'def': synset.definition(),
'synonyms': synonyms,
}
categories.append(category)
print('categories', len(categories))
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=1, shuffle=False,
num_workers=args.workers,
collate_fn=operator.itemgetter(0),
)
images = []
for img, label, index in tqdm(data_loader):
if label == -1:
continue
image = {
'id': int(index) + 1,
'pos_category_ids': [int(label) + 1],
'height': int(img.height),
'width': int(img.width),
'tar_index': int(index),
}
images.append(image)
data = {'categories': categories, 'images': images, 'annotations': []}
try:
for k, v in data.items():
print(k, len(v))
print('Saving to ', args.out_path)
json.dump(data, open(args.out_path, 'w'))
except:
pass
import pdb; pdb.set_trace()