File size: 1,365 Bytes
ba7a003 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
import pandas as pd
import os.path
import sys
import json
import logging
import contexttimer
if len(sys.argv) != 4:
print("Provide .tsv file name, images dir, output file name. e.g. python coco.py coco_captions_train2017.json /mnt/disks/data-1/flickr8k/coco_train.json coco_dataset_train.json")
exit(1)
annotation_file = sys.argv[1]
images_dir = sys.argv[2]
output_file = sys.argv[3]
logging.info("Processing Flicker 30k dataset")
with contexttimer.Timer(prefix="Loading from tsv"):
df = pd.read_csv(annotation_file, delimiter='\t')
images_dict = {}
for index, caption, image_name in df.itertuples():
if image_name in images_dict:
images_dict[image_name] += [caption]
else:
images_dict[image_name] = [caption]
lines = []
for image_path, captions in images_dict.items():
full_image_path = images_dir+"/"+image_name
if os.path.isfile(full_image_path):
lines.append(json.dumps({"image_path": full_image_path, "captions": captions}))
else:
print(f"{full_image_path} doesn't exist")
train_lines = lines[:-3_001]
valid_lines = lines[-3_001:]
with open(output_file+"_train.json", "w") as f:
f.write("\n".join(train_lines))
with open(output_file+"_val.json", "w") as f:
f.write("\n".join(valid_lines))
logging.info(f"Processing Flicker 30k dataset done. {len(lines)} images processed.")
|