Spaces:
Runtime error
Runtime error
[update]add main
Browse files
main.py
CHANGED
@@ -73,6 +73,7 @@ def chat_with_llm_non_stream(question: str,
|
|
73 |
history: List[Tuple[str, str]],
|
74 |
pretrained_model_name_or_path: str,
|
75 |
max_new_tokens: int, top_p: float, temperature: float, repetition_penalty: float,
|
|
|
76 |
):
|
77 |
device: str = "cuda" if torch.cuda.is_available() else "cpu"
|
78 |
|
@@ -90,7 +91,8 @@ def chat_with_llm_non_stream(question: str,
|
|
90 |
for input_ids_ in batch_input_ids:
|
91 |
input_ids.extend(input_ids_)
|
92 |
input_ids.append(tokenizer.eos_token_id)
|
93 |
-
input_ids = torch.tensor([input_ids], dtype=torch.long)
|
|
|
94 |
|
95 |
with torch.no_grad():
|
96 |
outputs = model.generate(
|
@@ -114,6 +116,7 @@ def chat_with_llm_streaming(question: str,
|
|
114 |
history: List[Tuple[str, str]],
|
115 |
pretrained_model_name_or_path: str,
|
116 |
max_new_tokens: int, top_p: float, temperature: float, repetition_penalty: float,
|
|
|
117 |
):
|
118 |
device: str = "cuda" if torch.cuda.is_available() else "cpu"
|
119 |
|
@@ -131,7 +134,8 @@ def chat_with_llm_streaming(question: str,
|
|
131 |
for input_ids_ in batch_input_ids:
|
132 |
input_ids.extend(input_ids_)
|
133 |
input_ids.append(tokenizer.eos_token_id)
|
134 |
-
input_ids = torch.tensor([input_ids], dtype=torch.long)
|
|
|
135 |
|
136 |
streamer = TextIteratorStreamer(tokenizer=tokenizer)
|
137 |
|
@@ -190,17 +194,25 @@ def main():
|
|
190 |
temperature = gr.Slider(minimum=0, maximum=1, value=0.35, step=0.01, label="temperature")
|
191 |
with gr.Column(scale=1):
|
192 |
repetition_penalty = gr.Slider(minimum=0, maximum=2, value=1.2, step=0.01, label="repetition_penalty")
|
|
|
|
|
193 |
|
194 |
with gr.Row():
|
195 |
-
model_name = gr.Dropdown(
|
196 |
-
|
197 |
-
|
198 |
-
|
|
|
|
|
|
|
|
|
|
|
199 |
gr.Examples(examples=["你好"], inputs=text_box)
|
200 |
|
201 |
inputs = [
|
202 |
text_box, chatbot, model_name,
|
203 |
max_new_tokens, top_p, temperature, repetition_penalty,
|
|
|
204 |
]
|
205 |
outputs = [
|
206 |
chatbot
|
|
|
73 |
history: List[Tuple[str, str]],
|
74 |
pretrained_model_name_or_path: str,
|
75 |
max_new_tokens: int, top_p: float, temperature: float, repetition_penalty: float,
|
76 |
+
history_max_len: int,
|
77 |
):
|
78 |
device: str = "cuda" if torch.cuda.is_available() else "cpu"
|
79 |
|
|
|
91 |
for input_ids_ in batch_input_ids:
|
92 |
input_ids.extend(input_ids_)
|
93 |
input_ids.append(tokenizer.eos_token_id)
|
94 |
+
input_ids = torch.tensor([input_ids], dtype=torch.long)
|
95 |
+
input_ids = input_ids[:, -history_max_len:].to(device)
|
96 |
|
97 |
with torch.no_grad():
|
98 |
outputs = model.generate(
|
|
|
116 |
history: List[Tuple[str, str]],
|
117 |
pretrained_model_name_or_path: str,
|
118 |
max_new_tokens: int, top_p: float, temperature: float, repetition_penalty: float,
|
119 |
+
history_max_len: int,
|
120 |
):
|
121 |
device: str = "cuda" if torch.cuda.is_available() else "cpu"
|
122 |
|
|
|
134 |
for input_ids_ in batch_input_ids:
|
135 |
input_ids.extend(input_ids_)
|
136 |
input_ids.append(tokenizer.eos_token_id)
|
137 |
+
input_ids = torch.tensor([input_ids], dtype=torch.long)
|
138 |
+
input_ids = input_ids[:, -history_max_len:].to(device)
|
139 |
|
140 |
streamer = TextIteratorStreamer(tokenizer=tokenizer)
|
141 |
|
|
|
194 |
temperature = gr.Slider(minimum=0, maximum=1, value=0.35, step=0.01, label="temperature")
|
195 |
with gr.Column(scale=1):
|
196 |
repetition_penalty = gr.Slider(minimum=0, maximum=2, value=1.2, step=0.01, label="repetition_penalty")
|
197 |
+
with gr.Column(scale=1):
|
198 |
+
history_max_len = gr.Slider(minimum=0, maximum=4096, value=1024, step=1, label="history_max_len")
|
199 |
|
200 |
with gr.Row():
|
201 |
+
model_name = gr.Dropdown(
|
202 |
+
choices=[
|
203 |
+
"Qwen/Qwen-7B-Chat",
|
204 |
+
"THUDM/chatglm2-6b",
|
205 |
+
"baichuan-inc/Baichuan2-7B-Chat",
|
206 |
+
],
|
207 |
+
value="Qwen/Qwen-7B-Chat",
|
208 |
+
label="model_name",
|
209 |
+
)
|
210 |
gr.Examples(examples=["你好"], inputs=text_box)
|
211 |
|
212 |
inputs = [
|
213 |
text_box, chatbot, model_name,
|
214 |
max_new_tokens, top_p, temperature, repetition_penalty,
|
215 |
+
history_max_len
|
216 |
]
|
217 |
outputs = [
|
218 |
chatbot
|