Spaces:
Runtime error
Runtime error
# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. | |
from abc import ABC | |
import torch | |
import torch.nn as nn | |
class PoseDecoder(nn.Module, ABC): | |
""" | |
Pose decoder network | |
Parameters | |
---------- | |
cfg : Config | |
Configuration with parameters | |
""" | |
def __init__(self, num_ch_enc, num_input_features, | |
num_frames_to_predict_for=None, | |
stride=1, output_multiplier=0.01): | |
super().__init__() | |
self.num_encoder_channels = num_ch_enc | |
self.num_input_features = num_input_features | |
self.output_multiplier = output_multiplier | |
if num_frames_to_predict_for is None: | |
num_frames_to_predict_for = num_input_features - 1 | |
self.num_output_predictions = num_frames_to_predict_for | |
self.convs = { | |
'squeeze': nn.Conv2d(self.num_encoder_channels[-1], 256, 1), | |
('pose', 0): nn.Conv2d(num_input_features * 256, 256, 3, stride, 1), | |
('pose', 1): nn.Conv2d(256, 256, 3, stride, 1), | |
('pose', 2): nn.Conv2d(256, 6 * num_frames_to_predict_for, 1), | |
} | |
self.net = nn.ModuleList(list(self.convs.values())) | |
self.relu = nn.ReLU() | |
def forward(self, all_features): | |
"""Network forward pass""" | |
last_features = [f[-1] for f in all_features] | |
last_features = [self.relu(self.convs['squeeze'](f)) for f in last_features] | |
cat_features = torch.cat(last_features, 1) | |
for i in range(3): | |
cat_features = self.convs[('pose', i)](cat_features) | |
if i < 2: | |
cat_features = self.relu(cat_features) | |
output = self.output_multiplier * \ | |
cat_features.mean(3).mean(2).view(-1, self.num_output_predictions, 1, 6) | |
return torch.split(output, split_size_or_sections=3, dim=-1) | |