vlff李飞飞 commited on
Commit
5b67568
1 Parent(s): 7b2782a
qwen_agent/llm/qwen_oai.py CHANGED
@@ -1,490 +1,18 @@
1
  import os
2
  from typing import Dict, Iterator, List, Optional
3
-
4
  import openai
5
-
6
  from qwen_agent.llm.base import BaseChatModel
7
-
8
- import re
9
- import copy
10
- import json
11
- import time
12
- from contextlib import asynccontextmanager
13
  from typing import Dict, List, Literal, Optional, Union
14
- import torch
15
- from pydantic import BaseModel, Field
16
- from sse_starlette.sse import EventSourceResponse
17
- from transformers import AutoTokenizer, AutoModelForCausalLM
18
- from transformers.generation import GenerationConfig
19
-
20
-
21
- def _gc(forced: bool = False, disable_gc: bool = True):
22
- if disable_gc and not forced:
23
- return
24
-
25
- import gc
26
- gc.collect()
27
- if torch.cuda.is_available():
28
- torch.cuda.empty_cache()
29
-
30
-
31
- class ChatMessage(BaseModel):
32
- role: Literal["user", "assistant", "system", "function"]
33
- content: Optional[str]
34
- function_call: Optional[Dict] = None
35
-
36
-
37
- class DeltaMessage(BaseModel):
38
- role: Optional[Literal["user", "assistant", "system"]] = None
39
- content: Optional[str] = None
40
-
41
-
42
- class ChatCompletionRequest(BaseModel):
43
- model: str
44
- messages: List[ChatMessage]
45
- functions: Optional[List[Dict]] = None
46
- temperature: Optional[float] = None
47
- top_p: Optional[float] = None
48
- max_length: Optional[int] = None
49
- stream: Optional[bool] = False
50
- stop: Optional[List[str]] = None
51
-
52
-
53
- class ChatCompletionResponseChoice(BaseModel):
54
- index: int
55
- message: ChatMessage
56
- finish_reason: Literal["stop", "length", "function_call"]
57
-
58
-
59
- class ChatCompletionResponseStreamChoice(BaseModel):
60
- index: int
61
- delta: DeltaMessage
62
- finish_reason: Optional[Literal["stop", "length"]]
63
-
64
-
65
- class ChatCompletionResponse(BaseModel):
66
- model: str
67
- object: Literal["chat.completion", "chat.completion.chunk"]
68
- choices: List[
69
- Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]
70
- ]
71
- created: Optional[int] = Field(default_factory=lambda: int(time.time()))
72
-
73
-
74
- # To work around that unpleasant leading-\n tokenization issue!
75
- def add_extra_stop_words(stop_words):
76
- if stop_words:
77
- _stop_words = []
78
- _stop_words.extend(stop_words)
79
- for x in stop_words:
80
- s = x.lstrip("\n")
81
- if s and (s not in _stop_words):
82
- _stop_words.append(s)
83
- return _stop_words
84
- return stop_words
85
-
86
-
87
- def trim_stop_words(response, stop_words):
88
- if stop_words:
89
- for stop in stop_words:
90
- idx = response.find(stop)
91
- if idx != -1:
92
- response = response[:idx]
93
- return response
94
-
95
-
96
- TOOL_DESC = """{name_for_model}: Call this tool to interact with the {name_for_human} API. What is the {name_for_human} API useful for? {description_for_model} Parameters: {parameters}"""
97
-
98
- REACT_INSTRUCTION = """Answer the following questions as best you can. You have access to the following APIs:
99
-
100
- {tools_text}
101
-
102
- Use the following format:
103
-
104
- Question: the input question you must answer
105
- Thought: you should always think about what to do
106
- Action: the action to take, should be one of [{tools_name_text}]
107
- Action Input: the input to the action
108
- Observation: the result of the action
109
- ... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
110
- Thought: I now know the final answer
111
- Final Answer: the final answer to the original input question
112
-
113
- Begin!"""
114
-
115
- _TEXT_COMPLETION_CMD = object()
116
-
117
-
118
- #
119
- # Temporarily, the system role does not work as expected.
120
- # We advise that you write the setups for role-play in your query,
121
- # i.e., use the user role instead of the system role.
122
- #
123
- # TODO: Use real system role when the model is ready.
124
- #
125
- def parse_messages(messages, functions):
126
- if all(m.role != "user" for m in messages):
127
- raise Exception(f"Invalid request: Expecting at least one user message.", )
128
- messages = copy.deepcopy(messages)
129
- default_system = "You are a helpful assistant."
130
- system = ""
131
- if messages[0].role == "system":
132
- system = messages.pop(0).content.lstrip("\n").rstrip()
133
- if system == default_system:
134
- system = ""
135
-
136
- if functions:
137
- tools_text = []
138
- tools_name_text = []
139
- for func_info in functions:
140
- name = func_info.get("name", "")
141
- name_m = func_info.get("name_for_model", name)
142
- name_h = func_info.get("name_for_human", name)
143
- desc = func_info.get("description", "")
144
- desc_m = func_info.get("description_for_model", desc)
145
- tool = TOOL_DESC.format(
146
- name_for_model=name_m,
147
- name_for_human=name_h,
148
- # Hint: You can add the following format requirements in description:
149
- # "Format the arguments as a JSON object."
150
- # "Enclose the code within triple backticks (`) at the beginning and end of the code."
151
- description_for_model=desc_m,
152
- parameters=json.dumps(func_info["parameters"], ensure_ascii=False),
153
- )
154
- tools_text.append(tool)
155
- tools_name_text.append(name_m)
156
- tools_text = "\n\n".join(tools_text)
157
- tools_name_text = ", ".join(tools_name_text)
158
- system += "\n\n" + REACT_INSTRUCTION.format(
159
- tools_text=tools_text,
160
- tools_name_text=tools_name_text,
161
- )
162
- system = system.lstrip("\n").rstrip()
163
-
164
- dummy_thought = {
165
- "en": "\nThought: I now know the final answer.\nFinal answer: ",
166
- "zh": "\nThought: 我会作答了。\nFinal answer: ",
167
- }
168
-
169
- _messages = messages
170
- messages = []
171
- for m_idx, m in enumerate(_messages):
172
- role, content, func_call = m.role, m.content, m.function_call
173
- if content:
174
- content = content.lstrip("\n").rstrip()
175
- if role == "function":
176
- if (len(messages) == 0) or (messages[-1].role != "assistant"):
177
- raise Exception("Invalid request: Expecting role assistant before role function.")
178
- messages[-1].content += f"\nObservation: {content}"
179
- if m_idx == len(_messages) - 1:
180
- messages[-1].content += "\nThought:"
181
- elif role == "assistant":
182
- if len(messages) == 0:
183
- raise Exception(f"Invalid request: Expecting role user before role assistant.")
184
- last_msg = messages[-1].content
185
- last_msg_has_zh = len(re.findall(r"[\u4e00-\u9fff]+", last_msg)) > 0
186
- if func_call is None:
187
- if functions:
188
- content = dummy_thought["zh" if last_msg_has_zh else "en"] + content
189
- else:
190
- f_name, f_args = func_call["name"], func_call["arguments"]
191
- if not content:
192
- if last_msg_has_zh:
193
- content = f"Thought: 我可以使用 {f_name} API。"
194
- else:
195
- content = f"Thought: I can use {f_name}."
196
- content = f"\n{content}\nAction: {f_name}\nAction Input: {f_args}"
197
- if messages[-1].role == "user":
198
- messages.append(
199
- ChatMessage(role="assistant", content=content.lstrip("\n").rstrip())
200
- )
201
- else:
202
- messages[-1].content += content
203
- elif role == "user":
204
- messages.append(
205
- ChatMessage(role="user", content=content.lstrip("\n").rstrip())
206
- )
207
- else:
208
- raise Exception(
209
- f"Invalid request: Incorrect role {role}."
210
- )
211
-
212
- query = _TEXT_COMPLETION_CMD
213
- if messages[-1].role == "user":
214
- query = messages[-1].content
215
- messages = messages[:-1]
216
-
217
- if len(messages) % 2 != 0:
218
- raise Exception("Invalid request")
219
-
220
- history = [] # [(Q1, A1), (Q2, A2), ..., (Q_last_turn, A_last_turn)]
221
- for i in range(0, len(messages), 2):
222
- if messages[i].role == "user" and messages[i + 1].role == "assistant":
223
- usr_msg = messages[i].content.lstrip("\n").rstrip()
224
- bot_msg = messages[i + 1].content.lstrip("\n").rstrip()
225
- if system and (i == len(messages) - 2):
226
- usr_msg = f"{system}\n\nQuestion: {usr_msg}"
227
- system = ""
228
- for t in dummy_thought.values():
229
- t = t.lstrip("\n")
230
- if bot_msg.startswith(t) and ("\nAction: " in bot_msg):
231
- bot_msg = bot_msg[len(t):]
232
- history.append([usr_msg, bot_msg])
233
- else:
234
- raise Exception("Invalid request: Expecting exactly one user (or function) role before every assistant role.")
235
- if system:
236
- assert query is not _TEXT_COMPLETION_CMD
237
- query = f"{system}\n\nQuestion: {query}"
238
- return query, history
239
-
240
-
241
- def parse_response(response):
242
- func_name, func_args = "", ""
243
- i = response.rfind("\nAction:")
244
- j = response.rfind("\nAction Input:")
245
- k = response.rfind("\nObservation:")
246
- if 0 <= i < j: # If the text has `Action` and `Action input`,
247
- if k < j: # but does not contain `Observation`,
248
- # then it is likely that `Observation` is omitted by the LLM,
249
- # because the output text may have discarded the stop word.
250
- response = response.rstrip() + "\nObservation:" # Add it back.
251
- k = response.rfind("\nObservation:")
252
- func_name = response[i + len("\nAction:"): j].strip()
253
- func_args = response[j + len("\nAction Input:"): k].strip()
254
- if func_name:
255
- choice_data = ChatCompletionResponseChoice(
256
- index=0,
257
- message=ChatMessage(
258
- role="assistant",
259
- content=response[:i],
260
- function_call={"name": func_name, "arguments": func_args},
261
- ),
262
- finish_reason="function_call",
263
- )
264
- return choice_data
265
- z = response.rfind("\nFinal Answer: ")
266
- if z >= 0:
267
- response = response[z + len("\nFinal Answer: "):]
268
- choice_data = ChatCompletionResponseChoice(
269
- index=0,
270
- message=ChatMessage(role="assistant", content=response),
271
- finish_reason="stop",
272
- )
273
- return choice_data
274
-
275
-
276
- # completion mode, not chat mode
277
- def text_complete_last_message(history, stop_words_ids, gen_kwargs):
278
- im_start = "<|im_start|>"
279
- im_end = "<|im_end|>"
280
- prompt = f"{im_start}system\nYou are a helpful assistant.{im_end}"
281
- for i, (query, response) in enumerate(history):
282
- query = query.lstrip("\n").rstrip()
283
- response = response.lstrip("\n").rstrip()
284
- prompt += f"\n{im_start}user\n{query}{im_end}"
285
- prompt += f"\n{im_start}assistant\n{response}{im_end}"
286
- prompt = prompt[: -len(im_end)]
287
-
288
- _stop_words_ids = [tokenizer.encode(im_end)]
289
- if stop_words_ids:
290
- for s in stop_words_ids:
291
- _stop_words_ids.append(s)
292
- stop_words_ids = _stop_words_ids
293
-
294
- input_ids = torch.tensor([tokenizer.encode(prompt)]).to(qmodel.device)
295
- output = qmodel.generate(input_ids, stop_words_ids=stop_words_ids, **gen_kwargs).tolist()[0]
296
- output = tokenizer.decode(output, errors="ignore")
297
- assert output.startswith(prompt)
298
- output = output[len(prompt):]
299
- output = trim_stop_words(output, ["<|endoftext|>", im_end])
300
- print(f"<completion>\n{prompt}\n<!-- *** -->\n{output}\n</completion>")
301
- return output
302
-
303
-
304
- def create_chat_completion(request: ChatCompletionRequest, qmodel, tokenizer):
305
-
306
- gen_kwargs = {}
307
- if request.temperature is not None:
308
- if request.temperature < 0.01:
309
- gen_kwargs['top_k'] = 1 # greedy decoding
310
- else:
311
- # Not recommended. Please tune top_p instead.
312
- gen_kwargs['temperature'] = request.temperature
313
- if request.top_p is not None:
314
- gen_kwargs['top_p'] = request.top_p
315
-
316
- stop_words = add_extra_stop_words(request.stop)
317
- if request.functions:
318
- stop_words = stop_words or []
319
- if "Observation:" not in stop_words:
320
- stop_words.append("Observation:")
321
-
322
- query, history = parse_messages(request.messages, request.functions)
323
-
324
- if request.stream:
325
- if request.functions:
326
- raise Exception("Invalid request: Function calling is not yet implemented for stream mode.")
327
- generate = predict(query, history, request.model, stop_words, gen_kwargs, qmodel, tokenizer)
328
- return generate
329
- # return EventSourceResponse(generate, media_type="text/event-stream")
330
-
331
- stop_words_ids = [tokenizer.encode(s) for s in stop_words] if stop_words else None
332
- if query is _TEXT_COMPLETION_CMD:
333
- response = text_complete_last_message(history, stop_words_ids=stop_words_ids, gen_kwargs=gen_kwargs)
334
- else:
335
- response, _ = qmodel.chat(
336
- tokenizer,
337
- query,
338
- history=history,
339
- stop_words_ids=stop_words_ids,
340
- **gen_kwargs
341
- )
342
- print(f"<chat>\n{history}\n{query}\n<!-- *** -->\n{response}\n</chat>")
343
- _gc()
344
-
345
- response = trim_stop_words(response, stop_words)
346
- if request.functions:
347
- choice_data = parse_response(response)
348
- else:
349
- choice_data = ChatCompletionResponseChoice(
350
- index=0,
351
- message=ChatMessage(role="assistant", content=response),
352
- finish_reason="stop",
353
- )
354
- return ChatCompletionResponse(
355
- model=request.model, choices=[choice_data], object="chat.completion"
356
- )
357
-
358
-
359
- def _dump_json(data: BaseModel, *args, **kwargs) -> str:
360
- try:
361
- return data.model_dump_json(*args, **kwargs)
362
- except AttributeError: # pydantic<2.0.0
363
- return data.json(*args, **kwargs) # noqa
364
-
365
-
366
- def predict(
367
- query: str, history: List[List[str]], model_id: str, stop_words: List[str], gen_kwargs: Dict, qmodel, tokenizer
368
- ):
369
- choice_data = ChatCompletionResponseStreamChoice(
370
- index=0, delta=DeltaMessage(role="assistant"), finish_reason=None
371
- )
372
- chunk = ChatCompletionResponse(
373
- model=model_id, choices=[choice_data], object="chat.completion.chunk"
374
- )
375
- # yield "{}".format(_dump_json(chunk, exclude_unset=True))
376
- yield chunk
377
-
378
- current_length = 0
379
- stop_words_ids = [tokenizer.encode(s) for s in stop_words] if stop_words else None
380
- if stop_words:
381
- # TODO: It's a little bit tricky to trim stop words in the stream mode.
382
- raise Exception("Invalid request: custom stop words are not yet supported for stream mode.", )
383
- response_generator = qmodel.chat_stream(
384
- tokenizer, query, history=history, stop_words_ids=stop_words_ids, **gen_kwargs
385
- )
386
- for new_response in response_generator:
387
- if len(new_response) == current_length:
388
- continue
389
-
390
- new_text = new_response[current_length:]
391
- current_length = len(new_response)
392
-
393
- choice_data = ChatCompletionResponseStreamChoice(
394
- index=0, delta=DeltaMessage(content=new_text), finish_reason=None
395
- )
396
- chunk = ChatCompletionResponse(
397
- model=model_id, choices=[choice_data], object="chat.completion.chunk"
398
- )
399
- # yield "{}".format(_dump_json(chunk, exclude_unset=True))
400
- yield chunk
401
-
402
- choice_data = ChatCompletionResponseStreamChoice(
403
- index=0, delta=DeltaMessage(), finish_reason="stop"
404
- )
405
- chunk = ChatCompletionResponse(
406
- model=model_id, choices=[choice_data], object="chat.completion.chunk"
407
- )
408
- # yield "{}".format(_dump_json(chunk, exclude_unset=True))
409
- yield chunk
410
- # yield "[DONE]"
411
-
412
- _gc()
413
 
414
 
415
  class QwenChatAsOAI(BaseChatModel):
416
 
417
  def __init__(self, model: str, api_key: str, model_server: str):
418
- self.model = model
419
  super().__init__()
420
- tokenizer = AutoTokenizer.from_pretrained(
421
- self.model,
422
- trust_remote_code=True,
423
- resume_download=True,
424
- )
425
- device_map = "cpu"
426
- # device_map = "auto"
427
- qmodel = AutoModelForCausalLM.from_pretrained(
428
- self.model,
429
- device_map=device_map,
430
- trust_remote_code=True,
431
- resume_download=True,
432
- ).eval()
433
-
434
- qmodel.generation_config = GenerationConfig.from_pretrained(
435
- self.model,
436
- trust_remote_code=True,
437
- resume_download=True,
438
- )
439
- self.qmodel = qmodel
440
- self.tokenizer = tokenizer
441
 
442
- def _chat_stream(
443
- self,
444
- messages: List[Dict],
445
- stop: Optional[List[str]] = None,
446
- ) -> Iterator[str]:
447
- _request = ChatCompletionRequest(model=self.model,
448
- messages=messages,
449
- stop=stop,
450
- stream=True)
451
- response = create_chat_completion(_request, self.qmodel, self.tokenizer)
452
- # TODO: error handling
453
- for chunk in response:
454
- if hasattr(chunk.choices[0].delta, 'content'):
455
- yield chunk.choices[0].delta.content
456
-
457
- def _chat_no_stream(
458
- self,
459
- messages: List[Dict],
460
- stop: Optional[List[str]] = None,
461
- ) -> str:
462
- _request = ChatCompletionRequest(model=self.model, messages=messages, stop=stop, stream=False)
463
- response = create_chat_completion(_request, self.qmodel, self.tokenizer)
464
- # TODO: error handling
465
- return response.choices[0].message.content
466
-
467
- def chat_with_functions(self,
468
- messages: List[Dict],
469
- functions: Optional[List[Dict]] = None) -> Dict:
470
- if functions:
471
- _request = ChatCompletionRequest(model=self.model, messages=messages, functions=functions)
472
- response = create_chat_completion(_request, self.qmodel, self.tokenizer)
473
- else:
474
- _request = ChatCompletionRequest(model=self.model, messages=messages)
475
- response = create_chat_completion(_request, self.qmodel, self.tokenizer)
476
- # TODO: error handling
477
- return response.choices[0].message.model_dump()
478
-
479
-
480
- class QwenChatAsOAI1(BaseChatModel):
481
-
482
- def __init__(self, model: str, api_key: str, model_server: str):
483
- super().__init__()
484
  if model_server.strip().lower() != 'openai':
485
- openai.api_base = model_server
486
- openai.api_key = api_key.strip() or os.getenv('OPENAI_API_KEY',
487
- default='EMPTY')
488
  self.model = model
489
 
490
  def _chat_stream(
 
1
  import os
2
  from typing import Dict, Iterator, List, Optional
 
3
  import openai
 
4
  from qwen_agent.llm.base import BaseChatModel
 
 
 
 
 
 
5
  from typing import Dict, List, Literal, Optional, Union
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
 
8
  class QwenChatAsOAI(BaseChatModel):
9
 
10
  def __init__(self, model: str, api_key: str, model_server: str):
 
11
  super().__init__()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  if model_server.strip().lower() != 'openai':
14
+ openai.api_base = os.getenv('OPENAI_API_BASE', model_server)
15
+ openai.api_key = api_key.strip() or os.getenv('OPENAI_API_KEY', 'EMPTY')
 
16
  self.model = model
17
 
18
  def _chat_stream(
qwen_agent/llm/qwen_oai_bak.py ADDED
@@ -0,0 +1,527 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict, Iterator, List, Optional
3
+
4
+ import openai
5
+
6
+ from qwen_agent.llm.base import BaseChatModel
7
+
8
+ import re
9
+ import copy
10
+ import json
11
+ import time
12
+ from contextlib import asynccontextmanager
13
+ from typing import Dict, List, Literal, Optional, Union
14
+ import torch
15
+ from pydantic import BaseModel, Field
16
+ from sse_starlette.sse import EventSourceResponse
17
+ from transformers import AutoTokenizer, AutoModelForCausalLM
18
+ from transformers.generation import GenerationConfig
19
+
20
+
21
+ def _gc(forced: bool = False, disable_gc: bool = True):
22
+ if disable_gc and not forced:
23
+ return
24
+
25
+ import gc
26
+ gc.collect()
27
+ if torch.cuda.is_available():
28
+ torch.cuda.empty_cache()
29
+
30
+
31
+ class ChatMessage(BaseModel):
32
+ role: Literal["user", "assistant", "system", "function"]
33
+ content: Optional[str]
34
+ function_call: Optional[Dict] = None
35
+
36
+
37
+ class DeltaMessage(BaseModel):
38
+ role: Optional[Literal["user", "assistant", "system"]] = None
39
+ content: Optional[str] = None
40
+
41
+
42
+ class ChatCompletionRequest(BaseModel):
43
+ model: str
44
+ messages: List[ChatMessage]
45
+ functions: Optional[List[Dict]] = None
46
+ temperature: Optional[float] = None
47
+ top_p: Optional[float] = None
48
+ max_length: Optional[int] = None
49
+ stream: Optional[bool] = False
50
+ stop: Optional[List[str]] = None
51
+
52
+
53
+ class ChatCompletionResponseChoice(BaseModel):
54
+ index: int
55
+ message: ChatMessage
56
+ finish_reason: Literal["stop", "length", "function_call"]
57
+
58
+
59
+ class ChatCompletionResponseStreamChoice(BaseModel):
60
+ index: int
61
+ delta: DeltaMessage
62
+ finish_reason: Optional[Literal["stop", "length"]]
63
+
64
+
65
+ class ChatCompletionResponse(BaseModel):
66
+ model: str
67
+ object: Literal["chat.completion", "chat.completion.chunk"]
68
+ choices: List[
69
+ Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]
70
+ ]
71
+ created: Optional[int] = Field(default_factory=lambda: int(time.time()))
72
+
73
+
74
+ # To work around that unpleasant leading-\n tokenization issue!
75
+ def add_extra_stop_words(stop_words):
76
+ if stop_words:
77
+ _stop_words = []
78
+ _stop_words.extend(stop_words)
79
+ for x in stop_words:
80
+ s = x.lstrip("\n")
81
+ if s and (s not in _stop_words):
82
+ _stop_words.append(s)
83
+ return _stop_words
84
+ return stop_words
85
+
86
+
87
+ def trim_stop_words(response, stop_words):
88
+ if stop_words:
89
+ for stop in stop_words:
90
+ idx = response.find(stop)
91
+ if idx != -1:
92
+ response = response[:idx]
93
+ return response
94
+
95
+
96
+ TOOL_DESC = """{name_for_model}: Call this tool to interact with the {name_for_human} API. What is the {name_for_human} API useful for? {description_for_model} Parameters: {parameters}"""
97
+
98
+ REACT_INSTRUCTION = """Answer the following questions as best you can. You have access to the following APIs:
99
+
100
+ {tools_text}
101
+
102
+ Use the following format:
103
+
104
+ Question: the input question you must answer
105
+ Thought: you should always think about what to do
106
+ Action: the action to take, should be one of [{tools_name_text}]
107
+ Action Input: the input to the action
108
+ Observation: the result of the action
109
+ ... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
110
+ Thought: I now know the final answer
111
+ Final Answer: the final answer to the original input question
112
+
113
+ Begin!"""
114
+
115
+ _TEXT_COMPLETION_CMD = object()
116
+
117
+
118
+ #
119
+ # Temporarily, the system role does not work as expected.
120
+ # We advise that you write the setups for role-play in your query,
121
+ # i.e., use the user role instead of the system role.
122
+ #
123
+ # TODO: Use real system role when the model is ready.
124
+ #
125
+ def parse_messages(messages, functions):
126
+ if all(m.role != "user" for m in messages):
127
+ raise Exception(f"Invalid request: Expecting at least one user message.", )
128
+ messages = copy.deepcopy(messages)
129
+ default_system = "You are a helpful assistant."
130
+ system = ""
131
+ if messages[0].role == "system":
132
+ system = messages.pop(0).content.lstrip("\n").rstrip()
133
+ if system == default_system:
134
+ system = ""
135
+
136
+ if functions:
137
+ tools_text = []
138
+ tools_name_text = []
139
+ for func_info in functions:
140
+ name = func_info.get("name", "")
141
+ name_m = func_info.get("name_for_model", name)
142
+ name_h = func_info.get("name_for_human", name)
143
+ desc = func_info.get("description", "")
144
+ desc_m = func_info.get("description_for_model", desc)
145
+ tool = TOOL_DESC.format(
146
+ name_for_model=name_m,
147
+ name_for_human=name_h,
148
+ # Hint: You can add the following format requirements in description:
149
+ # "Format the arguments as a JSON object."
150
+ # "Enclose the code within triple backticks (`) at the beginning and end of the code."
151
+ description_for_model=desc_m,
152
+ parameters=json.dumps(func_info["parameters"], ensure_ascii=False),
153
+ )
154
+ tools_text.append(tool)
155
+ tools_name_text.append(name_m)
156
+ tools_text = "\n\n".join(tools_text)
157
+ tools_name_text = ", ".join(tools_name_text)
158
+ system += "\n\n" + REACT_INSTRUCTION.format(
159
+ tools_text=tools_text,
160
+ tools_name_text=tools_name_text,
161
+ )
162
+ system = system.lstrip("\n").rstrip()
163
+
164
+ dummy_thought = {
165
+ "en": "\nThought: I now know the final answer.\nFinal answer: ",
166
+ "zh": "\nThought: 我会作答了。\nFinal answer: ",
167
+ }
168
+
169
+ _messages = messages
170
+ messages = []
171
+ for m_idx, m in enumerate(_messages):
172
+ role, content, func_call = m.role, m.content, m.function_call
173
+ if content:
174
+ content = content.lstrip("\n").rstrip()
175
+ if role == "function":
176
+ if (len(messages) == 0) or (messages[-1].role != "assistant"):
177
+ raise Exception("Invalid request: Expecting role assistant before role function.")
178
+ messages[-1].content += f"\nObservation: {content}"
179
+ if m_idx == len(_messages) - 1:
180
+ messages[-1].content += "\nThought:"
181
+ elif role == "assistant":
182
+ if len(messages) == 0:
183
+ raise Exception(f"Invalid request: Expecting role user before role assistant.")
184
+ last_msg = messages[-1].content
185
+ last_msg_has_zh = len(re.findall(r"[\u4e00-\u9fff]+", last_msg)) > 0
186
+ if func_call is None:
187
+ if functions:
188
+ content = dummy_thought["zh" if last_msg_has_zh else "en"] + content
189
+ else:
190
+ f_name, f_args = func_call["name"], func_call["arguments"]
191
+ if not content:
192
+ if last_msg_has_zh:
193
+ content = f"Thought: 我可以使用 {f_name} API。"
194
+ else:
195
+ content = f"Thought: I can use {f_name}."
196
+ content = f"\n{content}\nAction: {f_name}\nAction Input: {f_args}"
197
+ if messages[-1].role == "user":
198
+ messages.append(
199
+ ChatMessage(role="assistant", content=content.lstrip("\n").rstrip())
200
+ )
201
+ else:
202
+ messages[-1].content += content
203
+ elif role == "user":
204
+ messages.append(
205
+ ChatMessage(role="user", content=content.lstrip("\n").rstrip())
206
+ )
207
+ else:
208
+ raise Exception(
209
+ f"Invalid request: Incorrect role {role}."
210
+ )
211
+
212
+ query = _TEXT_COMPLETION_CMD
213
+ if messages[-1].role == "user":
214
+ query = messages[-1].content
215
+ messages = messages[:-1]
216
+
217
+ if len(messages) % 2 != 0:
218
+ raise Exception("Invalid request")
219
+
220
+ history = [] # [(Q1, A1), (Q2, A2), ..., (Q_last_turn, A_last_turn)]
221
+ for i in range(0, len(messages), 2):
222
+ if messages[i].role == "user" and messages[i + 1].role == "assistant":
223
+ usr_msg = messages[i].content.lstrip("\n").rstrip()
224
+ bot_msg = messages[i + 1].content.lstrip("\n").rstrip()
225
+ if system and (i == len(messages) - 2):
226
+ usr_msg = f"{system}\n\nQuestion: {usr_msg}"
227
+ system = ""
228
+ for t in dummy_thought.values():
229
+ t = t.lstrip("\n")
230
+ if bot_msg.startswith(t) and ("\nAction: " in bot_msg):
231
+ bot_msg = bot_msg[len(t):]
232
+ history.append([usr_msg, bot_msg])
233
+ else:
234
+ raise Exception("Invalid request: Expecting exactly one user (or function) role before every assistant role.")
235
+ if system:
236
+ assert query is not _TEXT_COMPLETION_CMD
237
+ query = f"{system}\n\nQuestion: {query}"
238
+ return query, history
239
+
240
+
241
+ def parse_response(response):
242
+ func_name, func_args = "", ""
243
+ i = response.rfind("\nAction:")
244
+ j = response.rfind("\nAction Input:")
245
+ k = response.rfind("\nObservation:")
246
+ if 0 <= i < j: # If the text has `Action` and `Action input`,
247
+ if k < j: # but does not contain `Observation`,
248
+ # then it is likely that `Observation` is omitted by the LLM,
249
+ # because the output text may have discarded the stop word.
250
+ response = response.rstrip() + "\nObservation:" # Add it back.
251
+ k = response.rfind("\nObservation:")
252
+ func_name = response[i + len("\nAction:"): j].strip()
253
+ func_args = response[j + len("\nAction Input:"): k].strip()
254
+ if func_name:
255
+ choice_data = ChatCompletionResponseChoice(
256
+ index=0,
257
+ message=ChatMessage(
258
+ role="assistant",
259
+ content=response[:i],
260
+ function_call={"name": func_name, "arguments": func_args},
261
+ ),
262
+ finish_reason="function_call",
263
+ )
264
+ return choice_data
265
+ z = response.rfind("\nFinal Answer: ")
266
+ if z >= 0:
267
+ response = response[z + len("\nFinal Answer: "):]
268
+ choice_data = ChatCompletionResponseChoice(
269
+ index=0,
270
+ message=ChatMessage(role="assistant", content=response),
271
+ finish_reason="stop",
272
+ )
273
+ return choice_data
274
+
275
+
276
+ # completion mode, not chat mode
277
+ def text_complete_last_message(history, stop_words_ids, gen_kwargs):
278
+ im_start = "<|im_start|>"
279
+ im_end = "<|im_end|>"
280
+ prompt = f"{im_start}system\nYou are a helpful assistant.{im_end}"
281
+ for i, (query, response) in enumerate(history):
282
+ query = query.lstrip("\n").rstrip()
283
+ response = response.lstrip("\n").rstrip()
284
+ prompt += f"\n{im_start}user\n{query}{im_end}"
285
+ prompt += f"\n{im_start}assistant\n{response}{im_end}"
286
+ prompt = prompt[: -len(im_end)]
287
+
288
+ _stop_words_ids = [tokenizer.encode(im_end)]
289
+ if stop_words_ids:
290
+ for s in stop_words_ids:
291
+ _stop_words_ids.append(s)
292
+ stop_words_ids = _stop_words_ids
293
+
294
+ input_ids = torch.tensor([tokenizer.encode(prompt)]).to(qmodel.device)
295
+ output = qmodel.generate(input_ids, stop_words_ids=stop_words_ids, **gen_kwargs).tolist()[0]
296
+ output = tokenizer.decode(output, errors="ignore")
297
+ assert output.startswith(prompt)
298
+ output = output[len(prompt):]
299
+ output = trim_stop_words(output, ["<|endoftext|>", im_end])
300
+ print(f"<completion>\n{prompt}\n<!-- *** -->\n{output}\n</completion>")
301
+ return output
302
+
303
+
304
+ def create_chat_completion(request: ChatCompletionRequest, qmodel, tokenizer):
305
+
306
+ gen_kwargs = {}
307
+ if request.temperature is not None:
308
+ if request.temperature < 0.01:
309
+ gen_kwargs['top_k'] = 1 # greedy decoding
310
+ else:
311
+ # Not recommended. Please tune top_p instead.
312
+ gen_kwargs['temperature'] = request.temperature
313
+ if request.top_p is not None:
314
+ gen_kwargs['top_p'] = request.top_p
315
+
316
+ stop_words = add_extra_stop_words(request.stop)
317
+ if request.functions:
318
+ stop_words = stop_words or []
319
+ if "Observation:" not in stop_words:
320
+ stop_words.append("Observation:")
321
+
322
+ query, history = parse_messages(request.messages, request.functions)
323
+
324
+ if request.stream:
325
+ if request.functions:
326
+ raise Exception("Invalid request: Function calling is not yet implemented for stream mode.")
327
+ generate = predict(query, history, request.model, stop_words, gen_kwargs, qmodel, tokenizer)
328
+ return generate
329
+ # return EventSourceResponse(generate, media_type="text/event-stream")
330
+
331
+ stop_words_ids = [tokenizer.encode(s) for s in stop_words] if stop_words else None
332
+ if query is _TEXT_COMPLETION_CMD:
333
+ response = text_complete_last_message(history, stop_words_ids=stop_words_ids, gen_kwargs=gen_kwargs)
334
+ else:
335
+ response, _ = qmodel.chat(
336
+ tokenizer,
337
+ query,
338
+ history=history,
339
+ stop_words_ids=stop_words_ids,
340
+ **gen_kwargs
341
+ )
342
+ print(f"<chat>\n{history}\n{query}\n<!-- *** -->\n{response}\n</chat>")
343
+ _gc()
344
+
345
+ response = trim_stop_words(response, stop_words)
346
+ if request.functions:
347
+ choice_data = parse_response(response)
348
+ else:
349
+ choice_data = ChatCompletionResponseChoice(
350
+ index=0,
351
+ message=ChatMessage(role="assistant", content=response),
352
+ finish_reason="stop",
353
+ )
354
+ return ChatCompletionResponse(
355
+ model=request.model, choices=[choice_data], object="chat.completion"
356
+ )
357
+
358
+
359
+ def _dump_json(data: BaseModel, *args, **kwargs) -> str:
360
+ try:
361
+ return data.model_dump_json(*args, **kwargs)
362
+ except AttributeError: # pydantic<2.0.0
363
+ return data.json(*args, **kwargs) # noqa
364
+
365
+
366
+ def predict(
367
+ query: str, history: List[List[str]], model_id: str, stop_words: List[str], gen_kwargs: Dict, qmodel, tokenizer
368
+ ):
369
+ choice_data = ChatCompletionResponseStreamChoice(
370
+ index=0, delta=DeltaMessage(role="assistant"), finish_reason=None
371
+ )
372
+ chunk = ChatCompletionResponse(
373
+ model=model_id, choices=[choice_data], object="chat.completion.chunk"
374
+ )
375
+ # yield "{}".format(_dump_json(chunk, exclude_unset=True))
376
+ yield chunk
377
+
378
+ current_length = 0
379
+ stop_words_ids = [tokenizer.encode(s) for s in stop_words] if stop_words else None
380
+ if stop_words:
381
+ # TODO: It's a little bit tricky to trim stop words in the stream mode.
382
+ raise Exception("Invalid request: custom stop words are not yet supported for stream mode.", )
383
+ response_generator = qmodel.chat_stream(
384
+ tokenizer, query, history=history, stop_words_ids=stop_words_ids, **gen_kwargs
385
+ )
386
+ for new_response in response_generator:
387
+ if len(new_response) == current_length:
388
+ continue
389
+
390
+ new_text = new_response[current_length:]
391
+ current_length = len(new_response)
392
+
393
+ choice_data = ChatCompletionResponseStreamChoice(
394
+ index=0, delta=DeltaMessage(content=new_text), finish_reason=None
395
+ )
396
+ chunk = ChatCompletionResponse(
397
+ model=model_id, choices=[choice_data], object="chat.completion.chunk"
398
+ )
399
+ # yield "{}".format(_dump_json(chunk, exclude_unset=True))
400
+ yield chunk
401
+
402
+ choice_data = ChatCompletionResponseStreamChoice(
403
+ index=0, delta=DeltaMessage(), finish_reason="stop"
404
+ )
405
+ chunk = ChatCompletionResponse(
406
+ model=model_id, choices=[choice_data], object="chat.completion.chunk"
407
+ )
408
+ # yield "{}".format(_dump_json(chunk, exclude_unset=True))
409
+ yield chunk
410
+ # yield "[DONE]"
411
+
412
+ _gc()
413
+
414
+
415
+ class QwenChatAsOAI(BaseChatModel):
416
+
417
+ def __init__(self, model: str, api_key: str, model_server: str):
418
+ self.model = model
419
+ super().__init__()
420
+ tokenizer = AutoTokenizer.from_pretrained(
421
+ self.model,
422
+ trust_remote_code=True,
423
+ resume_download=True,
424
+ )
425
+ device_map = "cpu"
426
+ # device_map = "auto"
427
+ qmodel = AutoModelForCausalLM.from_pretrained(
428
+ self.model,
429
+ device_map=device_map,
430
+ trust_remote_code=True,
431
+ resume_download=True,
432
+ ).eval()
433
+
434
+ qmodel.generation_config = GenerationConfig.from_pretrained(
435
+ self.model,
436
+ trust_remote_code=True,
437
+ resume_download=True,
438
+ )
439
+ self.qmodel = qmodel
440
+ self.tokenizer = tokenizer
441
+
442
+ def _chat_stream(
443
+ self,
444
+ messages: List[Dict],
445
+ stop: Optional[List[str]] = None,
446
+ ) -> Iterator[str]:
447
+ _request = ChatCompletionRequest(model=self.model,
448
+ messages=messages,
449
+ stop=stop,
450
+ stream=True)
451
+ response = create_chat_completion(_request, self.qmodel, self.tokenizer)
452
+ # TODO: error handling
453
+ for chunk in response:
454
+ if hasattr(chunk.choices[0].delta, 'content'):
455
+ yield chunk.choices[0].delta.content
456
+
457
+ def _chat_no_stream(
458
+ self,
459
+ messages: List[Dict],
460
+ stop: Optional[List[str]] = None,
461
+ ) -> str:
462
+ _request = ChatCompletionRequest(model=self.model, messages=messages, stop=stop, stream=False)
463
+ response = create_chat_completion(_request, self.qmodel, self.tokenizer)
464
+ # TODO: error handling
465
+ return response.choices[0].message.content
466
+
467
+ def chat_with_functions(self,
468
+ messages: List[Dict],
469
+ functions: Optional[List[Dict]] = None) -> Dict:
470
+ if functions:
471
+ _request = ChatCompletionRequest(model=self.model, messages=messages, functions=functions)
472
+ response = create_chat_completion(_request, self.qmodel, self.tokenizer)
473
+ else:
474
+ _request = ChatCompletionRequest(model=self.model, messages=messages)
475
+ response = create_chat_completion(_request, self.qmodel, self.tokenizer)
476
+ # TODO: error handling
477
+ return response.choices[0].message.model_dump()
478
+
479
+
480
+ class QwenChatAsOAI1(BaseChatModel):
481
+
482
+ def __init__(self, model: str, api_key: str, model_server: str):
483
+ super().__init__()
484
+ if model_server.strip().lower() != 'openai':
485
+ openai.api_base = model_server
486
+ openai.api_key = api_key.strip() or os.getenv('OPENAI_API_KEY',
487
+ default='EMPTY')
488
+ self.model = model
489
+
490
+ def _chat_stream(
491
+ self,
492
+ messages: List[Dict],
493
+ stop: Optional[List[str]] = None,
494
+ ) -> Iterator[str]:
495
+ response = openai.ChatCompletion.create(model=self.model,
496
+ messages=messages,
497
+ stop=stop,
498
+ stream=True)
499
+ # TODO: error handling
500
+ for chunk in response:
501
+ if hasattr(chunk.choices[0].delta, 'content'):
502
+ yield chunk.choices[0].delta.content
503
+
504
+ def _chat_no_stream(
505
+ self,
506
+ messages: List[Dict],
507
+ stop: Optional[List[str]] = None,
508
+ ) -> str:
509
+ response = openai.ChatCompletion.create(model=self.model,
510
+ messages=messages,
511
+ stop=stop,
512
+ stream=False)
513
+ # TODO: error handling
514
+ return response.choices[0].message.content
515
+
516
+ def chat_with_functions(self,
517
+ messages: List[Dict],
518
+ functions: Optional[List[Dict]] = None) -> Dict:
519
+ if functions:
520
+ response = openai.ChatCompletion.create(model=self.model,
521
+ messages=messages,
522
+ functions=functions)
523
+ else:
524
+ response = openai.ChatCompletion.create(model=self.model,
525
+ messages=messages)
526
+ # TODO: error handling
527
+ return response.choices[0].message