Spaces:
Runtime error
Runtime error
import argparse | |
import json | |
import os | |
import tarfile | |
import tempfile | |
from typing import Dict, List | |
from loguru import logger | |
from tqdm import tqdm | |
# fmt: off | |
parser = argparse.ArgumentParser( | |
description="""Pre-process RedCaps dataset for training VirTex models - make | |
small shards of TAR files containing images and captions.""" | |
) | |
parser.add_argument( | |
"-a", "--annotations", required=True, help="Path to a RedCaps annotation file." | |
) | |
parser.add_argument( | |
"-i", "--images", default="datasets/redcaps/images", | |
help="""Path to RedCaps image directory. This directory is expected to have | |
subreddit specific sub-directories containing images.""", | |
) | |
parser.add_argument( | |
"-z", "--shard-size", type=int, default=1000, | |
help="Maximum number of RedCaps instances in a single TAR file shard.", | |
) | |
parser.add_argument( | |
"-o", "--output-prefix", required=True, | |
help="Path prefix for saving TAR file shards. For example, `/tmp/tarfiles` " | |
"will save as `/tmp/tarfiles_000000.tar`, `/tmp/tarfiles_000001.tar`, ...", | |
) | |
# fmt: on | |
def main(_A: argparse.Namespace): | |
r""" | |
Make TAR files containing images and annotations from a single RedCaps | |
annotations file. These TAR files are arranged in a way that | |
`WebDataset <https://github.com/tmbdev/webdataset>`_ can understand. | |
""" | |
ANNOTATIONS: List[Dict] = json.load(open(_A.annotations))["annotations"] | |
# Keep track of the current index of TAR file shard and dataset index. | |
SHARD_INDEX: int = 0 | |
DATASET_INDEX: int = 0 | |
# Create TAR file handle for the initial shard. | |
tar_handle = tarfile.open(f"{_A.output_prefix}_{SHARD_INDEX:0>d}.tar", "w") | |
# Keep a count of submissions that were skipped because their image was | |
# not downloaded (not present in image dir). | |
SKIPPED: int = 0 | |
for ann in tqdm(ANNOTATIONS): | |
image_path = os.path.join( | |
_A.images, ann["subreddit"], f"{ann['image_id']}.jpg" | |
) | |
# Add current image in shard if it exists. | |
if os.path.exists(image_path): | |
tar_handle.add(image_path, arcname=f"{ann['image_id']}.jpg") | |
# Save subreddit name and caption as a JSON file. | |
subreddit_and_caption = { | |
"subreddit": ann["subreddit"], "caption": ann["caption"] | |
} | |
tmpfile = tempfile.NamedTemporaryFile("w+") | |
tmpfile.write(json.dumps(subreddit_and_caption)) | |
tmpfile.seek(0) | |
tar_handle.add(tmpfile.name, arcname=f"{ann['image_id']}.json") | |
tmpfile.close() | |
DATASET_INDEX += 1 | |
# Create new shard if current shard is full. | |
if DATASET_INDEX % _A.shard_size == 0 and DATASET_INDEX > 0: | |
tar_handle.close() | |
logger.success( | |
f"Saved shard: {_A.output_prefix}_{SHARD_INDEX:0>6d}.tar" | |
) | |
SHARD_INDEX += 1 | |
# Open new TAR file shard. | |
tar_handle = tarfile.open( | |
f"{_A.output_prefix}_{SHARD_INDEX:0>6d}.tar", "w" | |
) | |
else: | |
SKIPPED += 1 | |
# Close the file handle to properly save it. | |
tar_handle.close() | |
logger.success(f"Saved shard: {_A.output_prefix}_{SHARD_INDEX:0>6d}.tar\n") | |
logger.info(f"Skipped {SKIPPED} annotations due to missing images.") | |
if __name__ == "__main__": | |
_A = parser.parse_args() | |
main(_A) | |