Spaces:
Runtime error
Runtime error
# By Forge | |
import torch | |
def native_unpack_4x4bits_in_1x16bits_to_4x8bits_in_1x32bits(x): | |
x = x.view(torch.uint8).view(x.size(0), -1) | |
unpacked = torch.stack([x & 15, x >> 4], dim=-1) | |
reshaped = unpacked.view(x.size(0), -1) | |
reshaped = reshaped.view(torch.int8) - 8 | |
return reshaped.view(torch.int32) | |
def native_unpack_4x4bits_in_1x16bits_to_4x8bits_in_1x32bits_u(x): | |
x = x.view(torch.uint8).view(x.size(0), -1) | |
unpacked = torch.stack([x & 15, x >> 4], dim=-1) | |
reshaped = unpacked.view(x.size(0), -1) | |
return reshaped.view(torch.int32) | |
disable_all_optimizations = False | |
if not hasattr(torch, 'uint16'): | |
disable_all_optimizations = True | |
if disable_all_optimizations: | |
print('You are using PyTorch below version 2.3. Some optimizations will be disabled.') | |
if not disable_all_optimizations: | |
native_4bits_lookup_table = native_unpack_4x4bits_in_1x16bits_to_4x8bits_in_1x32bits(torch.arange(start=0, end=256*256, dtype=torch.long).to(torch.uint16))[:, 0] | |
native_4bits_lookup_table_u = native_unpack_4x4bits_in_1x16bits_to_4x8bits_in_1x32bits_u(torch.arange(start=0, end=256*256, dtype=torch.long).to(torch.uint16))[:, 0] | |
def quick_unpack_4bits(x): | |
if disable_all_optimizations: | |
return torch.stack([x & 15, x >> 4], dim=-1).view(x.size(0), -1).view(torch.int8) - 8 | |
global native_4bits_lookup_table | |
s0 = x.size(0) | |
x = x.view(torch.uint16) | |
if native_4bits_lookup_table.device != x.device: | |
native_4bits_lookup_table = native_4bits_lookup_table.to(device=x.device) | |
y = torch.index_select(input=native_4bits_lookup_table, dim=0, index=x.to(dtype=torch.int32).flatten()) | |
y = y.view(torch.int8) | |
y = y.view(s0, -1) | |
return y | |
def quick_unpack_4bits_u(x): | |
if disable_all_optimizations: | |
return torch.stack([x & 15, x >> 4], dim=-1).view(x.size(0), -1) | |
global native_4bits_lookup_table_u | |
s0 = x.size(0) | |
x = x.view(torch.uint16) | |
if native_4bits_lookup_table_u.device != x.device: | |
native_4bits_lookup_table_u = native_4bits_lookup_table_u.to(device=x.device) | |
y = torch.index_select(input=native_4bits_lookup_table_u, dim=0, index=x.to(dtype=torch.int32).flatten()) | |
y = y.view(torch.uint8) | |
y = y.view(s0, -1) | |
return y | |
def change_4bits_order(x): | |
y = torch.stack([x & 15, x >> 4], dim=-2).view(x.size(0), -1) | |
z = y[:, ::2] | (y[:, 1::2] << 4) | |
return z | |