Fix streaming_chat
Browse files- modeling_internlm.py +53 -21
modeling_internlm.py
CHANGED
@@ -20,6 +20,7 @@
|
|
20 |
""" PyTorch InternLM model."""
|
21 |
import math
|
22 |
from typing import List, Optional, Tuple, Union
|
|
|
23 |
|
24 |
import torch
|
25 |
import torch.utils.checkpoint
|
@@ -784,7 +785,6 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
|
|
784 |
do_sample: bool = True,
|
785 |
temperature: float = 0.8,
|
786 |
top_p: float = 0.8,
|
787 |
-
eos_token_id = (2, 103028),
|
788 |
**kwargs):
|
789 |
inputs = self.build_inputs(tokenizer, query, history)
|
790 |
inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)}
|
@@ -794,7 +794,6 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
|
|
794 |
do_sample=do_sample,
|
795 |
temperature=temperature,
|
796 |
top_p=top_p,
|
797 |
-
eos_token_id=list(eos_token_id),
|
798 |
**kwargs)
|
799 |
outputs = outputs[0].cpu().tolist()[len(inputs["input_ids"][0]):]
|
800 |
response = tokenizer.decode(outputs, skip_special_tokens=True)
|
@@ -811,38 +810,71 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
|
|
811 |
do_sample: bool = True,
|
812 |
temperature: float = 0.8,
|
813 |
top_p: float = 0.8,
|
814 |
-
eos_token_id = (2, 103028),
|
815 |
**kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
816 |
class ChatStreamer(BaseStreamer):
|
817 |
def __init__(self, tokenizer) -> None:
|
818 |
super().__init__()
|
819 |
self.tokenizer = tokenizer
|
820 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
821 |
def put(self, value):
|
822 |
if len(value.shape) > 1 and value.shape[0] > 1:
|
823 |
raise ValueError("ChatStreamer only supports batch size 1")
|
824 |
elif len(value.shape) > 1:
|
825 |
value = value[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
826 |
token = self.tokenizer.decode([value[-1]], skip_special_tokens=True)
|
827 |
if token.strip() != "<eoa>":
|
828 |
-
|
829 |
-
|
|
|
|
|
830 |
def end(self):
|
831 |
-
|
832 |
-
|
833 |
-
|
834 |
-
|
835 |
-
|
836 |
-
|
837 |
-
|
838 |
-
|
839 |
-
|
840 |
-
|
841 |
-
|
842 |
-
|
843 |
-
|
844 |
-
|
845 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
846 |
|
847 |
@add_start_docstrings(
|
848 |
"""
|
|
|
20 |
""" PyTorch InternLM model."""
|
21 |
import math
|
22 |
from typing import List, Optional, Tuple, Union
|
23 |
+
import threading, queue
|
24 |
|
25 |
import torch
|
26 |
import torch.utils.checkpoint
|
|
|
785 |
do_sample: bool = True,
|
786 |
temperature: float = 0.8,
|
787 |
top_p: float = 0.8,
|
|
|
788 |
**kwargs):
|
789 |
inputs = self.build_inputs(tokenizer, query, history)
|
790 |
inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)}
|
|
|
794 |
do_sample=do_sample,
|
795 |
temperature=temperature,
|
796 |
top_p=top_p,
|
|
|
797 |
**kwargs)
|
798 |
outputs = outputs[0].cpu().tolist()[len(inputs["input_ids"][0]):]
|
799 |
response = tokenizer.decode(outputs, skip_special_tokens=True)
|
|
|
810 |
do_sample: bool = True,
|
811 |
temperature: float = 0.8,
|
812 |
top_p: float = 0.8,
|
|
|
813 |
**kwargs):
|
814 |
+
"""
|
815 |
+
Return a generator in format: (response, history)
|
816 |
+
Eg.
|
817 |
+
('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')])
|
818 |
+
('你好,有什么可以帮助您的吗?', [('你好', '你好,有什么可以帮助您的吗?')])
|
819 |
+
"""
|
820 |
+
|
821 |
+
response_queue = queue.Queue(maxsize=20)
|
822 |
+
|
823 |
class ChatStreamer(BaseStreamer):
|
824 |
def __init__(self, tokenizer) -> None:
|
825 |
super().__init__()
|
826 |
self.tokenizer = tokenizer
|
827 |
+
self.queue = response_queue
|
828 |
+
self.query = query
|
829 |
+
self.history = history
|
830 |
+
self.response = ""
|
831 |
+
self.received_inputs = False
|
832 |
+
self.queue.put((self.response, history + [(self.query, self.response)]))
|
833 |
+
|
834 |
def put(self, value):
|
835 |
if len(value.shape) > 1 and value.shape[0] > 1:
|
836 |
raise ValueError("ChatStreamer only supports batch size 1")
|
837 |
elif len(value.shape) > 1:
|
838 |
value = value[0]
|
839 |
+
|
840 |
+
if not self.received_inputs:
|
841 |
+
# The first received value is input_ids, ignore here
|
842 |
+
self.received_inputs = True
|
843 |
+
return
|
844 |
+
|
845 |
token = self.tokenizer.decode([value[-1]], skip_special_tokens=True)
|
846 |
if token.strip() != "<eoa>":
|
847 |
+
self.response = self.response + token
|
848 |
+
history = self.history + [(self.query, self.response)]
|
849 |
+
self.queue.put((self.response, history))
|
850 |
+
|
851 |
def end(self):
|
852 |
+
self.queue.put(None)
|
853 |
+
|
854 |
+
def stream_producer():
|
855 |
+
return self.chat(
|
856 |
+
tokenizer=tokenizer,
|
857 |
+
query=query,
|
858 |
+
streamer=ChatStreamer(tokenizer=tokenizer),
|
859 |
+
history=history,
|
860 |
+
max_new_tokens=max_new_tokens,
|
861 |
+
do_sample=do_sample,
|
862 |
+
temperature=temperature,
|
863 |
+
top_p=top_p,
|
864 |
+
**kwargs
|
865 |
+
)
|
866 |
+
|
867 |
+
def consumer():
|
868 |
+
producer = threading.Thread(target=stream_producer)
|
869 |
+
producer.start()
|
870 |
+
while True:
|
871 |
+
res = response_queue.get()
|
872 |
+
if res is None:
|
873 |
+
return
|
874 |
+
yield res
|
875 |
+
|
876 |
+
return consumer()
|
877 |
+
|
878 |
|
879 |
@add_start_docstrings(
|
880 |
"""
|