sinder / repair.py
haoqiwang's picture
add files
9ae1b1e
import torch
class SVDLinearAddition(torch.nn.Module):
def __init__(self, linear, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
w, b = linear.weight, linear.bias
self.bias = b
with torch.no_grad():
U, S, Vt = torch.linalg.svd(w, full_matrices=False)
self.U = torch.nn.Parameter(U)
self.S = torch.nn.Parameter(S)
self.Vt = torch.nn.Parameter(Vt)
self.epsilon = torch.nn.Parameter(torch.zeros_like(S))
def forward(self, x):
w = self.U @ torch.diag(self.S + self.epsilon) @ self.Vt
x = torch.nn.functional.linear(x, w, self.bias)
return x
class SVDLinearAdditionQKV(torch.nn.Module):
def __init__(self, linear, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
w, b = linear.weight, linear.bias
self.bias = b
with torch.no_grad():
c = w.shape[0] // 3
q = w[:c]
k = w[c : 2 * c]
v = w[2 * c :]
U_q, S_q, Vt_q = torch.linalg.svd(q, full_matrices=False)
self.U_q = torch.nn.Parameter(U_q)
self.S_q = torch.nn.Parameter(S_q)
self.Vt_q = torch.nn.Parameter(Vt_q)
self.epsilon_q = torch.nn.Parameter(torch.zeros_like(S_q))
U_k, S_k, Vt_k = torch.linalg.svd(k, full_matrices=False)
self.U_k = torch.nn.Parameter(U_k)
self.S_k = torch.nn.Parameter(S_k)
self.Vt_k = torch.nn.Parameter(Vt_k)
self.epsilon_k = torch.nn.Parameter(torch.zeros_like(S_k))
U_v, S_v, Vt_v = torch.linalg.svd(v, full_matrices=False)
self.U_v = torch.nn.Parameter(U_v)
self.S_v = torch.nn.Parameter(S_v)
self.Vt_v = torch.nn.Parameter(Vt_v)
self.epsilon_v = torch.nn.Parameter(torch.zeros_like(S_v))
def forward(self, x):
w = torch.concatenate(
(
self.U_q @ torch.diag(self.S_q + self.epsilon_q) @ self.Vt_q,
self.U_k @ torch.diag(self.S_k + self.epsilon_k) @ self.Vt_k,
self.U_v @ torch.diag(self.S_v + self.epsilon_v) @ self.Vt_v,
)
)
x = torch.nn.functional.linear(x, w, self.bias)
return x
def replace_linear_addition_noqk(module, name):
for attr_str in dir(module):
target_attr = getattr(module, attr_str)
if type(target_attr) == torch.nn.Linear:
if 'qkv' in attr_str:
print('replaced: ', name, attr_str)
svd_linear_qkv = SVDLinearAdditionQKV(target_attr)
svd_linear_qkv.U_q.requires_grad = False
svd_linear_qkv.U_k.requires_grad = False
svd_linear_qkv.U_v.requires_grad = False
svd_linear_qkv.S_q.requires_grad = False
svd_linear_qkv.S_k.requires_grad = False
svd_linear_qkv.S_v.requires_grad = False
svd_linear_qkv.Vt_q.requires_grad = False
svd_linear_qkv.Vt_k.requires_grad = False
svd_linear_qkv.Vt_v.requires_grad = False
svd_linear_qkv.epsilon_q.requires_grad = False
svd_linear_qkv.epsilon_k.requires_grad = False
svd_linear_qkv.epsilon_v.requires_grad = True
svd_linear_qkv.bias.requires_grad = False
setattr(module, attr_str, svd_linear_qkv)
else:
print('replaced: ', name, attr_str)
svd_linear = SVDLinearAddition(target_attr)
svd_linear.U.requires_grad = False
svd_linear.S.requires_grad = False
svd_linear.Vt.requires_grad = False
svd_linear.bias.requires_grad = False
svd_linear.epsilon.requires_grad = True
setattr(module, attr_str, svd_linear)
for name, immediate_child_module in module.named_children():
replace_linear_addition_noqk(immediate_child_module, name)
def replace_back(module, name):
for attr_str in dir(module):
target_attr = getattr(module, attr_str)
if type(target_attr) == SVDLinearAddition:
print('replaced back: ', name, attr_str)
with torch.no_grad():
linear = torch.nn.Linear(
target_attr.Vt.shape[1],
target_attr.U.shape[0],
device=target_attr.U.device,
)
linear.weight.add_(
target_attr.U
@ torch.diag(target_attr.S + target_attr.epsilon)
@ target_attr.Vt
- linear.weight
)
linear.bias.add_(target_attr.bias - linear.bias)
setattr(module, attr_str, linear)
elif type(target_attr) == SVDLinearAdditionQKV:
print('replaced back: ', name, attr_str)
with torch.no_grad():
w = torch.concatenate(
(
target_attr.U_q
@ torch.diag(target_attr.S_q + target_attr.epsilon_q)
@ target_attr.Vt_q,
target_attr.U_k
@ torch.diag(target_attr.S_k + target_attr.epsilon_k)
@ target_attr.Vt_k,
target_attr.U_v
@ torch.diag(target_attr.S_v + target_attr.epsilon_v)
@ target_attr.Vt_v,
)
)
linear = torch.nn.Linear(
w.shape[1], w.shape[0], device=target_attr.U_q.device
)
linear.weight.add_(w - linear.weight)
linear.bias.add_(target_attr.bias - linear.bias)
setattr(module, attr_str, linear)
# iterate through immediate child modules. Note, the recursion is done by our code no need to use named_modules()
for name, immediate_child_module in module.named_children():
replace_back(immediate_child_module, name)