FunSR / models /metasr.py
KyanChen's picture
add
02c5426
raw
history blame
2.3 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import models
from models import register
from utils import make_coord
@register('metasr')
class MetaSR(nn.Module):
def __init__(self, encoder_spec):
super().__init__()
self.encoder = models.make(encoder_spec)
imnet_spec = {
'name': 'mlp',
'args': {
'in_dim': 3,
'out_dim': self.encoder.out_dim * 9 * 3,
'hidden_list': [256]
}
}
self.imnet = models.make(imnet_spec)
def gen_feat(self, inp):
self.feat = self.encoder(inp)
return self.feat
def query_rgb(self, coord, cell=None):
feat = self.feat
feat = F.unfold(feat, 3, padding=1).view(
feat.shape[0], feat.shape[1] * 9, feat.shape[2], feat.shape[3])
feat_coord = make_coord(feat.shape[-2:], flatten=False).cuda()
feat_coord[:, :, 0] -= (2 / feat.shape[-2]) / 2
feat_coord[:, :, 1] -= (2 / feat.shape[-1]) / 2
feat_coord = feat_coord.permute(2, 0, 1) \
.unsqueeze(0).expand(feat.shape[0], 2, *feat.shape[-2:])
coord_ = coord.clone()
coord_[:, :, 0] -= cell[:, :, 0] / 2
coord_[:, :, 1] -= cell[:, :, 1] / 2
coord_q = (coord_ + 1e-6).clamp(-1 + 1e-6, 1 - 1e-6)
q_feat = F.grid_sample(
feat, coord_q.flip(-1).unsqueeze(1),
mode='nearest', align_corners=False)[:, :, 0, :] \
.permute(0, 2, 1)
q_coord = F.grid_sample(
feat_coord, coord_q.flip(-1).unsqueeze(1),
mode='nearest', align_corners=False)[:, :, 0, :] \
.permute(0, 2, 1)
rel_coord = coord_ - q_coord
rel_coord[:, :, 0] *= feat.shape[-2] / 2
rel_coord[:, :, 1] *= feat.shape[-1] / 2
r_rev = cell[:, :, 0] * (feat.shape[-2] / 2)
inp = torch.cat([rel_coord, r_rev.unsqueeze(-1)], dim=-1)
bs, q = coord.shape[:2]
pred = self.imnet(inp.view(bs * q, -1)).view(bs * q, feat.shape[1], 3)
pred = torch.bmm(q_feat.contiguous().view(bs * q, 1, -1), pred)
pred = pred.view(bs, q, 3)
return pred
def forward(self, inp, coord, cell):
self.gen_feat(inp)
return self.query_rgb(coord, cell)