Spaces:
Running
Running
from collections import OrderedDict | |
import torch | |
import torch.nn as nn | |
class OnesLayer(nn.Module): | |
def __init__(self, size=None): | |
super().__init__() | |
self.size = size | |
def forward(self, tensor): | |
shape = list(tensor.shape) | |
shape[1] = 1 # return only one channel | |
if self.size is not None: | |
shape[2], shape[3] = self.size | |
return torch.ones(shape, dtype=torch.float32, device=tensor.device) | |
class UninformativeFeatures(torch.nn.Sequential): | |
def __init__(self): | |
super().__init__(OrderedDict([ | |
('ones', OnesLayer(size=(1, 1))), | |
])) | |