deepfake-detect / utils /architectures.py
Sara Mandelli
Update detector
6bd8735
raw
history blame contribute delete
No virus
13.2 kB
import torch
import torch.nn as nn
import numpy as np
from torchvision import transforms
import torch.nn.functional as F
from efficientnet_pytorch import EfficientNet
from efficientnet_pytorch.utils import (
round_filters,
round_repeats,
drop_connect,
get_same_padding_conv2d,
get_model_params,
efficientnet_params,
load_pretrained_weights,
Swish,
MemoryEfficientSwish,
)
from efficientnet_pytorch.model import MBConvBlock
from torchvision.models import resnet
from pytorchcv.model_provider import get_model
class Head(nn.Module):
def __init__(self, in_f, out_f):
super(Head, self).__init__()
self.f = nn.Flatten()
self.l = nn.Linear(in_f, 512)
self.d = nn.Dropout(0.5)
self.o = nn.Linear(512, out_f)
self.b1 = nn.BatchNorm1d(in_f)
self.b2 = nn.BatchNorm1d(512)
self.r = nn.ReLU()
def forward(self, x):
x = self.f(x)
x = self.b1(x)
x = self.d(x)
x = self.l(x)
x = self.r(x)
x = self.b2(x)
x = self.d(x)
out = self.o(x)
return out
class FCN(nn.Module):
def __init__(self, base, in_f, out_f):
super(FCN, self).__init__()
self.base = base
self.h1 = Head(in_f, out_f)
def forward(self, x):
x = self.base(x)
return self.h1(x)
class BaseFCN(nn.Module):
def __init__(self, n_classes: int):
super(BaseFCN, self).__init__()
self.f = nn.Flatten()
self.l = nn.Linear(625, 256)
self.d = nn.Dropout(0.5)
self.o = nn.Linear(256, n_classes)
def forward(self, x):
x = self.f(x)
x = self.l(x)
x = self.d(x)
out = self.o(x)
return out
def get_trainable_parameters_cooccur(self):
return self.parameters()
class BaseFCNHigh(nn.Module):
def __init__(self, n_classes: int):
super(BaseFCNHigh, self).__init__()
self.f = nn.Flatten()
self.l = nn.Linear(625, 512)
self.d = nn.Dropout(0.5)
self.o = nn.Linear(512, n_classes)
def forward(self, x):
x = self.f(x)
x = self.l(x)
x = self.d(x)
out = self.o(x)
return out
def get_trainable_parameters_cooccur(self):
return self.parameters()
class BaseFCN4(nn.Module):
def __init__(self, n_classes: int):
super(BaseFCN4, self).__init__()
self.f = nn.Flatten()
self.l1 = nn.Linear(625, 512)
self.l2 = nn.Linear(512, 384)
self.l3 = nn.Linear(384, 256)
self.d = nn.Dropout(0.5)
self.o = nn.Linear(256, n_classes)
def forward(self, x):
x = self.f(x)
x = self.l1(x)
x = self.d(x)
x = self.l2(x)
x = self.d(x)
x = self.l3(x)
x = self.d(x)
out = self.o(x)
return out
def get_trainable_parameters_cooccur(self):
return self.parameters()
class BaseFCNBnR(nn.Module):
def __init__(self, n_classes: int):
super(BaseFCNBnR, self).__init__()
self.f = nn.Flatten()
self.b1 = nn.BatchNorm1d(625)
self.b2 = nn.BatchNorm1d(256)
self.l = nn.Linear(625, 256)
self.d = nn.Dropout(0.5)
self.o = nn.Linear(256, n_classes)
self.r = nn.ReLU()
def forward(self, x):
x = self.f(x)
x = self.b1(x)
x = self.d(x)
x = self.l(x)
x = self.r(x)
x = self.b2(x)
x = self.d(x)
out = self.o(x)
return out
def get_trainable_parameters_cooccur(self):
return self.parameters()
def forward_resnet_conv(net, x, upto: int = 4):
"""
Forward ResNet only in its convolutional part
:param net:
:param x:
:param upto:
:return:
"""
x = net.conv1(x) # N / 2
x = net.bn1(x)
x = net.relu(x)
x = net.maxpool(x) # N / 4
if upto >= 1:
x = net.layer1(x) # N / 4
if upto >= 2:
x = net.layer2(x) # N / 8
if upto >= 3:
x = net.layer3(x) # N / 16
if upto >= 4:
x = net.layer4(x) # N / 32
return x
class FeatureExtractor(nn.Module):
"""
Abstract class to be extended when supporting features extraction.
It also provides standard normalized and parameters
"""
def features(self, x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
def get_trainable_parameters(self):
return self.parameters()
@staticmethod
def get_normalizer():
return transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
class FeatureExtractorGray(nn.Module):
"""
Abstract class to be extended when supporting features extraction.
It also provides standard normalized and parameters
"""
def features(self, x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
def get_trainable_parameters(self):
return self.parameters()
@staticmethod
def get_normalizer():
return transforms.Normalize(mean=[0.479], std=[0.226])
class EfficientNetGen(FeatureExtractor):
def __init__(self, model: str, n_classes: int, pretrained: bool):
super(EfficientNetGen, self).__init__()
if pretrained:
self.efficientnet = EfficientNet.from_pretrained(model)
else:
self.efficientnet = EfficientNet.from_name(model)
self.classifier = nn.Linear(self.efficientnet._conv_head.out_channels, n_classes)
del self.efficientnet._fc
def features(self, x: torch.Tensor) -> torch.Tensor:
x = self.efficientnet.extract_features(x)
x = self.efficientnet._avg_pooling(x)
x = x.flatten(start_dim=1)
return x
def forward(self, x):
x = self.features(x)
x = self.efficientnet._dropout(x)
x = self.classifier(x)
# x = F.softmax(x, dim=-1)
return x
class EfficientNetB0(EfficientNetGen):
def __init__(self, n_classes: int, pretrained: bool):
super(EfficientNetB0, self).__init__(model='efficientnet-b0', n_classes=n_classes, pretrained=pretrained)
class EfficientNetB4(EfficientNetGen):
def __init__(self, n_classes: int, pretrained: bool):
super(EfficientNetB4, self).__init__(model='efficientnet-b4', n_classes=n_classes, pretrained=pretrained)
class EfficientNetGenPostStem(FeatureExtractor):
def __init__(self, model: str, n_classes: int, pretrained: bool, n_ir_blocks: int):
super(EfficientNetGenPostStem, self).__init__()
if pretrained:
self.efficientnet = EfficientNet.from_pretrained(model)
else:
self.efficientnet = EfficientNet.from_name(model)
self.n_ir_blocks = n_ir_blocks
self.classifier = nn.Linear(self.efficientnet._conv_head.out_channels, n_classes)
# modify STEM
in_channels = 3 # rgb
out_channels = round_filters(32, self.efficientnet._global_params)
Conv2d = get_same_padding_conv2d(image_size=self.efficientnet._global_params.image_size)
self.efficientnet._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=1, bias=False)
self.init_blocks_args = self.efficientnet._blocks_args[0]
self.init_blocks_args = self.init_blocks_args._replace(output_filters=32)
self.init_block = MBConvBlock(self.init_blocks_args, self.efficientnet._global_params)
self.last_block_args = self.efficientnet._blocks_args[0]
self.last_block_args = self.last_block_args._replace(output_filters=32, stride=2)
self.last_block = MBConvBlock(self.last_block_args, self.efficientnet._global_params)
del self.efficientnet._fc
def features(self, x: torch.Tensor) -> torch.Tensor:
x = self.efficientnet._swish(self.efficientnet._bn0(self.efficientnet._conv_stem(x)))
# init blocks
for b in range(self.n_ir_blocks - 1):
x = self.init_block(x, drop_connect_rate=0)
# last block
x = self.last_block(x, drop_connect_rate=0)
# standard blocks efficientNet:
for idx, block in enumerate(self.efficientnet._blocks):
drop_connect_rate = self.efficientnet._global_params.drop_connect_rate
if drop_connect_rate:
drop_connect_rate *= float(idx) / len(self.efficientnet._blocks)
x = block(x, drop_connect_rate=drop_connect_rate)
x = self.efficientnet._swish(self.efficientnet._bn1(self.efficientnet._conv_head(x)))
x = self.efficientnet._avg_pooling(x)
x = x.flatten(start_dim=1)
return x
def forward(self, x):
x = self.features(x)
x = self.efficientnet._dropout(x)
x = self.classifier(x)
# x = F.softmax(x, dim=-1)
return x
class EfficientNetB0PostStemIR(EfficientNetGenPostStem):
def __init__(self, n_classes: int, pretrained: bool, n_ir_blocks: int):
super(EfficientNetB0PostStemIR, self).__init__(model='efficientnet-b0', n_classes=n_classes,
pretrained=pretrained, n_ir_blocks=n_ir_blocks)
class EfficientNetGenPreStem(FeatureExtractor):
def __init__(self, model: str, n_classes: int, pretrained: bool, n_ir_blocks: int):
super(EfficientNetGenPreStem, self).__init__()
if pretrained:
self.efficientnet = EfficientNet.from_pretrained(model)
else:
self.efficientnet = EfficientNet.from_name(model)
self.n_ir_blocks = n_ir_blocks
self.classifier = nn.Linear(self.efficientnet._conv_head.out_channels, n_classes)
self.init_block_args = self.efficientnet._blocks_args[0]
self.init_block_args = self.init_block_args._replace(input_filters=3, output_filters=32)
self.init_block = MBConvBlock(self.init_block_args, self.efficientnet._global_params)
self.last_blocks_args = self.efficientnet._blocks_args[0]
self.last_blocks_args = self.last_blocks_args._replace(output_filters=32)
self.last_block = MBConvBlock(self.last_blocks_args, self.efficientnet._global_params)
# modify STEM
in_channels = 32
out_channels = round_filters(32, self.efficientnet._global_params)
Conv2d = get_same_padding_conv2d(image_size=self.efficientnet._global_params.image_size)
self.efficientnet._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
del self.efficientnet._fc
def features(self, x: torch.Tensor) -> torch.Tensor:
# init block
x = self.init_block(x, drop_connect_rate=0)
# other blocks
for b in range(self.n_ir_blocks - 1):
x = self.last_block(x, drop_connect_rate=0)
# standard stem efficientNet:
x = self.efficientnet._swish(self.efficientnet._bn0(self.efficientnet._conv_stem(x)))
# standard blocks efficientNet:
for idx, block in enumerate(self.efficientnet._blocks):
drop_connect_rate = self.efficientnet._global_params.drop_connect_rate
if drop_connect_rate:
drop_connect_rate *= float(idx) / len(self.efficientnet._blocks)
x = block(x, drop_connect_rate=drop_connect_rate)
x = self.efficientnet._swish(self.efficientnet._bn1(self.efficientnet._conv_head(x)))
x = self.efficientnet._avg_pooling(x)
x = x.flatten(start_dim=1)
return x
def forward(self, x):
x = self.features(x)
x = self.efficientnet._dropout(x)
x = self.classifier(x)
# x = F.softmax(x, dim=-1)
return x
class EfficientNetB0PreStemIR(EfficientNetGenPreStem):
def __init__(self, n_classes: int, pretrained: bool, n_ir_blocks: int):
super(EfficientNetB0PreStemIR, self).__init__(model='efficientnet-b0', n_classes=n_classes,
pretrained=pretrained, n_ir_blocks=n_ir_blocks)
class ResNet50(FeatureExtractor):
def __init__(self, n_classes: int, pretrained: bool):
super(ResNet50, self).__init__()
self.resnet = resnet.resnet50(pretrained=pretrained)
self.fc = nn.Linear(in_features=self.resnet.fc.in_features, out_features=n_classes)
del self.resnet.fc
def features(self, x):
x = forward_resnet_conv(self.resnet, x)
x = self.resnet.avgpool(x).flatten(start_dim=1)
return x
def forward(self, x):
x = self.features(x)
x = self.fc(x)
return x
"""
Xception from Kaggle
"""
class XceptionWeiHao(FeatureExtractor):
def __init__(self, n_classes: int, pretrained: bool):
super(XceptionWeiHao, self).__init__()
self.model = get_model("xception", pretrained=pretrained)
self.model = nn.Sequential(*list(self.model.children())[:-1]) # Remove original output layer
self.model[0].final_block.pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)))
self.model = FCN(self.model, 2048, n_classes)
def features(self, x: torch.Tensor) -> torch.Tensor:
return self.model.base(x)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.features(x)
return self.model.h1(x)