PeteBleackley commited on
Commit
f8debfa
·
1 Parent(s): f5599c3

Ensure consistency of device assignment when training

Browse files
Files changed (1) hide show
  1. scripts.py +3 -2
scripts.py CHANGED
@@ -122,7 +122,8 @@ def train_models(path,progress=gradio.Progress(track_tqdm=True)):
122
  tokenizer = tokenizers.Tokenizer.from_pretrained('roberta-base')
123
  trainer = qarac.models.QaracTrainerModel.QaracTrainerModel('roberta-base',
124
  tokenizer)
125
- trainer.cuda()
 
126
  loss_fn = CombinedLoss()
127
  loss_fn.cuda()
128
  optimizer = torch.optim.NAdam(trainer.parameters(),lr=5.0e-5)
@@ -132,7 +133,7 @@ def train_models(path,progress=gradio.Progress(track_tqdm=True)):
132
  question_answering='corpora/question_answering.csv',
133
  reasoning='corpora/reasoning_train.csv',
134
  consistency='corpora/consistency.csv',
135
- device=trainer.device())
136
  n_batches = len(training_data)
137
  history = {}
138
  for epoch in range(25):
 
122
  tokenizer = tokenizers.Tokenizer.from_pretrained('roberta-base')
123
  trainer = qarac.models.QaracTrainerModel.QaracTrainerModel('roberta-base',
124
  tokenizer)
125
+ device = torch.device('cude:0')
126
+ trainer.to(device)
127
  loss_fn = CombinedLoss()
128
  loss_fn.cuda()
129
  optimizer = torch.optim.NAdam(trainer.parameters(),lr=5.0e-5)
 
133
  question_answering='corpora/question_answering.csv',
134
  reasoning='corpora/reasoning_train.csv',
135
  consistency='corpora/consistency.csv',
136
+ device=device)
137
  n_batches = len(training_data)
138
  history = {}
139
  for epoch in range(25):