curlyfu commited on
Commit
50d7a1d
·
1 Parent(s): 658202d

支持batch形式的对话

Browse files

— 添加chat_batch方法,可以同时进行多次多伦对话


![1683539812923.png](https://img1.imgtp.com/2023/05/08/AeydGQVA.png)

Files changed (1) hide show
  1. modeling_chatglm.py +30 -13
modeling_chatglm.py CHANGED
@@ -721,7 +721,6 @@ CHATGLM_6B_START_DOCSTRING = r"""
721
  This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class.
722
  Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
723
  usage and behavior.
724
-
725
  Parameters:
726
  config ([`~ChatGLM6BConfig`]): Model configuration class with all the parameters of the model.
727
  Initializing with a config file does not load the weights associated with the model, only the configuration.
@@ -732,37 +731,28 @@ CHATGLM_6B_INPUTS_DOCSTRING = r"""
732
  Args:
733
  input_ids (`torch.LongTensor` of shape `({0})`):
734
  Indices of input sequence tokens in the vocabulary.
735
-
736
  Indices can be obtained using [`ChatGLM6BTokenizer`].
737
  See [`PreTrainedTokenizer.encode`] and
738
  [`PreTrainedTokenizer.__call__`] for details.
739
-
740
  [What are input IDs?](../glossary#input-ids)
741
  attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
742
  Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
743
-
744
  - 1 for tokens that are **not masked**,
745
  - 0 for tokens that are **masked**.
746
-
747
  [What are attention masks?](../glossary#attention-mask)
748
  token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
749
  Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 1]`:
750
-
751
  - 0 corresponds to a *sentence A* token,
752
  - 1 corresponds to a *sentence B* token.
753
-
754
  [What are token type IDs?](../glossary#token-type-ids)
755
  position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
756
  Indices of positions of each input sequence tokens in the position embeddings.
757
  Selected in the range `[0, config.max_position_embeddings - 1]`.
758
-
759
  [What are position IDs?](../glossary#position-ids)
760
  head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
761
  Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
762
-
763
  - 1 indicates the head is **not masked**,
764
  - 0 indicates the head is **masked**.
765
-
766
  inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
767
  Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
768
  This is useful if you want more control over how to convert *input_ids* indices into associated vectors
@@ -784,13 +774,11 @@ CHATGLM_6B_INPUTS_DOCSTRING = r"""
784
  )
785
  class ChatGLMModel(ChatGLMPreTrainedModel):
786
  """
787
-
788
  The model can behave as an encoder (with only self-attention) as well
789
  as a decoder, in which case a layer of cross-attention is added between
790
  the self-attention layers, following the architecture described in [Attention is
791
  all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani,
792
  Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
793
-
794
  To behave as an decoder the model needs to be initialized with the
795
  `is_decoder` argument of the configuration set to `True`.
796
  To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder`
@@ -1237,7 +1225,6 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1237
  This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
1238
  [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
1239
  beam_idx at every generation step.
1240
-
1241
  Output shares the same memory storage as `past`.
1242
  """
1243
  return tuple(
@@ -1288,6 +1275,36 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1288
  response = self.process_response(response)
1289
  history = history + [(query, response)]
1290
  return response, history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1291
 
1292
  @torch.no_grad()
1293
  def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048,
 
721
  This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class.
722
  Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
723
  usage and behavior.
 
724
  Parameters:
725
  config ([`~ChatGLM6BConfig`]): Model configuration class with all the parameters of the model.
726
  Initializing with a config file does not load the weights associated with the model, only the configuration.
 
731
  Args:
732
  input_ids (`torch.LongTensor` of shape `({0})`):
733
  Indices of input sequence tokens in the vocabulary.
 
734
  Indices can be obtained using [`ChatGLM6BTokenizer`].
735
  See [`PreTrainedTokenizer.encode`] and
736
  [`PreTrainedTokenizer.__call__`] for details.
 
737
  [What are input IDs?](../glossary#input-ids)
738
  attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
739
  Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
 
740
  - 1 for tokens that are **not masked**,
741
  - 0 for tokens that are **masked**.
 
742
  [What are attention masks?](../glossary#attention-mask)
743
  token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
744
  Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 1]`:
 
745
  - 0 corresponds to a *sentence A* token,
746
  - 1 corresponds to a *sentence B* token.
 
747
  [What are token type IDs?](../glossary#token-type-ids)
748
  position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
749
  Indices of positions of each input sequence tokens in the position embeddings.
750
  Selected in the range `[0, config.max_position_embeddings - 1]`.
 
751
  [What are position IDs?](../glossary#position-ids)
752
  head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
753
  Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
 
754
  - 1 indicates the head is **not masked**,
755
  - 0 indicates the head is **masked**.
 
756
  inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
757
  Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
758
  This is useful if you want more control over how to convert *input_ids* indices into associated vectors
 
774
  )
775
  class ChatGLMModel(ChatGLMPreTrainedModel):
776
  """
 
777
  The model can behave as an encoder (with only self-attention) as well
778
  as a decoder, in which case a layer of cross-attention is added between
779
  the self-attention layers, following the architecture described in [Attention is
780
  all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani,
781
  Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
 
782
  To behave as an decoder the model needs to be initialized with the
783
  `is_decoder` argument of the configuration set to `True`.
784
  To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder`
 
1225
  This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
1226
  [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
1227
  beam_idx at every generation step.
 
1228
  Output shares the same memory storage as `past`.
1229
  """
1230
  return tuple(
 
1275
  response = self.process_response(response)
1276
  history = history + [(query, response)]
1277
  return response, history
1278
+
1279
+ @torch.no_grad()
1280
+ def chat_batch(self, tokenizer, querys=List[str], historys=List[List[Tuple[str, str]]], max_length: int = 2048, num_beams=1,
1281
+ do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
1282
+ responses = []
1283
+ prompts = []
1284
+ for query, history in zip(querys, historys):
1285
+ if history is None:
1286
+ history = []
1287
+ if logits_processor is None:
1288
+ logits_processor = LogitsProcessorList()
1289
+ logits_processor.append(InvalidScoreLogitsProcessor())
1290
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
1291
+ "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1292
+ if not history:
1293
+ prompt = query
1294
+ else:
1295
+ prompt = ""
1296
+ for i, (old_query, response) in enumerate(history):
1297
+ prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
1298
+ prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
1299
+ prompts.append(prompt)
1300
+ inputs = tokenizer(prompts, return_tensors="pt", padding=True)
1301
+ inputs = inputs.to(self.device)
1302
+ outputs = self.generate(**inputs, **gen_kwargs)
1303
+ outputs = outputs.tolist()
1304
+ outputs = [x[len(inputs["input_ids"][0]):] for x in outputs]
1305
+ responses = [tokenizer.decode(output) for output in outputs]
1306
+ responses = [self.process_response(response) for response in responses]
1307
+ return responses
1308
 
1309
  @torch.no_grad()
1310
  def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048,