Spaces:
Build error
Build error
PeteBleackley
commited on
Commit
·
f8debfa
1
Parent(s):
f5599c3
Ensure consistency of device assignment when training
Browse files- scripts.py +3 -2
scripts.py
CHANGED
@@ -122,7 +122,8 @@ def train_models(path,progress=gradio.Progress(track_tqdm=True)):
|
|
122 |
tokenizer = tokenizers.Tokenizer.from_pretrained('roberta-base')
|
123 |
trainer = qarac.models.QaracTrainerModel.QaracTrainerModel('roberta-base',
|
124 |
tokenizer)
|
125 |
-
|
|
|
126 |
loss_fn = CombinedLoss()
|
127 |
loss_fn.cuda()
|
128 |
optimizer = torch.optim.NAdam(trainer.parameters(),lr=5.0e-5)
|
@@ -132,7 +133,7 @@ def train_models(path,progress=gradio.Progress(track_tqdm=True)):
|
|
132 |
question_answering='corpora/question_answering.csv',
|
133 |
reasoning='corpora/reasoning_train.csv',
|
134 |
consistency='corpora/consistency.csv',
|
135 |
-
device=
|
136 |
n_batches = len(training_data)
|
137 |
history = {}
|
138 |
for epoch in range(25):
|
|
|
122 |
tokenizer = tokenizers.Tokenizer.from_pretrained('roberta-base')
|
123 |
trainer = qarac.models.QaracTrainerModel.QaracTrainerModel('roberta-base',
|
124 |
tokenizer)
|
125 |
+
device = torch.device('cude:0')
|
126 |
+
trainer.to(device)
|
127 |
loss_fn = CombinedLoss()
|
128 |
loss_fn.cuda()
|
129 |
optimizer = torch.optim.NAdam(trainer.parameters(),lr=5.0e-5)
|
|
|
133 |
question_answering='corpora/question_answering.csv',
|
134 |
reasoning='corpora/reasoning_train.csv',
|
135 |
consistency='corpora/consistency.csv',
|
136 |
+
device=device)
|
137 |
n_batches = len(training_data)
|
138 |
history = {}
|
139 |
for epoch in range(25):
|