File size: 1,213 Bytes
7e9f6a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29

import torch
from transformers import PhiForCausalLM
from .configuration_pruned_phi import PhiPrunedConfig
import torch.nn as nn


class PhiPrunedForCausalLM(PhiForCausalLM):
    config_class = PhiPrunedConfig

    def __init__(self, config: PhiPrunedConfig):
        super().__init__(config)
        for i in range(32):
            self.model.layers[i].self_attn.dense = nn.Linear(640, 2560, bias=True)
            self.model.layers[i].self_attn.hidden_size = 640
            self.model.layers[i].self_attn.q_proj = nn.Linear(2560, 640, bias=True)
            self.model.layers[i].self_attn.k_proj = nn.Linear(2560, 640, bias=True)
            self.model.layers[i].self_attn.v_proj = nn.Linear(2560, 640, bias=True)

            self.model.layers[i].mlp.fc1 = nn.Linear(2560, 10240, bias=True)
            self.model.layers[i].mlp.fc2 = nn.Linear(10240, 2560, bias=True)

        

        for layer in self.model.layers:
            layer.self_attn.num_heads = layer.self_attn.q_proj.weight.data.shape[0] // layer.self_attn.head_dim
            layer.self_attn.num_key_value_heads = layer.self_attn.k_proj.weight.data.shape[
                                                      0] // layer.self_attn.head_dim