import json import logging import sys import os.path if len(sys.argv) != 4: print("Provide .tsv file name, images dir, output file name. e.g. python coco.py coco_captions_train2017.json /mnt/disks/data-1/flickr8k/coco_train.json coco_dataset_train.json") exit(1) annotation_file = sys.argv[1] images_dir = sys.argv[2] output_file = sys.argv[3] logging.info("Processing Flicker 8k dataset") with open(annotation_file, "r") as f: annotations = json.load(f) lines = [] for image_path, captions in annotations.items(): edited_captions = [] for caption in captions: if len(caption) > 0: edited_captions.append(caption.replace(" ", "").replace(" ", "")) full_image_path = images_dir+"/"+image_path if os.path.isfile(full_image_path): if len(edited_captions) > 0: lines.append(json.dumps({"image_path": full_image_path, "captions": edited_captions})) else: print(f"{full_image_path} doesn't exist") train_lines = lines[:-801] valid_lines = lines[-801:] with open(output_file+"_train.json", "w") as f: f.write("\n".join(train_lines)) with open(output_file+"_val.json", "w") as f: f.write("\n".join(valid_lines)) logging.info(f"Processing Flicker 8k dataset done. {len(lines)} images processed.")