FLUX-VisionReply / modules /inr_fea_res_lite.py
gokaygokay's picture
full_files
2f4febc
raw
history blame
16 kB
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import einops
import numpy as np
import models
from modules.common_ckpt import Linear, Conv2d, AttnBlock, ResBlock, LayerNorm2d
#from modules.common_ckpt import AttnBlock,
from einops import rearrange
import torch.fft as fft
from modules.speed_util import checkpoint
def batched_linear_mm(x, wb):
# x: (B, N, D1); wb: (B, D1 + 1, D2) or (D1 + 1, D2)
one = torch.ones(*x.shape[:-1], 1, device=x.device)
return torch.matmul(torch.cat([x, one], dim=-1), wb)
def make_coord_grid(shape, range, device=None):
"""
Args:
shape: tuple
range: [minv, maxv] or [[minv_1, maxv_1], ..., [minv_d, maxv_d]] for each dim
Returns:
grid: shape (*shape, )
"""
l_lst = []
for i, s in enumerate(shape):
l = (0.5 + torch.arange(s, device=device)) / s
if isinstance(range[0], list) or isinstance(range[0], tuple):
minv, maxv = range[i]
else:
minv, maxv = range
l = minv + (maxv - minv) * l
l_lst.append(l)
grid = torch.meshgrid(*l_lst, indexing='ij')
grid = torch.stack(grid, dim=-1)
return grid
def init_wb(shape):
weight = torch.empty(shape[1], shape[0] - 1)
nn.init.kaiming_uniform_(weight, a=math.sqrt(5))
bias = torch.empty(shape[1], 1)
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
nn.init.uniform_(bias, -bound, bound)
return torch.cat([weight, bias], dim=1).t().detach()
def init_wb_rewrite(shape):
weight = torch.empty(shape[1], shape[0] - 1)
torch.nn.init.xavier_uniform_(weight)
bias = torch.empty(shape[1], 1)
torch.nn.init.xavier_uniform_(bias)
return torch.cat([weight, bias], dim=1).t().detach()
class HypoMlp(nn.Module):
def __init__(self, depth, in_dim, out_dim, hidden_dim, use_pe, pe_dim, out_bias=0, pe_sigma=1024):
super().__init__()
self.use_pe = use_pe
self.pe_dim = pe_dim
self.pe_sigma = pe_sigma
self.depth = depth
self.param_shapes = dict()
if use_pe:
last_dim = in_dim * pe_dim
else:
last_dim = in_dim
for i in range(depth): # for each layer the weight
cur_dim = hidden_dim if i < depth - 1 else out_dim
self.param_shapes[f'wb{i}'] = (last_dim + 1, cur_dim)
last_dim = cur_dim
self.relu = nn.ReLU()
self.params = None
self.out_bias = out_bias
def set_params(self, params):
self.params = params
def convert_posenc(self, x):
w = torch.exp(torch.linspace(0, np.log(self.pe_sigma), self.pe_dim // 2, device=x.device))
x = torch.matmul(x.unsqueeze(-1), w.unsqueeze(0)).view(*x.shape[:-1], -1)
x = torch.cat([torch.cos(np.pi * x), torch.sin(np.pi * x)], dim=-1)
return x
def forward(self, x):
B, query_shape = x.shape[0], x.shape[1: -1]
x = x.view(B, -1, x.shape[-1])
if self.use_pe:
x = self.convert_posenc(x)
#print('in line 79 after pos embedding', x.shape)
for i in range(self.depth):
x = batched_linear_mm(x, self.params[f'wb{i}'])
if i < self.depth - 1:
x = self.relu(x)
else:
x = x + self.out_bias
x = x.view(B, *query_shape, -1)
return x
class Attention(nn.Module):
def __init__(self, dim, n_head, head_dim, dropout=0.):
super().__init__()
self.n_head = n_head
inner_dim = n_head * head_dim
self.to_q = nn.Sequential(
nn.SiLU(),
Linear(dim, inner_dim ))
self.to_kv = nn.Sequential(
nn.SiLU(),
Linear(dim, inner_dim * 2))
self.scale = head_dim ** -0.5
# self.to_out = nn.Sequential(
# Linear(inner_dim, dim),
# nn.Dropout(dropout),
# )
def forward(self, fr, to=None):
if to is None:
to = fr
q = self.to_q(fr)
k, v = self.to_kv(to).chunk(2, dim=-1)
q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> b h n d', h=self.n_head), [q, k, v])
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = F.softmax(dots, dim=-1) # b h n n
out = torch.matmul(attn, v)
out = einops.rearrange(out, 'b h n d -> b n (h d)')
return out
class FeedForward(nn.Module):
def __init__(self, dim, ff_dim, dropout=0.):
super().__init__()
self.net = nn.Sequential(
Linear(dim, ff_dim),
nn.GELU(),
#GlobalResponseNorm(ff_dim),
nn.Dropout(dropout),
Linear(ff_dim, dim)
)
def forward(self, x):
return self.net(x)
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x):
return self.fn(self.norm(x))
#TransInr(ind=2048, ch=256, n_head=16, head_dim=16, n_groups=64, f_dim=256, time_dim=self.c_r, t_conds = [])
class TransformerEncoder(nn.Module):
def __init__(self, dim, depth, n_head, head_dim, ff_dim, dropout=0.):
super().__init__()
self.layers = nn.ModuleList()
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, n_head, head_dim, dropout=dropout)),
PreNorm(dim, FeedForward(dim, ff_dim, dropout=dropout)),
]))
def forward(self, x):
for norm_attn, norm_ff in self.layers:
x = x + norm_attn(x)
x = x + norm_ff(x)
return x
class ImgrecTokenizer(nn.Module):
def __init__(self, input_size=32*32, patch_size=1, dim=768, padding=0, img_channels=16):
super().__init__()
if isinstance(patch_size, int):
patch_size = (patch_size, patch_size)
if isinstance(padding, int):
padding = (padding, padding)
self.patch_size = patch_size
self.padding = padding
self.prefc = nn.Linear(patch_size[0] * patch_size[1] * img_channels, dim)
self.posemb = nn.Parameter(torch.randn(input_size, dim))
def forward(self, x):
#print(x.shape)
p = self.patch_size
x = F.unfold(x, p, stride=p, padding=self.padding) # (B, C * p * p, L)
#print('in line 185 after unfoding', x.shape)
x = x.permute(0, 2, 1).contiguous()
ttt = self.prefc(x)
x = self.prefc(x) + self.posemb[:x.shape[1]].unsqueeze(0)
return x
class SpatialAttention(nn.Module):
def __init__(self, kernel_size=7):
super(SpatialAttention, self).__init__()
self.conv1 = Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
x = torch.cat([avg_out, max_out], dim=1)
x = self.conv1(x)
return self.sigmoid(x)
class TimestepBlock_res(nn.Module):
def __init__(self, c, c_timestep, conds=['sca']):
super().__init__()
self.mapper = Linear(c_timestep, c * 2)
self.conds = conds
for cname in conds:
setattr(self, f"mapper_{cname}", Linear(c_timestep, c * 2))
def forward(self, x, t):
#print(x.shape, t.shape, self.conds, 'in line 269')
t = t.chunk(len(self.conds) + 1, dim=1)
a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1)
for i, c in enumerate(self.conds):
ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1)
a, b = a + ac, b + bc
return x * (1 + a) + b
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
class ScaleNormalize_res(nn.Module):
def __init__(self, c, scale_c, conds=['sca']):
super().__init__()
self.c_r = scale_c
self.mapping = TimestepBlock_res(c, scale_c, conds=conds)
self.t_conds = conds
self.alpha = nn.Conv2d(c, c, kernel_size=1)
self.gamma = nn.Conv2d(c, c, kernel_size=1)
self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6)
def gen_r_embedding(self, r, max_positions=10000):
r = r * max_positions
half_dim = self.c_r // 2
emb = math.log(max_positions) / (half_dim - 1)
emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
emb = r[:, None] * emb[None, :]
emb = torch.cat([emb.sin(), emb.cos()], dim=1)
if self.c_r % 2 == 1: # zero pad
emb = nn.functional.pad(emb, (0, 1), mode='constant')
return emb
def forward(self, x, std_size=24*24):
scale_val = math.sqrt(math.log(x.shape[-2] * x.shape[-1], std_size))
scale_val = torch.ones(x.shape[0]).to(x.device)*scale_val
scale_val_f = self.gen_r_embedding(scale_val)
for c in self.t_conds:
t_cond = torch.zeros_like(scale_val)
scale_val_f = torch.cat([scale_val_f, self.gen_r_embedding(t_cond)], dim=1)
f = self.mapping(x, scale_val_f)
return f + x
class TransInr_withnorm(nn.Module):
def __init__(self, ind=2048, ch=16, n_head=12, head_dim=64, n_groups=64, f_dim=768, time_dim=2048, t_conds=[]):
super().__init__()
self.input_layer= nn.Conv2d(ind, ch, 1)
self.tokenizer = ImgrecTokenizer(dim=ch, img_channels=ch)
#self.hyponet = HypoMlp(depth=12, in_dim=2, out_dim=ch, hidden_dim=f_dim, use_pe=True, pe_dim=128)
#self.transformer_encoder = TransformerEncoder(dim=f_dim, depth=12, n_head=n_head, head_dim=f_dim // n_head, ff_dim=3*f_dim, )
self.hyponet = HypoMlp(depth=2, in_dim=2, out_dim=ch, hidden_dim=f_dim, use_pe=True, pe_dim=128)
self.transformer_encoder = TransformerEncoder(dim=f_dim, depth=1, n_head=n_head, head_dim=f_dim // n_head, ff_dim=f_dim)
#self.transformer_encoder = TransInr( ch=ch, n_head=16, head_dim=16, n_groups=64, f_dim=ch, time_dim=time_dim, t_conds = [])
self.base_params = nn.ParameterDict()
n_wtokens = 0
self.wtoken_postfc = nn.ModuleDict()
self.wtoken_rng = dict()
for name, shape in self.hyponet.param_shapes.items():
self.base_params[name] = nn.Parameter(init_wb(shape))
g = min(n_groups, shape[1])
assert shape[1] % g == 0
self.wtoken_postfc[name] = nn.Sequential(
nn.LayerNorm(f_dim),
nn.Linear(f_dim, shape[0] - 1),
)
self.wtoken_rng[name] = (n_wtokens, n_wtokens + g)
n_wtokens += g
self.wtokens = nn.Parameter(torch.randn(n_wtokens, f_dim))
self.output_layer= nn.Conv2d(ch, ind, 1)
self.mapp_t = TimestepBlock_res( ind, time_dim, conds = t_conds)
self.hr_norm = ScaleNormalize_res(ind, 64, conds=[])
self.normalize_final = nn.Sequential(
LayerNorm2d(ind, elementwise_affine=False, eps=1e-6),
)
self.toout = nn.Sequential(
Linear( ind*2, ind // 4),
nn.GELU(),
Linear( ind // 4, ind)
)
self.apply(self._init_weights)
mask = torch.zeros((1, 1, 32, 32))
h, w = 32, 32
center_h, center_w = h // 2, w // 2
low_freq_h, low_freq_w = h // 4, w // 4
mask[:, :, center_h-low_freq_h:center_h+low_freq_h, center_w-low_freq_w:center_w+low_freq_w] = 1
self.mask = mask
def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
torch.nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
#nn.init.constant_(self.last.weight, 0)
def adain(self, feature_a, feature_b):
norm_mean = torch.mean(feature_a, dim=(2, 3), keepdim=True)
norm_std = torch.std(feature_a, dim=(2, 3), keepdim=True)
#feature_a = F.interpolate(feature_a, feature_b.shape[2:])
feature_b = (feature_b - feature_b.mean(dim=(2, 3), keepdim=True)) / (1e-8 + feature_b.std(dim=(2, 3), keepdim=True)) * norm_std + norm_mean
return feature_b
def forward(self, target_shape, target, dtokens, t_emb):
#print(target.shape, dtokens.shape, 'in line 290')
hlr, wlr = dtokens.shape[2:]
original = dtokens
dtokens = self.input_layer(dtokens)
dtokens = self.tokenizer(dtokens)
B = dtokens.shape[0]
wtokens = einops.repeat(self.wtokens, 'n d -> b n d', b=B)
#print(wtokens.shape, dtokens.shape)
trans_out = self.transformer_encoder(torch.cat([dtokens, wtokens], dim=1))
trans_out = trans_out[:, -len(self.wtokens):, :]
params = dict()
for name, shape in self.hyponet.param_shapes.items():
wb = einops.repeat(self.base_params[name], 'n m -> b n m', b=B)
w, b = wb[:, :-1, :], wb[:, -1:, :]
l, r = self.wtoken_rng[name]
x = self.wtoken_postfc[name](trans_out[:, l: r, :])
x = x.transpose(-1, -2) # (B, shape[0] - 1, g)
w = F.normalize(w * x.repeat(1, 1, w.shape[2] // x.shape[2]), dim=1)
wb = torch.cat([w, b], dim=1)
params[name] = wb
coord = make_coord_grid(target_shape[2:], (-1, 1), device=dtokens.device)
coord = einops.repeat(coord, 'h w d -> b h w d', b=dtokens.shape[0])
self.hyponet.set_params(params)
ori_up = F.interpolate(original.float(), target_shape[2:])
hr_rec = self.output_layer(rearrange(self.hyponet(coord), 'b h w c -> b c h w')) + ori_up
#print(hr_rec.shape, target.shape, torch.cat((hr_rec, target), dim=1).permute(0, 2, 3, 1).shape, 'in line 537')
output = self.toout(torch.cat((hr_rec, target), dim=1).permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
#print(output.shape, 'in line 540')
#output = self.last(output.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)* 0.3
output = self.mapp_t(output, t_emb)
output = self.normalize_final(output)
output = self.hr_norm(output)
#output = self.last(output.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
#output = self.mapp_t(output, t_emb)
#output = self.weight(output) * output
return output
class LayerNorm2d(nn.LayerNorm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, x):
return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
class GlobalResponseNorm(nn.Module):
"from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105"
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
def forward(self, x):
Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
return self.gamma * (x * Nx) + self.beta + x
if __name__ == '__main__':
#ef __init__(self, ch, n_head, head_dim, n_groups):
trans_inr = TransInr(16, 24, 32, 64).cuda()
input = torch.randn((1, 16, 24, 24)).cuda()
source = torch.randn((1, 16, 16, 16)).cuda()
t = torch.randn((1, 128)).cuda()
output, hr = trans_inr(input, t, source)
total_up = sum([ param.nelement() for param in trans_inr.parameters()])
print(output.shape, hr.shape, total_up /1e6 )