|
import torch |
|
import torch.nn.functional as F |
|
from huggingface_hub import PyTorchModelHubMixin |
|
from torch import nn |
|
from torchvision import models |
|
|
|
|
|
class ICN(nn.Module, PyTorchModelHubMixin): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
cnn = models.resnet50(pretrained=False) |
|
self.cnn_head = nn.Sequential( |
|
*list(cnn.children())[:4], |
|
*list(list(list(cnn.children())[4].children())[0].children())[:4], |
|
) |
|
self.cnn_tail = nn.Sequential( |
|
*list(list(cnn.children())[4].children() |
|
)[1:], *list(cnn.children())[5:-2] |
|
) |
|
|
|
self.conv1 = nn.Conv2d(128, 256, 3, padding=1) |
|
self.bn1 = nn.BatchNorm2d(num_features=256) |
|
|
|
self.fc1 = nn.Linear(2048 * 7 * 7, 256) |
|
self.fc2 = nn.Linear(256, 7 * 7) |
|
|
|
self.cls_fc = nn.Linear(256, 3) |
|
|
|
self.criterion = nn.CrossEntropyLoss() |
|
|
|
def forward(self, x): |
|
|
|
real = x[:, :3, :, :] |
|
fake = x[:, 3:, :, :] |
|
|
|
|
|
real_features = F.relu(self.cnn_head(real)) |
|
fake_features = F.relu(self.cnn_head(fake)) |
|
|
|
|
|
combined = torch.cat((real_features, fake_features), 1) |
|
|
|
x = self.conv1(combined) |
|
x = self.bn1(x) |
|
x = F.relu(x) |
|
|
|
x = self.cnn_tail(x) |
|
x = x.view(-1, 2048 * 7 * 7) |
|
|
|
|
|
d = F.relu(self.fc1(x)) |
|
|
|
|
|
grid = self.fc2(d) |
|
|
|
|
|
cl = self.cls_fc(d) |
|
|
|
return grid, cl |
|
|