PeteBleackley commited on
Commit
f5599c3
·
1 Parent(s): 4a7707c

Ensure consistency of device assignment when training

Browse files
Files changed (2) hide show
  1. qarac/corpora/CombinedCorpus.py +4 -3
  2. scripts.py +2 -1
qarac/corpora/CombinedCorpus.py CHANGED
@@ -58,6 +58,7 @@ class CombinedCorpus(torch.utils.data.IterableDataset):
58
  {},
59
  'consistency'),
60
  n_samples)
 
61
  self.batches = None
62
  self.pad_token = tokenizer.token_to_id('<pad>')
63
  self.max_lengths = {}
@@ -145,11 +146,11 @@ class CombinedCorpus(torch.utils.data.IterableDataset):
145
 
146
  X={key:self.pad(value,self.max_lengths[key])
147
  for (key,value) in X.items()}
148
- Y={key:torch.tensor(value,device='cuda').float() if key=='consistency' else self.pad(value,
149
  self.max_lengths[key],
150
  False)
151
  for (key,value) in Y.items()}
152
- Y['question_answering'] = torch.zeros((n,768),device='cuda')
153
  return (X,
154
  tuple([Y[key]
155
  for key in ('encode_decode',
@@ -176,7 +177,7 @@ class CombinedCorpus(torch.utils.data.IterableDataset):
176
  sample.pad(maxlen,pad_id=self.pad_token)
177
  input_ids = torch.tensor([sample.ids
178
  for sample in batch],
179
- device='cuda')
180
  result = input_ids
181
  if inputs:
182
  attention_mask = torch.not_equal(input_ids,
 
58
  {},
59
  'consistency'),
60
  n_samples)
61
+ self.device = kwargs['device']
62
  self.batches = None
63
  self.pad_token = tokenizer.token_to_id('<pad>')
64
  self.max_lengths = {}
 
146
 
147
  X={key:self.pad(value,self.max_lengths[key])
148
  for (key,value) in X.items()}
149
+ Y={key:torch.tensor(value,device=self.device).float() if key=='consistency' else self.pad(value,
150
  self.max_lengths[key],
151
  False)
152
  for (key,value) in Y.items()}
153
+ Y['question_answering'] = torch.zeros((n,768),device=self.device)
154
  return (X,
155
  tuple([Y[key]
156
  for key in ('encode_decode',
 
177
  sample.pad(maxlen,pad_id=self.pad_token)
178
  input_ids = torch.tensor([sample.ids
179
  for sample in batch],
180
+ device=self.device)
181
  result = input_ids
182
  if inputs:
183
  attention_mask = torch.not_equal(input_ids,
scripts.py CHANGED
@@ -131,7 +131,8 @@ def train_models(path,progress=gradio.Progress(track_tqdm=True)):
131
  all_text='corpora/all_text.csv',
132
  question_answering='corpora/question_answering.csv',
133
  reasoning='corpora/reasoning_train.csv',
134
- consistency='corpora/consistency.csv')
 
135
  n_batches = len(training_data)
136
  history = {}
137
  for epoch in range(25):
 
131
  all_text='corpora/all_text.csv',
132
  question_answering='corpora/question_answering.csv',
133
  reasoning='corpora/reasoning_train.csv',
134
+ consistency='corpora/consistency.csv',
135
+ device=trainer.device())
136
  n_batches = len(training_data)
137
  history = {}
138
  for epoch in range(25):