EditGuard / models /modules /Quantization.py
Ricoooo's picture
'folder'
5d21dd2
raw
history blame
480 Bytes
import torch
import torch.nn as nn
class Quant(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
input = torch.clamp(input, 0, 1)
output = (input * 255.).round() / 255.
return output
@staticmethod
def backward(ctx, grad_output):
return grad_output
class Quantization(nn.Module):
def __init__(self):
super(Quantization, self).__init__()
def forward(self, input):
return Quant.apply(input)