Spaces:
Build error
Build error
# Copyright (c) Facebook, Inc. and its affiliates. | |
import argparse | |
import json | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--ann", default='datasets/lvis/lvis_v1_train.json') | |
parser.add_argument("--add_freq", action='store_true') | |
parser.add_argument("--r_thresh", type=int, default=10) | |
parser.add_argument("--c_thresh", type=int, default=100) | |
args = parser.parse_args() | |
print('Loading', args.ann) | |
data = json.load(open(args.ann, 'r')) | |
cats = data['categories'] | |
image_count = {x['id']: set() for x in cats} | |
ann_count = {x['id']: 0 for x in cats} | |
for x in data['annotations']: | |
image_count[x['category_id']].add(x['image_id']) | |
ann_count[x['category_id']] += 1 | |
num_freqs = {x: 0 for x in ['r', 'f', 'c']} | |
for x in cats: | |
x['image_count'] = len(image_count[x['id']]) | |
x['instance_count'] = ann_count[x['id']] | |
if args.add_freq: | |
freq = 'f' | |
if x['image_count'] < args.c_thresh: | |
freq = 'c' | |
if x['image_count'] < args.r_thresh: | |
freq = 'r' | |
x['frequency'] = freq | |
num_freqs[freq] += 1 | |
print(cats) | |
image_counts = sorted([x['image_count'] for x in cats]) | |
# print('image count', image_counts) | |
# import pdb; pdb.set_trace() | |
if args.add_freq: | |
for x in ['r', 'c', 'f']: | |
print(x, num_freqs[x]) | |
out = cats # {'categories': cats} | |
out_path = args.ann[:-5] + '_cat_info.json' | |
print('Saving to', out_path) | |
json.dump(out, open(out_path, 'w')) | |