yangapku commited on
Commit
ff3a904
1 Parent(s): e3edce3

deprecate argument stream in model.chat()

Browse files
Files changed (1) hide show
  1. modeling_qwen.py +23 -36
modeling_qwen.py CHANGED
@@ -60,6 +60,12 @@ If you are directly using the model downloaded from Huggingface, please make sur
60
  如果您在直接使用我们从Huggingface提供的模型,请确保您在调用model.chat()时,使用的是"Qwen/Qwen-7B-Chat"模型(而非"Qwen/Qwen-7B"预训练模型)。
61
  """
62
 
 
 
 
 
 
 
63
  apply_rotary_emb_func = None
64
  rms_norm = None
65
  flash_attn_unpadded_func = None
@@ -977,10 +983,11 @@ class QWenLMHeadModel(QWenPreTrainedModel):
977
  history: Optional[HistoryType],
978
  system: str = "You are a helpful assistant.",
979
  append_history: bool = True,
980
- stream: Optional[bool] = False,
981
  stop_words_ids: Optional[List[List[int]]] = None,
982
  **kwargs,
983
  ) -> Tuple[str, HistoryType]:
 
984
  assert self.generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT
985
  if history is None:
986
  history = []
@@ -1000,41 +1007,21 @@ class QWenLMHeadModel(QWenPreTrainedModel):
1000
  self.generation_config.chat_format, tokenizer
1001
  ))
1002
  input_ids = torch.tensor([context_tokens]).to(self.device)
1003
- if stream:
1004
- logger.warn(
1005
- "[WARNING] This usage is deprecated and marked for removal."
1006
- "Please use chat_stream() instead of chat(stream=True)."
1007
- )
1008
- from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig
1009
- self.__class__.generate_stream = NewGenerationMixin.generate
1010
- self.__class__.sample_stream = NewGenerationMixin.sample_stream
1011
- stream_config = StreamGenerationConfig(**self.generation_config.to_dict(), do_stream=True)
1012
- def stream_generator():
1013
- outputs = []
1014
- for token in self.generate_stream(
1015
- input_ids, return_dict_in_generate=False, generation_config=stream_config, **kwargs):
1016
- outputs.append(token.item())
1017
- if outputs[-1] in (tokenizer.im_end_id, tokenizer.im_start_id):
1018
- break
1019
- yield tokenizer.decode(outputs, skip_special_tokens=True)
1020
-
1021
- return stream_generator()
1022
- else:
1023
- outputs = self.generate(
1024
- input_ids,
1025
- stop_words_ids = stop_words_ids,
1026
- return_dict_in_generate = False,
1027
- **kwargs,
1028
- )
1029
-
1030
- response = decode_tokens(
1031
- outputs[0],
1032
- tokenizer,
1033
- raw_text_len=len(raw_text),
1034
- context_length=len(context_tokens),
1035
- chat_format=self.generation_config.chat_format,
1036
- verbose=False,
1037
- )
1038
 
1039
  if append_history:
1040
  history.append((query, response))
 
60
  如果您在直接使用我们从Huggingface提供的模型,请确保您在调用model.chat()时,使用的是"Qwen/Qwen-7B-Chat"模型(而非"Qwen/Qwen-7B"预训练模型)。
61
  """
62
 
63
+ _SENTINEL = object()
64
+ _ERROR_STREAM_IN_CHAT = """\
65
+ Pass argument `stream` to model.chat() is buggy, deprecated, and marked for removal. Please use model.chat_stream(...) instead of model.chat(..., stream=True).
66
+ 向model.chat()传入参数stream的用法可能存在Bug,该用法已被废弃,将在未来被移除。请使用model.chat_stream(...)代替model.chat(..., stream=True)。
67
+ """
68
+
69
  apply_rotary_emb_func = None
70
  rms_norm = None
71
  flash_attn_unpadded_func = None
 
983
  history: Optional[HistoryType],
984
  system: str = "You are a helpful assistant.",
985
  append_history: bool = True,
986
+ stream: Optional[bool] = _SENTINEL,
987
  stop_words_ids: Optional[List[List[int]]] = None,
988
  **kwargs,
989
  ) -> Tuple[str, HistoryType]:
990
+ assert stream is _SENTINEL, _ERROR_STREAM_IN_CHAT
991
  assert self.generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT
992
  if history is None:
993
  history = []
 
1007
  self.generation_config.chat_format, tokenizer
1008
  ))
1009
  input_ids = torch.tensor([context_tokens]).to(self.device)
1010
+ outputs = self.generate(
1011
+ input_ids,
1012
+ stop_words_ids = stop_words_ids,
1013
+ return_dict_in_generate = False,
1014
+ **kwargs,
1015
+ )
1016
+
1017
+ response = decode_tokens(
1018
+ outputs[0],
1019
+ tokenizer,
1020
+ raw_text_len=len(raw_text),
1021
+ context_length=len(context_tokens),
1022
+ chat_format=self.generation_config.chat_format,
1023
+ verbose=False,
1024
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1025
 
1026
  if append_history:
1027
  history.append((query, response))