File size: 2,454 Bytes
dcf346a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
from transformers import PhiForCausalLM
from .configuration_asvd_phi import ASVDPhiConfig
import torch.nn as nn
class ASVDLinear(nn.Module):
def __init__(self, in_features, out_features, rank, train_frac_beta=0.2, bias=None):
super().__init__()
# self.BLinear = nn.Linear(in_features, rank, bias=False)
# self.ALinear = nn.Linear(rank, out_features, bias=bias)
self.BLinear_no_train = nn.Linear(in_features, rank[0], bias=False)
self.BLinear_train = nn.Linear(in_features, rank[1], bias=False)
self.ALinear_no_train = nn.Linear(rank[0], out_features, bias=False)
self.ALinear_train = nn.Linear(rank[1], out_features, bias=bias is not None)
# Gradients for no-train weights should be disabled
self.BLinear_no_train.weight.requires_grad = False
self.ALinear_no_train.weight.requires_grad = False
def forward(self, input):
# return self.ALinear(self.BLinear(input))
y_no_train = self.BLinear_no_train(input)
y_no_train = self.ALinear_no_train(y_no_train)
y_train = self.BLinear_train(input)
y_train = self.ALinear_train(y_train)
y = y_no_train + y_train
return y
class ASVDPhiForCausalLM(PhiForCausalLM):
config_class = ASVDPhiConfig
def __init__(self, config: ASVDPhiConfig):
super().__init__(config)
self.truncation_ranks = config.truncation_ranks
full_name_dict = {module: name for name, module in self.named_modules()}
linear_info = {}
modules = [self]
while len(modules) > 0:
submodule = modules.pop()
for name, raw_linear in submodule.named_children():
if isinstance(raw_linear, nn.Linear):
full_name = full_name_dict[raw_linear]
linear_info[raw_linear] = {
"father": submodule,
"name": name,
"full_name": full_name,
}
else:
modules.append(raw_linear)
for name, module in self.named_modules():
if name in self.truncation_ranks:
info = linear_info[module]
new_layer = ASVDLinear(
module.in_features, module.out_features, self.truncation_ranks[name], bias=module.bias is not None
)
setattr(info["father"], info["name"], new_layer)
|