duzx16
commited on
Commit
•
8049563
1
Parent(s):
71189e7
Fix prefix projection
Browse files- modeling_chatglm.py +7 -6
modeling_chatglm.py
CHANGED
@@ -68,11 +68,12 @@ class PrefixEncoder(torch.nn.Module):
|
|
68 |
self.prefix_projection = config.prefix_projection
|
69 |
if self.prefix_projection:
|
70 |
# Use a two-layer MLP to encode the prefix
|
71 |
-
|
|
|
72 |
self.trans = torch.nn.Sequential(
|
73 |
-
torch.nn.Linear(
|
74 |
torch.nn.Tanh(),
|
75 |
-
torch.nn.Linear(config.hidden_size,
|
76 |
)
|
77 |
else:
|
78 |
self.embedding = torch.nn.Embedding(config.pre_seq_len,
|
@@ -1013,7 +1014,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1013 |
inputs = inputs.to(self.device)
|
1014 |
return inputs
|
1015 |
|
1016 |
-
@torch.
|
1017 |
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 8192, num_beams=1,
|
1018 |
do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None, **kwargs):
|
1019 |
if history is None:
|
@@ -1031,7 +1032,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1031 |
history = history + [(query, response)]
|
1032 |
return response, history
|
1033 |
|
1034 |
-
@torch.
|
1035 |
def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, past_key_values=None,
|
1036 |
max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
|
1037 |
return_past_key_values=False, **kwargs):
|
@@ -1068,7 +1069,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1068 |
else:
|
1069 |
yield response, new_history
|
1070 |
|
1071 |
-
@torch.
|
1072 |
def stream_generate(
|
1073 |
self,
|
1074 |
input_ids,
|
|
|
68 |
self.prefix_projection = config.prefix_projection
|
69 |
if self.prefix_projection:
|
70 |
# Use a two-layer MLP to encode the prefix
|
71 |
+
kv_size = config.num_layers * config.kv_channels * config.multi_query_group_num * 2
|
72 |
+
self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size)
|
73 |
self.trans = torch.nn.Sequential(
|
74 |
+
torch.nn.Linear(kv_size, config.hidden_size),
|
75 |
torch.nn.Tanh(),
|
76 |
+
torch.nn.Linear(config.hidden_size, kv_size)
|
77 |
)
|
78 |
else:
|
79 |
self.embedding = torch.nn.Embedding(config.pre_seq_len,
|
|
|
1014 |
inputs = inputs.to(self.device)
|
1015 |
return inputs
|
1016 |
|
1017 |
+
@torch.inference_mode()
|
1018 |
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 8192, num_beams=1,
|
1019 |
do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None, **kwargs):
|
1020 |
if history is None:
|
|
|
1032 |
history = history + [(query, response)]
|
1033 |
return response, history
|
1034 |
|
1035 |
+
@torch.inference_mode()
|
1036 |
def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, past_key_values=None,
|
1037 |
max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
|
1038 |
return_past_key_values=False, **kwargs):
|
|
|
1069 |
else:
|
1070 |
yield response, new_history
|
1071 |
|
1072 |
+
@torch.inference_mode()
|
1073 |
def stream_generate(
|
1074 |
self,
|
1075 |
input_ids,
|