sinder / singular_defect.py
haoqiwang's picture
add files
9ae1b1e
raw
history blame
3.88 kB
# compute singular defect directions
import torch
import torch.nn.functional as F
def anomaly_dir_attn(
blk,
identity=False,
bias=False,
centered=False,
homogeneous=False,
):
with torch.no_grad():
N = blk.ls1.gamma.shape[0]
dev = blk.ls1.gamma.device
A4 = torch.diag(blk.ls1.gamma)
A3 = blk.attn.proj.weight
B3 = blk.attn.proj.bias
A2 = blk.attn.qkv.weight.chunk(3, dim=0)[-1]
B2 = blk.attn.qkv.bias.chunk(3, dim=0)[-1]
A1 = torch.diag(blk.norm1.weight)
B1 = blk.norm1.bias
A0 = (torch.eye(N) - 1 / N * torch.ones(N, N)).to(dev)
A = A4 @ A3 @ A2 @ A1
if centered:
A = A @ A0
B = A4 @ (A3 @ (A2 @ B1)) + A4 @ (A3 @ B2) + A4 @ B3
if bias:
A = torch.cat((A, B[:, None]), dim=1)
if homogeneous:
onehot = torch.cat(
(torch.zeros_like(B), torch.ones(1).to(dev))
)
A = torch.cat((A, onehot[None]), dim=0)
if identity:
iden = torch.eye(N).to(dev)
A[:N, :N] += iden
u, _, _ = torch.linalg.svd(A)
return u[:N, 0], A, B
def w12(blk, x):
with torch.no_grad():
x1, x2 = blk.mlp.w12(x).chunk(2, dim=-1)
return F.silu(x1) * x2
def anomaly_dir_mlp_ls(
blk,
identity=False,
bias=False,
centered=False,
homogeneous=False,
bias_ls=False,
):
with torch.no_grad():
N = blk.ls2.gamma.shape[0]
M = blk.mlp.w3.weight.shape[1]
dev = blk.ls2.gamma.device
A4 = torch.diag(blk.ls2.gamma)
A3 = blk.mlp.w3.weight
B3 = blk.mlp.w3.bias
X = torch.randn(100000, N, device=dev)
Y = w12(blk, X)
if bias_ls:
X_one = torch.cat((X, torch.ones(100000, 1).to(dev)), dim=1)
else:
X_one = X
sol = torch.linalg.lstsq(X_one, Y)
if bias_ls:
A2 = sol.solution.T[:, :-1]
B2 = sol.solution.T[:, -1]
else:
A2 = sol.solution.T
B2 = torch.zeros(M).to(dev)
A1 = torch.diag(blk.norm2.weight)
B1 = blk.norm2.bias
A0 = (torch.eye(N) - 1 / N * torch.ones(N, N)).to(dev)
A = A4 @ A3 @ A2 @ A1
if centered:
A = A @ A0
B = A4 @ (A3 @ (A2 @ B1)) + A4 @ (A3 @ B2) + A4 @ B3
if bias:
A = torch.cat((A, B[:, None]), dim=1)
if homogeneous:
onehot = torch.cat(
(torch.zeros_like(B), torch.ones(1).to(dev))
)
A = torch.cat((A, onehot[None]), dim=0)
if identity:
iden = torch.eye(N).to(dev)
A[:N, :N] += iden
u, s, vt = torch.linalg.svd(A)
return u[:N, 0], A, B
def anomaly_dir(blk, homogeneous=False):
_, A, b = anomaly_dir_attn(
blk,
identity=True,
bias=homogeneous,
centered=True,
homogeneous=homogeneous,
)
_, C, d = anomaly_dir_mlp_ls(
blk,
identity=True,
bias=homogeneous,
bias_ls=False,
centered=True,
homogeneous=homogeneous,
)
with torch.no_grad():
N = b.shape[0]
AA = C @ A
if homogeneous:
BB = 0
else:
BB = C @ b + d
u, _, _ = torch.linalg.svd(AA)
return u[:N, 0], AA, BB
def singular_defect_directions(model):
accumulative_anomalies = []
anomaly_dab = [anomaly_dir(blk) for blk in model.blocks]
anomaly_as = [dab[1] for dab in anomaly_dab]
with torch.no_grad():
aaa = torch.eye(anomaly_as[0].shape[0]).to(anomaly_as[0])
for a in anomaly_as:
aaa = a @ aaa
u, _, _ = torch.linalg.svd(aaa)
accumulative_anomalies.append(u[:, 0])
return accumulative_anomalies