|
import torch |
|
from torch.nn.functional import normalize |
|
|
|
|
|
def check_anomaly_theoretical( |
|
x, |
|
H, |
|
W, |
|
anomaly_dir=None, |
|
temperature=0.1, |
|
mask_thr=0.001, |
|
kernel=3, |
|
): |
|
x_token = x[:, 1:] |
|
B = x.shape[0] |
|
assert B == 1 |
|
x_token = x_token.reshape(H, W, -1).contiguous() |
|
|
|
with torch.no_grad(): |
|
feature = normalize(x_token, dim=-1) |
|
direction = normalize(anomaly_dir, dim=-1) |
|
|
|
logits = -(feature * direction).sum(dim=-1).abs() |
|
prob = torch.exp(logits / temperature) |
|
|
|
assert kernel in (3, 5) |
|
pad = kernel // 2 |
|
|
|
w = prob.unfold(0, kernel, 1).unfold(1, kernel, 1) |
|
w = w / w.sum(dim=(-1, -2), keepdims=True) |
|
|
|
if kernel == 3: |
|
gaussian = ( |
|
torch.FloatTensor( |
|
[ |
|
1 / 16, |
|
1 / 8, |
|
1 / 16, |
|
1 / 8, |
|
1 / 4, |
|
1 / 8, |
|
1 / 16, |
|
1 / 8, |
|
1 / 16, |
|
] |
|
) |
|
.to(w.device) |
|
.reshape(1, 1, 3, 3) |
|
) |
|
elif kernel == 5: |
|
gaussian = ( |
|
torch.tensor( |
|
[ |
|
[1, 4, 7, 4, 1], |
|
[4, 16, 26, 16, 4], |
|
[7, 26, 41, 26, 7], |
|
[4, 16, 26, 16, 4], |
|
[1, 4, 7, 4, 1], |
|
] |
|
) |
|
.float() |
|
.to(w.device) |
|
/ 273 |
|
) |
|
|
|
w2 = w * gaussian |
|
|
|
w2 = w2 / w2.sum(dim=(-1, -2), keepdims=True) |
|
|
|
T = x_token.unfold(0, kernel, 1).unfold(1, kernel, 1) |
|
T = (T * w2[:, :, None].to(T.device)).sum(dim=(-1, -2)) |
|
|
|
mask_full = logits < logits.mean() - mask_thr * logits.std() |
|
mask_full[:pad, :] = False |
|
mask_full[:, :pad] = False |
|
mask_full[-pad:, :] = False |
|
mask_full[:, -pad:] = False |
|
index_tensor = torch.nonzero(mask_full.flatten()).flatten() |
|
if len(index_tensor) == 0: |
|
return None |
|
rows = index_tensor // W |
|
cols = index_tensor % W |
|
|
|
alpha = x_token[pad:-pad, pad:-pad].norm(dim=-1).mean() |
|
|
|
loss_neighbor = ( |
|
(x_token[rows, cols] - T[rows - pad, cols - pad]).norm(dim=-1) |
|
).mean() / alpha |
|
|
|
return loss_neighbor, rows, cols, T, alpha, mask_full, x_token |
|
|
|
|
|
def get_neighbor_loss( |
|
model, |
|
x, |
|
skip_less_than=1, |
|
mask_thr=0.001, |
|
kernel=3, |
|
): |
|
H = x.shape[2] |
|
W = x.shape[3] |
|
x = model.prepare_tokens_with_masks(x) |
|
|
|
for i, blk in enumerate(model.blocks): |
|
x = blk(x) |
|
assert len(model.singular_defects) > 0 |
|
result = check_anomaly_theoretical( |
|
x, |
|
H // model.patch_size, |
|
W // model.patch_size, |
|
model.singular_defects[i], |
|
mask_thr=mask_thr, |
|
kernel=kernel, |
|
) |
|
if result is not None: |
|
( |
|
loss_neighbor, |
|
rows, |
|
cols, |
|
T, |
|
alpha, |
|
mask_angle, |
|
x_token, |
|
) = result |
|
if len(rows) >= skip_less_than: |
|
assert not torch.isnan(loss_neighbor).any() |
|
return ( |
|
i, |
|
loss_neighbor, |
|
rows, |
|
cols, |
|
T, |
|
alpha, |
|
mask_angle, |
|
x_token, |
|
) |
|
return None |
|
|