valhalla commited on
Commit
180ed1e
1 Parent(s): 6197b2f

remove bias and minor fixes

Browse files
Files changed (1) hide show
  1. dalle_mini/modeling_bart_flax.py +30 -55
dalle_mini/modeling_bart_flax.py CHANGED
@@ -44,7 +44,7 @@ from transformers.modeling_flax_utils import (
44
  from transformers.utils import logging
45
 
46
 
47
- from configuration_bart import BartConfig
48
 
49
 
50
  logger = logging.get_logger(__name__)
@@ -80,7 +80,7 @@ class FlaxBartAttention(nn.Module):
80
  dense = partial(
81
  nn.Dense,
82
  self.embed_dim,
83
- use_bias=self.bias,
84
  dtype=self.dtype,
85
  kernel_init=jax.nn.initializers.normal(self.config.init_std),
86
  )
@@ -242,10 +242,14 @@ class FlaxBartEncoderLayer(nn.Module):
242
  self.fc1 = nn.Dense(
243
  self.config.encoder_ffn_dim,
244
  dtype=self.dtype,
 
245
  kernel_init=jax.nn.initializers.normal(self.config.init_std),
246
  )
247
  self.fc2 = nn.Dense(
248
- self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
 
 
 
249
  )
250
  self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
251
 
@@ -325,14 +329,18 @@ class FlaxBartDecoderLayer(nn.Module):
325
  dropout=self.config.attention_dropout,
326
  dtype=self.dtype,
327
  )
328
- self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
329
  self.fc1 = nn.Dense(
330
  self.config.encoder_ffn_dim,
331
  dtype=self.dtype,
 
332
  kernel_init=jax.nn.initializers.normal(self.config.init_std),
333
  )
334
  self.fc2 = nn.Dense(
335
- self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
 
 
 
336
  )
337
  self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
338
 
@@ -414,7 +422,6 @@ class FlaxBartDecoderLayerCollection(nn.Module):
414
  class FlaxBartEncoder(nn.Module):
415
  config: BartConfig
416
  dtype: jnp.dtype = jnp.float32 # the dtype of the computation
417
- embed_tokens: Optional[nn.Embed] = None
418
 
419
  def setup(self):
420
  self.dropout_layer = nn.Dropout(rate=self.config.dropout)
@@ -424,16 +431,15 @@ class FlaxBartEncoder(nn.Module):
424
  self.max_source_positions = self.config.max_position_embeddings
425
  self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0
426
 
427
- if self.embed_tokens is None:
428
- self.embed_tokens = nn.Embed(
429
- self.config.vocab_size,
430
- embed_dim,
431
- embedding_init=jax.nn.initializers.normal(self.config.init_std),
432
- )
433
 
434
  # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
435
  # and adjust num_embeddings appropriately. Other models don't have this hack
436
- self.offset = 2
437
  self.embed_positions = nn.Embed(
438
  self.config.max_position_embeddings + self.offset,
439
  embed_dim,
@@ -472,7 +478,6 @@ class FlaxBartEncoder(nn.Module):
472
  class FlaxBartDecoder(nn.Module):
473
  config: BartConfig
474
  dtype: jnp.dtype = jnp.float32 # the dtype of the computation
475
- embed_tokens: Optional[nn.Embed] = None
476
 
477
  def setup(self):
478
  self.dropout_layer = nn.Dropout(rate=self.config.dropout)
@@ -482,18 +487,17 @@ class FlaxBartDecoder(nn.Module):
482
  self.max_target_positions = self.config.max_position_embeddings
483
  self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0
484
 
485
- if self.embed_tokens is None:
486
- self.embed_tokens = nn.Embed(
487
- self.config.vocab_size,
488
- embed_dim,
489
- embedding_init=jax.nn.initializers.normal(self.config.init_std),
490
- )
491
 
492
  # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
493
  # and adjust num_embeddings appropriately. Other models don't have this hack
494
- self.offset = 2
495
  self.embed_positions = nn.Embed(
496
- self.config.max_position_embeddings + self.offset,
497
  embed_dim,
498
  embedding_init=jax.nn.initializers.normal(self.config.init_std),
499
  )
@@ -546,20 +550,8 @@ class FlaxBartModule(nn.Module):
546
  dtype: jnp.dtype = jnp.float32 # the dtype of the computation
547
 
548
  def setup(self):
549
- self.shared = nn.Embed(
550
- self.config.vocab_size,
551
- self.config.d_model,
552
- embedding_init=jax.nn.initializers.normal(self.config.init_std),
553
- )
554
- # a separate embedding is used for the decoder
555
- self.decoder_embed = nn.Embed(
556
- self.config.decoder_vocab_size,
557
- self.config.d_model,
558
- embedding_init=jax.nn.initializers.normal(self.config.init_std),
559
- )
560
-
561
- self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
562
- self.decoder = FlaxBartDecoder(self.config, dtype=self.dtype, embed_tokens=self.decoder_embed)
563
 
564
  def _get_encoder_module(self):
565
  return self.encoder
@@ -575,8 +567,6 @@ class FlaxBartModule(nn.Module):
575
  decoder_attention_mask,
576
  position_ids,
577
  decoder_position_ids,
578
- output_attentions: bool = False,
579
- output_hidden_states: bool = False,
580
  return_dict: bool = True,
581
  deterministic: bool = True,
582
  ):
@@ -584,9 +574,6 @@ class FlaxBartModule(nn.Module):
584
  input_ids=input_ids,
585
  attention_mask=attention_mask,
586
  position_ids=position_ids,
587
- output_attentions=output_attentions,
588
- output_hidden_states=output_hidden_states,
589
- return_dict=return_dict,
590
  deterministic=deterministic,
591
  )
592
 
@@ -596,9 +583,6 @@ class FlaxBartModule(nn.Module):
596
  position_ids=decoder_position_ids,
597
  encoder_hidden_states=encoder_outputs[0],
598
  encoder_attention_mask=attention_mask,
599
- output_attentions=output_attentions,
600
- output_hidden_states=output_hidden_states,
601
- return_dict=return_dict,
602
  deterministic=deterministic,
603
  )
604
 
@@ -629,8 +613,8 @@ class FlaxBartPreTrainedModel(FlaxPreTrainedModel):
629
  dtype: jnp.dtype = jnp.float32,
630
  **kwargs,
631
  ):
632
- module = self.module_class(config=config, dtype=dtype, **kwargs)
633
- super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
634
 
635
  def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
636
  # init input tensors
@@ -755,17 +739,11 @@ class FlaxBartPreTrainedModel(FlaxPreTrainedModel):
755
  decoder_attention_mask: Optional[jnp.ndarray] = None,
756
  position_ids: Optional[jnp.ndarray] = None,
757
  decoder_position_ids: Optional[jnp.ndarray] = None,
758
- output_attentions: Optional[bool] = None,
759
- output_hidden_states: Optional[bool] = None,
760
  return_dict: Optional[bool] = None,
761
  train: bool = False,
762
  params: dict = None,
763
  dropout_rng: PRNGKey = None,
764
  ):
765
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
766
- output_hidden_states = (
767
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
768
- )
769
  return_dict = return_dict if return_dict is not None else self.config.return_dict
770
 
771
  # prepare encoder inputs
@@ -817,7 +795,6 @@ class FlaxBartForConditionalGenerationModule(nn.Module):
817
  dtype=self.dtype,
818
  kernel_init=jax.nn.initializers.normal(self.config.init_std),
819
  )
820
- self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.config.decoder_vocab_size))
821
 
822
  def _get_encoder_module(self):
823
  return self.model.encoder
@@ -853,8 +830,6 @@ class FlaxBartForConditionalGenerationModule(nn.Module):
853
  else:
854
  lm_logits = self.lm_head(hidden_states)
855
 
856
- lm_logits += self.final_logits_bias
857
-
858
  return FlaxSeq2SeqLMOutput(
859
  logits=lm_logits,
860
  decoder_hidden_states=outputs.decoder_hidden_states,
 
44
  from transformers.utils import logging
45
 
46
 
47
+ from .configuration_bart import BartConfig
48
 
49
 
50
  logger = logging.get_logger(__name__)
 
80
  dense = partial(
81
  nn.Dense,
82
  self.embed_dim,
83
+ use_bias=False,
84
  dtype=self.dtype,
85
  kernel_init=jax.nn.initializers.normal(self.config.init_std),
86
  )
 
242
  self.fc1 = nn.Dense(
243
  self.config.encoder_ffn_dim,
244
  dtype=self.dtype,
245
+ use_bias=False,
246
  kernel_init=jax.nn.initializers.normal(self.config.init_std),
247
  )
248
  self.fc2 = nn.Dense(
249
+ self.embed_dim,
250
+ dtype=self.dtype,
251
+ use_bias=False,
252
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
253
  )
254
  self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
255
 
 
329
  dropout=self.config.attention_dropout,
330
  dtype=self.dtype,
331
  )
332
+ self.encoder_attn_layer_norm = nn
333
  self.fc1 = nn.Dense(
334
  self.config.encoder_ffn_dim,
335
  dtype=self.dtype,
336
+ use_bias=False,
337
  kernel_init=jax.nn.initializers.normal(self.config.init_std),
338
  )
339
  self.fc2 = nn.Dense(
340
+ self.embed_dim,
341
+ dtype=self.dtype,
342
+ use_bias=False,
343
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
344
  )
345
  self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
346
 
 
422
  class FlaxBartEncoder(nn.Module):
423
  config: BartConfig
424
  dtype: jnp.dtype = jnp.float32 # the dtype of the computation
 
425
 
426
  def setup(self):
427
  self.dropout_layer = nn.Dropout(rate=self.config.dropout)
 
431
  self.max_source_positions = self.config.max_position_embeddings
432
  self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0
433
 
434
+ self.embed_tokens = nn.Embed(
435
+ self.config.vocab_size,
436
+ embed_dim,
437
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
438
+ )
 
439
 
440
  # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
441
  # and adjust num_embeddings appropriately. Other models don't have this hack
442
+ self.offset = 0
443
  self.embed_positions = nn.Embed(
444
  self.config.max_position_embeddings + self.offset,
445
  embed_dim,
 
478
  class FlaxBartDecoder(nn.Module):
479
  config: BartConfig
480
  dtype: jnp.dtype = jnp.float32 # the dtype of the computation
 
481
 
482
  def setup(self):
483
  self.dropout_layer = nn.Dropout(rate=self.config.dropout)
 
487
  self.max_target_positions = self.config.max_position_embeddings
488
  self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0
489
 
490
+ self.embed_tokens = nn.Embed(
491
+ self.config.decoder_vocab_size,
492
+ embed_dim,
493
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
494
+ )
 
495
 
496
  # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
497
  # and adjust num_embeddings appropriately. Other models don't have this hack
498
+ self.offset = 0
499
  self.embed_positions = nn.Embed(
500
+ self.config.decoder_max_position_embeddings + self.offset,
501
  embed_dim,
502
  embedding_init=jax.nn.initializers.normal(self.config.init_std),
503
  )
 
550
  dtype: jnp.dtype = jnp.float32 # the dtype of the computation
551
 
552
  def setup(self):
553
+ self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype)
554
+ self.decoder = FlaxBartDecoder(self.config, dtype=self.dtype)
 
 
 
 
 
 
 
 
 
 
 
 
555
 
556
  def _get_encoder_module(self):
557
  return self.encoder
 
567
  decoder_attention_mask,
568
  position_ids,
569
  decoder_position_ids,
 
 
570
  return_dict: bool = True,
571
  deterministic: bool = True,
572
  ):
 
574
  input_ids=input_ids,
575
  attention_mask=attention_mask,
576
  position_ids=position_ids,
 
 
 
577
  deterministic=deterministic,
578
  )
579
 
 
583
  position_ids=decoder_position_ids,
584
  encoder_hidden_states=encoder_outputs[0],
585
  encoder_attention_mask=attention_mask,
 
 
 
586
  deterministic=deterministic,
587
  )
588
 
 
613
  dtype: jnp.dtype = jnp.float32,
614
  **kwargs,
615
  ):
616
+ module = self.module_class(config=config, dtype=dtype)
617
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, **kwargs)
618
 
619
  def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
620
  # init input tensors
 
739
  decoder_attention_mask: Optional[jnp.ndarray] = None,
740
  position_ids: Optional[jnp.ndarray] = None,
741
  decoder_position_ids: Optional[jnp.ndarray] = None,
 
 
742
  return_dict: Optional[bool] = None,
743
  train: bool = False,
744
  params: dict = None,
745
  dropout_rng: PRNGKey = None,
746
  ):
 
 
 
 
747
  return_dict = return_dict if return_dict is not None else self.config.return_dict
748
 
749
  # prepare encoder inputs
 
795
  dtype=self.dtype,
796
  kernel_init=jax.nn.initializers.normal(self.config.init_std),
797
  )
 
798
 
799
  def _get_encoder_module(self):
800
  return self.model.encoder
 
830
  else:
831
  lm_logits = self.lm_head(hidden_states)
832
 
 
 
833
  return FlaxSeq2SeqLMOutput(
834
  logits=lm_logits,
835
  decoder_hidden_states=outputs.decoder_hidden_states,