Galuh Sahid
Add download_logs and scripts
ba7a003
raw
history blame
No virus
1.3 kB
import json
import collections
import logging
import sys
if len(sys.argv) != 4:
print("Provide .tsv file name, images dir, output file name. e.g. python coco.py coco_captions_train2017.json /mnt/disks/data-1/flickr8k/coco_train.json coco_dataset_train.json")
exit(1)
annotation_file = sys.argv[1]
images_dir = sys.argv[2]
output_file = sys.argv[3]
logging.info("Processing COCO dataset")
with open(annotation_file, "r") as f:
annotations = json.load(f)["annotations"]
image_path_to_caption = collections.defaultdict(list)
for element in annotations:
caption = f"{element['caption'].lower().rstrip('.')}"
image_path = images_dir + "/%012d.jpg" % (element["image_id"])
image_path_to_caption[image_path].append(caption)
lines = []
for image_path, captions in image_path_to_caption.items():
lines.append(json.dumps({"image_path": image_path, "captions": captions}))
train_lines = lines[:-10_001]
valid_lines = lines[-10_001:]
with open(output_file+"_train.json", "w") as f:
f.write("\n".join(train_lines))
with open(output_file+"_val.json", "w") as f:
f.write("\n".join(valid_lines))
logging.info(f"Processing COCO dataset done. {len(lines)} images processed.")
# python scripts/coco.py annotations/coco_captions_train2017.json coco_dataset_train.json