regisss
Push fused QKV checkpoint and config
5aaa188
import torch
import torch.nn as nn
from transformers import LlamaForCausalLM
from transformers.models.llama.modeling_llama import LlamaAttention
model = LlamaForCausalLM.from_pretrained(
"meta-llama/Llama-2-70b-hf",
use_cache=False,
torch_dtype=torch.bfloat16,
use_flash_attention_2=True,
max_position_embeddings=8192,
)
def replace_modules(module):
has_bias = module.q_proj.bias is not None
qkv_weight = torch.cat([module.q_proj.weight.data, module.k_proj.weight.data, module.v_proj.weight.data], dim=0)
module.qkv_proj = nn.Linear(module.hidden_size, qkv_weight.shape[0], bias=has_bias)
module.qkv_proj.weight.data = qkv_weight
if has_bias:
qkv_bias = torch.cat([module.q_proj.bias, module.k_proj.bias, module.v_proj.bias], dim=0)
module.qkv_proj.bias.data = qkv_bias
del module.q_proj
del module.k_proj
del module.v_proj
module.dim1 = module.num_heads * module.head_dim
module.dim2 = module.num_key_value_heads * module.head_dim
for name, module in model.named_modules():
if isinstance(module, LlamaAttention):
replace_modules(module)
model.config.save_pretrained("my_config")
model.save_pretrained("llama2-70b")