NilEneb's picture
Upload folder using huggingface_hub
ad93086 verified
# By Forge
import torch
def native_unpack_4x4bits_in_1x16bits_to_4x8bits_in_1x32bits(x):
x = x.view(torch.uint8).view(x.size(0), -1)
unpacked = torch.stack([x & 15, x >> 4], dim=-1)
reshaped = unpacked.view(x.size(0), -1)
reshaped = reshaped.view(torch.int8) - 8
return reshaped.view(torch.int32)
def native_unpack_4x4bits_in_1x16bits_to_4x8bits_in_1x32bits_u(x):
x = x.view(torch.uint8).view(x.size(0), -1)
unpacked = torch.stack([x & 15, x >> 4], dim=-1)
reshaped = unpacked.view(x.size(0), -1)
return reshaped.view(torch.int32)
disable_all_optimizations = False
if not hasattr(torch, 'uint16'):
disable_all_optimizations = True
if disable_all_optimizations:
print('You are using PyTorch below version 2.3. Some optimizations will be disabled.')
if not disable_all_optimizations:
native_4bits_lookup_table = native_unpack_4x4bits_in_1x16bits_to_4x8bits_in_1x32bits(torch.arange(start=0, end=256*256, dtype=torch.long).to(torch.uint16))[:, 0]
native_4bits_lookup_table_u = native_unpack_4x4bits_in_1x16bits_to_4x8bits_in_1x32bits_u(torch.arange(start=0, end=256*256, dtype=torch.long).to(torch.uint16))[:, 0]
def quick_unpack_4bits(x):
if disable_all_optimizations:
return torch.stack([x & 15, x >> 4], dim=-1).view(x.size(0), -1).view(torch.int8) - 8
global native_4bits_lookup_table
s0 = x.size(0)
x = x.view(torch.uint16)
if native_4bits_lookup_table.device != x.device:
native_4bits_lookup_table = native_4bits_lookup_table.to(device=x.device)
y = torch.index_select(input=native_4bits_lookup_table, dim=0, index=x.to(dtype=torch.int32).flatten())
y = y.view(torch.int8)
y = y.view(s0, -1)
return y
def quick_unpack_4bits_u(x):
if disable_all_optimizations:
return torch.stack([x & 15, x >> 4], dim=-1).view(x.size(0), -1)
global native_4bits_lookup_table_u
s0 = x.size(0)
x = x.view(torch.uint16)
if native_4bits_lookup_table_u.device != x.device:
native_4bits_lookup_table_u = native_4bits_lookup_table_u.to(device=x.device)
y = torch.index_select(input=native_4bits_lookup_table_u, dim=0, index=x.to(dtype=torch.int32).flatten())
y = y.view(torch.uint8)
y = y.view(s0, -1)
return y
def change_4bits_order(x):
y = torch.stack([x & 15, x >> 4], dim=-2).view(x.size(0), -1)
z = y[:, ::2] | (y[:, 1::2] << 4)
return z