jxm commited on
Commit
5d4b633
1 Parent(s): 7e5f0de

Delete misc.py

Browse files
Files changed (1) hide show
  1. misc.py +0 -514
misc.py DELETED
@@ -1,514 +0,0 @@
1
- from typing import Dict, Iterable, List, Tuple, Union
2
-
3
- import collections
4
- import functools
5
- import glob
6
- import json
7
- import hashlib
8
- import itertools
9
- import logging
10
- import multiprocessing
11
- import os
12
- import pickle
13
- import random
14
- import requests
15
- import sys
16
- import zipfile
17
-
18
- import datasets
19
- import numpy as np
20
- import safetensors
21
- import torch
22
- import tqdm
23
- import transformers
24
-
25
- from cde.lib.dist import get_num_proc, get_rank
26
-
27
-
28
- def get_cde_cache_dir() -> str:
29
- script_directory = os.path.normpath(
30
- os.path.join(
31
- os.path.dirname(os.path.abspath(__file__)),
32
- os.pardir, os.pardir,
33
- )
34
- )
35
- return os.path.join(script_directory, "data")
36
-
37
-
38
- def get_cache_location_from_kwargs(**kwargs):
39
- cache_location = os.path.join(
40
- get_cde_cache_dir(), "cluster"
41
- )
42
- os.makedirs(cache_location, exist_ok=True)
43
- return os.path.join(cache_location, md5_hash_kwargs(**kwargs))
44
-
45
-
46
- def process_qrels_uncached(corpus: datasets.Dataset, qrels: datasets.Dataset) -> Tuple[Dict[str, List[float]], Dict[str, List[str]]]:
47
- qrels_idxs = collections.defaultdict(list)
48
- qrels_scores = collections.defaultdict(list)
49
- corpus_ids = np.array(corpus['_id'])
50
- skipped_qrels = 0
51
-
52
- for ex in tqdm.tqdm(qrels, desc='processing qrels', colour='#964B00', leave=False):
53
- #
54
- # example:
55
- # {
56
- # 'query-id': 1,
57
- # 'corpus-id': 'b0680508-2019-04-18T13:48:51Z-00002-000',
58
- # 'score': 2
59
- # }
60
- #
61
- q_id = str(ex['query-id'])
62
- c_idxs = (corpus_ids == str(ex['corpus-id'])).nonzero()[0]
63
- #
64
- assert len(c_idxs) <= 1, f"error - duplicate corpus ID? (found {len(c_idxs)} matches)"
65
- #
66
- if len(c_idxs):
67
- qrels_idxs[q_id].append(c_idxs[0])
68
- qrels_scores[q_id].append(ex['score'])
69
- else:
70
- skipped_qrels += 1
71
- #
72
-
73
- if skipped_qrels > 0:
74
- logging.warning(f'Warning: Skipped {skipped_qrels}/{len(qrels)} qrels.')
75
-
76
- return qrels_idxs, qrels_scores
77
-
78
-
79
- def process_qrels(
80
- corpus: datasets.Dataset, qrels: datasets.Dataset,
81
- use_cache: bool = True
82
- ) -> Tuple[Dict[str, List[float]], Dict[str, List[str]]]:
83
- dataset_cache_file = '_'.join(
84
- (corpus.cache_files[0]['filename'], qrels.cache_files[0]['filename'])
85
- )
86
- cache_file = strip_extension(dataset_cache_file) + '_processed_qrels.p'
87
- os.makedirs(os.path.dirname(cache_file), exist_ok=True)
88
-
89
- if not (use_cache and os.path.exists(cache_file)):
90
- qrels_idxs, qrels_scores = process_qrels_uncached(
91
- corpus=corpus, qrels=qrels
92
- )
93
- if use_cache:
94
- pickle.dump((qrels_idxs, qrels_scores), open(cache_file, 'wb'))
95
- else:
96
- qrels_idxs, qrels_scores = pickle.load(open(cache_file, 'rb'))
97
-
98
- return qrels_idxs, qrels_scores
99
-
100
-
101
- def strip_extension(filename: str) -> str:
102
- """Strips file extension.
103
-
104
- Ex:
105
- >> strip_extension('/root/dir/sub/file.ext')
106
- '/root/dir/sub/file'
107
- """
108
- return os.path.splitext(filename)[0]
109
-
110
-
111
- def md5_hash(t: Tuple[str]) -> str:
112
- return hashlib.md5('__'.join(t).encode()).hexdigest()
113
-
114
-
115
- def md5_hash_kwargs(**kwargs) -> str:
116
- # We ignore special hf args that start with _ like '__cached__setup_devices'.
117
- safe_kwargs = {k: str(v) for k,v in kwargs.items() if not k.startswith('_')}
118
- s = json.dumps(safe_kwargs, sort_keys=True)
119
- return hashlib.md5(s.encode()).hexdigest()
120
-
121
- def download_url(url: str, save_path: str, chunk_size: int = 1024):
122
- """Download url with progress bar using tqdm
123
- https://stackoverflow.com/questions/15644964/python-progress-bar-and-downloads
124
- Args:
125
- url (str): downloadable url
126
- save_path (str): local path to save the downloaded file
127
- chunk_size (int, optional): chunking of files. Defaults to 1024.
128
- """
129
- r = requests.get(url, stream=True)
130
- total = int(r.headers.get('Content-Length', 0))
131
- with open(save_path, 'wb') as fd, tqdm.tqdm(
132
- desc=save_path,
133
- total=total,
134
- unit='iB',
135
- unit_scale=True,
136
- unit_divisor=chunk_size,
137
- ) as bar:
138
- for data in r.iter_content(chunk_size=chunk_size):
139
- size = fd.write(data)
140
- bar.update(size)
141
-
142
-
143
- def unzip(zip_file: str, out_dir: str):
144
- print("unzipping =>", zip_file)
145
- zip_ = zipfile.ZipFile(zip_file, "r")
146
- zip_.extractall(path=out_dir)
147
- zip_.close()
148
-
149
-
150
- def download_url_and_unzip(url: str, out_dir: str, chunk_size: int = 1024) -> str:
151
- os.makedirs(out_dir, exist_ok=True)
152
- dataset = url.split("/")[-1]
153
- zip_file = os.path.join(out_dir, dataset)
154
-
155
- if not os.path.isfile(zip_file):
156
- logging.info("Downloading {} ...".format(dataset))
157
- download_url(url, zip_file, chunk_size)
158
-
159
- if not os.path.isdir(zip_file.replace(".zip", "")):
160
- logging.info("Unzipping {} ...".format(dataset))
161
- unzip(zip_file, out_dir)
162
-
163
- return os.path.join(out_dir, dataset.replace(".zip", ""))
164
-
165
-
166
- def tqdm_if_main_worker(iterable: Iterable, **kwargs) -> Iterable:
167
- if get_rank() == 0:
168
- return tqdm.tqdm(iterable, **kwargs)
169
- else:
170
- return iterable
171
-
172
-
173
- class ContextualModelConfig(transformers.configuration_utils.PretrainedConfig):
174
- """We create a dummy configuration class that will just set properties
175
- based on whatever kwargs we pass in.
176
-
177
- When this class is initialized (see experiments.py) we pass in the
178
- union of all data, model, and training args, all of which should
179
- get saved to the config json.
180
- """
181
-
182
- def __init__(self, **kwargs):
183
- for key, value in kwargs.items():
184
- try:
185
- json.dumps(value)
186
- setattr(self, key, value)
187
- except TypeError:
188
- # value was not JSON-serializable, skip
189
- continue
190
- super().__init__()
191
-
192
-
193
- def independent_crop(
194
- input_ids: torch.Tensor, pad_token_id: int,
195
- l1: int = 256, l2: int = 256) -> Tuple[torch.Tensor, torch.Tensor]:
196
- """Returns two independent crops from input_ids.
197
-
198
- Assumes input_ids has a beginning and end token, like
199
- [101, ..., 102, 0, 0, 0].
200
-
201
- Args:
202
- input_ids: tensor of IDs
203
- pad_token_id: ID of pad tokens in input_ids
204
- l1: length of span 1, cropped
205
- l2: length of span 2, cropped
206
- Returns:
207
- span1: first crop (of length l1)
208
- span2: second crop (of length l2)
209
- """
210
- # Count tokens until pad.
211
- if (input_ids == pad_token_id).sum() == 0:
212
- N = len(input_ids)
213
- else:
214
- N = (input_ids == pad_token_id).int().argmax().item()
215
-
216
- ####
217
- ###
218
- ##
219
- ## Contriever: We use the random cropping data
220
- ## augmentation, with documents of 256 tokens and span
221
- ## sizes sampled between 5% and 50% of the document
222
- ## length
223
- ##
224
- ###
225
- #####
226
- ####### LaPraDor: The maximum lengths set for queries and
227
- ####### documents are 64 and 350...
228
- #####
229
- # TODO is this divide-by-two a good idea? (Don't want s1=s2 ever..)
230
- nl1 = min(N//2, l1)
231
- nl2 = min(N//2, l2)
232
-
233
- s1_start = random.randint(1, N-nl1)
234
- s2_start = random.randint(1, N-nl2)
235
-
236
- s1_idxs = itertools.chain(
237
- [0], range(s1_start, s1_start+nl1), [N-1]
238
- )
239
- s1 = input_ids[torch.tensor(list(s1_idxs))]
240
- s2_idxs = itertools.chain(
241
- [0], range(s2_start, s2_start+nl2), [N-1]
242
- )
243
- s2 = input_ids[torch.tensor(list(s2_idxs))]
244
- return (s1, s2)
245
-
246
-
247
- def load_dataset_tables(
248
- files: Iterable[str], num_workers: int = 16
249
- ) -> Iterable[datasets.table.MemoryMappedTable]:
250
- import concurrent
251
- from multiprocessing import Pool
252
-
253
- # num_workers = min(num_workers, len(files))
254
- num_workers = min(32, len(files))
255
-
256
- use_threads = True
257
- if use_threads:
258
- pool_cls = concurrent.futures.ThreadPoolExecutor
259
- pool_kwargs = {"max_workers": num_workers}
260
- else:
261
- pool_cls = Pool
262
- pool_kwargs = {"processes": num_workers}
263
-
264
- with pool_cls(**pool_kwargs) as pool:
265
- if len(files) > 10:
266
- files = tqdm_if_main_worker(
267
- files,
268
- desc=f"Loading {len(files)} files with {num_workers} workers",
269
- total=len(files),
270
- colour="#ffbd88"
271
- )
272
-
273
- result = list(
274
- pool.map(datasets.table.MemoryMappedTable.from_file, files)
275
- )
276
- return result
277
-
278
-
279
- def datasets_fast_load_from_disk(cache_path: str) -> datasets.Dataset:
280
- logging.info(f"fast_load_from_disk called with path:", cache_path)
281
- dataset_info_path = os.path.join(cache_path, "dataset_info.json")
282
- with open(dataset_info_path, encoding="utf-8") as dataset_info_file:
283
- dataset_info = datasets.DatasetInfo.from_dict(json.load(dataset_info_file))
284
-
285
- dataset_state_path = os.path.join(cache_path, "state.json")
286
- with open(dataset_state_path, encoding="utf-8") as state_file:
287
- state = json.load(state_file)
288
-
289
- files = glob.glob(os.path.join(cache_path, "data-*.arrow"))
290
- files = sorted(files)
291
- num_workers = get_num_proc()
292
- ds_tables = load_dataset_tables(
293
- files=files,
294
- num_workers=num_workers
295
- )
296
- arrow_table = datasets.table.concat_tables(ds_tables)
297
-
298
- split = state["_split"]
299
- split = datasets.splits.Split(split) if split is not None else split
300
-
301
- # print("returning dataset")
302
- return datasets.Dataset(
303
- arrow_table=arrow_table,
304
- info=dataset_info,
305
- split=split,
306
- fingerprint=state["_fingerprint"],
307
- )
308
-
309
-
310
- def tokenize_dataset(
311
- dataset: datasets.Dataset,
312
- tokenizer: transformers.PreTrainedTokenizer,
313
- max_length: int,
314
- text_key: str,
315
- padding_strategy: str
316
- ) -> datasets.Dataset:
317
- def tokenize_text(ex: Dict) -> Dict:
318
- tt = tokenizer(
319
- ex[text_key],
320
- max_length=max_length,
321
- truncation=True,
322
- padding=padding_strategy,
323
- )
324
- for k,v in tt.items():
325
- ex[f"{text_key}_{k}"] = v
326
- ex["length"] = [len(tt) for tt in ex[f"{text_key}_input_ids"]]
327
- return ex
328
-
329
- # generate unique hash for tokenizer
330
- vocab = tokenizer.vocab
331
- vocab_words = tuple(sorted(vocab.keys(), key=lambda word: vocab[word]))
332
- vocab_hash = md5_hash(vocab_words)
333
-
334
- data_fingerprint = '__'.join((
335
- dataset._fingerprint, str(vocab_hash), str(max_length),
336
- text_key, padding_strategy
337
- ))
338
- data_fingerprint = md5_hash(data_fingerprint)
339
- dataset = dataset.map(
340
- tokenize_text,
341
- new_fingerprint=data_fingerprint,
342
- batched=True,
343
- load_from_cache_file=True,
344
- )
345
- return dataset
346
-
347
-
348
- class TensorRunningAverages:
349
- _store_sum: Dict[str, torch.Tensor]
350
- _store_total: Dict[str, torch.Tensor]
351
-
352
- def __init__(self):
353
- self._store_sum = {}
354
- self._store_total = {}
355
-
356
- def __iter__(self) -> Iterable[str]:
357
- return iter(self._store_sum.keys())
358
-
359
- def update(self, key: str, val: Union[int, float, torch.Tensor]) -> None:
360
- if key not in self._store_sum:
361
- self.clear(key)
362
- if isinstance(val, torch.Tensor):
363
- val = val.item() # tensor -> num
364
- self._store_sum[key] += val
365
- self._store_total[key] += 1
366
-
367
- def get(self, key: str) -> float:
368
- total = max(self._store_total.get(key).item(), 1.0)
369
- return (self._store_sum[key] / float(total)).item() or 0.0
370
-
371
- def clear(self, key: str) -> None:
372
- self._store_sum[key] = torch.tensor(0.0, dtype=torch.float32)
373
- self._store_total[key] = torch.tensor(0, dtype=torch.int32)
374
-
375
- def clear_all(self) -> None:
376
- for key in self._store_sum:
377
- self.clear(key)
378
-
379
- def get_and_clear_all(self) -> Dict[str, float]:
380
- metrics = {}
381
- for key in self:
382
- metrics[key] = self.get(key)
383
- self.clear(key)
384
- return metrics
385
-
386
- def load_embedder_and_tokenizer(name: str) -> Tuple[
387
- transformers.PreTrainedModel,
388
- transformers.PreTrainedTokenizer
389
- ]:
390
- if name.startswith("nomic") or (name == "bert-base-uncased"):
391
- from cde.lib.nomic_bert import NomicBertModel
392
- if name.endswith("--from-scratch"):
393
- name = name.replace("--from-scratch", "")
394
- config = transformers.AutoConfig.from_pretrained(name, trust_remote_code=True)
395
- model = NomicBertModel._from_config(config)
396
- else:
397
- model = NomicBertModel.from_pretrained(
398
- name, add_pooling_layer=False
399
- )
400
- tokenizer = transformers.AutoTokenizer.from_pretrained(name)
401
- elif name in ["gtr-base", "gtr_base"]:
402
- model = transformers.AutoModel.from_pretrained(
403
- "sentence-transformers/gtr-t5-base"
404
- ).encoder
405
- tokenizer = transformers.AutoTokenizer.from_pretrained(
406
- "sentence-transformers/gtr-t5-base"
407
- )
408
- elif name == "pile-t5-base-encoder":
409
- model = transformers.AutoModel.from_pretrained(
410
- "EleutherAI/pile-t5-base"
411
- ).encoder
412
- tokenizer = transformers.AutoTokenizer.from_pretrained(
413
- "EleutherAI/pile-t5-base"
414
- )
415
- tokenizer.pad_token = tokenizer.eos_token
416
- elif name == "pile-t5-base-decoder":
417
- model = transformers.AutoModel.from_pretrained(
418
- "EleutherAI/pile-t5-base"
419
- ).decoder
420
- tokenizer = transformers.AutoTokenizer.from_pretrained(
421
- "EleutherAI/pile-t5-base"
422
- )
423
- tokenizer.pad_token = tokenizer.eos_token
424
- elif name.startswith("gpt2") or name.startswith("meta-llama") or ("Llama" in name):
425
- model = transformers.AutoModelForCausalLM.from_pretrained(
426
- name,
427
- # torch_dtype=torch.bfloat16,
428
- attn_implementation="flash_attention_2",
429
- low_cpu_mem_usage=True,
430
- # device_map="auto",
431
- )
432
- model.padding_side = "right"
433
- tokenizer = transformers.AutoTokenizer.from_pretrained(name)
434
- tokenizer.pad_token = tokenizer.eos_token
435
- tokenizer.add_eos_token = True
436
- else:
437
- model = transformers.AutoModel.from_pretrained(name, trust_remote_code=True)
438
- tokenizer = transformers.AutoTokenizer.from_pretrained(name)
439
-
440
- # if use_bettertransformer:
441
- # from optimum.bettertransformer import BetterTransformer
442
- # model = BetterTransformer.transform(model)
443
- return model, tokenizer
444
-
445
-
446
- def inputs_for_key(inputs: Dict[str, torch.Tensor], key: str):
447
- key += "_"
448
- return {k.replace(key, ""): v for k,v in inputs.items() if k.startswith(key)}
449
-
450
-
451
- def load_model_state_dict_from_path(folder: str) -> Dict:
452
- checkpoint_folder = transformers.trainer_utils.get_last_checkpoint(folder)
453
- if checkpoint_folder is None:
454
- raise FileNotFoundError(f"no checkpoint found in {folder}")
455
- WEIGHTS_NAME = "model.safetensors"
456
- weights_path = os.path.join(checkpoint_folder, WEIGHTS_NAME)
457
- if not os.path.exists(weights_path):
458
- raise FileNotFoundError(f"no model weights found at {weights_path}")
459
- return safetensors.torch.load_file(weights_path, device="cpu")
460
-
461
- def count_cpus() -> int:
462
- try:
463
- return len(os.sched_getaffinity(0))
464
- except AttributeError:
465
- return multiprocessing.cpu_count()
466
-
467
-
468
- def shuffle_batches(g: torch.Generator, list_of_tensors: List[torch.Tensor]) -> List[int]:
469
- all_indices = []
470
- for batch_tensor in tqdm_if_main_worker(list_of_tensors, colour="green", desc="Sampler shuffling per-batch"):
471
- rand_perm = torch.randperm(len(batch_tensor), generator=g)
472
- batch_list = batch_tensor[rand_perm].tolist()
473
- all_indices.extend(batch_list)
474
- return all_indices
475
-
476
-
477
- # def shuffle_batches_multiproc(g: torch.Generator, list_of_tensors: List[torch.Tensor], num_processes: int = 8) -> List[int]:
478
- # all_indices = []
479
- # print(f"Shuffling {len(list_of_tensors)} tensors with {num_processes} workers.")
480
- # pbar = tqdm_if_main_worker(list_of_tensors, colour="orange", desc=f"Sampler shuffling per-batch (nproc={num_processes})")
481
- # pool = multiprocessing.Pool(processes=num_processes)
482
- # chunk_size = len(list_of_tensors) // num_processes
483
- # chunks = [list_of_tensors[i:i + chunk_size] for i in range(0, len(list_of_tensors), chunk_size)]
484
- # worker_func = functools.partial(shuffle_batches, g=g)
485
- # results = pool.map(worker_func, chunks)
486
- # all_indices = []
487
- # for result in results:
488
- # all_indices.extend(result)
489
- # pbar.update()
490
- # return all_indices
491
-
492
-
493
- def exit_if_running_or_finished_wandb(
494
- project_name: str,
495
- exp_group: str, exp_name: str
496
- ) -> None:
497
- print("Checking if experiment is already running...")
498
- import wandb
499
-
500
- api = wandb.Api()
501
- running_runs = api.runs(
502
- path="tti-nomic-7",
503
- filters={
504
- "display_name": exp_name,
505
- "state": {"$regex": "Running|Finished"},
506
- "config.exp_group": exp_group,
507
- }
508
- )
509
- print("Found", len(running_runs), f"runs with name {exp_name} and group {exp_group} in {project_name}.")
510
-
511
- if len(running_runs) > 0:
512
- print("Exiting because experiment is already running or completed.")
513
- sys.exit(0)
514
-