mbar0075's picture
Testing Commit
c9baa67
from collections import OrderedDict
import torch
import torch.nn as nn
class OnesLayer(nn.Module):
def __init__(self, size=None):
super().__init__()
self.size = size
def forward(self, tensor):
shape = list(tensor.shape)
shape[1] = 1 # return only one channel
if self.size is not None:
shape[2], shape[3] = self.size
return torch.ones(shape, dtype=torch.float32, device=tensor.device)
class UninformativeFeatures(torch.nn.Sequential):
def __init__(self):
super().__init__(OrderedDict([
('ones', OnesLayer(size=(1, 1))),
]))