from transformers import PreTrainedModel | |
from .configuration_bitnet import BitNetConfig | |
from .quantization import BitLinear | |
from .linear_to_bitlinear import replace_linears_in_hf | |
class BitNetModel(PreTrainedModel): | |
config_class = BitNetConfig | |
def __init__(self, config): | |
super().__init__(config) | |
replace_linears_in_hf(self) | |