File size: 8,232 Bytes
d60982d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 |
import torch
import torch.nn as nn
import torch.nn.functional as F
__all__ = ['forward_hook', 'Clone', 'Add', 'Cat', 'ReLU', 'GELU', 'Dropout', 'BatchNorm2d', 'Linear', 'MaxPool2d',
'AdaptiveAvgPool2d', 'AvgPool2d', 'Conv2d', 'Sequential', 'safe_divide', 'einsum', 'Softmax', 'IndexSelect',
'LayerNorm', 'AddEye']
def safe_divide(a, b):
den = b.clamp(min=1e-9) + b.clamp(max=1e-9)
den = den + den.eq(0).type(den.type()) * 1e-9
return a / den * b.ne(0).type(b.type())
def forward_hook(self, input, output):
if type(input[0]) in (list, tuple):
self.X = []
for i in input[0]:
x = i.detach()
x.requires_grad = True
self.X.append(x)
else:
self.X = input[0].detach()
self.X.requires_grad = True
self.Y = output
def backward_hook(self, grad_input, grad_output):
self.grad_input = grad_input
self.grad_output = grad_output
class RelProp(nn.Module):
def __init__(self):
super(RelProp, self).__init__()
# if not self.training:
self.register_forward_hook(forward_hook)
def gradprop(self, Z, X, S):
C = torch.autograd.grad(Z, X, S, retain_graph=True)
return C
def relprop(self, R, alpha):
return R
class RelPropSimple(RelProp):
def relprop(self, R, alpha):
Z = self.forward(self.X)
S = safe_divide(R, Z)
C = self.gradprop(Z, self.X, S)
if torch.is_tensor(self.X) == False:
outputs = []
outputs.append(self.X[0] * C[0])
outputs.append(self.X[1] * C[1])
else:
outputs = self.X * (C[0])
return outputs
class AddEye(RelPropSimple):
# input of shape B, C, seq_len, seq_len
def forward(self, input):
return input + torch.eye(input.shape[2]).expand_as(input).to(input.device)
class ReLU(nn.ReLU, RelProp):
pass
class GELU(nn.GELU, RelProp):
pass
class Softmax(nn.Softmax, RelProp):
pass
class LayerNorm(nn.LayerNorm, RelProp):
pass
class Dropout(nn.Dropout, RelProp):
pass
class MaxPool2d(nn.MaxPool2d, RelPropSimple):
pass
class LayerNorm(nn.LayerNorm, RelProp):
pass
class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d, RelPropSimple):
pass
class AvgPool2d(nn.AvgPool2d, RelPropSimple):
pass
class Add(RelPropSimple):
def forward(self, inputs):
return torch.add(*inputs)
def relprop(self, R, alpha):
Z = self.forward(self.X)
S = safe_divide(R, Z)
C = self.gradprop(Z, self.X, S)
a = self.X[0] * C[0]
b = self.X[1] * C[1]
a_sum = a.sum()
b_sum = b.sum()
a_fact = safe_divide(a_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum()
b_fact = safe_divide(b_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum()
a = a * safe_divide(a_fact, a.sum())
b = b * safe_divide(b_fact, b.sum())
outputs = [a, b]
return outputs
class einsum(RelPropSimple):
def __init__(self, equation):
super().__init__()
self.equation = equation
def forward(self, *operands):
return torch.einsum(self.equation, *operands)
class IndexSelect(RelProp):
def forward(self, inputs, dim, indices):
self.__setattr__('dim', dim)
self.__setattr__('indices', indices)
return torch.index_select(inputs, dim, indices)
def relprop(self, R, alpha):
Z = self.forward(self.X, self.dim, self.indices)
S = safe_divide(R, Z)
C = self.gradprop(Z, self.X, S)
if torch.is_tensor(self.X) == False:
outputs = []
outputs.append(self.X[0] * C[0])
outputs.append(self.X[1] * C[1])
else:
outputs = self.X * (C[0])
return outputs
class Clone(RelProp):
def forward(self, input, num):
self.__setattr__('num', num)
outputs = []
for _ in range(num):
outputs.append(input)
return outputs
def relprop(self, R, alpha):
Z = []
for _ in range(self.num):
Z.append(self.X)
S = [safe_divide(r, z) for r, z in zip(R, Z)]
C = self.gradprop(Z, self.X, S)[0]
R = self.X * C
return R
class Cat(RelProp):
def forward(self, inputs, dim):
self.__setattr__('dim', dim)
return torch.cat(inputs, dim)
def relprop(self, R, alpha):
Z = self.forward(self.X, self.dim)
S = safe_divide(R, Z)
C = self.gradprop(Z, self.X, S)
outputs = []
for x, c in zip(self.X, C):
outputs.append(x * c)
return outputs
class Sequential(nn.Sequential):
def relprop(self, R, alpha):
for m in reversed(self._modules.values()):
R = m.relprop(R, alpha)
return R
class BatchNorm2d(nn.BatchNorm2d, RelProp):
def relprop(self, R, alpha):
X = self.X
beta = 1 - alpha
weight = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3) / (
(self.running_var.unsqueeze(0).unsqueeze(2).unsqueeze(3).pow(2) + self.eps).pow(0.5))
Z = X * weight + 1e-9
S = R / Z
Ca = S * weight
R = self.X * (Ca)
return R
class Linear(nn.Linear, RelProp):
def relprop(self, R, alpha):
beta = alpha - 1
pw = torch.clamp(self.weight, min=0)
nw = torch.clamp(self.weight, max=0)
px = torch.clamp(self.X, min=0)
nx = torch.clamp(self.X, max=0)
def f(w1, w2, x1, x2):
Z1 = F.linear(x1, w1)
Z2 = F.linear(x2, w2)
S1 = safe_divide(R, Z1 + Z2)
S2 = safe_divide(R, Z1 + Z2)
C1 = x1 * torch.autograd.grad(Z1, x1, S1)[0]
C2 = x2 * torch.autograd.grad(Z2, x2, S2)[0]
return C1 + C2
activator_relevances = f(pw, nw, px, nx)
inhibitor_relevances = f(nw, pw, px, nx)
R = alpha * activator_relevances - beta * inhibitor_relevances
return R
class Conv2d(nn.Conv2d, RelProp):
def gradprop2(self, DY, weight):
Z = self.forward(self.X)
output_padding = self.X.size()[2] - (
(Z.size()[2] - 1) * self.stride[0] - 2 * self.padding[0] + self.kernel_size[0])
return F.conv_transpose2d(DY, weight, stride=self.stride, padding=self.padding, output_padding=output_padding)
def relprop(self, R, alpha):
if self.X.shape[1] == 3:
pw = torch.clamp(self.weight, min=0)
nw = torch.clamp(self.weight, max=0)
X = self.X
L = self.X * 0 + \
torch.min(torch.min(torch.min(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3,
keepdim=True)[0]
H = self.X * 0 + \
torch.max(torch.max(torch.max(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3,
keepdim=True)[0]
Za = torch.conv2d(X, self.weight, bias=None, stride=self.stride, padding=self.padding) - \
torch.conv2d(L, pw, bias=None, stride=self.stride, padding=self.padding) - \
torch.conv2d(H, nw, bias=None, stride=self.stride, padding=self.padding) + 1e-9
S = R / Za
C = X * self.gradprop2(S, self.weight) - L * self.gradprop2(S, pw) - H * self.gradprop2(S, nw)
R = C
else:
beta = alpha - 1
pw = torch.clamp(self.weight, min=0)
nw = torch.clamp(self.weight, max=0)
px = torch.clamp(self.X, min=0)
nx = torch.clamp(self.X, max=0)
def f(w1, w2, x1, x2):
Z1 = F.conv2d(x1, w1, bias=None, stride=self.stride, padding=self.padding)
Z2 = F.conv2d(x2, w2, bias=None, stride=self.stride, padding=self.padding)
S1 = safe_divide(R, Z1)
S2 = safe_divide(R, Z2)
C1 = x1 * self.gradprop(Z1, x1, S1)[0]
C2 = x2 * self.gradprop(Z2, x2, S2)[0]
return C1 + C2
activator_relevances = f(pw, nw, px, nx)
inhibitor_relevances = f(nw, pw, px, nx)
R = alpha * activator_relevances - beta * inhibitor_relevances
return R |