from transformers import PhiForCausalLM from .configuration_asvd_phi import ASVDPhiConfig import torch.nn as nn class ASVDLinear(nn.Module): def __init__(self, in_features, out_features, rank, train_frac_beta=0.2, bias=None): super().__init__() # self.BLinear = nn.Linear(in_features, rank, bias=False) # self.ALinear = nn.Linear(rank, out_features, bias=bias) self.BLinear_no_train = nn.Linear(in_features, rank[0], bias=False) self.BLinear_train = nn.Linear(in_features, rank[1], bias=False) self.ALinear_no_train = nn.Linear(rank[0], out_features, bias=False) self.ALinear_train = nn.Linear(rank[1], out_features, bias=bias is not None) # Gradients for no-train weights should be disabled self.BLinear_no_train.weight.requires_grad = False self.ALinear_no_train.weight.requires_grad = False def forward(self, input): # return self.ALinear(self.BLinear(input)) y_no_train = self.BLinear_no_train(input) y_no_train = self.ALinear_no_train(y_no_train) y_train = self.BLinear_train(input) y_train = self.ALinear_train(y_train) y = y_no_train + y_train return y class ASVDPhiForCausalLM(PhiForCausalLM): config_class = ASVDPhiConfig def __init__(self, config: ASVDPhiConfig): super().__init__(config) self.truncation_ranks = config.truncation_ranks full_name_dict = {module: name for name, module in self.named_modules()} linear_info = {} modules = [self] while len(modules) > 0: submodule = modules.pop() for name, raw_linear in submodule.named_children(): if isinstance(raw_linear, nn.Linear): full_name = full_name_dict[raw_linear] linear_info[raw_linear] = { "father": submodule, "name": name, "full_name": full_name, } else: modules.append(raw_linear) for name, module in self.named_modules(): if name in self.truncation_ranks: info = linear_info[module] new_layer = ASVDLinear( module.in_features, module.out_features, self.truncation_ranks[name], bias=module.bias is not None ) setattr(info["father"], info["name"], new_layer)