Eduard-Sebastian Zamfir
add gradio app
9080570
raw
history blame contribute delete
No virus
13.9 kB
from typing import Tuple, List
from torch import Tensor
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops.layers.torch import Rearrange
######################
# Meta Architecture
######################
class SeemoRe(nn.Module):
def __init__(self,
scale: int = 4,
in_chans: int = 3,
num_experts: int = 6,
num_layers: int = 6,
embedding_dim: int = 64,
img_range: float = 1.0,
use_shuffle: bool = False,
global_kernel_size: int = 11,
recursive: int = 2,
lr_space: int = 1,
topk: int = 2,):
super().__init__()
self.scale = scale
self.num_in_channels = in_chans
self.num_out_channels = in_chans
self.img_range = img_range
rgb_mean = (0.4488, 0.4371, 0.4040)
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
# -- SHALLOW FEATURES --
self.conv_1 = nn.Conv2d(self.num_in_channels, embedding_dim, kernel_size=3, padding=1)
# -- DEEP FEATURES --
self.body = nn.ModuleList(
[ResGroup(in_ch=embedding_dim,
num_experts=num_experts,
use_shuffle=use_shuffle,
topk=topk,
lr_space=lr_space,
recursive=recursive,
global_kernel_size=global_kernel_size) for i in range(num_layers)]
)
# -- UPSCALE --
self.norm = LayerNorm(embedding_dim, data_format='channels_first')
self.conv_2 = nn.Conv2d(embedding_dim, embedding_dim, kernel_size=3, padding=1)
self.upsampler = nn.Sequential(
nn.Conv2d(embedding_dim, (scale**2) * self.num_out_channels, kernel_size=3, padding=1),
nn.PixelShuffle(scale)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
self.mean = self.mean.type_as(x)
x = (x - self.mean) * self.img_range
# -- SHALLOW FEATURES --
x = self.conv_1(x)
res = x
# -- DEEP FEATURES --
for idx, layer in enumerate(self.body):
x = layer(x)
x = self.norm(x)
# -- HR IMAGE RECONSTRUCTION --
x = self.conv_2(x) + res
x = self.upsampler(x)
x = x / self.img_range + self.mean
return x
#############################
# Components
#############################
class ResGroup(nn.Module):
def __init__(self,
in_ch: int,
num_experts: int,
global_kernel_size: int = 11,
lr_space: int = 1,
topk: int = 2,
recursive: int = 2,
use_shuffle: bool = False):
super().__init__()
self.local_block = RME(in_ch=in_ch,
num_experts=num_experts,
use_shuffle=use_shuffle,
lr_space=lr_space,
topk=topk,
recursive=recursive)
self.global_block = SME(in_ch=in_ch,
kernel_size=global_kernel_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.local_block(x)
x = self.global_block(x)
return x
#############################
# Global Block
#############################
class SME(nn.Module):
def __init__(self,
in_ch: int,
kernel_size: int = 11):
super().__init__()
self.norm_1 = LayerNorm(in_ch, data_format='channels_first')
self.block = StripedConvFormer(in_ch=in_ch, kernel_size=kernel_size)
self.norm_2 = LayerNorm(in_ch, data_format='channels_first')
self.ffn = GatedFFN(in_ch, mlp_ratio=2, kernel_size=3, act_layer=nn.GELU())
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.block(self.norm_1(x)) + x
x = self.ffn(self.norm_2(x)) + x
return x
class StripedConvFormer(nn.Module):
def __init__(self,
in_ch: int,
kernel_size: int):
super().__init__()
self.in_ch = in_ch
self.kernel_size = kernel_size
self.padding = kernel_size // 2
self.proj = nn.Conv2d(in_ch, in_ch, kernel_size=1, padding=0)
self.to_qv = nn.Sequential(
nn.Conv2d(in_ch, in_ch * 2, kernel_size=1, padding=0),
nn.GELU(),
)
self.attn = StripedConv2d(in_ch, kernel_size=kernel_size, depthwise=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
q, v = self.to_qv(x).chunk(2, dim=1)
q = self.attn(q)
x = self.proj(q * v)
return x
#############################
# Local Blocks
#############################
class RME(nn.Module):
def __init__(self,
in_ch: int,
num_experts: int,
topk: int,
lr_space: int = 1,
recursive: int = 2,
use_shuffle: bool = False,):
super().__init__()
self.norm_1 = LayerNorm(in_ch, data_format='channels_first')
self.block = MoEBlock(in_ch=in_ch, num_experts=num_experts, topk=topk, use_shuffle=use_shuffle, recursive=recursive, lr_space=lr_space,)
self.norm_2 = LayerNorm(in_ch, data_format='channels_first')
self.ffn = GatedFFN(in_ch, mlp_ratio=2, kernel_size=3, act_layer=nn.GELU())
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.block(self.norm_1(x)) + x
x = self.ffn(self.norm_2(x)) + x
return x
#################
# MoE Layer
#################
class MoEBlock(nn.Module):
def __init__(self,
in_ch: int,
num_experts: int,
topk: int,
use_shuffle: bool = False,
lr_space: str = "linear",
recursive: int = 2):
super().__init__()
self.use_shuffle = use_shuffle
self.recursive = recursive
self.conv_1 = nn.Sequential(
nn.Conv2d(in_ch, in_ch, kernel_size=3, padding=1),
nn.GELU(),
nn.Conv2d(in_ch, 2*in_ch, kernel_size=1, padding=0)
)
self.agg_conv = nn.Sequential(
nn.Conv2d(in_ch, in_ch, kernel_size=4, stride=4, groups=in_ch),
nn.GELU())
self.conv = nn.Sequential(
nn.Conv2d(in_ch, in_ch, kernel_size=3, stride=1, padding=1, groups=in_ch),
nn.Conv2d(in_ch, in_ch, kernel_size=1, padding=0)
)
self.conv_2 = nn.Sequential(
StripedConv2d(in_ch, kernel_size=3, depthwise=True),
nn.GELU())
if lr_space == "linear":
grow_func = lambda i: i+2
elif lr_space == "exp":
grow_func = lambda i: 2**(i+1)
elif lr_space == "double":
grow_func = lambda i: 2*i+2
else:
raise NotImplementedError(f"lr_space {lr_space} not implemented")
self.moe_layer = MoELayer(
experts=[Expert(in_ch=in_ch, low_dim=grow_func(i)) for i in range(num_experts)], # add here multiple of 2 as low_dim
gate=Router(in_ch=in_ch, num_experts=num_experts),
num_expert=topk,
)
self.proj = nn.Conv2d(in_ch, in_ch, kernel_size=1, padding=0)
def calibrate(self, x: torch.Tensor) -> torch.Tensor:
b, c, h, w = x.shape
res = x
for _ in range(self.recursive):
x = self.agg_conv(x)
x = self.conv(x)
x = F.interpolate(x, size=(h, w), mode="bilinear", align_corners=False)
return res + x
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv_1(x)
if self.use_shuffle:
x = channel_shuffle(x, groups=2)
x, k = torch.chunk(x, chunks=2, dim=1)
x = self.conv_2(x)
k = self.calibrate(k)
x = self.moe_layer(x, k)
x = self.proj(x)
return x
class MoELayer(nn.Module):
def __init__(self, experts: List[nn.Module], gate: nn.Module, num_expert: int = 1):
super().__init__()
assert len(experts) > 0
self.experts = nn.ModuleList(experts)
self.gate = gate
self.num_expert = num_expert
def forward(self, inputs: torch.Tensor, k: torch.Tensor):
out = self.gate(inputs)
weights = F.softmax(out, dim=1, dtype=torch.float).to(inputs.dtype)
topk_weights, topk_experts = torch.topk(weights, self.num_expert)
out = inputs.clone()
if self.training:
exp_weights = torch.zeros_like(weights)
exp_weights.scatter_(1, topk_experts, weights.gather(1, topk_experts))
for i, expert in enumerate(self.experts):
out += expert(inputs, k) * exp_weights[:, i:i+1, None, None]
else:
selected_experts = [self.experts[i] for i in topk_experts.squeeze(dim=0)]
for i, expert in enumerate(selected_experts):
out += expert(inputs, k) * topk_weights[:, i:i+1, None, None]
return out
class Expert(nn.Module):
def __init__(self,
in_ch: int,
low_dim: int,):
super().__init__()
self.conv_1 = nn.Conv2d(in_ch, low_dim, kernel_size=1, padding=0)
self.conv_2 = nn.Conv2d(in_ch, low_dim, kernel_size=1, padding=0)
self.conv_3 = nn.Conv2d(low_dim, in_ch, kernel_size=1, padding=0)
def forward(self, x: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
x = self.conv_1(x)
x = self.conv_2(k) * x # here no more sigmoid
x = self.conv_3(x)
return x
class Router(nn.Module):
def __init__(self,
in_ch: int,
num_experts: int):
super().__init__()
self.body = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
Rearrange('b c 1 1 -> b c'),
nn.Linear(in_ch, num_experts, bias=False),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.body(x)
#################
# Utilities
#################
class StripedConv2d(nn.Module):
def __init__(self,
in_ch: int,
kernel_size: int,
depthwise: bool = False):
super().__init__()
self.in_ch = in_ch
self.kernel_size = kernel_size
self.padding = kernel_size // 2
self.conv = nn.Sequential(
nn.Conv2d(in_ch, in_ch, kernel_size=(1, self.kernel_size), padding=(0, self.padding), groups=in_ch if depthwise else 1),
nn.Conv2d(in_ch, in_ch, kernel_size=(self.kernel_size, 1), padding=(self.padding, 0), groups=in_ch if depthwise else 1),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.conv(x)
def channel_shuffle(x, groups=2):
bat_size, channels, w, h = x.shape
group_c = channels // groups
x = x.view(bat_size, groups, group_c, w, h)
x = torch.transpose(x, 1, 2).contiguous()
x = x.view(bat_size, -1, w, h)
return x
class GatedFFN(nn.Module):
def __init__(self,
in_ch,
mlp_ratio,
kernel_size,
act_layer,):
super().__init__()
mlp_ch = in_ch * mlp_ratio
self.fn_1 = nn.Sequential(
nn.Conv2d(in_ch, mlp_ch, kernel_size=1, padding=0),
act_layer,
)
self.fn_2 = nn.Sequential(
nn.Conv2d(in_ch, in_ch, kernel_size=1, padding=0),
act_layer,
)
self.gate = nn.Conv2d(mlp_ch // 2, mlp_ch // 2,
kernel_size=kernel_size, padding=kernel_size // 2, groups=mlp_ch // 2)
def feat_decompose(self, x):
s = x - self.gate(x)
x = x + self.sigma * s
return x
def forward(self, x: torch.Tensor):
x = self.fn_1(x)
x, gate = torch.chunk(x, 2, dim=1)
gate = self.gate(gate)
x = x * gate
x = self.fn_2(x)
return x
class LayerNorm(nn.Module):
r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
with shape (batch_size, channels, height, width).
"""
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError
self.normalized_shape = (normalized_shape, )
def forward(self, x):
if self.data_format == "channels_last":
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
elif self.data_format == "channels_first":
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x