taesiri's picture
Initial Commit
8390f90
raw
history blame
1.48 kB
r""" Convolutional Hough Matching Networks """
import torch.nn as nn
import torch
from . import chmlearner as chmlearner
from .base import backbone
class CHMNet(nn.Module):
def __init__(self, ktype):
super(CHMNet, self).__init__()
self.backbone = backbone.resnet101(pretrained=True)
self.learner = chmlearner.CHMLearner(ktype, feat_dim=1024)
def forward(self, src_img, trg_img):
src_feat, trg_feat = self.extract_features(src_img, trg_img)
correlation = self.learner(src_feat, trg_feat)
return correlation
def extract_features(self, src_img, trg_img):
feat = self.backbone.conv1.forward(torch.cat([src_img, trg_img], dim=1))
feat = self.backbone.bn1.forward(feat)
feat = self.backbone.relu.forward(feat)
feat = self.backbone.maxpool.forward(feat)
for idx in range(1, 5):
feat = self.backbone.__getattr__('layer%d' % idx)(feat)
if idx == 3:
src_feat = feat.narrow(1, 0, feat.size(1) // 2).clone()
trg_feat = feat.narrow(1, feat.size(1) // 2, feat.size(1) // 2).clone()
return src_feat, trg_feat
def training_objective(cls, prd_kps, trg_kps, npts):
l2dist = (prd_kps - trg_kps).pow(2).sum(dim=1)
loss = []
for dist, npt in zip(l2dist, npts):
loss.append(dist[:npt].mean())
return torch.stack(loss).mean()