|
''' Moderator |
|
# Input feature: body, part(head, hand) |
|
# output: fused feature, weight |
|
''' |
|
import numpy as np |
|
import torch.nn as nn |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
|
|
|
class TempSoftmaxFusion(nn.Module): |
|
|
|
def __init__(self, |
|
channels=[2048 * 2, 1024, 1], |
|
detach_inputs=False, |
|
detach_feature=False): |
|
super(TempSoftmaxFusion, self).__init__() |
|
self.detach_inputs = detach_inputs |
|
self.detach_feature = detach_feature |
|
|
|
layers = [] |
|
for l in range(0, len(channels) - 1): |
|
layers.append(nn.Linear(channels[l], channels[l + 1])) |
|
if l < len(channels) - 2: |
|
layers.append(nn.ReLU()) |
|
self.layers = nn.Sequential(*layers) |
|
|
|
self.register_parameter('temperature', nn.Parameter(torch.ones(1))) |
|
|
|
def forward(self, x, y, work=True): |
|
''' |
|
x: feature from body |
|
y: feature from part(head/hand) |
|
work: whether to fuse features |
|
''' |
|
if work: |
|
|
|
f_in = torch.cat([x, y], dim=1) |
|
if self.detach_inputs: |
|
f_in = f_in.detach() |
|
f_temp = self.layers(f_in) |
|
f_weight = F.softmax(f_temp * self.temperature, dim=1) |
|
|
|
|
|
if self.detach_feature: |
|
x = x.detach() |
|
y = y.detach() |
|
f_out = f_weight[:, [0]] * x + f_weight[:, [1]] * y |
|
x_out = f_out |
|
y_out = f_out |
|
else: |
|
x_out = x |
|
y_out = y |
|
f_weight = None |
|
return x_out, y_out, f_weight |
|
|
|
|
|
|
|
|
|
|
|
|
|
class GumbelSoftmaxFusion(nn.Module): |
|
|
|
def __init__(self, |
|
channels=[2048 * 2, 1024, 1], |
|
detach_inputs=False, |
|
detach_feature=False): |
|
super(GumbelSoftmaxFusion, self).__init__() |
|
self.detach_inputs = detach_inputs |
|
self.detach_feature = detach_feature |
|
|
|
|
|
layers = [] |
|
for l in range(0, len(channels) - 1): |
|
layers.append(nn.Linear(channels[l], channels[l + 1])) |
|
if l < len(channels) - 2: |
|
layers.append(nn.ReLU()) |
|
layers.append(nn.Softmax()) |
|
self.layers = nn.Sequential(*layers) |
|
|
|
def forward(self, x, y, work=True): |
|
''' |
|
x: feature from body |
|
y: feature from part(head/hand) |
|
work: whether to fuse features |
|
''' |
|
if work: |
|
|
|
f_in = torch.cat([x, y], dim=-1) |
|
if self.detach_inputs: |
|
f_in = f_in.detach() |
|
f_weight = self.layers(f_in) |
|
|
|
f_weight = f_weight - f_weight.detach() + f_weight.gt(0.5) |
|
|
|
|
|
if self.detach_feature: |
|
x = x.detach() |
|
y = y.detach() |
|
f_out = f_weight[:, [0]] * x + f_weight[:, [1]] * y |
|
x_out = f_out |
|
y_out = f_out |
|
else: |
|
x_out = x |
|
y_out = y |
|
f_weight = None |
|
return x_out, y_out, f_weight |
|
|