Spaces:
Running
on
Zero
Running
on
Zero
""" | |
backbone.py - Contains the backbone of the model. | |
(It is based on LPIENet and CURL's backbone) | |
Perceptual Image Enhancement for Smartphone Real-Time Applications | |
https://github.com/mv-lab/AISP | |
CURL: Neural Curve Layers for Global Image Enhancement | |
https://github.com/sjmoran/CURL | |
David Serrano (dserrano@cvc.uab.cat) | |
May 2024 | |
""" | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from typing import List | |
class AttentionBlock(nn.Module): | |
def __init__(self, dim: int): | |
super(AttentionBlock, self).__init__() | |
self._spatial_attention_conv = nn.Conv2d(2, dim, kernel_size=3, padding=1) | |
# Channel attention MLP | |
self._channel_attention_conv0 = nn.Conv2d(1, dim, kernel_size=1, padding=0) | |
self._channel_attention_conv1 = nn.Conv2d(dim, dim, kernel_size=1, padding=0) | |
self._out_conv = nn.Conv2d(2 * dim, dim, kernel_size=1, padding=0) | |
def forward(self, x: torch.Tensor): | |
if len(x.shape) != 4: | |
raise ValueError(f"Expected [B, C, H, W] input, got {x.shape}.") | |
# Spatial attention | |
mean = torch.mean(x, dim=1, keepdim=True) # Mean/Max on C axis | |
max, _ = torch.max(x, dim=1, keepdim=True) | |
spatial_attention = torch.cat([mean, max], dim=1) # [B, 2, H, W] | |
spatial_attention = self._spatial_attention_conv(spatial_attention) | |
spatial_attention = torch.sigmoid(spatial_attention) * x | |
# NOTE: This differs from CBAM as it uses Channel pooling, not spatial pooling! | |
# In a way, this is 2x spatial attention | |
channel_attention = torch.relu(self._channel_attention_conv0(mean)) | |
channel_attention = self._channel_attention_conv1(channel_attention) | |
channel_attention = torch.sigmoid(channel_attention) * x | |
attention = torch.cat([spatial_attention, channel_attention], dim=1) # [B, 2*dim, H, W] | |
attention = self._out_conv(attention) | |
return x + attention | |
class InverseBlock(nn.Module): | |
def __init__(self, input_channels: int, channels: int): | |
super(InverseBlock, self).__init__() | |
self._conv0 = nn.Conv2d(input_channels, channels, kernel_size=1) | |
self._dw_conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1, groups=channels) | |
self._conv1 = nn.Conv2d(channels, channels, kernel_size=1) | |
self._conv2 = nn.Conv2d(input_channels, channels, kernel_size=1) | |
def forward(self, x: torch.Tensor): | |
features = self._conv0(x) | |
features = F.elu(self._dw_conv(features)) | |
features = self._conv1(features) | |
x = torch.relu(self._conv2(x)) | |
return x + features | |
class BaseBlock(nn.Module): | |
def __init__(self, channels: int): | |
super(BaseBlock, self).__init__() | |
self._conv0 = nn.Conv2d(channels, channels, kernel_size=1) | |
self._dw_conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1, groups=channels) | |
self._conv1 = nn.Conv2d(channels, channels, kernel_size=1) | |
self._conv2 = nn.Conv2d(channels, channels, kernel_size=1) | |
self._conv3 = nn.Conv2d(channels, channels, kernel_size=1) | |
def forward(self, x: torch.Tensor): | |
features = self._conv0(x) | |
features = F.elu(self._dw_conv(features)) | |
features = self._conv1(features) | |
x = x + features | |
features = F.elu(self._conv2(x)) | |
features = self._conv3(features) | |
return x + features | |
class AttentionTail(nn.Module): | |
def __init__(self, channels: int): | |
super(AttentionTail, self).__init__() | |
self._conv0 = nn.Conv2d(channels, channels, kernel_size=7, padding=3) | |
self._conv1 = nn.Conv2d(channels, channels, kernel_size=5, padding=2) | |
self._conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) | |
def forward(self, x: torch.Tensor): | |
attention = torch.relu(self._conv0(x)) | |
attention = torch.relu(self._conv1(attention)) | |
attention = torch.sigmoid(self._conv2(attention)) | |
return x * attention | |
class Flatten(nn.Module): | |
def forward(self, x): | |
"""Flatten a Tensor to a Vector | |
:param x: Tensor | |
:returns: 1D Tensor | |
:rtype: Tensor | |
""" | |
return x.view(x.size()[0], -1) | |
class ResidualConnection(nn.Module): | |
def __init__(self, in_channels): | |
super(ResidualConnection, self).__init__() | |
self.in_channels = in_channels | |
self.midnet2 = nn.Sequential( | |
nn.Conv2d(in_channels, 64, 3, 1, 2, 2), | |
nn.LeakyReLU(), | |
nn.Conv2d(64, 64, 3, 1, 2, 2), | |
nn.LeakyReLU() | |
) | |
self.midnet4 = nn.Sequential( | |
nn.Conv2d(in_channels, 64, 3, 1, 4, 4), | |
nn.LeakyReLU(), | |
nn.Conv2d(64, 64, 3, 1, 4, 4), | |
nn.LeakyReLU() | |
) | |
self.globnet = nn.Sequential( | |
nn.Conv2d(in_channels, 64, 3, 2, 1, 1), | |
nn.LeakyReLU(), | |
nn.MaxPool2d(kernel_size=3, stride=2, padding=1), | |
nn.Conv2d(64, 64, 3, 2, 1, 1), | |
nn.LeakyReLU(), | |
nn.MaxPool2d(kernel_size=3, stride=2, padding=1), | |
nn.Conv2d(64, 64, 3, 2, 1, 1), | |
nn.LeakyReLU(), | |
nn.AdaptiveAvgPool2d(1), | |
Flatten(), | |
nn.Dropout(0.5), | |
nn.Linear(64, 64) | |
) | |
self.conv_fuse = nn.Conv2d(in_channels=192+in_channels, out_channels=in_channels, kernel_size=1) | |
def forward(self, x): | |
x_midnet2 = self.midnet2(x) | |
x_midnet4 = self.midnet4(x) | |
x_global = self.globnet(x).unsqueeze(2).unsqueeze(3) | |
x_global = x_global.repeat(1, 1, x_midnet2.shape[2], x_midnet2.shape[3]) | |
x_fuse = torch.cat((x, x_midnet2, x_midnet4, x_global), dim=1) | |
x_out = self.conv_fuse(x_fuse) | |
return x_out | |
class Backbone(nn.Module): | |
def __init__(self, input_channels: int, output_channels: int, encoder_dims: List[int], decoder_dims: List[int]): | |
super(Backbone, self).__init__() | |
if len(encoder_dims) != len(decoder_dims) + 1 or len(decoder_dims) < 1: | |
raise ValueError(f"Unexpected encoder and decoder dims: {encoder_dims}, {decoder_dims}.") | |
if input_channels != output_channels: | |
raise NotImplementedError() | |
encoders = [] | |
for i, encoder_dim in enumerate(encoder_dims): | |
input_dim = input_channels if i == 0 else encoder_dims[i - 1] | |
encoders.append( | |
nn.Sequential( | |
nn.Conv2d(input_dim, encoder_dim, kernel_size=3, padding=1), | |
BaseBlock(encoder_dim), | |
BaseBlock(encoder_dim), | |
AttentionBlock(encoder_dim), | |
) | |
) | |
self._encoders = nn.ModuleList(encoders) | |
decoders = [] | |
for i, decoder_dim in enumerate(decoder_dims): | |
input_dim = encoder_dims[-1] if i == 0 else decoder_dims[i - 1] + encoder_dims[-i - 1] | |
decoders.append( | |
nn.Sequential( | |
nn.Conv2d(input_dim, decoder_dim, kernel_size=3, padding=1), | |
BaseBlock(decoder_dim), | |
BaseBlock(decoder_dim), | |
AttentionBlock(decoder_dim), | |
) | |
) | |
self._decoders = nn.ModuleList(decoders) | |
self._inverse_bock = InverseBlock(encoder_dims[0] + decoder_dims[-1], output_channels) | |
self._attention_tail = AttentionTail(output_channels) | |
residual_connections = [] | |
for i, decoder_dim in enumerate(encoder_dims): | |
residual_connections.append( | |
ResidualConnection(in_channels=decoder_dim) | |
) | |
self._residual_connections = nn.ModuleList(residual_connections) | |
def forward(self, x: torch.Tensor): | |
if len(x.shape) != 4: | |
raise ValueError(f"Expected [B, C, H, W] input, got {x.shape}.") | |
global_residual = x | |
encoder_outputs, residual_connections = [], [] | |
for i, encoder in enumerate(self._encoders): | |
x = encoder(x) | |
if i != len(self._encoders) - 1: | |
encoder_outputs.append(x) | |
residual_connections.append(self._residual_connections[i](x)) | |
x = F.max_pool2d(x, kernel_size=2) | |
encoder_outputs.reverse() | |
residual_connections.reverse() | |
for i, decoder in enumerate(self._decoders): | |
x = decoder(x) | |
x = nn.Upsample(size=encoder_outputs[i].shape[2:], mode='bilinear', align_corners=False)(x) | |
x = torch.cat([x, residual_connections[i]], dim=1) | |
x = self._inverse_bock(x) | |
x = self._attention_tail(x) | |
return torch.clip(x + global_residual, 0, 1) |