blast-llama-4B / modeling_blast.py
cwoolee's picture
Upload model
279b61f verified
import os
import logging
import torch
import torch.nn as nn
from transformers import PretrainedConfig, LlamaConfig, LlamaModel, LlamaForCausalLM
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRotaryEmbedding, LlamaRMSNorm
from typing import List, Union, Tuple
from huggingface_hub import PyTorchModelHubMixin
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
class BlastLlamaConfig(LlamaConfig):
model_type = "blast_llama"
keys_to_ignore_at_inference = ["blast_decomposed_weight_path"]
def __init__(
self,
target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'],
blast_rank={'q_proj': 1024, 'k_proj': 1024, 'v_proj': 1024, 'o_proj': 1024, 'gate_proj': 1488, 'up_proj': 1488, 'down_proj': 1488},
blast_num_blocks: Union[Union[List, Tuple], int] = 4,
indices=[i for i in range(32)],
precompute_matrix=False,
**kwargs,
):
self.target_modules = target_modules
self.blast_rank = blast_rank
self.blast_num_blocks = blast_num_blocks,
self.indices = indices
self.precompute_matrix = precompute_matrix
#self.blast_decomposed_weight_path = blast_decomposed_weight_path
super().__init__(**kwargs)
def get_parent(model, mn):
parent_name = ".".join(mn.split(".")[:-1])
for n, m in model.named_modules():
if n == parent_name:
return m
def replace_layers_with_blast(
model,
target_modules,
blast_rank,
blast_num_blocks,
indices,
precompute_matrix=False,
):
for mn, m in model.named_modules():
if isinstance(m, torch.nn.Linear):
for tmn in target_modules:
if tmn in mn:
layer_idx = int(mn.split(".")[-3])
if layer_idx not in indices:
continue
if isinstance(blast_rank, dict):
for k in blast_rank.keys():
if k in mn:
rank = blast_rank[k]
break
elif isinstance(blast_rank, int):
rank = blast_rank
elif isinstance(blast_rank, float):
rank = int(blast_rank * min(m.weight.shape[0], m.weight.shape[1]))
else:
raise ValueError(f"blast_rank must have either dict, int, or float type, got: {type(blast_rank)}.")
if isinstance(blast_num_blocks, dict):
for k in blast_rank.keys():
if k in mn:
num_blocks = blast_num_blocks[k]
break
elif isinstance(blast_num_blocks, int):
num_blocks = blast_num_blocks
elif isinstance(blast_num_blocks, tuple):
num_blocks = blast_num_blocks
if len(blast_num_blocks) == 1:
num_blocks = num_blocks[0]
if isinstance(num_blocks, list):
num_blocks = num_blocks[0]
else:
raise ValueError(f"blast_num_blocks must have either dict, int, or tuple of ints, got: {type(blast_num_blocks)}.")
# Load Decomposed BLAST Weights
new_layer = BlastLinear(
in_features=m.weight.shape[1],
out_features=m.weight.shape[0],
num_blocks=num_blocks,
rank=rank,
bias=m.bias is not None,
device=m.weight.device,
dtype=m.weight.dtype,
precompute_matrix=precompute_matrix,
)
parent_module = get_parent(model, mn)
child_name = mn.split(".")[-1]
parent_module.add_module(child_name, new_layer)
return model
class BlastLinear(torch.nn.Module):
def __init__(self,
in_features: int,
out_features: int,
num_blocks: Union[int, Union[List, Tuple]],
rank=None,
bias: bool = True,
device=None,
dtype=torch.float32,
precompute_matrix=False,
) -> None:
super().__init__()
self.in_features = in_features
self.out_features = out_features
if isinstance(num_blocks, int):
num_blocks=(num_blocks, num_blocks)
if isinstance(num_blocks[0], list):
num_blocks[0] = num_blocks[0][0]
if isinstance(num_blocks[1], list):
num_blocks[1] = num_blocks[1][0]
assert len(num_blocks)==2
assert in_features % num_blocks[1] == 0 and out_features % num_blocks[0] == 0
self.num_blocks = num_blocks
self.precompute_matrix = precompute_matrix
if rank is None:
rank = min(in_features, out_features)
if isinstance(rank, float):
rank = int(rank * min(in_features, out_features))
self.rank = rank
self.B = nn.Parameter(torch.empty(num_blocks[0], out_features // num_blocks[0], rank, device=device, dtype=dtype))
self.C = nn.Parameter(torch.empty(num_blocks[1], rank, in_features // num_blocks[1], device=device, dtype=dtype))
self.D = nn.Parameter(torch.empty(num_blocks[0], num_blocks[1], rank, device=device, dtype=dtype))
if bias:
self.bias = nn.Parameter(torch.empty(out_features, device=device, dtype=dtype))
else:
self.register_parameter('bias', None)
self.rank_score = 0.
def get_matrix(self):
C = self.C.unsqueeze(0) # 1,b2,r,q
D = self.D.unsqueeze(-1) # b1,b2,r,1
DC = C*D
DC = DC.permute(0,1,3,2).reshape(self.num_blocks[0], self.in_features, self.rank) # b1 n r
B = self.B # b1 p r
A = torch.bmm(B, DC.transpose(1,2))
A = A.view(self.out_features, self.in_features)
return A
#@torch.compile
def forward(self, x : torch.Tensor) -> torch.Tensor:
if self.precompute_matrix:
if self.training:
self.A = None
A = self.get_matrix()
else:
if not hasattr(self, 'A') or self.A is None:
self.A = self.get_matrix()
A = self.A
out = torch.nn.functional.linear(x, A)
else:
x_shape = x.shape
x = x.flatten(0,-2)
x = x.view(-1, self.num_blocks[1], x.shape[-1]//self.num_blocks[1]).transpose(0,1)
y = torch.bmm(x, self.C.transpose(1,2)) # (nb, n, rank)
z = y.unsqueeze(0) * self.D.unsqueeze(2)
z = z.sum(1)
out = torch.bmm(z, self.B.transpose(1,2))
out = out.transpose(0,1).reshape(*(x_shape[:-1] + (self.out_features,)))
if self.bias is not None:
out += self.bias.to(x.dtype)
return out
def extra_repr(self) -> str:
return f'in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}, rank={self.rank}, num_blocks={self.num_blocks}'
class BlastLlamaModel(LlamaModel):
config_class = BlastLlamaConfig
def __init__(self, config: BlastLlamaConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
logger.info("Replacing Linear Layers to BlastLiner...")
replace_layers_with_blast(
self.layers,
config.target_modules,
config.blast_rank,
config.blast_num_blocks,
config.indices,
config.precompute_matrix,
#config.blast_decomposed_weight_path,
)
#config.blast_decomposed_weight_path = None
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = LlamaRotaryEmbedding(config=config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
class BlastModelForCausalLM(LlamaForCausalLM, PyTorchModelHubMixin):
config_class = BlastLlamaConfig
def __init__(self, config):
super().__init__(config)
self.model = BlastLlamaModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()