Object Detection
vision
nanotracker-hf / qat_core /functions.py
sonebu
update email
df8cf63
###########################################################################
# Computer vision - Embedded person tracking demo software by HyperbeeAI. #
# Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. hello@hyperbee.ai #
###########################################################################
from abc import ABC
import torch, sys
import torch.nn as nn
from torch.autograd import Function
###################################################
### Quantization Functions
### backward passes are straight through
## Up-Down (ud) quantization for wide last layer ("bigdata"). Used in QAT
class Q_ud_wide(Function):
@staticmethod
def forward(_, x, xb, extrab):
up_factor = 2 ** (xb - extrab - 1)
down_factor = 2 ** (xb - 1)
return x.mul(up_factor).add(.5).floor().div(down_factor)
@staticmethod
def backward(_, x):
return x, None, None
## Up-Down (ud) quantization. Used in QAT
class Q_ud(Function):
@staticmethod
def forward(_, x, xb):
updown_factor = 2 ** (xb - 1)
return x.mul(updown_factor).add(.5).floor().div(updown_factor)
@staticmethod
def backward(_, x):
return x, None
## Up-Down (ud) quantization for antipodal binary. Used in qat-ap
class Q_ud_ap(Function):
@staticmethod
def forward(_, x):
x = torch.sign(x).div(2.0) # antipodal (-1,+1) weights @HW correspond to (-0.5,+0.5) in qat
mask = (x == 0)
return x - mask.float().div(2.0)
@staticmethod
def backward(_, x):
return x
## Up (u) quantization. Used in Eval/hardware
class Q_u(Function):
@staticmethod
def forward(_, x, xb):
up_factor = 2 ** (8 - xb)
return x.mul(up_factor).add(.5).floor() ### Burak: maxim has a .add(0.5) at the beginning, I think that's wrong
@staticmethod
def backward(_, x):
return x, None
## Down (d) quantization. Used in Eval/hardware
class Q_d(Function):
@staticmethod
def forward(_, x, xb):
down_factor = 2 ** (xb-1)
return x.div(down_factor).add(.5).floor() ### Burak: maxim has a .add(0.5) at the beginning, I think that's wrong
@staticmethod
def backward(_, x):
return x, None
###################################################
### Quantization module
### ("umbrella" for Functions)
class quantization(nn.Module):
def __init__(self, xb=8, mode='updown', wide=False, m=None, g=None):
super().__init__()
self.xb = xb
self.mode = mode
self.wide = wide
self.m = m
self.g = g
def forward(self, x):
### Deniz: Wide mode was not functioning as expected, so changed with the code from older repo
if(self.mode=='updown'):
if(self.wide):
### Burak: maxim's implementation had the third argument as +1, which was wrong.
### the chip seems to be adding 5 more bits to the fractional part
return Q_ud_wide.apply(x, self.xb, -5)
else:
return Q_ud.apply(x, self.xb)
elif(self.mode=='down'):
if(self.wide):
### Burak: maxim's implementation had the second argument as (self.xb + 1), which was wrong.
### the chip seems to be adding 5 more bits to the fractional part
return Q_d.apply(x, self.xb - 5)
else:
return Q_d.apply(x, self.xb)
elif (self.mode == 'up'):
return Q_u.apply(x, self.xb)
elif (self.mode == 'updown_ap'):
return Q_ud_ap.apply(x)
else:
print('wrong quantization mode. exiting')
sys.exit()
###################################################
### Clamping modules
### (doesn't need Functions since backward passes are well-defined)
class clamping_qa(nn.Module):
def __init__(self, xb=8, wide=False):
super().__init__()
if (wide):
self.min_val = -16384.0 ### Burak: this is wrong, but it's how maxim currently does it, so we play along
self.max_val = 16383.0 ### Burak: this is wrong, but it's how maxim currently does it, so we play along
else:
self.min_val = -1.0
self.max_val = (2 ** (xb - 1) - 1) / (2 ** (xb - 1))
def forward(self, x):
return x.clamp(min=self.min_val, max=self.max_val)
class clamping_hw(nn.Module):
def __init__(self, xb=8, wide=False):
super().__init__()
if(wide):
self.min_val = -2 ** (30-1) ### Burak: this is wrong, but it's how maxim currently does it, so we play along
self.max_val = 2 ** (30-1)-1 ### Burak: this is wrong, but it's how maxim currently does it, so we play along
else:
self.min_val = -2 ** (xb - 1)
self.max_val = 2 ** (xb - 1) - 1
def forward(self, x):
return x.clamp(min=self.min_val, max=self.max_val)
###################################################
### Computing output_shift, i.e., "los"
def calc_out_shift(weight, bias, shift_quantile):
weight_r = torch.flatten(weight)
if bias is not None:
bias_r = torch.flatten(bias)
params_r = torch.cat((weight_r, bias_r))
else:
params_r = weight_r
limit = torch.quantile(params_r.abs(), shift_quantile)
return -(1. / limit).log2().floor().clamp(min=-15., max=15.)
def calc_out_shift_rho(W):
# eqn. 22 in the AHA report
# this code segment is taken from the v1 repo, duygu-yavuz-dev branch
# layers.py shift_amount_1bit function
limit = torch.quantile(W.abs(), 1.0)
return - (1. / limit).log2().ceil().clamp(min=-15., max=15.)