from torch import nn from .quantization import BitLinear def replace_linears_in_hf( model, name_skip = 'lm_head' ): """ Replaces all instances of nn.Linear in the given model with BitLinear15b. Args: model (nn.Module): The model to modify. Returns: None """ for name, module in model.named_children(): if isinstance(module, nn.Linear) and name != name_skip: # Replace the nn.Linear with BitLinear matching in features and and out_features, and add it to the model setattr( model, name, BitLinear( in_features=module.in_features, out_features=module.out_features, bias=module.bias is not None, ), ) else: # Recursively apply to child modules replace_linears_in_hf(module) # def final_quantization( # model, # ): # for name, module in model.named_children(): # if isinstance(module, BitLinear): # module.quantization() # else: # # Recursively apply to child modules # final_quantization(module) def final_quantization(model): for name, module in model.named_children(): if isinstance(module, BitLinear): # Cuantificar directamente los pesos y biases del módulo module.weight.data = weight_quant(module.weight.data) if module.bias is not None: module.bias.data = activation_quant(module.bias.data, module.input_bits) else: # Recursivamente aplicar a los módulos hijos final_quantization(module)