import torch, argparse, copy from transformers import AutoModelForCausalLM, AutoTokenizer from auto_gptq.nn_modules.qlinear.qlinear_exllama import QuantLinear from marlin import Layer as MarlinLayer import gc parser = argparse.ArgumentParser() parser.add_argument("--model-id", type=str) parser.add_argument("--save-path", type=str) parser.add_argument("--do-generation", action="store_true") def _validate_compatibility(model): if not hasattr(model.config, "quantization_config"): raise ValueError("Must be a quantized model to convert to Marlin Format") quantization_config = model.config.quantization_config if quantization_config.quant_method != "gptq": raise ValueError(f"Only GPTQ models can be converted to Marlin format. You passed a model with quant_method={quantization_config.quant_method}") if quantization_config.bits != 4: raise ValueError(f"Only 4 bit quantized models can be converted to Marlin format. You passed a model with bits={quantization_config.bits}") if quantization_config.group_size != 128: raise ValueError(f"Only group size 128 models can be converted to Marlin format. You passed a model with group_size={quantization_config.group_size}") if not quantization_config.sym: raise ValueError(f"Only models with symmetric quantization can be converted to Marlin Format. You passed a model with sym={quantization_config.sym}") if quantization_config.desc_act: raise ValueError(f"Models with act order quantization cannot be converted to Marlin Format. You passed a model with desc_act={quantization_config.desc_act}") @torch.no_grad() def unpack_4bit_to_32bit_signed(qweight, qzeros): # Unpack 4-bit values and interpret them as signed integers unpacked_weights = torch.zeros((qweight.shape[0]*8, qweight.shape[1]), dtype=torch.int8, device=qweight.device, requires_grad=False) unpacked_zeros = torch.zeros((qzeros.shape[0], qzeros.shape[1]*8), dtype=torch.int8, device=qzeros.device, requires_grad=False) for row in range(unpacked_weights.shape[0]): i = row % 8 unpacked_weights[row, :] = (qweight[row // 8, :] >> (4 * i)) & 0xF for col in range(unpacked_zeros.shape[1]): i = col % 8 unpacked_zeros[:, col] = (qzeros[:, col // 8] >> (4 * i)) & 0xF return unpacked_weights, unpacked_zeros + 1 @torch.no_grad() def dequantize_weight(layer): qweight, qzeros, scales = layer.qweight, layer.qzeros, layer.scales unpacked_qweight, unpacked_qzeros = unpack_4bit_to_32bit_signed(qweight, qzeros) group_size = unpacked_qweight.shape[0] // scales.shape[0] scales = scales.repeat_interleave(group_size, dim=0) unpacked_qzeros = unpacked_qzeros.repeat_interleave(group_size, dim=0) unpacked_qweight = (unpacked_qweight - unpacked_qzeros) * scales return unpacked_qweight.T @torch.no_grad() def convert_model(model, verbose=True): for name, module in model.named_modules(): if not isinstance(module, QuantLinear): continue if verbose: print(f"--- Converting Module: {name}") parent_name = ".".join(name.split(".")[:-1]) layer_name = name[len(parent_name) + 1:] # Dequantize the weight. dequantized_weight = dequantize_weight(module).to(torch.float16) linear_module = torch.nn.Linear( in_features=dequantized_weight.shape[1], out_features=dequantized_weight.shape[0], bias=False, dtype=torch.float16, device="cuda") linear_module.weight.data.copy_(dequantized_weight) # Create new linear method and copy to model. new_module = MarlinLayer( infeatures=linear_module.in_features, outfeatures=linear_module.out_features, groupsize=model.config.quantization_config.group_size) new_module.pack(linear_module, scales=copy.deepcopy(module.scales.data.t())) # Save to parent. parent_module = model.get_submodule(parent_name) setattr(parent_module, layer_name, new_module) # Free cuda memory. del dequantized_weight, module torch.cuda.empty_cache() gc.collect() return model @torch.no_grad() def dequantize_model(model, verbose=True): for name, module in model.named_modules(): if not isinstance(module, QuantLinear): continue if verbose: print(f"--- Dequantizing Module: {name}") parent_name = ".".join(name.split(".")[:-1]) layer_name = name[len(parent_name) + 1:] # Dequantize the weight. dequantized_weight = dequantize_weight(module) dequantized_weight_cpu = dequantized_weight.to("cpu") # Create new linear method and copy to model. new_module = torch.nn.Linear( in_features=dequantized_weight_cpu.shape[1], out_features=dequantized_weight_cpu.shape[0], bias=False, dtype=torch.float16) new_module.weight.data.copy_(dequantized_weight_cpu) new_module.scales = torch.nn.Parameter(copy.deepcopy(module.scales.data)) # Save to parent. parent_module = model.get_submodule(parent_name) setattr(parent_module, layer_name, new_module) # Free cuda memory. del dequantized_weight, dequantized_weight_cpu, module torch.cuda.empty_cache() return model if __name__ == "__main__": args = parser.parse_args() model_id = args.model_id save_path = args.save_path do_generation = args.do_generation print("Loading gptq model...") model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto") tokenizer = AutoTokenizer.from_pretrained(model_id) # Validate that this model is compatible with Marlin. print("Validating compatibility...") _validate_compatibility(model) # Dequantize the Model. print("Converting model...") model = convert_model(model).to("cpu") # Save after updating quantization config. print("Saving marlin model...") model.config.quantization_config = { "group_size": model.config.quantization_config.group_size, "quant_method": "marlin" } model.save_pretrained(save_path) tokenizer.save_pretrained(save_path) if do_generation: print("Generating sample text...") model.to("cuda") prompt = "My favorite song is" inputs = tokenizer(prompt, return_tensors="pt") inputs = {k: v.to("cuda") for k, v in inputs.items()} outputs = model.generate(**inputs, max_new_tokens=50, do_sample=False) print(tokenizer.batch_decode(outputs)[0])