Spaces:
Build error
Build error
PeteBleackley
commited on
Commit
·
f5599c3
1
Parent(s):
4a7707c
Ensure consistency of device assignment when training
Browse files- qarac/corpora/CombinedCorpus.py +4 -3
- scripts.py +2 -1
qarac/corpora/CombinedCorpus.py
CHANGED
@@ -58,6 +58,7 @@ class CombinedCorpus(torch.utils.data.IterableDataset):
|
|
58 |
{},
|
59 |
'consistency'),
|
60 |
n_samples)
|
|
|
61 |
self.batches = None
|
62 |
self.pad_token = tokenizer.token_to_id('<pad>')
|
63 |
self.max_lengths = {}
|
@@ -145,11 +146,11 @@ class CombinedCorpus(torch.utils.data.IterableDataset):
|
|
145 |
|
146 |
X={key:self.pad(value,self.max_lengths[key])
|
147 |
for (key,value) in X.items()}
|
148 |
-
Y={key:torch.tensor(value,device=
|
149 |
self.max_lengths[key],
|
150 |
False)
|
151 |
for (key,value) in Y.items()}
|
152 |
-
Y['question_answering'] = torch.zeros((n,768),device=
|
153 |
return (X,
|
154 |
tuple([Y[key]
|
155 |
for key in ('encode_decode',
|
@@ -176,7 +177,7 @@ class CombinedCorpus(torch.utils.data.IterableDataset):
|
|
176 |
sample.pad(maxlen,pad_id=self.pad_token)
|
177 |
input_ids = torch.tensor([sample.ids
|
178 |
for sample in batch],
|
179 |
-
device=
|
180 |
result = input_ids
|
181 |
if inputs:
|
182 |
attention_mask = torch.not_equal(input_ids,
|
|
|
58 |
{},
|
59 |
'consistency'),
|
60 |
n_samples)
|
61 |
+
self.device = kwargs['device']
|
62 |
self.batches = None
|
63 |
self.pad_token = tokenizer.token_to_id('<pad>')
|
64 |
self.max_lengths = {}
|
|
|
146 |
|
147 |
X={key:self.pad(value,self.max_lengths[key])
|
148 |
for (key,value) in X.items()}
|
149 |
+
Y={key:torch.tensor(value,device=self.device).float() if key=='consistency' else self.pad(value,
|
150 |
self.max_lengths[key],
|
151 |
False)
|
152 |
for (key,value) in Y.items()}
|
153 |
+
Y['question_answering'] = torch.zeros((n,768),device=self.device)
|
154 |
return (X,
|
155 |
tuple([Y[key]
|
156 |
for key in ('encode_decode',
|
|
|
177 |
sample.pad(maxlen,pad_id=self.pad_token)
|
178 |
input_ids = torch.tensor([sample.ids
|
179 |
for sample in batch],
|
180 |
+
device=self.device)
|
181 |
result = input_ids
|
182 |
if inputs:
|
183 |
attention_mask = torch.not_equal(input_ids,
|
scripts.py
CHANGED
@@ -131,7 +131,8 @@ def train_models(path,progress=gradio.Progress(track_tqdm=True)):
|
|
131 |
all_text='corpora/all_text.csv',
|
132 |
question_answering='corpora/question_answering.csv',
|
133 |
reasoning='corpora/reasoning_train.csv',
|
134 |
-
consistency='corpora/consistency.csv'
|
|
|
135 |
n_batches = len(training_data)
|
136 |
history = {}
|
137 |
for epoch in range(25):
|
|
|
131 |
all_text='corpora/all_text.csv',
|
132 |
question_answering='corpora/question_answering.csv',
|
133 |
reasoning='corpora/reasoning_train.csv',
|
134 |
+
consistency='corpora/consistency.csv',
|
135 |
+
device=trainer.device())
|
136 |
n_batches = len(training_data)
|
137 |
history = {}
|
138 |
for epoch in range(25):
|