virtex-redcaps / virtex /scripts /preprocess /preprocess_redcaps.py
zamborg's picture
added datasets and virtex
a5f8a35
raw
history blame
3.43 kB
import argparse
import json
import os
import tarfile
import tempfile
from typing import Dict, List
from loguru import logger
from tqdm import tqdm
# fmt: off
parser = argparse.ArgumentParser(
description="""Pre-process RedCaps dataset for training VirTex models - make
small shards of TAR files containing images and captions."""
)
parser.add_argument(
"-a", "--annotations", required=True, help="Path to a RedCaps annotation file."
)
parser.add_argument(
"-i", "--images", default="datasets/redcaps/images",
help="""Path to RedCaps image directory. This directory is expected to have
subreddit specific sub-directories containing images.""",
)
parser.add_argument(
"-z", "--shard-size", type=int, default=1000,
help="Maximum number of RedCaps instances in a single TAR file shard.",
)
parser.add_argument(
"-o", "--output-prefix", required=True,
help="Path prefix for saving TAR file shards. For example, `/tmp/tarfiles` "
"will save as `/tmp/tarfiles_000000.tar`, `/tmp/tarfiles_000001.tar`, ...",
)
# fmt: on
def main(_A: argparse.Namespace):
r"""
Make TAR files containing images and annotations from a single RedCaps
annotations file. These TAR files are arranged in a way that
`WebDataset <https://github.com/tmbdev/webdataset>`_ can understand.
"""
ANNOTATIONS: List[Dict] = json.load(open(_A.annotations))["annotations"]
# Keep track of the current index of TAR file shard and dataset index.
SHARD_INDEX: int = 0
DATASET_INDEX: int = 0
# Create TAR file handle for the initial shard.
tar_handle = tarfile.open(f"{_A.output_prefix}_{SHARD_INDEX:0>d}.tar", "w")
# Keep a count of submissions that were skipped because their image was
# not downloaded (not present in image dir).
SKIPPED: int = 0
for ann in tqdm(ANNOTATIONS):
image_path = os.path.join(
_A.images, ann["subreddit"], f"{ann['image_id']}.jpg"
)
# Add current image in shard if it exists.
if os.path.exists(image_path):
tar_handle.add(image_path, arcname=f"{ann['image_id']}.jpg")
# Save subreddit name and caption as a JSON file.
subreddit_and_caption = {
"subreddit": ann["subreddit"], "caption": ann["caption"]
}
tmpfile = tempfile.NamedTemporaryFile("w+")
tmpfile.write(json.dumps(subreddit_and_caption))
tmpfile.seek(0)
tar_handle.add(tmpfile.name, arcname=f"{ann['image_id']}.json")
tmpfile.close()
DATASET_INDEX += 1
# Create new shard if current shard is full.
if DATASET_INDEX % _A.shard_size == 0 and DATASET_INDEX > 0:
tar_handle.close()
logger.success(
f"Saved shard: {_A.output_prefix}_{SHARD_INDEX:0>6d}.tar"
)
SHARD_INDEX += 1
# Open new TAR file shard.
tar_handle = tarfile.open(
f"{_A.output_prefix}_{SHARD_INDEX:0>6d}.tar", "w"
)
else:
SKIPPED += 1
# Close the file handle to properly save it.
tar_handle.close()
logger.success(f"Saved shard: {_A.output_prefix}_{SHARD_INDEX:0>6d}.tar\n")
logger.info(f"Skipped {SKIPPED} annotations due to missing images.")
if __name__ == "__main__":
_A = parser.parse_args()
main(_A)