|
from transformers import PreTrainedModel, AutoConfig, AutoModelForCausalLM
|
|
from .configuration_lumenspark import LumensparkConfig
|
|
from torch import nn
|
|
import torch
|
|
import math
|
|
|
|
|
|
|
|
|
|
|
|
class LowRankLinear(nn.Module):
|
|
def __init__(self, in_features, out_features, rank, init_std=0.02):
|
|
super().__init__()
|
|
self.U = nn.Linear(in_features, rank, bias=False)
|
|
self.V = nn.Linear(rank, out_features, bias=False)
|
|
nn.init.normal_(self.U.weight, std=init_std)
|
|
nn.init.normal_(self.V.weight, std=init_std)
|
|
|
|
def forward(self, x):
|
|
return self.V(self.U(x))
|
|
|
|
|
|
|
|
|
|
|
|
class LumensparkSelfAttention(nn.Module):
|
|
def __init__(self, embed_dim, num_heads, head_dim=None, dropout=0.0):
|
|
super().__init__()
|
|
assert (embed_dim % num_heads) == 0, 'Embedding dimension must be divisible by the number of heads'
|
|
|
|
self.num_heads = num_heads
|
|
self.embed_dim = embed_dim
|
|
self.head_dim = head_dim if head_dim is not None else embed_dim // num_heads
|
|
|
|
self.q_proj = nn.Linear(embed_dim, self.head_dim * num_heads)
|
|
self.k_proj = nn.Linear(embed_dim, self.head_dim * num_heads)
|
|
self.v_proj = nn.Linear(embed_dim, self.head_dim * num_heads)
|
|
|
|
self.dropout_layer = nn.Dropout(dropout)
|
|
self.output_transform = nn.Linear(self.head_dim * num_heads, embed_dim)
|
|
|
|
def stable_softmax(self, x, dim=-1):
|
|
x_max = torch.max(x, dim=dim, keepdim=True)[0]
|
|
exp_x = torch.exp(x - x_max)
|
|
return exp_x / (torch.sum(exp_x, dim=dim, keepdim=True) + 1e-6)
|
|
|
|
def forward(self, inputs, attention_mask=None):
|
|
batch_size, seq_len, _ = inputs.shape
|
|
|
|
q = self.q_proj(inputs).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
k = self.k_proj(inputs).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
v = self.v_proj(inputs).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
|
|
attention_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
|
|
|
if attention_mask is not None:
|
|
attention_scores = attention_scores.masked_fill(attention_mask == 0, float('-inf'))
|
|
|
|
attention_weights = self.stable_softmax(attention_scores, dim=-1)
|
|
attention_weights = self.dropout_layer(attention_weights)
|
|
|
|
attention_output = torch.matmul(attention_weights, v)
|
|
attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)
|
|
return self.output_transform(attention_output)
|
|
|
|
|
|
|
|
|
|
|
|
class LumensparkModel(PreTrainedModel):
|
|
config_class = LumensparkConfig
|
|
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.config = config
|
|
|
|
|
|
self.token_embedding = nn.Embedding(config.vocab_size, config.embed_dim)
|
|
self.position_embedding = nn.Embedding(config.seq_length, config.embed_dim)
|
|
|
|
|
|
self.layers = nn.ModuleList()
|
|
for _ in range(config.depth):
|
|
layer = nn.ModuleDict({
|
|
"norm1": nn.LayerNorm(config.embed_dim),
|
|
"attn": LumensparkSelfAttention(
|
|
embed_dim=config.embed_dim,
|
|
num_heads=config.heads,
|
|
head_dim=config.embed_dim // config.heads,
|
|
dropout=config.dropout
|
|
),
|
|
"norm2": nn.LayerNorm(config.embed_dim),
|
|
"ffn": nn.Sequential(
|
|
LowRankLinear(config.embed_dim, config.embed_dim * 4, rank=config.rank),
|
|
nn.GELU(),
|
|
nn.Dropout(config.dropout),
|
|
LowRankLinear(config.embed_dim * 4, config.embed_dim, rank=config.rank),
|
|
nn.Dropout(config.dropout)
|
|
),
|
|
})
|
|
|
|
layer.layer_scale_attn = nn.Parameter(torch.ones(config.embed_dim) * 1e-2)
|
|
layer.layer_scale_ffn = nn.Parameter(torch.ones(config.embed_dim) * 1e-2)
|
|
self.layers.append(layer)
|
|
|
|
|
|
self.final_norm = nn.LayerNorm(config.embed_dim)
|
|
|
|
|
|
self.fc_out = nn.Linear(config.embed_dim, config.vocab_size)
|
|
self.dropout = nn.Dropout(config.dropout)
|
|
|
|
|
|
self.init_weights()
|
|
|
|
@staticmethod
|
|
def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float('Inf')):
|
|
"""
|
|
Filter a distribution of logits using top-k and/or top-p filtering.
|
|
"""
|
|
top_k = min(top_k, logits.size(-1))
|
|
if top_k > 0:
|
|
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
|
logits[indices_to_remove] = filter_value
|
|
if top_p < 1.0:
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
|
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
|
sorted_indices_to_remove = cumulative_probs > top_p
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
|
sorted_indices_to_remove[..., 0] = 0
|
|
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
|
logits[:, indices_to_remove] = filter_value
|
|
return logits
|
|
|
|
def generate(self, input_ids, attention_mask=None, max_length=160, min_length=20, temperature=0.6, top_k=50, top_p=0.9, repetition_penalty=1.1, do_sample=True):
|
|
"""
|
|
Text generation method that handles auto-regressive generation with repetition penalty.
|
|
Input `input_ids` should be a tensor. Returns generated tokens.
|
|
"""
|
|
self.eval()
|
|
device = input_ids.device
|
|
generated_tokens = input_ids
|
|
|
|
for _ in range(max_length - input_ids.size(1)):
|
|
|
|
outputs = self.forward(input_ids=generated_tokens, attention_mask=attention_mask)
|
|
logits = outputs["logits"][:, -1, :]
|
|
|
|
|
|
logits = logits / temperature
|
|
|
|
|
|
for token in set(generated_tokens.view(-1).tolist()):
|
|
logits[:, token] /= repetition_penalty
|
|
|
|
|
|
if do_sample:
|
|
filtered_logits = LumensparkModel.top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
|
|
probs = torch.softmax(filtered_logits, dim=-1)
|
|
next_token = torch.multinomial(probs, num_samples=1)
|
|
else:
|
|
next_token = torch.argmax(logits, dim=-1, keepdim=True)
|
|
|
|
|
|
generated_tokens = torch.cat((generated_tokens, next_token), dim=1)
|
|
attention_mask = torch.ones_like(generated_tokens).to(device)
|
|
|
|
|
|
if next_token.item() == self.config.eos_token_id and generated_tokens.size(1) < min_length:
|
|
continue
|
|
if next_token.item() == self.config.eos_token_id:
|
|
break
|
|
return generated_tokens
|
|
|
|
def forward(self, input_ids, attention_mask=None, labels=None):
|
|
"""
|
|
Forward pass of the model. If `labels` are provided, computes the loss.
|
|
"""
|
|
batch_size, seq_length = input_ids.size()
|
|
|
|
|
|
position_ids = torch.arange(0, seq_length, dtype=torch.long, device=input_ids.device)
|
|
position_ids = position_ids.unsqueeze(0).expand(batch_size, seq_length)
|
|
|
|
|
|
token_embeddings = self.token_embedding(input_ids)
|
|
position_embeddings = self.position_embedding(position_ids)
|
|
|
|
|
|
embeddings = token_embeddings + position_embeddings
|
|
embeddings = self.dropout(embeddings)
|
|
|
|
|
|
device = embeddings.device
|
|
causal_mask = torch.tril(torch.ones((seq_length, seq_length), device=device)).unsqueeze(0).unsqueeze(0)
|
|
|
|
|
|
combined_mask = causal_mask if attention_mask is None else attention_mask[:, None, None, :].float() * causal_mask
|
|
|
|
|
|
for layer in self.layers:
|
|
embeddings_norm = layer["norm1"](embeddings)
|
|
attn_output = layer["attn"](embeddings_norm, attention_mask=combined_mask)
|
|
embeddings = embeddings + layer.layer_scale_attn * attn_output
|
|
|
|
embeddings_norm = layer["norm2"](embeddings)
|
|
ffn_output = layer["ffn"](embeddings_norm)
|
|
embeddings = embeddings + layer.layer_scale_ffn * ffn_output
|
|
|
|
|
|
embeddings = self.final_norm(embeddings)
|
|
logits = self.fc_out(embeddings)
|
|
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
shift_logits = logits[:, :-1, :].contiguous().view(-1, self.config.vocab_size)
|
|
shift_labels = labels[:, 1:].contiguous().view(-1)
|
|
loss_fct = nn.CrossEntropyLoss()
|
|
loss = loss_fct(shift_logits, shift_labels)
|
|
|
|
return {"loss": loss, "logits": logits}
|
|
|
|
|
|
AutoConfig.register("lumenspark", LumensparkConfig)
|
|
AutoModelForCausalLM.register(LumensparkConfig, LumensparkModel)
|
|
|