KANLlama_RottenTomatoes / modeling_kanllama.py
SE6446's picture
Upload 2 files
da0ef01 verified
import torch
import torch.nn.functional as F
import math
from .configuration_kanllama import KANLlamaConfig
######
# KAN and KANLinear are take from efficient KAN
# https://github.com/Blealtan/efficient-kan
######
class KANLinear(torch.nn.Module):
def __init__(
self,
in_features,
out_features,
grid_size=5,
spline_order=3,
scale_noise=0.1,
scale_base=1.0,
scale_spline=1.0,
enable_standalone_scale_spline=True,
base_activation=torch.nn.SiLU,
grid_eps=0.02,
grid_range=[-1, 1],
):
super(KANLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.grid_size = grid_size
self.spline_order = spline_order
h = (grid_range[1] - grid_range[0]) / grid_size
grid = (
(
torch.arange(-spline_order, grid_size + spline_order + 1) * h
+ grid_range[0]
)
.expand(in_features, -1)
.contiguous()
)
self.register_buffer("grid", grid)
self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
self.spline_weight = torch.nn.Parameter(
torch.Tensor(out_features, in_features, grid_size + spline_order)
)
if enable_standalone_scale_spline:
self.spline_scaler = torch.nn.Parameter(
torch.Tensor(out_features, in_features)
)
self.scale_noise = scale_noise
self.scale_base = scale_base
self.scale_spline = scale_spline
self.enable_standalone_scale_spline = enable_standalone_scale_spline
self.base_activation = base_activation()
self.grid_eps = grid_eps
self.reset_parameters()
def reset_parameters(self):
torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)
with torch.no_grad():
noise = (
(
torch.rand(self.grid_size + 1, self.in_features, self.out_features)
- 1 / 2
)
* self.scale_noise
/ self.grid_size
)
self.spline_weight.data.copy_(
(self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
* self.curve2coeff(
self.grid.T[self.spline_order : -self.spline_order],
noise,
)
)
if self.enable_standalone_scale_spline:
# torch.nn.init.constant_(self.spline_scaler, self.scale_spline)
torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)
def b_splines(self, x: torch.Tensor):
"""
Compute the B-spline bases for the given input tensor.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, in_features).
Returns:
torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
"""
assert x.dim() == 2 and x.size(1) == self.in_features
grid: torch.Tensor = (
self.grid
) # (in_features, grid_size + 2 * spline_order + 1)
x = x.unsqueeze(-1)
bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
for k in range(1, self.spline_order + 1):
bases = (
(x - grid[:, : -(k + 1)])
/ (grid[:, k:-1] - grid[:, : -(k + 1)])
* bases[:, :, :-1]
) + (
(grid[:, k + 1 :] - x)
/ (grid[:, k + 1 :] - grid[:, 1:(-k)])
* bases[:, :, 1:]
)
assert bases.size() == (
x.size(0),
self.in_features,
self.grid_size + self.spline_order,
)
return bases.contiguous()
def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
"""
Compute the coefficients of the curve that interpolates the given points.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, in_features).
y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).
Returns:
torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).
"""
assert x.dim() == 2 and x.size(1) == self.in_features
assert y.size() == (x.size(0), self.in_features, self.out_features)
A = self.b_splines(x).transpose(
0, 1
) # (in_features, batch_size, grid_size + spline_order)
B = y.transpose(0, 1) # (in_features, batch_size, out_features)
solution = torch.linalg.lstsq(
A, B
).solution # (in_features, grid_size + spline_order, out_features)
result = solution.permute(
2, 0, 1
) # (out_features, in_features, grid_size + spline_order)
assert result.size() == (
self.out_features,
self.in_features,
self.grid_size + self.spline_order,
)
return result.contiguous()
@property
def scaled_spline_weight(self):
return self.spline_weight * (
self.spline_scaler.unsqueeze(-1)
if self.enable_standalone_scale_spline
else 1.0
)
def forward(self, x: torch.Tensor):
assert x.size(-1) == self.in_features
original_shape = x.shape
x = x.view(-1, self.in_features)
base_output = F.linear(self.base_activation(x), self.base_weight)
spline_output = F.linear(
self.b_splines(x).view(x.size(0), -1),
self.scaled_spline_weight.view(self.out_features, -1),
)
output = base_output + spline_output
output = output.view(*original_shape[:-1], self.out_features)
return output
@torch.no_grad()
def update_grid(self, x: torch.Tensor, margin=0.01):
assert x.dim() == 2 and x.size(1) == self.in_features
batch = x.size(0)
splines = self.b_splines(x) # (batch, in, coeff)
splines = splines.permute(1, 0, 2) # (in, batch, coeff)
orig_coeff = self.scaled_spline_weight # (out, in, coeff)
orig_coeff = orig_coeff.permute(1, 2, 0) # (in, coeff, out)
unreduced_spline_output = torch.bmm(splines, orig_coeff) # (in, batch, out)
unreduced_spline_output = unreduced_spline_output.permute(
1, 0, 2
) # (batch, in, out)
# sort each channel individually to collect data distribution
x_sorted = torch.sort(x, dim=0)[0]
grid_adaptive = x_sorted[
torch.linspace(
0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device
)
]
uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
grid_uniform = (
torch.arange(
self.grid_size + 1, dtype=torch.float32, device=x.device
).unsqueeze(1)
* uniform_step
+ x_sorted[0]
- margin
)
grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
grid = torch.concatenate(
[
grid[:1]
- uniform_step
* torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),
grid,
grid[-1:]
+ uniform_step
* torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),
],
dim=0,
)
self.grid.copy_(grid.T)
self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))
def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
"""
Compute the regularization loss.
This is a dumb simulation of the original L1 regularization as stated in the
paper, since the original one requires computing absolutes and entropy from the
expanded (batch, in_features, out_features) intermediate tensor, which is hidden
behind the F.linear function if we want an memory efficient implementation.
The L1 regularization is now computed as mean absolute value of the spline
weights. The authors implementation also includes this term in addition to the
sample-based regularization.
"""
l1_fake = self.spline_weight.abs().mean(-1)
regularization_loss_activation = l1_fake.sum()
p = l1_fake / regularization_loss_activation
regularization_loss_entropy = -torch.sum(p * p.log())
return (
regularize_activation * regularization_loss_activation
+ regularize_entropy * regularization_loss_entropy
)
class KAN(torch.nn.Module):
def __init__(
self,
layers_hidden,
grid_size=5,
spline_order=3,
scale_noise=0.1,
scale_base=1.0,
scale_spline=1.0,
base_activation=torch.nn.SiLU,
grid_eps=0.02,
grid_range=[-1, 1],
):
super(KAN, self).__init__()
self.grid_size = grid_size
self.spline_order = spline_order
self.layers = torch.nn.ModuleList()
for in_features, out_features in zip(layers_hidden, layers_hidden[1:]):
self.layers.append(
KANLinear(
in_features,
out_features,
grid_size=grid_size,
spline_order=spline_order,
scale_noise=scale_noise,
scale_base=scale_base,
scale_spline=scale_spline,
base_activation=base_activation,
grid_eps=grid_eps,
grid_range=grid_range,
)
)
def forward(self, x: torch.Tensor, update_grid=False):
for layer in self.layers:
if update_grid:
layer.update_grid(x)
x = layer(x)
return x
def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
return sum(
layer.regularization_loss(regularize_activation, regularize_entropy)
for layer in self.layers
)
"""## Build Kanformer"""
from transformers import AutoConfig, AutoTokenizer, AutoModel
from transformers.models.llama.modeling_llama import *
class KANLlamaAttention(LlamaAttention):
def __init__(self, config, **args):
super().__init__(config, **args)
head_dim = config.hidden_size // config.num_attention_heads
self.q_proj = KANLinear(config.hidden_size, config.num_attention_heads * head_dim)
self.k_proj = KANLinear(config.hidden_size, config.num_key_value_heads * head_dim)
self.v_proj = KANLinear(config.hidden_size, config.num_key_value_heads * head_dim)
self.o_proj = KANLinear(config.hidden_size, config.hidden_size)
self._init_rope()
class KANLlamaDecoderLayer(LlamaDecoderLayer):
def __init__(self,config, layer_idx):
super().__init__(config, layer_idx)
self.hidden_size = config.hidden_size
self.self_attn = KANLlamaAttention(config=config, layer_idx=layer_idx)
self.mlp = KAN([config.hidden_size,config.intermediate_size,config.intermediate_size,config.hidden_size])
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
class KANLlamaModel(LlamaModel):
config_class = KANLlamaConfig
def __init__(self, config):
super().__init__(config)
self.layers = nn.ModuleList(
[KANLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
class KANLlamaForSequenceClassification(LlamaForSequenceClassification):
config_class = KANLlamaConfig
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = KANLlamaModel(config)
self.score = KANLinear(config.hidden_size, self.num_labels)
# Initialize weights and apply final processing
self.post_init()
class KANLlamaForCausalLM(LlamaForCausalLM):
config_class = KANLlamaConfig
def __init__(self, config):
super().__init__(config)
self.model = KANLlamaModel(config)
self.vocab_size = config.vocab_size
self.lm_head = KANLinear(config.hidden_size, config.vocab_size)
# Initialize weights and apply final processing
self.post_init()