xuxw98 commited on
Commit
c8ac827
·
1 Parent(s): 2b6f2c0

Update lit_llama/utils.py

Browse files
Files changed (1) hide show
  1. lit_llama/utils.py +3 -3
lit_llama/utils.py CHANGED
@@ -89,13 +89,13 @@ class EmptyInitOnDevice(torch.overrides.TorchFunctionMode):
89
  if self.quantization_mode == 'llm.int8':
90
  if device.type != "cuda":
91
  raise ValueError("Quantization is only supported on the GPU.")
92
- from lit_llama.quantization import Linear8bitLt
93
  self.quantized_linear_cls = Linear8bitLt
94
  elif self.quantization_mode == 'gptq.int4':
95
- from lit_llama.quantization import ColBlockQuantizedLinear
96
  self.quantized_linear_cls = functools.partial(ColBlockQuantizedLinear, bits=4, tile_cols=-1)
97
  elif self.quantization_mode == 'gptq.int8':
98
- from lit_llama.quantization import ColBlockQuantizedLinear
99
  self.quantized_linear_cls = functools.partial(ColBlockQuantizedLinear, bits=8, tile_cols=-1)
100
  elif self.quantization_mode is not None:
101
  raise RuntimeError(f"unknown quantization mode {self.quantization_mode}")
 
89
  if self.quantization_mode == 'llm.int8':
90
  if device.type != "cuda":
91
  raise ValueError("Quantization is only supported on the GPU.")
92
+ from .quantization import Linear8bitLt
93
  self.quantized_linear_cls = Linear8bitLt
94
  elif self.quantization_mode == 'gptq.int4':
95
+ from .quantization import ColBlockQuantizedLinear
96
  self.quantized_linear_cls = functools.partial(ColBlockQuantizedLinear, bits=4, tile_cols=-1)
97
  elif self.quantization_mode == 'gptq.int8':
98
+ from .quantization import ColBlockQuantizedLinear
99
  self.quantized_linear_cls = functools.partial(ColBlockQuantizedLinear, bits=8, tile_cols=-1)
100
  elif self.quantization_mode is not None:
101
  raise RuntimeError(f"unknown quantization mode {self.quantization_mode}")