nbroad HF staff commited on
Commit
43f37f9
·
1 Parent(s): 6f9442d

fix download wikipedia

Browse files

specify number to embed, to skip

Files changed (1) hide show
  1. utils.py +25 -9
utils.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  import time
3
  import shutil
4
  from pathlib import Path
@@ -107,33 +108,48 @@ def load_hf_dataset(ds_name: str, ds_config: str = None, ds_split: str = "train"
107
  if ds_config == "":
108
  ds_config = None
109
 
110
- ds = load_dataset(ds_name, ds_config, split=ds_split, )
111
- #streaming=True)
 
 
 
 
 
 
112
 
113
  return ds
114
 
115
- def download_wikipedia(ds_name, ds_config):
116
  ds = load_dataset(ds_name, ds_config, streaming=True, split="train")
117
 
118
  def gen():
119
- for example in ds:
120
- yield {"text": example["text"]}
 
 
 
 
 
121
 
122
  ds2 = Dataset.from_generator(gen)
123
 
124
- chunk_size = 200_000
125
 
126
  filenames = []
127
 
128
- Path("wiki_chunks").mkdir(exist_ok=True)
 
 
 
 
129
 
130
  for chunk_num, start_idx in enumerate(range(0, len(ds2), chunk_size)):
131
  end_idx = min(start_idx + chunk_size, len(ds2))
132
 
133
  temp = ds2.select(range(start_idx, end_idx))
134
 
135
- temp.to_parquet(f"/data/wiki_chunks/chunk_{chunk_num}")
136
- filenames.append(f"/data/wiki_chunks/chunk_{chunk_num}")
137
 
138
  return load_dataset("parquet", data_files=filenames, split="train")
139
 
 
1
  import os
2
+ import re
3
  import time
4
  import shutil
5
  from pathlib import Path
 
108
  if ds_config == "":
109
  ds_config = None
110
 
111
+ if ds_name == "wikipedia":
112
+ pattern = re.compile(r"[^a-zA-Z0-9]")
113
+ folder = Path("/data") / pattern.sub("", ds_name+ds_config)
114
+ files = list(map(str, folder.glob("chunk_")))
115
+
116
+ return load_dataset("parquet", data_files=files, split="train")
117
+
118
+ ds = load_dataset(ds_name, ds_config, split=ds_split)
119
 
120
  return ds
121
 
122
+ def download_wikipedia(ds_name, ds_config, num2skip, num2embed):
123
  ds = load_dataset(ds_name, ds_config, streaming=True, split="train")
124
 
125
  def gen():
126
+ if num2embed > 0:
127
+
128
+ for example in ds.skip(num2skip).take(num2embed):
129
+ yield {"text": example["text"]}
130
+ else:
131
+ for example in ds.skip(num2skip):
132
+ yield {"text": example["text"]}
133
 
134
  ds2 = Dataset.from_generator(gen)
135
 
136
+ chunk_size = 20_000
137
 
138
  filenames = []
139
 
140
+ pattern = re.compile(r"[^a-zA-Z0-9]")
141
+
142
+ folder = Path("/data") / pattern.sub("", ds_name+ds_config)
143
+
144
+ folder.mkdir(exist_ok=True, parents=True)
145
 
146
  for chunk_num, start_idx in enumerate(range(0, len(ds2), chunk_size)):
147
  end_idx = min(start_idx + chunk_size, len(ds2))
148
 
149
  temp = ds2.select(range(start_idx, end_idx))
150
 
151
+ temp.to_parquet(str(folder / f"chunk_{chunk_num}"))
152
+ filenames.append(str(folder / f"chunk_{chunk_num}"))
153
 
154
  return load_dataset("parquet", data_files=filenames, split="train")
155