Spaces:
Running
Running
fix: position embedding for generate method
Browse files
src/dalle_mini/model/modeling.py
CHANGED
@@ -371,7 +371,8 @@ class FlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationMod
|
|
371 |
def setup(self):
|
372 |
self.model = FlaxBartModule(config=self.config, dtype=self.dtype)
|
373 |
self.lm_head = nn.Dense(
|
374 |
-
self.config.image_vocab_size
|
|
|
375 |
use_bias=False,
|
376 |
dtype=self.dtype,
|
377 |
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
@@ -437,6 +438,8 @@ class DalleBart(
|
|
437 |
- uses custom FlaxBartPreTrainedModel
|
438 |
- uses custom FlaxBartForConditionalGenerationModule
|
439 |
- no bias in decode method
|
|
|
|
|
440 |
"""
|
441 |
|
442 |
module_class = FlaxBartForConditionalGenerationModule
|
@@ -572,3 +575,38 @@ class DalleBart(
|
|
572 |
outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
|
573 |
|
574 |
return outputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
371 |
def setup(self):
|
372 |
self.model = FlaxBartModule(config=self.config, dtype=self.dtype)
|
373 |
self.lm_head = nn.Dense(
|
374 |
+
self.config.image_vocab_size
|
375 |
+
+ 1, # image vocab size + 1 for BOS to have same size as decoder inputs (for sharding)
|
376 |
use_bias=False,
|
377 |
dtype=self.dtype,
|
378 |
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
|
|
438 |
- uses custom FlaxBartPreTrainedModel
|
439 |
- uses custom FlaxBartForConditionalGenerationModule
|
440 |
- no bias in decode method
|
441 |
+
- custom prepare_inputs_for_generation using "max_length - 1" to avoid issues
|
442 |
+
related to position embedding during model.generate()
|
443 |
"""
|
444 |
|
445 |
module_class = FlaxBartForConditionalGenerationModule
|
|
|
575 |
outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
|
576 |
|
577 |
return outputs
|
578 |
+
|
579 |
+
def prepare_inputs_for_generation(
|
580 |
+
self,
|
581 |
+
decoder_input_ids,
|
582 |
+
max_length,
|
583 |
+
attention_mask: Optional[jnp.DeviceArray] = None,
|
584 |
+
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
|
585 |
+
encoder_outputs=None,
|
586 |
+
**kwargs,
|
587 |
+
):
|
588 |
+
# initializing the cache
|
589 |
+
batch_size, seq_length = decoder_input_ids.shape
|
590 |
+
|
591 |
+
past_key_values = self.init_cache(batch_size, max_length - 1, encoder_outputs)
|
592 |
+
# Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
|
593 |
+
# But since the decoder uses a causal mask, those positions are masked anyways.
|
594 |
+
# Thus we can create a single static attention_mask here, which is more efficient for compilation
|
595 |
+
extended_attention_mask = jnp.ones((batch_size, max_length - 1), dtype="i4")
|
596 |
+
if decoder_attention_mask is not None:
|
597 |
+
position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
|
598 |
+
extended_attention_mask = lax.dynamic_update_slice(
|
599 |
+
extended_attention_mask, decoder_attention_mask, (0, 0)
|
600 |
+
)
|
601 |
+
else:
|
602 |
+
position_ids = jnp.broadcast_to(
|
603 |
+
jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)
|
604 |
+
)
|
605 |
+
|
606 |
+
return {
|
607 |
+
"past_key_values": past_key_values,
|
608 |
+
"encoder_outputs": encoder_outputs,
|
609 |
+
"encoder_attention_mask": attention_mask,
|
610 |
+
"decoder_attention_mask": extended_attention_mask,
|
611 |
+
"decoder_position_ids": position_ids,
|
612 |
+
}
|