Paul Engstler
Initial commit
92f0e98
import torch
import torch.nn as nn
import numpy as np
from monai.transforms import Compose
from typing import List, Tuple
def __define_coord_channels__(x_dim, y_dim, z_dim=None):
"""
Returns coord x and y channels from 0 to x_dim-1 and from 0 to y_dim -1
"""
if z_dim is None:
# original implementation (https://github.com/jmfacil/camconvs/blob/894add0858343e00da52143231bd30cda0f9f385/python/CAM/blocks/camconvs.py)
xx_ones = torch.ones([y_dim], dtype=torch.long)
xx_ones = torch.unsqueeze(xx_ones, -1)
xx_range = torch.tile(torch.unsqueeze(torch.arange(x_dim), 0), [1])
xx_range = torch.unsqueeze(xx_range, 1)
xx_channel = torch.matmul(xx_ones, xx_range)
yy_ones = torch.ones([x_dim], dtype=torch.long)
yy_ones = torch.unsqueeze(yy_ones, 1)
yy_range = torch.tile(torch.unsqueeze(torch.arange(y_dim), 0), [1])
yy_range = torch.unsqueeze(yy_range, -1)
yy_channel = torch.matmul(yy_range, yy_ones)
return xx_channel.float(), yy_channel.float()
else:
# simplified 3d version
x = torch.unsqueeze(torch.arange(y_dim)[None, :, None].repeat(x_dim, 1, z_dim), 0).float()
y = torch.unsqueeze(torch.arange(x_dim)[:, None, None].repeat(1, y_dim, z_dim), 0).float()
z = torch.unsqueeze(torch.arange(z_dim)[None, None, :].repeat(x_dim, y_dim, 1), 0).float()
return x, y, z
class PositionalEncoding3D(nn.Module):
def __init__(self, channels):
"""
:param channels: The last dimension of the tensor you want to apply pos emb to.
"""
super(PositionalEncoding3D, self).__init__()
channels = int(np.ceil(channels/6)*2)
if channels % 2:
channels += 1
self.channels = channels
inv_freq = 1. / (10000 ** (torch.arange(0, channels, 2).float() / channels))
self.register_buffer('inv_freq', inv_freq)
def forward(self, tensor):
"""
:param tensor: A 5d tensor of size (batch_size, x, y, z, ch)
:return: Positional Encoding Matrix of size (batch_size, x, y, z, ch)
"""
if len(tensor.shape) != 5:
raise RuntimeError("The input tensor has to be 5d!")
batch_size, x, y, z, orig_ch = tensor.shape
pos_x = torch.arange(x, device=tensor.device).type(self.inv_freq.type())
pos_y = torch.arange(y, device=tensor.device).type(self.inv_freq.type())
pos_z = torch.arange(z, device=tensor.device).type(self.inv_freq.type())
sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq)
sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq)
sin_inp_z = torch.einsum("i,j->ij", pos_z, self.inv_freq)
emb_x = torch.cat((sin_inp_x.sin(), sin_inp_x.cos()), dim=-1).unsqueeze(1).unsqueeze(1)
emb_y = torch.cat((sin_inp_y.sin(), sin_inp_y.cos()), dim=-1).unsqueeze(1)
emb_z = torch.cat((sin_inp_z.sin(), sin_inp_z.cos()), dim=-1)
emb = torch.zeros((x,y,z,self.channels*3),device=tensor.device).type(tensor.type())
emb[:,:,:,:self.channels] = emb_x
emb[:,:,:,self.channels:2*self.channels] = emb_y
emb[:,:,:,2*self.channels:] = emb_z
return emb[None,:,:,:,:orig_ch].repeat(batch_size, 1, 1, 1, 1)
class PositionalEncodingPermute3D(nn.Module):
def __init__(self, channels):
"""
Accepts (batchsize, ch, x, y, z) instead of (batchsize, x, y, z, ch)
"""
super(PositionalEncodingPermute3D, self).__init__()
self.penc = PositionalEncoding3D(channels)
def forward(self, tensor):
tensor = tensor.permute(0,2,3,4,1)
enc = self.penc(tensor)
return enc.permute(0,4,1,2,3)
class AddCoordinateChannels:
def __init__(self, to_key: str, input_size: int, input_dim: int, sinusodal: bool = False):
self.to_key = to_key
self.input_size = input_size
self.input_dim = input_dim
self.sinusodal = sinusodal
def get_normalized_coordinate_channels(self) -> np.ndarray:
if not self.sinusodal:
channels = __define_coord_channels__(*((self.input_size,)*self.input_dim))
normalize = lambda channel: (channel/(self.input_size-1))*2.0-1.0
channels = [normalize(c) for c in channels]
return torch.cat(channels).numpy()
else:
pep = PositionalEncodingPermute3D(self.input_dim)
channels = pep(torch.zeros(1, self.input_dim, *(((self.input_size,)*self.input_dim))))
return channels.squeeze(0).numpy()
def __call__(self, data):
d = dict(data)
return {**d, self.to_key: self.get_normalized_coordinate_channels()}
def get_normalized_coordinates_transform(hparams, loaded_keys) -> Tuple[Compose, List[str]]:
if hparams.coordinates:
return Compose([
AddCoordinateChannels("coordinates", hparams.input_size, hparams.input_dim, sinusodal=True)
]), loaded_keys + ["coordinates"]
else:
return Compose([]), loaded_keys