File size: 1,365 Bytes
ba7a003
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import pandas as pd
import os.path
import sys
import json
import logging
import contexttimer

if len(sys.argv) != 4:
    print("Provide .tsv file name, images dir, output file name. e.g. python coco.py coco_captions_train2017.json /mnt/disks/data-1/flickr8k/coco_train.json coco_dataset_train.json")
    exit(1)

annotation_file = sys.argv[1]
images_dir = sys.argv[2]
output_file = sys.argv[3]

logging.info("Processing Flicker 30k dataset")

with contexttimer.Timer(prefix="Loading from tsv"):
    df = pd.read_csv(annotation_file, delimiter='\t')

images_dict = {}

for index, caption, image_name in df.itertuples():
    if image_name in images_dict:
        images_dict[image_name] += [caption]
    else:
        images_dict[image_name] = [caption]

lines = []

for image_path, captions in images_dict.items():
    full_image_path =  images_dir+"/"+image_name
    if os.path.isfile(full_image_path):
        lines.append(json.dumps({"image_path": full_image_path, "captions": captions}))
    else:
        print(f"{full_image_path} doesn't exist")

train_lines = lines[:-3_001]
valid_lines = lines[-3_001:]

with open(output_file+"_train.json", "w") as f:
    f.write("\n".join(train_lines))

with open(output_file+"_val.json", "w") as f:
    f.write("\n".join(valid_lines))

logging.info(f"Processing Flicker 30k dataset done. {len(lines)} images processed.")