vlff李飞飞 commited on
Commit
2319518
·
1 Parent(s): 8d16531
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +24 -0
  2. .pre-commit-config.yaml +27 -0
  3. Dockerfile +14 -0
  4. LICENSE +53 -0
  5. README_CN.md +252 -0
  6. assets/screenshot-ci.png +0 -0
  7. assets/screenshot-editor-movie.png +0 -0
  8. assets/screenshot-multi-web-qa.png +0 -0
  9. assets/screenshot-pdf-qa.png +0 -0
  10. assets/screenshot-web-qa.png +0 -0
  11. assets/screenshot-writing.png +0 -0
  12. benchmark/README.md +248 -0
  13. benchmark/code_interpreter.py +250 -0
  14. benchmark/config.py +66 -0
  15. benchmark/inference_and_execute.py +280 -0
  16. benchmark/metrics/__init__.py +0 -0
  17. benchmark/metrics/code_execution.py +257 -0
  18. benchmark/metrics/gsm8k.py +54 -0
  19. benchmark/metrics/visualization.py +179 -0
  20. benchmark/models/__init__.py +4 -0
  21. benchmark/models/base.py +17 -0
  22. benchmark/models/dashscope.py +40 -0
  23. benchmark/models/llm.py +26 -0
  24. benchmark/models/qwen.py +36 -0
  25. benchmark/parser/__init__.py +2 -0
  26. benchmark/parser/internlm_parser.py +11 -0
  27. benchmark/parser/react_parser.py +46 -0
  28. benchmark/prompt/__init__.py +4 -0
  29. benchmark/prompt/internlm_react.py +103 -0
  30. benchmark/prompt/llama_react.py +20 -0
  31. benchmark/prompt/qwen_react.py +80 -0
  32. benchmark/prompt/react.py +87 -0
  33. benchmark/requirements.txt +13 -0
  34. benchmark/utils/__init__.py +0 -0
  35. benchmark/utils/code_utils.py +31 -0
  36. benchmark/utils/data_utils.py +28 -0
  37. browser_qwen/background.js +58 -0
  38. browser_qwen/img/copy.png +0 -0
  39. browser_qwen/img/logo.png +0 -0
  40. browser_qwen/img/popup.png +0 -0
  41. browser_qwen/manifest.json +45 -0
  42. browser_qwen/src/content.js +86 -0
  43. browser_qwen/src/popup.html +121 -0
  44. browser_qwen/src/popup.js +65 -0
  45. openai_api.py +564 -0
  46. qwen_agent/__init__.py +0 -0
  47. qwen_agent/actions/__init__.py +13 -0
  48. qwen_agent/actions/base.py +40 -0
  49. qwen_agent/actions/continue_writing.py +35 -0
  50. qwen_agent/actions/expand_writing.py +62 -0
.gitignore ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ env
2
+ *.pyc
3
+ __pycache__
4
+
5
+ .idea
6
+ .vscode
7
+ .DS_Store
8
+
9
+ qwen_agent/llm/gpt.py
10
+ qwen_agent/llm/tools.py
11
+ workspace/*
12
+
13
+ benchmark/log/*
14
+ benchmark/output_data/*
15
+ benchmark/upload_file/*
16
+ benchmark/upload_file_clean/*
17
+ benchmark/eval_data/
18
+ Qwen-Agent
19
+
20
+ docqa/*
21
+ log/*
22
+ ai_builder/*
23
+ qwen_agent.egg-info/*
24
+ build/*
.pre-commit-config.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pycqa/flake8.git
3
+ rev: 5.0.4
4
+ hooks:
5
+ - id: flake8
6
+ args: ["--max-line-length=300"]
7
+ - repo: https://github.com/PyCQA/isort.git
8
+ rev: 5.11.5
9
+ hooks:
10
+ - id: isort
11
+ - repo: https://github.com/pre-commit/mirrors-yapf.git
12
+ rev: v0.32.0
13
+ hooks:
14
+ - id: yapf
15
+ - repo: https://github.com/pre-commit/pre-commit-hooks.git
16
+ rev: v4.3.0
17
+ hooks:
18
+ - id: trailing-whitespace
19
+ - id: check-yaml
20
+ - id: end-of-file-fixer
21
+ - id: requirements-txt-fixer
22
+ - id: double-quote-string-fixer
23
+ - id: check-merge-conflict
24
+ - id: fix-encoding-pragma
25
+ args: ["--remove"]
26
+ - id: mixed-line-ending
27
+ args: ["--fix=lf"]
Dockerfile ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
2
+ # you will also find guides on how best to write your Dockerfile
3
+
4
+ FROM python:3.10
5
+
6
+ WORKDIR /code
7
+
8
+ COPY ./requirements.txt /code/requirements.txt
9
+
10
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
11
+
12
+ COPY . .
13
+
14
+ CMD ["python", "run_server.py", "--llm", "Qwen/Qwen-1_8B-Chat", "--model_server", "http://127.0.0.1:7905/v1", "--server_host", "0.0.0.0", "--workstation_port", "7860"]
LICENSE ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Tongyi Qianwen LICENSE AGREEMENT
2
+
3
+ Tongyi Qianwen Release Date: August 3, 2023
4
+
5
+ By clicking to agree or by using or distributing any portion or element of the Tongyi Qianwen Materials, you will be deemed to have recognized and accepted the content of this Agreement, which is effective immediately.
6
+
7
+ 1. Definitions
8
+ a. This Tongyi Qianwen LICENSE AGREEMENT (this "Agreement") shall mean the terms and conditions for use, reproduction, distribution and modification of the Materials as defined by this Agreement.
9
+ b. "We"(or "Us") shall mean Alibaba Cloud.
10
+ c. "You" (or "Your") shall mean a natural person or legal entity exercising the rights granted by this Agreement and/or using the Materials for any purpose and in any field of use.
11
+ d. "Third Parties" shall mean individuals or legal entities that are not under common control with Us or You.
12
+ e. "Tongyi Qianwen" shall mean the large language models (including Qwen model and Qwen-Chat model), and software and algorithms, consisting of trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing distributed by Us.
13
+ f. "Materials" shall mean, collectively, Alibaba Cloud's proprietary Tongyi Qianwen and Documentation (and any portion thereof) made available under this Agreement.
14
+ g. "Source" form shall mean the preferred form for making modifications, including but not limited to model source code, documentation source, and configuration files.
15
+ h. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation,
16
+ and conversions to other media types.
17
+
18
+ 2. Grant of Rights
19
+ You are granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Alibaba Cloud's intellectual property or other rights owned by Us embodied in the Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the Materials.
20
+
21
+ 3. Redistribution
22
+ You may reproduce and distribute copies of the Materials or derivative works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
23
+ a. You shall give any other recipients of the Materials or derivative works a copy of this Agreement;
24
+ b. You shall cause any modified files to carry prominent notices stating that You changed the files;
25
+ c. You shall retain in all copies of the Materials that You distribute the following attribution notices within a "Notice" text file distributed as a part of such copies: "Tongyi Qianwen is licensed under the Tongyi Qianwen LICENSE AGREEMENT, Copyright (c) Alibaba Cloud. All Rights Reserved."; and
26
+ d. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such derivative works as a whole, provided Your use, reproduction, and distribution of the work otherwise complies with the terms and conditions of this Agreement.
27
+
28
+ 4. Restrictions
29
+ If you are commercially using the Materials, and your product or service has more than 100 million monthly active users, You shall request a license from Us. You cannot exercise your rights under this Agreement without our express authorization.
30
+
31
+ 5. Rules of use
32
+ a. The Materials may be subject to export controls or restrictions in China, the United States or other countries or regions. You shall comply with applicable laws and regulations in your use of the Materials.
33
+ b. You can not use the Materials or any output therefrom to improve any other large language model (excluding Tongyi Qianwen or derivative works thereof).
34
+
35
+ 6. Intellectual Property
36
+ a. We retain ownership of all intellectual property rights in and to the Materials and derivatives made by or for Us. Conditioned upon compliance with the terms and conditions of this Agreement, with respect to any derivative works and modifications of the Materials that are made by you, you are and will be the owner of such derivative works and modifications.
37
+ b. No trademark license is granted to use the trade names, trademarks, service marks, or product names of Us, except as required to fulfill notice requirements under this Agreement or as required for reasonable and customary use in describing and redistributing the Materials.
38
+ c. If you commence a lawsuit or other proceedings (including a cross-claim or counterclaim in a lawsuit) against Us or any entity alleging that the Materials or any output therefrom, or any part of the foregoing, infringe any intellectual property or other right owned or licensable by you, then all licences granted to you under this Agreement shall terminate as of the date such lawsuit or other proceeding is commenced or brought.
39
+
40
+ 7. Disclaimer of Warranty and Limitation of Liability
41
+
42
+ a. We are not obligated to support, update, provide training for, or develop any further version of the Tongyi Qianwen Materials or to grant any license thereto.
43
+ b. THE MATERIALS ARE PROVIDED "AS IS" WITHOUT ANY EXPRESS OR IMPLIED WARRANTY OF ANY KIND INCLUDING WARRANTIES OF MERCHANTABILITY, NONINFRINGEMENT, OR FITNESS FOR A PARTICULAR PURPOSE. WE MAKE NO WARRANTY AND ASSUME NO RESPONSIBILITY FOR THE SAFETY OR STABILITY OF THE MATERIALS AND ANY OUTPUT THEREFROM.
44
+ c. IN NO EVENT SHALL WE BE LIABLE TO YOU FOR ANY DAMAGES, INCLUDING, BUT NOT LIMITED TO ANY DIRECT, OR INDIRECT, SPECIAL OR CONSEQUENTIAL DAMAGES ARISING FROM YOUR USE OR INABILITY TO USE THE MATERIALS OR ANY OUTPUT OF IT, NO MATTER HOW IT’S CAUSED.
45
+ d. You will defend, indemnify and hold harmless Us from and against any claim by any third party arising out of or related to your use or distribution of the Materials.
46
+
47
+ 8. Survival and Termination.
48
+ a. The term of this Agreement shall commence upon your acceptance of this Agreement or access to the Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein.
49
+ b. We may terminate this Agreement if you breach any of the terms or conditions of this Agreement. Upon termination of this Agreement, you must delete and cease use of the Materials. Sections 7 and 9 shall survive the termination of this Agreement.
50
+
51
+ 9. Governing Law and Jurisdiction.
52
+ a. This Agreement and any dispute arising out of or relating to it will be governed by the laws of China, without regard to conflict of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement.
53
+ b. The People's Courts in Hangzhou City shall have exclusive jurisdiction over any dispute arising out of this Agreement.
README_CN.md ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Qwen Agent
3
+ emoji: 📈
4
+ colorFrom: yellow
5
+ colorTo: purple
6
+ sdk: docker
7
+ pinned: false
8
+ license: apache-2.0
9
+ app_port: 7860
10
+ ---
11
+
12
+ 中文 | [English](./README.md)
13
+
14
+ <p align="center">
15
+ <img src="https://qianwen-res.oss-cn-beijing.aliyuncs.com/assets/qwen_agent/logo-qwen-agent.png" width="400"/>
16
+ <p>
17
+ <br>
18
+
19
+ Qwen-Agent是一个代码框架,用于发掘开源通义千问模型([Qwen](https://github.com/QwenLM/Qwen))的工具使用、规划、记忆能力。
20
+ 在Qwen-Agent的基础上,我们开发了一个名为BrowserQwen的**Chrome浏览器扩展**,它具有以下主要功能:
21
+
22
+ - 与Qwen讨论当前网页或PDF文档的内容。
23
+ - 在获得您的授权后,BrowserQwen会记录您浏览过的网页和PDF/Word/PPT材料,以帮助您快速了解多个页面的内容,总结您浏览过的内容,并自动化繁琐的文字工作。
24
+ - 集成各种插件,包括可用于数学问题求解、数据分析与可视化、处理文件等的**代码解释器**(**Code Interpreter**)。
25
+
26
+ # 用例演示
27
+
28
+ 如果您更喜欢观看视频,而不是效果截图,可以参见[视频演示](#视频演示)。
29
+
30
+ ## 工作台 - 创作模式
31
+
32
+ **根据浏览过的网页、PDFs素材进行长文创作**
33
+
34
+ <figure>
35
+ <img src="assets/screenshot-writing.png">
36
+ </figure>
37
+
38
+ **调用插件辅助富文本创作**
39
+
40
+ <figure>
41
+ <img src="assets/screenshot-editor-movie.png">
42
+ </figure>
43
+
44
+ ## 工作台 - 对话模式
45
+
46
+ **多网页问答**
47
+
48
+ <figure >
49
+ <img src="assets/screenshot-multi-web-qa.png">
50
+ </figure>
51
+
52
+ **使用代码解释器绘制数据图表**
53
+
54
+ <figure>
55
+ <img src="assets/screenshot-ci.png">
56
+ </figure>
57
+
58
+ ## 浏览器助手
59
+
60
+ **网页问答**
61
+
62
+ <figure>
63
+ <img src="assets/screenshot-web-qa.png">
64
+ </figure>
65
+
66
+ **PDF文档问答**
67
+
68
+ <figure>
69
+ <img src="assets/screenshot-pdf-qa.png">
70
+ </figure>
71
+
72
+ # BrowserQwen 使用说明
73
+
74
+ 支持环境:MacOS,Linux,Windows。
75
+
76
+ ## 第一步 - 部署模型服务
77
+
78
+ ***如果您正在使用阿里云提供的[DashScope](https://help.aliyun.com/zh/dashscope/developer-reference/quick-start)服务来访问Qwen系列模型,可以跳过这一步,直接到第二步。***
79
+
80
+ 但如果您不想使用DashScope,而是希望自己部署一个模型服务。那么可以参考[Qwen项目](https://github.com/QwenLM/Qwen/blob/main/README_CN.md#api),部署一个兼容OpenAI API的模型服务:
81
+
82
+ ```bash
83
+ # 安装依赖
84
+ git clone git@github.com:QwenLM/Qwen.git
85
+ cd Qwen
86
+ pip install -r requirements.txt
87
+ pip install fastapi uvicorn "openai<1.0.0" "pydantic>=2.3.0" sse_starlette
88
+
89
+ # 启动模型服务,通过 -c 参数指定模型版本
90
+ # - 指定 --server-name 0.0.0.0 将允许其他机器访问您的模型服务
91
+ # - 指定 --server-name 127.0.0.1 则只允许部署模型的机器自身访问该模型服务
92
+ python openai_api.py --server-name 0.0.0.0 --server-port 7905 -c Qwen/Qwen-14B-Chat
93
+ ```
94
+
95
+ 目前,我们支持指定-c参数以加载 [Qwen 的 Hugging Face主页](https://huggingface.co/Qwen) 上的模型,比如`Qwen/Qwen-1_8B-Chat`、`Qwen/Qwen-7B-Chat`、`Qwen/Qwen-14B-Chat`、`Qwen/Qwen-72B-Chat`,以及它们的`Int4`和`Int8`版本。
96
+
97
+ ## 第二步 - 部署本地数据库服务
98
+
99
+ 在这一步,您需要在您的本地机器上(即您可以打开Chrome浏览器的那台机器),部署维护个人浏览历史、对话历史的数据库服务。
100
+
101
+ 首次启动数据库服务前,请记得安装相关的依赖:
102
+
103
+ ```bash
104
+ # 安装依赖
105
+ git clone https://github.com/QwenLM/Qwen-Agent.git
106
+ cd Qwen-Agent
107
+ pip install -r requirements.txt
108
+ ```
109
+
110
+ 如果跳过了第一步、因为您打算使用DashScope提供的模型服务的话,请执行以下命令启动数据库服务:
111
+
112
+ ```bash
113
+ # 启动数据库服务,通过 --llm 参数指定您希望通过DashScope使用的具体模型
114
+ # 参数 --llm 可以是如下之一,按资源消耗从小到大排序:
115
+ # - qwen-7b-chat (与开源的Qwen-7B-Chat相同模型)
116
+ # - qwen-14b-chat (与开源的Qwen-14B-Chat相同模型)
117
+ # - qwen-turbo
118
+ # - qwen-plus
119
+ # 您需要将YOUR_DASHSCOPE_API_KEY替换为您的真实API-KEY。
120
+ export DASHSCOPE_API_KEY=YOUR_DASHSCOPE_API_KEY
121
+ python run_server.py --model_server dashscope --llm qwen-7b-chat --workstation_port 7864
122
+ ```
123
+
124
+ 如果您没有在使用DashScope、而是参考第一步部署了自己的模型服务的话,请执行以下命令:
125
+
126
+ ```bash
127
+ # 启动数据库服务,通过 --model_server 参数指定您在 Step 1 里部署好的模型服务
128
+ # - 若 Step 1 的机器 IP 为 123.45.67.89,则可指定 --model_server http://123.45.67.89:7905/v1
129
+ # - 若 Step 1 和 Step 2 是同一台机器,则可指定 --model_server http://127.0.0.1:7905/v1
130
+ python run_server.py --model_server http://{MODEL_SERVER_IP}:7905/v1 --workstation_port 7864
131
+ ```
132
+
133
+ 现在您可以访问 [http://127.0.0.1:7864/](http://127.0.0.1:7864/) 来使用工作台(Workstation)的创作模式(Editor模式)和对话模式(Chat模式)了。
134
+
135
+ 关于工作台的使用技巧,请参见工作台���面的文字说明、或观看[视频演示](#视频演示)。
136
+
137
+ ## Step 3. 安装浏览器助手
138
+
139
+ 安装BrowserQwen的Chrome插件(又称Chrome扩展程序):
140
+
141
+ 1. 打开Chrome浏览器,在浏览器的地址栏中输入 `chrome://extensions/` 并按下回车键;
142
+ 2. 确保右上角的 `开发者模式` 处于打开状态,之后点击 `加载已解压的扩展程序` 上传本项目下的 `browser_qwen` 目录并启用;
143
+ 3. 单击谷歌浏览器右上角```扩展程序```图标,将BrowserQwen固定在工具栏。
144
+
145
+ 注意,安装Chrome插件后,需要刷新页面,插件才能生效。
146
+
147
+ 当您想让Qwen阅读当前网页的内容时:
148
+
149
+ 1. 请先点击屏幕上的 `Add to Qwen's Reading List` 按钮,以授权Qwen在后台分析本页面。
150
+ 2. 再单击浏览器右上角扩展程序栏的Qwen图标,便可以和Qwen交流当前页面的内容了。
151
+
152
+ ## 视频演示
153
+
154
+ 可查看以下几个演示视频,了解BrowserQwen的基本操作:
155
+
156
+ - 根据浏览过的网页、PDFs进行长文创作 [video](https://qianwen-res.oss-cn-beijing.aliyuncs.com/assets/qwen_agent/showcase_write_article_based_on_webpages_and_pdfs.mp4)
157
+ - 提取浏览内容使用代码解释器画图 [video](https://qianwen-res.oss-cn-beijing.aliyuncs.com/assets/qwen_agent/showcase_chat_with_docs_and_code_interpreter.mp4)
158
+ - 上传文件、多轮对话利用代码解释器分析数据 [video](https://qianwen-res.oss-cn-beijing.aliyuncs.com/assets/qwen_agent/showcase_code_interpreter_multi_turn_chat.mp4)
159
+
160
+ # 评测基准
161
+
162
+ 我们也开源了一个评测基准,用于评估一个模型写Python代码并使用Code Interpreter进行数学解题、数据分析、及其他通用任务时的表现。评测基准见 [benchmark](benchmark/README.md) 目录,当前的评测结果如下:
163
+
164
+ <table>
165
+ <tr>
166
+ <th colspan="5" align="center">In-house Code Interpreter Benchmark (Version 20231206)</th>
167
+ </tr>
168
+ <tr>
169
+ <th rowspan="2" align="center">Model</th>
170
+ <th colspan="3" align="center">代码执行结果正确性 (%)</th>
171
+ <th colspan="1" align="center">生成代码的可执行率 (%)</th>
172
+ </tr>
173
+ <tr>
174
+ <th align="center">Math↑</th><th align="center">Visualization-Hard↑</th><th align="center">Visualization-Easy↑</th><th align="center">General↑</th>
175
+ </tr>
176
+ <tr>
177
+ <td>GPT-4</td>
178
+ <td align="center">82.8</td>
179
+ <td align="center">66.7</td>
180
+ <td align="center">60.8</td>
181
+ <td align="center">82.8</td>
182
+ </tr>
183
+ <tr>
184
+ <td>GPT-3.5</td>
185
+ <td align="center">47.3</td>
186
+ <td align="center">33.3</td>
187
+ <td align="center">55.7</td>
188
+ <td align="center">74.1</td>
189
+ </tr>
190
+ <tr>
191
+ <td>LLaMA2-13B-Chat</td>
192
+ <td align="center">8.3</td>
193
+ <td align="center">1.2</td>
194
+ <td align="center">15.2</td>
195
+ <td align="center">48.3</td>
196
+ </tr>
197
+ <tr>
198
+ <td>CodeLLaMA-13B-Instruct</td>
199
+ <td align="center">28.2</td>
200
+ <td align="center">15.5</td>
201
+ <td align="center">21.5</td>
202
+ <td align="center">74.1</td>
203
+ </tr>
204
+ <tr>
205
+ <td>InternLM-20B-Chat</td>
206
+ <td align="center">34.6</td>
207
+ <td align="center">10.7</td>
208
+ <td align="center">24.1</td>
209
+ <td align="center">65.5</td>
210
+ </tr>
211
+ <tr>
212
+ <td>ChatGLM3-6B</td>
213
+ <td align="center">54.2</td>
214
+ <td align="center">4.8</td>
215
+ <td align="center">15.2</td>
216
+ <td align="center">62.1</td>
217
+ </tr>
218
+ <tr>
219
+ <td>Qwen-1.8B-Chat</td>
220
+ <td align="center">25.6</td>
221
+ <td align="center">21.4</td>
222
+ <td align="center">22.8</td>
223
+ <td align="center">65.5</td>
224
+ </tr>
225
+ <tr>
226
+ <td>Qwen-7B-Chat</td>
227
+ <td align="center">41.9</td>
228
+ <td align="center">23.8</td>
229
+ <td align="center">38.0</td>
230
+ <td align="center">67.2</td>
231
+ </tr>
232
+ <tr>
233
+ <td>Qwen-14B-Chat</td>
234
+ <td align="center">58.4</td>
235
+ <td align="center">31.0</td>
236
+ <td align="center">45.6</td>
237
+ <td align="center">65.5</td>
238
+ </tr>
239
+ <tr>
240
+ <td>Qwen-72B-Chat</td>
241
+ <td align="center">72.7</td>
242
+ <td align="center">41.7</td>
243
+ <td align="center">43.0</td>
244
+ <td align="center">82.8</td>
245
+ </tr>
246
+ </table>
247
+
248
+ # 免责声明
249
+
250
+ 本项目并非正式产品,而是一个概念验证项目,用于演示Qwen系列模型的能力。
251
+
252
+ > 重要提示:代码解释器未进行沙盒隔离,会在部署环境中执行代码。请避免向Qwen发出危险指令,切勿将该代码解释器直接用于生产目的。
assets/screenshot-ci.png ADDED
assets/screenshot-editor-movie.png ADDED
assets/screenshot-multi-web-qa.png ADDED
assets/screenshot-pdf-qa.png ADDED
assets/screenshot-web-qa.png ADDED
assets/screenshot-writing.png ADDED
benchmark/README.md ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code Interpreter Benchmark
2
+
3
+ ## Introduction
4
+ To assess LLM's ability to use the Python Code Interpreter for tasks such as mathematical problem solving, data visualization, and other general-purpose tasks such as file handling and web scraping, we have created and open-sourced a benchmark specifically designed for evaluating these capabilities.
5
+
6
+ ### Metrics
7
+ The metrics are divided into two parts: code executability and code correctness.
8
+ - Code executability: evaluating the ability of the LLM-generated code to be executed.
9
+ - Code correctness: evaluating whether the LLM-generated code runs correctly.
10
+
11
+ ### Domain
12
+ When evaluating the accuracy of the code execution results for code correctness, we further divide it into two specific domains: `Math`, `Visualization`.
13
+ In terms of code executability, we calculate executable rate of the generated code for `General problem-solving`.
14
+
15
+ ## Results
16
+ - Qwen-7B-Chat refers to the version updated after September 25, 2023.
17
+ - The code correctness judger model for `Visualization` has changed from `Qwen-vl-chat` to `gpt-4-vision-preview` in the version 20231206.
18
+
19
+ <table>
20
+ <tr>
21
+ <th colspan="5" align="center">In-house Code Interpreter Benchmark (Version 20231206)</th>
22
+ </tr>
23
+ <tr>
24
+ <th rowspan="2" align="center">Model</th>
25
+ <th colspan="3" align="center">Accuracy of Code Execution Results (%)</th>
26
+ <th colspan="1" align="center">Executable Rate of Code (%)</th>
27
+ </tr>
28
+ <tr>
29
+ <th align="center">Math↑</th><th align="center">Visualization-Hard↑</th><th align="center">Visualization-Easy↑</th><th align="center">General↑</th>
30
+ </tr>
31
+ <tr>
32
+ <td>GPT-4</td>
33
+ <td align="center">82.8</td>
34
+ <td align="center">66.7</td>
35
+ <td align="center">60.8</td>
36
+ <td align="center">82.8</td>
37
+ </tr>
38
+ <tr>
39
+ <td>GPT-3.5</td>
40
+ <td align="center">47.3</td>
41
+ <td align="center">33.3</td>
42
+ <td align="center">55.7</td>
43
+ <td align="center">74.1</td>
44
+ </tr>
45
+ <tr>
46
+ <td>LLaMA2-13B-Chat</td>
47
+ <td align="center">8.3</td>
48
+ <td align="center">1.2</td>
49
+ <td align="center">15.2</td>
50
+ <td align="center">48.3</td>
51
+ </tr>
52
+ <tr>
53
+ <td>CodeLLaMA-13B-Instruct</td>
54
+ <td align="center">28.2</td>
55
+ <td align="center">15.5</td>
56
+ <td align="center">21.5</td>
57
+ <td align="center">74.1</td>
58
+ </tr>
59
+ <tr>
60
+ <td>InternLM-20B-Chat</td>
61
+ <td align="center">34.6</td>
62
+ <td align="center">10.7</td>
63
+ <td align="center">24.1</td>
64
+ <td align="center">65.5</td>
65
+ </tr>
66
+ <tr>
67
+ <td>ChatGLM3-6B</td>
68
+ <td align="center">54.2</td>
69
+ <td align="center">4.8</td>
70
+ <td align="center">15.2</td>
71
+ <td align="center">62.1</td>
72
+ </tr>
73
+ <tr>
74
+ <td>Qwen-1.8B-Chat</td>
75
+ <td align="center">25.6</td>
76
+ <td align="center">21.4</td>
77
+ <td align="center">22.8</td>
78
+ <td align="center">65.5</td>
79
+ </tr>
80
+ <tr>
81
+ <td>Qwen-7B-Chat</td>
82
+ <td align="center">41.9</td>
83
+ <td align="center">23.8</td>
84
+ <td align="center">38.0</td>
85
+ <td align="center">67.2</td>
86
+ </tr>
87
+ <tr>
88
+ <td>Qwen-14B-Chat</td>
89
+ <td align="center">58.4</td>
90
+ <td align="center">31.0</td>
91
+ <td align="center">45.6</td>
92
+ <td align="center">65.5</td>
93
+ </tr>
94
+ <tr>
95
+ <td>Qwen-72B-Chat</td>
96
+ <td align="center">72.7</td>
97
+ <td align="center">41.7</td>
98
+ <td align="center">43.0</td>
99
+ <td align="center">82.8</td>
100
+ </tr>
101
+ </table>
102
+
103
+ Furthermore, we also provide the results of `Qwen-vl-plus` as the code correctness judger model for `Visualization` task to serve as a reference.
104
+
105
+ <table>
106
+ <tr>
107
+ <th colspan="3" align="center">Code Correctness Judger Model = Qwen-vl-plus</th>
108
+ </tr>
109
+ <tr>
110
+ <th rowspan="2" align="center">Model</th>
111
+ <th colspan="2" align="center">Accuracy of Code Execution Results (%)</th>
112
+ </tr>
113
+ <tr>
114
+ <th align="center">Visualization-Hard↑</th>
115
+ <th align="center">Visualization-Easy↑</th>
116
+ </tr>
117
+ <tr>
118
+ <td>LLaMA2-13B-Chat</td>
119
+ <td align="center">2.4</td>
120
+ <td align="center">17.7</td>
121
+ </tr>
122
+ <tr>
123
+ <td>CodeLLaMA-13B-Instruct</td>
124
+ <td align="center">17.9</td>
125
+ <td align="center">34.2</td>
126
+ </tr>
127
+ <tr>
128
+ <td>InternLM-20B-Chat</td>
129
+ <td align="center">9.5</td>
130
+ <td align="center">31.7</td>
131
+ </tr>
132
+ <tr>
133
+ <td>ChatGLM3-6B</td>
134
+ <td align="center">10.7</td>
135
+ <td align="center">29.1</td>
136
+ </tr>
137
+ <tr>
138
+ <td>Qwen-1.8B-Chat</td>
139
+ <td align="center">32.1</td>
140
+ <td align="center">32.9</td>
141
+ </tr>
142
+ <tr>
143
+ <td>Qwen-7B-Chat</td>
144
+ <td align="center">26.2</td>
145
+ <td align="center">39.2</td>
146
+ </tr>
147
+ <tr>
148
+ <td>Qwen-14B-Chat</td>
149
+ <td align="center">36.9</td>
150
+ <td align="center">41.8</td>
151
+ </tr>
152
+ <tr>
153
+ <td>Qwen-72B-Chat</td>
154
+ <td align="center">38.1</td>
155
+ <td align="center">38.0</td>
156
+ </tr>
157
+ </table>
158
+
159
+
160
+
161
+ ## Usage
162
+
163
+ ### Installation
164
+
165
+ ```shell
166
+ git clone https://github.com/QwenLM/Qwen-Agent.git
167
+ cd benchmark
168
+ pip install -r requirements.txt
169
+ ```
170
+
171
+ ### Dataset Download
172
+ ```shell
173
+ cd benchmark
174
+ wget https://qianwen-res.oss-cn-beijing.aliyuncs.com/assets/qwen_agent/benchmark_code_interpreter_data.zip
175
+ unzip benchmark_code_interpreter_data.zip
176
+ mkdir eval_data
177
+ mv eval_code_interpreter_v1.jsonl eval_data/
178
+ ```
179
+
180
+ ### Evaluation
181
+ To reproduce the comprehensive results of benchmark, you can run the following script:
182
+
183
+ ```Shell
184
+ python inference_and_execute.py --model {model_name}
185
+ ```
186
+
187
+ {model_name}:
188
+ - qwen-1.8b-chat
189
+ - qwen-7b-chat
190
+ - qwen-14b-chat
191
+ - qwen-72b-chat
192
+ - llama-2-7b-chat
193
+ - llama-2-13b-chat
194
+ - codellama-7b-instruct
195
+ - codellama-13b-instruct
196
+ - internlm-7b-chat-1.1
197
+ - internlm-20b-chat
198
+
199
+ The benchmark will run the test cases and generate the performance results. The results will be saved in the `output_data` directory.
200
+
201
+ **Notes**:
202
+ Please install `simhei.ttf` font for proper display in matplotlib when evaluating visualization task. You can do this by preparing `simhei.ttf` (which can be found on any Windows PC) and then running the following code snippet:
203
+ ```python
204
+ import os
205
+ import matplotlib
206
+ target_font_path = os.path.join(
207
+ os.path.abspath(
208
+ os.path.join(matplotlib.matplotlib_fname(), os.path.pardir)),
209
+ 'fonts', 'ttf', 'simhei.ttf')
210
+ os.system(f'cp simhei.ttf {target_font_path}')
211
+ font_list_cache = os.path.join(matplotlib.get_cachedir(), 'fontlist-*.json')
212
+ os.system(f'rm -f {font_list_cache}')
213
+ ```
214
+
215
+ #### Code Executable Rate
216
+ ```Shell
217
+ python inference_and_execute.py --task {task_name} --model {model_name}
218
+ ```
219
+
220
+ {task_name}:
221
+ - `general`: General problem-solving task
222
+
223
+
224
+ #### Code Correctness Rate
225
+ ```Shell
226
+ python inference_and_execute.py --task {task_name} --model {model_name}
227
+ ```
228
+
229
+ {task_name}:
230
+ - `visualization`: Visualization task
231
+ - `gsm8k`: Math task
232
+
233
+
234
+ ## Configuration
235
+ The inference_and_exec.py file contains the following configurable options:
236
+
237
+ - `--model`: The model to test which can be one of `qwen-72b-chat`, `qwen-14b-chat`, `qwen-7b-chat`, `qwen-1.8b-chat`, `qwen-7b-chat`, `llama-2-7b-chat`, `llama-2-13b-chat`, `codellama-7b-instruct`, `codellama-13b-instruct`, `internlm-7b-chat-1.1`, `internlm-20b-chat`.
238
+ - `--task`: The test task which can be one of `all`, `visualization`, `general`, `gsm8k`.
239
+ - `--output-path`: The path for saving evaluation result.
240
+ - `--input-path`: The path for placing evaluation data.
241
+ - `--output-fname`: The file name for evaluation result.
242
+ - `--input-fname`: The file name for evaluation data.
243
+ - `--force`: Force generation and will overwrite the cached results.
244
+ - `--eval-only`: Only calculate evaluation metrics without re-inference.
245
+ - `--eval-code-exec-only`: Only evaluate code executable rate
246
+ - `--gen-exec-only`: Only generate and execuate code without calculating evaluation metrics.
247
+ - `--gen-only`: Only generate without execuating code and calculating evaluation metrics.
248
+ - `--vis-judger`: The model to judge the result correctness for `Visualization` task which can be one of `gpt-4-vision-preview`, `qwen-vl-chat`, `qwen-vl-plus`. It is set to `gpt-4-vision-preview` by default in the version 20231206, and `Qwen-vl-chat` has been deprecated.
benchmark/code_interpreter.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+ import json
4
+ import logging
5
+ import os
6
+ import queue
7
+ import re
8
+ import subprocess
9
+ import sys
10
+ import time
11
+ import traceback
12
+ import uuid
13
+
14
+ import matplotlib
15
+ import PIL.Image
16
+ from jupyter_client import BlockingKernelClient
17
+ from utils.code_utils import extract_code
18
+
19
+ WORK_DIR = os.getenv('CODE_INTERPRETER_WORK_DIR', '/tmp/workspace')
20
+
21
+ LAUNCH_KERNEL_PY = """
22
+ from ipykernel import kernelapp as app
23
+ app.launch_new_instance()
24
+ """
25
+
26
+ _KERNEL_CLIENTS = {}
27
+
28
+
29
+ # Run this fix before jupyter starts if matplotlib cannot render CJK fonts.
30
+ # And we need to additionally run the following lines in the jupyter notebook.
31
+ # ```python
32
+ # import matplotlib.pyplot as plt
33
+ # plt.rcParams['font.sans-serif'] = ['SimHei']
34
+ # plt.rcParams['axes.unicode_minus'] = False
35
+ # ````
36
+ def fix_matplotlib_cjk_font_issue():
37
+ local_ttf = os.path.join(
38
+ os.path.abspath(
39
+ os.path.join(matplotlib.matplotlib_fname(), os.path.pardir)),
40
+ 'fonts', 'ttf', 'simhei.ttf')
41
+ if not os.path.exists(local_ttf):
42
+ logging.warning(
43
+ f'Missing font file `{local_ttf}` for matplotlib. It may cause some error when using matplotlib.'
44
+ )
45
+
46
+
47
+ def start_kernel(pid):
48
+ fix_matplotlib_cjk_font_issue()
49
+
50
+ connection_file = os.path.join(WORK_DIR,
51
+ f'kernel_connection_file_{pid}.json')
52
+ launch_kernel_script = os.path.join(WORK_DIR, f'launch_kernel_{pid}.py')
53
+ for f in [connection_file, launch_kernel_script]:
54
+ if os.path.exists(f):
55
+ logging.warning(f'{f} already exists')
56
+ os.remove(f)
57
+
58
+ os.makedirs(WORK_DIR, exist_ok=True)
59
+
60
+ with open(launch_kernel_script, 'w') as fout:
61
+ fout.write(LAUNCH_KERNEL_PY)
62
+
63
+ kernel_process = subprocess.Popen([
64
+ sys.executable,
65
+ launch_kernel_script,
66
+ '--IPKernelApp.connection_file',
67
+ connection_file,
68
+ '--matplotlib=inline',
69
+ '--quiet',
70
+ ],
71
+ cwd=WORK_DIR)
72
+ logging.info(f"INFO: kernel process's PID = {kernel_process.pid}")
73
+
74
+ # Wait for kernel connection file to be written
75
+ while True:
76
+ if not os.path.isfile(connection_file):
77
+ time.sleep(0.1)
78
+ else:
79
+ # Keep looping if JSON parsing fails, file may be partially written
80
+ try:
81
+ with open(connection_file, 'r') as fp:
82
+ json.load(fp)
83
+ break
84
+ except json.JSONDecodeError:
85
+ pass
86
+
87
+ # Client
88
+ kc = BlockingKernelClient(connection_file=connection_file)
89
+ kc.load_connection_file()
90
+ kc.start_channels()
91
+ kc.wait_for_ready()
92
+ return kc
93
+
94
+
95
+ def escape_ansi(line):
96
+ ansi_escape = re.compile(r'(?:\x1B[@-_]|[\x80-\x9F])[0-?]*[ -/]*[@-~]')
97
+ return ansi_escape.sub('', line)
98
+
99
+
100
+ def publish_image_to_local(image_base64: str):
101
+ image_file = str(uuid.uuid4()) + '.png'
102
+ local_image_file = os.path.join(WORK_DIR, image_file)
103
+
104
+ png_bytes = base64.b64decode(image_base64)
105
+ assert isinstance(png_bytes, bytes)
106
+ bytes_io = io.BytesIO(png_bytes)
107
+ PIL.Image.open(bytes_io).save(local_image_file, 'png')
108
+
109
+ return local_image_file
110
+
111
+
112
+ START_CODE = """
113
+ import signal
114
+ def _m6_code_interpreter_timeout_handler(signum, frame):
115
+ raise TimeoutError("M6_CODE_INTERPRETER_TIMEOUT")
116
+ signal.signal(signal.SIGALRM, _m6_code_interpreter_timeout_handler)
117
+
118
+ def input(*args, **kwargs):
119
+ raise NotImplementedError('Python input() function is disabled.')
120
+
121
+ import os
122
+ if 'upload_file' not in os.getcwd():
123
+ os.chdir("./upload_file/")
124
+
125
+ import math
126
+ import re
127
+ import json
128
+
129
+ import seaborn as sns
130
+ sns.set_theme()
131
+
132
+ import matplotlib
133
+ import matplotlib.pyplot as plt
134
+ plt.rcParams['font.sans-serif'] = ['SimHei']
135
+ plt.rcParams['axes.unicode_minus'] = False
136
+
137
+ import numpy as np
138
+ import pandas as pd
139
+
140
+ from sympy import Eq, symbols, solve
141
+ """
142
+
143
+
144
+ def code_interpreter(action_input_list: list, timeout=30, clear=False):
145
+ code = ''
146
+ for action_input in action_input_list:
147
+ code += (extract_code(action_input) + '\n')
148
+ fixed_code = []
149
+ for line in code.split('\n'):
150
+ fixed_code.append(line)
151
+ if line.startswith('sns.set_theme('):
152
+ fixed_code.append('plt.rcParams["font.sans-serif"] = ["SimHei"]')
153
+ fixed_code.append('plt.rcParams["axes.unicode_minus"] = False')
154
+ fixed_code = '\n'.join(fixed_code)
155
+ if 'def solution()' in fixed_code:
156
+ fixed_code += '\nsolution()'
157
+
158
+ return _code_interpreter(fixed_code, timeout, clear)
159
+
160
+
161
+ def _code_interpreter(code: str, timeout, clear=False):
162
+ if not code.strip():
163
+ return ''
164
+ if timeout:
165
+ code = f'signal.alarm({timeout})\n{code}'
166
+ if clear:
167
+ code = "get_ipython().run_line_magic('reset', '-f')\n" + START_CODE + code
168
+
169
+ pid = os.getpid()
170
+ if pid not in _KERNEL_CLIENTS:
171
+ _KERNEL_CLIENTS[pid] = start_kernel(pid)
172
+ _code_interpreter(START_CODE, timeout=None)
173
+ kc = _KERNEL_CLIENTS[pid]
174
+ kc.wait_for_ready()
175
+ kc.execute(code)
176
+ result = ''
177
+ image_idx = 0
178
+ while True:
179
+ text = ''
180
+ image = ''
181
+ finished = False
182
+ msg_type = 'error'
183
+ try:
184
+ msg = kc.get_iopub_msg()
185
+ msg_type = msg['msg_type']
186
+ if msg_type == 'status':
187
+ if msg['content'].get('execution_state') == 'idle':
188
+ finished = True
189
+ elif msg_type == 'execute_result':
190
+ text = msg['content']['data'].get('text/plain', '')
191
+ if 'image/png' in msg['content']['data']:
192
+ image_b64 = msg['content']['data']['image/png']
193
+ image_url = publish_image_to_local(image_b64)
194
+ image_idx += 1
195
+ image = '![fig-%03d](%s)' % (image_idx, image_url)
196
+ elif msg_type == 'display_data':
197
+ if 'image/png' in msg['content']['data']:
198
+ image_b64 = msg['content']['data']['image/png']
199
+ image_url = publish_image_to_local(image_b64)
200
+ image_idx += 1
201
+ image = '![fig-%03d](%s)' % (image_idx, image_url)
202
+ else:
203
+ text = msg['content']['data'].get('text/plain', '')
204
+ elif msg_type == 'stream':
205
+ msg_type = msg['content']['name'] # stdout, stderr
206
+ text = msg['content']['text']
207
+ elif msg_type == 'error':
208
+ text = escape_ansi('\n'.join(msg['content']['traceback']))
209
+ if 'M6_CODE_INTERPRETER_TIMEOUT' in text:
210
+ text = f'Timeout. No response after {timeout} seconds.'
211
+ except queue.Empty:
212
+ text = f'Timeout. No response after {timeout} seconds.'
213
+ finished = True
214
+ except Exception:
215
+ text = 'The code interpreter encountered an unexpected error.'
216
+ logging.warning(''.join(
217
+ traceback.format_exception(*sys.exc_info())))
218
+ finished = True
219
+ if text:
220
+ result += f'\n\n{msg_type}:\n\n```\n{text}\n```'
221
+ if image:
222
+ result += f'\n\n{image}'
223
+ if finished:
224
+ break
225
+ result = result.lstrip('\n')
226
+ if timeout:
227
+ _code_interpreter('signal.alarm(0)', timeout=None)
228
+ return result
229
+
230
+
231
+ def get_multiline_input(hint):
232
+ print(hint)
233
+ print('// Press ENTER to make a new line. Press CTRL-D to end input.')
234
+ lines = []
235
+ while True:
236
+ try:
237
+ line = input()
238
+ except EOFError: # CTRL-D
239
+ break
240
+ lines.append(line)
241
+ print('// Input received.')
242
+ if lines:
243
+ return '\n'.join(lines)
244
+ else:
245
+ return ''
246
+
247
+
248
+ if __name__ == '__main__':
249
+ while True:
250
+ print(code_interpreter([get_multiline_input('Enter python code:')]))
benchmark/config.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from parser import InternLMReActParser, ReActParser
2
+
3
+ from models import LLM, QwenVL, Qwen, QwenDashscopeVLModel
4
+ from prompt import InternLMReAct, LlamaReAct, QwenReAct
5
+
6
+ react_prompt_map = {
7
+ 'qwen': QwenReAct,
8
+ 'llama': LlamaReAct,
9
+ 'internlm': InternLMReAct,
10
+ }
11
+
12
+ react_parser_map = {
13
+ 'qwen': ReActParser,
14
+ 'llama': ReActParser,
15
+ 'internlm': InternLMReActParser,
16
+ }
17
+
18
+ model_map = {'qwen': Qwen, 'llama': LLM, 'internlm': LLM, 'qwen-vl-chat': QwenVL}
19
+
20
+ model_type_map = {
21
+ 'qwen-72b-chat': 'qwen',
22
+ 'qwen-14b-chat': 'qwen',
23
+ 'qwen-1.8b-chat': 'qwen',
24
+ 'qwen-7b-chat': 'qwen',
25
+ 'llama-2-7b-chat': 'llama',
26
+ 'llama-2-13b-chat': 'llama',
27
+ 'codellama-7b-instruct': 'llama',
28
+ 'codellama-13b-instruct': 'llama',
29
+ 'internlm-7b-chat-1.1': 'internlm',
30
+ 'internlm-20b-chat': 'internlm',
31
+ 'qwen-vl-chat': 'qwen-vl-chat',
32
+ }
33
+
34
+ model_path_map = {
35
+ 'qwen-72b-chat': 'Qwen/Qwen-72B-Chat',
36
+ 'qwen-14b-chat': 'Qwen/Qwen-14B-Chat',
37
+ 'qwen-7b-chat': 'Qwen/Qwen-7B-Chat',
38
+ 'qwen-1.8b-chat': 'Qwen/Qwen-1_8B-Chat',
39
+ 'llama-2-7b-chat': 'meta-llama/Llama-2-7b-chat-hf',
40
+ 'llama-2-13b-chat': 'meta-llama/Llama-2-13b-chat-hf',
41
+ 'codellama-7b-instruct': 'codellama/CodeLlama-7b-Instruct-hf',
42
+ 'codellama-13b-instruct': 'codellama/CodeLlama-13b-Instruct-hf',
43
+ 'internlm-7b-chat-1.1': 'internlm/internlm-chat-7b-v1_1',
44
+ 'internlm-20b-chat': 'internlm/internlm-chat-20b',
45
+ 'qwen-vl-chat': 'Qwen/Qwen-VL-Chat',
46
+ }
47
+
48
+
49
+ def get_react_prompt(model_name, query, lang, upload_fname_list):
50
+ react_prompt_cls = react_prompt_map.get(model_type_map[model_name],
51
+ QwenReAct)
52
+ return react_prompt_cls(query, lang, upload_fname_list)
53
+
54
+
55
+ def get_react_parser(model_name):
56
+ react_parser_cls = react_parser_map.get(model_type_map[model_name],
57
+ ReActParser)
58
+ return react_parser_cls()
59
+
60
+
61
+ def get_model(model_name):
62
+ if model_name in ["qwen-vl-plus"]:
63
+ return QwenDashscopeVLModel(model=model_name)
64
+ model_path = model_path_map.get(model_name, None)
65
+ model_cls = model_map.get(model_type_map[model_name], LLM)
66
+ return model_cls(model_path)
benchmark/inference_and_execute.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import logging
4
+ import os
5
+ from parser import ReActParser
6
+
7
+ import prettytable
8
+ import tqdm
9
+ from code_interpreter import code_interpreter
10
+ from config import (get_model, get_react_parser, get_react_prompt,
11
+ model_path_map)
12
+ from datasets import load_dataset
13
+ from metrics.code_execution import eval_code_execution_rate
14
+ from metrics.gsm8k import eval_gsm8k_acc, is_correct
15
+ from metrics.visualization import eval_visualization_acc
16
+ from utils.code_utils import replace_upload_fname
17
+ from utils.data_utils import load_jsonl
18
+
19
+ logging.basicConfig(
20
+ format='%(asctime)s - %(levelname)s - %(message)s',
21
+ datefmt='%Y-%m-%d %H:%M:%S',
22
+ level=logging.INFO,
23
+ )
24
+
25
+ WORK_DIR = os.getenv('CODE_INTERPRETER_WORK_DIR', '/tmp/workspace')
26
+ os.makedirs(WORK_DIR, exist_ok=True)
27
+ os.system(f'cp -r upload_file_clean {WORK_DIR}/upload_file')
28
+ os.system('cp -r upload_file_clean ./upload_file')
29
+
30
+ global_eval_result = {
31
+ 'code_executability': {
32
+ 'math': None,
33
+ 'visualization': None,
34
+ 'general': None,
35
+ },
36
+ 'code_correctness': {
37
+ 'math': None,
38
+ 'visualization-hard': None,
39
+ 'visualization-easy': None,
40
+ }
41
+ }
42
+
43
+
44
+ def llm_with_plugin(args, query, item=None, exec_limit=3):
45
+ exec_count = 0
46
+
47
+ # Build ReAct prompt
48
+ upload_fname_list = item[
49
+ 'input_file_path'] if item and 'input_file_path' in item else []
50
+ lang = item['lang'] if item and 'lang' in item else 'en'
51
+ react_prompt_obj = get_react_prompt(args.model, query, lang,
52
+ upload_fname_list)
53
+ planning_prompt = react_prompt_obj.build_prompt()
54
+
55
+ # Execute the code when providing the first action in the query
56
+ if '<|im_start|>' in query:
57
+ _, prepend_code, __ = ReActParser().parse_latest_plugin_call(query)
58
+ prepend_code = replace_upload_fname(prepend_code, upload_fname_list)
59
+ call_plugin(_, [prepend_code], clear=(exec_count == 0))
60
+ exec_count += 1
61
+ exec_limit += 1
62
+
63
+ # Inference and execute
64
+ text = ''
65
+ while exec_count < exec_limit:
66
+ stop_words_list = react_prompt_obj.get_stop_words_list()
67
+ output = text_completion(args.llm,
68
+ planning_prompt + text,
69
+ stop_words=stop_words_list)
70
+
71
+ if args.gen_only:
72
+ text += output
73
+ break
74
+
75
+ react_parser = get_react_parser(args.model)
76
+ action, action_input, output = react_parser.parse_latest_plugin_call(
77
+ output)
78
+ if action:
79
+ action_input = replace_upload_fname(action_input,
80
+ upload_fname_list)
81
+ observation = call_plugin(action, [action_input],
82
+ clear=(exec_count == 0))
83
+ output += react_prompt_obj.build_observation(observation)
84
+ text += output
85
+ exec_count += 1
86
+ if 'error:' in observation or 'Traceback' in observation:
87
+ break
88
+ else:
89
+ text += output
90
+ break
91
+ return text
92
+
93
+
94
+ def text_completion(llm, input_text, stop_words=[]):
95
+ logging.info('Generating'.center(60, '='))
96
+ logging.info('Input'.center(60, '-'))
97
+ logging.info(input_text)
98
+
99
+ output = llm.generate(input_text, stop_words)
100
+
101
+ logging.info('Output'.center(60, '-'))
102
+ logging.info(output)
103
+ return output
104
+
105
+
106
+ def call_plugin(plugin_name, plugin_args_list, clear=False):
107
+ # Relax constraints on plugin name.
108
+ logging.info('Call code interpreter'.center(60, '='))
109
+ obs = code_interpreter(plugin_args_list, clear=clear)
110
+ logging.info(obs)
111
+ return obs
112
+
113
+
114
+ def process_code_interpreter(item, writer):
115
+ query = item['query']
116
+ exec_limit = 3 if 'visualization' in item['tags'] else 1
117
+ response = llm_with_plugin(args=args,
118
+ query=query,
119
+ item=item,
120
+ exec_limit=exec_limit)
121
+ item['gen'] = response
122
+
123
+ writer.write(json.dumps(item, ensure_ascii=False) + '\n')
124
+ writer.flush()
125
+
126
+
127
+ def process_gsm8k(doc, writer):
128
+ context = doc['question']
129
+ completion = llm_with_plugin(args=args, query=context)
130
+ acc = is_correct(completion, doc['answer'])
131
+ doc['completion'] = completion
132
+ doc['acc'] = acc
133
+
134
+ writer.write(json.dumps(doc, ensure_ascii=False) + '\n')
135
+ writer.flush()
136
+
137
+
138
+ def sequential_processing(args, data_list, process_func, writer):
139
+ for item in tqdm.tqdm(data_list):
140
+ process_func(item, writer)
141
+
142
+
143
+ process_func_map = {
144
+ 'gsm8k': process_gsm8k,
145
+ 'visualization': process_code_interpreter
146
+ }
147
+
148
+
149
+ def gather_eval_result(model_name):
150
+ for metric in global_eval_result:
151
+ logging.info(metric)
152
+ table = prettytable.PrettyTable()
153
+ table.field_names = ['model'] + list(global_eval_result[metric].keys())
154
+ row_data = [model_name]
155
+ for item in global_eval_result[metric].values():
156
+ item = str(item) if not item else str(round(item, 2))
157
+ row_data.append(item)
158
+ table.add_row(row_data)
159
+ logging.info('\n' + str(table))
160
+
161
+
162
+ def eval_metrics(args, test_set, full_output_fname):
163
+ # metrics
164
+ assert os.path.exists(
165
+ full_output_fname), f'Not Found File {full_output_fname}.'
166
+ inference_res = load_jsonl(full_output_fname)
167
+ assert len(inference_res) == len(
168
+ test_set
169
+ ), f'There are still {len(test_set)-len(inference_res)} cases left.'
170
+
171
+ abs_output_fname = os.path.join(os.path.dirname(os.path.abspath(__file__)),
172
+ full_output_fname)
173
+ if args.task == 'gsm8k':
174
+ math_code_correctness = eval_gsm8k_acc(abs_output_fname)
175
+ global_eval_result['code_correctness'].update(math_code_correctness)
176
+ else:
177
+ code_executability = eval_code_execution_rate(abs_output_fname,
178
+ args.task, args.model)
179
+ global_eval_result['code_executability'].update(code_executability)
180
+ if args.task in ['all_ci', 'visualization'
181
+ ] and not args.eval_code_exec_only:
182
+ visualization_code_correctness = eval_visualization_acc(
183
+ abs_output_fname, args.model, args.vis_judger)
184
+ global_eval_result['code_correctness'].update(
185
+ visualization_code_correctness)
186
+
187
+
188
+ def main(args):
189
+ current_dir = os.getcwd()
190
+ os.makedirs(args.output_path, exist_ok=True)
191
+ full_output_fname = os.path.join(
192
+ args.output_path,
193
+ (args.output_fname or f'{args.task}_{args.model}_res.jsonl'))
194
+
195
+ if not os.path.exists(full_output_fname):
196
+ with open(full_output_fname, 'w'):
197
+ logging.info(f'Create file {full_output_fname} done.')
198
+
199
+ # build data
200
+ if args.task == 'gsm8k':
201
+ dataset = load_dataset('gsm8k', 'main')
202
+ test_set = dataset['test']
203
+ else:
204
+ eval_data_path = os.path.join(args.input_path, args.input_fname)
205
+ test_set = [
206
+ item for item in load_jsonl(eval_data_path)
207
+ if args.task in item['tags']
208
+ ]
209
+ logging.info(f'Test set: {len(test_set)}')
210
+
211
+ if args.eval_only:
212
+ eval_metrics(args, test_set, full_output_fname)
213
+ else:
214
+ key = 'question' if args.task == 'gsm8k' else 'query'
215
+ cache_question = [item[key] for item in load_jsonl(full_output_fname)
216
+ ] if not args.force else []
217
+ data_list = [
218
+ item for item in test_set if item[key] not in cache_question
219
+ ]
220
+ logging.info(f'Left cases: {len(data_list)}')
221
+
222
+ # inference
223
+ writer_mode = 'w' if args.force else 'a'
224
+ f_output = open(full_output_fname, writer_mode, encoding='utf-8')
225
+ process_func = process_func_map.get(args.task,
226
+ process_code_interpreter)
227
+ sequential_processing(args, data_list, process_func, f_output)
228
+ f_output.close()
229
+
230
+ # evaluate
231
+ if not args.gen_exec_only:
232
+ eval_metrics(args, test_set, full_output_fname)
233
+
234
+ os.chdir(current_dir)
235
+
236
+
237
+ def parse_args():
238
+ parser = argparse.ArgumentParser()
239
+ parser.add_argument('--model',
240
+ type=str,
241
+ default='qwen-14b-chat',
242
+ choices=list(model_path_map.keys()))
243
+ parser.add_argument(
244
+ '--task',
245
+ type=str,
246
+ default='all',
247
+ choices=['all', 'gsm8k', 'visualization', 'general'])
248
+ parser.add_argument('--output-path', type=str, default='output_data')
249
+ parser.add_argument('--input-path', type=str, default='eval_data')
250
+ parser.add_argument('-o', '--output-fname', type=str, default='')
251
+ parser.add_argument('-i',
252
+ '--input-fname',
253
+ type=str,
254
+ default='eval_code_interpreter_v1.jsonl')
255
+ parser.add_argument('-f', '--force', action='store_true', default=False)
256
+ parser.add_argument('--eval-only', action='store_true', default=False)
257
+ parser.add_argument('--eval-code-exec-only',
258
+ action='store_true',
259
+ default=False)
260
+ parser.add_argument('--gen-exec-only', action='store_true', default=False)
261
+ parser.add_argument('--gen-only', action='store_true', default=False)
262
+ parser.add_argument('--vis-judger', type=str, default="'gpt-4-vision-preview'",
263
+ choices=['gpt-4-vision-preview', 'qwen-vl-chat', 'qwen-vl-plus'])
264
+ args = parser.parse_args()
265
+ return args
266
+
267
+
268
+ if __name__ == '__main__':
269
+ args = parse_args()
270
+ if not args.eval_only:
271
+ args.llm = get_model(args.model)
272
+ logging.info(f'Init {args.model} done.')
273
+
274
+ if args.task == 'all':
275
+ for key in ['gsm8k', 'visualization', 'general']:
276
+ args.task = key
277
+ main(args)
278
+ else:
279
+ main(args)
280
+ gather_eval_result(args.model)
benchmark/metrics/__init__.py ADDED
File without changes
benchmark/metrics/code_execution.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+
4
+ import func_timeout
5
+ from config import get_react_parser
6
+ from func_timeout import func_set_timeout
7
+ from utils.code_utils import extract_code, replace_upload_fname
8
+ from utils.data_utils import load_jsonl, save_jsonl
9
+
10
+ pre_load = """
11
+ import os
12
+ if 'upload_file' not in os.getcwd():
13
+ os.chdir("./upload_file/")
14
+
15
+ import seaborn as sns
16
+
17
+ import matplotlib
18
+ # matplotlib.use('Agg')
19
+ import matplotlib.pyplot as plt
20
+ plt.ion()
21
+
22
+ import numpy as np
23
+ import pandas as pd
24
+ from sympy import Eq, symbols, solve
25
+ import re
26
+ import json
27
+ import math
28
+ """
29
+
30
+ tags_config = {
31
+ 'visualization': {
32
+ 'timelimit': True,
33
+ 'extract_first_code': True,
34
+ },
35
+ 'math': {
36
+ 'timelimit': True,
37
+ 'extract_first_code': False,
38
+ },
39
+ 'general': {
40
+ 'timelimit': False,
41
+ 'extract_first_code': True,
42
+ }
43
+ }
44
+
45
+ code_executability = {'math': None, 'visualization': None, 'general': None}
46
+
47
+
48
+ @func_set_timeout(10)
49
+ def exec_limit_time(text):
50
+ exec(text, locals())
51
+
52
+
53
+ def exec_code(text, timelimit=False):
54
+ if timelimit:
55
+ exec_limit_time(text)
56
+ else:
57
+ exec(text, locals())
58
+
59
+
60
+ def postprocess_code(gen_code, line):
61
+ if '<|im_start|>' in line['query']:
62
+ first_action_code = get_action_input_code(line['query'])
63
+ gen_code = first_action_code + gen_code
64
+
65
+ upload_fname_list = line[
66
+ 'input_file_path'] if line and 'input_file_path' in line else []
67
+ gen_code = replace_upload_fname(gen_code, upload_fname_list)
68
+
69
+ if 'def solution()' in gen_code:
70
+ gen_code += '\nsolution()\n'
71
+
72
+ if 'plt.show()' in gen_code:
73
+ gen_code += "\nplt.pause(1)\nplt.close('all')\n"
74
+
75
+ if 'sns.' in gen_code and 'plot' in gen_code:
76
+ gen_code += "\nplt.close('all')\n"
77
+
78
+ gen_code = pre_load + gen_code
79
+ return gen_code
80
+
81
+
82
+ def get_action_input_code(text,
83
+ model_name='qwen-14b-chat',
84
+ extract_first_code=False):
85
+ action_input_list = []
86
+ tmp = text
87
+ react_parser = get_react_parser(model_name)
88
+ while True:
89
+ action_input = react_parser.get_first_action_input(tmp)
90
+ if not action_input:
91
+ break
92
+ action_input_list.append(action_input)
93
+ tmp = tmp.split(action_input)[1]
94
+ if not tmp or extract_first_code:
95
+ break
96
+
97
+ code = ''
98
+ for action_input in action_input_list:
99
+ code = code + '# concat\n' + extract_code(action_input) + '\n'
100
+ return code
101
+
102
+
103
+ def eval_code_execution_rate(output_fname,
104
+ tag='all_ci',
105
+ model_name='qwen-14b-chat',
106
+ timelimit=False,
107
+ extract_first_code=False):
108
+ data_list = load_jsonl(output_fname)
109
+ pip_package = []
110
+
111
+ for line_id, line in enumerate(data_list):
112
+ line['idx'] = line_id
113
+ tags_list = line['tags'].split(',')
114
+ if tag not in tags_list:
115
+ continue
116
+
117
+ # update args
118
+ for cur_tag in tags_list:
119
+ if cur_tag != 'all_ci':
120
+ timelimit = tags_config[cur_tag]['timelimit']
121
+ extract_first_code = tags_config[cur_tag]['extract_first_code']
122
+
123
+ line['executable_code'] = False
124
+ line['missing_code'] = False
125
+ line['code_error_info'] = ''
126
+
127
+ # get Action Input code from response
128
+ gen_code = get_action_input_code(line['gen'],
129
+ model_name=model_name,
130
+ extract_first_code=extract_first_code)
131
+
132
+ if not gen_code:
133
+ line['missing_code'] = True
134
+ line['code'] = ''
135
+ line['code_error_info'] = 'missing code'
136
+ continue
137
+
138
+ line['code'] = gen_code
139
+ gen_code = postprocess_code(gen_code, line)
140
+
141
+ while True:
142
+ try:
143
+ exec_code(gen_code, timelimit=timelimit)
144
+ line['executable_code'] = True
145
+ break
146
+ except func_timeout.exceptions.FunctionTimedOut as ex:
147
+ line['code_error_info'] = str(ex)
148
+ break
149
+ except (ImportError, ModuleNotFoundError) as ex:
150
+ try:
151
+ packege = str(ex).split("'")[1].strip()
152
+ except Exception:
153
+ packege = ''
154
+ if packege and packege not in pip_package: # install package
155
+ pip_package.append(packege)
156
+ os.system('pip install ' + packege)
157
+ logging.info(f'Automatic installation: {packege}')
158
+ else:
159
+ line['code_error_info'] = str(ex)
160
+ break
161
+ except Exception as ex:
162
+ line['code_error_info'] = str(ex)
163
+ break
164
+
165
+ # double check
166
+ observation = get_react_parser(model_name).get_first_observation(
167
+ line['gen'])
168
+ if line['executable_code'] and ('error:' in observation):
169
+ logging.warning(
170
+ 'The code executes correctly, but it has an error in IPython!')
171
+ logging.warning(f'Code:\n{gen_code}')
172
+ logging.warning(f'IPython error info:\n{observation}')
173
+ logging.info('=' * 60)
174
+ elif not line['executable_code'] and not ('error:' in observation):
175
+ logging.warning(
176
+ 'The code has an execution error, but it runs correctly in IPython!'
177
+ )
178
+ logging.warning(f'Code:\n{gen_code}')
179
+ logging.warning(f"Exec error info:\n{line['code_error_info']}")
180
+ logging.warning(f'IPython observation:\n{observation}')
181
+ logging.info('=' * 60)
182
+
183
+ # save error data
184
+ error_data_list = [
185
+ item for item in data_list
186
+ if not item['executable_code'] or item['missing_code']
187
+ ]
188
+ error_data_output_fname = os.path.splitext(
189
+ output_fname)[0] + '_exec_error.jsonl'
190
+ save_jsonl(error_data_list, error_data_output_fname)
191
+
192
+ log_result(data_list)
193
+
194
+ return code_executability
195
+
196
+
197
+ def log_result(data_list, verbose=True):
198
+ if verbose:
199
+ logging.info('*' * 60)
200
+ logging.info('{:^60}'.format('Detail'))
201
+ logging.info('*' * 60)
202
+ for line_id, line in enumerate(data_list):
203
+ logging.info(f'Question {line_id}'.center(60, '='))
204
+ logging.info(line['query'])
205
+
206
+ logging.info(f'Generated {line_id}'.center(60, '-'))
207
+ logging.info('\n' + line['gen'])
208
+
209
+ logging.info(f'Code {line_id}'.center(60, '-'))
210
+ logging.info('\n' + line['code'])
211
+
212
+ logging.info(f'Exec Result {line_id}'.center(60, '-'))
213
+ prefix_info = 'Exec Success' if line[
214
+ 'executable_code'] else 'Exec Error: '
215
+ exec_info = prefix_info + line['code_error_info']
216
+ logging.info(exec_info)
217
+
218
+ logging.info('=' * 60)
219
+ logging.info('{:^60}'.format('Code Execuation Rate'))
220
+ logging.info('=' * 60)
221
+ involved_tags = []
222
+ for line in data_list:
223
+ involved_tags += line['tags'].split(',')
224
+ involved_tags = list(set(involved_tags))
225
+
226
+ for key in involved_tags:
227
+ logging.info(f'task: {key}'.center(60, '='))
228
+ key_item_list = [item for item in data_list if key in item['tags']]
229
+ all_count = len(key_item_list)
230
+ missing_code_count = len(
231
+ [item for item in key_item_list if item['missing_code']])
232
+ executable_code_count = len(
233
+ [item for item in key_item_list if item['executable_code']])
234
+
235
+ logging.info(f'All Test: {all_count}')
236
+ logging.info(f'Missing Code: {missing_code_count}')
237
+ logging.info(f'Predict Exec Success: {executable_code_count}')
238
+ logging.info('Codes available && Execution Rate: {:.2f}'.format(
239
+ executable_code_count / (all_count - missing_code_count) * 100))
240
+ logging.info('Execution Rate: {:.2f}'.format(executable_code_count /
241
+ all_count * 100))
242
+ logging.info('Non-executable rate: {:.2f}'.format(
243
+ (all_count - missing_code_count - executable_code_count) /
244
+ all_count * 100))
245
+ logging.info('Missing code rate: {:.2f}'.format(missing_code_count /
246
+ all_count * 100))
247
+
248
+ if key != 'all_ci':
249
+ code_executability[key] = executable_code_count / all_count * 100
250
+
251
+ if verbose:
252
+ logging.info('Error List: ')
253
+ error_list = [(item['idx'], item['code_error_info'])
254
+ for item in key_item_list if item['code_error_info']]
255
+ error_list.sort(key=lambda x: x[1])
256
+ for x in error_list:
257
+ logging.info(x)
benchmark/metrics/gsm8k.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import re
4
+
5
+ import numpy as np
6
+ from utils.data_utils import load_jsonl, save_jsonl
7
+
8
+ INVALID_ANS = '[invalid]'
9
+
10
+
11
+ def extract_answer(completion):
12
+
13
+ def _get_last_digit(s):
14
+ _PAT_LAST_DIGIT = re.compile(
15
+ r'(?<=(\s|[\$%#{]))([+-])?(?=(\S))(0|([1-9](\d*|\d{0,2}(,\d{3})*)))?(\.\d*[1-9])?(?=(\s|[.,}]|$))'
16
+ )
17
+ match = list(_PAT_LAST_DIGIT.finditer(s))
18
+ if match:
19
+ last_digit = match[-1].group().replace(',', '').replace('+', '')
20
+ else:
21
+ last_digit = None
22
+ logging.warning(f'No digits found in {s!r}')
23
+ return last_digit
24
+
25
+ job_gen = completion.strip('.').replace('\n', '\\n')
26
+ last_digit = _get_last_digit(job_gen)
27
+ if last_digit:
28
+ return eval(last_digit)
29
+ else:
30
+ return INVALID_ANS
31
+
32
+
33
+ def is_correct(completion, answer):
34
+ gold = extract_answer(answer)
35
+ assert gold != INVALID_ANS, 'No ground truth answer found in the document.'
36
+ return extract_answer(completion) == gold
37
+
38
+
39
+ def eval_gsm8k_acc(output_fname):
40
+ data_list = load_jsonl(output_fname)
41
+ acc_res = [item['acc'] for item in data_list]
42
+ logging.info('=' * 60)
43
+ logging.info('{:^60}'.format('Math Acc.'))
44
+ logging.info('=' * 60)
45
+ logging.info('Total num={:.2f}'.format(len(acc_res)))
46
+ logging.info('Right num={:.2f}'.format(np.sum(acc_res)))
47
+ logging.info('Zero-shot Acc={:.2f}'.format(np.mean(acc_res) * 100))
48
+
49
+ error_data_list = [item for item in data_list if not item['acc']]
50
+ error_data_output_fname = os.path.splitext(
51
+ output_fname)[0] + '_gsm8k_error.jsonl'
52
+ save_jsonl(error_data_list, error_data_output_fname)
53
+
54
+ return {'math': np.mean(acc_res) * 100}
benchmark/metrics/visualization.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import re
4
+ import base64
5
+ import torch
6
+ from config import get_model, get_react_parser
7
+ from utils.data_utils import load_jsonl, save_jsonl
8
+
9
+ torch.manual_seed(1234)
10
+
11
+ EVAL_VISUAL_PROMPT_ZH = """请判断图片是否与下面的[问题]一致,如果一致则回复“right”,不一致则回复“wrong”。
12
+ [问题]:{query}
13
+ """
14
+
15
+ EVAL_VISUAL_PROMPT_EN = """Please judge whether the image is consistent with the [Question] below, if it is consistent then reply "right", if not then reply "wrong".
16
+ [Question]: {query}
17
+ """
18
+
19
+ visualization_code_correctness = {
20
+ 'visualization-hard': None,
21
+ 'visualization-easy': None,
22
+ }
23
+
24
+
25
+ def encode_image(image_path):
26
+ with open(image_path, "rb") as image_file:
27
+ a = base64.b64encode(image_file.read()).decode('utf-8')
28
+ return a
29
+
30
+
31
+ def judger_model_inference(judger_model_name, judger_model, imgs=[], prompt=''):
32
+ output = ""
33
+ if judger_model_name == 'gpt-4-vision-preview':
34
+ logging.warning("This is an example of `gpt-4-vision-preview`. "
35
+ "Please set the API key and use according to your actual situation.")
36
+ from openai import OpenAI
37
+ client = OpenAI()
38
+ content_list = []
39
+ content_list.append({"type": "text", "text": prompt})
40
+ input_images = []
41
+ for img in imgs:
42
+ if 'http' not in img:
43
+ base64_image = encode_image(img)
44
+ img = f"data:image/jpeg;base64,{base64_image}"
45
+ input_images.append({"type": "image_url", 'image_url': img})
46
+ content_list.extend(input_images)
47
+ response = client.chat.completions.create(
48
+ model="gpt-4-vision-preview",
49
+ messages=[
50
+ {
51
+ "role": "user",
52
+ "content": content_list,
53
+ }
54
+ ],
55
+ max_tokens=300,
56
+ )
57
+ output = response.choices[0]
58
+ elif judger_model_name in ['qwen-vl-plus', 'qwen-vl-chat']:
59
+ inputs = []
60
+ for img in imgs:
61
+ if 'http' not in img and judger_model_name == 'qwen-vl-plus':
62
+ img = "file://" + img
63
+ inputs.append({'image': img})
64
+ inputs.append({'text': prompt})
65
+
66
+ logging.info('Eval'.center(60, '-'))
67
+ logging.info(inputs)
68
+ output = judger_model.generate(inputs)
69
+ logging.info(output)
70
+ logging.info('=' * 60)
71
+ return output
72
+
73
+
74
+ def extract_images(text):
75
+ regex = re.compile(r'!\[fig-(.+)\]\((.+)\)')
76
+ results = re.findall(regex, text)
77
+ images = []
78
+ for res in results:
79
+ assert len(res) == 2
80
+ if os.path.exists(res[1]):
81
+ images.append(res[1])
82
+ return images
83
+
84
+
85
+ def check_images_observation(text, images, model_name):
86
+ start_flag = get_react_parser(model_name).observation
87
+ for image in images:
88
+ logging.info('Image'.center(60, '-'))
89
+ logging.info(image)
90
+
91
+ end_idx = text.find(image)
92
+ tmp_text = text[:end_idx + len(image)]
93
+ start_idx = tmp_text.rfind(start_flag)
94
+ check_text = tmp_text[start_idx + len(start_flag):]
95
+
96
+ logging.info('Observation'.center(60, '-'))
97
+ logging.info(check_text)
98
+
99
+ # As long as there exists correctly executed observation, we consider `True`
100
+ if 'error:' not in check_text and 'Traceback' not in check_text:
101
+ return True
102
+ return False
103
+
104
+
105
+ eval_visual_prompt = {'zh': EVAL_VISUAL_PROMPT_ZH, 'en': EVAL_VISUAL_PROMPT_EN}
106
+
107
+
108
+ def eval_visualization_acc(output_fname, model_name, judger_model_name='gpt-4-vision-preview'):
109
+ if judger_model_name == 'gpt-4-vision-preview':
110
+ judger_model = None
111
+ elif judger_model_name in ['qwen-vl-chat', 'qwen-vl-plus']:
112
+ if judger_model_name == 'qwen-vl-chat':
113
+ logging.warning('In this benchmark of version 20231206, `Qwen-vl-chat` is no longer used as the '
114
+ 'evaluation model for `Visualization` task.. If you insist on using it, '
115
+ 'the evaluation results might differ from the official results.')
116
+ judger_model = get_model(judger_model_name)
117
+ else:
118
+ raise Exception("Not supported judger model.")
119
+
120
+ one_action, one_action_right = 0, 0
121
+ zero_action, zero_action_right = 0, 0
122
+
123
+ data_list = load_jsonl(output_fname)
124
+ for item in data_list:
125
+ if 'visualization' not in item['tags']:
126
+ continue
127
+
128
+ item['vis_acc'] = False
129
+ if '<|im_end|>' in item['query']:
130
+ one_action += 1
131
+ prompt = item['query'].split('<|im_end|>')[0]
132
+ else:
133
+ zero_action += 1
134
+ prompt = item['query']
135
+
136
+ images = extract_images(item['gen'])
137
+
138
+ if images and check_images_observation(item['gen'], images,
139
+ model_name):
140
+ input_prompt = eval_visual_prompt[item.get('lang', 'en')]
141
+ format_prompt = input_prompt.format(query=prompt)
142
+ output = judger_model_inference(judger_model_name, judger_model, images, format_prompt)
143
+ if 'right' in output.lower():
144
+ item['vis_acc'] = True
145
+ if '<|im_end|>' in item['query']:
146
+ one_action_right += 1
147
+ else:
148
+ zero_action_right += 1
149
+
150
+ logging.info('*' * 60)
151
+ logging.info('{:^60}'.format('Visualization Acc.'))
152
+ logging.info('*' * 60)
153
+ logging.info(
154
+ 'Visualization-Hard count={}, Visualization-Hard right count={}, Visualization-Hard acc={:.2f}'
155
+ .format(zero_action, zero_action_right,
156
+ zero_action_right / zero_action * 100))
157
+ logging.info(
158
+ 'Visualization-Easy count={}, Visualization-Easy right count={}, Visualization-Easy acc={:.2f}'
159
+ .format(one_action, one_action_right,
160
+ one_action_right / one_action * 100))
161
+ logging.info('all count={}, all right={}, all acc={:.2f}'.format(
162
+ zero_action + one_action, zero_action_right + one_action_right,
163
+ (zero_action_right + one_action_right) / (zero_action + one_action) *
164
+ 100))
165
+
166
+ visualization_code_correctness[
167
+ 'visualization-hard'] = zero_action_right / zero_action * 100
168
+ visualization_code_correctness[
169
+ 'visualization-easy'] = one_action_right / one_action * 100
170
+
171
+ error_data_list = [
172
+ item for item in data_list
173
+ if 'visualization' in item['tags'] and not item['vis_acc']
174
+ ]
175
+ error_data_output_fname = os.path.splitext(
176
+ output_fname)[0] + '_vis_error.jsonl'
177
+ save_jsonl(error_data_list, error_data_output_fname)
178
+
179
+ return visualization_code_correctness
benchmark/models/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from models.base import HFModel # noqa
2
+ from models.llm import LLM # noqa
3
+ from models.qwen import Qwen, QwenVL # noqa
4
+ from models.dashscope import QwenDashscopeVLModel
benchmark/models/base.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer
2
+ from transformers.generation import GenerationConfig
3
+
4
+
5
+ class HFModel(object):
6
+
7
+ def __init__(self, model_path):
8
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path,
9
+ trust_remote_code=True)
10
+ self.model = AutoModelForCausalLM.from_pretrained(
11
+ model_path,
12
+ trust_remote_code=True,
13
+ device_map='auto',
14
+ low_cpu_mem_usage=True).eval()
15
+ self.model.generation_config = GenerationConfig.from_pretrained(
16
+ model_path, trust_remote_code=True)
17
+ self.model.generation_config.do_sample = False
benchmark/models/dashscope.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from http import HTTPStatus
3
+ import time
4
+ import dashscope
5
+
6
+
7
+ class QwenDashscopeVLModel(object):
8
+ def __init__(self, model, api_key):
9
+ self.model = model
10
+ dashscope.api_key = api_key.strip() or os.getenv('DASHSCOPE_API_KEY', default='')
11
+ assert dashscope.api_key, 'DASHSCOPE_API_KEY is required.'
12
+
13
+ def generate(self, prompt, stop_words=[]):
14
+ if isinstance(prompt, str):
15
+ prompt = [{'text': prompt}]
16
+
17
+ MAX_TRY = 3
18
+ count = 0
19
+ while count < MAX_TRY:
20
+ response = dashscope.MultiModalConversation.call(
21
+ self.model,
22
+ messages=[{'role': 'user', 'content': prompt}],
23
+ top_p=0.01,
24
+ top_k=1,
25
+ )
26
+ if response.status_code == HTTPStatus.OK:
27
+ output = response.output.choices[0].message.content[0]['text']
28
+ for stop_str in stop_words:
29
+ idx = output.find(stop_str)
30
+ if idx != -1:
31
+ output = output[: idx + len(stop_str)]
32
+ return output
33
+ else:
34
+ err = 'Error code: %s, error message: %s' % (
35
+ response.code,
36
+ response.message,
37
+ )
38
+ logging.error(err)
39
+ count += 1
40
+ time.sleep(1)
benchmark/models/llm.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from models.base import HFModel
3
+
4
+
5
+ class LLM(HFModel):
6
+
7
+ def __init__(self, model_path):
8
+ super().__init__(model_path)
9
+
10
+ def generate(self, input_text, stop_words=[], max_new_tokens=512):
11
+ if isinstance(input_text, str):
12
+ input_text = [input_text]
13
+
14
+ input_ids = self.tokenizer(input_text)['input_ids']
15
+ input_ids = torch.tensor(input_ids, device=self.model.device)
16
+ gen_kwargs = {'max_new_tokens': max_new_tokens, 'do_sample': False}
17
+ outputs = self.model.generate(input_ids, **gen_kwargs)
18
+ s = outputs[0][input_ids.shape[1]:]
19
+ output = self.tokenizer.decode(s, skip_special_tokens=True)
20
+
21
+ for stop_str in stop_words:
22
+ idx = output.find(stop_str)
23
+ if idx != -1:
24
+ output = output[:idx + len(stop_str)]
25
+
26
+ return output
benchmark/models/qwen.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from models.base import HFModel
3
+
4
+
5
+ class Qwen(HFModel):
6
+
7
+ def __init__(self, model_path):
8
+ super().__init__(model_path)
9
+
10
+ def generate(self, input_text, stop_words=[]):
11
+ im_end = '<|im_end|>'
12
+ if im_end not in stop_words:
13
+ stop_words = stop_words + [im_end]
14
+ stop_words_ids = [self.tokenizer.encode(w) for w in stop_words]
15
+
16
+ input_ids = torch.tensor([self.tokenizer.encode(input_text)
17
+ ]).to(self.model.device)
18
+ output = self.model.generate(input_ids, stop_words_ids=stop_words_ids)
19
+ output = output.tolist()[0]
20
+ output = self.tokenizer.decode(output, errors='ignore')
21
+ assert output.startswith(input_text)
22
+ output = output[len(input_text):].replace('<|endoftext|>',
23
+ '').replace(im_end, '')
24
+
25
+ return output
26
+
27
+
28
+ class QwenVL(HFModel):
29
+ def __init__(self, model_path):
30
+ super().__init__(model_path)
31
+
32
+ def generate(self, inputs: list):
33
+ query = self.tokenizer.from_list_format(inputs)
34
+ response, _ = self.model.chat(self.tokenizer, query=query, history=None)
35
+
36
+ return response
benchmark/parser/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from parser.internlm_parser import InternLMReActParser # noqa
2
+ from parser.react_parser import ReActParser # noqa
benchmark/parser/internlm_parser.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from parser.react_parser import ReActParser
2
+
3
+
4
+ class InternLMReActParser(ReActParser):
5
+
6
+ def __init__(self):
7
+ self.action = '\nAction:'
8
+ self.action_input = '\nActionInput:'
9
+ self.action_input_stop = '<eoa>'
10
+ self.observation = '<|System|>:Response:'
11
+ self.observation_stop = '<TOKENS_UNUSED_2>\n<|Bot|>:'
benchmark/parser/react_parser.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class ReActParser(object):
2
+
3
+ def __init__(self):
4
+ self.action = '\nAction:'
5
+ self.action_input = '\nAction Input:'
6
+ self.action_input_stop = '\nObservation:'
7
+ self.observation = '\nObservation:'
8
+ self.observation_stop = '\nThought:'
9
+
10
+ def parse_latest_plugin_call(self, text):
11
+ action = self.action
12
+ action_input = self.action_input
13
+ observation = self.action_input_stop
14
+ plugin_name, plugin_args = '', ''
15
+ i = text.rfind(action)
16
+ j = text.rfind(action_input)
17
+ k = text.rfind(observation)
18
+ if 0 <= i < j: # If the text has `Action` and `Action input`,
19
+ if k < j: # but does not contain `Observation`,
20
+ # then it is likely that `Observation` is ommited by the LLM,
21
+ # because the output text may have discarded the stop word.
22
+ text = text.rstrip() + observation # Add it back.
23
+ k = text.rfind(observation)
24
+ plugin_name = text[i + len(action):j].strip()
25
+ plugin_args = text[j + len(action_input):k].strip()
26
+ text = text[:k]
27
+ return plugin_name, plugin_args, text
28
+
29
+ def _extract_first_target(self, text, start_flag, end_flag):
30
+ target = ''
31
+ i = text.find(start_flag)
32
+ if i != -1:
33
+ j = text.find(end_flag, i)
34
+ if j != -1:
35
+ target = text[i + len(start_flag):j].strip()
36
+ else:
37
+ target = text[i + len(start_flag):].strip()
38
+ return target
39
+
40
+ def get_first_observation(self, text):
41
+ return self._extract_first_target(text, self.observation,
42
+ self.observation_stop)
43
+
44
+ def get_first_action_input(self, text):
45
+ return self._extract_first_target(text, self.action_input,
46
+ self.action_input_stop)
benchmark/prompt/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from prompt.internlm_react import InternLMReAct # noqa
2
+ from prompt.llama_react import LlamaReAct # noqa
3
+ from prompt.qwen_react import QwenReAct # noqa
4
+ from prompt.react import ReAct # noqa
benchmark/prompt/internlm_react.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from prompt.react import ReAct
2
+
3
+ INTERNLM_TOOL_DESCRIPTION = """用来执行Python代码。代码必须是一个函数,
4
+ 函数名必须得是 'solution',代码对应你的思考过程。代码实例格式如下:
5
+ ```python
6
+ # import 依赖包
7
+ import xxx
8
+ def solution():
9
+ # 初始化一些变量
10
+ variable_names_with_real_meaning = xxx
11
+ # 步骤一
12
+ mid_variable = func(variable_names_with_real_meaning)
13
+ # 步骤 x
14
+ mid_variable = func(mid_variable)
15
+ # 最后结果
16
+ final_answer = func(mid_variable)
17
+ return final_answer
18
+ ```"""
19
+
20
+ INTERNLM_TOOL = {'PythonInterpreter': INTERNLM_TOOL_DESCRIPTION}
21
+
22
+ INTERNLM_REACT_PROMPT_ZH = """<|System|>:你是一个可以调用外部工具的助手,可以使用的工具包括:
23
+ {tools_text}
24
+ 如果使用工具请遵循以下格式回复:
25
+ ```
26
+ Thought:思考你当前步骤需要解决什么问题,是否需要使用工具
27
+ Action:工具名称,你的工具必须从 [{tools_name_text}] 选择
28
+ ActionInput:工具输入参数
29
+ ```
30
+ 工具返回按照以下格式回复:
31
+ ```
32
+ Response:调用工具后的结果
33
+ ```
34
+ 如果你已经知道了答案,或者你不需要工具,请遵循以下格式回复
35
+ ```
36
+ Thought:给出最终答案的思考过程
37
+ FinalAnswer:最终答案
38
+ ```
39
+ 开始!<TOKENS_UNUSED_2>
40
+ <|User|>:{query}<eoh>
41
+ <|Bot|>:"""
42
+
43
+ INTERNLM_REACT_PROMPT_EN = """<|System|>:You are a assistant who can utilize external tools.
44
+ {tools_text}
45
+ To use a tool, please use the following format:
46
+ ```
47
+ Thought: Think what you need to solve, do you need to use tools?
48
+ Action: the tool name, should be one of [{tools_name_text}]
49
+ ActionInput: the input to the action
50
+ ```
51
+ The response after utilizing tools should using the following format:
52
+ ```
53
+ Response: the results after call the tool.
54
+ ``
55
+ If you already know the answer, or you do not need to use tools,
56
+ please using the following format to reply:
57
+ ```
58
+ Thought: the thought process to get the final answer
59
+ FinalAnswer: final answer
60
+ ```
61
+ Begin!<TOKENS_UNUSED_2>
62
+ <|User|>:{query}<eoh>
63
+ <|Bot|>:"""
64
+
65
+
66
+ class InternLMReAct(ReAct):
67
+
68
+ def __init__(self, query, lang='en', upload_file_paths=[]):
69
+ super().__init__(query, lang, upload_file_paths)
70
+ self.react_template = INTERNLM_REACT_PROMPT_ZH if self.lang == 'zh' else INTERNLM_REACT_PROMPT_EN
71
+
72
+ def build_prompt(self):
73
+ planning_prompt = super().build_prompt()
74
+ if '<|im_end|>' in self.query and planning_prompt.endswith(
75
+ '<eoh>\n<|Bot|>:'):
76
+ planning_prompt = planning_prompt[:-len('<eoh>\n<|Bot|>:')]
77
+
78
+ if '<|im_end|>' in self.query:
79
+ planning_prompt = planning_prompt.replace(
80
+ '<|im_end|>\n<|im_start|>assistant\n',
81
+ '<eoh>\n<|Bot|>:').replace(
82
+ 'Observation:', '<eoa>\n<|System|>:Response:').replace(
83
+ '\nAction Input',
84
+ '\nActionInput').replace('code_interpreter',
85
+ 'PythonInterpreter')
86
+ assert planning_prompt.endswith('Thought:')
87
+ planning_prompt = planning_prompt[:-len(
88
+ 'Thought:')] + '<TOKENS_UNUSED_2>\n<|Bot|>:'
89
+
90
+ self.prompt = planning_prompt
91
+ return planning_prompt
92
+
93
+ def _build_tools_text(self):
94
+ return INTERNLM_TOOL
95
+
96
+ def _build_tools_name_text(self):
97
+ return list(INTERNLM_TOOL.keys())
98
+
99
+ def build_observation(self, observation):
100
+ return f'<eoa>\n<|System|>:Response:{observation}\n<TOKENS_UNUSED_2>\n<|Bot|>:'
101
+
102
+ def get_stop_words_list(self):
103
+ return ['<eoa>']
benchmark/prompt/llama_react.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from prompt.react import ReAct
2
+
3
+
4
+ class LlamaReAct(ReAct):
5
+
6
+ def __init__(self, query, lang='en', upload_file_paths=[]):
7
+ super().__init__(query, lang, upload_file_paths)
8
+
9
+ def build_prompt(self):
10
+ planning_prompt = super().build_prompt()
11
+ planning_prompt = '[INST] ' + planning_prompt + ' [/INST]'
12
+
13
+ if '<|im_end|>' in self.query:
14
+ planning_prompt = planning_prompt.replace(
15
+ '<|im_end|>\n<|im_start|>assistant', ' [/INST] ')
16
+ assert planning_prompt.endswith(' [/INST]')
17
+ planning_prompt = planning_prompt[:-len(' [/INST]')]
18
+
19
+ self.prompt = planning_prompt
20
+ return planning_prompt
benchmark/prompt/qwen_react.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ from prompt.react import ReAct
5
+
6
+ QWEN_TOOLS_LIST = [
7
+ {
8
+ 'name_for_human': '代码解释器',
9
+ 'name_for_model': 'code_interpreter',
10
+ 'description_for_model': '代码解释器,可用于执行Python代码。',
11
+ 'parameters': [{
12
+ 'name': 'code',
13
+ 'type': 'string',
14
+ 'description': '待执行的代码'
15
+ }],
16
+ 'args_format': 'code'
17
+ },
18
+ ]
19
+
20
+ TOOL_DESC = """{name_for_model}: Call this tool to interact with the {name_for_human} API. What is the {name_for_human} API useful for? {description_for_model} Parameters: {parameters}"""
21
+
22
+
23
+ class QwenReAct(ReAct):
24
+
25
+ def __init__(self, query, lang='en', upload_file_paths=[]):
26
+ super().__init__(query, lang, upload_file_paths)
27
+
28
+ self.upload_file_paths = [
29
+ f'{os.path.basename(fname)}' for fname in upload_file_paths
30
+ ]
31
+ self.list_of_plugin_info = QWEN_TOOLS_LIST
32
+ self.fname_template = {
33
+ 'zh': '[上传文件{fname_str}]',
34
+ 'en': '[Upload file {fname_str}]',
35
+ 'en_multi': '[Upload file {fname_str}]'
36
+ }
37
+
38
+ def build_prompt(self):
39
+ im_start = '<|im_start|>'
40
+ im_end = '<|im_end|>'
41
+ prompt = f'{im_start}system\nYou are a helpful assistant.{im_end}'
42
+
43
+ query = super().build_prompt()
44
+
45
+ query = query.lstrip('\n').rstrip()
46
+ prompt += f'\n{im_start}user\n{query}{im_end}'
47
+ if f'{im_start}assistant' not in query:
48
+ prompt += f'\n{im_start}assistant\n{im_end}'
49
+ assert prompt.endswith(f'\n{im_start}assistant\n{im_end}')
50
+
51
+ prompt = prompt[:-len(f'{im_end}')]
52
+ self.prompt = prompt
53
+ return prompt
54
+
55
+ def _build_tools_text(self):
56
+ # tool info
57
+ tools_text = []
58
+ for plugin_info in self.list_of_plugin_info:
59
+ tool = TOOL_DESC.format(
60
+ name_for_model=plugin_info['name_for_model'],
61
+ name_for_human=plugin_info['name_for_human'],
62
+ description_for_model=plugin_info['description_for_model'],
63
+ parameters=json.dumps(plugin_info['parameters'],
64
+ ensure_ascii=False),
65
+ )
66
+ if plugin_info.get('args_format', 'json') == 'json':
67
+ tool += ' Format the arguments as a JSON object.'
68
+ elif plugin_info['args_format'] == 'code':
69
+ tool += ' Enclose the code within triple backticks (`) at the beginning and end of the code.'
70
+ else:
71
+ raise NotImplementedError
72
+ tools_text.append(tool)
73
+ tools_text = '\n\n'.join(tools_text)
74
+ return tools_text
75
+
76
+ def _build_tools_name_text(self):
77
+ return ', '.join([
78
+ plugin_info['name_for_model']
79
+ for plugin_info in self.list_of_plugin_info
80
+ ])
benchmark/prompt/react.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ tools_text = """code_interpreter: Call this tool to interact with the Code Interpreter API.
4
+ What is the Code Interpreter API useful for?
5
+ Code Interpreter is used to execute Python code to deal with the following tasks:
6
+ 1. Solving mathematical problems, both quantitative and qualitative
7
+ 2. Doing data analysis and visualization
8
+ 3. Converting files between formats
9
+ Parameters:
10
+ ```py
11
+ code
12
+ ```
13
+ Enclose the code within triple backticks (```) at the beginning and end of the code.
14
+ """
15
+
16
+ REACT_PROMPT = """Answer the following questions as best you can. You have access to the following tools:
17
+
18
+ {tools_text}
19
+
20
+ Use the following format:
21
+
22
+ Question: the input question you must answer
23
+ Thought: you should always think about what to do
24
+ Action: the action to take, should be one of [{tools_name_text}]
25
+ Action Input: the input to the action
26
+ Observation: the result of the action
27
+ ... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
28
+ Thought: I now know the final answer
29
+ Final Answer: the final answer to the original input question
30
+
31
+ Begin!
32
+
33
+ Question: {query}"""
34
+
35
+ fname_template = {
36
+ 'zh': '文件{fname_str},',
37
+ 'en_multi': 'Files {fname_str}. ',
38
+ 'en': 'File {fname_str}. ',
39
+ }
40
+
41
+
42
+ class ReAct(object):
43
+
44
+ def __init__(self, query, lang='en', upload_file_paths=[]):
45
+ self.query = query
46
+ self.lang = lang
47
+ self.upload_file_paths = [
48
+ f'`{os.path.basename(fname)}`' for fname in upload_file_paths
49
+ ]
50
+
51
+ self.fname_template = fname_template
52
+ self.react_template = REACT_PROMPT
53
+ self.prompt = ''
54
+
55
+ def build_prompt(self):
56
+ query = self._format_upload_fname() + self.query
57
+ tools_text = self._build_tools_text()
58
+ tools_name_text = self._build_tools_name_text()
59
+ planning_prompt = self.react_template.format(
60
+ query=query,
61
+ tools_text=tools_text,
62
+ tools_name_text=tools_name_text)
63
+
64
+ self.prompt = planning_prompt
65
+ return planning_prompt
66
+
67
+ def _format_upload_fname(self):
68
+ prefix = ''
69
+ if self.upload_file_paths:
70
+ fname_str = ', '.join(self.upload_file_paths)
71
+ lang_key = 'en_multi' if self.lang == 'en' and len(
72
+ self.upload_file_paths) > 1 else self.lang
73
+ fname_template = self.fname_template[lang_key]
74
+ prefix = fname_template.format(fname_str=fname_str)
75
+ return prefix
76
+
77
+ def _build_tools_text(self):
78
+ return tools_text
79
+
80
+ def _build_tools_name_text(self):
81
+ return 'code_interpreter'
82
+
83
+ def build_observation(self, observation):
84
+ return f'\nObservation: {observation}\nThought:'
85
+
86
+ def get_stop_words_list(self):
87
+ return ['Observation:', 'Observation:\n']
benchmark/requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate>=0.20.3
2
+ func_timeout
3
+ json5
4
+ matplotlib
5
+ numpy
6
+ pandas
7
+ PrettyTable
8
+ scipy
9
+ seaborn
10
+ sympy
11
+ transformers==4.33.1
12
+ transformers_stream_generator
13
+ openai
benchmark/utils/__init__.py ADDED
File without changes
benchmark/utils/code_utils.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+
4
+ import json5
5
+
6
+
7
+ def replace_upload_fname(text, upload_fname_list):
8
+ for full_input_fname in upload_fname_list:
9
+ if full_input_fname not in text and os.path.basename(
10
+ full_input_fname) in text:
11
+ text = text.replace(os.path.basename(full_input_fname),
12
+ full_input_fname)
13
+ return text
14
+
15
+
16
+ def extract_code(text):
17
+ # Match triple backtick blocks first
18
+ triple_match = re.search(r'```[^\n]*\n(.+?)```', text, re.DOTALL)
19
+ # Match single backtick blocks second
20
+ single_match = re.search(r'`([^`]*)`', text, re.DOTALL)
21
+ if triple_match:
22
+ text = triple_match.group(1)
23
+ elif single_match:
24
+ text = single_match.group(1)
25
+ else:
26
+ try:
27
+ text = json5.loads(text)['code']
28
+ except Exception:
29
+ pass
30
+ # If no code blocks found, return original text
31
+ return text
benchmark/utils/data_utils.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+
4
+ from tqdm import tqdm
5
+
6
+
7
+ def load_jsonl(path):
8
+ data = []
9
+ with open(path, 'r', encoding='utf8') as f:
10
+ for idx, line in enumerate(f, start=1):
11
+ try:
12
+ data.append(json.loads(line))
13
+ except Exception as e:
14
+ logging.info(line)
15
+ logging.warning(f'Error at line {idx}: {e}')
16
+ continue
17
+ return data
18
+
19
+
20
+ def save_jsonl(data, path, progress=False, enabled=True):
21
+ if not enabled:
22
+ return
23
+ with open(path, 'w', encoding='utf-8') as f:
24
+ if progress:
25
+ data = tqdm(data)
26
+ for item in data:
27
+ line = json.dumps(item, ensure_ascii=False)
28
+ print(line, file=f)
browser_qwen/background.js ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ var database;
2
+
3
+ function send_data(msg){
4
+ chrome.storage.local.get(['database_host'], function(result) {
5
+ if (result.database_host) {
6
+ console.log('database_host currently is ' + result.database_host);
7
+ database = "http://"+result.database_host+":7866/endpoint";
8
+ } else {
9
+ database = "http://127.0.0.1:7866/endpoint";
10
+ }
11
+ fetch(database, {
12
+ method: "POST",
13
+ headers: {
14
+ "Content-Type": "application/json",
15
+ },
16
+ body: JSON.stringify(msg),
17
+ })
18
+ .then((response) => response.json())
19
+ .then((data) => {
20
+ console.log(data.result);
21
+ });
22
+ });
23
+ }
24
+
25
+ chrome.runtime.onMessage.addListener(async (msg, sender) => {
26
+ if (msg.flag == "open_tab_and_cache_from_content"){
27
+ var url = "";
28
+ chrome.tabs.query({active: true, currentWindow: true}, function(tabs) {
29
+ url = tabs[0].url;
30
+ console.log(url);
31
+ if (msg.data) {
32
+ chrome.storage.sync.get(['data'], function(result) {
33
+ chrome.storage.sync.set({ data: result.data }, function() {
34
+ send_data({ 'content' : msg.data, 'query': '', 'url':url, 'task':'cache', 'type':msg.type});
35
+ });
36
+ });
37
+ }
38
+ });
39
+ }
40
+ if (msg.flag == "open_popup_and_send_url_from_popup"){
41
+ if (msg.data) {
42
+ chrome.storage.sync.get(['data'], function(result) {
43
+ chrome.storage.sync.set({ data: result.data }, function() {
44
+ send_data({ 'url' : msg.data, 'task':'pop_url'});
45
+ });
46
+ });
47
+ }
48
+ }
49
+ // if (msg.flag == "set_addr"){
50
+ // if (msg.data) {
51
+ // chrome.storage.sync.get(['data'], function(result) {
52
+ // chrome.storage.sync.set({ data: result.data }, function() {
53
+ // send_data({ 'addr' : msg.data, 'task':'set_addr'});
54
+ // });
55
+ // });
56
+ // }
57
+ // }
58
+ });
browser_qwen/img/copy.png ADDED
browser_qwen/img/logo.png ADDED
browser_qwen/img/popup.png ADDED
browser_qwen/manifest.json ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "BrowserQwen",
3
+ "description" : "An Extension Driven by LLM",
4
+ "version": "1.0",
5
+ "manifest_version": 3,
6
+
7
+ "background": {
8
+ "service_worker": "background.js"
9
+ },
10
+
11
+ "action": {
12
+ "default_popup": "src/popup.html",
13
+ "default_icon": "img/popup.png",
14
+ "default_title": "BrowserQwen"
15
+ },
16
+ "permissions": [
17
+ "tabs",
18
+ "notifications",
19
+ "storage",
20
+ "scripting",
21
+ "activeTab"
22
+ ],
23
+ "host_permissions": [
24
+ "http://*/*",
25
+ "https://*/*"
26
+ ],
27
+ "icons": {
28
+ "16": "img/popup.png",
29
+ "32": "img/popup.png",
30
+ "48": "img/popup.png",
31
+ "128": "img/popup.png"
32
+ },
33
+ "content_scripts": [
34
+ {
35
+ "js": ["src/content.js"],
36
+ "matches": [
37
+ "https://www.jianshu.com/p/*",
38
+ "https://*/*",
39
+ "http://*/*",
40
+ "file:///*/*"
41
+ ]
42
+ }
43
+ ]
44
+
45
+ }
browser_qwen/src/content.js ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ function getPageTextContent() {
3
+ var textContent = document.body.textContent;
4
+ return textContent;
5
+ }
6
+
7
+ function cache_browser(){
8
+ const body = document.querySelector('html');
9
+ const text = body.innerHTML;
10
+ console.log(text);
11
+ chrome.runtime.sendMessage({ data: text , close: true , flag: 'open_tab_and_cache_from_content', type: 'html'});
12
+
13
+ }
14
+
15
+ const floatingBox = document.createElement('div');
16
+ floatingBox.style.position = 'fixed';
17
+ floatingBox.style.bottom = '650px';
18
+ floatingBox.style.right = '60px';
19
+ floatingBox.style.width = '125px';
20
+ floatingBox.style.height = '55px';
21
+ floatingBox.style.backgroundColor = '#f2f2f2';
22
+ floatingBox.style.border = '1px solid black';
23
+ floatingBox.style.borderRadius = '5px';
24
+ floatingBox.style.padding = '10px';
25
+ floatingBox.style.zIndex = '9999';
26
+
27
+ const button = document.createElement('button');
28
+ button.style.position = 'fixed';
29
+ button.style.top = '30px';
30
+ button.style.right = '30px';
31
+ button.style.zIndex = "9999";
32
+ button.textContent = "Add to Qwen's Reading List";
33
+ button.style.fontFamily = 'Arial, sans-serif';
34
+ button.style.fontSize = '14px';
35
+ button.style.width = '140px';
36
+ button.style.height = '60px';
37
+ button.style.backgroundColor = '#695DE8';
38
+ button.style.color = 'white';
39
+ button.style.borderRadius = '5px';
40
+ button.style.border = '0px';
41
+ button.style.whiteSpace = 'pre-wrap';
42
+ button.style.boxShadow = '0 4px 6px rgba(0, 0, 0, 0.2)';
43
+
44
+ floatingBox.appendChild(button);
45
+
46
+ document.body.appendChild(button);
47
+
48
+ let isDragging = false;
49
+ var isMouseReleased = false;
50
+ let initialX;
51
+ let initialY;
52
+
53
+ button.addEventListener('mousedown', (e) => {
54
+ isDragging = true;
55
+ initialX = e.clientX;
56
+ initialY = e.clientY;
57
+ });
58
+
59
+ document.addEventListener('mousemove', (e) => {
60
+ if (isDragging) {
61
+ const dx = e.clientX - initialX;
62
+ const dy = e.clientY - initialY;
63
+ button.style.right = `${parseFloat(button.style.right) - dx}px`;
64
+ button.style.top = `${parseFloat(button.style.top) + dy}px`;
65
+ initialX = e.clientX;
66
+ initialY = e.clientY;
67
+ isMouseReleased = true;
68
+ }
69
+ });
70
+
71
+ document.addEventListener('mouseup', (e) => {
72
+ isDragging = false;
73
+
74
+ });
75
+
76
+ button.addEventListener('click', (e) => {
77
+ if (isMouseReleased) {
78
+ isMouseReleased = false;
79
+ e.stopPropagation();
80
+ } else {
81
+ var result = confirm("Are you sure to ask Qwen to remember this page?");
82
+ if (result) {
83
+ cache_browser()
84
+ }
85
+ }
86
+ });
browser_qwen/src/popup.html ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html>
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta http-equiv="X-UA-Compatible" content="IE=edge">
6
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
7
+ <title>BrowserQwen</title>
8
+ <!-- <script src="popup.js"></script>-->
9
+ <style>
10
+ .title-style {
11
+ margin-top: 10px;
12
+ margin-left: 5px;
13
+ font-family: Arial, sans-serif;
14
+ font-size: 24px;
15
+ color: #333;
16
+ }
17
+
18
+ body {
19
+ width: 500px;
20
+ height: 600px;
21
+ background-color: aliceblue;
22
+ }
23
+ .contnet {
24
+ display: flex;
25
+ flex-direction: column;
26
+ flex-wrap: nowrap;
27
+ align-items: center;
28
+ }
29
+ .upload_file {
30
+ display: flex;
31
+ flex-direction: column;
32
+ margin-top: 10px;
33
+ }
34
+ .upload_btn {
35
+
36
+ border: none;
37
+ border-radius: 5px;
38
+ background-color: #5c5cdf;
39
+ font-size: 16px;
40
+ color: white;
41
+ font-weight: 400;
42
+ cursor: pointer;
43
+ width: 70px;
44
+ height: 35px;
45
+ }
46
+ .upload_btn:hover {
47
+ border: none;
48
+ border-radius: 10px;
49
+ background-color: #7f7ff2;
50
+ font-size: 16px;
51
+ color: white;
52
+ font-weight: 400;
53
+ cursor: pointer;
54
+ box-shadow: 0px 0px 0px 1px #848181;
55
+ }
56
+
57
+ .div-with-copy-paste {
58
+ width: 340px;
59
+ height: 200px;
60
+ background-color: #f2f2f2;
61
+ border: 1px solid #ccc;
62
+ border-radius: 5px;
63
+ padding: 20px;
64
+ box-shadow: 0 2px 5px rgba(0, 0, 0, 0.3);
65
+ position: relative;
66
+ }
67
+
68
+ .input-text {
69
+ width: 300px;
70
+ /* height: 100px; */
71
+ background-color: #fffdfd;
72
+ border: 1px solid #ccc;
73
+ border-radius: 10px;
74
+ padding: 20px;
75
+ box-shadow: 0 2px 5px rgba(0, 0, 0, 0.3);
76
+ position: left;
77
+ margin-top: 5px;
78
+ margin-bottom: 5px;
79
+ }
80
+
81
+ .copy-button {
82
+ /* position: absolute;
83
+ top: 10px;
84
+ right: 10px; */
85
+ background-color: #5085cf;
86
+ color: white;
87
+ padding: 8px 8px;
88
+ border: none;
89
+ border-radius: 3px;
90
+ cursor: pointer;
91
+ }
92
+
93
+ .copy-button:hover {
94
+ background-color: #45a049;
95
+ }
96
+
97
+ ::placeholder {
98
+ color: rgb(197, 196, 196);
99
+ }
100
+
101
+ iframe {
102
+ width: 100%;
103
+ height: 100%;
104
+ border: none;
105
+ }
106
+ </style>
107
+
108
+ </head>
109
+
110
+
111
+ <body>
112
+ <!-- <iframe src=$popup_url style="height: 550px"></iframe>-->
113
+ <div id="iframe_area" style="height: 570px"></div>
114
+
115
+ <h3>Customize Address:</h3>
116
+ <input type="text" id="addr" name="addr" class="input-text">
117
+ <button id="set_addr" class="upload_btn">Change</button>
118
+
119
+ <script src="popup.js"></script>
120
+ </body>
121
+ </html>
browser_qwen/src/popup.js ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ chrome.runtime.onMessage.addListener((msg, sender, sendResponse) => {
4
+ // if (msg.flag == 'from_content'){
5
+ // console.log(msg.rsp);
6
+ // var sessionContainer = document.getElementById('session');
7
+ // sessionContainer.innerText = msg.rsp;
8
+ // sendResponse({ msg: 'Get!' });
9
+ // }
10
+ if (msg.flag === 'from_llm'){
11
+ // var sessionContainer = document.getElementById('session');
12
+ // // sessionContainer.innerHTML = msg.rsp;
13
+ // sessionContainer.innerText = msg.rsp;
14
+ sendResponse({ message: 'Get Response!' });
15
+ }
16
+ });
17
+
18
+
19
+ document.addEventListener('DOMContentLoaded', function() {
20
+ chrome.tabs.query({ active: true, currentWindow: true }, function(tabs) {
21
+ var currentUrl = tabs[0].url;
22
+
23
+ chrome.runtime.sendMessage({ data: currentUrl , close: true , flag: 'open_popup_and_send_url_from_popup'});
24
+
25
+ });
26
+ setTimeout(function() {
27
+ // console.log('This message will be logged after 0.5 second');
28
+ var popup_url='';
29
+ chrome.storage.local.get(['database_host'], function(result) {
30
+ if (result.database_host) {
31
+ console.log('database_host currently is ' + result.database_host);
32
+ popup_url = "http://"+result.database_host+":7863/";
33
+ } else {
34
+ popup_url = "http://127.0.0.1:7863/";
35
+ }
36
+ var iframe = document.createElement('iframe');
37
+ iframe.src = popup_url;
38
+ iframe.height = '570px';
39
+ // iframe.sandbox = 'allow-same-origin allow-scripts';
40
+ // iframe.allow = "geolocation *;";
41
+ var iframe_area = document.getElementById('iframe_area')
42
+ iframe_area.appendChild(iframe);
43
+
44
+ });
45
+ }, 500);
46
+
47
+ // fetch('../config_host.json')
48
+ // .then(response => response.json())
49
+ // .then(data => {
50
+ // console.log(data);
51
+ // popup_url = "http://"+data.database_host+":"+data.app_in_browser_port+"/";
52
+ // console.log(popup_url);
53
+ // })
54
+ // .catch(error => console.error('Error:', error));
55
+ })
56
+
57
+ document.getElementById('set_addr').addEventListener('click', function() {
58
+ var addr = document.getElementById('addr').value;
59
+ // save config
60
+ chrome.storage.local.set({database_host: addr}, function() {
61
+ console.log('database_host is set to ' + addr);
62
+ // chrome.runtime.sendMessage({ data: addr , close: true , flag: 'set_addr'});
63
+ document.getElementById('addr').value = '';
64
+ });
65
+ })
openai_api.py ADDED
@@ -0,0 +1,564 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Implements API for Qwen-7B in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat)
3
+ # Usage: python openai_api.py
4
+ # Visit http://localhost:8000/docs for documents.
5
+
6
+ import re
7
+ import copy
8
+ import json
9
+ import time
10
+ from argparse import ArgumentParser
11
+ from contextlib import asynccontextmanager
12
+ from typing import Dict, List, Literal, Optional, Union
13
+
14
+ import torch
15
+ import uvicorn
16
+ from fastapi import FastAPI, HTTPException
17
+ from fastapi.middleware.cors import CORSMiddleware
18
+ from pydantic import BaseModel, Field
19
+ from sse_starlette.sse import EventSourceResponse
20
+ from transformers import AutoTokenizer, AutoModelForCausalLM
21
+ from transformers.generation import GenerationConfig
22
+ from starlette.middleware.base import BaseHTTPMiddleware
23
+ from starlette.requests import Request
24
+ from starlette.responses import Response
25
+ import base64
26
+
27
+
28
+ class BasicAuthMiddleware(BaseHTTPMiddleware):
29
+ def __init__(self, app, username: str, password: str):
30
+ super().__init__(app)
31
+ self.required_credentials = base64.b64encode(f"{username}:{password}".encode()).decode()
32
+
33
+ async def dispatch(self, request: Request, call_next):
34
+ authorization: str = request.headers.get("Authorization")
35
+ if authorization:
36
+ try:
37
+ schema, credentials = authorization.split()
38
+ if credentials == self.required_credentials:
39
+ return await call_next(request)
40
+ except ValueError:
41
+ pass
42
+
43
+ headers = {'WWW-Authenticate': 'Basic'}
44
+ return Response(status_code=401, headers=headers)
45
+
46
+
47
+ def _gc(forced: bool = False):
48
+ global args
49
+ if args.disable_gc and not forced:
50
+ return
51
+
52
+ import gc
53
+ gc.collect()
54
+ if torch.cuda.is_available():
55
+ torch.cuda.empty_cache()
56
+
57
+
58
+ @asynccontextmanager
59
+ async def lifespan(app: FastAPI): # collects GPU memory
60
+ yield
61
+ _gc(forced=True)
62
+
63
+
64
+ app = FastAPI(lifespan=lifespan)
65
+
66
+ app.add_middleware(
67
+ CORSMiddleware,
68
+ allow_origins=["*"],
69
+ allow_credentials=True,
70
+ allow_methods=["*"],
71
+ allow_headers=["*"],
72
+ )
73
+
74
+
75
+ class ModelCard(BaseModel):
76
+ id: str
77
+ object: str = "model"
78
+ created: int = Field(default_factory=lambda: int(time.time()))
79
+ owned_by: str = "owner"
80
+ root: Optional[str] = None
81
+ parent: Optional[str] = None
82
+ permission: Optional[list] = None
83
+
84
+
85
+ class ModelList(BaseModel):
86
+ object: str = "list"
87
+ data: List[ModelCard] = []
88
+
89
+
90
+ class ChatMessage(BaseModel):
91
+ role: Literal["user", "assistant", "system", "function"]
92
+ content: Optional[str]
93
+ function_call: Optional[Dict] = None
94
+
95
+
96
+ class DeltaMessage(BaseModel):
97
+ role: Optional[Literal["user", "assistant", "system"]] = None
98
+ content: Optional[str] = None
99
+
100
+
101
+ class ChatCompletionRequest(BaseModel):
102
+ model: str
103
+ messages: List[ChatMessage]
104
+ functions: Optional[List[Dict]] = None
105
+ temperature: Optional[float] = None
106
+ top_p: Optional[float] = None
107
+ max_length: Optional[int] = None
108
+ stream: Optional[bool] = False
109
+ stop: Optional[List[str]] = None
110
+
111
+
112
+ class ChatCompletionResponseChoice(BaseModel):
113
+ index: int
114
+ message: ChatMessage
115
+ finish_reason: Literal["stop", "length", "function_call"]
116
+
117
+
118
+ class ChatCompletionResponseStreamChoice(BaseModel):
119
+ index: int
120
+ delta: DeltaMessage
121
+ finish_reason: Optional[Literal["stop", "length"]]
122
+
123
+
124
+ class ChatCompletionResponse(BaseModel):
125
+ model: str
126
+ object: Literal["chat.completion", "chat.completion.chunk"]
127
+ choices: List[
128
+ Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]
129
+ ]
130
+ created: Optional[int] = Field(default_factory=lambda: int(time.time()))
131
+
132
+
133
+ @app.get("/v1/models", response_model=ModelList)
134
+ async def list_models():
135
+ global model_args
136
+ model_card = ModelCard(id="gpt-3.5-turbo")
137
+ return ModelList(data=[model_card])
138
+
139
+
140
+ # To work around that unpleasant leading-\n tokenization issue!
141
+ def add_extra_stop_words(stop_words):
142
+ if stop_words:
143
+ _stop_words = []
144
+ _stop_words.extend(stop_words)
145
+ for x in stop_words:
146
+ s = x.lstrip("\n")
147
+ if s and (s not in _stop_words):
148
+ _stop_words.append(s)
149
+ return _stop_words
150
+ return stop_words
151
+
152
+
153
+ def trim_stop_words(response, stop_words):
154
+ if stop_words:
155
+ for stop in stop_words:
156
+ idx = response.find(stop)
157
+ if idx != -1:
158
+ response = response[:idx]
159
+ return response
160
+
161
+
162
+ TOOL_DESC = """{name_for_model}: Call this tool to interact with the {name_for_human} API. What is the {name_for_human} API useful for? {description_for_model} Parameters: {parameters}"""
163
+
164
+ REACT_INSTRUCTION = """Answer the following questions as best you can. You have access to the following APIs:
165
+
166
+ {tools_text}
167
+
168
+ Use the following format:
169
+
170
+ Question: the input question you must answer
171
+ Thought: you should always think about what to do
172
+ Action: the action to take, should be one of [{tools_name_text}]
173
+ Action Input: the input to the action
174
+ Observation: the result of the action
175
+ ... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
176
+ Thought: I now know the final answer
177
+ Final Answer: the final answer to the original input question
178
+
179
+ Begin!"""
180
+
181
+ _TEXT_COMPLETION_CMD = object()
182
+
183
+
184
+ #
185
+ # Temporarily, the system role does not work as expected.
186
+ # We advise that you write the setups for role-play in your query,
187
+ # i.e., use the user role instead of the system role.
188
+ #
189
+ # TODO: Use real system role when the model is ready.
190
+ #
191
+ def parse_messages(messages, functions):
192
+ if all(m.role != "user" for m in messages):
193
+ raise HTTPException(
194
+ status_code=400,
195
+ detail=f"Invalid request: Expecting at least one user message.",
196
+ )
197
+
198
+ messages = copy.deepcopy(messages)
199
+ default_system = "You are a helpful assistant."
200
+ system = ""
201
+ if messages[0].role == "system":
202
+ system = messages.pop(0).content.lstrip("\n").rstrip()
203
+ if system == default_system:
204
+ system = ""
205
+
206
+ if functions:
207
+ tools_text = []
208
+ tools_name_text = []
209
+ for func_info in functions:
210
+ name = func_info.get("name", "")
211
+ name_m = func_info.get("name_for_model", name)
212
+ name_h = func_info.get("name_for_human", name)
213
+ desc = func_info.get("description", "")
214
+ desc_m = func_info.get("description_for_model", desc)
215
+ tool = TOOL_DESC.format(
216
+ name_for_model=name_m,
217
+ name_for_human=name_h,
218
+ # Hint: You can add the following format requirements in description:
219
+ # "Format the arguments as a JSON object."
220
+ # "Enclose the code within triple backticks (`) at the beginning and end of the code."
221
+ description_for_model=desc_m,
222
+ parameters=json.dumps(func_info["parameters"], ensure_ascii=False),
223
+ )
224
+ tools_text.append(tool)
225
+ tools_name_text.append(name_m)
226
+ tools_text = "\n\n".join(tools_text)
227
+ tools_name_text = ", ".join(tools_name_text)
228
+ system += "\n\n" + REACT_INSTRUCTION.format(
229
+ tools_text=tools_text,
230
+ tools_name_text=tools_name_text,
231
+ )
232
+ system = system.lstrip("\n").rstrip()
233
+
234
+ dummy_thought = {
235
+ "en": "\nThought: I now know the final answer.\nFinal answer: ",
236
+ "zh": "\nThought: 我会作答了。\nFinal answer: ",
237
+ }
238
+
239
+ _messages = messages
240
+ messages = []
241
+ for m_idx, m in enumerate(_messages):
242
+ role, content, func_call = m.role, m.content, m.function_call
243
+ if content:
244
+ content = content.lstrip("\n").rstrip()
245
+ if role == "function":
246
+ if (len(messages) == 0) or (messages[-1].role != "assistant"):
247
+ raise HTTPException(
248
+ status_code=400,
249
+ detail=f"Invalid request: Expecting role assistant before role function.",
250
+ )
251
+ messages[-1].content += f"\nObservation: {content}"
252
+ if m_idx == len(_messages) - 1:
253
+ messages[-1].content += "\nThought:"
254
+ elif role == "assistant":
255
+ if len(messages) == 0:
256
+ raise HTTPException(
257
+ status_code=400,
258
+ detail=f"Invalid request: Expecting role user before role assistant.",
259
+ )
260
+ last_msg = messages[-1].content
261
+ last_msg_has_zh = len(re.findall(r"[\u4e00-\u9fff]+", last_msg)) > 0
262
+ if func_call is None:
263
+ if functions:
264
+ content = dummy_thought["zh" if last_msg_has_zh else "en"] + content
265
+ else:
266
+ f_name, f_args = func_call["name"], func_call["arguments"]
267
+ if not content:
268
+ if last_msg_has_zh:
269
+ content = f"Thought: 我可以使用 {f_name} API。"
270
+ else:
271
+ content = f"Thought: I can use {f_name}."
272
+ content = f"\n{content}\nAction: {f_name}\nAction Input: {f_args}"
273
+ if messages[-1].role == "user":
274
+ messages.append(
275
+ ChatMessage(role="assistant", content=content.lstrip("\n").rstrip())
276
+ )
277
+ else:
278
+ messages[-1].content += content
279
+ elif role == "user":
280
+ messages.append(
281
+ ChatMessage(role="user", content=content.lstrip("\n").rstrip())
282
+ )
283
+ else:
284
+ raise HTTPException(
285
+ status_code=400, detail=f"Invalid request: Incorrect role {role}."
286
+ )
287
+
288
+ query = _TEXT_COMPLETION_CMD
289
+ if messages[-1].role == "user":
290
+ query = messages[-1].content
291
+ messages = messages[:-1]
292
+
293
+ if len(messages) % 2 != 0:
294
+ raise HTTPException(status_code=400, detail="Invalid request")
295
+
296
+ history = [] # [(Q1, A1), (Q2, A2), ..., (Q_last_turn, A_last_turn)]
297
+ for i in range(0, len(messages), 2):
298
+ if messages[i].role == "user" and messages[i + 1].role == "assistant":
299
+ usr_msg = messages[i].content.lstrip("\n").rstrip()
300
+ bot_msg = messages[i + 1].content.lstrip("\n").rstrip()
301
+ if system and (i == len(messages) - 2):
302
+ usr_msg = f"{system}\n\nQuestion: {usr_msg}"
303
+ system = ""
304
+ for t in dummy_thought.values():
305
+ t = t.lstrip("\n")
306
+ if bot_msg.startswith(t) and ("\nAction: " in bot_msg):
307
+ bot_msg = bot_msg[len(t):]
308
+ history.append([usr_msg, bot_msg])
309
+ else:
310
+ raise HTTPException(
311
+ status_code=400,
312
+ detail="Invalid request: Expecting exactly one user (or function) role before every assistant role.",
313
+ )
314
+ if system:
315
+ assert query is not _TEXT_COMPLETION_CMD
316
+ query = f"{system}\n\nQuestion: {query}"
317
+ return query, history
318
+
319
+
320
+ def parse_response(response):
321
+ func_name, func_args = "", ""
322
+ i = response.rfind("\nAction:")
323
+ j = response.rfind("\nAction Input:")
324
+ k = response.rfind("\nObservation:")
325
+ if 0 <= i < j: # If the text has `Action` and `Action input`,
326
+ if k < j: # but does not contain `Observation`,
327
+ # then it is likely that `Observation` is omitted by the LLM,
328
+ # because the output text may have discarded the stop word.
329
+ response = response.rstrip() + "\nObservation:" # Add it back.
330
+ k = response.rfind("\nObservation:")
331
+ func_name = response[i + len("\nAction:"): j].strip()
332
+ func_args = response[j + len("\nAction Input:"): k].strip()
333
+ if func_name:
334
+ choice_data = ChatCompletionResponseChoice(
335
+ index=0,
336
+ message=ChatMessage(
337
+ role="assistant",
338
+ content=response[:i],
339
+ function_call={"name": func_name, "arguments": func_args},
340
+ ),
341
+ finish_reason="function_call",
342
+ )
343
+ return choice_data
344
+ z = response.rfind("\nFinal Answer: ")
345
+ if z >= 0:
346
+ response = response[z + len("\nFinal Answer: "):]
347
+ choice_data = ChatCompletionResponseChoice(
348
+ index=0,
349
+ message=ChatMessage(role="assistant", content=response),
350
+ finish_reason="stop",
351
+ )
352
+ return choice_data
353
+
354
+
355
+ # completion mode, not chat mode
356
+ def text_complete_last_message(history, stop_words_ids, gen_kwargs):
357
+ im_start = "<|im_start|>"
358
+ im_end = "<|im_end|>"
359
+ prompt = f"{im_start}system\nYou are a helpful assistant.{im_end}"
360
+ for i, (query, response) in enumerate(history):
361
+ query = query.lstrip("\n").rstrip()
362
+ response = response.lstrip("\n").rstrip()
363
+ prompt += f"\n{im_start}user\n{query}{im_end}"
364
+ prompt += f"\n{im_start}assistant\n{response}{im_end}"
365
+ prompt = prompt[: -len(im_end)]
366
+
367
+ _stop_words_ids = [tokenizer.encode(im_end)]
368
+ if stop_words_ids:
369
+ for s in stop_words_ids:
370
+ _stop_words_ids.append(s)
371
+ stop_words_ids = _stop_words_ids
372
+
373
+ input_ids = torch.tensor([tokenizer.encode(prompt)]).to(model.device)
374
+ output = model.generate(input_ids, stop_words_ids=stop_words_ids, **gen_kwargs).tolist()[0]
375
+ output = tokenizer.decode(output, errors="ignore")
376
+ assert output.startswith(prompt)
377
+ output = output[len(prompt):]
378
+ output = trim_stop_words(output, ["<|endoftext|>", im_end])
379
+ print(f"<completion>\n{prompt}\n<!-- *** -->\n{output}\n</completion>")
380
+ return output
381
+
382
+
383
+ @app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
384
+ async def create_chat_completion(request: ChatCompletionRequest):
385
+ global model, tokenizer
386
+
387
+ gen_kwargs = {}
388
+ if request.temperature is not None:
389
+ if request.temperature < 0.01:
390
+ gen_kwargs['top_k'] = 1 # greedy decoding
391
+ else:
392
+ # Not recommended. Please tune top_p instead.
393
+ gen_kwargs['temperature'] = request.temperature
394
+ if request.top_p is not None:
395
+ gen_kwargs['top_p'] = request.top_p
396
+
397
+ stop_words = add_extra_stop_words(request.stop)
398
+ if request.functions:
399
+ stop_words = stop_words or []
400
+ if "Observation:" not in stop_words:
401
+ stop_words.append("Observation:")
402
+
403
+ query, history = parse_messages(request.messages, request.functions)
404
+
405
+ if request.stream:
406
+ if request.functions:
407
+ raise HTTPException(
408
+ status_code=400,
409
+ detail="Invalid request: Function calling is not yet implemented for stream mode.",
410
+ )
411
+ generate = predict(query, history, request.model, stop_words, gen_kwargs)
412
+ return generate
413
+ # return EventSourceResponse(generate, media_type="text/event-stream")
414
+
415
+ stop_words_ids = [tokenizer.encode(s) for s in stop_words] if stop_words else None
416
+ if query is _TEXT_COMPLETION_CMD:
417
+ response = text_complete_last_message(history, stop_words_ids=stop_words_ids, gen_kwargs=gen_kwargs)
418
+ else:
419
+ response, _ = model.chat(
420
+ tokenizer,
421
+ query,
422
+ history=history,
423
+ stop_words_ids=stop_words_ids,
424
+ **gen_kwargs
425
+ )
426
+ print(f"<chat>\n{history}\n{query}\n<!-- *** -->\n{response}\n</chat>")
427
+ _gc()
428
+
429
+ response = trim_stop_words(response, stop_words)
430
+ if request.functions:
431
+ choice_data = parse_response(response)
432
+ else:
433
+ choice_data = ChatCompletionResponseChoice(
434
+ index=0,
435
+ message=ChatMessage(role="assistant", content=response),
436
+ finish_reason="stop",
437
+ )
438
+ return ChatCompletionResponse(
439
+ model=request.model, choices=[choice_data], object="chat.completion"
440
+ )
441
+
442
+
443
+ def _dump_json(data: BaseModel, *args, **kwargs) -> str:
444
+ try:
445
+ return data.model_dump_json(*args, **kwargs)
446
+ except AttributeError: # pydantic<2.0.0
447
+ return data.json(*args, **kwargs) # noqa
448
+
449
+
450
+ async def predict(
451
+ query: str, history: List[List[str]], model_id: str, stop_words: List[str], gen_kwargs: Dict,
452
+ ):
453
+ global model, tokenizer
454
+ choice_data = ChatCompletionResponseStreamChoice(
455
+ index=0, delta=DeltaMessage(role="assistant"), finish_reason=None
456
+ )
457
+ chunk = ChatCompletionResponse(
458
+ model=model_id, choices=[choice_data], object="chat.completion.chunk"
459
+ )
460
+ yield "{}".format(_dump_json(chunk, exclude_unset=True))
461
+
462
+ current_length = 0
463
+ stop_words_ids = [tokenizer.encode(s) for s in stop_words] if stop_words else None
464
+ if stop_words:
465
+ # TODO: It's a little bit tricky to trim stop words in the stream mode.
466
+ raise HTTPException(
467
+ status_code=400,
468
+ detail="Invalid request: custom stop words are not yet supported for stream mode.",
469
+ )
470
+ response_generator = model.chat_stream(
471
+ tokenizer, query, history=history, stop_words_ids=stop_words_ids, **gen_kwargs
472
+ )
473
+ for new_response in response_generator:
474
+ if len(new_response) == current_length:
475
+ continue
476
+
477
+ new_text = new_response[current_length:]
478
+ current_length = len(new_response)
479
+
480
+ choice_data = ChatCompletionResponseStreamChoice(
481
+ index=0, delta=DeltaMessage(content=new_text), finish_reason=None
482
+ )
483
+ chunk = ChatCompletionResponse(
484
+ model=model_id, choices=[choice_data], object="chat.completion.chunk"
485
+ )
486
+ yield "{}".format(_dump_json(chunk, exclude_unset=True))
487
+
488
+ choice_data = ChatCompletionResponseStreamChoice(
489
+ index=0, delta=DeltaMessage(), finish_reason="stop"
490
+ )
491
+ chunk = ChatCompletionResponse(
492
+ model=model_id, choices=[choice_data], object="chat.completion.chunk"
493
+ )
494
+ yield "{}".format(_dump_json(chunk, exclude_unset=True))
495
+ yield "[DONE]"
496
+
497
+ _gc()
498
+
499
+
500
+ def _get_args():
501
+ parser = ArgumentParser()
502
+ parser.add_argument(
503
+ "-c",
504
+ "--checkpoint-path",
505
+ type=str,
506
+ default="Qwen/Qwen-7B-Chat",
507
+ help="Checkpoint name or path, default to %(default)r",
508
+ )
509
+ parser.add_argument(
510
+ "--api-auth", help="API authentication credentials"
511
+ )
512
+ parser.add_argument(
513
+ "--cpu-only", action="store_true", help="Run demo with CPU only"
514
+ )
515
+ parser.add_argument(
516
+ "--server-port", type=int, default=8000, help="Demo server port."
517
+ )
518
+ parser.add_argument(
519
+ "--server-name",
520
+ type=str,
521
+ default="127.0.0.1",
522
+ help="Demo server name. Default: 127.0.0.1, which is only visible from the local computer."
523
+ " If you want other computers to access your server, use 0.0.0.0 instead.",
524
+ )
525
+ parser.add_argument("--disable-gc", action="store_true",
526
+ help="Disable GC after each response generated.")
527
+
528
+ args = parser.parse_args()
529
+ return args
530
+
531
+
532
+ if __name__ == "__main__":
533
+ args = _get_args()
534
+
535
+ tokenizer = AutoTokenizer.from_pretrained(
536
+ args.checkpoint_path,
537
+ trust_remote_code=True,
538
+ resume_download=True,
539
+ )
540
+
541
+ if args.api_auth:
542
+ app.add_middleware(
543
+ BasicAuthMiddleware, username=args.api_auth.split(":")[0], password=args.api_auth.split(":")[1]
544
+ )
545
+
546
+ if args.cpu_only:
547
+ device_map = "cpu"
548
+ else:
549
+ device_map = "auto"
550
+
551
+ model = AutoModelForCausalLM.from_pretrained(
552
+ args.checkpoint_path,
553
+ device_map=device_map,
554
+ trust_remote_code=True,
555
+ resume_download=True,
556
+ ).eval()
557
+
558
+ model.generation_config = GenerationConfig.from_pretrained(
559
+ args.checkpoint_path,
560
+ trust_remote_code=True,
561
+ resume_download=True,
562
+ )
563
+
564
+ uvicorn.run(app, host=args.server_name, port=args.server_port, workers=1)
qwen_agent/__init__.py ADDED
File without changes
qwen_agent/actions/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .continue_writing import ContinueWriting
2
+ from .expand_writing import ExpandWriting
3
+ from .gen_keyword import GenKeyword
4
+ from .outline_writing import OutlineWriting
5
+ from .react import ReAct
6
+ from .retrieval_qa import RetrievalQA
7
+ from .summarize import Summarize
8
+ from .write_from_scratch import WriteFromScratch
9
+
10
+ __all__ = [
11
+ 'RetrievalQA', 'ContinueWriting', 'OutlineWriting', 'ExpandWriting',
12
+ 'ReAct', 'WriteFromScratch', 'Summarize', 'GenKeyword'
13
+ ]
qwen_agent/actions/base.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Dict, Iterator, List, Optional, Union
3
+
4
+ from qwen_agent.llm.base import BaseChatModel
5
+ from qwen_agent.utils.utils import has_chinese_chars
6
+
7
+ # TODO: Should *planning* just be another action that uses other actions?
8
+
9
+
10
+ class Action(ABC):
11
+
12
+ def __init__(self, llm: BaseChatModel = None, stream: bool = False):
13
+ self.llm = llm
14
+ self.stream = stream
15
+
16
+ def run(self, *args, **kwargs) -> Union[str, Iterator[str]]:
17
+ if 'lang' not in kwargs:
18
+ if has_chinese_chars([args, kwargs]):
19
+ kwargs['lang'] = 'zh'
20
+ else:
21
+ kwargs['lang'] = 'en'
22
+ return self._run(*args, **kwargs)
23
+
24
+ @abstractmethod
25
+ def _run(self, *args, **kwargs) -> Union[str, Iterator[str]]:
26
+ raise NotImplementedError
27
+
28
+ # It is okay for an Action to not call LLMs.
29
+ def _call_llm(
30
+ self,
31
+ prompt: Optional[str] = None,
32
+ messages: Optional[List[Dict]] = None,
33
+ stop: Optional[List[str]] = None,
34
+ ) -> Union[str, Iterator[str]]:
35
+ return self.llm.chat(
36
+ prompt=prompt,
37
+ messages=messages,
38
+ stop=stop,
39
+ stream=self.stream,
40
+ )
qwen_agent/actions/continue_writing.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from qwen_agent.actions.base import Action
2
+
3
+ PROMPT_TEMPLATE_ZH = """你是一个写作助手,请依据参考资料,根据给定的前置文本续写合适的内容。
4
+ #参考资料:
5
+ {ref_doc}
6
+
7
+ #前置文本:
8
+ {user_request}
9
+
10
+ 保证续写内容和前置文本保持连贯,请开始续写:"""
11
+
12
+ PROMPT_TEMPLATE_EN = """You are a writing assistant, please follow the reference materials and continue to write appropriate content based on the given previous text.
13
+
14
+ # References:
15
+ {ref_doc}
16
+
17
+ # Previous text:
18
+ {user_request}
19
+
20
+ Please start writing directly, output only the continued text, do not repeat the previous text, do not say irrelevant words, and ensure that the continued content and the previous text remain consistent."""
21
+
22
+ PROMPT_TEMPLATE = {
23
+ 'zh': PROMPT_TEMPLATE_ZH,
24
+ 'en': PROMPT_TEMPLATE_EN,
25
+ }
26
+
27
+
28
+ class ContinueWriting(Action):
29
+
30
+ def _run(self, user_request, ref_doc, lang: str = 'en'):
31
+ prompt = PROMPT_TEMPLATE[lang].format(
32
+ ref_doc=ref_doc,
33
+ user_request=user_request,
34
+ )
35
+ return self._call_llm(prompt)
qwen_agent/actions/expand_writing.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from qwen_agent.actions.base import Action
2
+
3
+ PROMPT_TEMPLATE_ZH = """
4
+ 你是一个写作助手,任务是依据参考资料,完成写作任务。
5
+ #参考资料:
6
+ {ref_doc}
7
+
8
+ 写作标题是:{user_request}
9
+ 大纲是:
10
+ {outline}
11
+
12
+ 此时你的任务是扩写第{index}个一级标题对应的章节:{capture}。注意每个章节负责撰写不同的内容,所以你不需要为了全面而涵盖之后的内容。请不要在这里生成大纲。只依据给定的参考资料来写,不要引入其余知识。
13
+ """
14
+
15
+ PROMPT_TEMPLATE_EN = """
16
+ You are a writing assistant. Your task is to complete writing article based on reference materials.
17
+
18
+ # References:
19
+ {ref_doc}
20
+
21
+ The title is: {user_request}
22
+
23
+ The outline is:
24
+ {outline}
25
+
26
+ At this point, your task is to expand the chapter corresponding to the {index} first level title: {capture}.
27
+ Note that each chapter is responsible for writing different content, so you don't need to cover the following content. Please do not generate an outline here. Write only based on the given reference materials and do not introduce other knowledge.
28
+ """
29
+
30
+ PROMPT_TEMPLATE = {
31
+ 'zh': PROMPT_TEMPLATE_ZH,
32
+ 'en': PROMPT_TEMPLATE_EN,
33
+ }
34
+
35
+
36
+ class ExpandWriting(Action):
37
+
38
+ def _run(
39
+ self,
40
+ user_request,
41
+ ref_doc,
42
+ outline='',
43
+ index='1',
44
+ capture='',
45
+ capture_later='',
46
+ lang: str = 'en',
47
+ ):
48
+ prompt = PROMPT_TEMPLATE[lang].format(
49
+ ref_doc=ref_doc,
50
+ user_request=user_request,
51
+ index=index,
52
+ outline=outline,
53
+ capture=capture,
54
+ )
55
+ if capture_later:
56
+ if lang == 'zh':
57
+ prompt = prompt + '请在涉及 ' + capture_later + ' 时停止。'
58
+ elif lang == 'en':
59
+ prompt = prompt + ' Please stop when writing ' + capture_later
60
+ else:
61
+ raise NotImplementedError
62
+ return self._call_llm(prompt)