yangapku commited on
Commit
2db302e
1 Parent(s): 50ea631

fix chat streaming

Browse files
Files changed (1) hide show
  1. modeling_qwen.py +79 -10
modeling_qwen.py CHANGED
@@ -5,7 +5,7 @@
5
 
6
  import importlib
7
  import math
8
- from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List
9
 
10
  import torch
11
  import torch.nn.functional as F
@@ -53,6 +53,13 @@ _CONFIG_FOR_DOC = "QWenConfig"
53
 
54
  QWen_PRETRAINED_MODEL_ARCHIVE_LIST = ["qwen-7b"]
55
 
 
 
 
 
 
 
 
56
  apply_rotary_emb_func = None
57
  rms_norm = None
58
  flash_attn_unpadded_func = None
@@ -971,6 +978,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
971
  stop_words_ids: Optional[List[List[int]]] = None,
972
  **kwargs,
973
  ) -> Tuple[str, HistoryType]:
 
974
  if history is None:
975
  history = []
976
  if stop_words_ids is None:
@@ -990,14 +998,17 @@ class QWenLMHeadModel(QWenPreTrainedModel):
990
  ))
991
  input_ids = torch.tensor([context_tokens]).to(self.device)
992
  if stream:
993
- assert self.generation_config.chat_format == 'chatml'
 
 
 
994
  from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig
995
- self.__class__.generate = NewGenerationMixin.generate
996
  self.__class__.sample_stream = NewGenerationMixin.sample_stream
997
  stream_config = StreamGenerationConfig(**self.generation_config.to_dict(), do_stream=True)
998
  def stream_generator():
999
  outputs = []
1000
- for token in self.generate(
1001
  input_ids, return_dict_in_generate=False, generation_config=stream_config, **kwargs):
1002
  outputs.append(token.item())
1003
  if outputs[-1] in (tokenizer.im_end_id, tokenizer.im_start_id):
@@ -1027,6 +1038,62 @@ class QWenLMHeadModel(QWenPreTrainedModel):
1027
 
1028
  return response, history
1029
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1030
  def generate(
1031
  self,
1032
  inputs: Optional[torch.Tensor] = None,
@@ -1037,6 +1104,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
1037
  Callable[[int, torch.Tensor], List[int]]
1038
  ] = None,
1039
  synced_gpus: Optional[bool] = None,
 
1040
  streamer: Optional["BaseStreamer"] = None,
1041
  **kwargs,
1042
  ) -> Union[GenerateOutput, torch.LongTensor]:
@@ -1059,12 +1127,13 @@ class QWenLMHeadModel(QWenPreTrainedModel):
1059
 
1060
  return super().generate(
1061
  inputs,
1062
- generation_config,
1063
- logits_processor,
1064
- stopping_criteria,
1065
- prefix_allowed_tokens_fn,
1066
- synced_gpus,
1067
- streamer,
 
1068
  **kwargs,
1069
  )
1070
 
 
5
 
6
  import importlib
7
  import math
8
+ from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator
9
 
10
  import torch
11
  import torch.nn.functional as F
 
53
 
54
  QWen_PRETRAINED_MODEL_ARCHIVE_LIST = ["qwen-7b"]
55
 
56
+ _ERROR_BAD_CHAT_FORMAT = """\
57
+ We detect you are probably using the pretrained model (rather than chat model) for chatting, since the chat_format in generation_config is not "chatml".
58
+ If you are directly using the model downloaded from Huggingface, please make sure you are using our "Qwen/Qwen-7B-Chat" Huggingface model (rather than "Qwen/Qwen-7B") when you call model.chat().
59
+ 我们检测到您可能在使用预训练模型(而非chat模型)进行多轮chat,因为您当前在generation_config指定的chat_format,并未设置为我们在对话中所支持的"chatml"格式。
60
+ 如果您在直接使用我们从Huggingface提供的模型,请确保您在调用model.chat()时,使用的是"Qwen/Qwen-7B-Chat"模型(而非"Qwen/Qwen-7B"预训练模型)。
61
+ """
62
+
63
  apply_rotary_emb_func = None
64
  rms_norm = None
65
  flash_attn_unpadded_func = None
 
978
  stop_words_ids: Optional[List[List[int]]] = None,
979
  **kwargs,
980
  ) -> Tuple[str, HistoryType]:
981
+ assert self.generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT
982
  if history is None:
983
  history = []
984
  if stop_words_ids is None:
 
998
  ))
999
  input_ids = torch.tensor([context_tokens]).to(self.device)
1000
  if stream:
1001
+ logger.warn(
1002
+ "[WARNING] This usage is deprecated and marked for removal."
1003
+ "Please use chat_stream() instead of chat(stream=True)."
1004
+ )
1005
  from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig
1006
+ self.__class__.generate_stream = NewGenerationMixin.generate
1007
  self.__class__.sample_stream = NewGenerationMixin.sample_stream
1008
  stream_config = StreamGenerationConfig(**self.generation_config.to_dict(), do_stream=True)
1009
  def stream_generator():
1010
  outputs = []
1011
+ for token in self.generate_stream(
1012
  input_ids, return_dict_in_generate=False, generation_config=stream_config, **kwargs):
1013
  outputs.append(token.item())
1014
  if outputs[-1] in (tokenizer.im_end_id, tokenizer.im_start_id):
 
1038
 
1039
  return response, history
1040
 
1041
+ def chat_stream(
1042
+ self,
1043
+ tokenizer: PreTrainedTokenizer,
1044
+ query: str,
1045
+ history: Optional[HistoryType],
1046
+ system: str = "You are a helpful assistant.",
1047
+ stop_words_ids: Optional[List[List[int]]] = None,
1048
+ logits_processor: Optional[LogitsProcessorList] = None,
1049
+ **kwargs,
1050
+ ) -> Generator[str, Any, None]:
1051
+ assert self.generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT
1052
+ if history is None:
1053
+ history = []
1054
+ if stop_words_ids is None:
1055
+ stop_words_ids = []
1056
+
1057
+ raw_text, context_tokens = make_context(
1058
+ tokenizer,
1059
+ query,
1060
+ history=history,
1061
+ system=system,
1062
+ max_window_size=6144,
1063
+ chat_format=self.generation_config.chat_format,
1064
+ )
1065
+
1066
+ stop_words_ids.extend(get_stop_words_ids(
1067
+ self.generation_config.chat_format, tokenizer
1068
+ ))
1069
+ if stop_words_ids is not None:
1070
+ stop_words_logits_processor = StopWordsLogitsProcessor(
1071
+ stop_words_ids=stop_words_ids,
1072
+ eos_token_id=self.generation_config.eos_token_id,
1073
+ )
1074
+ if logits_processor is None:
1075
+ logits_processor = LogitsProcessorList([stop_words_logits_processor])
1076
+ else:
1077
+ logits_processor.append(stop_words_logits_processor)
1078
+ input_ids = torch.tensor([context_tokens]).to(self.device)
1079
+
1080
+ from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig
1081
+ self.__class__.generate_stream = NewGenerationMixin.generate
1082
+ self.__class__.sample_stream = NewGenerationMixin.sample_stream
1083
+ stream_config = StreamGenerationConfig(**self.generation_config.to_dict(), do_stream=True)
1084
+ def stream_generator():
1085
+ outputs = []
1086
+ for token in self.generate_stream(
1087
+ input_ids,
1088
+ return_dict_in_generate=False,
1089
+ generation_config=stream_config,
1090
+ logits_processor=logits_processor,
1091
+ **kwargs):
1092
+ outputs.append(token.item())
1093
+ yield tokenizer.decode(outputs, skip_special_tokens=True, erros='ignore')
1094
+
1095
+ return stream_generator()
1096
+
1097
  def generate(
1098
  self,
1099
  inputs: Optional[torch.Tensor] = None,
 
1104
  Callable[[int, torch.Tensor], List[int]]
1105
  ] = None,
1106
  synced_gpus: Optional[bool] = None,
1107
+ assistant_model: Optional["PreTrainedModel"] = None,
1108
  streamer: Optional["BaseStreamer"] = None,
1109
  **kwargs,
1110
  ) -> Union[GenerateOutput, torch.LongTensor]:
 
1127
 
1128
  return super().generate(
1129
  inputs,
1130
+ generation_config=generation_config,
1131
+ logits_processor=logits_processor,
1132
+ stopping_criteria=stopping_criteria,
1133
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
1134
+ synced_gpus=synced_gpus,
1135
+ assistant_model=assistant_model,
1136
+ streamer=streamer,
1137
  **kwargs,
1138
  )
1139