|
|
|
import torch |
|
from transformers import PhiForCausalLM |
|
from .configuration_pruned_phi import PhiPrunedConfig |
|
import torch.nn as nn |
|
|
|
|
|
class PhiPrunedForCausalLM(PhiForCausalLM): |
|
config_class = PhiPrunedConfig |
|
|
|
def __init__(self, config: PhiPrunedConfig): |
|
super().__init__(config) |
|
for i in range(32): |
|
self.model.layers[i].self_attn.dense = nn.Linear(640, 2560, bias=True) |
|
self.model.layers[i].self_attn.hidden_size = 640 |
|
self.model.layers[i].self_attn.q_proj = nn.Linear(2560, 640, bias=True) |
|
self.model.layers[i].self_attn.k_proj = nn.Linear(2560, 640, bias=True) |
|
self.model.layers[i].self_attn.v_proj = nn.Linear(2560, 640, bias=True) |
|
|
|
self.model.layers[i].mlp.fc1 = nn.Linear(2560, 10240, bias=True) |
|
self.model.layers[i].mlp.fc2 = nn.Linear(10240, 2560, bias=True) |
|
|
|
|
|
|
|
for layer in self.model.layers: |
|
layer.self_attn.num_heads = layer.self_attn.q_proj.weight.data.shape[0] // layer.self_attn.head_dim |
|
layer.self_attn.num_key_value_heads = layer.self_attn.k_proj.weight.data.shape[ |
|
0] // layer.self_attn.head_dim |
|
|