from transformers import PreTrainedModel, AutoConfig, AutoModelForCausalLM from .configuration_lumenspark import LumensparkConfig from torch import nn import torch import math # ---------------------------- # Low-Rank Linear Layer Implementation # ---------------------------- class LowRankLinear(nn.Module): def __init__(self, in_features, out_features, rank, init_std=0.02): super().__init__() self.U = nn.Linear(in_features, rank, bias=False) self.V = nn.Linear(rank, out_features, bias=False) nn.init.normal_(self.U.weight, std=init_std) nn.init.normal_(self.V.weight, std=init_std) def forward(self, x): return self.V(self.U(x)) # ---------------------------- # Lumenspark Self-Attention Implementation # ---------------------------- class LumensparkSelfAttention(nn.Module): def __init__(self, embed_dim, num_heads, head_dim=None, dropout=0.0): super().__init__() assert (embed_dim % num_heads) == 0, 'Embedding dimension must be divisible by the number of heads' self.num_heads = num_heads self.embed_dim = embed_dim self.head_dim = head_dim if head_dim is not None else embed_dim // num_heads self.q_proj = nn.Linear(embed_dim, self.head_dim * num_heads) self.k_proj = nn.Linear(embed_dim, self.head_dim * num_heads) self.v_proj = nn.Linear(embed_dim, self.head_dim * num_heads) self.dropout_layer = nn.Dropout(dropout) self.output_transform = nn.Linear(self.head_dim * num_heads, embed_dim) def stable_softmax(self, x, dim=-1): x_max = torch.max(x, dim=dim, keepdim=True)[0] exp_x = torch.exp(x - x_max) return exp_x / (torch.sum(exp_x, dim=dim, keepdim=True) + 1e-6) def forward(self, inputs, attention_mask=None): batch_size, seq_len, _ = inputs.shape q = self.q_proj(inputs).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) k = self.k_proj(inputs).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) v = self.v_proj(inputs).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) attention_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) if attention_mask is not None: attention_scores = attention_scores.masked_fill(attention_mask == 0, float('-inf')) attention_weights = self.stable_softmax(attention_scores, dim=-1) attention_weights = self.dropout_layer(attention_weights) attention_output = torch.matmul(attention_weights, v) attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim) return self.output_transform(attention_output) # ---------------------------- # Define Lumenspark Model Wrapper # ---------------------------- class LumensparkModel(PreTrainedModel): config_class = LumensparkConfig def __init__(self, config): super().__init__(config) self.config = config # Token and position embeddings self.token_embedding = nn.Embedding(config.vocab_size, config.embed_dim) self.position_embedding = nn.Embedding(config.seq_length, config.embed_dim) # Lumenspark transformer encoder layers with prenormalization and LayerScale self.layers = nn.ModuleList() for _ in range(config.depth): layer = nn.ModuleDict({ "norm1": nn.LayerNorm(config.embed_dim), "attn": LumensparkSelfAttention( embed_dim=config.embed_dim, num_heads=config.heads, head_dim=config.embed_dim // config.heads, dropout=config.dropout ), "norm2": nn.LayerNorm(config.embed_dim), "ffn": nn.Sequential( LowRankLinear(config.embed_dim, config.embed_dim * 4, rank=config.rank), nn.GELU(), nn.Dropout(config.dropout), LowRankLinear(config.embed_dim * 4, config.embed_dim, rank=config.rank), nn.Dropout(config.dropout) ), }) # Assign the parameters directly as attributes layer.layer_scale_attn = nn.Parameter(torch.ones(config.embed_dim) * 1e-2) layer.layer_scale_ffn = nn.Parameter(torch.ones(config.embed_dim) * 1e-2) self.layers.append(layer) # Final LayerNorm layer self.final_norm = nn.LayerNorm(config.embed_dim) # Feed-forward output layer self.fc_out = nn.Linear(config.embed_dim, config.vocab_size) self.dropout = nn.Dropout(config.dropout) # Initialize model weights self.init_weights() @staticmethod def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float('Inf')): """ Filter a distribution of logits using top-k and/or top-p filtering. """ top_k = min(top_k, logits.size(-1)) if top_k > 0: indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] logits[indices_to_remove] = filter_value if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices[sorted_indices_to_remove] logits[:, indices_to_remove] = filter_value return logits def generate(self, input_ids, attention_mask=None, max_length=160, min_length=20, temperature=0.6, top_k=50, top_p=0.9, repetition_penalty=1.1, do_sample=True): """ Text generation method that handles auto-regressive generation with repetition penalty. Input `input_ids` should be a tensor. Returns generated tokens. """ self.eval() device = input_ids.device generated_tokens = input_ids for _ in range(max_length - input_ids.size(1)): # Forward pass for logits outputs = self.forward(input_ids=generated_tokens, attention_mask=attention_mask) logits = outputs["logits"][:, -1, :] # Adjust logits by temperature logits = logits / temperature # Apply repetition penalty by reducing logits of tokens already generated for token in set(generated_tokens.view(-1).tolist()): logits[:, token] /= repetition_penalty # Apply sampling with top-k and top-p if do_sample: filtered_logits = LumensparkModel.top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) probs = torch.softmax(filtered_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) else: next_token = torch.argmax(logits, dim=-1, keepdim=True) # Append the generated token generated_tokens = torch.cat((generated_tokens, next_token), dim=1) attention_mask = torch.ones_like(generated_tokens).to(device) # Ensure min_length before stopping generation with end-of-sequence (EOS) token if next_token.item() == self.config.eos_token_id and generated_tokens.size(1) < min_length: continue if next_token.item() == self.config.eos_token_id: break return generated_tokens def forward(self, input_ids, attention_mask=None, labels=None): """ Forward pass of the model. If `labels` are provided, computes the loss. """ batch_size, seq_length = input_ids.size() # Generate position ids for input tokens position_ids = torch.arange(0, seq_length, dtype=torch.long, device=input_ids.device) position_ids = position_ids.unsqueeze(0).expand(batch_size, seq_length) # Embed tokens and positions token_embeddings = self.token_embedding(input_ids) position_embeddings = self.position_embedding(position_ids) # Combine token and position embeddings embeddings = token_embeddings + position_embeddings embeddings = self.dropout(embeddings) # Create causal mask for self-attention to ensure autoregressive behavior device = embeddings.device causal_mask = torch.tril(torch.ones((seq_length, seq_length), device=device)).unsqueeze(0).unsqueeze(0) # Combine with attention mask if provided combined_mask = causal_mask if attention_mask is None else attention_mask[:, None, None, :].float() * causal_mask # Pass through transformer layers for layer in self.layers: embeddings_norm = layer["norm1"](embeddings) attn_output = layer["attn"](embeddings_norm, attention_mask=combined_mask) embeddings = embeddings + layer.layer_scale_attn * attn_output embeddings_norm = layer["norm2"](embeddings) ffn_output = layer["ffn"](embeddings_norm) embeddings = embeddings + layer.layer_scale_ffn * ffn_output # Final normalization and output projection to logits embeddings = self.final_norm(embeddings) logits = self.fc_out(embeddings) # Compute loss if labels are provided loss = None if labels is not None: shift_logits = logits[:, :-1, :].contiguous().view(-1, self.config.vocab_size) shift_labels = labels[:, 1:].contiguous().view(-1) loss_fct = nn.CrossEntropyLoss() loss = loss_fct(shift_logits, shift_labels) return {"loss": loss, "logits": logits} # Register LumensparkForCausalLM with AutoModelForCausalLM AutoConfig.register("lumenspark", LumensparkConfig) AutoModelForCausalLM.register(LumensparkConfig, LumensparkModel)