update config and streaming generation
Browse files- config.json +2 -2
- modeling_qwen.py +32 -15
config.json
CHANGED
@@ -14,12 +14,12 @@
|
|
14 |
"fp32": false,
|
15 |
"bias_dropout_fusion": true,
|
16 |
"bos_token_id": 151643,
|
17 |
-
"embd_pdrop": 0.
|
18 |
"eos_token_id": 151643,
|
19 |
"ffn_hidden_size": 22016,
|
20 |
"initializer_range": 0.02,
|
21 |
"kv_channels": 128,
|
22 |
-
"layer_norm_epsilon": 1e-
|
23 |
"model_type": "qwen",
|
24 |
"n_embd": 4096,
|
25 |
"n_head": 32,
|
|
|
14 |
"fp32": false,
|
15 |
"bias_dropout_fusion": true,
|
16 |
"bos_token_id": 151643,
|
17 |
+
"embd_pdrop": 0.0,
|
18 |
"eos_token_id": 151643,
|
19 |
"ffn_hidden_size": 22016,
|
20 |
"initializer_range": 0.02,
|
21 |
"kv_channels": 128,
|
22 |
+
"layer_norm_epsilon": 1e-06,
|
23 |
"model_type": "qwen",
|
24 |
"n_embd": 4096,
|
25 |
"n_head": 32,
|
modeling_qwen.py
CHANGED
@@ -958,8 +958,10 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|
958 |
history: Optional[HistoryType],
|
959 |
system: str = "You are a helpful assistant.",
|
960 |
append_history: bool = True,
|
|
|
961 |
) -> Tuple[str, HistoryType]:
|
962 |
|
|
|
963 |
if history is None:
|
964 |
history = []
|
965 |
|
@@ -976,21 +978,36 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|
976 |
self.generation_config.chat_format, tokenizer
|
977 |
)
|
978 |
input_ids = torch.tensor([context_tokens]).to(self.device)
|
979 |
-
|
980 |
-
|
981 |
-
|
982 |
-
|
983 |
-
|
984 |
-
|
985 |
-
|
986 |
-
|
987 |
-
|
988 |
-
|
989 |
-
|
990 |
-
|
991 |
-
|
992 |
-
|
993 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
994 |
|
995 |
if append_history:
|
996 |
history.append((query, response))
|
|
|
958 |
history: Optional[HistoryType],
|
959 |
system: str = "You are a helpful assistant.",
|
960 |
append_history: bool = True,
|
961 |
+
stream: Optional[bool] = False
|
962 |
) -> Tuple[str, HistoryType]:
|
963 |
|
964 |
+
|
965 |
if history is None:
|
966 |
history = []
|
967 |
|
|
|
978 |
self.generation_config.chat_format, tokenizer
|
979 |
)
|
980 |
input_ids = torch.tensor([context_tokens]).to(self.device)
|
981 |
+
if stream:
|
982 |
+
assert self.generation_config.chat_format == 'chatml'
|
983 |
+
from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig
|
984 |
+
self.__class__.generate = NewGenerationMixin.generate
|
985 |
+
self.__class__.sample_stream = NewGenerationMixin.sample_stream
|
986 |
+
stream_config = StreamGenerationConfig(**self.generation_config.to_dict(), do_stream=True)
|
987 |
+
def stream_generator():
|
988 |
+
outputs = []
|
989 |
+
for token in self.generate(input_ids, return_dict_in_generate=False, generation_config=stream_config):
|
990 |
+
outputs.append(token.item())
|
991 |
+
if outputs[-1] in (tokenizer.im_end_id, tokenizer.im_start_id):
|
992 |
+
break
|
993 |
+
yield tokenizer.decode(outputs, skip_special_tokens=True)
|
994 |
+
|
995 |
+
return stream_generator()
|
996 |
+
else:
|
997 |
+
outputs = self.generate(
|
998 |
+
input_ids,
|
999 |
+
stop_words_ids = stop_words_ids,
|
1000 |
+
return_dict_in_generate = False,
|
1001 |
+
)
|
1002 |
+
|
1003 |
+
response = decode_tokens(
|
1004 |
+
outputs[0],
|
1005 |
+
tokenizer,
|
1006 |
+
raw_text_len=len(raw_text),
|
1007 |
+
context_length=len(context_tokens),
|
1008 |
+
chat_format=self.generation_config.chat_format,
|
1009 |
+
verbose=False,
|
1010 |
+
)
|
1011 |
|
1012 |
if append_history:
|
1013 |
history.append((query, response))
|