Iker commited on
Commit
feed195
2 Parent(s): 60d3e46 0af8cb0

Merge pull request #3 from ikergarcia1996/multigpu-bug

Browse files
Files changed (2) hide show
  1. dataset.py +8 -7
  2. translate.py +4 -0
dataset.py CHANGED
@@ -14,6 +14,8 @@ class DatasetReader(IterableDataset):
14
  self.current_line = 0
15
  self.total_lines = count_lines(filename)
16
  print(f"{self.total_lines} lines in {filename}")
 
 
17
 
18
  def preprocess(self, text: str):
19
  self.current_line += 1
@@ -29,9 +31,7 @@ class DatasetReader(IterableDataset):
29
  )
30
 
31
  def __iter__(self):
32
- file_itr = open(self.filename, "r", encoding="utf8")
33
- mapped_itr = map(self.preprocess, file_itr)
34
- return mapped_itr
35
 
36
  def __len__(self):
37
  return self.total_lines
@@ -50,6 +50,10 @@ class ParallelTextReader(IterableDataset):
50
  self.num_sentences = gold_path_lines
51
  self.current_line = 0
52
 
 
 
 
 
53
  def preprocess(self, pred: str, gold: str):
54
  self.current_line += 1
55
  pred = pred.rstrip().strip()
@@ -61,10 +65,7 @@ class ParallelTextReader(IterableDataset):
61
  return pred, [gold]
62
 
63
  def __iter__(self):
64
- pred_itr = open(self.pred_path, "r", encoding="utf8")
65
- gold_itr = open(self.gold_path, "r", encoding="utf8")
66
- mapped_itr = map(self.preprocess, pred_itr, gold_itr)
67
- return mapped_itr
68
 
69
  def __len__(self):
70
  return self.num_sentences
 
14
  self.current_line = 0
15
  self.total_lines = count_lines(filename)
16
  print(f"{self.total_lines} lines in {filename}")
17
+ file_itr = open(self.filename, "r", encoding="utf8")
18
+ self.mapped_itr = map(self.preprocess, file_itr)
19
 
20
  def preprocess(self, text: str):
21
  self.current_line += 1
 
31
  )
32
 
33
  def __iter__(self):
34
+ return self.mapped_itr
 
 
35
 
36
  def __len__(self):
37
  return self.total_lines
 
50
  self.num_sentences = gold_path_lines
51
  self.current_line = 0
52
 
53
+ pred_itr = open(self.pred_path, "r", encoding="utf8")
54
+ gold_itr = open(self.gold_path, "r", encoding="utf8")
55
+ self.mapped_itr = map(self.preprocess, pred_itr, gold_itr)
56
+
57
  def preprocess(self, pred: str, gold: str):
58
  self.current_line += 1
59
  pred = pred.rstrip().strip()
 
65
  return pred, [gold]
66
 
67
  def __iter__(self):
68
+ return self.mapped_itr
 
 
 
69
 
70
  def __len__(self):
71
  return self.num_sentences
translate.py CHANGED
@@ -19,6 +19,10 @@ from dataset import DatasetReader, count_lines
19
 
20
  from accelerate import Accelerator, DistributedType, find_executable_batch_size
21
 
 
 
 
 
22
 
23
  def get_dataloader(
24
  accelerator: Accelerator,
 
19
 
20
  from accelerate import Accelerator, DistributedType, find_executable_batch_size
21
 
22
+ torch.multiprocessing.set_sharing_strategy(
23
+ "file_system"
24
+ ) # FIXES RuntimeError: Too many open files.
25
+
26
 
27
  def get_dataloader(
28
  accelerator: Accelerator,