File size: 350 Bytes
4d061f7
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
from transformers import PreTrainedModel
from .configuration_bitnet import BitNetConfig 
from .quantization import BitLinear
from .linear_to_bitlinear import replace_linears_in_hf

class BitNetModel(PreTrainedModel):
    config_class = BitNetConfig
    def __init__(self, config):
        super().__init__(config)
        replace_linears_in_hf(self)