File size: 10,297 Bytes
4306d2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
from transformers import PreTrainedModel, AutoConfig, AutoModelForCausalLM
from .configuration_lumenspark import LumensparkConfig
from torch import nn
import torch
import math

# ----------------------------
# Low-Rank Linear Layer Implementation
# ----------------------------

class LowRankLinear(nn.Module):
    def __init__(self, in_features, out_features, rank, init_std=0.02):
        super().__init__()
        self.U = nn.Linear(in_features, rank, bias=False)
        self.V = nn.Linear(rank, out_features, bias=False)
        nn.init.normal_(self.U.weight, std=init_std)
        nn.init.normal_(self.V.weight, std=init_std)

    def forward(self, x):
        return self.V(self.U(x))

# ----------------------------
# Lumenspark Self-Attention Implementation
# ----------------------------

class LumensparkSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, head_dim=None, dropout=0.0):
        super().__init__()
        assert (embed_dim % num_heads) == 0, 'Embedding dimension must be divisible by the number of heads'

        self.num_heads = num_heads
        self.embed_dim = embed_dim
        self.head_dim = head_dim if head_dim is not None else embed_dim // num_heads

        self.q_proj = nn.Linear(embed_dim, self.head_dim * num_heads)
        self.k_proj = nn.Linear(embed_dim, self.head_dim * num_heads)
        self.v_proj = nn.Linear(embed_dim, self.head_dim * num_heads)

        self.dropout_layer = nn.Dropout(dropout)
        self.output_transform = nn.Linear(self.head_dim * num_heads, embed_dim)

    def stable_softmax(self, x, dim=-1):
        x_max = torch.max(x, dim=dim, keepdim=True)[0]
        exp_x = torch.exp(x - x_max)
        return exp_x / (torch.sum(exp_x, dim=dim, keepdim=True) + 1e-6)
    
    def forward(self, inputs, attention_mask=None):
        batch_size, seq_len, _ = inputs.shape

        q = self.q_proj(inputs).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(inputs).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(inputs).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        attention_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        
        if attention_mask is not None:
            attention_scores = attention_scores.masked_fill(attention_mask == 0, float('-inf'))
        
        attention_weights = self.stable_softmax(attention_scores, dim=-1)
        attention_weights = self.dropout_layer(attention_weights)

        attention_output = torch.matmul(attention_weights, v)
        attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)
        return self.output_transform(attention_output)

# ----------------------------
# Define Lumenspark Model Wrapper
# ----------------------------

class LumensparkModel(PreTrainedModel):
    config_class = LumensparkConfig

    def __init__(self, config):
        super().__init__(config)
        self.config = config

        # Token and position embeddings
        self.token_embedding = nn.Embedding(config.vocab_size, config.embed_dim)
        self.position_embedding = nn.Embedding(config.seq_length, config.embed_dim)

        # Lumenspark transformer encoder layers with prenormalization and LayerScale
        self.layers = nn.ModuleList()
        for _ in range(config.depth):
            layer = nn.ModuleDict({
                "norm1": nn.LayerNorm(config.embed_dim),
                "attn": LumensparkSelfAttention(
                    embed_dim=config.embed_dim,
                    num_heads=config.heads,
                    head_dim=config.embed_dim // config.heads,
                    dropout=config.dropout
                ),
                "norm2": nn.LayerNorm(config.embed_dim),
                "ffn": nn.Sequential(
                    LowRankLinear(config.embed_dim, config.embed_dim * 4, rank=config.rank),
                    nn.GELU(),
                    nn.Dropout(config.dropout),
                    LowRankLinear(config.embed_dim * 4, config.embed_dim, rank=config.rank),
                    nn.Dropout(config.dropout)
                ),
            })
            # Assign the parameters directly as attributes
            layer.layer_scale_attn = nn.Parameter(torch.ones(config.embed_dim) * 1e-2)
            layer.layer_scale_ffn = nn.Parameter(torch.ones(config.embed_dim) * 1e-2)
            self.layers.append(layer)

        # Final LayerNorm layer
        self.final_norm = nn.LayerNorm(config.embed_dim)

        # Feed-forward output layer
        self.fc_out = nn.Linear(config.embed_dim, config.vocab_size)
        self.dropout = nn.Dropout(config.dropout)

        # Initialize model weights
        self.init_weights()

    @staticmethod
    def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float('Inf')):
        """

        Filter a distribution of logits using top-k and/or top-p filtering.

        """
        top_k = min(top_k, logits.size(-1))
        if top_k > 0:
            indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
            logits[indices_to_remove] = filter_value
        if top_p < 1.0:
            sorted_logits, sorted_indices = torch.sort(logits, descending=True)
            cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
            sorted_indices_to_remove = cumulative_probs > top_p
            sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
            sorted_indices_to_remove[..., 0] = 0
            indices_to_remove = sorted_indices[sorted_indices_to_remove]
            logits[:, indices_to_remove] = filter_value
        return logits

    def generate(self, input_ids, attention_mask=None, max_length=160, min_length=20, temperature=0.6, top_k=50, top_p=0.9, repetition_penalty=1.1, do_sample=True):
        """

        Text generation method that handles auto-regressive generation with repetition penalty.

        Input `input_ids` should be a tensor. Returns generated tokens.

        """
        self.eval()
        device = input_ids.device
        generated_tokens = input_ids

        for _ in range(max_length - input_ids.size(1)):
            # Forward pass for logits
            outputs = self.forward(input_ids=generated_tokens, attention_mask=attention_mask)
            logits = outputs["logits"][:, -1, :]

            # Adjust logits by temperature
            logits = logits / temperature

            # Apply repetition penalty by reducing logits of tokens already generated
            for token in set(generated_tokens.view(-1).tolist()):
                logits[:, token] /= repetition_penalty

            # Apply sampling with top-k and top-p
            if do_sample:
                filtered_logits = LumensparkModel.top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
                probs = torch.softmax(filtered_logits, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1)
            else:
                next_token = torch.argmax(logits, dim=-1, keepdim=True)

            # Append the generated token
            generated_tokens = torch.cat((generated_tokens, next_token), dim=1)
            attention_mask = torch.ones_like(generated_tokens).to(device)

            # Ensure min_length before stopping generation with end-of-sequence (EOS) token
            if next_token.item() == self.config.eos_token_id and generated_tokens.size(1) < min_length:
                continue
            if next_token.item() == self.config.eos_token_id:
                break
        return generated_tokens

    def forward(self, input_ids, attention_mask=None, labels=None):
        """

        Forward pass of the model. If `labels` are provided, computes the loss.

        """
        batch_size, seq_length = input_ids.size()

        # Generate position ids for input tokens
        position_ids = torch.arange(0, seq_length, dtype=torch.long, device=input_ids.device)
        position_ids = position_ids.unsqueeze(0).expand(batch_size, seq_length)

        # Embed tokens and positions
        token_embeddings = self.token_embedding(input_ids)
        position_embeddings = self.position_embedding(position_ids)

        # Combine token and position embeddings
        embeddings = token_embeddings + position_embeddings
        embeddings = self.dropout(embeddings)

        # Create causal mask for self-attention to ensure autoregressive behavior
        device = embeddings.device
        causal_mask = torch.tril(torch.ones((seq_length, seq_length), device=device)).unsqueeze(0).unsqueeze(0)

        # Combine with attention mask if provided
        combined_mask = causal_mask if attention_mask is None else attention_mask[:, None, None, :].float() * causal_mask

        # Pass through transformer layers
        for layer in self.layers:
            embeddings_norm = layer["norm1"](embeddings)
            attn_output = layer["attn"](embeddings_norm, attention_mask=combined_mask)
            embeddings = embeddings + layer.layer_scale_attn * attn_output

            embeddings_norm = layer["norm2"](embeddings)
            ffn_output = layer["ffn"](embeddings_norm)
            embeddings = embeddings + layer.layer_scale_ffn * ffn_output

        # Final normalization and output projection to logits
        embeddings = self.final_norm(embeddings)
        logits = self.fc_out(embeddings)

        # Compute loss if labels are provided
        loss = None
        if labels is not None:
            shift_logits = logits[:, :-1, :].contiguous().view(-1, self.config.vocab_size)
            shift_labels = labels[:, 1:].contiguous().view(-1)
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(shift_logits, shift_labels)

        return {"loss": loss, "logits": logits}

# Register LumensparkForCausalLM with AutoModelForCausalLM
AutoConfig.register("lumenspark", LumensparkConfig)
AutoModelForCausalLM.register(LumensparkConfig, LumensparkModel)