multimodal / open_flamingo /tools /convert_mmc4_to_wds.py
Li
init
5282eae
raw
history blame
No virus
4.83 kB
import argparse
import base64
import json
import os
import tarfile
import uuid
import zipfile
import time
import braceexpand
import webdataset as wds
from tqdm import tqdm
from tqdm.contrib.concurrent import process_map
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument("--output_dir", type=str)
arg_parser.add_argument(
"--image_shards",
type=str,
help="Pass in a list of shards in the format path_to_shard/shard_{0..23098}_images_v2.tar",
)
arg_parser.add_argument(
"--doc_shards",
type=str,
help="Pass in a list of shards in the format path_to_shard/docs_shard_{0..23098}_v2.jsonl.zip",
)
arg_parser.add_argument(
"--thread",
type=int,
default=128,
)
args = arg_parser.parse_args()
def get_txt_to_filename_dict(image_shards, disable_tqdm=False):
txt_to_filename_dict = {}
dataset = wds.WebDataset(image_shards).decode("pil").to_tuple("txt", "json")
for data in tqdm(dataset, disable=disable_tqdm):
txt = data[0].split(".")[0]
txt_to_filename_dict[txt] = data[1]['key']
return txt_to_filename_dict
def single_thread(args):
i = args["i"]
output_dir = args["output_dir"]
doc_shards = args["doc_shards"]
image_shards = args["image_shards"]
if i == 0:
tqdm.write(f"output_dir: {output_dir}")
tqdm.write(f"doc_shards: {doc_shards[:5]}")
tqdm.write(f"image_shards: {image_shards[:5]}")
with wds.ShardWriter(os.path.join(output_dir, "%09d.tar"), maxcount=1000) as sink:
sink.verbose = False
for doc_shard, image_shard in tqdm(zip(doc_shards, image_shards), disable=(i != 0), total=len(doc_shards)):
# txt_to_filename_dict = get_txt_to_filename_dict(image_shard, disable_tqdm=(i != 0))
# image_tar = tarfile.open(image_shard)
# Open the ZIP archive and extract the JSON file
with zipfile.ZipFile(doc_shard, "r") as zip_file:
# Assumes the JSON file is the first file in the archive
json_filename = zip_file.namelist()[0]
with zip_file.open(json_filename, "r") as json_file:
pbar = tqdm(json_file, disable=True)
total_num = 0
exist_num = 0
for sample_data in pbar:
# get image names from json
sample_data = json.loads(sample_data)
image_info = sample_data["image_info"]
image_names = [image["image_name"] for image in image_info]
# Add each image to the tar file
for img_idx, image_name in enumerate(image_names):
total_num += 1
try:
image = image_tar.extractfile(txt_to_filename_dict[image_name.split(".")[0]]+".jpg")
# convert to base64
image_bytes = image.read()
image_base64 = base64.b64encode(image_bytes).decode("utf-8")
exist_num += 1
except:
tqdm.write(f"{image_name.split('.')[0]}")
image_base64 = "null"
sample_data["image_info"][img_idx][
"image_base64"
] = image_base64
key_str = uuid.uuid4().hex
sink.write({"__key__": key_str, "json": sample_data})
pbar.set_description(f"{exist_num/total_num:.2f}")
# image_tar.close()
def main():
timestamp = int(time.time())
os.makedirs(args.output_dir, exist_ok=True)
os.makedirs(os.path.join(args.output_dir, str(timestamp)), exist_ok=True)
tasks = []
for i in range(args.thread):
thread_dir = os.path.join(args.output_dir, str(timestamp), str(i))
os.makedirs(thread_dir, exist_ok=True)
tasks.append({
"i": i,
"output_dir": thread_dir,
"doc_shards": [],
"image_shards": [],
})
doc_shards = list(braceexpand.braceexpand(args.doc_shards))
image_shards = list(braceexpand.braceexpand(args.image_shards))
assert len(doc_shards) == len(
image_shards
), "Each doc shards must have a corresponding image shard"
for i, (doc_shard, image_shard) in enumerate(zip(doc_shards, image_shards)):
tasks[i % args.thread]["doc_shards"].append(doc_shard)
tasks[i % args.thread]["image_shards"].append(image_shard)
# assert len(tasks) == args.thread
# process_map(single_thread, tasks, max_workers=args.thread, disable=True)
single_thread(tasks[0])
if __name__ == "__main__":
main()