nbroad HF staff commited on
Commit
f827190
1 Parent(s): f0c5d5d

use upload file/folder and dataloader

Browse files
Files changed (1) hide show
  1. utils.py +63 -43
utils.py CHANGED
@@ -5,6 +5,7 @@ from pathlib import Path
5
  from typing import Union, Dict, List
6
 
7
  import torch
 
8
  import datasets
9
  from datasets import load_dataset, Dataset
10
  from transformers import AutoTokenizer, PreTrainedTokenizer
@@ -274,17 +275,15 @@ def batch_embed(
274
 
275
  repo = init_git_repo(new_dataset_id)
276
 
277
- iterator = iter(
278
- ds.map(
279
- tokenize,
280
- batched=True,
281
- batch_size=map_batch_size,
282
- fn_kwargs={
283
- "tokenizer": tokenizer,
284
- "column_name": column_name,
285
- "padding": "max_length" if opt_level == "O4" else True,
286
- },
287
- )
288
  )
289
 
290
  embeds = []
@@ -299,23 +298,20 @@ def batch_embed(
299
 
300
  inference_bs = get_batch_size(torch.cuda.get_device_name(0), model_name, opt_level)
301
 
302
- loop = True
303
-
304
- # skip through some examples
305
  if num2skip > 0:
306
- [next(iterator) for _ in range(num2skip)]
307
 
308
  start_time = time.time()
309
- while loop:
310
- batch = [next(iterator, None) for _ in range(inference_bs)]
311
-
312
- # batch will have None values when iterator runs out
313
- if batch[-1] is None:
314
- batch = [x for x in batch if x is not None]
315
- loop = False
316
- if len(batch) == 0:
317
- break
318
 
 
 
 
 
 
 
 
 
319
  ids = torch.tensor([b["input_ids"] for b in batch], device=device)
320
  mask = torch.tensor([b["attention_mask"] for b in batch], device=device)
321
  t_ids = torch.zeros_like(ids)
@@ -325,7 +321,7 @@ def batch_embed(
325
  embeds.extend(mean_pooling(outputs[0], mask).cpu().tolist())
326
  texts.extend([b[column_name] for b in batch])
327
 
328
- current_count += len(batch)
329
 
330
  # Check if we have embedded enough examples
331
  if current_count >= num2embed:
@@ -405,18 +401,19 @@ def init_git_repo(repo_id: str):
405
 
406
 
407
  def push_to_repo(
408
- repo: str,
409
  last_count: int,
410
  current_count: int,
411
  embeds: List[List[float]],
412
  texts: List[str],
 
413
  ):
414
  """
415
  Push embeddings to the repo.
416
 
417
  Args:
418
- repo (`huggingface_hub.Repository`):
419
- repo to push to
420
  last_count (`int`):
421
  last count of embeddings.
422
  This is the number of embeddings that have already been pushed.
@@ -427,9 +424,10 @@ def push_to_repo(
427
  list of embeddings to push to the repo
428
  texts (`List[str]`):
429
  list of texts to push to the repo
 
 
430
  """
431
 
432
- # TODO: write dataset loading script as well
433
 
434
  temp_ds = Dataset.from_dict(
435
  {
@@ -438,24 +436,46 @@ def push_to_repo(
438
  }
439
  )
440
 
441
- data_dir = Path(repo.local_dir) / "data"
 
 
442
  data_dir.mkdir(exist_ok=True, parents=True)
443
 
444
- temp_ds.to_parquet(
445
- str(data_dir / f"embeddings_{last_count}_{current_count}.parquet")
446
- )
447
 
448
- repo.push_to_hub(
449
- commit_message=f"Embedded examples {last_count} thru {current_count}",
450
- blocking=False,
451
- auto_lfs_prune=True,
452
- )
453
 
454
- # TODO: delete/untrack old files
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
455
 
456
- # most_recent_file = f"embeddings_{last_count}_{current_count}.parquet"
457
 
458
  # Delete old files
459
- # for f in data_dir.glob("*.parquet"):
460
- # if f.name != most_recent_file:
461
- # f.unlink()
 
 
5
  from typing import Union, Dict, List
6
 
7
  import torch
8
+ from torch.utils.data import DataLoader
9
  import datasets
10
  from datasets import load_dataset, Dataset
11
  from transformers import AutoTokenizer, PreTrainedTokenizer
 
275
 
276
  repo = init_git_repo(new_dataset_id)
277
 
278
+ ds = ds.map(
279
+ tokenize,
280
+ batched=True,
281
+ batch_size=map_batch_size,
282
+ fn_kwargs={
283
+ "tokenizer": tokenizer,
284
+ "column_name": column_name,
285
+ "padding": "max_length" if opt_level == "O4" else True,
286
+ },
 
 
287
  )
288
 
289
  embeds = []
 
298
 
299
  inference_bs = get_batch_size(torch.cuda.get_device_name(0), model_name, opt_level)
300
 
301
+ # skip through some examples if specified
 
 
302
  if num2skip > 0:
303
+ ds = ds.skip(num2skip)
304
 
305
  start_time = time.time()
 
 
 
 
 
 
 
 
 
306
 
307
+ for batch in DataLoader(
308
+ ds,
309
+ batch_size=inference_bs,
310
+ shuffle=False,
311
+ num_workers=2,
312
+ pin_memory=True,
313
+ drop_last=False,
314
+ ):
315
  ids = torch.tensor([b["input_ids"] for b in batch], device=device)
316
  mask = torch.tensor([b["attention_mask"] for b in batch], device=device)
317
  t_ids = torch.zeros_like(ids)
 
321
  embeds.extend(mean_pooling(outputs[0], mask).cpu().tolist())
322
  texts.extend([b[column_name] for b in batch])
323
 
324
+ current_count += ids.shape[0]
325
 
326
  # Check if we have embedded enough examples
327
  if current_count >= num2embed:
 
401
 
402
 
403
  def push_to_repo(
404
+ repo_id: str,
405
  last_count: int,
406
  current_count: int,
407
  embeds: List[List[float]],
408
  texts: List[str],
409
+ api: HfApi,
410
  ):
411
  """
412
  Push embeddings to the repo.
413
 
414
  Args:
415
+ repo_id (`str`):
416
+ id of the new dataset to create. Should include username or organization.
417
  last_count (`int`):
418
  last count of embeddings.
419
  This is the number of embeddings that have already been pushed.
 
424
  list of embeddings to push to the repo
425
  texts (`List[str]`):
426
  list of texts to push to the repo
427
+ api (`huggingface_hub.HfApi`):
428
+ api to use to push to the repo
429
  """
430
 
 
431
 
432
  temp_ds = Dataset.from_dict(
433
  {
 
436
  }
437
  )
438
 
439
+ local_dir = repo_id.replace("/", "_")
440
+
441
+ data_dir = Path(local_dir) / "data"
442
  data_dir.mkdir(exist_ok=True, parents=True)
443
 
444
+ # use zfill so sorting puts the files in order
445
+ filename = f"embeddings_{str(last_count).zfill(8)}_{current_count}.parquet"
446
+ filepath = str(data_dir / filename)
447
 
448
+ temp_ds.to_parquet(filepath)
 
 
 
 
449
 
450
+
451
+ files = sorted(list(data_dir.glob("*.parquet")))
452
+
453
+
454
+ if len(files) == 1:
455
+ api.upload_folder(
456
+ folder_path=str(data_dir),
457
+ repo_id=repo_id,
458
+ repo_type="dataset",
459
+ run_as_future=True,
460
+ token=os.environ["HF_TOKEN"],
461
+ commit_message=f"Embedded examples {last_count} thru {current_count} with folder",
462
+ )
463
+
464
+ else:
465
+
466
+ api.upload_file(
467
+ path_or_fileobj=filepath,
468
+ path_in_repo=f"data/{filename}",
469
+ repo_id=repo_id,
470
+ repo_type="dataset",
471
+ run_as_future=True,
472
+ token=os.environ["HF_TOKEN"],
473
+ commit_message=f"Embedded examples {last_count} thru {current_count}",
474
+ )
475
 
 
476
 
477
  # Delete old files
478
+
479
+ if len(files) > 4:
480
+ for file in files[:2]:
481
+ file.unlink()