Spaces:
Running
Running
Update DL_models.py
Browse files- DL_models.py +17 -2
DL_models.py
CHANGED
@@ -33,6 +33,10 @@ class CustomResNet(nn.Module):
|
|
33 |
self.fc1 = nn.Linear(2 * (512 * 4), 256)
|
34 |
self.fc2 = nn.Linear(256, 1)
|
35 |
|
|
|
|
|
|
|
|
|
36 |
|
37 |
def forward(self, x1, x2, Location):
|
38 |
N = x1.shape[0]
|
@@ -42,7 +46,9 @@ class CustomResNet(nn.Module):
|
|
42 |
Location = Location.to(torch.float).to(self.device)
|
43 |
|
44 |
# Process both images through the same ResNet
|
45 |
-
f1 = self.resnet(x1)
|
|
|
|
|
46 |
f2 = self.resnet(x2)
|
47 |
|
48 |
# Flatten the features
|
@@ -59,4 +65,13 @@ class CustomResNet(nn.Module):
|
|
59 |
x = torch.relu(self.fc1(combined))
|
60 |
x = self.fc2(x)
|
61 |
|
62 |
-
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
self.fc1 = nn.Linear(2 * (512 * 4), 256)
|
34 |
self.fc2 = nn.Linear(256, 1)
|
35 |
|
36 |
+
self.gradients = None
|
37 |
+
|
38 |
+
def activations_hook(self, grad):
|
39 |
+
self.gradients = grad
|
40 |
|
41 |
def forward(self, x1, x2, Location):
|
42 |
N = x1.shape[0]
|
|
|
46 |
Location = Location.to(torch.float).to(self.device)
|
47 |
|
48 |
# Process both images through the same ResNet
|
49 |
+
f1 = self.resnet[:8](x1)
|
50 |
+
h = f1.register_hook(self.activations_hook)
|
51 |
+
f1 = self.resnet[8:](f1)
|
52 |
f2 = self.resnet(x2)
|
53 |
|
54 |
# Flatten the features
|
|
|
65 |
x = torch.relu(self.fc1(combined))
|
66 |
x = self.fc2(x)
|
67 |
|
68 |
+
return x
|
69 |
+
|
70 |
+
# method for the gradient extraction
|
71 |
+
def get_activations_gradient(self):
|
72 |
+
return self.gradients
|
73 |
+
|
74 |
+
# method for the activation exctraction
|
75 |
+
def get_activations(self, x):
|
76 |
+
x = self.transform(x.to(torch.float).to(self.device))
|
77 |
+
return self.resnet[:8](x)
|