import torch import torch.nn as nn from torch.nn import CrossEntropyLoss from transformers.modeling_outputs import Seq2SeqLMOutput from transformers.activations import ACT2FN from .flash_atten import MHA # Import the MHA class from the provided implementation from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss from liger_kernel.transformers.rms_norm import LigerRMSNorm from liger_kernel.transformers.swiglu import LigerSwiGLUMLP from .configuration_custom_seq2seq_llm import Seq2SeqConfig class RMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.eps = eps def forward(self, hidden_states): variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.eps) return self.weight * hidden_states.to(self.weight.dtype) class CustomSeq2SeqLLM(PreTrainedModel): config_class = Seq2SeqConfig base_model_prefix = "custom_seq2seq_llm" def __init__(self, config): super().__init__(config) self.config = config self.shared = nn.Embedding(config.vocab_size, config.hidden_size) self.encoder = CustomEncoder(config) self.decoder = CustomDecoder(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.loss_fn = LigerCrossEntropyLoss() self.init_weights() def get_encoder(self): return self.encoder def get_decoder(self): return self.decoder def get_input_embeddings(self): return self.shared def set_input_embeddings(self, value): self.shared = value def get_output_embeddings(self): return self.lm_head def forward( self, input_ids=None, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, encoder_outputs=None, past_key_values=None, labels=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, position_ids=None, ): if position_ids is None and input_ids is not None: position_ids = torch.arange(input_ids.size(1), dtype=torch.long, device=input_ids.device) position_ids = position_ids.unsqueeze(0).expand_as(input_ids) if encoder_outputs is None and input_ids is not None: encoder_outputs = self.encoder( self.shared(input_ids), attention_mask=attention_mask, position_ids=position_ids, ) if decoder_input_ids is None: if labels is not None: decoder_input_ids = self._shift_right(labels) elif input_ids is not None: decoder_input_ids = input_ids else: raise ValueError("Either decoder_input_ids, labels, or input_ids must be provided.") decoder_outputs = self.decoder( self.shared(decoder_input_ids), encoder_outputs, attention_mask=decoder_attention_mask, position_ids=position_ids, ) lm_logits = self.lm_head(decoder_outputs) loss = None if labels is not None: loss = self.loss_fn(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) return Seq2SeqLMOutput( loss=loss, logits=lm_logits, encoder_last_hidden_state=encoder_outputs, decoder_hidden_states=decoder_outputs, ) def _shift_right(self, input_ids): shifted_input_ids = input_ids.new_zeros(input_ids.shape) shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() shifted_input_ids[..., 0] = self.config.pad_token_id return shifted_input_ids class CustomEncoder(nn.Module): def __init__(self, config): super().__init__() self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.num_encoder_layers)]) self.layer_norm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps) def forward(self, hidden_states, attention_mask=None, position_ids=None): for layer in self.layers: hidden_states = layer(hidden_states, attention_mask, position_ids) return self.layer_norm(hidden_states) class EncoderLayer(nn.Module): def __init__(self, config): super().__init__() self.self_attn = MHA(config.hidden_size, config.num_attention_heads, num_heads_kv=config.num_key_value_heads, dropout=config.attention_probs_dropout_prob, causal=False, rotary_emb_dim=config.rotary_emb_dim, rotary_emb_base=config.rotary_emb_base, rotary_emb_scale_base=config.rotary_emb_scale_base, rotary_emb_interleaved=config.rotary_emb_interleaved) self.feed_forward = LigerSwiGLUMLP(config) self.layer_norm1 = RMSNorm(config.hidden_size, eps=config.layer_norm_eps) self.layer_norm2 = RMSNorm(config.hidden_size, eps=config.layer_norm_eps) def forward(self, hidden_states, attention_mask=None, position_ids=None): normed_hidden_states = self.layer_norm1(hidden_states) attention_output = self.self_attn(normed_hidden_states, key_padding_mask=attention_mask) hidden_states = hidden_states + attention_output normed_hidden_states = self.layer_norm2(hidden_states) feed_forward_output = self.feed_forward(normed_hidden_states) hidden_states = hidden_states + feed_forward_output return hidden_states class CustomDecoder(nn.Module): def __init__(self, config): super().__init__() self.layers = nn.ModuleList([ DecoderLayer(config, use_cross_attention=self._should_use_cross_attention(i, config.num_decoder_layers)) for i in range(config.num_decoder_layers) ]) self.layer_norm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps) def _should_use_cross_attention(self, layer_idx, total_layers): return layer_idx == 0 or layer_idx == total_layers - 1 or layer_idx % 2 == 0 def forward(self, hidden_states, encoder_hidden_states, attention_mask=None, position_ids=None): for layer in self.layers: hidden_states = layer(hidden_states, encoder_hidden_states, attention_mask, position_ids) return self.layer_norm(hidden_states) class DecoderLayer(nn.Module): def __init__(self, config, use_cross_attention=True): super().__init__() self.use_cross_attention = use_cross_attention self.self_attn = MHA(config.hidden_size, config.num_attention_heads, num_heads_kv=config.num_key_value_heads, dropout=config.attention_probs_dropout_prob, causal=True, rotary_emb_dim=config.rotary_emb_dim, rotary_emb_base=config.rotary_emb_base, rotary_emb_scale_base=config.rotary_emb_scale_base, rotary_emb_interleaved=config.rotary_emb_interleaved) if use_cross_attention: self.cross_attn = MHA(config.hidden_size, config.num_attention_heads, num_heads_kv=config.num_key_value_heads, dropout=config.attention_probs_dropout_prob, causal=False, cross_attn=True) self.layer_norm2 = RMSNorm(config.hidden_size, eps=config.layer_norm_eps) self.feed_forward = LigerSwiGLUMLP(config) self.layer_norm1 = RMSNorm(config.hidden_size, eps=config.layer_norm_eps) self.layer_norm3 = RMSNorm(config.hidden_size, eps=config.layer_norm_eps) def forward(self, hidden_states, encoder_hidden_states, attention_mask=None, position_ids=None): normed_hidden_states = self.layer_norm1(hidden_states) self_attention_output = self.self_attn(normed_hidden_states, key_padding_mask=attention_mask) hidden_states = hidden_states + self_attention_output if self.use_cross_attention: normed_hidden_states = self.layer_norm2(hidden_states) cross_attention_output = self.cross_attn(normed_hidden_states, x_kv=encoder_hidden_states, key_padding_mask=attention_mask) hidden_states = hidden_states + cross_attention_output normed_hidden_states = self.layer_norm3(hidden_states) feed_forward_output = self.feed_forward(normed_hidden_states) hidden_states = hidden_states + feed_forward_output return hidden_states class FeedForward(nn.Module): def __init__(self, config): super().__init__() self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) self.act = ACT2FN[config.hidden_act] self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.dropout(x) x = self.fc2(x) x = self.dropout(x) return x