SaulLu's picture
add website
baf9de9
# Quantization reduces a bit representation to less bits for efficient storage or computation.
# Most floating point data types have a mapping from a bit representation, e.g. 0010 = 2 to a floating
# point representation 2 -> 2 / max(0010) = 2/15 = 0.133333
# As such, we can represent a floating point quantization a mapping from integers to floating point values, e.g.
# [0, 1, 2, 3] -> [-1.0, -0.25, 0.25 , 1.0]
import numpy as np
from scipy.spatial.distance import cdist
index = np.array([0, 1, 2, 3, 4, 5, 6, 7])
values = np.linspace(-1.0, 1.0, 8) # 3-bit linear quantization
print('quantization values:', values)
# To quantize an input distribution we first need to normalize its range into the range of the quantization values, in this case [-1.0, 1.0]
# We can do this through division by the abolute maximum value if our distribution is roughly symmetric (most distribution in deep learning are noramlly distributed)
rand_inputs = np.random.randn(1024, 1024).astype(np.float32)
absmax = np.max(np.abs(rand_inputs))
normed = rand_inputs / absmax
print('normalized min and max range', np.min(normed), np.max(normed))
# The next step is to round the input value to the closest quantization value.
# This can be done by performing a binary search of each element of the normalized input tensor with respect to the sorted values array:
# In this case, we simply compute the distance between all values and find the closest directly.
dist = cdist(normed.flatten().reshape(-1, 1), values.reshape(-1, 1))
closest_idx = np.argmin(dist, 1).reshape(rand_inputs.shape)
val, count = np.unique(closest_idx, return_counts=True)
print('Values:', val)
print('Count:', count)
# Closest index now represents the quantized 3 bit representation (4 different values). We can use this representation to store the data efficiently.
# ==================DEQUANTIZATION========================
# To dequantize the tensor we reverse the operations the we did
# 1. lookup the values corresponding to the 3-bit index
# 2. Denormalize by multipying by absmax
dequant = values[closest_idx]*absmax
# mean absolute error:
error = np.abs(dequant-rand_inputs).mean()
print(f'Absolute linear 3-bit quantization error: {error:.4f}')
# This yields an error of about 0.34 per value. We can do better with non-linear quantization.
# ==================NON-LINEAR QUANTIZATION========================
# In non-linear quantization the distance between quantization values is not always equal.
# This allows us to allocate more values to regions of high density. For example, the normal distribution has many values around 0.
# This can reduce the overall error in the distribution.
index = np.array([0, 1, 2, 3, 4, 5, 6, 7])
values = np.array([-1.0, -0.5, -0.25, -0.075, 0.075, 0.25, 0.5, 1.0])
dist = cdist(normed.flatten().reshape(-1, 1), values.reshape(-1, 1))
closest_idx = np.argmin(dist, 1).reshape(rand_inputs.shape)
val, count = np.unique(closest_idx, return_counts=True)
print('Values:', val)
print('Count:', count)
dequant = values[closest_idx]*absmax
error = np.abs(dequant-rand_inputs).mean()
print(f'Absolute non-linear 3-bit quantization error: {error:.4f}')
# dynamic quantization
# Adaptive from: https://github.com/facebookresearch/bitsandbytes/blob/main/bitsandbytes/functional.py
def create_dynamic_map(signed=True, n=7):
'''
Creates the dynamic quantiztion map.
The dynamic data type is made up of a dynamic exponent and
fraction. As the exponent increase from 0 to -7 the number
of bits available for the fraction shrinks.
This is a generalization of the dynamic type where a certain
number of the bits and be reserved for the linear quantization
region (the fraction). n determines the maximum number of
exponent bits.
For more details see
(8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561]
'''
data = []
# these are additional items that come from the case
# where all the exponent bits are zero and no
# indicator bit is present
additional_items = 2**(7-n)-1
if not signed: additional_items = 2*additional_items
for i in range(n):
fraction_items = 2**(i+7-n)+1 if signed else 2**(i+7-n+1)+1
boundaries = np.linspace(0.1, 1, fraction_items)
means = (boundaries[:-1]+boundaries[1:])/2.0
data += ((10**(-(n-1)+i))*means).tolist()
if signed:
data += (-(10**(-(n-1)+i))*means).tolist()
if additional_items > 0:
boundaries = np.linspace(0.1, 1, additional_items+1)
means = (boundaries[:-1]+boundaries[1:])/2.0
data += ((10**(-(n-1)+i))*means).tolist()
if signed:
data += (-(10**(-(n-1)+i))*means).tolist()
data.append(0)
data.append(1.0)
data.sort()
return np.array(data)
import time
values = create_dynamic_map(signed=True)
t0 = time.time()
dist = cdist(normed.flatten().reshape(-1, 1), values.reshape(-1, 1))
closest_idx = np.argmin(dist, 1).reshape(rand_inputs.shape)
quant_time = time.time()-t0
dequant = values[closest_idx]*absmax
error = np.abs(dequant-rand_inputs).mean()
print(f'Absolute dynamic 8-bit quantization error: {error:.4f}')
print(f'Total time taken: {quant_time:.4f} seconds.')
# This yields an error as low as 0.012. We could do even better when we use block-wise quantization.
# But performing block-wise quantization without optimized code is a bit slow. We can use the bitsandbytes library to do this quickly.
import torch
import bitsandbytes.functional as F
rand_inputs = torch.from_numpy(rand_inputs)
t0 = time.time()
quant_values, quant_state = F.quantize_blockwise(rand_inputs)
quant_time = time.time()-t0
dequant_values = F.dequantize_blockwise(quant_values, quant_state)
error = torch.abs(dequant_values-rand_inputs).mean().item()
print(f'Absolute dynamic block-wise 8-bit quantization error: {error:.4f}')
print(f'Total time taken (CPU): {quant_time:.4f} seconds.')
rand_inputs = rand_inputs.cuda()
t0 = time.time()
quant_values, quant_state = F.quantize_blockwise(rand_inputs)
quant_time = time.time()-t0
print(f'Total time taken (GPU): {quant_time:.4f} seconds.')