Galuh Sahid
Add download_logs and scripts
ba7a003
raw
history blame
No virus
1.42 kB
import pandas as pd
import os.path
import sys
import json
import logging
import contexttimer
# Setup
logging.basicConfig(filename='download.log', filemode='w', level=logging.INFO)
if len(sys.argv) != 4:
print("Provide .tsv file name, images dir, output file name. e.g. python coco.py coco_captions_train2017.json /mnt/disks/data-1/flickr8k/coco_train.json coco_dataset_train.json")
exit(1)
annotation_file = sys.argv[1]
images_dir = sys.argv[2]
output_file = sys.argv[3]
logging.info("Processing cc12m dataset")
with contexttimer.Timer(prefix="Loading from tsv"):
df = pd.read_csv(annotation_file, delimiter='\t')
lines = []
df = df[["caption", "url"]]
print(f"Loaded {len(df)} images.")
for index, caption_reference_description, image_url in df.itertuples():
index+=1
base_url = os.path.basename(image_url) # extract base url
stem, ext = os.path.splitext(base_url) # split into stem and extension
filename = f'{index:08d}---{stem}.jpg'
full_image_path = images_dir+"/"+filename
if os.path.isfile(full_image_path):
lines.append(json.dumps({"image_path": full_image_path, "captions": [caption_reference_description]}))
else:
#print(f"{full_image_path} doesn't exist")
logging.error(full_image_path)
with open(output_file, "w") as f:
f.write("\n".join(lines))
logging.info(f"Processing cc12m dataset done. {len(lines)} images processed.")