Spaces:
Running
Running
Pedro Cuenca
commited on
Commit
•
ecafe5e
1
Parent(s):
32dc2d8
* Make padding mask optional.
Browse files* Perform preprocessing in parallel.
- seq2seq/do_run.sh +3 -2
- seq2seq/run_seq2seq_flax.py +7 -4
seq2seq/do_run.sh
CHANGED
@@ -3,7 +3,8 @@ python run_seq2seq_flax.py \
|
|
3 |
--train_file /data/CC12M/encoded-small-train.tsv \
|
4 |
--validation_file /data/CC12M/encoded-small-valid.tsv \
|
5 |
--output_dir output \
|
6 |
-
--per_device_train_batch_size
|
7 |
-
--per_device_eval_batch_size
|
|
|
8 |
--do_train \
|
9 |
--do_eval \
|
|
|
3 |
--train_file /data/CC12M/encoded-small-train.tsv \
|
4 |
--validation_file /data/CC12M/encoded-small-valid.tsv \
|
5 |
--output_dir output \
|
6 |
+
--per_device_train_batch_size 24 \
|
7 |
+
--per_device_eval_batch_size 24 \
|
8 |
+
--preprocessing_num_workers 48 \
|
9 |
--do_train \
|
10 |
--do_eval \
|
seq2seq/run_seq2seq_flax.py
CHANGED
@@ -482,8 +482,6 @@ def main():
|
|
482 |
|
483 |
# We need decoder_attention_mask so we can ignore pad tokens from loss
|
484 |
# TODO: I don't believe we need "decoder_attention_mask" in this case because all labels have same length
|
485 |
-
# However, we need to provide a mask or modify the compute_loss function, which relies on having one
|
486 |
-
model_inputs["decoder_attention_mask"] = np.ones(labels.shape)
|
487 |
#model_inputs["decoder_attention_mask"] = labels["attention_mask"]
|
488 |
|
489 |
return model_inputs
|
@@ -647,6 +645,9 @@ def main():
|
|
647 |
loss = optax.softmax_cross_entropy(logits, soft_labels)
|
648 |
loss = loss - normalizing_constant
|
649 |
|
|
|
|
|
|
|
650 |
# ignore padded tokens from loss
|
651 |
loss = loss * padding_mask
|
652 |
loss = loss.sum() / padding_mask.sum()
|
@@ -659,7 +660,8 @@ def main():
|
|
659 |
def compute_loss(params):
|
660 |
labels = batch.pop("labels")
|
661 |
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
662 |
-
|
|
|
663 |
return loss
|
664 |
|
665 |
grad_fn = jax.value_and_grad(compute_loss)
|
@@ -677,7 +679,8 @@ def main():
|
|
677 |
def eval_step(params, batch, label_smoothing_factor=0.0):
|
678 |
labels = batch.pop("labels")
|
679 |
logits = model(**batch, params=params, train=False)[0]
|
680 |
-
|
|
|
681 |
|
682 |
# summarize metrics
|
683 |
metrics = {"loss": loss}
|
|
|
482 |
|
483 |
# We need decoder_attention_mask so we can ignore pad tokens from loss
|
484 |
# TODO: I don't believe we need "decoder_attention_mask" in this case because all labels have same length
|
|
|
|
|
485 |
#model_inputs["decoder_attention_mask"] = labels["attention_mask"]
|
486 |
|
487 |
return model_inputs
|
|
|
645 |
loss = optax.softmax_cross_entropy(logits, soft_labels)
|
646 |
loss = loss - normalizing_constant
|
647 |
|
648 |
+
if padding_mask is None:
|
649 |
+
padding_mask = np.ones(loss.shape)
|
650 |
+
|
651 |
# ignore padded tokens from loss
|
652 |
loss = loss * padding_mask
|
653 |
loss = loss.sum() / padding_mask.sum()
|
|
|
660 |
def compute_loss(params):
|
661 |
labels = batch.pop("labels")
|
662 |
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
663 |
+
padding_mask = batch.get("decoder_attention_mask", None)
|
664 |
+
loss = loss_fn(logits, labels, padding_mask, label_smoothing_factor)
|
665 |
return loss
|
666 |
|
667 |
grad_fn = jax.value_and_grad(compute_loss)
|
|
|
679 |
def eval_step(params, batch, label_smoothing_factor=0.0):
|
680 |
labels = batch.pop("labels")
|
681 |
logits = model(**batch, params=params, train=False)[0]
|
682 |
+
padding_mask = batch.get("decoder_attention_mask", None)
|
683 |
+
loss = loss_fn(logits, labels, padding_mask, label_smoothing_factor)
|
684 |
|
685 |
# summarize metrics
|
686 |
metrics = {"loss": loss}
|