jymcc commited on
Commit
bdeeb52
1 Parent(s): ad19a57
Files changed (2) hide show
  1. config.json +0 -0
  2. modeling_baichuan.py +16 -0
config.json CHANGED
The diff for this file is too large to render. See raw diff
 
modeling_baichuan.py CHANGED
@@ -706,6 +706,22 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
706
  generation_config: Optional[GenerationConfig]=None):
707
  generation_config = generation_config or self.generation_config
708
  input_ids = build_chat_input(self, tokenizer, messages, generation_config.max_new_tokens)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
709
  if stream:
710
  streamer = TextIterStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
711
  Thread(target=self.generate, kwargs=dict(
 
706
  generation_config: Optional[GenerationConfig]=None):
707
  generation_config = generation_config or self.generation_config
708
  input_ids = build_chat_input(self, tokenizer, messages, generation_config.max_new_tokens)
709
+ if stream:
710
+ streamer = TextIterStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
711
+ Thread(target=self.generate, kwargs=dict(
712
+ inputs=input_ids, streamer=streamer,
713
+ generation_config=generation_config,
714
+ )).start()
715
+ return streamer
716
+ else:
717
+ outputs = self.generate(input_ids, generation_config=generation_config)
718
+ response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
719
+ return response
720
+
721
+ def HuatuoChat(self, tokenizer, messages: List[dict], stream=False,
722
+ generation_config: Optional[GenerationConfig]=None):
723
+ generation_config = generation_config or self.generation_config
724
+ input_ids = build_chat_input(self, tokenizer, messages, generation_config.max_new_tokens)
725
  if stream:
726
  streamer = TextIterStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
727
  Thread(target=self.generate, kwargs=dict(