Spaces:
Build error
Build error
PeteBleackley
commited on
Commit
·
5f107ac
1
Parent(s):
269f149
Moved device assignement to beginning of function
Browse files- scripts.py +2 -1
scripts.py
CHANGED
@@ -118,11 +118,12 @@ def prepare_training_datasets():
|
|
118 |
consistency.to_csv('corpora/consistency.csv')
|
119 |
|
120 |
def train_models(path,progress=gradio.Progress(track_tqdm=True)):
|
|
|
121 |
torch.cuda.empty_cache()
|
122 |
tokenizer = tokenizers.Tokenizer.from_pretrained('roberta-base')
|
123 |
trainer = qarac.models.QaracTrainerModel.QaracTrainerModel('roberta-base',
|
124 |
tokenizer)
|
125 |
-
|
126 |
trainer.to(device)
|
127 |
loss_fn = CombinedLoss()
|
128 |
loss_fn.cuda()
|
|
|
118 |
consistency.to_csv('corpora/consistency.csv')
|
119 |
|
120 |
def train_models(path,progress=gradio.Progress(track_tqdm=True)):
|
121 |
+
device = torch.device('cuda:0')
|
122 |
torch.cuda.empty_cache()
|
123 |
tokenizer = tokenizers.Tokenizer.from_pretrained('roberta-base')
|
124 |
trainer = qarac.models.QaracTrainerModel.QaracTrainerModel('roberta-base',
|
125 |
tokenizer)
|
126 |
+
|
127 |
trainer.to(device)
|
128 |
loss_fn = CombinedLoss()
|
129 |
loss_fn.cuda()
|