Spaces:
Runtime error
Runtime error
File size: 3,429 Bytes
a5f8a35 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
import argparse
import json
import os
import tarfile
import tempfile
from typing import Dict, List
from loguru import logger
from tqdm import tqdm
# fmt: off
parser = argparse.ArgumentParser(
description="""Pre-process RedCaps dataset for training VirTex models - make
small shards of TAR files containing images and captions."""
)
parser.add_argument(
"-a", "--annotations", required=True, help="Path to a RedCaps annotation file."
)
parser.add_argument(
"-i", "--images", default="datasets/redcaps/images",
help="""Path to RedCaps image directory. This directory is expected to have
subreddit specific sub-directories containing images.""",
)
parser.add_argument(
"-z", "--shard-size", type=int, default=1000,
help="Maximum number of RedCaps instances in a single TAR file shard.",
)
parser.add_argument(
"-o", "--output-prefix", required=True,
help="Path prefix for saving TAR file shards. For example, `/tmp/tarfiles` "
"will save as `/tmp/tarfiles_000000.tar`, `/tmp/tarfiles_000001.tar`, ...",
)
# fmt: on
def main(_A: argparse.Namespace):
r"""
Make TAR files containing images and annotations from a single RedCaps
annotations file. These TAR files are arranged in a way that
`WebDataset <https://github.com/tmbdev/webdataset>`_ can understand.
"""
ANNOTATIONS: List[Dict] = json.load(open(_A.annotations))["annotations"]
# Keep track of the current index of TAR file shard and dataset index.
SHARD_INDEX: int = 0
DATASET_INDEX: int = 0
# Create TAR file handle for the initial shard.
tar_handle = tarfile.open(f"{_A.output_prefix}_{SHARD_INDEX:0>d}.tar", "w")
# Keep a count of submissions that were skipped because their image was
# not downloaded (not present in image dir).
SKIPPED: int = 0
for ann in tqdm(ANNOTATIONS):
image_path = os.path.join(
_A.images, ann["subreddit"], f"{ann['image_id']}.jpg"
)
# Add current image in shard if it exists.
if os.path.exists(image_path):
tar_handle.add(image_path, arcname=f"{ann['image_id']}.jpg")
# Save subreddit name and caption as a JSON file.
subreddit_and_caption = {
"subreddit": ann["subreddit"], "caption": ann["caption"]
}
tmpfile = tempfile.NamedTemporaryFile("w+")
tmpfile.write(json.dumps(subreddit_and_caption))
tmpfile.seek(0)
tar_handle.add(tmpfile.name, arcname=f"{ann['image_id']}.json")
tmpfile.close()
DATASET_INDEX += 1
# Create new shard if current shard is full.
if DATASET_INDEX % _A.shard_size == 0 and DATASET_INDEX > 0:
tar_handle.close()
logger.success(
f"Saved shard: {_A.output_prefix}_{SHARD_INDEX:0>6d}.tar"
)
SHARD_INDEX += 1
# Open new TAR file shard.
tar_handle = tarfile.open(
f"{_A.output_prefix}_{SHARD_INDEX:0>6d}.tar", "w"
)
else:
SKIPPED += 1
# Close the file handle to properly save it.
tar_handle.close()
logger.success(f"Saved shard: {_A.output_prefix}_{SHARD_INDEX:0>6d}.tar\n")
logger.info(f"Skipped {SKIPPED} annotations due to missing images.")
if __name__ == "__main__":
_A = parser.parse_args()
main(_A)
|