import torch import torch.nn as nn from . import SparseTensor __all__ = [ 'SparseReLU', 'SparseSiLU', 'SparseGELU', 'SparseActivation' ] class SparseReLU(nn.ReLU): def forward(self, input: SparseTensor) -> SparseTensor: return input.replace(super().forward(input.feats)) class SparseSiLU(nn.SiLU): def forward(self, input: SparseTensor) -> SparseTensor: return input.replace(super().forward(input.feats)) class SparseGELU(nn.GELU): def forward(self, input: SparseTensor) -> SparseTensor: return input.replace(super().forward(input.feats)) class SparseActivation(nn.Module): def __init__(self, activation: nn.Module): super().__init__() self.activation = activation def forward(self, input: SparseTensor) -> SparseTensor: return input.replace(self.activation(input.feats))