Fix embedding quantization
Browse files- modeling_chatglm.py +10 -5
modeling_chatglm.py
CHANGED
@@ -1408,6 +1408,11 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1408 |
|
1409 |
self.transformer = quantize(self.transformer, bits, use_quantization_cache=use_quantization_cache, empty_init=empty_init, **kwargs)
|
1410 |
|
|
|
|
|
|
|
|
|
|
|
1411 |
if quantize_embeddings:
|
1412 |
logger.info("Applying quantization to embeddings")
|
1413 |
self.transformer.word_embeddings = QuantizedEmbedding(
|
@@ -1415,11 +1420,11 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1415 |
weight_tensor=self.transformer.word_embeddings.weight.to(self.device),
|
1416 |
num_embeddings=self.transformer.word_embeddings.num_embeddings,
|
1417 |
embedding_dim=self.transformer.word_embeddings.embedding_dim,
|
1418 |
-
dtype=
|
1419 |
-
empty_init=
|
1420 |
device=self.transformer.word_embeddings.weight.device,
|
1421 |
)
|
1422 |
-
self.lm_head =
|
1423 |
weight_bit_width=bits,
|
1424 |
weight_tensor=self.lm_head.weight.to(self.device),
|
1425 |
bias_tensor=None,
|
@@ -1428,8 +1433,8 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1428 |
bias=False,
|
1429 |
quantized_weight=self.transformer.word_embeddings.weight,
|
1430 |
quantized_weight_scale=self.transformer.word_embeddings.weight_scale,
|
1431 |
-
dtype=
|
1432 |
-
empty_init=
|
1433 |
device=self.lm_head.weight.device,
|
1434 |
)
|
1435 |
|
|
|
1408 |
|
1409 |
self.transformer = quantize(self.transformer, bits, use_quantization_cache=use_quantization_cache, empty_init=empty_init, **kwargs)
|
1410 |
|
1411 |
+
if self.device == torch.device("cpu"):
|
1412 |
+
dtype = torch.float32
|
1413 |
+
else:
|
1414 |
+
dtype = torch.half
|
1415 |
+
|
1416 |
if quantize_embeddings:
|
1417 |
logger.info("Applying quantization to embeddings")
|
1418 |
self.transformer.word_embeddings = QuantizedEmbedding(
|
|
|
1420 |
weight_tensor=self.transformer.word_embeddings.weight.to(self.device),
|
1421 |
num_embeddings=self.transformer.word_embeddings.num_embeddings,
|
1422 |
embedding_dim=self.transformer.word_embeddings.embedding_dim,
|
1423 |
+
dtype=dtype,
|
1424 |
+
empty_init=empty_init,
|
1425 |
device=self.transformer.word_embeddings.weight.device,
|
1426 |
)
|
1427 |
+
self.lm_head = QuantizedLinear(
|
1428 |
weight_bit_width=bits,
|
1429 |
weight_tensor=self.lm_head.weight.to(self.device),
|
1430 |
bias_tensor=None,
|
|
|
1433 |
bias=False,
|
1434 |
quantized_weight=self.transformer.word_embeddings.weight,
|
1435 |
quantized_weight_scale=self.transformer.word_embeddings.weight_scale,
|
1436 |
+
dtype=dtype,
|
1437 |
+
empty_init=empty_init,
|
1438 |
device=self.lm_head.weight.device,
|
1439 |
)
|
1440 |
|