boris commited on
Commit
ebac379
1 Parent(s): da9367c

fix: position embedding for generate method

Browse files
Files changed (1) hide show
  1. src/dalle_mini/model/modeling.py +39 -1
src/dalle_mini/model/modeling.py CHANGED
@@ -371,7 +371,8 @@ class FlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationMod
371
  def setup(self):
372
  self.model = FlaxBartModule(config=self.config, dtype=self.dtype)
373
  self.lm_head = nn.Dense(
374
- self.config.image_vocab_size + 1, # image vocab size + 1 for BOS
 
375
  use_bias=False,
376
  dtype=self.dtype,
377
  kernel_init=jax.nn.initializers.normal(self.config.init_std),
@@ -437,6 +438,8 @@ class DalleBart(
437
  - uses custom FlaxBartPreTrainedModel
438
  - uses custom FlaxBartForConditionalGenerationModule
439
  - no bias in decode method
 
 
440
  """
441
 
442
  module_class = FlaxBartForConditionalGenerationModule
@@ -572,3 +575,38 @@ class DalleBart(
572
  outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
573
 
574
  return outputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371
  def setup(self):
372
  self.model = FlaxBartModule(config=self.config, dtype=self.dtype)
373
  self.lm_head = nn.Dense(
374
+ self.config.image_vocab_size
375
+ + 1, # image vocab size + 1 for BOS to have same size as decoder inputs (for sharding)
376
  use_bias=False,
377
  dtype=self.dtype,
378
  kernel_init=jax.nn.initializers.normal(self.config.init_std),
 
438
  - uses custom FlaxBartPreTrainedModel
439
  - uses custom FlaxBartForConditionalGenerationModule
440
  - no bias in decode method
441
+ - custom prepare_inputs_for_generation using "max_length - 1" to avoid issues
442
+ related to position embedding during model.generate()
443
  """
444
 
445
  module_class = FlaxBartForConditionalGenerationModule
 
575
  outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
576
 
577
  return outputs
578
+
579
+ def prepare_inputs_for_generation(
580
+ self,
581
+ decoder_input_ids,
582
+ max_length,
583
+ attention_mask: Optional[jnp.DeviceArray] = None,
584
+ decoder_attention_mask: Optional[jnp.DeviceArray] = None,
585
+ encoder_outputs=None,
586
+ **kwargs,
587
+ ):
588
+ # initializing the cache
589
+ batch_size, seq_length = decoder_input_ids.shape
590
+
591
+ past_key_values = self.init_cache(batch_size, max_length - 1, encoder_outputs)
592
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
593
+ # But since the decoder uses a causal mask, those positions are masked anyways.
594
+ # Thus we can create a single static attention_mask here, which is more efficient for compilation
595
+ extended_attention_mask = jnp.ones((batch_size, max_length - 1), dtype="i4")
596
+ if decoder_attention_mask is not None:
597
+ position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
598
+ extended_attention_mask = lax.dynamic_update_slice(
599
+ extended_attention_mask, decoder_attention_mask, (0, 0)
600
+ )
601
+ else:
602
+ position_ids = jnp.broadcast_to(
603
+ jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)
604
+ )
605
+
606
+ return {
607
+ "past_key_values": past_key_values,
608
+ "encoder_outputs": encoder_outputs,
609
+ "encoder_attention_mask": attention_mask,
610
+ "decoder_attention_mask": extended_attention_mask,
611
+ "decoder_position_ids": position_ids,
612
+ }