|
import torch |
|
import torch.nn as nn |
|
from transformers import LlamaForCausalLM |
|
from transformers.models.llama.modeling_llama import LlamaAttention |
|
|
|
model = LlamaForCausalLM.from_pretrained( |
|
"meta-llama/Llama-2-70b-hf", |
|
use_cache=False, |
|
torch_dtype=torch.bfloat16, |
|
use_flash_attention_2=True, |
|
max_position_embeddings=8192, |
|
) |
|
|
|
def replace_modules(module): |
|
has_bias = module.q_proj.bias is not None |
|
qkv_weight = torch.cat([module.q_proj.weight.data, module.k_proj.weight.data, module.v_proj.weight.data], dim=0) |
|
module.qkv_proj = nn.Linear(module.hidden_size, qkv_weight.shape[0], bias=has_bias) |
|
module.qkv_proj.weight.data = qkv_weight |
|
if has_bias: |
|
qkv_bias = torch.cat([module.q_proj.bias, module.k_proj.bias, module.v_proj.bias], dim=0) |
|
module.qkv_proj.bias.data = qkv_bias |
|
del module.q_proj |
|
del module.k_proj |
|
del module.v_proj |
|
module.dim1 = module.num_heads * module.head_dim |
|
module.dim2 = module.num_key_value_heads * module.head_dim |
|
|
|
for name, module in model.named_modules(): |
|
if isinstance(module, LlamaAttention): |
|
replace_modules(module) |
|
|
|
model.config.save_pretrained("my_config") |
|
model.save_pretrained("llama2-70b") |
|
|