boris commited on
Commit
19946be
1 Parent(s): bf4da91

fix: correct decoder_input_ids and labels

Browse files
Files changed (1) hide show
  1. seq2seq/run_seq2seq_flax.py +7 -12
seq2seq/run_seq2seq_flax.py CHANGED
@@ -475,25 +475,20 @@ def main():
475
  )
476
 
477
  # set up targets
478
- # Note: we prepend the bos token instead of doing `shift_tokens_right` because the latter
479
- # removes the last token, and we know we don't need padding. In our case, labels
480
- # has a length of exactly 1 + 256, while shifting would produce 256 tokens.
481
- labels = [[config.decoder_start_token_id] + eval(indices) for indices in examples['encoding']]
482
  labels = np.asarray(labels)
483
 
484
  # We need the labels, in addition to the decoder_input_ids, for the compute_loss function
485
- # In our case, they are the same as decoder_input_ids. Is that correct?
486
  model_inputs["labels"] = labels
487
 
488
- # TODO: if data processing prevents correct compilation, we will:
489
- # - have data saved in JSONL (to avoid `eval` which is needed here to convert string "[2]" to list[int])
490
- # - use below `shift_tokens_right_fn`
491
  # In our case, this prepends the bos token and removes the last one
492
- # decoder_input_ids = shift_tokens_right_fn(
493
- # jnp.array(labels), config.pad_token_id, config.decoder_start_token_id
494
- # )
495
 
496
- model_inputs["decoder_input_ids"] = labels
497
 
498
  return model_inputs
499
 
 
475
  )
476
 
477
  # set up targets
478
+ # Note: labels correspond to our target indices
479
+ # decoder input ids are the same but shifted to the right with bos at the beginning (and without last token)
480
+ labels = [[eval(indices) for indices in examples['encoding']]
 
481
  labels = np.asarray(labels)
482
 
483
  # We need the labels, in addition to the decoder_input_ids, for the compute_loss function
 
484
  model_inputs["labels"] = labels
485
 
 
 
 
486
  # In our case, this prepends the bos token and removes the last one
487
+ decoder_input_ids = shift_tokens_right_fn(
488
+ jnp.array(labels), config.pad_token_id, config.decoder_start_token_id
489
+ )
490
 
491
+ model_inputs["decoder_input_ids"] = decoder_input_ids
492
 
493
  return model_inputs
494