File size: 1,208 Bytes
5aaa188
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
import torch.nn as nn
from transformers import LlamaForCausalLM
from transformers.models.llama.modeling_llama import LlamaAttention

model = LlamaForCausalLM.from_pretrained(
    "meta-llama/Llama-2-70b-hf",
    use_cache=False,
    torch_dtype=torch.bfloat16,
    use_flash_attention_2=True,
    max_position_embeddings=8192,
)

def replace_modules(module):
    has_bias = module.q_proj.bias is not None
    qkv_weight = torch.cat([module.q_proj.weight.data, module.k_proj.weight.data, module.v_proj.weight.data], dim=0)
    module.qkv_proj = nn.Linear(module.hidden_size, qkv_weight.shape[0], bias=has_bias)
    module.qkv_proj.weight.data = qkv_weight
    if has_bias:
        qkv_bias = torch.cat([module.q_proj.bias, module.k_proj.bias, module.v_proj.bias], dim=0)
        module.qkv_proj.bias.data = qkv_bias
    del module.q_proj
    del module.k_proj
    del module.v_proj
    module.dim1 = module.num_heads * module.head_dim
    module.dim2 = module.num_key_value_heads * module.head_dim

for name, module in model.named_modules():
    if isinstance(module, LlamaAttention):
        replace_modules(module)

model.config.save_pretrained("my_config")
model.save_pretrained("llama2-70b")