|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
import models |
|
from models import register |
|
from utils import make_coord |
|
|
|
|
|
@register('metasr') |
|
class MetaSR(nn.Module): |
|
|
|
def __init__(self, encoder_spec): |
|
super().__init__() |
|
|
|
self.encoder = models.make(encoder_spec) |
|
imnet_spec = { |
|
'name': 'mlp', |
|
'args': { |
|
'in_dim': 3, |
|
'out_dim': self.encoder.out_dim * 9 * 3, |
|
'hidden_list': [256] |
|
} |
|
} |
|
self.imnet = models.make(imnet_spec) |
|
|
|
def gen_feat(self, inp): |
|
self.feat = self.encoder(inp) |
|
return self.feat |
|
|
|
def query_rgb(self, coord, cell=None): |
|
feat = self.feat |
|
feat = F.unfold(feat, 3, padding=1).view( |
|
feat.shape[0], feat.shape[1] * 9, feat.shape[2], feat.shape[3]) |
|
|
|
feat_coord = make_coord(feat.shape[-2:], flatten=False).cuda() |
|
feat_coord[:, :, 0] -= (2 / feat.shape[-2]) / 2 |
|
feat_coord[:, :, 1] -= (2 / feat.shape[-1]) / 2 |
|
feat_coord = feat_coord.permute(2, 0, 1) \ |
|
.unsqueeze(0).expand(feat.shape[0], 2, *feat.shape[-2:]) |
|
|
|
coord_ = coord.clone() |
|
coord_[:, :, 0] -= cell[:, :, 0] / 2 |
|
coord_[:, :, 1] -= cell[:, :, 1] / 2 |
|
coord_q = (coord_ + 1e-6).clamp(-1 + 1e-6, 1 - 1e-6) |
|
q_feat = F.grid_sample( |
|
feat, coord_q.flip(-1).unsqueeze(1), |
|
mode='nearest', align_corners=False)[:, :, 0, :] \ |
|
.permute(0, 2, 1) |
|
q_coord = F.grid_sample( |
|
feat_coord, coord_q.flip(-1).unsqueeze(1), |
|
mode='nearest', align_corners=False)[:, :, 0, :] \ |
|
.permute(0, 2, 1) |
|
|
|
rel_coord = coord_ - q_coord |
|
rel_coord[:, :, 0] *= feat.shape[-2] / 2 |
|
rel_coord[:, :, 1] *= feat.shape[-1] / 2 |
|
|
|
r_rev = cell[:, :, 0] * (feat.shape[-2] / 2) |
|
inp = torch.cat([rel_coord, r_rev.unsqueeze(-1)], dim=-1) |
|
|
|
bs, q = coord.shape[:2] |
|
pred = self.imnet(inp.view(bs * q, -1)).view(bs * q, feat.shape[1], 3) |
|
pred = torch.bmm(q_feat.contiguous().view(bs * q, 1, -1), pred) |
|
pred = pred.view(bs, q, 3) |
|
return pred |
|
|
|
def forward(self, inp, coord, cell): |
|
self.gen_feat(inp) |
|
return self.query_rgb(coord, cell) |
|
|