|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
from monai.transforms import Compose |
|
from typing import List, Tuple |
|
|
|
def __define_coord_channels__(x_dim, y_dim, z_dim=None): |
|
""" |
|
Returns coord x and y channels from 0 to x_dim-1 and from 0 to y_dim -1 |
|
""" |
|
|
|
if z_dim is None: |
|
|
|
xx_ones = torch.ones([y_dim], dtype=torch.long) |
|
xx_ones = torch.unsqueeze(xx_ones, -1) |
|
xx_range = torch.tile(torch.unsqueeze(torch.arange(x_dim), 0), [1]) |
|
xx_range = torch.unsqueeze(xx_range, 1) |
|
xx_channel = torch.matmul(xx_ones, xx_range) |
|
|
|
yy_ones = torch.ones([x_dim], dtype=torch.long) |
|
yy_ones = torch.unsqueeze(yy_ones, 1) |
|
yy_range = torch.tile(torch.unsqueeze(torch.arange(y_dim), 0), [1]) |
|
yy_range = torch.unsqueeze(yy_range, -1) |
|
yy_channel = torch.matmul(yy_range, yy_ones) |
|
|
|
return xx_channel.float(), yy_channel.float() |
|
|
|
else: |
|
|
|
x = torch.unsqueeze(torch.arange(y_dim)[None, :, None].repeat(x_dim, 1, z_dim), 0).float() |
|
y = torch.unsqueeze(torch.arange(x_dim)[:, None, None].repeat(1, y_dim, z_dim), 0).float() |
|
z = torch.unsqueeze(torch.arange(z_dim)[None, None, :].repeat(x_dim, y_dim, 1), 0).float() |
|
|
|
return x, y, z |
|
|
|
class PositionalEncoding3D(nn.Module): |
|
def __init__(self, channels): |
|
""" |
|
:param channels: The last dimension of the tensor you want to apply pos emb to. |
|
""" |
|
super(PositionalEncoding3D, self).__init__() |
|
channels = int(np.ceil(channels/6)*2) |
|
if channels % 2: |
|
channels += 1 |
|
self.channels = channels |
|
inv_freq = 1. / (10000 ** (torch.arange(0, channels, 2).float() / channels)) |
|
self.register_buffer('inv_freq', inv_freq) |
|
|
|
def forward(self, tensor): |
|
""" |
|
:param tensor: A 5d tensor of size (batch_size, x, y, z, ch) |
|
:return: Positional Encoding Matrix of size (batch_size, x, y, z, ch) |
|
""" |
|
if len(tensor.shape) != 5: |
|
raise RuntimeError("The input tensor has to be 5d!") |
|
batch_size, x, y, z, orig_ch = tensor.shape |
|
pos_x = torch.arange(x, device=tensor.device).type(self.inv_freq.type()) |
|
pos_y = torch.arange(y, device=tensor.device).type(self.inv_freq.type()) |
|
pos_z = torch.arange(z, device=tensor.device).type(self.inv_freq.type()) |
|
sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq) |
|
sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq) |
|
sin_inp_z = torch.einsum("i,j->ij", pos_z, self.inv_freq) |
|
emb_x = torch.cat((sin_inp_x.sin(), sin_inp_x.cos()), dim=-1).unsqueeze(1).unsqueeze(1) |
|
emb_y = torch.cat((sin_inp_y.sin(), sin_inp_y.cos()), dim=-1).unsqueeze(1) |
|
emb_z = torch.cat((sin_inp_z.sin(), sin_inp_z.cos()), dim=-1) |
|
emb = torch.zeros((x,y,z,self.channels*3),device=tensor.device).type(tensor.type()) |
|
emb[:,:,:,:self.channels] = emb_x |
|
emb[:,:,:,self.channels:2*self.channels] = emb_y |
|
emb[:,:,:,2*self.channels:] = emb_z |
|
|
|
return emb[None,:,:,:,:orig_ch].repeat(batch_size, 1, 1, 1, 1) |
|
|
|
class PositionalEncodingPermute3D(nn.Module): |
|
def __init__(self, channels): |
|
""" |
|
Accepts (batchsize, ch, x, y, z) instead of (batchsize, x, y, z, ch) |
|
""" |
|
super(PositionalEncodingPermute3D, self).__init__() |
|
self.penc = PositionalEncoding3D(channels) |
|
|
|
def forward(self, tensor): |
|
tensor = tensor.permute(0,2,3,4,1) |
|
enc = self.penc(tensor) |
|
return enc.permute(0,4,1,2,3) |
|
|
|
class AddCoordinateChannels: |
|
def __init__(self, to_key: str, input_size: int, input_dim: int, sinusodal: bool = False): |
|
self.to_key = to_key |
|
self.input_size = input_size |
|
self.input_dim = input_dim |
|
self.sinusodal = sinusodal |
|
|
|
def get_normalized_coordinate_channels(self) -> np.ndarray: |
|
if not self.sinusodal: |
|
channels = __define_coord_channels__(*((self.input_size,)*self.input_dim)) |
|
normalize = lambda channel: (channel/(self.input_size-1))*2.0-1.0 |
|
channels = [normalize(c) for c in channels] |
|
return torch.cat(channels).numpy() |
|
else: |
|
pep = PositionalEncodingPermute3D(self.input_dim) |
|
channels = pep(torch.zeros(1, self.input_dim, *(((self.input_size,)*self.input_dim)))) |
|
return channels.squeeze(0).numpy() |
|
|
|
def __call__(self, data): |
|
d = dict(data) |
|
return {**d, self.to_key: self.get_normalized_coordinate_channels()} |
|
|
|
def get_normalized_coordinates_transform(hparams, loaded_keys) -> Tuple[Compose, List[str]]: |
|
if hparams.coordinates: |
|
return Compose([ |
|
AddCoordinateChannels("coordinates", hparams.input_size, hparams.input_dim, sinusodal=True) |
|
]), loaded_keys + ["coordinates"] |
|
else: |
|
return Compose([]), loaded_keys |