File size: 480 Bytes
5d21dd2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
import torch
import torch.nn as nn
class Quant(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
input = torch.clamp(input, 0, 1)
output = (input * 255.).round() / 255.
return output
@staticmethod
def backward(ctx, grad_output):
return grad_output
class Quantization(nn.Module):
def __init__(self):
super(Quantization, self).__init__()
def forward(self, input):
return Quant.apply(input)
|