FunSR / models /rs_super.py
KyanChen's picture
add
02c5426
raw
history blame
8.07 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
import models
from models import register
from utils import make_coord, to_coordinates
@register('rs_super')
class RSSuper(nn.Module):
def __init__(self,
encoder_spec,
neck=None,
decoder=None,
input_rgb=True,
n_forward_times=1,
global_decoder=None
):
super().__init__()
self.n_forward_times = n_forward_times
self.encoder = models.make(encoder_spec)
if neck is not None:
self.neck = models.make(neck, args={'in_dim': self.encoder.out_dim})
self.input_rgb = input_rgb
decoder_in_dim = 5 if self.input_rgb else 2
if decoder is not None:
self.decoder = models.make(decoder, args={'modulation_dim': self.neck.out_dim, 'in_dim': decoder_in_dim})
if global_decoder is not None:
decoder_in_dim = 5 if self.input_rgb else 2
self.decoder_is_proj = global_decoder.get('is_proj', False)
self.grid_global = global_decoder.get('grid_global', False)
self.global_decoder = models.make(global_decoder, args={'modulation_dim': self.neck.out_dim, 'in_dim': decoder_in_dim})
if self.decoder_is_proj:
self.input_proj = nn.Sequential(
nn.Linear(self.neck.out_dim, self.neck.out_dim)
)
self.output_proj = nn.Sequential(
nn.Linear(3, 3)
)
def query_rgb(self, coord, cell=None):
feat = self.feat
if self.imnet is None:
ret = F.grid_sample(feat, coord.flip(-1).unsqueeze(1),
mode='nearest', align_corners=False)[:, :, 0, :] \
.permute(0, 2, 1)
return ret
if self.feat_unfold:
feat = F.unfold(feat, 3, padding=1).view(
feat.shape[0], feat.shape[1] * 9, feat.shape[2], feat.shape[3])
if self.local_ensemble:
vx_lst = [-1, 1]
vy_lst = [-1, 1]
eps_shift = 1e-6
else:
vx_lst, vy_lst, eps_shift = [0], [0], 0
# field radius (global: [-1, 1])
rx = 2 / feat.shape[-2] / 2
ry = 2 / feat.shape[-1] / 2
feat_coord = make_coord(feat.shape[-2:], flatten=False).cuda() \
.permute(2, 0, 1) \
.unsqueeze(0).expand(feat.shape[0], 2, *feat.shape[-2:])
preds = []
areas = []
for vx in vx_lst:
for vy in vy_lst:
coord_ = coord.clone()
coord_[:, :, 0] += vx * rx + eps_shift
coord_[:, :, 1] += vy * ry + eps_shift
coord_.clamp_(-1 + 1e-6, 1 - 1e-6)
q_feat = F.grid_sample(
feat, coord_.flip(-1).unsqueeze(1),
mode='nearest', align_corners=False)[:, :, 0, :] \
.permute(0, 2, 1)
q_coord = F.grid_sample(
feat_coord, coord_.flip(-1).unsqueeze(1),
mode='nearest', align_corners=False)[:, :, 0, :] \
.permute(0, 2, 1)
rel_coord = coord - q_coord
rel_coord[:, :, 0] *= feat.shape[-2]
rel_coord[:, :, 1] *= feat.shape[-1]
inp = torch.cat([q_feat, rel_coord], dim=-1)
if self.cell_decode:
rel_cell = cell.clone()
rel_cell[:, :, 0] *= feat.shape[-2]
rel_cell[:, :, 1] *= feat.shape[-1]
inp = torch.cat([inp, rel_cell], dim=-1)
bs, q = coord.shape[:2]
pred = self.imnet(inp.view(bs * q, -1)).view(bs, q, -1)
preds.append(pred)
area = torch.abs(rel_coord[:, :, 0] * rel_coord[:, :, 1])
areas.append(area + 1e-9)
tot_area = torch.stack(areas).sum(dim=0)
if self.local_ensemble:
t = areas[0]; areas[0] = areas[3]; areas[3] = t
t = areas[1]; areas[1] = areas[2]; areas[2] = t
ret = 0
for pred, area in zip(preds, areas):
ret = ret + pred * (area / tot_area).unsqueeze(-1)
return ret
def forward_backbone_neck(self, inp, coord):
# inp: 64x3x32x32
# coord: BxNx2
feat = self.encoder(inp) # 64x64x32x32
global_content, x_rep = self.neck(feat) # Bx1xC; BxCxHxW
return feat, x_rep, global_content
def forward_step(self, inp, coord, feat, x_rep, global_content, pred_rgb_value=None):
weight_gen_func = 'bilinear' # 'bilinear'
# grid: 先x再y
coord_ = coord.clone().unsqueeze(1).flip(-1) # Bx1xNxC
modulations = F.grid_sample(x_rep, coord_, padding_mode='border', mode=weight_gen_func,
align_corners=True).squeeze(2) # B C N
modulations = rearrange(modulations, 'B C N -> (B N) C')
feat_coord = to_coordinates(feat.shape[-2:], return_map=True).to(inp.device)
feat_coord = repeat(feat_coord, 'H W C -> B C H W', B=inp.size(0)) # 坐标是[y, x]
nearest_coord = F.grid_sample(feat_coord, coord_, mode='nearest', align_corners=True).squeeze(2) # B 2 N
nearest_coord = rearrange(nearest_coord, 'B C N -> B N C') # B N 2
relative_coord = coord - nearest_coord
relative_coord[:, :, 0] *= feat.shape[-2]
relative_coord[:, :, 1] *= feat.shape[-1]
relative_coord = rearrange(relative_coord, 'B N C -> (B N) C')
decoder_input = relative_coord
interpolated_rgb = None
if self.input_rgb:
if pred_rgb_value is not None:
interpolated_rgb = rearrange(pred_rgb_value, 'B N C -> (B N) C')
else:
interpolated_rgb = F.grid_sample(inp, coord_, padding_mode='border', mode='bilinear', align_corners=True).squeeze(2) # B 3 N
interpolated_rgb = rearrange(interpolated_rgb, 'B C N -> (B N) C')
decoder_input = torch.cat((decoder_input, interpolated_rgb), dim=-1)
decoder_output = self.decoder(decoder_input, modulations)
decoder_output = rearrange(decoder_output, '(B N) C -> B N C', B=inp.size(0))
if hasattr(self, 'global_decoder'):
# coord: BxNx2
# global_content: Bx1xC
if self.decoder_is_proj:
global_content = self.input_proj(global_content) # B 1 C
global_modulations = repeat(global_content, 'B N C -> B (N S) C', S=coord.size(1))
global_modulations = rearrange(global_modulations, 'B N C -> (B N) C')
if self.grid_global:
# import pdb
# pdb.set_trace()
global_decoder_input = decoder_input
else:
global_decoder_input = rearrange(coord, 'B N C -> (B N) C')
if self.input_rgb:
global_decoder_input = torch.cat((global_decoder_input, interpolated_rgb), dim=-1)
global_decoder_output = self.global_decoder(global_decoder_input, global_modulations)
global_decoder_output = rearrange(global_decoder_output, '(B N) C -> B N C', B=inp.size(0))
if self.decoder_is_proj:
decoder_output = self.output_proj(global_decoder_output + decoder_output)
else:
decoder_output = global_decoder_output + decoder_output
return decoder_output
def forward(self, inp, coord):
# import pdb
# pdb.set_trace()
pred_rgb_value = None
feat, x_rep, global_content = self.forward_backbone_neck(inp, coord)
return_pred_rgb_value = []
for n_time in range(self.n_forward_times):
pred_rgb_value = self.forward_step(inp, coord, feat, x_rep, global_content, pred_rgb_value)
return_pred_rgb_value.append(pred_rgb_value)
return return_pred_rgb_value