update
Browse files- config.json +0 -0
- modeling_baichuan.py +16 -0
config.json
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
modeling_baichuan.py
CHANGED
@@ -706,6 +706,22 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
|
|
706 |
generation_config: Optional[GenerationConfig]=None):
|
707 |
generation_config = generation_config or self.generation_config
|
708 |
input_ids = build_chat_input(self, tokenizer, messages, generation_config.max_new_tokens)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
709 |
if stream:
|
710 |
streamer = TextIterStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
711 |
Thread(target=self.generate, kwargs=dict(
|
|
|
706 |
generation_config: Optional[GenerationConfig]=None):
|
707 |
generation_config = generation_config or self.generation_config
|
708 |
input_ids = build_chat_input(self, tokenizer, messages, generation_config.max_new_tokens)
|
709 |
+
if stream:
|
710 |
+
streamer = TextIterStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
711 |
+
Thread(target=self.generate, kwargs=dict(
|
712 |
+
inputs=input_ids, streamer=streamer,
|
713 |
+
generation_config=generation_config,
|
714 |
+
)).start()
|
715 |
+
return streamer
|
716 |
+
else:
|
717 |
+
outputs = self.generate(input_ids, generation_config=generation_config)
|
718 |
+
response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
|
719 |
+
return response
|
720 |
+
|
721 |
+
def HuatuoChat(self, tokenizer, messages: List[dict], stream=False,
|
722 |
+
generation_config: Optional[GenerationConfig]=None):
|
723 |
+
generation_config = generation_config or self.generation_config
|
724 |
+
input_ids = build_chat_input(self, tokenizer, messages, generation_config.max_new_tokens)
|
725 |
if stream:
|
726 |
streamer = TextIterStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
727 |
Thread(target=self.generate, kwargs=dict(
|