Iker commited on
Commit
adaca32
2 Parent(s): 01e13d9 ba5f9a4

Merge pull request #2 from ikergarcia1996/multigpu-bug

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. translate.py +26 -13
README.md CHANGED
@@ -67,7 +67,7 @@ Any other ModelForSeq2SeqLM from HuggingFace's Hub should work with this library
67
  Pytorch >= 1.10.0
68
  See: https://pytorch.org/get-started/locally/
69
 
70
- Accelerate >= 0.7.1
71
  pip install --upgrade accelerate
72
 
73
  HuggingFace Transformers
 
67
  Pytorch >= 1.10.0
68
  See: https://pytorch.org/get-started/locally/
69
 
70
+ Accelerate >= 0.12.0
71
  pip install --upgrade accelerate
72
 
73
  HuggingFace Transformers
translate.py CHANGED
@@ -1,17 +1,23 @@
 
 
 
 
 
 
 
 
 
1
  from transformers import (
2
  AutoModelForSeq2SeqLM,
3
  AutoTokenizer,
4
  PreTrainedTokenizerBase,
5
  DataCollatorForSeq2Seq,
6
  )
7
- from tqdm import tqdm
8
- import argparse
9
- import torch
10
- from torch.utils.data import DataLoader
11
  from dataset import DatasetReader, count_lines
12
- import os
13
- from accelerate import Accelerator, DistributedType
14
- from accelerate.memory_utils import find_executable_batch_size
15
 
16
 
17
  def get_dataloader(
@@ -45,6 +51,7 @@ def get_dataloader(
45
  dataset,
46
  batch_size=batch_size,
47
  collate_fn=data_collator,
 
48
  )
49
 
50
 
@@ -72,7 +79,7 @@ def main(
72
  accelerator = Accelerator(
73
  mixed_precision=precision if precision != "32" else "no",
74
  split_batches=False,
75
- dispatch_batches=True,
76
  )
77
 
78
  print(f"Loading tokenizer {model_name}...")
@@ -115,7 +122,7 @@ def main(
115
  "top_p": top_p,
116
  }
117
 
118
- # total_lines: int = count_lines(sentences_path)
119
 
120
  if accelerator.is_main_process:
121
  print(
@@ -155,7 +162,7 @@ def main(
155
  samples_seen: int = 0
156
 
157
  with tqdm(
158
- total=len(data_loader.dataset),
159
  desc="Dataset translation",
160
  leave=True,
161
  ascii=True,
@@ -182,10 +189,16 @@ def main(
182
  generated_tokens, skip_special_tokens=True
183
  )
184
  if accelerator.is_main_process:
185
- if step == len(data_loader) - 1:
 
 
 
 
 
 
 
186
  tgt_text = tgt_text[
187
- : (len(data_loader.dataset) * num_return_sequences)
188
- - samples_seen
189
  ]
190
  else:
191
  samples_seen += len(tgt_text)
 
1
+ import os
2
+ import math
3
+ import argparse
4
+
5
+ import torch
6
+ from torch.utils.data import DataLoader
7
+
8
+ from tqdm import tqdm
9
+
10
  from transformers import (
11
  AutoModelForSeq2SeqLM,
12
  AutoTokenizer,
13
  PreTrainedTokenizerBase,
14
  DataCollatorForSeq2Seq,
15
  )
16
+
17
+
 
 
18
  from dataset import DatasetReader, count_lines
19
+
20
+ from accelerate import Accelerator, DistributedType, find_executable_batch_size
 
21
 
22
 
23
  def get_dataloader(
 
51
  dataset,
52
  batch_size=batch_size,
53
  collate_fn=data_collator,
54
+ num_workers=1,
55
  )
56
 
57
 
 
79
  accelerator = Accelerator(
80
  mixed_precision=precision if precision != "32" else "no",
81
  split_batches=False,
82
+ dispatch_batches=False,
83
  )
84
 
85
  print(f"Loading tokenizer {model_name}...")
 
122
  "top_p": top_p,
123
  }
124
 
125
+ total_lines: int = count_lines(sentences_path)
126
 
127
  if accelerator.is_main_process:
128
  print(
 
162
  samples_seen: int = 0
163
 
164
  with tqdm(
165
+ total=total_lines,
166
  desc="Dataset translation",
167
  leave=True,
168
  ascii=True,
 
189
  generated_tokens, skip_special_tokens=True
190
  )
191
  if accelerator.is_main_process:
192
+ if (
193
+ step
194
+ == math.ceil(
195
+ math.ceil(total_lines / batch_size)
196
+ / accelerator.num_processes
197
+ )
198
+ - 1
199
+ ):
200
  tgt_text = tgt_text[
201
+ : (total_lines * num_return_sequences) - samples_seen
 
202
  ]
203
  else:
204
  samples_seen += len(tgt_text)