|
|
|
|
|
|
|
|
|
from abc import ABC |
|
|
|
import torch, sys |
|
import torch.nn as nn |
|
from torch.autograd import Function |
|
|
|
|
|
|
|
|
|
|
|
|
|
class Q_ud_wide(Function): |
|
@staticmethod |
|
def forward(_, x, xb, extrab): |
|
up_factor = 2 ** (xb - extrab - 1) |
|
down_factor = 2 ** (xb - 1) |
|
return x.mul(up_factor).add(.5).floor().div(down_factor) |
|
|
|
@staticmethod |
|
def backward(_, x): |
|
return x, None, None |
|
|
|
|
|
|
|
class Q_ud(Function): |
|
@staticmethod |
|
def forward(_, x, xb): |
|
updown_factor = 2 ** (xb - 1) |
|
return x.mul(updown_factor).add(.5).floor().div(updown_factor) |
|
|
|
@staticmethod |
|
def backward(_, x): |
|
return x, None |
|
|
|
|
|
|
|
class Q_ud_ap(Function): |
|
@staticmethod |
|
def forward(_, x): |
|
x = torch.sign(x).div(2.0) |
|
mask = (x == 0) |
|
return x - mask.float().div(2.0) |
|
|
|
@staticmethod |
|
def backward(_, x): |
|
return x |
|
|
|
|
|
|
|
class Q_u(Function): |
|
@staticmethod |
|
def forward(_, x, xb): |
|
up_factor = 2 ** (8 - xb) |
|
return x.mul(up_factor).add(.5).floor() |
|
|
|
@staticmethod |
|
def backward(_, x): |
|
return x, None |
|
|
|
|
|
|
|
class Q_d(Function): |
|
@staticmethod |
|
def forward(_, x, xb): |
|
down_factor = 2 ** (xb-1) |
|
return x.div(down_factor).add(.5).floor() |
|
|
|
@staticmethod |
|
def backward(_, x): |
|
return x, None |
|
|
|
|
|
|
|
|
|
class quantization(nn.Module): |
|
def __init__(self, xb=8, mode='updown', wide=False, m=None, g=None): |
|
super().__init__() |
|
self.xb = xb |
|
self.mode = mode |
|
self.wide = wide |
|
self.m = m |
|
self.g = g |
|
|
|
def forward(self, x): |
|
|
|
if(self.mode=='updown'): |
|
if(self.wide): |
|
|
|
|
|
return Q_ud_wide.apply(x, self.xb, -5) |
|
else: |
|
return Q_ud.apply(x, self.xb) |
|
elif(self.mode=='down'): |
|
if(self.wide): |
|
|
|
|
|
return Q_d.apply(x, self.xb - 5) |
|
else: |
|
return Q_d.apply(x, self.xb) |
|
elif (self.mode == 'up'): |
|
return Q_u.apply(x, self.xb) |
|
elif (self.mode == 'updown_ap'): |
|
return Q_ud_ap.apply(x) |
|
else: |
|
print('wrong quantization mode. exiting') |
|
sys.exit() |
|
|
|
|
|
|
|
|
|
|
|
class clamping_qa(nn.Module): |
|
def __init__(self, xb=8, wide=False): |
|
super().__init__() |
|
if (wide): |
|
self.min_val = -16384.0 |
|
self.max_val = 16383.0 |
|
else: |
|
self.min_val = -1.0 |
|
self.max_val = (2 ** (xb - 1) - 1) / (2 ** (xb - 1)) |
|
|
|
def forward(self, x): |
|
return x.clamp(min=self.min_val, max=self.max_val) |
|
|
|
|
|
class clamping_hw(nn.Module): |
|
def __init__(self, xb=8, wide=False): |
|
super().__init__() |
|
if(wide): |
|
self.min_val = -2 ** (30-1) |
|
self.max_val = 2 ** (30-1)-1 |
|
else: |
|
self.min_val = -2 ** (xb - 1) |
|
self.max_val = 2 ** (xb - 1) - 1 |
|
|
|
def forward(self, x): |
|
return x.clamp(min=self.min_val, max=self.max_val) |
|
|
|
|
|
|
|
|
|
def calc_out_shift(weight, bias, shift_quantile): |
|
weight_r = torch.flatten(weight) |
|
if bias is not None: |
|
bias_r = torch.flatten(bias) |
|
params_r = torch.cat((weight_r, bias_r)) |
|
else: |
|
params_r = weight_r |
|
limit = torch.quantile(params_r.abs(), shift_quantile) |
|
return -(1. / limit).log2().floor().clamp(min=-15., max=15.) |
|
|
|
def calc_out_shift_rho(W): |
|
|
|
|
|
|
|
|
|
limit = torch.quantile(W.abs(), 1.0) |
|
return - (1. / limit).log2().ceil().clamp(min=-15., max=15.) |
|
|