clear-image / models /seemore.py
JohnAlexander23's picture
Upload 23 files
69f4183 verified
raw
history blame
13.9 kB
from typing import Tuple, List
from torch import Tensor
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops.layers.torch import Rearrange
######################
# Meta Architecture
######################
class SeemoRe(nn.Module):
def __init__(self,
scale: int = 4,
in_chans: int = 3,
num_experts: int = 6,
num_layers: int = 6,
embedding_dim: int = 64,
img_range: float = 1.0,
use_shuffle: bool = False,
global_kernel_size: int = 11,
recursive: int = 2,
lr_space: int = 1,
topk: int = 2,):
super().__init__()
self.scale = scale
self.num_in_channels = in_chans
self.num_out_channels = in_chans
self.img_range = img_range
rgb_mean = (0.4488, 0.4371, 0.4040)
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
# -- SHALLOW FEATURES --
self.conv_1 = nn.Conv2d(self.num_in_channels, embedding_dim, kernel_size=3, padding=1)
# -- DEEP FEATURES --
self.body = nn.ModuleList(
[ResGroup(in_ch=embedding_dim,
num_experts=num_experts,
use_shuffle=use_shuffle,
topk=topk,
lr_space=lr_space,
recursive=recursive,
global_kernel_size=global_kernel_size) for i in range(num_layers)]
)
# -- UPSCALE --
self.norm = LayerNorm(embedding_dim, data_format='channels_first')
self.conv_2 = nn.Conv2d(embedding_dim, embedding_dim, kernel_size=3, padding=1)
self.upsampler = nn.Sequential(
nn.Conv2d(embedding_dim, (scale**2) * self.num_out_channels, kernel_size=3, padding=1),
nn.PixelShuffle(scale)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
self.mean = self.mean.type_as(x)
x = (x - self.mean) * self.img_range
# -- SHALLOW FEATURES --
x = self.conv_1(x)
res = x
# -- DEEP FEATURES --
for idx, layer in enumerate(self.body):
x = layer(x)
x = self.norm(x)
# -- HR IMAGE RECONSTRUCTION --
x = self.conv_2(x) + res
x = self.upsampler(x)
x = x / self.img_range + self.mean
return x
#############################
# Components
#############################
class ResGroup(nn.Module):
def __init__(self,
in_ch: int,
num_experts: int,
global_kernel_size: int = 11,
lr_space: int = 1,
topk: int = 2,
recursive: int = 2,
use_shuffle: bool = False):
super().__init__()
self.local_block = RME(in_ch=in_ch,
num_experts=num_experts,
use_shuffle=use_shuffle,
lr_space=lr_space,
topk=topk,
recursive=recursive)
self.global_block = SME(in_ch=in_ch,
kernel_size=global_kernel_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.local_block(x)
x = self.global_block(x)
return x
#############################
# Global Block
#############################
class SME(nn.Module):
def __init__(self,
in_ch: int,
kernel_size: int = 11):
super().__init__()
self.norm_1 = LayerNorm(in_ch, data_format='channels_first')
self.block = StripedConvFormer(in_ch=in_ch, kernel_size=kernel_size)
self.norm_2 = LayerNorm(in_ch, data_format='channels_first')
self.ffn = GatedFFN(in_ch, mlp_ratio=2, kernel_size=3, act_layer=nn.GELU())
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.block(self.norm_1(x)) + x
x = self.ffn(self.norm_2(x)) + x
return x
class StripedConvFormer(nn.Module):
def __init__(self,
in_ch: int,
kernel_size: int):
super().__init__()
self.in_ch = in_ch
self.kernel_size = kernel_size
self.padding = kernel_size // 2
self.proj = nn.Conv2d(in_ch, in_ch, kernel_size=1, padding=0)
self.to_qv = nn.Sequential(
nn.Conv2d(in_ch, in_ch * 2, kernel_size=1, padding=0),
nn.GELU(),
)
self.attn = StripedConv2d(in_ch, kernel_size=kernel_size, depthwise=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
q, v = self.to_qv(x).chunk(2, dim=1)
q = self.attn(q)
x = self.proj(q * v)
return x
#############################
# Local Blocks
#############################
class RME(nn.Module):
def __init__(self,
in_ch: int,
num_experts: int,
topk: int,
lr_space: int = 1,
recursive: int = 2,
use_shuffle: bool = False,):
super().__init__()
self.norm_1 = LayerNorm(in_ch, data_format='channels_first')
self.block = MoEBlock(in_ch=in_ch, num_experts=num_experts, topk=topk, use_shuffle=use_shuffle, recursive=recursive, lr_space=lr_space,)
self.norm_2 = LayerNorm(in_ch, data_format='channels_first')
self.ffn = GatedFFN(in_ch, mlp_ratio=2, kernel_size=3, act_layer=nn.GELU())
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.block(self.norm_1(x)) + x
x = self.ffn(self.norm_2(x)) + x
return x
#################
# MoE Layer
#################
class MoEBlock(nn.Module):
def __init__(self,
in_ch: int,
num_experts: int,
topk: int,
use_shuffle: bool = False,
lr_space: str = "linear",
recursive: int = 2):
super().__init__()
self.use_shuffle = use_shuffle
self.recursive = recursive
self.conv_1 = nn.Sequential(
nn.Conv2d(in_ch, in_ch, kernel_size=3, padding=1),
nn.GELU(),
nn.Conv2d(in_ch, 2*in_ch, kernel_size=1, padding=0)
)
self.agg_conv = nn.Sequential(
nn.Conv2d(in_ch, in_ch, kernel_size=4, stride=4, groups=in_ch),
nn.GELU())
self.conv = nn.Sequential(
nn.Conv2d(in_ch, in_ch, kernel_size=3, stride=1, padding=1, groups=in_ch),
nn.Conv2d(in_ch, in_ch, kernel_size=1, padding=0)
)
self.conv_2 = nn.Sequential(
StripedConv2d(in_ch, kernel_size=3, depthwise=True),
nn.GELU())
if lr_space == "linear":
grow_func = lambda i: i+2
elif lr_space == "exp":
grow_func = lambda i: 2**(i+1)
elif lr_space == "double":
grow_func = lambda i: 2*i+2
else:
raise NotImplementedError(f"lr_space {lr_space} not implemented")
self.moe_layer = MoELayer(
experts=[Expert(in_ch=in_ch, low_dim=grow_func(i)) for i in range(num_experts)], # add here multiple of 2 as low_dim
gate=Router(in_ch=in_ch, num_experts=num_experts),
num_expert=topk,
)
self.proj = nn.Conv2d(in_ch, in_ch, kernel_size=1, padding=0)
def calibrate(self, x: torch.Tensor) -> torch.Tensor:
b, c, h, w = x.shape
res = x
for _ in range(self.recursive):
x = self.agg_conv(x)
x = self.conv(x)
x = F.interpolate(x, size=(h, w), mode="bilinear", align_corners=False)
return res + x
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv_1(x)
if self.use_shuffle:
x = channel_shuffle(x, groups=2)
x, k = torch.chunk(x, chunks=2, dim=1)
x = self.conv_2(x)
k = self.calibrate(k)
x = self.moe_layer(x, k)
x = self.proj(x)
return x
class MoELayer(nn.Module):
def __init__(self, experts: List[nn.Module], gate: nn.Module, num_expert: int = 1):
super().__init__()
assert len(experts) > 0
self.experts = nn.ModuleList(experts)
self.gate = gate
self.num_expert = num_expert
def forward(self, inputs: torch.Tensor, k: torch.Tensor):
out = self.gate(inputs)
weights = F.softmax(out, dim=1, dtype=torch.float).to(inputs.dtype)
topk_weights, topk_experts = torch.topk(weights, self.num_expert)
out = inputs.clone()
if self.training:
exp_weights = torch.zeros_like(weights)
exp_weights.scatter_(1, topk_experts, weights.gather(1, topk_experts))
for i, expert in enumerate(self.experts):
out += expert(inputs, k) * exp_weights[:, i:i+1, None, None]
else:
selected_experts = [self.experts[i] for i in topk_experts.squeeze(dim=0)]
for i, expert in enumerate(selected_experts):
out += expert(inputs, k) * topk_weights[:, i:i+1, None, None]
return out
class Expert(nn.Module):
def __init__(self,
in_ch: int,
low_dim: int,):
super().__init__()
self.conv_1 = nn.Conv2d(in_ch, low_dim, kernel_size=1, padding=0)
self.conv_2 = nn.Conv2d(in_ch, low_dim, kernel_size=1, padding=0)
self.conv_3 = nn.Conv2d(low_dim, in_ch, kernel_size=1, padding=0)
def forward(self, x: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
x = self.conv_1(x)
x = self.conv_2(k) * x # here no more sigmoid
x = self.conv_3(x)
return x
class Router(nn.Module):
def __init__(self,
in_ch: int,
num_experts: int):
super().__init__()
self.body = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
Rearrange('b c 1 1 -> b c'),
nn.Linear(in_ch, num_experts, bias=False),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.body(x)
#################
# Utilities
#################
class StripedConv2d(nn.Module):
def __init__(self,
in_ch: int,
kernel_size: int,
depthwise: bool = False):
super().__init__()
self.in_ch = in_ch
self.kernel_size = kernel_size
self.padding = kernel_size // 2
self.conv = nn.Sequential(
nn.Conv2d(in_ch, in_ch, kernel_size=(1, self.kernel_size), padding=(0, self.padding), groups=in_ch if depthwise else 1),
nn.Conv2d(in_ch, in_ch, kernel_size=(self.kernel_size, 1), padding=(self.padding, 0), groups=in_ch if depthwise else 1),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.conv(x)
def channel_shuffle(x, groups=2):
bat_size, channels, w, h = x.shape
group_c = channels // groups
x = x.view(bat_size, groups, group_c, w, h)
x = torch.transpose(x, 1, 2).contiguous()
x = x.view(bat_size, -1, w, h)
return x
class GatedFFN(nn.Module):
def __init__(self,
in_ch,
mlp_ratio,
kernel_size,
act_layer,):
super().__init__()
mlp_ch = in_ch * mlp_ratio
self.fn_1 = nn.Sequential(
nn.Conv2d(in_ch, mlp_ch, kernel_size=1, padding=0),
act_layer,
)
self.fn_2 = nn.Sequential(
nn.Conv2d(in_ch, in_ch, kernel_size=1, padding=0),
act_layer,
)
self.gate = nn.Conv2d(mlp_ch // 2, mlp_ch // 2,
kernel_size=kernel_size, padding=kernel_size // 2, groups=mlp_ch // 2)
def feat_decompose(self, x):
s = x - self.gate(x)
x = x + self.sigma * s
return x
def forward(self, x: torch.Tensor):
x = self.fn_1(x)
x, gate = torch.chunk(x, 2, dim=1)
gate = self.gate(gate)
x = x * gate
x = self.fn_2(x)
return x
class LayerNorm(nn.Module):
r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
with shape (batch_size, channels, height, width).
"""
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError
self.normalized_shape = (normalized_shape, )
def forward(self, x):
if self.data_format == "channels_last":
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
elif self.data_format == "channels_first":
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x