Jiading Fang
add define
fc16538
# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
from abc import ABC
import torch
import torch.nn as nn
class PoseDecoder(nn.Module, ABC):
"""
Pose decoder network
Parameters
----------
cfg : Config
Configuration with parameters
"""
def __init__(self, num_ch_enc, num_input_features,
num_frames_to_predict_for=None,
stride=1, output_multiplier=0.01):
super().__init__()
self.num_encoder_channels = num_ch_enc
self.num_input_features = num_input_features
self.output_multiplier = output_multiplier
if num_frames_to_predict_for is None:
num_frames_to_predict_for = num_input_features - 1
self.num_output_predictions = num_frames_to_predict_for
self.convs = {
'squeeze': nn.Conv2d(self.num_encoder_channels[-1], 256, 1),
('pose', 0): nn.Conv2d(num_input_features * 256, 256, 3, stride, 1),
('pose', 1): nn.Conv2d(256, 256, 3, stride, 1),
('pose', 2): nn.Conv2d(256, 6 * num_frames_to_predict_for, 1),
}
self.net = nn.ModuleList(list(self.convs.values()))
self.relu = nn.ReLU()
def forward(self, all_features):
"""Network forward pass"""
last_features = [f[-1] for f in all_features]
last_features = [self.relu(self.convs['squeeze'](f)) for f in last_features]
cat_features = torch.cat(last_features, 1)
for i in range(3):
cat_features = self.convs[('pose', i)](cat_features)
if i < 2:
cat_features = self.relu(cat_features)
output = self.output_multiplier * \
cat_features.mean(3).mean(2).view(-1, self.num_output_predictions, 1, 6)
return torch.split(output, split_size_or_sections=3, dim=-1)