multimodal / open_flamingo /tools /prepare_pile.py
Li
init
5282eae
raw
history blame
No virus
1.12 kB
import datasets
import os
from tqdm import tqdm
import webdataset as wds
import json
DATASET_ROOT = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/the_pile/all/train"
OUT_DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/the_pile"
SAMPLE_PER_SHARD = 100000
if __name__ == "__main__":
os.makedirs(OUT_DIR)
print("load dataset...")
pile = datasets.load_from_disk(DATASET_ROOT)
total_num = pile.num_rows
print("total num:", total_num)
num = 0
pbar = tqdm(total=total_num)
with wds.ShardWriter(OUT_DIR+"/%05d.tar", maxcount=SAMPLE_PER_SHARD, encoder=False) as sink:
for sample in pile.iter(4096):
for text, meta in zip(sample["text"], sample["meta"]):
pbar.update(1)
if meta.get("pile_set_name", None) == "Github":
continue
num += 1
sink.write({
'__key__': str(num),
'txt': text.encode("utf-8"),
'json': json.dumps(meta, indent=4).encode("utf-8"),
})
print(f"{num} out of {total_num} is written")