Pedro Cuenca commited on
Commit
ecafe5e
1 Parent(s): 32dc2d8

* Make padding mask optional.

Browse files

* Perform preprocessing in parallel.

Files changed (2) hide show
  1. seq2seq/do_run.sh +3 -2
  2. seq2seq/run_seq2seq_flax.py +7 -4
seq2seq/do_run.sh CHANGED
@@ -3,7 +3,8 @@ python run_seq2seq_flax.py \
3
  --train_file /data/CC12M/encoded-small-train.tsv \
4
  --validation_file /data/CC12M/encoded-small-valid.tsv \
5
  --output_dir output \
6
- --per_device_train_batch_size 16 \
7
- --per_device_eval_batch_size 16 \
 
8
  --do_train \
9
  --do_eval \
 
3
  --train_file /data/CC12M/encoded-small-train.tsv \
4
  --validation_file /data/CC12M/encoded-small-valid.tsv \
5
  --output_dir output \
6
+ --per_device_train_batch_size 24 \
7
+ --per_device_eval_batch_size 24 \
8
+ --preprocessing_num_workers 48 \
9
  --do_train \
10
  --do_eval \
seq2seq/run_seq2seq_flax.py CHANGED
@@ -482,8 +482,6 @@ def main():
482
 
483
  # We need decoder_attention_mask so we can ignore pad tokens from loss
484
  # TODO: I don't believe we need "decoder_attention_mask" in this case because all labels have same length
485
- # However, we need to provide a mask or modify the compute_loss function, which relies on having one
486
- model_inputs["decoder_attention_mask"] = np.ones(labels.shape)
487
  #model_inputs["decoder_attention_mask"] = labels["attention_mask"]
488
 
489
  return model_inputs
@@ -647,6 +645,9 @@ def main():
647
  loss = optax.softmax_cross_entropy(logits, soft_labels)
648
  loss = loss - normalizing_constant
649
 
 
 
 
650
  # ignore padded tokens from loss
651
  loss = loss * padding_mask
652
  loss = loss.sum() / padding_mask.sum()
@@ -659,7 +660,8 @@ def main():
659
  def compute_loss(params):
660
  labels = batch.pop("labels")
661
  logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
662
- loss = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
 
663
  return loss
664
 
665
  grad_fn = jax.value_and_grad(compute_loss)
@@ -677,7 +679,8 @@ def main():
677
  def eval_step(params, batch, label_smoothing_factor=0.0):
678
  labels = batch.pop("labels")
679
  logits = model(**batch, params=params, train=False)[0]
680
- loss = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
 
681
 
682
  # summarize metrics
683
  metrics = {"loss": loss}
 
482
 
483
  # We need decoder_attention_mask so we can ignore pad tokens from loss
484
  # TODO: I don't believe we need "decoder_attention_mask" in this case because all labels have same length
 
 
485
  #model_inputs["decoder_attention_mask"] = labels["attention_mask"]
486
 
487
  return model_inputs
 
645
  loss = optax.softmax_cross_entropy(logits, soft_labels)
646
  loss = loss - normalizing_constant
647
 
648
+ if padding_mask is None:
649
+ padding_mask = np.ones(loss.shape)
650
+
651
  # ignore padded tokens from loss
652
  loss = loss * padding_mask
653
  loss = loss.sum() / padding_mask.sum()
 
660
  def compute_loss(params):
661
  labels = batch.pop("labels")
662
  logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
663
+ padding_mask = batch.get("decoder_attention_mask", None)
664
+ loss = loss_fn(logits, labels, padding_mask, label_smoothing_factor)
665
  return loss
666
 
667
  grad_fn = jax.value_and_grad(compute_loss)
 
679
  def eval_step(params, batch, label_smoothing_factor=0.0):
680
  labels = batch.pop("labels")
681
  logits = model(**batch, params=params, train=False)[0]
682
+ padding_mask = batch.get("decoder_attention_mask", None)
683
+ loss = loss_fn(logits, labels, padding_mask, label_smoothing_factor)
684
 
685
  # summarize metrics
686
  metrics = {"loss": loss}