wuzhiying commited on
Commit
35e3d5a
1 Parent(s): be60123

fix BitsAndBytesConfig 8bit issues

Browse files
Files changed (1) hide show
  1. modeling_baichuan.py +2 -1
modeling_baichuan.py CHANGED
@@ -534,7 +534,8 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
534
  super().__init__(config, *model_args, **model_kwargs)
535
  self.model = BaichuanModel(config)
536
  self.lm_head = NormHead(config.hidden_size, config.vocab_size, bias=False)
537
- if hasattr(config, "quantization_config") and config.quantization_config['load_in_4bit']:
 
538
  try:
539
  from .quantizer import quantize_offline, init_model_weight_int4
540
  except ImportError:
 
534
  super().__init__(config, *model_args, **model_kwargs)
535
  self.model = BaichuanModel(config)
536
  self.lm_head = NormHead(config.hidden_size, config.vocab_size, bias=False)
537
+ #if hasattr(config, "quantization_config") and config.quantization_config['load_in_4bit']:
538
+ if hasattr(config, "quantization_config") and isinstance(config.quantization_config, dict) and config.quantization_config.get('load_in_4bit', False):
539
  try:
540
  from .quantizer import quantize_offline, init_model_weight_int4
541
  except ImportError: