import webdataset as wds import glob import os from tqdm import tqdm from tqdm.contrib.concurrent import process_map import pickle as pkl def single_thread(filename): id_table = {} dataset = wds.WebDataset(filename).decode().to_tuple("json") for data in dataset: data = data[0] image_id = data["caption"].split(".")[0] image_key = data["key"] tarfile = os.path.basename(filename) if image_id not in id_table: id_table[image_id] = [tarfile, image_key] return id_table if __name__ == "__main__": filenames = sorted(glob.glob("/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/mmc4/images/*.tar"))[:16000] print("start from", filenames[0]) print("to", filenames[-1]) id_tables = process_map(single_thread, filenames, max_workers=64) id_table = {} for table in tqdm(id_tables): id_table.update(table) print("total unique image:", len(id_table)) pkl.dump(id_table, open("mmc4_id_table.pkl", "wb")) print("DONE")