import math from typing import Optional, Tuple, Union, List, Dict import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat from transformers import PreTrainedModel, LlamaConfig from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.models.llama.modeling_llama import ( LlamaRMSNorm, LlamaRotaryEmbedding, LlamaLinearScalingRotaryEmbedding, LlamaDynamicNTKScalingRotaryEmbedding, LlamaMLP, apply_rotary_pos_emb, repeat_kv, ) from transformers.cache_utils import Cache, DynamicCache, StaticCache class DiffLLaMAConfig(LlamaConfig): """ Configuration class for the DiffLLaMA model. Inherits from LlamaConfig and can be extended with additional parameters. """ model_type = "diff_llama" def __init__( self, num_kv_heads: int = 8, intermediate_size: int = 3072, rope_scaling: Optional[Dict[str, Union[str, float]]] = None, **kwargs ): super().__init__(**kwargs) self.num_kv_heads = num_kv_heads self.intermediate_size = intermediate_size self.rope_scaling = rope_scaling or {"type": "linear", "factor": 1.0} # Add any custom configuration parameters here @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): """ Load configuration from a pretrained model. """ config_dict = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) return cls(**config_dict) def init_method(tensor): """Initialize tensor with Kaiming uniform initialization.""" nn.init.kaiming_uniform_(tensor, a=math.sqrt(5)) def lambda_init_fn(depth): """Compute lambda initialization value based on layer depth.""" return 0.8 - 0.6 * math.exp(-0.3 * depth) class MultiheadDiffAttn(nn.Module): def __init__(self, config: DiffLLaMAConfig, layer_idx: Optional[int] = None): super().__init__() self.config = config self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads self.num_key_value_heads = config.num_kv_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) self.scaling = self.head_dim ** -0.5 self.rotary_emb = self._init_rope() self.lambda_init = lambda_init_fn(layer_idx if layer_idx is not None else 0) self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1)) self.lambda_k1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1)) self.lambda_q2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1)) self.lambda_k2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1)) self.subln = nn.LayerNorm(self.num_heads * self.head_dim, elementwise_affine=False) self._init_rope() def _init_rope(self): if not hasattr(self.config, 'rope_scaling') or self.config.rope_scaling is None: self.rotary_emb = LlamaRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, base=self.rope_theta, ) else: scaling_type = self.config.rope_scaling.get("type", "linear") scaling_factor = self.config.rope_scaling.get("factor", 1.0) if scaling_type == "linear": self.rotary_emb = LlamaLinearScalingRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor, base=self.rope_theta, ) elif scaling_type == "dynamic": self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor, base=self.rope_theta, ) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: batch_size, seq_length, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(batch_size, seq_length, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(batch_size, seq_length, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) past_key_value = (key_states, value_states) if use_cache else None # Repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) attn_weights = torch.matmul(query_states, key_states.transpose(-1, -2)) attn_weights = attn_weights * self.scaling lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1)) lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2)) lambda_full = lambda_1 - lambda_2 + self.lambda_init # Apply differential attention attn_weights_diff = attn_weights[:, :, :, :-1] - lambda_full * attn_weights[:, :, :, 1:] attn_weights = torch.cat([attn_weights_diff, attn_weights[:, :, :, -1:]], dim=-1) if attention_mask is not None: # Expand attention_mask attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) attention_mask = attention_mask.expand(batch_size, self.num_heads, seq_length, attention_mask.size(-1)) attention_mask = attention_mask.to(dtype=attn_weights.dtype) # Convert to same dtype as attn_weights # Use a large negative number instead of negative infinity attn_weights = attn_weights + (1.0 - attention_mask) * -10000.0 attn_weights = F.softmax(attn_weights, dim=-1) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_length, self.num_heads * self.head_dim) attn_output = self.subln(attn_output) attn_output = attn_output * (1 - self.lambda_init) attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value class DiffLLaMALayer(nn.Module): """ A single layer of the DiffLLaMA model, consisting of multi-head differential attention and a feed-forward network. Incorporates gradient checkpointing for memory efficiency. """ def __init__(self, config: DiffLLaMAConfig, layer_idx: int): super().__init__() self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.self_attn = MultiheadDiffAttn( config=config, layer_idx=layer_idx ) self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.mlp = LlamaMLP(config) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, ) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) if use_cache: outputs += (present_key_value,) return outputs class DiffLLaMAModel(PreTrainedModel): """ DiffLLaMAModel is a variant of LLaMA with differential attention mechanisms. Incorporates mixed precision training and gradient checkpointing for optimized performance. """ config_class = DiffLLaMAConfig def __init__(self, config: DiffLLaMAConfig): super().__init__(config) self.config = config self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([ DiffLLaMALayer(config, layer_idx=i) for i in range(config.num_hidden_layers) ]) self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = LlamaRotaryEmbedding( dim=config.hidden_size // config.num_attention_heads, max_position_embeddings=config.max_position_embeddings, base=config.rope_theta, ) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[Tuple[torch.FloatTensor, torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: """ Forward pass for the DiffLLaMAModel with performance optimizations. Args: input_ids: Input token IDs. attention_mask: Attention mask. position_ids: Position IDs. past_key_values: Past key and value tensors for caching. inputs_embeds: Input embeddings. use_cache: Whether to return present key and value for caching. output_attentions: Whether to output attention weights. output_hidden_states: Whether to output hidden states. return_dict: Whether to return a dict. cache_position: Position IDs for caching. Returns: Model output, either as a tuple or a BaseModelOutputWithPast. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: batch_size, seq_length = input_ids.shape elif inputs_embeds is not None: batch_size, seq_length, _ = inputs_embeds.shape else: raise ValueError("You have to specify either input_ids or inputs_embeds") if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange(seq_length, dtype=torch.long, device=device) position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) # Position embeddings are handled within each layer; remove pre-computation # Removed the following lines: # cos, sin = self.rotary_emb(position_ids, seq_len=seq_length) # position_embeddings = (cos, sin) hidden_states = inputs_embeds # Attention mask if attention_mask is None: attention_mask = torch.ones((batch_size, seq_length), device=hidden_states.device) # Initialize lists to store outputs all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_cache = () if use_cache else None for idx, layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) layer_outputs = layer( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values[idx] if past_key_values is not None else None, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, ) # Correctly unpack layer_outputs based on the configuration hidden_states = layer_outputs[0] if use_cache: present_key_value = layer_outputs[-1] next_cache += (present_key_value,) if output_attentions: self_attn_weights = layer_outputs[1] all_self_attns += (self_attn_weights,) hidden_states = self.norm(hidden_states) if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = None if use_cache: next_cache = ( next_cache.to_legacy_cache() if isinstance(next_cache, Cache) else next_cache ) if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, ) class DiffLLaMAForCausalLM(PreTrainedModel): """ DiffLLaMA model with a causal language modeling head. Incorporates mixed precision training for optimized performance. """ config_class = DiffLLaMAConfig _tied_weights_keys = ["lm_head.weight"] def __init__(self, config: DiffLLaMAConfig): super().__init__(config) self.model = DiffLLaMAModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): """Return input embeddings.""" return self.model.get_input_embeddings() def set_input_embeddings(self, value): """Set input embeddings.""" self.model.set_input_embeddings(value) def get_output_embeddings(self): """Return output embeddings (language modeling head).""" return self.lm_head def set_output_embeddings(self, new_embeddings): """Set output embeddings (language modeling head).""" self.lm_head = new_embeddings def set_decoder(self, decoder): """Set the decoder model.""" self.model = decoder def get_decoder(self): """Get the decoder model.""" return self.model def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[Tuple[torch.FloatTensor, torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: """ Forward pass for DiffLLaMAForCausalLM with performance optimizations. Args: input_ids: Input token IDs. attention_mask: Attention mask. position_ids: Position IDs. past_key_values: Past key and value tensors for caching. inputs_embeds: Input embeddings. labels: Labels for computing the loss. use_cache: Whether to return past key and value tensors. output_attentions: Whether to output attention weights. output_hidden_states: Whether to output hidden states. return_dict: Whether to return a dict. cache_position: Position IDs for caching. Returns: CausalLMOutputWithPast or tuple containing loss and logits. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Get outputs from the model outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, ) hidden_states = outputs.last_hidden_state if return_dict else outputs[0] logits = self.lm_head(hidden_states) loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = nn.CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Compute loss using mixed precision if enabled if shift_logits.dtype == torch.float16: with torch.cuda.amp.autocast(enabled=False): loss = loss_fct(shift_logits, shift_labels) else: loss = loss_fct(shift_logits, shift_labels) if not return_dict: if use_cache: return ((loss, logits) + outputs[1:]) if loss is not None else (logits,) + outputs[1:] else: return (loss, logits) if loss is not None else (logits,) return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )