kirch's picture
Duplicate from PAIR/Text2Video-Zero
508927a
import torch
from torch import nn
from torch.nn import functional as F
class Encoding(nn.Module):
"""Encoding Layer: a learnable residual encoder.
Input is of shape (batch_size, channels, height, width).
Output is of shape (batch_size, num_codes, channels).
Args:
channels: dimension of the features or feature channels
num_codes: number of code words
"""
def __init__(self, channels, num_codes):
super(Encoding, self).__init__()
# init codewords and smoothing factor
self.channels, self.num_codes = channels, num_codes
std = 1. / ((num_codes * channels)**0.5)
# [num_codes, channels]
self.codewords = nn.Parameter(
torch.empty(num_codes, channels,
dtype=torch.float).uniform_(-std, std),
requires_grad=True)
# [num_codes]
self.scale = nn.Parameter(
torch.empty(num_codes, dtype=torch.float).uniform_(-1, 0),
requires_grad=True)
@staticmethod
def scaled_l2(x, codewords, scale):
num_codes, channels = codewords.size()
batch_size = x.size(0)
reshaped_scale = scale.view((1, 1, num_codes))
expanded_x = x.unsqueeze(2).expand(
(batch_size, x.size(1), num_codes, channels))
reshaped_codewords = codewords.view((1, 1, num_codes, channels))
scaled_l2_norm = reshaped_scale * (
expanded_x - reshaped_codewords).pow(2).sum(dim=3)
return scaled_l2_norm
@staticmethod
def aggregate(assignment_weights, x, codewords):
num_codes, channels = codewords.size()
reshaped_codewords = codewords.view((1, 1, num_codes, channels))
batch_size = x.size(0)
expanded_x = x.unsqueeze(2).expand(
(batch_size, x.size(1), num_codes, channels))
encoded_feat = (assignment_weights.unsqueeze(3) *
(expanded_x - reshaped_codewords)).sum(dim=1)
return encoded_feat
def forward(self, x):
assert x.dim() == 4 and x.size(1) == self.channels
# [batch_size, channels, height, width]
batch_size = x.size(0)
# [batch_size, height x width, channels]
x = x.view(batch_size, self.channels, -1).transpose(1, 2).contiguous()
# assignment_weights: [batch_size, channels, num_codes]
assignment_weights = F.softmax(
self.scaled_l2(x, self.codewords, self.scale), dim=2)
# aggregate
encoded_feat = self.aggregate(assignment_weights, x, self.codewords)
return encoded_feat
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(Nx{self.channels}xHxW =>Nx{self.num_codes}' \
f'x{self.channels})'
return repr_str