import torch | |
import torch.nn as nn | |
class Quant(torch.autograd.Function): | |
def forward(ctx, input): | |
input = torch.clamp(input, 0, 1) | |
output = (input * 255.).round() / 255. | |
return output | |
def backward(ctx, grad_output): | |
return grad_output | |
class Quantization(nn.Module): | |
def __init__(self): | |
super(Quantization, self).__init__() | |
def forward(self, input): | |
return Quant.apply(input) | |