boris commited on
Commit
9db361a
1 Parent(s): d61405b

feat: simplify loss function

Browse files
Files changed (1) hide show
  1. seq2seq/run_seq2seq_flax.py +8 -22
seq2seq/run_seq2seq_flax.py CHANGED
@@ -639,33 +639,19 @@ def main():
639
  state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dropout_rng=dropout_rng)
640
 
641
  # label smoothed cross entropy
642
- def loss_fn(logits, labels, label_smoothing_factor=0.0):
643
- """
644
- The label smoothing implementation is adapted from Flax's official example:
645
- https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
646
- """
647
- vocab_size = logits.shape[-1]
648
- confidence = 1.0 - label_smoothing_factor
649
- low_confidence = (1.0 - confidence) / (vocab_size - 1)
650
- normalizing_constant = -(
651
- confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
652
- )
653
- soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
654
-
655
- loss = optax.softmax_cross_entropy(logits, soft_labels)
656
- loss = loss - normalizing_constant
657
-
658
  loss = loss.mean()
659
  return loss
660
 
661
  # Define gradient update step fn
662
- def train_step(state, batch, label_smoothing_factor=0.0):
663
  dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
664
 
665
  def compute_loss(params):
666
  labels = batch.pop("labels")
667
  logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
668
- loss = loss_fn(logits, labels, label_smoothing_factor)
669
  return loss
670
 
671
  grad_fn = jax.value_and_grad(compute_loss)
@@ -680,10 +666,10 @@ def main():
680
  return new_state, metrics
681
 
682
  # Define eval fn
683
- def eval_step(params, batch, label_smoothing_factor=0.0):
684
  labels = batch.pop("labels")
685
  logits = model(**batch, params=params, train=False)[0]
686
- loss = loss_fn(logits, labels, label_smoothing_factor)
687
 
688
  # summarize metrics
689
  metrics = {"loss": loss}
@@ -704,9 +690,9 @@ def main():
704
 
705
  # Create parallel version of the train and eval step
706
  p_train_step = jax.pmap(
707
- partial(train_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch", donate_argnums=(0,)
708
  )
709
- p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch")
710
  p_generate_step = jax.pmap(generate_step, "batch")
711
 
712
  # Replicate the train state on each device
 
639
  state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dropout_rng=dropout_rng)
640
 
641
  # label smoothed cross entropy
642
+ def loss_fn(logits, labels):
643
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1]))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
644
  loss = loss.mean()
645
  return loss
646
 
647
  # Define gradient update step fn
648
+ def train_step(state, batch):
649
  dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
650
 
651
  def compute_loss(params):
652
  labels = batch.pop("labels")
653
  logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
654
+ loss = loss_fn(logits, labels)
655
  return loss
656
 
657
  grad_fn = jax.value_and_grad(compute_loss)
 
666
  return new_state, metrics
667
 
668
  # Define eval fn
669
+ def eval_step(params, batch):
670
  labels = batch.pop("labels")
671
  logits = model(**batch, params=params, train=False)[0]
672
+ loss = loss_fn(logits, labels)
673
 
674
  # summarize metrics
675
  metrics = {"loss": loss}
 
690
 
691
  # Create parallel version of the train and eval step
692
  p_train_step = jax.pmap(
693
+ train_step, "batch", donate_argnums=(0,)
694
  )
695
+ p_eval_step = jax.pmap(eval_step, "batch")
696
  p_generate_step = jax.pmap(generate_step, "batch")
697
 
698
  # Replicate the train state on each device