import argparse import base64 import json import os import tarfile import uuid import zipfile import time import braceexpand import webdataset as wds from tqdm import tqdm from tqdm.contrib.concurrent import process_map arg_parser = argparse.ArgumentParser() arg_parser.add_argument("--output_dir", type=str) arg_parser.add_argument( "--image_shards", type=str, help="Pass in a list of shards in the format path_to_shard/shard_{0..23098}_images_v2.tar", ) arg_parser.add_argument( "--doc_shards", type=str, help="Pass in a list of shards in the format path_to_shard/docs_shard_{0..23098}_v2.jsonl.zip", ) arg_parser.add_argument( "--thread", type=int, default=128, ) args = arg_parser.parse_args() def get_txt_to_filename_dict(image_shards, disable_tqdm=False): txt_to_filename_dict = {} dataset = wds.WebDataset(image_shards).decode("pil").to_tuple("txt", "json") for data in tqdm(dataset, disable=disable_tqdm): txt = data[0].split(".")[0] txt_to_filename_dict[txt] = data[1]['key'] return txt_to_filename_dict def single_thread(args): i = args["i"] output_dir = args["output_dir"] doc_shards = args["doc_shards"] image_shards = args["image_shards"] if i == 0: tqdm.write(f"output_dir: {output_dir}") tqdm.write(f"doc_shards: {doc_shards[:5]}") tqdm.write(f"image_shards: {image_shards[:5]}") with wds.ShardWriter(os.path.join(output_dir, "%09d.tar"), maxcount=1000) as sink: sink.verbose = False for doc_shard, image_shard in tqdm(zip(doc_shards, image_shards), disable=(i != 0), total=len(doc_shards)): # txt_to_filename_dict = get_txt_to_filename_dict(image_shard, disable_tqdm=(i != 0)) # image_tar = tarfile.open(image_shard) # Open the ZIP archive and extract the JSON file with zipfile.ZipFile(doc_shard, "r") as zip_file: # Assumes the JSON file is the first file in the archive json_filename = zip_file.namelist()[0] with zip_file.open(json_filename, "r") as json_file: pbar = tqdm(json_file, disable=True) total_num = 0 exist_num = 0 for sample_data in pbar: # get image names from json sample_data = json.loads(sample_data) image_info = sample_data["image_info"] image_names = [image["image_name"] for image in image_info] # Add each image to the tar file for img_idx, image_name in enumerate(image_names): total_num += 1 try: image = image_tar.extractfile(txt_to_filename_dict[image_name.split(".")[0]]+".jpg") # convert to base64 image_bytes = image.read() image_base64 = base64.b64encode(image_bytes).decode("utf-8") exist_num += 1 except: tqdm.write(f"{image_name.split('.')[0]}") image_base64 = "null" sample_data["image_info"][img_idx][ "image_base64" ] = image_base64 key_str = uuid.uuid4().hex sink.write({"__key__": key_str, "json": sample_data}) pbar.set_description(f"{exist_num/total_num:.2f}") # image_tar.close() def main(): timestamp = int(time.time()) os.makedirs(args.output_dir, exist_ok=True) os.makedirs(os.path.join(args.output_dir, str(timestamp)), exist_ok=True) tasks = [] for i in range(args.thread): thread_dir = os.path.join(args.output_dir, str(timestamp), str(i)) os.makedirs(thread_dir, exist_ok=True) tasks.append({ "i": i, "output_dir": thread_dir, "doc_shards": [], "image_shards": [], }) doc_shards = list(braceexpand.braceexpand(args.doc_shards)) image_shards = list(braceexpand.braceexpand(args.image_shards)) assert len(doc_shards) == len( image_shards ), "Each doc shards must have a corresponding image shard" for i, (doc_shard, image_shard) in enumerate(zip(doc_shards, image_shards)): tasks[i % args.thread]["doc_shards"].append(doc_shard) tasks[i % args.thread]["image_shards"].append(image_shard) # assert len(tasks) == args.thread # process_map(single_thread, tasks, max_workers=args.thread, disable=True) single_thread(tasks[0]) if __name__ == "__main__": main()