|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" projection_layer.py """ |
|
from typing import Tuple |
|
|
|
import math |
|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
from torch.nn import Linear, LayerNorm |
|
|
|
from einops import rearrange |
|
from model.ops import count_parameters |
|
|
|
|
|
class GroupLinearFlatten(nn.Module): |
|
""" |
|
Implements a grouped linear layer with a flattened output. |
|
|
|
This module applies individual linear transformations for each group in the input tensor |
|
and then flattens the group dimension to produce the final output. It's useful when you |
|
have distinct groups in the input tensor and you want separate linear transformations for |
|
each of these groups. |
|
|
|
Args: |
|
- in_features (int): The number of input features per group. |
|
- flatten_out_features (int): The total number of flattened output features. This value must |
|
be divisible by num_groups. The actual number of output features |
|
per group is computed as flatten_out_features/num_groups. |
|
- num_groups (int): The number of distinct groups in the input tensor. |
|
- use_bmm (bool, optional): Whether to use batch matrix multiplication for computation. |
|
Default is True. |
|
|
|
Shape: |
|
- Input: (batch_size, sequence_length, num_groups, in_features) |
|
- Output: (batch_size, sequence_length, flatten_out_features) |
|
|
|
Examples: |
|
>>> m = GroupLinearFlatten(128, 512, 24) # |
|
>>> input = torch.randn(16, 10, 24, 128) # (B, T, C, F) |
|
>>> output = m(input) |
|
>>> output.size() |
|
torch.Size([16, 10, 512]) # (B, T, D) |
|
""" |
|
|
|
def __init__(self, in_features, flatten_out_features, num_groups, use_bmm=True): |
|
super().__init__() |
|
self.in_features = in_features |
|
self.flatten_out_features = flatten_out_features |
|
self.num_groups = num_groups |
|
self.use_bmm = use_bmm |
|
|
|
|
|
self.out_features_per_group = self.flatten_out_features // self.num_groups |
|
|
|
|
|
self.weight = nn.Parameter(torch.Tensor(num_groups, self.out_features_per_group, in_features)) |
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) |
|
|
|
def forward(self, input): |
|
|
|
|
|
|
|
batch_size, t, k, source_d = input.size() |
|
|
|
if self.use_bmm: |
|
|
|
input_reshaped = rearrange(input, 'b t k d -> k d (b t)') |
|
|
|
|
|
output_bmm = torch.bmm(self.weight, input_reshaped) |
|
|
|
|
|
output = rearrange(output_bmm, 'k d_out (b t) -> b t (k d_out)', b=batch_size, t=t, k=k) |
|
else: |
|
output = torch.einsum('bsgi,goi->bsgo', input, self.weight) |
|
output = rearrange(output, 'b t k d_out -> b t (k d_out)') |
|
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MultiChannelLinearProjection(nn.Module): |
|
|
|
def __init__(self, in_ch=26, in_dim=128, out_ch=13, out_dim=512): |
|
super().__init__() |
|
self.in_ch = in_ch |
|
self.in_dim = in_dim |
|
self.out_ch = out_ch |
|
self.out_dim = out_dim |
|
|
|
self.in_ch_per_group = in_ch // out_ch |
|
self.linear_in_ch = in_ch // self.in_ch_per_group |
|
self.linear_in_dim = in_dim * self.in_ch_per_group |
|
|
|
|
|
|
|
if in_dim * self.in_ch_per_group == out_dim: |
|
self.linear = nn.Identity() |
|
else: |
|
self.linear = nn.Linear(in_features=self.linear_in_dim, out_features=out_dim, bias=False) |
|
|
|
def forward(self, x): |
|
""" |
|
Args: |
|
x: (B, T, C, D) |
|
|
|
Returns: |
|
x: (B, C_target, T, D_target) |
|
""" |
|
x = rearrange(x, 'b t (c1 c2) d -> b c1 t (c2 d)', c1=self.linear_in_ch, c2=self.in_ch_per_group) |
|
return self.linear(x) |
|
|
|
|
|
def get_multi_channel_projection_layer(input_shape: Tuple[int], output_shape: Tuple[int], proj_type: str) -> nn.Module: |
|
""" This function returns one of the projection layers for multi-channel models.""" |
|
in_ch = input_shape[-2] |
|
in_dim = input_shape[-1] |
|
out_ch = output_shape[-2] |
|
out_dim = output_shape[-1] |
|
|
|
if proj_type == 'mc_shared_linear': |
|
return MultiChannelLinearProjection(in_ch, in_dim, out_ch, out_dim) |
|
|
|
|
|
def test_multi_channel_linear_projection(): |
|
x = torch.randn(2, 10, 26, 128) |
|
mclp = MultiChannelLinearProjection(in_ch=26, in_dim=128, out_ch=13, out_dim=256) |
|
assert type(nn.Identity()) == type(mclp.linear) |
|
assert mclp(x).shape == (2, 13, 10, 256) |
|
|
|
x = torch.randn(2, 10, 26, 128) |
|
mclp = MultiChannelLinearProjection(in_ch=26, in_dim=128, out_ch=13, out_dim=512) |
|
assert torch.nn.modules.linear.Linear == type(mclp.linear) |
|
assert mclp(x).shape == (2, 13, 10, 512) |
|
|
|
|
|
class FlattenMLP(nn.Module): |
|
|
|
def __init__(self, in_features, flatten_out_features, num_groups, hidden_dim=None, activation=None): |
|
super().__init__() |
|
|
|
self.in_features = in_features |
|
self.num_groups = num_groups |
|
|
|
|
|
self.flat_in_dim = in_features * num_groups |
|
if hidden_dim is None: |
|
hidden_dim = self.flat_in_dim // 2 |
|
self.hidden_dim = hidden_dim |
|
|
|
|
|
assert flatten_out_features % in_features == 0, "flatten_out_features should be divisible by in_features." |
|
|
|
|
|
self.layers = nn.Sequential(nn.Flatten(2, 3), nn.Linear(self.flat_in_dim, hidden_dim), nn.LayerNorm(hidden_dim), |
|
activation() if activation else nn.Identity(), nn.Linear(hidden_dim, hidden_dim), |
|
nn.LayerNorm(hidden_dim), |
|
activation() if activation else nn.Identity(), |
|
nn.Linear(hidden_dim, flatten_out_features)) |
|
|
|
def forward(self, x): |
|
|
|
return self.layers(x) |
|
|
|
|
|
class LinearProjection(nn.Module): |
|
|
|
def __init__(self, in_features, flatten_out_features, num_groups): |
|
super().__init__() |
|
|
|
|
|
self.flat_in_dim = in_features * num_groups |
|
self.projection_layer = nn.Linear(in_features=self.flat_in_dim, out_features=flatten_out_features, bias=False) |
|
|
|
def forward(self, x): |
|
|
|
batch_size, t, _, _ = x.size() |
|
x_flattened = x.reshape(batch_size, t, -1) |
|
return self.projection_layer(x_flattened) |
|
|
|
|
|
class DepthwiseConvProjection(nn.Module): |
|
|
|
def __init__(self, in_features, flatten_out_features, num_groups, depth): |
|
super().__init__() |
|
d_out = flatten_out_features // in_features |
|
|
|
self.conv = nn.Conv2d(in_channels=num_groups, |
|
out_channels=num_groups * d_out, |
|
kernel_size=(1, depth), |
|
groups=num_groups) |
|
|
|
self.fc = nn.Linear(num_groups * d_out * (in_features - depth + 1), flatten_out_features) |
|
|
|
def forward(self, x): |
|
|
|
x = x.permute(0, 2, 1, 3) |
|
|
|
|
|
x = self.conv(x) |
|
|
|
|
|
batch_size, _, t, _ = x.size() |
|
x = x.reshape(batch_size, t, -1) |
|
return self.fc(x) |
|
|
|
|
|
def get_projection_layer(input_shape: Tuple[int], output_shape: Tuple[int], proj_type: str) -> nn.Module: |
|
""" This function returns one of the projection layers defined below. """ |
|
if len(input_shape) == 2: |
|
_, d_source = input_shape |
|
elif len(input_shape) == 3: |
|
_, k_source, d_source = input_shape |
|
if len(output_shape) == 2: |
|
_, d_target = output_shape |
|
elif len(output_shape) == 3: |
|
_, k_target, d_target = output_shape |
|
|
|
if 'linear' == proj_type: |
|
return LinearProjection(in_features=d_source, flatten_out_features=d_target, num_groups=k_source) |
|
elif 'mlp' in proj_type: |
|
if 'gelu' in proj_type: |
|
return FlattenMLP(in_features=d_source, |
|
flatten_out_features=d_target, |
|
num_groups=k_source, |
|
activation=nn.GELU) |
|
elif 'relu' in proj_type: |
|
return FlattenMLP(in_features=d_source, |
|
flatten_out_features=d_target, |
|
num_groups=k_source, |
|
activation=nn.ReLU) |
|
else: |
|
return FlattenMLP(in_features=d_source, flatten_out_features=d_target, num_groups=k_source, activation=None) |
|
elif 'conv' in proj_type: |
|
if 'conv4' == proj_type: |
|
return DepthwiseConvProjection(in_features=d_source, |
|
flatten_out_features=d_target, |
|
num_groups=k_source, |
|
depth=4) |
|
elif 'conv16' == proj_type: |
|
return DepthwiseConvProjection(in_features=d_source, |
|
flatten_out_features=d_target, |
|
num_groups=k_source, |
|
depth=16) |
|
elif 'conv32' == proj_type: |
|
return DepthwiseConvProjection(in_features=d_source, |
|
flatten_out_features=d_target, |
|
num_groups=k_source, |
|
depth=32) |
|
elif 'conv64' == proj_type: |
|
return DepthwiseConvProjection(in_features=d_source, |
|
flatten_out_features=d_target, |
|
num_groups=k_source, |
|
depth=64) |
|
else: |
|
return DepthwiseConvProjection(in_features=d_source, |
|
flatten_out_features=d_target, |
|
num_groups=k_source, |
|
depth=1) |
|
elif 'group_linear' == proj_type: |
|
assert d_source % k_source == 0, "d_source and k_source must be divisible for group_linear projection." |
|
return GroupLinearFlatten(in_features=d_source, |
|
flatten_out_features=d_target, |
|
num_groups=k_source, |
|
use_bmm=True) |
|
else: |
|
raise ValueError(f"Invalid projection type: {proj_type}") |
|
|
|
|
|
def test_projection_layers(): |
|
|
|
b = 2 |
|
t = 110 |
|
k = 24 |
|
d = 128 |
|
enc_hs = torch.randn(b, t, k, d) |
|
|
|
|
|
target_flatten_d = 512 |
|
|
|
|
|
gl = GroupLinearFlatten(in_features=d, flatten_out_features=target_flatten_d, num_groups=k, use_bmm=True) |
|
enc_hs_hat = gl(enc_hs) |
|
assert enc_hs_hat.shape == (b, t, target_flatten_d) |
|
print('GroupLinear: ', f'{count_parameters(gl)//1000}k') |
|
|
|
|
|
fm = FlattenMLP(in_features=d, |
|
flatten_out_features=target_flatten_d, |
|
num_groups=k, |
|
hidden_dim=None, |
|
activation=nn.GELU) |
|
enc_hs_hat = fm(enc_hs) |
|
assert enc_hs_hat.shape == (b, t, target_flatten_d) |
|
print('FlattenMLP: ', f'{count_parameters(fm)//1000}k') |
|
|
|
|
|
lp = LinearProjection(in_features=d, flatten_out_features=target_flatten_d, num_groups=k) |
|
enc_hs_hat = lp(enc_hs) |
|
assert enc_hs_hat.shape == (b, t, target_flatten_d) |
|
print('LinearProjection: ', f'{count_parameters(lp)//1000}k') |
|
|
|
|
|
dc = DepthwiseConvProjection(in_features=d, flatten_out_features=target_flatten_d, num_groups=k, depth=16) |
|
enc_hs_hat = dc(enc_hs) |
|
assert enc_hs_hat.shape == (b, t, target_flatten_d) |
|
print('DepthwiseConvProjection: ', f'{count_parameters(dc)//1000}k') |
|
|