Delete misc.py
Browse files
misc.py
DELETED
@@ -1,514 +0,0 @@
|
|
1 |
-
from typing import Dict, Iterable, List, Tuple, Union
|
2 |
-
|
3 |
-
import collections
|
4 |
-
import functools
|
5 |
-
import glob
|
6 |
-
import json
|
7 |
-
import hashlib
|
8 |
-
import itertools
|
9 |
-
import logging
|
10 |
-
import multiprocessing
|
11 |
-
import os
|
12 |
-
import pickle
|
13 |
-
import random
|
14 |
-
import requests
|
15 |
-
import sys
|
16 |
-
import zipfile
|
17 |
-
|
18 |
-
import datasets
|
19 |
-
import numpy as np
|
20 |
-
import safetensors
|
21 |
-
import torch
|
22 |
-
import tqdm
|
23 |
-
import transformers
|
24 |
-
|
25 |
-
from cde.lib.dist import get_num_proc, get_rank
|
26 |
-
|
27 |
-
|
28 |
-
def get_cde_cache_dir() -> str:
|
29 |
-
script_directory = os.path.normpath(
|
30 |
-
os.path.join(
|
31 |
-
os.path.dirname(os.path.abspath(__file__)),
|
32 |
-
os.pardir, os.pardir,
|
33 |
-
)
|
34 |
-
)
|
35 |
-
return os.path.join(script_directory, "data")
|
36 |
-
|
37 |
-
|
38 |
-
def get_cache_location_from_kwargs(**kwargs):
|
39 |
-
cache_location = os.path.join(
|
40 |
-
get_cde_cache_dir(), "cluster"
|
41 |
-
)
|
42 |
-
os.makedirs(cache_location, exist_ok=True)
|
43 |
-
return os.path.join(cache_location, md5_hash_kwargs(**kwargs))
|
44 |
-
|
45 |
-
|
46 |
-
def process_qrels_uncached(corpus: datasets.Dataset, qrels: datasets.Dataset) -> Tuple[Dict[str, List[float]], Dict[str, List[str]]]:
|
47 |
-
qrels_idxs = collections.defaultdict(list)
|
48 |
-
qrels_scores = collections.defaultdict(list)
|
49 |
-
corpus_ids = np.array(corpus['_id'])
|
50 |
-
skipped_qrels = 0
|
51 |
-
|
52 |
-
for ex in tqdm.tqdm(qrels, desc='processing qrels', colour='#964B00', leave=False):
|
53 |
-
#
|
54 |
-
# example:
|
55 |
-
# {
|
56 |
-
# 'query-id': 1,
|
57 |
-
# 'corpus-id': 'b0680508-2019-04-18T13:48:51Z-00002-000',
|
58 |
-
# 'score': 2
|
59 |
-
# }
|
60 |
-
#
|
61 |
-
q_id = str(ex['query-id'])
|
62 |
-
c_idxs = (corpus_ids == str(ex['corpus-id'])).nonzero()[0]
|
63 |
-
#
|
64 |
-
assert len(c_idxs) <= 1, f"error - duplicate corpus ID? (found {len(c_idxs)} matches)"
|
65 |
-
#
|
66 |
-
if len(c_idxs):
|
67 |
-
qrels_idxs[q_id].append(c_idxs[0])
|
68 |
-
qrels_scores[q_id].append(ex['score'])
|
69 |
-
else:
|
70 |
-
skipped_qrels += 1
|
71 |
-
#
|
72 |
-
|
73 |
-
if skipped_qrels > 0:
|
74 |
-
logging.warning(f'Warning: Skipped {skipped_qrels}/{len(qrels)} qrels.')
|
75 |
-
|
76 |
-
return qrels_idxs, qrels_scores
|
77 |
-
|
78 |
-
|
79 |
-
def process_qrels(
|
80 |
-
corpus: datasets.Dataset, qrels: datasets.Dataset,
|
81 |
-
use_cache: bool = True
|
82 |
-
) -> Tuple[Dict[str, List[float]], Dict[str, List[str]]]:
|
83 |
-
dataset_cache_file = '_'.join(
|
84 |
-
(corpus.cache_files[0]['filename'], qrels.cache_files[0]['filename'])
|
85 |
-
)
|
86 |
-
cache_file = strip_extension(dataset_cache_file) + '_processed_qrels.p'
|
87 |
-
os.makedirs(os.path.dirname(cache_file), exist_ok=True)
|
88 |
-
|
89 |
-
if not (use_cache and os.path.exists(cache_file)):
|
90 |
-
qrels_idxs, qrels_scores = process_qrels_uncached(
|
91 |
-
corpus=corpus, qrels=qrels
|
92 |
-
)
|
93 |
-
if use_cache:
|
94 |
-
pickle.dump((qrels_idxs, qrels_scores), open(cache_file, 'wb'))
|
95 |
-
else:
|
96 |
-
qrels_idxs, qrels_scores = pickle.load(open(cache_file, 'rb'))
|
97 |
-
|
98 |
-
return qrels_idxs, qrels_scores
|
99 |
-
|
100 |
-
|
101 |
-
def strip_extension(filename: str) -> str:
|
102 |
-
"""Strips file extension.
|
103 |
-
|
104 |
-
Ex:
|
105 |
-
>> strip_extension('/root/dir/sub/file.ext')
|
106 |
-
'/root/dir/sub/file'
|
107 |
-
"""
|
108 |
-
return os.path.splitext(filename)[0]
|
109 |
-
|
110 |
-
|
111 |
-
def md5_hash(t: Tuple[str]) -> str:
|
112 |
-
return hashlib.md5('__'.join(t).encode()).hexdigest()
|
113 |
-
|
114 |
-
|
115 |
-
def md5_hash_kwargs(**kwargs) -> str:
|
116 |
-
# We ignore special hf args that start with _ like '__cached__setup_devices'.
|
117 |
-
safe_kwargs = {k: str(v) for k,v in kwargs.items() if not k.startswith('_')}
|
118 |
-
s = json.dumps(safe_kwargs, sort_keys=True)
|
119 |
-
return hashlib.md5(s.encode()).hexdigest()
|
120 |
-
|
121 |
-
def download_url(url: str, save_path: str, chunk_size: int = 1024):
|
122 |
-
"""Download url with progress bar using tqdm
|
123 |
-
https://stackoverflow.com/questions/15644964/python-progress-bar-and-downloads
|
124 |
-
Args:
|
125 |
-
url (str): downloadable url
|
126 |
-
save_path (str): local path to save the downloaded file
|
127 |
-
chunk_size (int, optional): chunking of files. Defaults to 1024.
|
128 |
-
"""
|
129 |
-
r = requests.get(url, stream=True)
|
130 |
-
total = int(r.headers.get('Content-Length', 0))
|
131 |
-
with open(save_path, 'wb') as fd, tqdm.tqdm(
|
132 |
-
desc=save_path,
|
133 |
-
total=total,
|
134 |
-
unit='iB',
|
135 |
-
unit_scale=True,
|
136 |
-
unit_divisor=chunk_size,
|
137 |
-
) as bar:
|
138 |
-
for data in r.iter_content(chunk_size=chunk_size):
|
139 |
-
size = fd.write(data)
|
140 |
-
bar.update(size)
|
141 |
-
|
142 |
-
|
143 |
-
def unzip(zip_file: str, out_dir: str):
|
144 |
-
print("unzipping =>", zip_file)
|
145 |
-
zip_ = zipfile.ZipFile(zip_file, "r")
|
146 |
-
zip_.extractall(path=out_dir)
|
147 |
-
zip_.close()
|
148 |
-
|
149 |
-
|
150 |
-
def download_url_and_unzip(url: str, out_dir: str, chunk_size: int = 1024) -> str:
|
151 |
-
os.makedirs(out_dir, exist_ok=True)
|
152 |
-
dataset = url.split("/")[-1]
|
153 |
-
zip_file = os.path.join(out_dir, dataset)
|
154 |
-
|
155 |
-
if not os.path.isfile(zip_file):
|
156 |
-
logging.info("Downloading {} ...".format(dataset))
|
157 |
-
download_url(url, zip_file, chunk_size)
|
158 |
-
|
159 |
-
if not os.path.isdir(zip_file.replace(".zip", "")):
|
160 |
-
logging.info("Unzipping {} ...".format(dataset))
|
161 |
-
unzip(zip_file, out_dir)
|
162 |
-
|
163 |
-
return os.path.join(out_dir, dataset.replace(".zip", ""))
|
164 |
-
|
165 |
-
|
166 |
-
def tqdm_if_main_worker(iterable: Iterable, **kwargs) -> Iterable:
|
167 |
-
if get_rank() == 0:
|
168 |
-
return tqdm.tqdm(iterable, **kwargs)
|
169 |
-
else:
|
170 |
-
return iterable
|
171 |
-
|
172 |
-
|
173 |
-
class ContextualModelConfig(transformers.configuration_utils.PretrainedConfig):
|
174 |
-
"""We create a dummy configuration class that will just set properties
|
175 |
-
based on whatever kwargs we pass in.
|
176 |
-
|
177 |
-
When this class is initialized (see experiments.py) we pass in the
|
178 |
-
union of all data, model, and training args, all of which should
|
179 |
-
get saved to the config json.
|
180 |
-
"""
|
181 |
-
|
182 |
-
def __init__(self, **kwargs):
|
183 |
-
for key, value in kwargs.items():
|
184 |
-
try:
|
185 |
-
json.dumps(value)
|
186 |
-
setattr(self, key, value)
|
187 |
-
except TypeError:
|
188 |
-
# value was not JSON-serializable, skip
|
189 |
-
continue
|
190 |
-
super().__init__()
|
191 |
-
|
192 |
-
|
193 |
-
def independent_crop(
|
194 |
-
input_ids: torch.Tensor, pad_token_id: int,
|
195 |
-
l1: int = 256, l2: int = 256) -> Tuple[torch.Tensor, torch.Tensor]:
|
196 |
-
"""Returns two independent crops from input_ids.
|
197 |
-
|
198 |
-
Assumes input_ids has a beginning and end token, like
|
199 |
-
[101, ..., 102, 0, 0, 0].
|
200 |
-
|
201 |
-
Args:
|
202 |
-
input_ids: tensor of IDs
|
203 |
-
pad_token_id: ID of pad tokens in input_ids
|
204 |
-
l1: length of span 1, cropped
|
205 |
-
l2: length of span 2, cropped
|
206 |
-
Returns:
|
207 |
-
span1: first crop (of length l1)
|
208 |
-
span2: second crop (of length l2)
|
209 |
-
"""
|
210 |
-
# Count tokens until pad.
|
211 |
-
if (input_ids == pad_token_id).sum() == 0:
|
212 |
-
N = len(input_ids)
|
213 |
-
else:
|
214 |
-
N = (input_ids == pad_token_id).int().argmax().item()
|
215 |
-
|
216 |
-
####
|
217 |
-
###
|
218 |
-
##
|
219 |
-
## Contriever: We use the random cropping data
|
220 |
-
## augmentation, with documents of 256 tokens and span
|
221 |
-
## sizes sampled between 5% and 50% of the document
|
222 |
-
## length
|
223 |
-
##
|
224 |
-
###
|
225 |
-
#####
|
226 |
-
####### LaPraDor: The maximum lengths set for queries and
|
227 |
-
####### documents are 64 and 350...
|
228 |
-
#####
|
229 |
-
# TODO is this divide-by-two a good idea? (Don't want s1=s2 ever..)
|
230 |
-
nl1 = min(N//2, l1)
|
231 |
-
nl2 = min(N//2, l2)
|
232 |
-
|
233 |
-
s1_start = random.randint(1, N-nl1)
|
234 |
-
s2_start = random.randint(1, N-nl2)
|
235 |
-
|
236 |
-
s1_idxs = itertools.chain(
|
237 |
-
[0], range(s1_start, s1_start+nl1), [N-1]
|
238 |
-
)
|
239 |
-
s1 = input_ids[torch.tensor(list(s1_idxs))]
|
240 |
-
s2_idxs = itertools.chain(
|
241 |
-
[0], range(s2_start, s2_start+nl2), [N-1]
|
242 |
-
)
|
243 |
-
s2 = input_ids[torch.tensor(list(s2_idxs))]
|
244 |
-
return (s1, s2)
|
245 |
-
|
246 |
-
|
247 |
-
def load_dataset_tables(
|
248 |
-
files: Iterable[str], num_workers: int = 16
|
249 |
-
) -> Iterable[datasets.table.MemoryMappedTable]:
|
250 |
-
import concurrent
|
251 |
-
from multiprocessing import Pool
|
252 |
-
|
253 |
-
# num_workers = min(num_workers, len(files))
|
254 |
-
num_workers = min(32, len(files))
|
255 |
-
|
256 |
-
use_threads = True
|
257 |
-
if use_threads:
|
258 |
-
pool_cls = concurrent.futures.ThreadPoolExecutor
|
259 |
-
pool_kwargs = {"max_workers": num_workers}
|
260 |
-
else:
|
261 |
-
pool_cls = Pool
|
262 |
-
pool_kwargs = {"processes": num_workers}
|
263 |
-
|
264 |
-
with pool_cls(**pool_kwargs) as pool:
|
265 |
-
if len(files) > 10:
|
266 |
-
files = tqdm_if_main_worker(
|
267 |
-
files,
|
268 |
-
desc=f"Loading {len(files)} files with {num_workers} workers",
|
269 |
-
total=len(files),
|
270 |
-
colour="#ffbd88"
|
271 |
-
)
|
272 |
-
|
273 |
-
result = list(
|
274 |
-
pool.map(datasets.table.MemoryMappedTable.from_file, files)
|
275 |
-
)
|
276 |
-
return result
|
277 |
-
|
278 |
-
|
279 |
-
def datasets_fast_load_from_disk(cache_path: str) -> datasets.Dataset:
|
280 |
-
logging.info(f"fast_load_from_disk called with path:", cache_path)
|
281 |
-
dataset_info_path = os.path.join(cache_path, "dataset_info.json")
|
282 |
-
with open(dataset_info_path, encoding="utf-8") as dataset_info_file:
|
283 |
-
dataset_info = datasets.DatasetInfo.from_dict(json.load(dataset_info_file))
|
284 |
-
|
285 |
-
dataset_state_path = os.path.join(cache_path, "state.json")
|
286 |
-
with open(dataset_state_path, encoding="utf-8") as state_file:
|
287 |
-
state = json.load(state_file)
|
288 |
-
|
289 |
-
files = glob.glob(os.path.join(cache_path, "data-*.arrow"))
|
290 |
-
files = sorted(files)
|
291 |
-
num_workers = get_num_proc()
|
292 |
-
ds_tables = load_dataset_tables(
|
293 |
-
files=files,
|
294 |
-
num_workers=num_workers
|
295 |
-
)
|
296 |
-
arrow_table = datasets.table.concat_tables(ds_tables)
|
297 |
-
|
298 |
-
split = state["_split"]
|
299 |
-
split = datasets.splits.Split(split) if split is not None else split
|
300 |
-
|
301 |
-
# print("returning dataset")
|
302 |
-
return datasets.Dataset(
|
303 |
-
arrow_table=arrow_table,
|
304 |
-
info=dataset_info,
|
305 |
-
split=split,
|
306 |
-
fingerprint=state["_fingerprint"],
|
307 |
-
)
|
308 |
-
|
309 |
-
|
310 |
-
def tokenize_dataset(
|
311 |
-
dataset: datasets.Dataset,
|
312 |
-
tokenizer: transformers.PreTrainedTokenizer,
|
313 |
-
max_length: int,
|
314 |
-
text_key: str,
|
315 |
-
padding_strategy: str
|
316 |
-
) -> datasets.Dataset:
|
317 |
-
def tokenize_text(ex: Dict) -> Dict:
|
318 |
-
tt = tokenizer(
|
319 |
-
ex[text_key],
|
320 |
-
max_length=max_length,
|
321 |
-
truncation=True,
|
322 |
-
padding=padding_strategy,
|
323 |
-
)
|
324 |
-
for k,v in tt.items():
|
325 |
-
ex[f"{text_key}_{k}"] = v
|
326 |
-
ex["length"] = [len(tt) for tt in ex[f"{text_key}_input_ids"]]
|
327 |
-
return ex
|
328 |
-
|
329 |
-
# generate unique hash for tokenizer
|
330 |
-
vocab = tokenizer.vocab
|
331 |
-
vocab_words = tuple(sorted(vocab.keys(), key=lambda word: vocab[word]))
|
332 |
-
vocab_hash = md5_hash(vocab_words)
|
333 |
-
|
334 |
-
data_fingerprint = '__'.join((
|
335 |
-
dataset._fingerprint, str(vocab_hash), str(max_length),
|
336 |
-
text_key, padding_strategy
|
337 |
-
))
|
338 |
-
data_fingerprint = md5_hash(data_fingerprint)
|
339 |
-
dataset = dataset.map(
|
340 |
-
tokenize_text,
|
341 |
-
new_fingerprint=data_fingerprint,
|
342 |
-
batched=True,
|
343 |
-
load_from_cache_file=True,
|
344 |
-
)
|
345 |
-
return dataset
|
346 |
-
|
347 |
-
|
348 |
-
class TensorRunningAverages:
|
349 |
-
_store_sum: Dict[str, torch.Tensor]
|
350 |
-
_store_total: Dict[str, torch.Tensor]
|
351 |
-
|
352 |
-
def __init__(self):
|
353 |
-
self._store_sum = {}
|
354 |
-
self._store_total = {}
|
355 |
-
|
356 |
-
def __iter__(self) -> Iterable[str]:
|
357 |
-
return iter(self._store_sum.keys())
|
358 |
-
|
359 |
-
def update(self, key: str, val: Union[int, float, torch.Tensor]) -> None:
|
360 |
-
if key not in self._store_sum:
|
361 |
-
self.clear(key)
|
362 |
-
if isinstance(val, torch.Tensor):
|
363 |
-
val = val.item() # tensor -> num
|
364 |
-
self._store_sum[key] += val
|
365 |
-
self._store_total[key] += 1
|
366 |
-
|
367 |
-
def get(self, key: str) -> float:
|
368 |
-
total = max(self._store_total.get(key).item(), 1.0)
|
369 |
-
return (self._store_sum[key] / float(total)).item() or 0.0
|
370 |
-
|
371 |
-
def clear(self, key: str) -> None:
|
372 |
-
self._store_sum[key] = torch.tensor(0.0, dtype=torch.float32)
|
373 |
-
self._store_total[key] = torch.tensor(0, dtype=torch.int32)
|
374 |
-
|
375 |
-
def clear_all(self) -> None:
|
376 |
-
for key in self._store_sum:
|
377 |
-
self.clear(key)
|
378 |
-
|
379 |
-
def get_and_clear_all(self) -> Dict[str, float]:
|
380 |
-
metrics = {}
|
381 |
-
for key in self:
|
382 |
-
metrics[key] = self.get(key)
|
383 |
-
self.clear(key)
|
384 |
-
return metrics
|
385 |
-
|
386 |
-
def load_embedder_and_tokenizer(name: str) -> Tuple[
|
387 |
-
transformers.PreTrainedModel,
|
388 |
-
transformers.PreTrainedTokenizer
|
389 |
-
]:
|
390 |
-
if name.startswith("nomic") or (name == "bert-base-uncased"):
|
391 |
-
from cde.lib.nomic_bert import NomicBertModel
|
392 |
-
if name.endswith("--from-scratch"):
|
393 |
-
name = name.replace("--from-scratch", "")
|
394 |
-
config = transformers.AutoConfig.from_pretrained(name, trust_remote_code=True)
|
395 |
-
model = NomicBertModel._from_config(config)
|
396 |
-
else:
|
397 |
-
model = NomicBertModel.from_pretrained(
|
398 |
-
name, add_pooling_layer=False
|
399 |
-
)
|
400 |
-
tokenizer = transformers.AutoTokenizer.from_pretrained(name)
|
401 |
-
elif name in ["gtr-base", "gtr_base"]:
|
402 |
-
model = transformers.AutoModel.from_pretrained(
|
403 |
-
"sentence-transformers/gtr-t5-base"
|
404 |
-
).encoder
|
405 |
-
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
406 |
-
"sentence-transformers/gtr-t5-base"
|
407 |
-
)
|
408 |
-
elif name == "pile-t5-base-encoder":
|
409 |
-
model = transformers.AutoModel.from_pretrained(
|
410 |
-
"EleutherAI/pile-t5-base"
|
411 |
-
).encoder
|
412 |
-
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
413 |
-
"EleutherAI/pile-t5-base"
|
414 |
-
)
|
415 |
-
tokenizer.pad_token = tokenizer.eos_token
|
416 |
-
elif name == "pile-t5-base-decoder":
|
417 |
-
model = transformers.AutoModel.from_pretrained(
|
418 |
-
"EleutherAI/pile-t5-base"
|
419 |
-
).decoder
|
420 |
-
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
421 |
-
"EleutherAI/pile-t5-base"
|
422 |
-
)
|
423 |
-
tokenizer.pad_token = tokenizer.eos_token
|
424 |
-
elif name.startswith("gpt2") or name.startswith("meta-llama") or ("Llama" in name):
|
425 |
-
model = transformers.AutoModelForCausalLM.from_pretrained(
|
426 |
-
name,
|
427 |
-
# torch_dtype=torch.bfloat16,
|
428 |
-
attn_implementation="flash_attention_2",
|
429 |
-
low_cpu_mem_usage=True,
|
430 |
-
# device_map="auto",
|
431 |
-
)
|
432 |
-
model.padding_side = "right"
|
433 |
-
tokenizer = transformers.AutoTokenizer.from_pretrained(name)
|
434 |
-
tokenizer.pad_token = tokenizer.eos_token
|
435 |
-
tokenizer.add_eos_token = True
|
436 |
-
else:
|
437 |
-
model = transformers.AutoModel.from_pretrained(name, trust_remote_code=True)
|
438 |
-
tokenizer = transformers.AutoTokenizer.from_pretrained(name)
|
439 |
-
|
440 |
-
# if use_bettertransformer:
|
441 |
-
# from optimum.bettertransformer import BetterTransformer
|
442 |
-
# model = BetterTransformer.transform(model)
|
443 |
-
return model, tokenizer
|
444 |
-
|
445 |
-
|
446 |
-
def inputs_for_key(inputs: Dict[str, torch.Tensor], key: str):
|
447 |
-
key += "_"
|
448 |
-
return {k.replace(key, ""): v for k,v in inputs.items() if k.startswith(key)}
|
449 |
-
|
450 |
-
|
451 |
-
def load_model_state_dict_from_path(folder: str) -> Dict:
|
452 |
-
checkpoint_folder = transformers.trainer_utils.get_last_checkpoint(folder)
|
453 |
-
if checkpoint_folder is None:
|
454 |
-
raise FileNotFoundError(f"no checkpoint found in {folder}")
|
455 |
-
WEIGHTS_NAME = "model.safetensors"
|
456 |
-
weights_path = os.path.join(checkpoint_folder, WEIGHTS_NAME)
|
457 |
-
if not os.path.exists(weights_path):
|
458 |
-
raise FileNotFoundError(f"no model weights found at {weights_path}")
|
459 |
-
return safetensors.torch.load_file(weights_path, device="cpu")
|
460 |
-
|
461 |
-
def count_cpus() -> int:
|
462 |
-
try:
|
463 |
-
return len(os.sched_getaffinity(0))
|
464 |
-
except AttributeError:
|
465 |
-
return multiprocessing.cpu_count()
|
466 |
-
|
467 |
-
|
468 |
-
def shuffle_batches(g: torch.Generator, list_of_tensors: List[torch.Tensor]) -> List[int]:
|
469 |
-
all_indices = []
|
470 |
-
for batch_tensor in tqdm_if_main_worker(list_of_tensors, colour="green", desc="Sampler shuffling per-batch"):
|
471 |
-
rand_perm = torch.randperm(len(batch_tensor), generator=g)
|
472 |
-
batch_list = batch_tensor[rand_perm].tolist()
|
473 |
-
all_indices.extend(batch_list)
|
474 |
-
return all_indices
|
475 |
-
|
476 |
-
|
477 |
-
# def shuffle_batches_multiproc(g: torch.Generator, list_of_tensors: List[torch.Tensor], num_processes: int = 8) -> List[int]:
|
478 |
-
# all_indices = []
|
479 |
-
# print(f"Shuffling {len(list_of_tensors)} tensors with {num_processes} workers.")
|
480 |
-
# pbar = tqdm_if_main_worker(list_of_tensors, colour="orange", desc=f"Sampler shuffling per-batch (nproc={num_processes})")
|
481 |
-
# pool = multiprocessing.Pool(processes=num_processes)
|
482 |
-
# chunk_size = len(list_of_tensors) // num_processes
|
483 |
-
# chunks = [list_of_tensors[i:i + chunk_size] for i in range(0, len(list_of_tensors), chunk_size)]
|
484 |
-
# worker_func = functools.partial(shuffle_batches, g=g)
|
485 |
-
# results = pool.map(worker_func, chunks)
|
486 |
-
# all_indices = []
|
487 |
-
# for result in results:
|
488 |
-
# all_indices.extend(result)
|
489 |
-
# pbar.update()
|
490 |
-
# return all_indices
|
491 |
-
|
492 |
-
|
493 |
-
def exit_if_running_or_finished_wandb(
|
494 |
-
project_name: str,
|
495 |
-
exp_group: str, exp_name: str
|
496 |
-
) -> None:
|
497 |
-
print("Checking if experiment is already running...")
|
498 |
-
import wandb
|
499 |
-
|
500 |
-
api = wandb.Api()
|
501 |
-
running_runs = api.runs(
|
502 |
-
path="tti-nomic-7",
|
503 |
-
filters={
|
504 |
-
"display_name": exp_name,
|
505 |
-
"state": {"$regex": "Running|Finished"},
|
506 |
-
"config.exp_group": exp_group,
|
507 |
-
}
|
508 |
-
)
|
509 |
-
print("Found", len(running_runs), f"runs with name {exp_name} and group {exp_group} in {project_name}.")
|
510 |
-
|
511 |
-
if len(running_runs) > 0:
|
512 |
-
print("Exiting because experiment is already running or completed.")
|
513 |
-
sys.exit(0)
|
514 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|