PeteBleackley commited on
Commit
5f107ac
·
1 Parent(s): 269f149

Moved device assignement to beginning of function

Browse files
Files changed (1) hide show
  1. scripts.py +2 -1
scripts.py CHANGED
@@ -118,11 +118,12 @@ def prepare_training_datasets():
118
  consistency.to_csv('corpora/consistency.csv')
119
 
120
  def train_models(path,progress=gradio.Progress(track_tqdm=True)):
 
121
  torch.cuda.empty_cache()
122
  tokenizer = tokenizers.Tokenizer.from_pretrained('roberta-base')
123
  trainer = qarac.models.QaracTrainerModel.QaracTrainerModel('roberta-base',
124
  tokenizer)
125
- device = torch.device('cuda:0')
126
  trainer.to(device)
127
  loss_fn = CombinedLoss()
128
  loss_fn.cuda()
 
118
  consistency.to_csv('corpora/consistency.csv')
119
 
120
  def train_models(path,progress=gradio.Progress(track_tqdm=True)):
121
+ device = torch.device('cuda:0')
122
  torch.cuda.empty_cache()
123
  tokenizer = tokenizers.Tokenizer.from_pretrained('roberta-base')
124
  trainer = qarac.models.QaracTrainerModel.QaracTrainerModel('roberta-base',
125
  tokenizer)
126
+
127
  trainer.to(device)
128
  loss_fn = CombinedLoss()
129
  loss_fn.cuda()