Spaces:
Running
Running
vlff李飞飞
commited on
Commit
·
2319518
1
Parent(s):
8d16531
update md
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +24 -0
- .pre-commit-config.yaml +27 -0
- Dockerfile +14 -0
- LICENSE +53 -0
- README_CN.md +252 -0
- assets/screenshot-ci.png +0 -0
- assets/screenshot-editor-movie.png +0 -0
- assets/screenshot-multi-web-qa.png +0 -0
- assets/screenshot-pdf-qa.png +0 -0
- assets/screenshot-web-qa.png +0 -0
- assets/screenshot-writing.png +0 -0
- benchmark/README.md +248 -0
- benchmark/code_interpreter.py +250 -0
- benchmark/config.py +66 -0
- benchmark/inference_and_execute.py +280 -0
- benchmark/metrics/__init__.py +0 -0
- benchmark/metrics/code_execution.py +257 -0
- benchmark/metrics/gsm8k.py +54 -0
- benchmark/metrics/visualization.py +179 -0
- benchmark/models/__init__.py +4 -0
- benchmark/models/base.py +17 -0
- benchmark/models/dashscope.py +40 -0
- benchmark/models/llm.py +26 -0
- benchmark/models/qwen.py +36 -0
- benchmark/parser/__init__.py +2 -0
- benchmark/parser/internlm_parser.py +11 -0
- benchmark/parser/react_parser.py +46 -0
- benchmark/prompt/__init__.py +4 -0
- benchmark/prompt/internlm_react.py +103 -0
- benchmark/prompt/llama_react.py +20 -0
- benchmark/prompt/qwen_react.py +80 -0
- benchmark/prompt/react.py +87 -0
- benchmark/requirements.txt +13 -0
- benchmark/utils/__init__.py +0 -0
- benchmark/utils/code_utils.py +31 -0
- benchmark/utils/data_utils.py +28 -0
- browser_qwen/background.js +58 -0
- browser_qwen/img/copy.png +0 -0
- browser_qwen/img/logo.png +0 -0
- browser_qwen/img/popup.png +0 -0
- browser_qwen/manifest.json +45 -0
- browser_qwen/src/content.js +86 -0
- browser_qwen/src/popup.html +121 -0
- browser_qwen/src/popup.js +65 -0
- openai_api.py +564 -0
- qwen_agent/__init__.py +0 -0
- qwen_agent/actions/__init__.py +13 -0
- qwen_agent/actions/base.py +40 -0
- qwen_agent/actions/continue_writing.py +35 -0
- qwen_agent/actions/expand_writing.py +62 -0
.gitignore
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
env
|
2 |
+
*.pyc
|
3 |
+
__pycache__
|
4 |
+
|
5 |
+
.idea
|
6 |
+
.vscode
|
7 |
+
.DS_Store
|
8 |
+
|
9 |
+
qwen_agent/llm/gpt.py
|
10 |
+
qwen_agent/llm/tools.py
|
11 |
+
workspace/*
|
12 |
+
|
13 |
+
benchmark/log/*
|
14 |
+
benchmark/output_data/*
|
15 |
+
benchmark/upload_file/*
|
16 |
+
benchmark/upload_file_clean/*
|
17 |
+
benchmark/eval_data/
|
18 |
+
Qwen-Agent
|
19 |
+
|
20 |
+
docqa/*
|
21 |
+
log/*
|
22 |
+
ai_builder/*
|
23 |
+
qwen_agent.egg-info/*
|
24 |
+
build/*
|
.pre-commit-config.yaml
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
repos:
|
2 |
+
- repo: https://github.com/pycqa/flake8.git
|
3 |
+
rev: 5.0.4
|
4 |
+
hooks:
|
5 |
+
- id: flake8
|
6 |
+
args: ["--max-line-length=300"]
|
7 |
+
- repo: https://github.com/PyCQA/isort.git
|
8 |
+
rev: 5.11.5
|
9 |
+
hooks:
|
10 |
+
- id: isort
|
11 |
+
- repo: https://github.com/pre-commit/mirrors-yapf.git
|
12 |
+
rev: v0.32.0
|
13 |
+
hooks:
|
14 |
+
- id: yapf
|
15 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks.git
|
16 |
+
rev: v4.3.0
|
17 |
+
hooks:
|
18 |
+
- id: trailing-whitespace
|
19 |
+
- id: check-yaml
|
20 |
+
- id: end-of-file-fixer
|
21 |
+
- id: requirements-txt-fixer
|
22 |
+
- id: double-quote-string-fixer
|
23 |
+
- id: check-merge-conflict
|
24 |
+
- id: fix-encoding-pragma
|
25 |
+
args: ["--remove"]
|
26 |
+
- id: mixed-line-ending
|
27 |
+
args: ["--fix=lf"]
|
Dockerfile
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
|
2 |
+
# you will also find guides on how best to write your Dockerfile
|
3 |
+
|
4 |
+
FROM python:3.10
|
5 |
+
|
6 |
+
WORKDIR /code
|
7 |
+
|
8 |
+
COPY ./requirements.txt /code/requirements.txt
|
9 |
+
|
10 |
+
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
|
11 |
+
|
12 |
+
COPY . .
|
13 |
+
|
14 |
+
CMD ["python", "run_server.py", "--llm", "Qwen/Qwen-1_8B-Chat", "--model_server", "http://127.0.0.1:7905/v1", "--server_host", "0.0.0.0", "--workstation_port", "7860"]
|
LICENSE
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Tongyi Qianwen LICENSE AGREEMENT
|
2 |
+
|
3 |
+
Tongyi Qianwen Release Date: August 3, 2023
|
4 |
+
|
5 |
+
By clicking to agree or by using or distributing any portion or element of the Tongyi Qianwen Materials, you will be deemed to have recognized and accepted the content of this Agreement, which is effective immediately.
|
6 |
+
|
7 |
+
1. Definitions
|
8 |
+
a. This Tongyi Qianwen LICENSE AGREEMENT (this "Agreement") shall mean the terms and conditions for use, reproduction, distribution and modification of the Materials as defined by this Agreement.
|
9 |
+
b. "We"(or "Us") shall mean Alibaba Cloud.
|
10 |
+
c. "You" (or "Your") shall mean a natural person or legal entity exercising the rights granted by this Agreement and/or using the Materials for any purpose and in any field of use.
|
11 |
+
d. "Third Parties" shall mean individuals or legal entities that are not under common control with Us or You.
|
12 |
+
e. "Tongyi Qianwen" shall mean the large language models (including Qwen model and Qwen-Chat model), and software and algorithms, consisting of trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing distributed by Us.
|
13 |
+
f. "Materials" shall mean, collectively, Alibaba Cloud's proprietary Tongyi Qianwen and Documentation (and any portion thereof) made available under this Agreement.
|
14 |
+
g. "Source" form shall mean the preferred form for making modifications, including but not limited to model source code, documentation source, and configuration files.
|
15 |
+
h. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation,
|
16 |
+
and conversions to other media types.
|
17 |
+
|
18 |
+
2. Grant of Rights
|
19 |
+
You are granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Alibaba Cloud's intellectual property or other rights owned by Us embodied in the Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the Materials.
|
20 |
+
|
21 |
+
3. Redistribution
|
22 |
+
You may reproduce and distribute copies of the Materials or derivative works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
|
23 |
+
a. You shall give any other recipients of the Materials or derivative works a copy of this Agreement;
|
24 |
+
b. You shall cause any modified files to carry prominent notices stating that You changed the files;
|
25 |
+
c. You shall retain in all copies of the Materials that You distribute the following attribution notices within a "Notice" text file distributed as a part of such copies: "Tongyi Qianwen is licensed under the Tongyi Qianwen LICENSE AGREEMENT, Copyright (c) Alibaba Cloud. All Rights Reserved."; and
|
26 |
+
d. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such derivative works as a whole, provided Your use, reproduction, and distribution of the work otherwise complies with the terms and conditions of this Agreement.
|
27 |
+
|
28 |
+
4. Restrictions
|
29 |
+
If you are commercially using the Materials, and your product or service has more than 100 million monthly active users, You shall request a license from Us. You cannot exercise your rights under this Agreement without our express authorization.
|
30 |
+
|
31 |
+
5. Rules of use
|
32 |
+
a. The Materials may be subject to export controls or restrictions in China, the United States or other countries or regions. You shall comply with applicable laws and regulations in your use of the Materials.
|
33 |
+
b. You can not use the Materials or any output therefrom to improve any other large language model (excluding Tongyi Qianwen or derivative works thereof).
|
34 |
+
|
35 |
+
6. Intellectual Property
|
36 |
+
a. We retain ownership of all intellectual property rights in and to the Materials and derivatives made by or for Us. Conditioned upon compliance with the terms and conditions of this Agreement, with respect to any derivative works and modifications of the Materials that are made by you, you are and will be the owner of such derivative works and modifications.
|
37 |
+
b. No trademark license is granted to use the trade names, trademarks, service marks, or product names of Us, except as required to fulfill notice requirements under this Agreement or as required for reasonable and customary use in describing and redistributing the Materials.
|
38 |
+
c. If you commence a lawsuit or other proceedings (including a cross-claim or counterclaim in a lawsuit) against Us or any entity alleging that the Materials or any output therefrom, or any part of the foregoing, infringe any intellectual property or other right owned or licensable by you, then all licences granted to you under this Agreement shall terminate as of the date such lawsuit or other proceeding is commenced or brought.
|
39 |
+
|
40 |
+
7. Disclaimer of Warranty and Limitation of Liability
|
41 |
+
|
42 |
+
a. We are not obligated to support, update, provide training for, or develop any further version of the Tongyi Qianwen Materials or to grant any license thereto.
|
43 |
+
b. THE MATERIALS ARE PROVIDED "AS IS" WITHOUT ANY EXPRESS OR IMPLIED WARRANTY OF ANY KIND INCLUDING WARRANTIES OF MERCHANTABILITY, NONINFRINGEMENT, OR FITNESS FOR A PARTICULAR PURPOSE. WE MAKE NO WARRANTY AND ASSUME NO RESPONSIBILITY FOR THE SAFETY OR STABILITY OF THE MATERIALS AND ANY OUTPUT THEREFROM.
|
44 |
+
c. IN NO EVENT SHALL WE BE LIABLE TO YOU FOR ANY DAMAGES, INCLUDING, BUT NOT LIMITED TO ANY DIRECT, OR INDIRECT, SPECIAL OR CONSEQUENTIAL DAMAGES ARISING FROM YOUR USE OR INABILITY TO USE THE MATERIALS OR ANY OUTPUT OF IT, NO MATTER HOW IT’S CAUSED.
|
45 |
+
d. You will defend, indemnify and hold harmless Us from and against any claim by any third party arising out of or related to your use or distribution of the Materials.
|
46 |
+
|
47 |
+
8. Survival and Termination.
|
48 |
+
a. The term of this Agreement shall commence upon your acceptance of this Agreement or access to the Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein.
|
49 |
+
b. We may terminate this Agreement if you breach any of the terms or conditions of this Agreement. Upon termination of this Agreement, you must delete and cease use of the Materials. Sections 7 and 9 shall survive the termination of this Agreement.
|
50 |
+
|
51 |
+
9. Governing Law and Jurisdiction.
|
52 |
+
a. This Agreement and any dispute arising out of or relating to it will be governed by the laws of China, without regard to conflict of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement.
|
53 |
+
b. The People's Courts in Hangzhou City shall have exclusive jurisdiction over any dispute arising out of this Agreement.
|
README_CN.md
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Qwen Agent
|
3 |
+
emoji: 📈
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: purple
|
6 |
+
sdk: docker
|
7 |
+
pinned: false
|
8 |
+
license: apache-2.0
|
9 |
+
app_port: 7860
|
10 |
+
---
|
11 |
+
|
12 |
+
中文 | [English](./README.md)
|
13 |
+
|
14 |
+
<p align="center">
|
15 |
+
<img src="https://qianwen-res.oss-cn-beijing.aliyuncs.com/assets/qwen_agent/logo-qwen-agent.png" width="400"/>
|
16 |
+
<p>
|
17 |
+
<br>
|
18 |
+
|
19 |
+
Qwen-Agent是一个代码框架,用于发掘开源通义千问模型([Qwen](https://github.com/QwenLM/Qwen))的工具使用、规划、记忆能力。
|
20 |
+
在Qwen-Agent的基础上,我们开发了一个名为BrowserQwen的**Chrome浏览器扩展**,它具有以下主要功能:
|
21 |
+
|
22 |
+
- 与Qwen讨论当前网页或PDF文档的内容。
|
23 |
+
- 在获得您的授权后,BrowserQwen会记录您浏览过的网页和PDF/Word/PPT材料,以帮助您快速了解多个页面的内容,总结您浏览过的内容,并自动化繁琐的文字工作。
|
24 |
+
- 集成各种插件,包括可用于数学问题求解、数据分析与可视化、处理文件等的**代码解释器**(**Code Interpreter**)。
|
25 |
+
|
26 |
+
# 用例演示
|
27 |
+
|
28 |
+
如果您更喜欢观看视频,而不是效果截图,可以参见[视频演示](#视频演示)。
|
29 |
+
|
30 |
+
## 工作台 - 创作模式
|
31 |
+
|
32 |
+
**根据浏览过的网页、PDFs素材进行长文创作**
|
33 |
+
|
34 |
+
<figure>
|
35 |
+
<img src="assets/screenshot-writing.png">
|
36 |
+
</figure>
|
37 |
+
|
38 |
+
**调用插件辅助富文本创作**
|
39 |
+
|
40 |
+
<figure>
|
41 |
+
<img src="assets/screenshot-editor-movie.png">
|
42 |
+
</figure>
|
43 |
+
|
44 |
+
## 工作台 - 对话模式
|
45 |
+
|
46 |
+
**多网页问答**
|
47 |
+
|
48 |
+
<figure >
|
49 |
+
<img src="assets/screenshot-multi-web-qa.png">
|
50 |
+
</figure>
|
51 |
+
|
52 |
+
**使用代码解释器绘制数据图表**
|
53 |
+
|
54 |
+
<figure>
|
55 |
+
<img src="assets/screenshot-ci.png">
|
56 |
+
</figure>
|
57 |
+
|
58 |
+
## 浏览器助手
|
59 |
+
|
60 |
+
**网页问答**
|
61 |
+
|
62 |
+
<figure>
|
63 |
+
<img src="assets/screenshot-web-qa.png">
|
64 |
+
</figure>
|
65 |
+
|
66 |
+
**PDF文档问答**
|
67 |
+
|
68 |
+
<figure>
|
69 |
+
<img src="assets/screenshot-pdf-qa.png">
|
70 |
+
</figure>
|
71 |
+
|
72 |
+
# BrowserQwen 使用说明
|
73 |
+
|
74 |
+
支持环境:MacOS,Linux,Windows。
|
75 |
+
|
76 |
+
## 第一步 - 部署模型服务
|
77 |
+
|
78 |
+
***如果您正在使用阿里云提供的[DashScope](https://help.aliyun.com/zh/dashscope/developer-reference/quick-start)服务来访问Qwen系列模型,可以跳过这一步,直接到第二步。***
|
79 |
+
|
80 |
+
但如果您不想使用DashScope,而是希望自己部署一个模型服务。那么可以参考[Qwen项目](https://github.com/QwenLM/Qwen/blob/main/README_CN.md#api),部署一个兼容OpenAI API的模型服务:
|
81 |
+
|
82 |
+
```bash
|
83 |
+
# 安装依赖
|
84 |
+
git clone git@github.com:QwenLM/Qwen.git
|
85 |
+
cd Qwen
|
86 |
+
pip install -r requirements.txt
|
87 |
+
pip install fastapi uvicorn "openai<1.0.0" "pydantic>=2.3.0" sse_starlette
|
88 |
+
|
89 |
+
# 启动模型服务,通过 -c 参数指定模型版本
|
90 |
+
# - 指定 --server-name 0.0.0.0 将允许其他机器访问您的模型服务
|
91 |
+
# - 指定 --server-name 127.0.0.1 则只允许部署模型的机器自身访问该模型服务
|
92 |
+
python openai_api.py --server-name 0.0.0.0 --server-port 7905 -c Qwen/Qwen-14B-Chat
|
93 |
+
```
|
94 |
+
|
95 |
+
目前,我们支持指定-c参数以加载 [Qwen 的 Hugging Face主页](https://huggingface.co/Qwen) 上的模型,比如`Qwen/Qwen-1_8B-Chat`、`Qwen/Qwen-7B-Chat`、`Qwen/Qwen-14B-Chat`、`Qwen/Qwen-72B-Chat`,以及它们的`Int4`和`Int8`版本。
|
96 |
+
|
97 |
+
## 第二步 - 部署本地数据库服务
|
98 |
+
|
99 |
+
在这一步,您需要在您的本地机器上(即您可以打开Chrome浏览器的那台机器),部署维护个人浏览历史、对话历史的数据库服务。
|
100 |
+
|
101 |
+
首次启动数据库服务前,请记得安装相关的依赖:
|
102 |
+
|
103 |
+
```bash
|
104 |
+
# 安装依赖
|
105 |
+
git clone https://github.com/QwenLM/Qwen-Agent.git
|
106 |
+
cd Qwen-Agent
|
107 |
+
pip install -r requirements.txt
|
108 |
+
```
|
109 |
+
|
110 |
+
如果跳过了第一步、因为您打算使用DashScope提供的模型服务的话,请执行以下命令启动数据库服务:
|
111 |
+
|
112 |
+
```bash
|
113 |
+
# 启动数据库服务,通过 --llm 参数指定您希望通过DashScope使用的具体模型
|
114 |
+
# 参数 --llm 可以是如下之一,按资源消耗从小到大排序:
|
115 |
+
# - qwen-7b-chat (与开源的Qwen-7B-Chat相同模型)
|
116 |
+
# - qwen-14b-chat (与开源的Qwen-14B-Chat相同模型)
|
117 |
+
# - qwen-turbo
|
118 |
+
# - qwen-plus
|
119 |
+
# 您需要将YOUR_DASHSCOPE_API_KEY替换为您的真实API-KEY。
|
120 |
+
export DASHSCOPE_API_KEY=YOUR_DASHSCOPE_API_KEY
|
121 |
+
python run_server.py --model_server dashscope --llm qwen-7b-chat --workstation_port 7864
|
122 |
+
```
|
123 |
+
|
124 |
+
如果您没有在使用DashScope、而是参考第一步部署了自己的模型服务的话,请执行以下命令:
|
125 |
+
|
126 |
+
```bash
|
127 |
+
# 启动数据库服务,通过 --model_server 参数指定您在 Step 1 里部署好的模型服务
|
128 |
+
# - 若 Step 1 的机器 IP 为 123.45.67.89,则可指定 --model_server http://123.45.67.89:7905/v1
|
129 |
+
# - 若 Step 1 和 Step 2 是同一台机器,则可指定 --model_server http://127.0.0.1:7905/v1
|
130 |
+
python run_server.py --model_server http://{MODEL_SERVER_IP}:7905/v1 --workstation_port 7864
|
131 |
+
```
|
132 |
+
|
133 |
+
现在您可以访问 [http://127.0.0.1:7864/](http://127.0.0.1:7864/) 来使用工作台(Workstation)的创作模式(Editor模式)和对话模式(Chat模式)了。
|
134 |
+
|
135 |
+
关于工作台的使用技巧,请参见工作台���面的文字说明、或观看[视频演示](#视频演示)。
|
136 |
+
|
137 |
+
## Step 3. 安装浏览器助手
|
138 |
+
|
139 |
+
安装BrowserQwen的Chrome插件(又称Chrome扩展程序):
|
140 |
+
|
141 |
+
1. 打开Chrome浏览器,在浏览器的地址栏中输入 `chrome://extensions/` 并按下回车键;
|
142 |
+
2. 确保右上角的 `开发者模式` 处于打开状态,之后点击 `加载已解压的扩展程序` 上传本项目下的 `browser_qwen` 目录并启用;
|
143 |
+
3. 单击谷歌浏览器右上角```扩展程序```图标,将BrowserQwen固定在工具栏。
|
144 |
+
|
145 |
+
注意,安装Chrome插件后,需要刷新页面,插件才能生效。
|
146 |
+
|
147 |
+
当您想让Qwen阅读当前网页的内容时:
|
148 |
+
|
149 |
+
1. 请先点击屏幕上的 `Add to Qwen's Reading List` 按钮,以授权Qwen在后台分析本页面。
|
150 |
+
2. 再单击浏览器右上角扩展程序栏的Qwen图标,便可以和Qwen交流当前页面的内容了。
|
151 |
+
|
152 |
+
## 视频演示
|
153 |
+
|
154 |
+
可查看以下几个演示视频,了解BrowserQwen的基本操作:
|
155 |
+
|
156 |
+
- 根据浏览过的网页、PDFs进行长文创作 [video](https://qianwen-res.oss-cn-beijing.aliyuncs.com/assets/qwen_agent/showcase_write_article_based_on_webpages_and_pdfs.mp4)
|
157 |
+
- 提取浏览内容使用代码解释器画图 [video](https://qianwen-res.oss-cn-beijing.aliyuncs.com/assets/qwen_agent/showcase_chat_with_docs_and_code_interpreter.mp4)
|
158 |
+
- 上传文件、多轮对话利用代码解释器分析数据 [video](https://qianwen-res.oss-cn-beijing.aliyuncs.com/assets/qwen_agent/showcase_code_interpreter_multi_turn_chat.mp4)
|
159 |
+
|
160 |
+
# 评测基准
|
161 |
+
|
162 |
+
我们也开源了一个评测基准,用于评估一个模型写Python代码并使用Code Interpreter进行数学解题、数据分析、及其他通用任务时的表现。评测基准见 [benchmark](benchmark/README.md) 目录,当前的评测结果如下:
|
163 |
+
|
164 |
+
<table>
|
165 |
+
<tr>
|
166 |
+
<th colspan="5" align="center">In-house Code Interpreter Benchmark (Version 20231206)</th>
|
167 |
+
</tr>
|
168 |
+
<tr>
|
169 |
+
<th rowspan="2" align="center">Model</th>
|
170 |
+
<th colspan="3" align="center">代码执行结果正确性 (%)</th>
|
171 |
+
<th colspan="1" align="center">生成代码的可执行率 (%)</th>
|
172 |
+
</tr>
|
173 |
+
<tr>
|
174 |
+
<th align="center">Math↑</th><th align="center">Visualization-Hard↑</th><th align="center">Visualization-Easy↑</th><th align="center">General↑</th>
|
175 |
+
</tr>
|
176 |
+
<tr>
|
177 |
+
<td>GPT-4</td>
|
178 |
+
<td align="center">82.8</td>
|
179 |
+
<td align="center">66.7</td>
|
180 |
+
<td align="center">60.8</td>
|
181 |
+
<td align="center">82.8</td>
|
182 |
+
</tr>
|
183 |
+
<tr>
|
184 |
+
<td>GPT-3.5</td>
|
185 |
+
<td align="center">47.3</td>
|
186 |
+
<td align="center">33.3</td>
|
187 |
+
<td align="center">55.7</td>
|
188 |
+
<td align="center">74.1</td>
|
189 |
+
</tr>
|
190 |
+
<tr>
|
191 |
+
<td>LLaMA2-13B-Chat</td>
|
192 |
+
<td align="center">8.3</td>
|
193 |
+
<td align="center">1.2</td>
|
194 |
+
<td align="center">15.2</td>
|
195 |
+
<td align="center">48.3</td>
|
196 |
+
</tr>
|
197 |
+
<tr>
|
198 |
+
<td>CodeLLaMA-13B-Instruct</td>
|
199 |
+
<td align="center">28.2</td>
|
200 |
+
<td align="center">15.5</td>
|
201 |
+
<td align="center">21.5</td>
|
202 |
+
<td align="center">74.1</td>
|
203 |
+
</tr>
|
204 |
+
<tr>
|
205 |
+
<td>InternLM-20B-Chat</td>
|
206 |
+
<td align="center">34.6</td>
|
207 |
+
<td align="center">10.7</td>
|
208 |
+
<td align="center">24.1</td>
|
209 |
+
<td align="center">65.5</td>
|
210 |
+
</tr>
|
211 |
+
<tr>
|
212 |
+
<td>ChatGLM3-6B</td>
|
213 |
+
<td align="center">54.2</td>
|
214 |
+
<td align="center">4.8</td>
|
215 |
+
<td align="center">15.2</td>
|
216 |
+
<td align="center">62.1</td>
|
217 |
+
</tr>
|
218 |
+
<tr>
|
219 |
+
<td>Qwen-1.8B-Chat</td>
|
220 |
+
<td align="center">25.6</td>
|
221 |
+
<td align="center">21.4</td>
|
222 |
+
<td align="center">22.8</td>
|
223 |
+
<td align="center">65.5</td>
|
224 |
+
</tr>
|
225 |
+
<tr>
|
226 |
+
<td>Qwen-7B-Chat</td>
|
227 |
+
<td align="center">41.9</td>
|
228 |
+
<td align="center">23.8</td>
|
229 |
+
<td align="center">38.0</td>
|
230 |
+
<td align="center">67.2</td>
|
231 |
+
</tr>
|
232 |
+
<tr>
|
233 |
+
<td>Qwen-14B-Chat</td>
|
234 |
+
<td align="center">58.4</td>
|
235 |
+
<td align="center">31.0</td>
|
236 |
+
<td align="center">45.6</td>
|
237 |
+
<td align="center">65.5</td>
|
238 |
+
</tr>
|
239 |
+
<tr>
|
240 |
+
<td>Qwen-72B-Chat</td>
|
241 |
+
<td align="center">72.7</td>
|
242 |
+
<td align="center">41.7</td>
|
243 |
+
<td align="center">43.0</td>
|
244 |
+
<td align="center">82.8</td>
|
245 |
+
</tr>
|
246 |
+
</table>
|
247 |
+
|
248 |
+
# 免责声明
|
249 |
+
|
250 |
+
本项目并非正式产品,而是一个概念验证项目,用于演示Qwen系列模型的能力。
|
251 |
+
|
252 |
+
> 重要提示:代码解释器未进行沙盒隔离,会在部署环境中执行代码。请避免向Qwen发出危险指令,切勿将该代码解释器直接用于生产目的。
|
assets/screenshot-ci.png
ADDED
assets/screenshot-editor-movie.png
ADDED
assets/screenshot-multi-web-qa.png
ADDED
assets/screenshot-pdf-qa.png
ADDED
assets/screenshot-web-qa.png
ADDED
assets/screenshot-writing.png
ADDED
benchmark/README.md
ADDED
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Code Interpreter Benchmark
|
2 |
+
|
3 |
+
## Introduction
|
4 |
+
To assess LLM's ability to use the Python Code Interpreter for tasks such as mathematical problem solving, data visualization, and other general-purpose tasks such as file handling and web scraping, we have created and open-sourced a benchmark specifically designed for evaluating these capabilities.
|
5 |
+
|
6 |
+
### Metrics
|
7 |
+
The metrics are divided into two parts: code executability and code correctness.
|
8 |
+
- Code executability: evaluating the ability of the LLM-generated code to be executed.
|
9 |
+
- Code correctness: evaluating whether the LLM-generated code runs correctly.
|
10 |
+
|
11 |
+
### Domain
|
12 |
+
When evaluating the accuracy of the code execution results for code correctness, we further divide it into two specific domains: `Math`, `Visualization`.
|
13 |
+
In terms of code executability, we calculate executable rate of the generated code for `General problem-solving`.
|
14 |
+
|
15 |
+
## Results
|
16 |
+
- Qwen-7B-Chat refers to the version updated after September 25, 2023.
|
17 |
+
- The code correctness judger model for `Visualization` has changed from `Qwen-vl-chat` to `gpt-4-vision-preview` in the version 20231206.
|
18 |
+
|
19 |
+
<table>
|
20 |
+
<tr>
|
21 |
+
<th colspan="5" align="center">In-house Code Interpreter Benchmark (Version 20231206)</th>
|
22 |
+
</tr>
|
23 |
+
<tr>
|
24 |
+
<th rowspan="2" align="center">Model</th>
|
25 |
+
<th colspan="3" align="center">Accuracy of Code Execution Results (%)</th>
|
26 |
+
<th colspan="1" align="center">Executable Rate of Code (%)</th>
|
27 |
+
</tr>
|
28 |
+
<tr>
|
29 |
+
<th align="center">Math↑</th><th align="center">Visualization-Hard↑</th><th align="center">Visualization-Easy↑</th><th align="center">General↑</th>
|
30 |
+
</tr>
|
31 |
+
<tr>
|
32 |
+
<td>GPT-4</td>
|
33 |
+
<td align="center">82.8</td>
|
34 |
+
<td align="center">66.7</td>
|
35 |
+
<td align="center">60.8</td>
|
36 |
+
<td align="center">82.8</td>
|
37 |
+
</tr>
|
38 |
+
<tr>
|
39 |
+
<td>GPT-3.5</td>
|
40 |
+
<td align="center">47.3</td>
|
41 |
+
<td align="center">33.3</td>
|
42 |
+
<td align="center">55.7</td>
|
43 |
+
<td align="center">74.1</td>
|
44 |
+
</tr>
|
45 |
+
<tr>
|
46 |
+
<td>LLaMA2-13B-Chat</td>
|
47 |
+
<td align="center">8.3</td>
|
48 |
+
<td align="center">1.2</td>
|
49 |
+
<td align="center">15.2</td>
|
50 |
+
<td align="center">48.3</td>
|
51 |
+
</tr>
|
52 |
+
<tr>
|
53 |
+
<td>CodeLLaMA-13B-Instruct</td>
|
54 |
+
<td align="center">28.2</td>
|
55 |
+
<td align="center">15.5</td>
|
56 |
+
<td align="center">21.5</td>
|
57 |
+
<td align="center">74.1</td>
|
58 |
+
</tr>
|
59 |
+
<tr>
|
60 |
+
<td>InternLM-20B-Chat</td>
|
61 |
+
<td align="center">34.6</td>
|
62 |
+
<td align="center">10.7</td>
|
63 |
+
<td align="center">24.1</td>
|
64 |
+
<td align="center">65.5</td>
|
65 |
+
</tr>
|
66 |
+
<tr>
|
67 |
+
<td>ChatGLM3-6B</td>
|
68 |
+
<td align="center">54.2</td>
|
69 |
+
<td align="center">4.8</td>
|
70 |
+
<td align="center">15.2</td>
|
71 |
+
<td align="center">62.1</td>
|
72 |
+
</tr>
|
73 |
+
<tr>
|
74 |
+
<td>Qwen-1.8B-Chat</td>
|
75 |
+
<td align="center">25.6</td>
|
76 |
+
<td align="center">21.4</td>
|
77 |
+
<td align="center">22.8</td>
|
78 |
+
<td align="center">65.5</td>
|
79 |
+
</tr>
|
80 |
+
<tr>
|
81 |
+
<td>Qwen-7B-Chat</td>
|
82 |
+
<td align="center">41.9</td>
|
83 |
+
<td align="center">23.8</td>
|
84 |
+
<td align="center">38.0</td>
|
85 |
+
<td align="center">67.2</td>
|
86 |
+
</tr>
|
87 |
+
<tr>
|
88 |
+
<td>Qwen-14B-Chat</td>
|
89 |
+
<td align="center">58.4</td>
|
90 |
+
<td align="center">31.0</td>
|
91 |
+
<td align="center">45.6</td>
|
92 |
+
<td align="center">65.5</td>
|
93 |
+
</tr>
|
94 |
+
<tr>
|
95 |
+
<td>Qwen-72B-Chat</td>
|
96 |
+
<td align="center">72.7</td>
|
97 |
+
<td align="center">41.7</td>
|
98 |
+
<td align="center">43.0</td>
|
99 |
+
<td align="center">82.8</td>
|
100 |
+
</tr>
|
101 |
+
</table>
|
102 |
+
|
103 |
+
Furthermore, we also provide the results of `Qwen-vl-plus` as the code correctness judger model for `Visualization` task to serve as a reference.
|
104 |
+
|
105 |
+
<table>
|
106 |
+
<tr>
|
107 |
+
<th colspan="3" align="center">Code Correctness Judger Model = Qwen-vl-plus</th>
|
108 |
+
</tr>
|
109 |
+
<tr>
|
110 |
+
<th rowspan="2" align="center">Model</th>
|
111 |
+
<th colspan="2" align="center">Accuracy of Code Execution Results (%)</th>
|
112 |
+
</tr>
|
113 |
+
<tr>
|
114 |
+
<th align="center">Visualization-Hard↑</th>
|
115 |
+
<th align="center">Visualization-Easy↑</th>
|
116 |
+
</tr>
|
117 |
+
<tr>
|
118 |
+
<td>LLaMA2-13B-Chat</td>
|
119 |
+
<td align="center">2.4</td>
|
120 |
+
<td align="center">17.7</td>
|
121 |
+
</tr>
|
122 |
+
<tr>
|
123 |
+
<td>CodeLLaMA-13B-Instruct</td>
|
124 |
+
<td align="center">17.9</td>
|
125 |
+
<td align="center">34.2</td>
|
126 |
+
</tr>
|
127 |
+
<tr>
|
128 |
+
<td>InternLM-20B-Chat</td>
|
129 |
+
<td align="center">9.5</td>
|
130 |
+
<td align="center">31.7</td>
|
131 |
+
</tr>
|
132 |
+
<tr>
|
133 |
+
<td>ChatGLM3-6B</td>
|
134 |
+
<td align="center">10.7</td>
|
135 |
+
<td align="center">29.1</td>
|
136 |
+
</tr>
|
137 |
+
<tr>
|
138 |
+
<td>Qwen-1.8B-Chat</td>
|
139 |
+
<td align="center">32.1</td>
|
140 |
+
<td align="center">32.9</td>
|
141 |
+
</tr>
|
142 |
+
<tr>
|
143 |
+
<td>Qwen-7B-Chat</td>
|
144 |
+
<td align="center">26.2</td>
|
145 |
+
<td align="center">39.2</td>
|
146 |
+
</tr>
|
147 |
+
<tr>
|
148 |
+
<td>Qwen-14B-Chat</td>
|
149 |
+
<td align="center">36.9</td>
|
150 |
+
<td align="center">41.8</td>
|
151 |
+
</tr>
|
152 |
+
<tr>
|
153 |
+
<td>Qwen-72B-Chat</td>
|
154 |
+
<td align="center">38.1</td>
|
155 |
+
<td align="center">38.0</td>
|
156 |
+
</tr>
|
157 |
+
</table>
|
158 |
+
|
159 |
+
|
160 |
+
|
161 |
+
## Usage
|
162 |
+
|
163 |
+
### Installation
|
164 |
+
|
165 |
+
```shell
|
166 |
+
git clone https://github.com/QwenLM/Qwen-Agent.git
|
167 |
+
cd benchmark
|
168 |
+
pip install -r requirements.txt
|
169 |
+
```
|
170 |
+
|
171 |
+
### Dataset Download
|
172 |
+
```shell
|
173 |
+
cd benchmark
|
174 |
+
wget https://qianwen-res.oss-cn-beijing.aliyuncs.com/assets/qwen_agent/benchmark_code_interpreter_data.zip
|
175 |
+
unzip benchmark_code_interpreter_data.zip
|
176 |
+
mkdir eval_data
|
177 |
+
mv eval_code_interpreter_v1.jsonl eval_data/
|
178 |
+
```
|
179 |
+
|
180 |
+
### Evaluation
|
181 |
+
To reproduce the comprehensive results of benchmark, you can run the following script:
|
182 |
+
|
183 |
+
```Shell
|
184 |
+
python inference_and_execute.py --model {model_name}
|
185 |
+
```
|
186 |
+
|
187 |
+
{model_name}:
|
188 |
+
- qwen-1.8b-chat
|
189 |
+
- qwen-7b-chat
|
190 |
+
- qwen-14b-chat
|
191 |
+
- qwen-72b-chat
|
192 |
+
- llama-2-7b-chat
|
193 |
+
- llama-2-13b-chat
|
194 |
+
- codellama-7b-instruct
|
195 |
+
- codellama-13b-instruct
|
196 |
+
- internlm-7b-chat-1.1
|
197 |
+
- internlm-20b-chat
|
198 |
+
|
199 |
+
The benchmark will run the test cases and generate the performance results. The results will be saved in the `output_data` directory.
|
200 |
+
|
201 |
+
**Notes**:
|
202 |
+
Please install `simhei.ttf` font for proper display in matplotlib when evaluating visualization task. You can do this by preparing `simhei.ttf` (which can be found on any Windows PC) and then running the following code snippet:
|
203 |
+
```python
|
204 |
+
import os
|
205 |
+
import matplotlib
|
206 |
+
target_font_path = os.path.join(
|
207 |
+
os.path.abspath(
|
208 |
+
os.path.join(matplotlib.matplotlib_fname(), os.path.pardir)),
|
209 |
+
'fonts', 'ttf', 'simhei.ttf')
|
210 |
+
os.system(f'cp simhei.ttf {target_font_path}')
|
211 |
+
font_list_cache = os.path.join(matplotlib.get_cachedir(), 'fontlist-*.json')
|
212 |
+
os.system(f'rm -f {font_list_cache}')
|
213 |
+
```
|
214 |
+
|
215 |
+
#### Code Executable Rate
|
216 |
+
```Shell
|
217 |
+
python inference_and_execute.py --task {task_name} --model {model_name}
|
218 |
+
```
|
219 |
+
|
220 |
+
{task_name}:
|
221 |
+
- `general`: General problem-solving task
|
222 |
+
|
223 |
+
|
224 |
+
#### Code Correctness Rate
|
225 |
+
```Shell
|
226 |
+
python inference_and_execute.py --task {task_name} --model {model_name}
|
227 |
+
```
|
228 |
+
|
229 |
+
{task_name}:
|
230 |
+
- `visualization`: Visualization task
|
231 |
+
- `gsm8k`: Math task
|
232 |
+
|
233 |
+
|
234 |
+
## Configuration
|
235 |
+
The inference_and_exec.py file contains the following configurable options:
|
236 |
+
|
237 |
+
- `--model`: The model to test which can be one of `qwen-72b-chat`, `qwen-14b-chat`, `qwen-7b-chat`, `qwen-1.8b-chat`, `qwen-7b-chat`, `llama-2-7b-chat`, `llama-2-13b-chat`, `codellama-7b-instruct`, `codellama-13b-instruct`, `internlm-7b-chat-1.1`, `internlm-20b-chat`.
|
238 |
+
- `--task`: The test task which can be one of `all`, `visualization`, `general`, `gsm8k`.
|
239 |
+
- `--output-path`: The path for saving evaluation result.
|
240 |
+
- `--input-path`: The path for placing evaluation data.
|
241 |
+
- `--output-fname`: The file name for evaluation result.
|
242 |
+
- `--input-fname`: The file name for evaluation data.
|
243 |
+
- `--force`: Force generation and will overwrite the cached results.
|
244 |
+
- `--eval-only`: Only calculate evaluation metrics without re-inference.
|
245 |
+
- `--eval-code-exec-only`: Only evaluate code executable rate
|
246 |
+
- `--gen-exec-only`: Only generate and execuate code without calculating evaluation metrics.
|
247 |
+
- `--gen-only`: Only generate without execuating code and calculating evaluation metrics.
|
248 |
+
- `--vis-judger`: The model to judge the result correctness for `Visualization` task which can be one of `gpt-4-vision-preview`, `qwen-vl-chat`, `qwen-vl-plus`. It is set to `gpt-4-vision-preview` by default in the version 20231206, and `Qwen-vl-chat` has been deprecated.
|
benchmark/code_interpreter.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import io
|
3 |
+
import json
|
4 |
+
import logging
|
5 |
+
import os
|
6 |
+
import queue
|
7 |
+
import re
|
8 |
+
import subprocess
|
9 |
+
import sys
|
10 |
+
import time
|
11 |
+
import traceback
|
12 |
+
import uuid
|
13 |
+
|
14 |
+
import matplotlib
|
15 |
+
import PIL.Image
|
16 |
+
from jupyter_client import BlockingKernelClient
|
17 |
+
from utils.code_utils import extract_code
|
18 |
+
|
19 |
+
WORK_DIR = os.getenv('CODE_INTERPRETER_WORK_DIR', '/tmp/workspace')
|
20 |
+
|
21 |
+
LAUNCH_KERNEL_PY = """
|
22 |
+
from ipykernel import kernelapp as app
|
23 |
+
app.launch_new_instance()
|
24 |
+
"""
|
25 |
+
|
26 |
+
_KERNEL_CLIENTS = {}
|
27 |
+
|
28 |
+
|
29 |
+
# Run this fix before jupyter starts if matplotlib cannot render CJK fonts.
|
30 |
+
# And we need to additionally run the following lines in the jupyter notebook.
|
31 |
+
# ```python
|
32 |
+
# import matplotlib.pyplot as plt
|
33 |
+
# plt.rcParams['font.sans-serif'] = ['SimHei']
|
34 |
+
# plt.rcParams['axes.unicode_minus'] = False
|
35 |
+
# ````
|
36 |
+
def fix_matplotlib_cjk_font_issue():
|
37 |
+
local_ttf = os.path.join(
|
38 |
+
os.path.abspath(
|
39 |
+
os.path.join(matplotlib.matplotlib_fname(), os.path.pardir)),
|
40 |
+
'fonts', 'ttf', 'simhei.ttf')
|
41 |
+
if not os.path.exists(local_ttf):
|
42 |
+
logging.warning(
|
43 |
+
f'Missing font file `{local_ttf}` for matplotlib. It may cause some error when using matplotlib.'
|
44 |
+
)
|
45 |
+
|
46 |
+
|
47 |
+
def start_kernel(pid):
|
48 |
+
fix_matplotlib_cjk_font_issue()
|
49 |
+
|
50 |
+
connection_file = os.path.join(WORK_DIR,
|
51 |
+
f'kernel_connection_file_{pid}.json')
|
52 |
+
launch_kernel_script = os.path.join(WORK_DIR, f'launch_kernel_{pid}.py')
|
53 |
+
for f in [connection_file, launch_kernel_script]:
|
54 |
+
if os.path.exists(f):
|
55 |
+
logging.warning(f'{f} already exists')
|
56 |
+
os.remove(f)
|
57 |
+
|
58 |
+
os.makedirs(WORK_DIR, exist_ok=True)
|
59 |
+
|
60 |
+
with open(launch_kernel_script, 'w') as fout:
|
61 |
+
fout.write(LAUNCH_KERNEL_PY)
|
62 |
+
|
63 |
+
kernel_process = subprocess.Popen([
|
64 |
+
sys.executable,
|
65 |
+
launch_kernel_script,
|
66 |
+
'--IPKernelApp.connection_file',
|
67 |
+
connection_file,
|
68 |
+
'--matplotlib=inline',
|
69 |
+
'--quiet',
|
70 |
+
],
|
71 |
+
cwd=WORK_DIR)
|
72 |
+
logging.info(f"INFO: kernel process's PID = {kernel_process.pid}")
|
73 |
+
|
74 |
+
# Wait for kernel connection file to be written
|
75 |
+
while True:
|
76 |
+
if not os.path.isfile(connection_file):
|
77 |
+
time.sleep(0.1)
|
78 |
+
else:
|
79 |
+
# Keep looping if JSON parsing fails, file may be partially written
|
80 |
+
try:
|
81 |
+
with open(connection_file, 'r') as fp:
|
82 |
+
json.load(fp)
|
83 |
+
break
|
84 |
+
except json.JSONDecodeError:
|
85 |
+
pass
|
86 |
+
|
87 |
+
# Client
|
88 |
+
kc = BlockingKernelClient(connection_file=connection_file)
|
89 |
+
kc.load_connection_file()
|
90 |
+
kc.start_channels()
|
91 |
+
kc.wait_for_ready()
|
92 |
+
return kc
|
93 |
+
|
94 |
+
|
95 |
+
def escape_ansi(line):
|
96 |
+
ansi_escape = re.compile(r'(?:\x1B[@-_]|[\x80-\x9F])[0-?]*[ -/]*[@-~]')
|
97 |
+
return ansi_escape.sub('', line)
|
98 |
+
|
99 |
+
|
100 |
+
def publish_image_to_local(image_base64: str):
|
101 |
+
image_file = str(uuid.uuid4()) + '.png'
|
102 |
+
local_image_file = os.path.join(WORK_DIR, image_file)
|
103 |
+
|
104 |
+
png_bytes = base64.b64decode(image_base64)
|
105 |
+
assert isinstance(png_bytes, bytes)
|
106 |
+
bytes_io = io.BytesIO(png_bytes)
|
107 |
+
PIL.Image.open(bytes_io).save(local_image_file, 'png')
|
108 |
+
|
109 |
+
return local_image_file
|
110 |
+
|
111 |
+
|
112 |
+
START_CODE = """
|
113 |
+
import signal
|
114 |
+
def _m6_code_interpreter_timeout_handler(signum, frame):
|
115 |
+
raise TimeoutError("M6_CODE_INTERPRETER_TIMEOUT")
|
116 |
+
signal.signal(signal.SIGALRM, _m6_code_interpreter_timeout_handler)
|
117 |
+
|
118 |
+
def input(*args, **kwargs):
|
119 |
+
raise NotImplementedError('Python input() function is disabled.')
|
120 |
+
|
121 |
+
import os
|
122 |
+
if 'upload_file' not in os.getcwd():
|
123 |
+
os.chdir("./upload_file/")
|
124 |
+
|
125 |
+
import math
|
126 |
+
import re
|
127 |
+
import json
|
128 |
+
|
129 |
+
import seaborn as sns
|
130 |
+
sns.set_theme()
|
131 |
+
|
132 |
+
import matplotlib
|
133 |
+
import matplotlib.pyplot as plt
|
134 |
+
plt.rcParams['font.sans-serif'] = ['SimHei']
|
135 |
+
plt.rcParams['axes.unicode_minus'] = False
|
136 |
+
|
137 |
+
import numpy as np
|
138 |
+
import pandas as pd
|
139 |
+
|
140 |
+
from sympy import Eq, symbols, solve
|
141 |
+
"""
|
142 |
+
|
143 |
+
|
144 |
+
def code_interpreter(action_input_list: list, timeout=30, clear=False):
|
145 |
+
code = ''
|
146 |
+
for action_input in action_input_list:
|
147 |
+
code += (extract_code(action_input) + '\n')
|
148 |
+
fixed_code = []
|
149 |
+
for line in code.split('\n'):
|
150 |
+
fixed_code.append(line)
|
151 |
+
if line.startswith('sns.set_theme('):
|
152 |
+
fixed_code.append('plt.rcParams["font.sans-serif"] = ["SimHei"]')
|
153 |
+
fixed_code.append('plt.rcParams["axes.unicode_minus"] = False')
|
154 |
+
fixed_code = '\n'.join(fixed_code)
|
155 |
+
if 'def solution()' in fixed_code:
|
156 |
+
fixed_code += '\nsolution()'
|
157 |
+
|
158 |
+
return _code_interpreter(fixed_code, timeout, clear)
|
159 |
+
|
160 |
+
|
161 |
+
def _code_interpreter(code: str, timeout, clear=False):
|
162 |
+
if not code.strip():
|
163 |
+
return ''
|
164 |
+
if timeout:
|
165 |
+
code = f'signal.alarm({timeout})\n{code}'
|
166 |
+
if clear:
|
167 |
+
code = "get_ipython().run_line_magic('reset', '-f')\n" + START_CODE + code
|
168 |
+
|
169 |
+
pid = os.getpid()
|
170 |
+
if pid not in _KERNEL_CLIENTS:
|
171 |
+
_KERNEL_CLIENTS[pid] = start_kernel(pid)
|
172 |
+
_code_interpreter(START_CODE, timeout=None)
|
173 |
+
kc = _KERNEL_CLIENTS[pid]
|
174 |
+
kc.wait_for_ready()
|
175 |
+
kc.execute(code)
|
176 |
+
result = ''
|
177 |
+
image_idx = 0
|
178 |
+
while True:
|
179 |
+
text = ''
|
180 |
+
image = ''
|
181 |
+
finished = False
|
182 |
+
msg_type = 'error'
|
183 |
+
try:
|
184 |
+
msg = kc.get_iopub_msg()
|
185 |
+
msg_type = msg['msg_type']
|
186 |
+
if msg_type == 'status':
|
187 |
+
if msg['content'].get('execution_state') == 'idle':
|
188 |
+
finished = True
|
189 |
+
elif msg_type == 'execute_result':
|
190 |
+
text = msg['content']['data'].get('text/plain', '')
|
191 |
+
if 'image/png' in msg['content']['data']:
|
192 |
+
image_b64 = msg['content']['data']['image/png']
|
193 |
+
image_url = publish_image_to_local(image_b64)
|
194 |
+
image_idx += 1
|
195 |
+
image = '![fig-%03d](%s)' % (image_idx, image_url)
|
196 |
+
elif msg_type == 'display_data':
|
197 |
+
if 'image/png' in msg['content']['data']:
|
198 |
+
image_b64 = msg['content']['data']['image/png']
|
199 |
+
image_url = publish_image_to_local(image_b64)
|
200 |
+
image_idx += 1
|
201 |
+
image = '![fig-%03d](%s)' % (image_idx, image_url)
|
202 |
+
else:
|
203 |
+
text = msg['content']['data'].get('text/plain', '')
|
204 |
+
elif msg_type == 'stream':
|
205 |
+
msg_type = msg['content']['name'] # stdout, stderr
|
206 |
+
text = msg['content']['text']
|
207 |
+
elif msg_type == 'error':
|
208 |
+
text = escape_ansi('\n'.join(msg['content']['traceback']))
|
209 |
+
if 'M6_CODE_INTERPRETER_TIMEOUT' in text:
|
210 |
+
text = f'Timeout. No response after {timeout} seconds.'
|
211 |
+
except queue.Empty:
|
212 |
+
text = f'Timeout. No response after {timeout} seconds.'
|
213 |
+
finished = True
|
214 |
+
except Exception:
|
215 |
+
text = 'The code interpreter encountered an unexpected error.'
|
216 |
+
logging.warning(''.join(
|
217 |
+
traceback.format_exception(*sys.exc_info())))
|
218 |
+
finished = True
|
219 |
+
if text:
|
220 |
+
result += f'\n\n{msg_type}:\n\n```\n{text}\n```'
|
221 |
+
if image:
|
222 |
+
result += f'\n\n{image}'
|
223 |
+
if finished:
|
224 |
+
break
|
225 |
+
result = result.lstrip('\n')
|
226 |
+
if timeout:
|
227 |
+
_code_interpreter('signal.alarm(0)', timeout=None)
|
228 |
+
return result
|
229 |
+
|
230 |
+
|
231 |
+
def get_multiline_input(hint):
|
232 |
+
print(hint)
|
233 |
+
print('// Press ENTER to make a new line. Press CTRL-D to end input.')
|
234 |
+
lines = []
|
235 |
+
while True:
|
236 |
+
try:
|
237 |
+
line = input()
|
238 |
+
except EOFError: # CTRL-D
|
239 |
+
break
|
240 |
+
lines.append(line)
|
241 |
+
print('// Input received.')
|
242 |
+
if lines:
|
243 |
+
return '\n'.join(lines)
|
244 |
+
else:
|
245 |
+
return ''
|
246 |
+
|
247 |
+
|
248 |
+
if __name__ == '__main__':
|
249 |
+
while True:
|
250 |
+
print(code_interpreter([get_multiline_input('Enter python code:')]))
|
benchmark/config.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from parser import InternLMReActParser, ReActParser
|
2 |
+
|
3 |
+
from models import LLM, QwenVL, Qwen, QwenDashscopeVLModel
|
4 |
+
from prompt import InternLMReAct, LlamaReAct, QwenReAct
|
5 |
+
|
6 |
+
react_prompt_map = {
|
7 |
+
'qwen': QwenReAct,
|
8 |
+
'llama': LlamaReAct,
|
9 |
+
'internlm': InternLMReAct,
|
10 |
+
}
|
11 |
+
|
12 |
+
react_parser_map = {
|
13 |
+
'qwen': ReActParser,
|
14 |
+
'llama': ReActParser,
|
15 |
+
'internlm': InternLMReActParser,
|
16 |
+
}
|
17 |
+
|
18 |
+
model_map = {'qwen': Qwen, 'llama': LLM, 'internlm': LLM, 'qwen-vl-chat': QwenVL}
|
19 |
+
|
20 |
+
model_type_map = {
|
21 |
+
'qwen-72b-chat': 'qwen',
|
22 |
+
'qwen-14b-chat': 'qwen',
|
23 |
+
'qwen-1.8b-chat': 'qwen',
|
24 |
+
'qwen-7b-chat': 'qwen',
|
25 |
+
'llama-2-7b-chat': 'llama',
|
26 |
+
'llama-2-13b-chat': 'llama',
|
27 |
+
'codellama-7b-instruct': 'llama',
|
28 |
+
'codellama-13b-instruct': 'llama',
|
29 |
+
'internlm-7b-chat-1.1': 'internlm',
|
30 |
+
'internlm-20b-chat': 'internlm',
|
31 |
+
'qwen-vl-chat': 'qwen-vl-chat',
|
32 |
+
}
|
33 |
+
|
34 |
+
model_path_map = {
|
35 |
+
'qwen-72b-chat': 'Qwen/Qwen-72B-Chat',
|
36 |
+
'qwen-14b-chat': 'Qwen/Qwen-14B-Chat',
|
37 |
+
'qwen-7b-chat': 'Qwen/Qwen-7B-Chat',
|
38 |
+
'qwen-1.8b-chat': 'Qwen/Qwen-1_8B-Chat',
|
39 |
+
'llama-2-7b-chat': 'meta-llama/Llama-2-7b-chat-hf',
|
40 |
+
'llama-2-13b-chat': 'meta-llama/Llama-2-13b-chat-hf',
|
41 |
+
'codellama-7b-instruct': 'codellama/CodeLlama-7b-Instruct-hf',
|
42 |
+
'codellama-13b-instruct': 'codellama/CodeLlama-13b-Instruct-hf',
|
43 |
+
'internlm-7b-chat-1.1': 'internlm/internlm-chat-7b-v1_1',
|
44 |
+
'internlm-20b-chat': 'internlm/internlm-chat-20b',
|
45 |
+
'qwen-vl-chat': 'Qwen/Qwen-VL-Chat',
|
46 |
+
}
|
47 |
+
|
48 |
+
|
49 |
+
def get_react_prompt(model_name, query, lang, upload_fname_list):
|
50 |
+
react_prompt_cls = react_prompt_map.get(model_type_map[model_name],
|
51 |
+
QwenReAct)
|
52 |
+
return react_prompt_cls(query, lang, upload_fname_list)
|
53 |
+
|
54 |
+
|
55 |
+
def get_react_parser(model_name):
|
56 |
+
react_parser_cls = react_parser_map.get(model_type_map[model_name],
|
57 |
+
ReActParser)
|
58 |
+
return react_parser_cls()
|
59 |
+
|
60 |
+
|
61 |
+
def get_model(model_name):
|
62 |
+
if model_name in ["qwen-vl-plus"]:
|
63 |
+
return QwenDashscopeVLModel(model=model_name)
|
64 |
+
model_path = model_path_map.get(model_name, None)
|
65 |
+
model_cls = model_map.get(model_type_map[model_name], LLM)
|
66 |
+
return model_cls(model_path)
|
benchmark/inference_and_execute.py
ADDED
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
from parser import ReActParser
|
6 |
+
|
7 |
+
import prettytable
|
8 |
+
import tqdm
|
9 |
+
from code_interpreter import code_interpreter
|
10 |
+
from config import (get_model, get_react_parser, get_react_prompt,
|
11 |
+
model_path_map)
|
12 |
+
from datasets import load_dataset
|
13 |
+
from metrics.code_execution import eval_code_execution_rate
|
14 |
+
from metrics.gsm8k import eval_gsm8k_acc, is_correct
|
15 |
+
from metrics.visualization import eval_visualization_acc
|
16 |
+
from utils.code_utils import replace_upload_fname
|
17 |
+
from utils.data_utils import load_jsonl
|
18 |
+
|
19 |
+
logging.basicConfig(
|
20 |
+
format='%(asctime)s - %(levelname)s - %(message)s',
|
21 |
+
datefmt='%Y-%m-%d %H:%M:%S',
|
22 |
+
level=logging.INFO,
|
23 |
+
)
|
24 |
+
|
25 |
+
WORK_DIR = os.getenv('CODE_INTERPRETER_WORK_DIR', '/tmp/workspace')
|
26 |
+
os.makedirs(WORK_DIR, exist_ok=True)
|
27 |
+
os.system(f'cp -r upload_file_clean {WORK_DIR}/upload_file')
|
28 |
+
os.system('cp -r upload_file_clean ./upload_file')
|
29 |
+
|
30 |
+
global_eval_result = {
|
31 |
+
'code_executability': {
|
32 |
+
'math': None,
|
33 |
+
'visualization': None,
|
34 |
+
'general': None,
|
35 |
+
},
|
36 |
+
'code_correctness': {
|
37 |
+
'math': None,
|
38 |
+
'visualization-hard': None,
|
39 |
+
'visualization-easy': None,
|
40 |
+
}
|
41 |
+
}
|
42 |
+
|
43 |
+
|
44 |
+
def llm_with_plugin(args, query, item=None, exec_limit=3):
|
45 |
+
exec_count = 0
|
46 |
+
|
47 |
+
# Build ReAct prompt
|
48 |
+
upload_fname_list = item[
|
49 |
+
'input_file_path'] if item and 'input_file_path' in item else []
|
50 |
+
lang = item['lang'] if item and 'lang' in item else 'en'
|
51 |
+
react_prompt_obj = get_react_prompt(args.model, query, lang,
|
52 |
+
upload_fname_list)
|
53 |
+
planning_prompt = react_prompt_obj.build_prompt()
|
54 |
+
|
55 |
+
# Execute the code when providing the first action in the query
|
56 |
+
if '<|im_start|>' in query:
|
57 |
+
_, prepend_code, __ = ReActParser().parse_latest_plugin_call(query)
|
58 |
+
prepend_code = replace_upload_fname(prepend_code, upload_fname_list)
|
59 |
+
call_plugin(_, [prepend_code], clear=(exec_count == 0))
|
60 |
+
exec_count += 1
|
61 |
+
exec_limit += 1
|
62 |
+
|
63 |
+
# Inference and execute
|
64 |
+
text = ''
|
65 |
+
while exec_count < exec_limit:
|
66 |
+
stop_words_list = react_prompt_obj.get_stop_words_list()
|
67 |
+
output = text_completion(args.llm,
|
68 |
+
planning_prompt + text,
|
69 |
+
stop_words=stop_words_list)
|
70 |
+
|
71 |
+
if args.gen_only:
|
72 |
+
text += output
|
73 |
+
break
|
74 |
+
|
75 |
+
react_parser = get_react_parser(args.model)
|
76 |
+
action, action_input, output = react_parser.parse_latest_plugin_call(
|
77 |
+
output)
|
78 |
+
if action:
|
79 |
+
action_input = replace_upload_fname(action_input,
|
80 |
+
upload_fname_list)
|
81 |
+
observation = call_plugin(action, [action_input],
|
82 |
+
clear=(exec_count == 0))
|
83 |
+
output += react_prompt_obj.build_observation(observation)
|
84 |
+
text += output
|
85 |
+
exec_count += 1
|
86 |
+
if 'error:' in observation or 'Traceback' in observation:
|
87 |
+
break
|
88 |
+
else:
|
89 |
+
text += output
|
90 |
+
break
|
91 |
+
return text
|
92 |
+
|
93 |
+
|
94 |
+
def text_completion(llm, input_text, stop_words=[]):
|
95 |
+
logging.info('Generating'.center(60, '='))
|
96 |
+
logging.info('Input'.center(60, '-'))
|
97 |
+
logging.info(input_text)
|
98 |
+
|
99 |
+
output = llm.generate(input_text, stop_words)
|
100 |
+
|
101 |
+
logging.info('Output'.center(60, '-'))
|
102 |
+
logging.info(output)
|
103 |
+
return output
|
104 |
+
|
105 |
+
|
106 |
+
def call_plugin(plugin_name, plugin_args_list, clear=False):
|
107 |
+
# Relax constraints on plugin name.
|
108 |
+
logging.info('Call code interpreter'.center(60, '='))
|
109 |
+
obs = code_interpreter(plugin_args_list, clear=clear)
|
110 |
+
logging.info(obs)
|
111 |
+
return obs
|
112 |
+
|
113 |
+
|
114 |
+
def process_code_interpreter(item, writer):
|
115 |
+
query = item['query']
|
116 |
+
exec_limit = 3 if 'visualization' in item['tags'] else 1
|
117 |
+
response = llm_with_plugin(args=args,
|
118 |
+
query=query,
|
119 |
+
item=item,
|
120 |
+
exec_limit=exec_limit)
|
121 |
+
item['gen'] = response
|
122 |
+
|
123 |
+
writer.write(json.dumps(item, ensure_ascii=False) + '\n')
|
124 |
+
writer.flush()
|
125 |
+
|
126 |
+
|
127 |
+
def process_gsm8k(doc, writer):
|
128 |
+
context = doc['question']
|
129 |
+
completion = llm_with_plugin(args=args, query=context)
|
130 |
+
acc = is_correct(completion, doc['answer'])
|
131 |
+
doc['completion'] = completion
|
132 |
+
doc['acc'] = acc
|
133 |
+
|
134 |
+
writer.write(json.dumps(doc, ensure_ascii=False) + '\n')
|
135 |
+
writer.flush()
|
136 |
+
|
137 |
+
|
138 |
+
def sequential_processing(args, data_list, process_func, writer):
|
139 |
+
for item in tqdm.tqdm(data_list):
|
140 |
+
process_func(item, writer)
|
141 |
+
|
142 |
+
|
143 |
+
process_func_map = {
|
144 |
+
'gsm8k': process_gsm8k,
|
145 |
+
'visualization': process_code_interpreter
|
146 |
+
}
|
147 |
+
|
148 |
+
|
149 |
+
def gather_eval_result(model_name):
|
150 |
+
for metric in global_eval_result:
|
151 |
+
logging.info(metric)
|
152 |
+
table = prettytable.PrettyTable()
|
153 |
+
table.field_names = ['model'] + list(global_eval_result[metric].keys())
|
154 |
+
row_data = [model_name]
|
155 |
+
for item in global_eval_result[metric].values():
|
156 |
+
item = str(item) if not item else str(round(item, 2))
|
157 |
+
row_data.append(item)
|
158 |
+
table.add_row(row_data)
|
159 |
+
logging.info('\n' + str(table))
|
160 |
+
|
161 |
+
|
162 |
+
def eval_metrics(args, test_set, full_output_fname):
|
163 |
+
# metrics
|
164 |
+
assert os.path.exists(
|
165 |
+
full_output_fname), f'Not Found File {full_output_fname}.'
|
166 |
+
inference_res = load_jsonl(full_output_fname)
|
167 |
+
assert len(inference_res) == len(
|
168 |
+
test_set
|
169 |
+
), f'There are still {len(test_set)-len(inference_res)} cases left.'
|
170 |
+
|
171 |
+
abs_output_fname = os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
172 |
+
full_output_fname)
|
173 |
+
if args.task == 'gsm8k':
|
174 |
+
math_code_correctness = eval_gsm8k_acc(abs_output_fname)
|
175 |
+
global_eval_result['code_correctness'].update(math_code_correctness)
|
176 |
+
else:
|
177 |
+
code_executability = eval_code_execution_rate(abs_output_fname,
|
178 |
+
args.task, args.model)
|
179 |
+
global_eval_result['code_executability'].update(code_executability)
|
180 |
+
if args.task in ['all_ci', 'visualization'
|
181 |
+
] and not args.eval_code_exec_only:
|
182 |
+
visualization_code_correctness = eval_visualization_acc(
|
183 |
+
abs_output_fname, args.model, args.vis_judger)
|
184 |
+
global_eval_result['code_correctness'].update(
|
185 |
+
visualization_code_correctness)
|
186 |
+
|
187 |
+
|
188 |
+
def main(args):
|
189 |
+
current_dir = os.getcwd()
|
190 |
+
os.makedirs(args.output_path, exist_ok=True)
|
191 |
+
full_output_fname = os.path.join(
|
192 |
+
args.output_path,
|
193 |
+
(args.output_fname or f'{args.task}_{args.model}_res.jsonl'))
|
194 |
+
|
195 |
+
if not os.path.exists(full_output_fname):
|
196 |
+
with open(full_output_fname, 'w'):
|
197 |
+
logging.info(f'Create file {full_output_fname} done.')
|
198 |
+
|
199 |
+
# build data
|
200 |
+
if args.task == 'gsm8k':
|
201 |
+
dataset = load_dataset('gsm8k', 'main')
|
202 |
+
test_set = dataset['test']
|
203 |
+
else:
|
204 |
+
eval_data_path = os.path.join(args.input_path, args.input_fname)
|
205 |
+
test_set = [
|
206 |
+
item for item in load_jsonl(eval_data_path)
|
207 |
+
if args.task in item['tags']
|
208 |
+
]
|
209 |
+
logging.info(f'Test set: {len(test_set)}')
|
210 |
+
|
211 |
+
if args.eval_only:
|
212 |
+
eval_metrics(args, test_set, full_output_fname)
|
213 |
+
else:
|
214 |
+
key = 'question' if args.task == 'gsm8k' else 'query'
|
215 |
+
cache_question = [item[key] for item in load_jsonl(full_output_fname)
|
216 |
+
] if not args.force else []
|
217 |
+
data_list = [
|
218 |
+
item for item in test_set if item[key] not in cache_question
|
219 |
+
]
|
220 |
+
logging.info(f'Left cases: {len(data_list)}')
|
221 |
+
|
222 |
+
# inference
|
223 |
+
writer_mode = 'w' if args.force else 'a'
|
224 |
+
f_output = open(full_output_fname, writer_mode, encoding='utf-8')
|
225 |
+
process_func = process_func_map.get(args.task,
|
226 |
+
process_code_interpreter)
|
227 |
+
sequential_processing(args, data_list, process_func, f_output)
|
228 |
+
f_output.close()
|
229 |
+
|
230 |
+
# evaluate
|
231 |
+
if not args.gen_exec_only:
|
232 |
+
eval_metrics(args, test_set, full_output_fname)
|
233 |
+
|
234 |
+
os.chdir(current_dir)
|
235 |
+
|
236 |
+
|
237 |
+
def parse_args():
|
238 |
+
parser = argparse.ArgumentParser()
|
239 |
+
parser.add_argument('--model',
|
240 |
+
type=str,
|
241 |
+
default='qwen-14b-chat',
|
242 |
+
choices=list(model_path_map.keys()))
|
243 |
+
parser.add_argument(
|
244 |
+
'--task',
|
245 |
+
type=str,
|
246 |
+
default='all',
|
247 |
+
choices=['all', 'gsm8k', 'visualization', 'general'])
|
248 |
+
parser.add_argument('--output-path', type=str, default='output_data')
|
249 |
+
parser.add_argument('--input-path', type=str, default='eval_data')
|
250 |
+
parser.add_argument('-o', '--output-fname', type=str, default='')
|
251 |
+
parser.add_argument('-i',
|
252 |
+
'--input-fname',
|
253 |
+
type=str,
|
254 |
+
default='eval_code_interpreter_v1.jsonl')
|
255 |
+
parser.add_argument('-f', '--force', action='store_true', default=False)
|
256 |
+
parser.add_argument('--eval-only', action='store_true', default=False)
|
257 |
+
parser.add_argument('--eval-code-exec-only',
|
258 |
+
action='store_true',
|
259 |
+
default=False)
|
260 |
+
parser.add_argument('--gen-exec-only', action='store_true', default=False)
|
261 |
+
parser.add_argument('--gen-only', action='store_true', default=False)
|
262 |
+
parser.add_argument('--vis-judger', type=str, default="'gpt-4-vision-preview'",
|
263 |
+
choices=['gpt-4-vision-preview', 'qwen-vl-chat', 'qwen-vl-plus'])
|
264 |
+
args = parser.parse_args()
|
265 |
+
return args
|
266 |
+
|
267 |
+
|
268 |
+
if __name__ == '__main__':
|
269 |
+
args = parse_args()
|
270 |
+
if not args.eval_only:
|
271 |
+
args.llm = get_model(args.model)
|
272 |
+
logging.info(f'Init {args.model} done.')
|
273 |
+
|
274 |
+
if args.task == 'all':
|
275 |
+
for key in ['gsm8k', 'visualization', 'general']:
|
276 |
+
args.task = key
|
277 |
+
main(args)
|
278 |
+
else:
|
279 |
+
main(args)
|
280 |
+
gather_eval_result(args.model)
|
benchmark/metrics/__init__.py
ADDED
File without changes
|
benchmark/metrics/code_execution.py
ADDED
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
|
4 |
+
import func_timeout
|
5 |
+
from config import get_react_parser
|
6 |
+
from func_timeout import func_set_timeout
|
7 |
+
from utils.code_utils import extract_code, replace_upload_fname
|
8 |
+
from utils.data_utils import load_jsonl, save_jsonl
|
9 |
+
|
10 |
+
pre_load = """
|
11 |
+
import os
|
12 |
+
if 'upload_file' not in os.getcwd():
|
13 |
+
os.chdir("./upload_file/")
|
14 |
+
|
15 |
+
import seaborn as sns
|
16 |
+
|
17 |
+
import matplotlib
|
18 |
+
# matplotlib.use('Agg')
|
19 |
+
import matplotlib.pyplot as plt
|
20 |
+
plt.ion()
|
21 |
+
|
22 |
+
import numpy as np
|
23 |
+
import pandas as pd
|
24 |
+
from sympy import Eq, symbols, solve
|
25 |
+
import re
|
26 |
+
import json
|
27 |
+
import math
|
28 |
+
"""
|
29 |
+
|
30 |
+
tags_config = {
|
31 |
+
'visualization': {
|
32 |
+
'timelimit': True,
|
33 |
+
'extract_first_code': True,
|
34 |
+
},
|
35 |
+
'math': {
|
36 |
+
'timelimit': True,
|
37 |
+
'extract_first_code': False,
|
38 |
+
},
|
39 |
+
'general': {
|
40 |
+
'timelimit': False,
|
41 |
+
'extract_first_code': True,
|
42 |
+
}
|
43 |
+
}
|
44 |
+
|
45 |
+
code_executability = {'math': None, 'visualization': None, 'general': None}
|
46 |
+
|
47 |
+
|
48 |
+
@func_set_timeout(10)
|
49 |
+
def exec_limit_time(text):
|
50 |
+
exec(text, locals())
|
51 |
+
|
52 |
+
|
53 |
+
def exec_code(text, timelimit=False):
|
54 |
+
if timelimit:
|
55 |
+
exec_limit_time(text)
|
56 |
+
else:
|
57 |
+
exec(text, locals())
|
58 |
+
|
59 |
+
|
60 |
+
def postprocess_code(gen_code, line):
|
61 |
+
if '<|im_start|>' in line['query']:
|
62 |
+
first_action_code = get_action_input_code(line['query'])
|
63 |
+
gen_code = first_action_code + gen_code
|
64 |
+
|
65 |
+
upload_fname_list = line[
|
66 |
+
'input_file_path'] if line and 'input_file_path' in line else []
|
67 |
+
gen_code = replace_upload_fname(gen_code, upload_fname_list)
|
68 |
+
|
69 |
+
if 'def solution()' in gen_code:
|
70 |
+
gen_code += '\nsolution()\n'
|
71 |
+
|
72 |
+
if 'plt.show()' in gen_code:
|
73 |
+
gen_code += "\nplt.pause(1)\nplt.close('all')\n"
|
74 |
+
|
75 |
+
if 'sns.' in gen_code and 'plot' in gen_code:
|
76 |
+
gen_code += "\nplt.close('all')\n"
|
77 |
+
|
78 |
+
gen_code = pre_load + gen_code
|
79 |
+
return gen_code
|
80 |
+
|
81 |
+
|
82 |
+
def get_action_input_code(text,
|
83 |
+
model_name='qwen-14b-chat',
|
84 |
+
extract_first_code=False):
|
85 |
+
action_input_list = []
|
86 |
+
tmp = text
|
87 |
+
react_parser = get_react_parser(model_name)
|
88 |
+
while True:
|
89 |
+
action_input = react_parser.get_first_action_input(tmp)
|
90 |
+
if not action_input:
|
91 |
+
break
|
92 |
+
action_input_list.append(action_input)
|
93 |
+
tmp = tmp.split(action_input)[1]
|
94 |
+
if not tmp or extract_first_code:
|
95 |
+
break
|
96 |
+
|
97 |
+
code = ''
|
98 |
+
for action_input in action_input_list:
|
99 |
+
code = code + '# concat\n' + extract_code(action_input) + '\n'
|
100 |
+
return code
|
101 |
+
|
102 |
+
|
103 |
+
def eval_code_execution_rate(output_fname,
|
104 |
+
tag='all_ci',
|
105 |
+
model_name='qwen-14b-chat',
|
106 |
+
timelimit=False,
|
107 |
+
extract_first_code=False):
|
108 |
+
data_list = load_jsonl(output_fname)
|
109 |
+
pip_package = []
|
110 |
+
|
111 |
+
for line_id, line in enumerate(data_list):
|
112 |
+
line['idx'] = line_id
|
113 |
+
tags_list = line['tags'].split(',')
|
114 |
+
if tag not in tags_list:
|
115 |
+
continue
|
116 |
+
|
117 |
+
# update args
|
118 |
+
for cur_tag in tags_list:
|
119 |
+
if cur_tag != 'all_ci':
|
120 |
+
timelimit = tags_config[cur_tag]['timelimit']
|
121 |
+
extract_first_code = tags_config[cur_tag]['extract_first_code']
|
122 |
+
|
123 |
+
line['executable_code'] = False
|
124 |
+
line['missing_code'] = False
|
125 |
+
line['code_error_info'] = ''
|
126 |
+
|
127 |
+
# get Action Input code from response
|
128 |
+
gen_code = get_action_input_code(line['gen'],
|
129 |
+
model_name=model_name,
|
130 |
+
extract_first_code=extract_first_code)
|
131 |
+
|
132 |
+
if not gen_code:
|
133 |
+
line['missing_code'] = True
|
134 |
+
line['code'] = ''
|
135 |
+
line['code_error_info'] = 'missing code'
|
136 |
+
continue
|
137 |
+
|
138 |
+
line['code'] = gen_code
|
139 |
+
gen_code = postprocess_code(gen_code, line)
|
140 |
+
|
141 |
+
while True:
|
142 |
+
try:
|
143 |
+
exec_code(gen_code, timelimit=timelimit)
|
144 |
+
line['executable_code'] = True
|
145 |
+
break
|
146 |
+
except func_timeout.exceptions.FunctionTimedOut as ex:
|
147 |
+
line['code_error_info'] = str(ex)
|
148 |
+
break
|
149 |
+
except (ImportError, ModuleNotFoundError) as ex:
|
150 |
+
try:
|
151 |
+
packege = str(ex).split("'")[1].strip()
|
152 |
+
except Exception:
|
153 |
+
packege = ''
|
154 |
+
if packege and packege not in pip_package: # install package
|
155 |
+
pip_package.append(packege)
|
156 |
+
os.system('pip install ' + packege)
|
157 |
+
logging.info(f'Automatic installation: {packege}')
|
158 |
+
else:
|
159 |
+
line['code_error_info'] = str(ex)
|
160 |
+
break
|
161 |
+
except Exception as ex:
|
162 |
+
line['code_error_info'] = str(ex)
|
163 |
+
break
|
164 |
+
|
165 |
+
# double check
|
166 |
+
observation = get_react_parser(model_name).get_first_observation(
|
167 |
+
line['gen'])
|
168 |
+
if line['executable_code'] and ('error:' in observation):
|
169 |
+
logging.warning(
|
170 |
+
'The code executes correctly, but it has an error in IPython!')
|
171 |
+
logging.warning(f'Code:\n{gen_code}')
|
172 |
+
logging.warning(f'IPython error info:\n{observation}')
|
173 |
+
logging.info('=' * 60)
|
174 |
+
elif not line['executable_code'] and not ('error:' in observation):
|
175 |
+
logging.warning(
|
176 |
+
'The code has an execution error, but it runs correctly in IPython!'
|
177 |
+
)
|
178 |
+
logging.warning(f'Code:\n{gen_code}')
|
179 |
+
logging.warning(f"Exec error info:\n{line['code_error_info']}")
|
180 |
+
logging.warning(f'IPython observation:\n{observation}')
|
181 |
+
logging.info('=' * 60)
|
182 |
+
|
183 |
+
# save error data
|
184 |
+
error_data_list = [
|
185 |
+
item for item in data_list
|
186 |
+
if not item['executable_code'] or item['missing_code']
|
187 |
+
]
|
188 |
+
error_data_output_fname = os.path.splitext(
|
189 |
+
output_fname)[0] + '_exec_error.jsonl'
|
190 |
+
save_jsonl(error_data_list, error_data_output_fname)
|
191 |
+
|
192 |
+
log_result(data_list)
|
193 |
+
|
194 |
+
return code_executability
|
195 |
+
|
196 |
+
|
197 |
+
def log_result(data_list, verbose=True):
|
198 |
+
if verbose:
|
199 |
+
logging.info('*' * 60)
|
200 |
+
logging.info('{:^60}'.format('Detail'))
|
201 |
+
logging.info('*' * 60)
|
202 |
+
for line_id, line in enumerate(data_list):
|
203 |
+
logging.info(f'Question {line_id}'.center(60, '='))
|
204 |
+
logging.info(line['query'])
|
205 |
+
|
206 |
+
logging.info(f'Generated {line_id}'.center(60, '-'))
|
207 |
+
logging.info('\n' + line['gen'])
|
208 |
+
|
209 |
+
logging.info(f'Code {line_id}'.center(60, '-'))
|
210 |
+
logging.info('\n' + line['code'])
|
211 |
+
|
212 |
+
logging.info(f'Exec Result {line_id}'.center(60, '-'))
|
213 |
+
prefix_info = 'Exec Success' if line[
|
214 |
+
'executable_code'] else 'Exec Error: '
|
215 |
+
exec_info = prefix_info + line['code_error_info']
|
216 |
+
logging.info(exec_info)
|
217 |
+
|
218 |
+
logging.info('=' * 60)
|
219 |
+
logging.info('{:^60}'.format('Code Execuation Rate'))
|
220 |
+
logging.info('=' * 60)
|
221 |
+
involved_tags = []
|
222 |
+
for line in data_list:
|
223 |
+
involved_tags += line['tags'].split(',')
|
224 |
+
involved_tags = list(set(involved_tags))
|
225 |
+
|
226 |
+
for key in involved_tags:
|
227 |
+
logging.info(f'task: {key}'.center(60, '='))
|
228 |
+
key_item_list = [item for item in data_list if key in item['tags']]
|
229 |
+
all_count = len(key_item_list)
|
230 |
+
missing_code_count = len(
|
231 |
+
[item for item in key_item_list if item['missing_code']])
|
232 |
+
executable_code_count = len(
|
233 |
+
[item for item in key_item_list if item['executable_code']])
|
234 |
+
|
235 |
+
logging.info(f'All Test: {all_count}')
|
236 |
+
logging.info(f'Missing Code: {missing_code_count}')
|
237 |
+
logging.info(f'Predict Exec Success: {executable_code_count}')
|
238 |
+
logging.info('Codes available && Execution Rate: {:.2f}'.format(
|
239 |
+
executable_code_count / (all_count - missing_code_count) * 100))
|
240 |
+
logging.info('Execution Rate: {:.2f}'.format(executable_code_count /
|
241 |
+
all_count * 100))
|
242 |
+
logging.info('Non-executable rate: {:.2f}'.format(
|
243 |
+
(all_count - missing_code_count - executable_code_count) /
|
244 |
+
all_count * 100))
|
245 |
+
logging.info('Missing code rate: {:.2f}'.format(missing_code_count /
|
246 |
+
all_count * 100))
|
247 |
+
|
248 |
+
if key != 'all_ci':
|
249 |
+
code_executability[key] = executable_code_count / all_count * 100
|
250 |
+
|
251 |
+
if verbose:
|
252 |
+
logging.info('Error List: ')
|
253 |
+
error_list = [(item['idx'], item['code_error_info'])
|
254 |
+
for item in key_item_list if item['code_error_info']]
|
255 |
+
error_list.sort(key=lambda x: x[1])
|
256 |
+
for x in error_list:
|
257 |
+
logging.info(x)
|
benchmark/metrics/gsm8k.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
from utils.data_utils import load_jsonl, save_jsonl
|
7 |
+
|
8 |
+
INVALID_ANS = '[invalid]'
|
9 |
+
|
10 |
+
|
11 |
+
def extract_answer(completion):
|
12 |
+
|
13 |
+
def _get_last_digit(s):
|
14 |
+
_PAT_LAST_DIGIT = re.compile(
|
15 |
+
r'(?<=(\s|[\$%#{]))([+-])?(?=(\S))(0|([1-9](\d*|\d{0,2}(,\d{3})*)))?(\.\d*[1-9])?(?=(\s|[.,}]|$))'
|
16 |
+
)
|
17 |
+
match = list(_PAT_LAST_DIGIT.finditer(s))
|
18 |
+
if match:
|
19 |
+
last_digit = match[-1].group().replace(',', '').replace('+', '')
|
20 |
+
else:
|
21 |
+
last_digit = None
|
22 |
+
logging.warning(f'No digits found in {s!r}')
|
23 |
+
return last_digit
|
24 |
+
|
25 |
+
job_gen = completion.strip('.').replace('\n', '\\n')
|
26 |
+
last_digit = _get_last_digit(job_gen)
|
27 |
+
if last_digit:
|
28 |
+
return eval(last_digit)
|
29 |
+
else:
|
30 |
+
return INVALID_ANS
|
31 |
+
|
32 |
+
|
33 |
+
def is_correct(completion, answer):
|
34 |
+
gold = extract_answer(answer)
|
35 |
+
assert gold != INVALID_ANS, 'No ground truth answer found in the document.'
|
36 |
+
return extract_answer(completion) == gold
|
37 |
+
|
38 |
+
|
39 |
+
def eval_gsm8k_acc(output_fname):
|
40 |
+
data_list = load_jsonl(output_fname)
|
41 |
+
acc_res = [item['acc'] for item in data_list]
|
42 |
+
logging.info('=' * 60)
|
43 |
+
logging.info('{:^60}'.format('Math Acc.'))
|
44 |
+
logging.info('=' * 60)
|
45 |
+
logging.info('Total num={:.2f}'.format(len(acc_res)))
|
46 |
+
logging.info('Right num={:.2f}'.format(np.sum(acc_res)))
|
47 |
+
logging.info('Zero-shot Acc={:.2f}'.format(np.mean(acc_res) * 100))
|
48 |
+
|
49 |
+
error_data_list = [item for item in data_list if not item['acc']]
|
50 |
+
error_data_output_fname = os.path.splitext(
|
51 |
+
output_fname)[0] + '_gsm8k_error.jsonl'
|
52 |
+
save_jsonl(error_data_list, error_data_output_fname)
|
53 |
+
|
54 |
+
return {'math': np.mean(acc_res) * 100}
|
benchmark/metrics/visualization.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
import base64
|
5 |
+
import torch
|
6 |
+
from config import get_model, get_react_parser
|
7 |
+
from utils.data_utils import load_jsonl, save_jsonl
|
8 |
+
|
9 |
+
torch.manual_seed(1234)
|
10 |
+
|
11 |
+
EVAL_VISUAL_PROMPT_ZH = """请判断图片是否与下面的[问题]一致,如果一致则回复“right”,不一致则回复“wrong”。
|
12 |
+
[问题]:{query}
|
13 |
+
"""
|
14 |
+
|
15 |
+
EVAL_VISUAL_PROMPT_EN = """Please judge whether the image is consistent with the [Question] below, if it is consistent then reply "right", if not then reply "wrong".
|
16 |
+
[Question]: {query}
|
17 |
+
"""
|
18 |
+
|
19 |
+
visualization_code_correctness = {
|
20 |
+
'visualization-hard': None,
|
21 |
+
'visualization-easy': None,
|
22 |
+
}
|
23 |
+
|
24 |
+
|
25 |
+
def encode_image(image_path):
|
26 |
+
with open(image_path, "rb") as image_file:
|
27 |
+
a = base64.b64encode(image_file.read()).decode('utf-8')
|
28 |
+
return a
|
29 |
+
|
30 |
+
|
31 |
+
def judger_model_inference(judger_model_name, judger_model, imgs=[], prompt=''):
|
32 |
+
output = ""
|
33 |
+
if judger_model_name == 'gpt-4-vision-preview':
|
34 |
+
logging.warning("This is an example of `gpt-4-vision-preview`. "
|
35 |
+
"Please set the API key and use according to your actual situation.")
|
36 |
+
from openai import OpenAI
|
37 |
+
client = OpenAI()
|
38 |
+
content_list = []
|
39 |
+
content_list.append({"type": "text", "text": prompt})
|
40 |
+
input_images = []
|
41 |
+
for img in imgs:
|
42 |
+
if 'http' not in img:
|
43 |
+
base64_image = encode_image(img)
|
44 |
+
img = f"data:image/jpeg;base64,{base64_image}"
|
45 |
+
input_images.append({"type": "image_url", 'image_url': img})
|
46 |
+
content_list.extend(input_images)
|
47 |
+
response = client.chat.completions.create(
|
48 |
+
model="gpt-4-vision-preview",
|
49 |
+
messages=[
|
50 |
+
{
|
51 |
+
"role": "user",
|
52 |
+
"content": content_list,
|
53 |
+
}
|
54 |
+
],
|
55 |
+
max_tokens=300,
|
56 |
+
)
|
57 |
+
output = response.choices[0]
|
58 |
+
elif judger_model_name in ['qwen-vl-plus', 'qwen-vl-chat']:
|
59 |
+
inputs = []
|
60 |
+
for img in imgs:
|
61 |
+
if 'http' not in img and judger_model_name == 'qwen-vl-plus':
|
62 |
+
img = "file://" + img
|
63 |
+
inputs.append({'image': img})
|
64 |
+
inputs.append({'text': prompt})
|
65 |
+
|
66 |
+
logging.info('Eval'.center(60, '-'))
|
67 |
+
logging.info(inputs)
|
68 |
+
output = judger_model.generate(inputs)
|
69 |
+
logging.info(output)
|
70 |
+
logging.info('=' * 60)
|
71 |
+
return output
|
72 |
+
|
73 |
+
|
74 |
+
def extract_images(text):
|
75 |
+
regex = re.compile(r'!\[fig-(.+)\]\((.+)\)')
|
76 |
+
results = re.findall(regex, text)
|
77 |
+
images = []
|
78 |
+
for res in results:
|
79 |
+
assert len(res) == 2
|
80 |
+
if os.path.exists(res[1]):
|
81 |
+
images.append(res[1])
|
82 |
+
return images
|
83 |
+
|
84 |
+
|
85 |
+
def check_images_observation(text, images, model_name):
|
86 |
+
start_flag = get_react_parser(model_name).observation
|
87 |
+
for image in images:
|
88 |
+
logging.info('Image'.center(60, '-'))
|
89 |
+
logging.info(image)
|
90 |
+
|
91 |
+
end_idx = text.find(image)
|
92 |
+
tmp_text = text[:end_idx + len(image)]
|
93 |
+
start_idx = tmp_text.rfind(start_flag)
|
94 |
+
check_text = tmp_text[start_idx + len(start_flag):]
|
95 |
+
|
96 |
+
logging.info('Observation'.center(60, '-'))
|
97 |
+
logging.info(check_text)
|
98 |
+
|
99 |
+
# As long as there exists correctly executed observation, we consider `True`
|
100 |
+
if 'error:' not in check_text and 'Traceback' not in check_text:
|
101 |
+
return True
|
102 |
+
return False
|
103 |
+
|
104 |
+
|
105 |
+
eval_visual_prompt = {'zh': EVAL_VISUAL_PROMPT_ZH, 'en': EVAL_VISUAL_PROMPT_EN}
|
106 |
+
|
107 |
+
|
108 |
+
def eval_visualization_acc(output_fname, model_name, judger_model_name='gpt-4-vision-preview'):
|
109 |
+
if judger_model_name == 'gpt-4-vision-preview':
|
110 |
+
judger_model = None
|
111 |
+
elif judger_model_name in ['qwen-vl-chat', 'qwen-vl-plus']:
|
112 |
+
if judger_model_name == 'qwen-vl-chat':
|
113 |
+
logging.warning('In this benchmark of version 20231206, `Qwen-vl-chat` is no longer used as the '
|
114 |
+
'evaluation model for `Visualization` task.. If you insist on using it, '
|
115 |
+
'the evaluation results might differ from the official results.')
|
116 |
+
judger_model = get_model(judger_model_name)
|
117 |
+
else:
|
118 |
+
raise Exception("Not supported judger model.")
|
119 |
+
|
120 |
+
one_action, one_action_right = 0, 0
|
121 |
+
zero_action, zero_action_right = 0, 0
|
122 |
+
|
123 |
+
data_list = load_jsonl(output_fname)
|
124 |
+
for item in data_list:
|
125 |
+
if 'visualization' not in item['tags']:
|
126 |
+
continue
|
127 |
+
|
128 |
+
item['vis_acc'] = False
|
129 |
+
if '<|im_end|>' in item['query']:
|
130 |
+
one_action += 1
|
131 |
+
prompt = item['query'].split('<|im_end|>')[0]
|
132 |
+
else:
|
133 |
+
zero_action += 1
|
134 |
+
prompt = item['query']
|
135 |
+
|
136 |
+
images = extract_images(item['gen'])
|
137 |
+
|
138 |
+
if images and check_images_observation(item['gen'], images,
|
139 |
+
model_name):
|
140 |
+
input_prompt = eval_visual_prompt[item.get('lang', 'en')]
|
141 |
+
format_prompt = input_prompt.format(query=prompt)
|
142 |
+
output = judger_model_inference(judger_model_name, judger_model, images, format_prompt)
|
143 |
+
if 'right' in output.lower():
|
144 |
+
item['vis_acc'] = True
|
145 |
+
if '<|im_end|>' in item['query']:
|
146 |
+
one_action_right += 1
|
147 |
+
else:
|
148 |
+
zero_action_right += 1
|
149 |
+
|
150 |
+
logging.info('*' * 60)
|
151 |
+
logging.info('{:^60}'.format('Visualization Acc.'))
|
152 |
+
logging.info('*' * 60)
|
153 |
+
logging.info(
|
154 |
+
'Visualization-Hard count={}, Visualization-Hard right count={}, Visualization-Hard acc={:.2f}'
|
155 |
+
.format(zero_action, zero_action_right,
|
156 |
+
zero_action_right / zero_action * 100))
|
157 |
+
logging.info(
|
158 |
+
'Visualization-Easy count={}, Visualization-Easy right count={}, Visualization-Easy acc={:.2f}'
|
159 |
+
.format(one_action, one_action_right,
|
160 |
+
one_action_right / one_action * 100))
|
161 |
+
logging.info('all count={}, all right={}, all acc={:.2f}'.format(
|
162 |
+
zero_action + one_action, zero_action_right + one_action_right,
|
163 |
+
(zero_action_right + one_action_right) / (zero_action + one_action) *
|
164 |
+
100))
|
165 |
+
|
166 |
+
visualization_code_correctness[
|
167 |
+
'visualization-hard'] = zero_action_right / zero_action * 100
|
168 |
+
visualization_code_correctness[
|
169 |
+
'visualization-easy'] = one_action_right / one_action * 100
|
170 |
+
|
171 |
+
error_data_list = [
|
172 |
+
item for item in data_list
|
173 |
+
if 'visualization' in item['tags'] and not item['vis_acc']
|
174 |
+
]
|
175 |
+
error_data_output_fname = os.path.splitext(
|
176 |
+
output_fname)[0] + '_vis_error.jsonl'
|
177 |
+
save_jsonl(error_data_list, error_data_output_fname)
|
178 |
+
|
179 |
+
return visualization_code_correctness
|
benchmark/models/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from models.base import HFModel # noqa
|
2 |
+
from models.llm import LLM # noqa
|
3 |
+
from models.qwen import Qwen, QwenVL # noqa
|
4 |
+
from models.dashscope import QwenDashscopeVLModel
|
benchmark/models/base.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
2 |
+
from transformers.generation import GenerationConfig
|
3 |
+
|
4 |
+
|
5 |
+
class HFModel(object):
|
6 |
+
|
7 |
+
def __init__(self, model_path):
|
8 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_path,
|
9 |
+
trust_remote_code=True)
|
10 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
11 |
+
model_path,
|
12 |
+
trust_remote_code=True,
|
13 |
+
device_map='auto',
|
14 |
+
low_cpu_mem_usage=True).eval()
|
15 |
+
self.model.generation_config = GenerationConfig.from_pretrained(
|
16 |
+
model_path, trust_remote_code=True)
|
17 |
+
self.model.generation_config.do_sample = False
|
benchmark/models/dashscope.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from http import HTTPStatus
|
3 |
+
import time
|
4 |
+
import dashscope
|
5 |
+
|
6 |
+
|
7 |
+
class QwenDashscopeVLModel(object):
|
8 |
+
def __init__(self, model, api_key):
|
9 |
+
self.model = model
|
10 |
+
dashscope.api_key = api_key.strip() or os.getenv('DASHSCOPE_API_KEY', default='')
|
11 |
+
assert dashscope.api_key, 'DASHSCOPE_API_KEY is required.'
|
12 |
+
|
13 |
+
def generate(self, prompt, stop_words=[]):
|
14 |
+
if isinstance(prompt, str):
|
15 |
+
prompt = [{'text': prompt}]
|
16 |
+
|
17 |
+
MAX_TRY = 3
|
18 |
+
count = 0
|
19 |
+
while count < MAX_TRY:
|
20 |
+
response = dashscope.MultiModalConversation.call(
|
21 |
+
self.model,
|
22 |
+
messages=[{'role': 'user', 'content': prompt}],
|
23 |
+
top_p=0.01,
|
24 |
+
top_k=1,
|
25 |
+
)
|
26 |
+
if response.status_code == HTTPStatus.OK:
|
27 |
+
output = response.output.choices[0].message.content[0]['text']
|
28 |
+
for stop_str in stop_words:
|
29 |
+
idx = output.find(stop_str)
|
30 |
+
if idx != -1:
|
31 |
+
output = output[: idx + len(stop_str)]
|
32 |
+
return output
|
33 |
+
else:
|
34 |
+
err = 'Error code: %s, error message: %s' % (
|
35 |
+
response.code,
|
36 |
+
response.message,
|
37 |
+
)
|
38 |
+
logging.error(err)
|
39 |
+
count += 1
|
40 |
+
time.sleep(1)
|
benchmark/models/llm.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from models.base import HFModel
|
3 |
+
|
4 |
+
|
5 |
+
class LLM(HFModel):
|
6 |
+
|
7 |
+
def __init__(self, model_path):
|
8 |
+
super().__init__(model_path)
|
9 |
+
|
10 |
+
def generate(self, input_text, stop_words=[], max_new_tokens=512):
|
11 |
+
if isinstance(input_text, str):
|
12 |
+
input_text = [input_text]
|
13 |
+
|
14 |
+
input_ids = self.tokenizer(input_text)['input_ids']
|
15 |
+
input_ids = torch.tensor(input_ids, device=self.model.device)
|
16 |
+
gen_kwargs = {'max_new_tokens': max_new_tokens, 'do_sample': False}
|
17 |
+
outputs = self.model.generate(input_ids, **gen_kwargs)
|
18 |
+
s = outputs[0][input_ids.shape[1]:]
|
19 |
+
output = self.tokenizer.decode(s, skip_special_tokens=True)
|
20 |
+
|
21 |
+
for stop_str in stop_words:
|
22 |
+
idx = output.find(stop_str)
|
23 |
+
if idx != -1:
|
24 |
+
output = output[:idx + len(stop_str)]
|
25 |
+
|
26 |
+
return output
|
benchmark/models/qwen.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from models.base import HFModel
|
3 |
+
|
4 |
+
|
5 |
+
class Qwen(HFModel):
|
6 |
+
|
7 |
+
def __init__(self, model_path):
|
8 |
+
super().__init__(model_path)
|
9 |
+
|
10 |
+
def generate(self, input_text, stop_words=[]):
|
11 |
+
im_end = '<|im_end|>'
|
12 |
+
if im_end not in stop_words:
|
13 |
+
stop_words = stop_words + [im_end]
|
14 |
+
stop_words_ids = [self.tokenizer.encode(w) for w in stop_words]
|
15 |
+
|
16 |
+
input_ids = torch.tensor([self.tokenizer.encode(input_text)
|
17 |
+
]).to(self.model.device)
|
18 |
+
output = self.model.generate(input_ids, stop_words_ids=stop_words_ids)
|
19 |
+
output = output.tolist()[0]
|
20 |
+
output = self.tokenizer.decode(output, errors='ignore')
|
21 |
+
assert output.startswith(input_text)
|
22 |
+
output = output[len(input_text):].replace('<|endoftext|>',
|
23 |
+
'').replace(im_end, '')
|
24 |
+
|
25 |
+
return output
|
26 |
+
|
27 |
+
|
28 |
+
class QwenVL(HFModel):
|
29 |
+
def __init__(self, model_path):
|
30 |
+
super().__init__(model_path)
|
31 |
+
|
32 |
+
def generate(self, inputs: list):
|
33 |
+
query = self.tokenizer.from_list_format(inputs)
|
34 |
+
response, _ = self.model.chat(self.tokenizer, query=query, history=None)
|
35 |
+
|
36 |
+
return response
|
benchmark/parser/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from parser.internlm_parser import InternLMReActParser # noqa
|
2 |
+
from parser.react_parser import ReActParser # noqa
|
benchmark/parser/internlm_parser.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from parser.react_parser import ReActParser
|
2 |
+
|
3 |
+
|
4 |
+
class InternLMReActParser(ReActParser):
|
5 |
+
|
6 |
+
def __init__(self):
|
7 |
+
self.action = '\nAction:'
|
8 |
+
self.action_input = '\nActionInput:'
|
9 |
+
self.action_input_stop = '<eoa>'
|
10 |
+
self.observation = '<|System|>:Response:'
|
11 |
+
self.observation_stop = '<TOKENS_UNUSED_2>\n<|Bot|>:'
|
benchmark/parser/react_parser.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class ReActParser(object):
|
2 |
+
|
3 |
+
def __init__(self):
|
4 |
+
self.action = '\nAction:'
|
5 |
+
self.action_input = '\nAction Input:'
|
6 |
+
self.action_input_stop = '\nObservation:'
|
7 |
+
self.observation = '\nObservation:'
|
8 |
+
self.observation_stop = '\nThought:'
|
9 |
+
|
10 |
+
def parse_latest_plugin_call(self, text):
|
11 |
+
action = self.action
|
12 |
+
action_input = self.action_input
|
13 |
+
observation = self.action_input_stop
|
14 |
+
plugin_name, plugin_args = '', ''
|
15 |
+
i = text.rfind(action)
|
16 |
+
j = text.rfind(action_input)
|
17 |
+
k = text.rfind(observation)
|
18 |
+
if 0 <= i < j: # If the text has `Action` and `Action input`,
|
19 |
+
if k < j: # but does not contain `Observation`,
|
20 |
+
# then it is likely that `Observation` is ommited by the LLM,
|
21 |
+
# because the output text may have discarded the stop word.
|
22 |
+
text = text.rstrip() + observation # Add it back.
|
23 |
+
k = text.rfind(observation)
|
24 |
+
plugin_name = text[i + len(action):j].strip()
|
25 |
+
plugin_args = text[j + len(action_input):k].strip()
|
26 |
+
text = text[:k]
|
27 |
+
return plugin_name, plugin_args, text
|
28 |
+
|
29 |
+
def _extract_first_target(self, text, start_flag, end_flag):
|
30 |
+
target = ''
|
31 |
+
i = text.find(start_flag)
|
32 |
+
if i != -1:
|
33 |
+
j = text.find(end_flag, i)
|
34 |
+
if j != -1:
|
35 |
+
target = text[i + len(start_flag):j].strip()
|
36 |
+
else:
|
37 |
+
target = text[i + len(start_flag):].strip()
|
38 |
+
return target
|
39 |
+
|
40 |
+
def get_first_observation(self, text):
|
41 |
+
return self._extract_first_target(text, self.observation,
|
42 |
+
self.observation_stop)
|
43 |
+
|
44 |
+
def get_first_action_input(self, text):
|
45 |
+
return self._extract_first_target(text, self.action_input,
|
46 |
+
self.action_input_stop)
|
benchmark/prompt/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from prompt.internlm_react import InternLMReAct # noqa
|
2 |
+
from prompt.llama_react import LlamaReAct # noqa
|
3 |
+
from prompt.qwen_react import QwenReAct # noqa
|
4 |
+
from prompt.react import ReAct # noqa
|
benchmark/prompt/internlm_react.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from prompt.react import ReAct
|
2 |
+
|
3 |
+
INTERNLM_TOOL_DESCRIPTION = """用来执行Python代码。代码必须是一个函数,
|
4 |
+
函数名必须得是 'solution',代码对应你的思考过程。代码实例格式如下:
|
5 |
+
```python
|
6 |
+
# import 依赖包
|
7 |
+
import xxx
|
8 |
+
def solution():
|
9 |
+
# 初始化一些变量
|
10 |
+
variable_names_with_real_meaning = xxx
|
11 |
+
# 步骤一
|
12 |
+
mid_variable = func(variable_names_with_real_meaning)
|
13 |
+
# 步骤 x
|
14 |
+
mid_variable = func(mid_variable)
|
15 |
+
# 最后结果
|
16 |
+
final_answer = func(mid_variable)
|
17 |
+
return final_answer
|
18 |
+
```"""
|
19 |
+
|
20 |
+
INTERNLM_TOOL = {'PythonInterpreter': INTERNLM_TOOL_DESCRIPTION}
|
21 |
+
|
22 |
+
INTERNLM_REACT_PROMPT_ZH = """<|System|>:你是一个可以调用外部工具的助手,可以使用的工具包括:
|
23 |
+
{tools_text}
|
24 |
+
如果使用工具请遵循以下格式回复:
|
25 |
+
```
|
26 |
+
Thought:思考你当前步骤需要解决什么问题,是否需要使用工具
|
27 |
+
Action:工具名称,你的工具必须从 [{tools_name_text}] 选择
|
28 |
+
ActionInput:工具输入参数
|
29 |
+
```
|
30 |
+
工具返回按照以下格式回复:
|
31 |
+
```
|
32 |
+
Response:调用工具后的结果
|
33 |
+
```
|
34 |
+
如果你已经知道了答案,或者你不需要工具,请遵循以下格式回复
|
35 |
+
```
|
36 |
+
Thought:给出最终答案的思考过程
|
37 |
+
FinalAnswer:最终答案
|
38 |
+
```
|
39 |
+
开始!<TOKENS_UNUSED_2>
|
40 |
+
<|User|>:{query}<eoh>
|
41 |
+
<|Bot|>:"""
|
42 |
+
|
43 |
+
INTERNLM_REACT_PROMPT_EN = """<|System|>:You are a assistant who can utilize external tools.
|
44 |
+
{tools_text}
|
45 |
+
To use a tool, please use the following format:
|
46 |
+
```
|
47 |
+
Thought: Think what you need to solve, do you need to use tools?
|
48 |
+
Action: the tool name, should be one of [{tools_name_text}]
|
49 |
+
ActionInput: the input to the action
|
50 |
+
```
|
51 |
+
The response after utilizing tools should using the following format:
|
52 |
+
```
|
53 |
+
Response: the results after call the tool.
|
54 |
+
``
|
55 |
+
If you already know the answer, or you do not need to use tools,
|
56 |
+
please using the following format to reply:
|
57 |
+
```
|
58 |
+
Thought: the thought process to get the final answer
|
59 |
+
FinalAnswer: final answer
|
60 |
+
```
|
61 |
+
Begin!<TOKENS_UNUSED_2>
|
62 |
+
<|User|>:{query}<eoh>
|
63 |
+
<|Bot|>:"""
|
64 |
+
|
65 |
+
|
66 |
+
class InternLMReAct(ReAct):
|
67 |
+
|
68 |
+
def __init__(self, query, lang='en', upload_file_paths=[]):
|
69 |
+
super().__init__(query, lang, upload_file_paths)
|
70 |
+
self.react_template = INTERNLM_REACT_PROMPT_ZH if self.lang == 'zh' else INTERNLM_REACT_PROMPT_EN
|
71 |
+
|
72 |
+
def build_prompt(self):
|
73 |
+
planning_prompt = super().build_prompt()
|
74 |
+
if '<|im_end|>' in self.query and planning_prompt.endswith(
|
75 |
+
'<eoh>\n<|Bot|>:'):
|
76 |
+
planning_prompt = planning_prompt[:-len('<eoh>\n<|Bot|>:')]
|
77 |
+
|
78 |
+
if '<|im_end|>' in self.query:
|
79 |
+
planning_prompt = planning_prompt.replace(
|
80 |
+
'<|im_end|>\n<|im_start|>assistant\n',
|
81 |
+
'<eoh>\n<|Bot|>:').replace(
|
82 |
+
'Observation:', '<eoa>\n<|System|>:Response:').replace(
|
83 |
+
'\nAction Input',
|
84 |
+
'\nActionInput').replace('code_interpreter',
|
85 |
+
'PythonInterpreter')
|
86 |
+
assert planning_prompt.endswith('Thought:')
|
87 |
+
planning_prompt = planning_prompt[:-len(
|
88 |
+
'Thought:')] + '<TOKENS_UNUSED_2>\n<|Bot|>:'
|
89 |
+
|
90 |
+
self.prompt = planning_prompt
|
91 |
+
return planning_prompt
|
92 |
+
|
93 |
+
def _build_tools_text(self):
|
94 |
+
return INTERNLM_TOOL
|
95 |
+
|
96 |
+
def _build_tools_name_text(self):
|
97 |
+
return list(INTERNLM_TOOL.keys())
|
98 |
+
|
99 |
+
def build_observation(self, observation):
|
100 |
+
return f'<eoa>\n<|System|>:Response:{observation}\n<TOKENS_UNUSED_2>\n<|Bot|>:'
|
101 |
+
|
102 |
+
def get_stop_words_list(self):
|
103 |
+
return ['<eoa>']
|
benchmark/prompt/llama_react.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from prompt.react import ReAct
|
2 |
+
|
3 |
+
|
4 |
+
class LlamaReAct(ReAct):
|
5 |
+
|
6 |
+
def __init__(self, query, lang='en', upload_file_paths=[]):
|
7 |
+
super().__init__(query, lang, upload_file_paths)
|
8 |
+
|
9 |
+
def build_prompt(self):
|
10 |
+
planning_prompt = super().build_prompt()
|
11 |
+
planning_prompt = '[INST] ' + planning_prompt + ' [/INST]'
|
12 |
+
|
13 |
+
if '<|im_end|>' in self.query:
|
14 |
+
planning_prompt = planning_prompt.replace(
|
15 |
+
'<|im_end|>\n<|im_start|>assistant', ' [/INST] ')
|
16 |
+
assert planning_prompt.endswith(' [/INST]')
|
17 |
+
planning_prompt = planning_prompt[:-len(' [/INST]')]
|
18 |
+
|
19 |
+
self.prompt = planning_prompt
|
20 |
+
return planning_prompt
|
benchmark/prompt/qwen_react.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
|
4 |
+
from prompt.react import ReAct
|
5 |
+
|
6 |
+
QWEN_TOOLS_LIST = [
|
7 |
+
{
|
8 |
+
'name_for_human': '代码解释器',
|
9 |
+
'name_for_model': 'code_interpreter',
|
10 |
+
'description_for_model': '代码解释器,可用于执行Python代码。',
|
11 |
+
'parameters': [{
|
12 |
+
'name': 'code',
|
13 |
+
'type': 'string',
|
14 |
+
'description': '待执行的代码'
|
15 |
+
}],
|
16 |
+
'args_format': 'code'
|
17 |
+
},
|
18 |
+
]
|
19 |
+
|
20 |
+
TOOL_DESC = """{name_for_model}: Call this tool to interact with the {name_for_human} API. What is the {name_for_human} API useful for? {description_for_model} Parameters: {parameters}"""
|
21 |
+
|
22 |
+
|
23 |
+
class QwenReAct(ReAct):
|
24 |
+
|
25 |
+
def __init__(self, query, lang='en', upload_file_paths=[]):
|
26 |
+
super().__init__(query, lang, upload_file_paths)
|
27 |
+
|
28 |
+
self.upload_file_paths = [
|
29 |
+
f'{os.path.basename(fname)}' for fname in upload_file_paths
|
30 |
+
]
|
31 |
+
self.list_of_plugin_info = QWEN_TOOLS_LIST
|
32 |
+
self.fname_template = {
|
33 |
+
'zh': '[上传文件{fname_str}]',
|
34 |
+
'en': '[Upload file {fname_str}]',
|
35 |
+
'en_multi': '[Upload file {fname_str}]'
|
36 |
+
}
|
37 |
+
|
38 |
+
def build_prompt(self):
|
39 |
+
im_start = '<|im_start|>'
|
40 |
+
im_end = '<|im_end|>'
|
41 |
+
prompt = f'{im_start}system\nYou are a helpful assistant.{im_end}'
|
42 |
+
|
43 |
+
query = super().build_prompt()
|
44 |
+
|
45 |
+
query = query.lstrip('\n').rstrip()
|
46 |
+
prompt += f'\n{im_start}user\n{query}{im_end}'
|
47 |
+
if f'{im_start}assistant' not in query:
|
48 |
+
prompt += f'\n{im_start}assistant\n{im_end}'
|
49 |
+
assert prompt.endswith(f'\n{im_start}assistant\n{im_end}')
|
50 |
+
|
51 |
+
prompt = prompt[:-len(f'{im_end}')]
|
52 |
+
self.prompt = prompt
|
53 |
+
return prompt
|
54 |
+
|
55 |
+
def _build_tools_text(self):
|
56 |
+
# tool info
|
57 |
+
tools_text = []
|
58 |
+
for plugin_info in self.list_of_plugin_info:
|
59 |
+
tool = TOOL_DESC.format(
|
60 |
+
name_for_model=plugin_info['name_for_model'],
|
61 |
+
name_for_human=plugin_info['name_for_human'],
|
62 |
+
description_for_model=plugin_info['description_for_model'],
|
63 |
+
parameters=json.dumps(plugin_info['parameters'],
|
64 |
+
ensure_ascii=False),
|
65 |
+
)
|
66 |
+
if plugin_info.get('args_format', 'json') == 'json':
|
67 |
+
tool += ' Format the arguments as a JSON object.'
|
68 |
+
elif plugin_info['args_format'] == 'code':
|
69 |
+
tool += ' Enclose the code within triple backticks (`) at the beginning and end of the code.'
|
70 |
+
else:
|
71 |
+
raise NotImplementedError
|
72 |
+
tools_text.append(tool)
|
73 |
+
tools_text = '\n\n'.join(tools_text)
|
74 |
+
return tools_text
|
75 |
+
|
76 |
+
def _build_tools_name_text(self):
|
77 |
+
return ', '.join([
|
78 |
+
plugin_info['name_for_model']
|
79 |
+
for plugin_info in self.list_of_plugin_info
|
80 |
+
])
|
benchmark/prompt/react.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
tools_text = """code_interpreter: Call this tool to interact with the Code Interpreter API.
|
4 |
+
What is the Code Interpreter API useful for?
|
5 |
+
Code Interpreter is used to execute Python code to deal with the following tasks:
|
6 |
+
1. Solving mathematical problems, both quantitative and qualitative
|
7 |
+
2. Doing data analysis and visualization
|
8 |
+
3. Converting files between formats
|
9 |
+
Parameters:
|
10 |
+
```py
|
11 |
+
code
|
12 |
+
```
|
13 |
+
Enclose the code within triple backticks (```) at the beginning and end of the code.
|
14 |
+
"""
|
15 |
+
|
16 |
+
REACT_PROMPT = """Answer the following questions as best you can. You have access to the following tools:
|
17 |
+
|
18 |
+
{tools_text}
|
19 |
+
|
20 |
+
Use the following format:
|
21 |
+
|
22 |
+
Question: the input question you must answer
|
23 |
+
Thought: you should always think about what to do
|
24 |
+
Action: the action to take, should be one of [{tools_name_text}]
|
25 |
+
Action Input: the input to the action
|
26 |
+
Observation: the result of the action
|
27 |
+
... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
|
28 |
+
Thought: I now know the final answer
|
29 |
+
Final Answer: the final answer to the original input question
|
30 |
+
|
31 |
+
Begin!
|
32 |
+
|
33 |
+
Question: {query}"""
|
34 |
+
|
35 |
+
fname_template = {
|
36 |
+
'zh': '文件{fname_str},',
|
37 |
+
'en_multi': 'Files {fname_str}. ',
|
38 |
+
'en': 'File {fname_str}. ',
|
39 |
+
}
|
40 |
+
|
41 |
+
|
42 |
+
class ReAct(object):
|
43 |
+
|
44 |
+
def __init__(self, query, lang='en', upload_file_paths=[]):
|
45 |
+
self.query = query
|
46 |
+
self.lang = lang
|
47 |
+
self.upload_file_paths = [
|
48 |
+
f'`{os.path.basename(fname)}`' for fname in upload_file_paths
|
49 |
+
]
|
50 |
+
|
51 |
+
self.fname_template = fname_template
|
52 |
+
self.react_template = REACT_PROMPT
|
53 |
+
self.prompt = ''
|
54 |
+
|
55 |
+
def build_prompt(self):
|
56 |
+
query = self._format_upload_fname() + self.query
|
57 |
+
tools_text = self._build_tools_text()
|
58 |
+
tools_name_text = self._build_tools_name_text()
|
59 |
+
planning_prompt = self.react_template.format(
|
60 |
+
query=query,
|
61 |
+
tools_text=tools_text,
|
62 |
+
tools_name_text=tools_name_text)
|
63 |
+
|
64 |
+
self.prompt = planning_prompt
|
65 |
+
return planning_prompt
|
66 |
+
|
67 |
+
def _format_upload_fname(self):
|
68 |
+
prefix = ''
|
69 |
+
if self.upload_file_paths:
|
70 |
+
fname_str = ', '.join(self.upload_file_paths)
|
71 |
+
lang_key = 'en_multi' if self.lang == 'en' and len(
|
72 |
+
self.upload_file_paths) > 1 else self.lang
|
73 |
+
fname_template = self.fname_template[lang_key]
|
74 |
+
prefix = fname_template.format(fname_str=fname_str)
|
75 |
+
return prefix
|
76 |
+
|
77 |
+
def _build_tools_text(self):
|
78 |
+
return tools_text
|
79 |
+
|
80 |
+
def _build_tools_name_text(self):
|
81 |
+
return 'code_interpreter'
|
82 |
+
|
83 |
+
def build_observation(self, observation):
|
84 |
+
return f'\nObservation: {observation}\nThought:'
|
85 |
+
|
86 |
+
def get_stop_words_list(self):
|
87 |
+
return ['Observation:', 'Observation:\n']
|
benchmark/requirements.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate>=0.20.3
|
2 |
+
func_timeout
|
3 |
+
json5
|
4 |
+
matplotlib
|
5 |
+
numpy
|
6 |
+
pandas
|
7 |
+
PrettyTable
|
8 |
+
scipy
|
9 |
+
seaborn
|
10 |
+
sympy
|
11 |
+
transformers==4.33.1
|
12 |
+
transformers_stream_generator
|
13 |
+
openai
|
benchmark/utils/__init__.py
ADDED
File without changes
|
benchmark/utils/code_utils.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
|
4 |
+
import json5
|
5 |
+
|
6 |
+
|
7 |
+
def replace_upload_fname(text, upload_fname_list):
|
8 |
+
for full_input_fname in upload_fname_list:
|
9 |
+
if full_input_fname not in text and os.path.basename(
|
10 |
+
full_input_fname) in text:
|
11 |
+
text = text.replace(os.path.basename(full_input_fname),
|
12 |
+
full_input_fname)
|
13 |
+
return text
|
14 |
+
|
15 |
+
|
16 |
+
def extract_code(text):
|
17 |
+
# Match triple backtick blocks first
|
18 |
+
triple_match = re.search(r'```[^\n]*\n(.+?)```', text, re.DOTALL)
|
19 |
+
# Match single backtick blocks second
|
20 |
+
single_match = re.search(r'`([^`]*)`', text, re.DOTALL)
|
21 |
+
if triple_match:
|
22 |
+
text = triple_match.group(1)
|
23 |
+
elif single_match:
|
24 |
+
text = single_match.group(1)
|
25 |
+
else:
|
26 |
+
try:
|
27 |
+
text = json5.loads(text)['code']
|
28 |
+
except Exception:
|
29 |
+
pass
|
30 |
+
# If no code blocks found, return original text
|
31 |
+
return text
|
benchmark/utils/data_utils.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
|
4 |
+
from tqdm import tqdm
|
5 |
+
|
6 |
+
|
7 |
+
def load_jsonl(path):
|
8 |
+
data = []
|
9 |
+
with open(path, 'r', encoding='utf8') as f:
|
10 |
+
for idx, line in enumerate(f, start=1):
|
11 |
+
try:
|
12 |
+
data.append(json.loads(line))
|
13 |
+
except Exception as e:
|
14 |
+
logging.info(line)
|
15 |
+
logging.warning(f'Error at line {idx}: {e}')
|
16 |
+
continue
|
17 |
+
return data
|
18 |
+
|
19 |
+
|
20 |
+
def save_jsonl(data, path, progress=False, enabled=True):
|
21 |
+
if not enabled:
|
22 |
+
return
|
23 |
+
with open(path, 'w', encoding='utf-8') as f:
|
24 |
+
if progress:
|
25 |
+
data = tqdm(data)
|
26 |
+
for item in data:
|
27 |
+
line = json.dumps(item, ensure_ascii=False)
|
28 |
+
print(line, file=f)
|
browser_qwen/background.js
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
var database;
|
2 |
+
|
3 |
+
function send_data(msg){
|
4 |
+
chrome.storage.local.get(['database_host'], function(result) {
|
5 |
+
if (result.database_host) {
|
6 |
+
console.log('database_host currently is ' + result.database_host);
|
7 |
+
database = "http://"+result.database_host+":7866/endpoint";
|
8 |
+
} else {
|
9 |
+
database = "http://127.0.0.1:7866/endpoint";
|
10 |
+
}
|
11 |
+
fetch(database, {
|
12 |
+
method: "POST",
|
13 |
+
headers: {
|
14 |
+
"Content-Type": "application/json",
|
15 |
+
},
|
16 |
+
body: JSON.stringify(msg),
|
17 |
+
})
|
18 |
+
.then((response) => response.json())
|
19 |
+
.then((data) => {
|
20 |
+
console.log(data.result);
|
21 |
+
});
|
22 |
+
});
|
23 |
+
}
|
24 |
+
|
25 |
+
chrome.runtime.onMessage.addListener(async (msg, sender) => {
|
26 |
+
if (msg.flag == "open_tab_and_cache_from_content"){
|
27 |
+
var url = "";
|
28 |
+
chrome.tabs.query({active: true, currentWindow: true}, function(tabs) {
|
29 |
+
url = tabs[0].url;
|
30 |
+
console.log(url);
|
31 |
+
if (msg.data) {
|
32 |
+
chrome.storage.sync.get(['data'], function(result) {
|
33 |
+
chrome.storage.sync.set({ data: result.data }, function() {
|
34 |
+
send_data({ 'content' : msg.data, 'query': '', 'url':url, 'task':'cache', 'type':msg.type});
|
35 |
+
});
|
36 |
+
});
|
37 |
+
}
|
38 |
+
});
|
39 |
+
}
|
40 |
+
if (msg.flag == "open_popup_and_send_url_from_popup"){
|
41 |
+
if (msg.data) {
|
42 |
+
chrome.storage.sync.get(['data'], function(result) {
|
43 |
+
chrome.storage.sync.set({ data: result.data }, function() {
|
44 |
+
send_data({ 'url' : msg.data, 'task':'pop_url'});
|
45 |
+
});
|
46 |
+
});
|
47 |
+
}
|
48 |
+
}
|
49 |
+
// if (msg.flag == "set_addr"){
|
50 |
+
// if (msg.data) {
|
51 |
+
// chrome.storage.sync.get(['data'], function(result) {
|
52 |
+
// chrome.storage.sync.set({ data: result.data }, function() {
|
53 |
+
// send_data({ 'addr' : msg.data, 'task':'set_addr'});
|
54 |
+
// });
|
55 |
+
// });
|
56 |
+
// }
|
57 |
+
// }
|
58 |
+
});
|
browser_qwen/img/copy.png
ADDED
browser_qwen/img/logo.png
ADDED
browser_qwen/img/popup.png
ADDED
browser_qwen/manifest.json
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"name": "BrowserQwen",
|
3 |
+
"description" : "An Extension Driven by LLM",
|
4 |
+
"version": "1.0",
|
5 |
+
"manifest_version": 3,
|
6 |
+
|
7 |
+
"background": {
|
8 |
+
"service_worker": "background.js"
|
9 |
+
},
|
10 |
+
|
11 |
+
"action": {
|
12 |
+
"default_popup": "src/popup.html",
|
13 |
+
"default_icon": "img/popup.png",
|
14 |
+
"default_title": "BrowserQwen"
|
15 |
+
},
|
16 |
+
"permissions": [
|
17 |
+
"tabs",
|
18 |
+
"notifications",
|
19 |
+
"storage",
|
20 |
+
"scripting",
|
21 |
+
"activeTab"
|
22 |
+
],
|
23 |
+
"host_permissions": [
|
24 |
+
"http://*/*",
|
25 |
+
"https://*/*"
|
26 |
+
],
|
27 |
+
"icons": {
|
28 |
+
"16": "img/popup.png",
|
29 |
+
"32": "img/popup.png",
|
30 |
+
"48": "img/popup.png",
|
31 |
+
"128": "img/popup.png"
|
32 |
+
},
|
33 |
+
"content_scripts": [
|
34 |
+
{
|
35 |
+
"js": ["src/content.js"],
|
36 |
+
"matches": [
|
37 |
+
"https://www.jianshu.com/p/*",
|
38 |
+
"https://*/*",
|
39 |
+
"http://*/*",
|
40 |
+
"file:///*/*"
|
41 |
+
]
|
42 |
+
}
|
43 |
+
]
|
44 |
+
|
45 |
+
}
|
browser_qwen/src/content.js
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
function getPageTextContent() {
|
3 |
+
var textContent = document.body.textContent;
|
4 |
+
return textContent;
|
5 |
+
}
|
6 |
+
|
7 |
+
function cache_browser(){
|
8 |
+
const body = document.querySelector('html');
|
9 |
+
const text = body.innerHTML;
|
10 |
+
console.log(text);
|
11 |
+
chrome.runtime.sendMessage({ data: text , close: true , flag: 'open_tab_and_cache_from_content', type: 'html'});
|
12 |
+
|
13 |
+
}
|
14 |
+
|
15 |
+
const floatingBox = document.createElement('div');
|
16 |
+
floatingBox.style.position = 'fixed';
|
17 |
+
floatingBox.style.bottom = '650px';
|
18 |
+
floatingBox.style.right = '60px';
|
19 |
+
floatingBox.style.width = '125px';
|
20 |
+
floatingBox.style.height = '55px';
|
21 |
+
floatingBox.style.backgroundColor = '#f2f2f2';
|
22 |
+
floatingBox.style.border = '1px solid black';
|
23 |
+
floatingBox.style.borderRadius = '5px';
|
24 |
+
floatingBox.style.padding = '10px';
|
25 |
+
floatingBox.style.zIndex = '9999';
|
26 |
+
|
27 |
+
const button = document.createElement('button');
|
28 |
+
button.style.position = 'fixed';
|
29 |
+
button.style.top = '30px';
|
30 |
+
button.style.right = '30px';
|
31 |
+
button.style.zIndex = "9999";
|
32 |
+
button.textContent = "Add to Qwen's Reading List";
|
33 |
+
button.style.fontFamily = 'Arial, sans-serif';
|
34 |
+
button.style.fontSize = '14px';
|
35 |
+
button.style.width = '140px';
|
36 |
+
button.style.height = '60px';
|
37 |
+
button.style.backgroundColor = '#695DE8';
|
38 |
+
button.style.color = 'white';
|
39 |
+
button.style.borderRadius = '5px';
|
40 |
+
button.style.border = '0px';
|
41 |
+
button.style.whiteSpace = 'pre-wrap';
|
42 |
+
button.style.boxShadow = '0 4px 6px rgba(0, 0, 0, 0.2)';
|
43 |
+
|
44 |
+
floatingBox.appendChild(button);
|
45 |
+
|
46 |
+
document.body.appendChild(button);
|
47 |
+
|
48 |
+
let isDragging = false;
|
49 |
+
var isMouseReleased = false;
|
50 |
+
let initialX;
|
51 |
+
let initialY;
|
52 |
+
|
53 |
+
button.addEventListener('mousedown', (e) => {
|
54 |
+
isDragging = true;
|
55 |
+
initialX = e.clientX;
|
56 |
+
initialY = e.clientY;
|
57 |
+
});
|
58 |
+
|
59 |
+
document.addEventListener('mousemove', (e) => {
|
60 |
+
if (isDragging) {
|
61 |
+
const dx = e.clientX - initialX;
|
62 |
+
const dy = e.clientY - initialY;
|
63 |
+
button.style.right = `${parseFloat(button.style.right) - dx}px`;
|
64 |
+
button.style.top = `${parseFloat(button.style.top) + dy}px`;
|
65 |
+
initialX = e.clientX;
|
66 |
+
initialY = e.clientY;
|
67 |
+
isMouseReleased = true;
|
68 |
+
}
|
69 |
+
});
|
70 |
+
|
71 |
+
document.addEventListener('mouseup', (e) => {
|
72 |
+
isDragging = false;
|
73 |
+
|
74 |
+
});
|
75 |
+
|
76 |
+
button.addEventListener('click', (e) => {
|
77 |
+
if (isMouseReleased) {
|
78 |
+
isMouseReleased = false;
|
79 |
+
e.stopPropagation();
|
80 |
+
} else {
|
81 |
+
var result = confirm("Are you sure to ask Qwen to remember this page?");
|
82 |
+
if (result) {
|
83 |
+
cache_browser()
|
84 |
+
}
|
85 |
+
}
|
86 |
+
});
|
browser_qwen/src/popup.html
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html>
|
3 |
+
<head>
|
4 |
+
<meta charset="UTF-8">
|
5 |
+
<meta http-equiv="X-UA-Compatible" content="IE=edge">
|
6 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
7 |
+
<title>BrowserQwen</title>
|
8 |
+
<!-- <script src="popup.js"></script>-->
|
9 |
+
<style>
|
10 |
+
.title-style {
|
11 |
+
margin-top: 10px;
|
12 |
+
margin-left: 5px;
|
13 |
+
font-family: Arial, sans-serif;
|
14 |
+
font-size: 24px;
|
15 |
+
color: #333;
|
16 |
+
}
|
17 |
+
|
18 |
+
body {
|
19 |
+
width: 500px;
|
20 |
+
height: 600px;
|
21 |
+
background-color: aliceblue;
|
22 |
+
}
|
23 |
+
.contnet {
|
24 |
+
display: flex;
|
25 |
+
flex-direction: column;
|
26 |
+
flex-wrap: nowrap;
|
27 |
+
align-items: center;
|
28 |
+
}
|
29 |
+
.upload_file {
|
30 |
+
display: flex;
|
31 |
+
flex-direction: column;
|
32 |
+
margin-top: 10px;
|
33 |
+
}
|
34 |
+
.upload_btn {
|
35 |
+
|
36 |
+
border: none;
|
37 |
+
border-radius: 5px;
|
38 |
+
background-color: #5c5cdf;
|
39 |
+
font-size: 16px;
|
40 |
+
color: white;
|
41 |
+
font-weight: 400;
|
42 |
+
cursor: pointer;
|
43 |
+
width: 70px;
|
44 |
+
height: 35px;
|
45 |
+
}
|
46 |
+
.upload_btn:hover {
|
47 |
+
border: none;
|
48 |
+
border-radius: 10px;
|
49 |
+
background-color: #7f7ff2;
|
50 |
+
font-size: 16px;
|
51 |
+
color: white;
|
52 |
+
font-weight: 400;
|
53 |
+
cursor: pointer;
|
54 |
+
box-shadow: 0px 0px 0px 1px #848181;
|
55 |
+
}
|
56 |
+
|
57 |
+
.div-with-copy-paste {
|
58 |
+
width: 340px;
|
59 |
+
height: 200px;
|
60 |
+
background-color: #f2f2f2;
|
61 |
+
border: 1px solid #ccc;
|
62 |
+
border-radius: 5px;
|
63 |
+
padding: 20px;
|
64 |
+
box-shadow: 0 2px 5px rgba(0, 0, 0, 0.3);
|
65 |
+
position: relative;
|
66 |
+
}
|
67 |
+
|
68 |
+
.input-text {
|
69 |
+
width: 300px;
|
70 |
+
/* height: 100px; */
|
71 |
+
background-color: #fffdfd;
|
72 |
+
border: 1px solid #ccc;
|
73 |
+
border-radius: 10px;
|
74 |
+
padding: 20px;
|
75 |
+
box-shadow: 0 2px 5px rgba(0, 0, 0, 0.3);
|
76 |
+
position: left;
|
77 |
+
margin-top: 5px;
|
78 |
+
margin-bottom: 5px;
|
79 |
+
}
|
80 |
+
|
81 |
+
.copy-button {
|
82 |
+
/* position: absolute;
|
83 |
+
top: 10px;
|
84 |
+
right: 10px; */
|
85 |
+
background-color: #5085cf;
|
86 |
+
color: white;
|
87 |
+
padding: 8px 8px;
|
88 |
+
border: none;
|
89 |
+
border-radius: 3px;
|
90 |
+
cursor: pointer;
|
91 |
+
}
|
92 |
+
|
93 |
+
.copy-button:hover {
|
94 |
+
background-color: #45a049;
|
95 |
+
}
|
96 |
+
|
97 |
+
::placeholder {
|
98 |
+
color: rgb(197, 196, 196);
|
99 |
+
}
|
100 |
+
|
101 |
+
iframe {
|
102 |
+
width: 100%;
|
103 |
+
height: 100%;
|
104 |
+
border: none;
|
105 |
+
}
|
106 |
+
</style>
|
107 |
+
|
108 |
+
</head>
|
109 |
+
|
110 |
+
|
111 |
+
<body>
|
112 |
+
<!-- <iframe src=$popup_url style="height: 550px"></iframe>-->
|
113 |
+
<div id="iframe_area" style="height: 570px"></div>
|
114 |
+
|
115 |
+
<h3>Customize Address:</h3>
|
116 |
+
<input type="text" id="addr" name="addr" class="input-text">
|
117 |
+
<button id="set_addr" class="upload_btn">Change</button>
|
118 |
+
|
119 |
+
<script src="popup.js"></script>
|
120 |
+
</body>
|
121 |
+
</html>
|
browser_qwen/src/popup.js
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
chrome.runtime.onMessage.addListener((msg, sender, sendResponse) => {
|
4 |
+
// if (msg.flag == 'from_content'){
|
5 |
+
// console.log(msg.rsp);
|
6 |
+
// var sessionContainer = document.getElementById('session');
|
7 |
+
// sessionContainer.innerText = msg.rsp;
|
8 |
+
// sendResponse({ msg: 'Get!' });
|
9 |
+
// }
|
10 |
+
if (msg.flag === 'from_llm'){
|
11 |
+
// var sessionContainer = document.getElementById('session');
|
12 |
+
// // sessionContainer.innerHTML = msg.rsp;
|
13 |
+
// sessionContainer.innerText = msg.rsp;
|
14 |
+
sendResponse({ message: 'Get Response!' });
|
15 |
+
}
|
16 |
+
});
|
17 |
+
|
18 |
+
|
19 |
+
document.addEventListener('DOMContentLoaded', function() {
|
20 |
+
chrome.tabs.query({ active: true, currentWindow: true }, function(tabs) {
|
21 |
+
var currentUrl = tabs[0].url;
|
22 |
+
|
23 |
+
chrome.runtime.sendMessage({ data: currentUrl , close: true , flag: 'open_popup_and_send_url_from_popup'});
|
24 |
+
|
25 |
+
});
|
26 |
+
setTimeout(function() {
|
27 |
+
// console.log('This message will be logged after 0.5 second');
|
28 |
+
var popup_url='';
|
29 |
+
chrome.storage.local.get(['database_host'], function(result) {
|
30 |
+
if (result.database_host) {
|
31 |
+
console.log('database_host currently is ' + result.database_host);
|
32 |
+
popup_url = "http://"+result.database_host+":7863/";
|
33 |
+
} else {
|
34 |
+
popup_url = "http://127.0.0.1:7863/";
|
35 |
+
}
|
36 |
+
var iframe = document.createElement('iframe');
|
37 |
+
iframe.src = popup_url;
|
38 |
+
iframe.height = '570px';
|
39 |
+
// iframe.sandbox = 'allow-same-origin allow-scripts';
|
40 |
+
// iframe.allow = "geolocation *;";
|
41 |
+
var iframe_area = document.getElementById('iframe_area')
|
42 |
+
iframe_area.appendChild(iframe);
|
43 |
+
|
44 |
+
});
|
45 |
+
}, 500);
|
46 |
+
|
47 |
+
// fetch('../config_host.json')
|
48 |
+
// .then(response => response.json())
|
49 |
+
// .then(data => {
|
50 |
+
// console.log(data);
|
51 |
+
// popup_url = "http://"+data.database_host+":"+data.app_in_browser_port+"/";
|
52 |
+
// console.log(popup_url);
|
53 |
+
// })
|
54 |
+
// .catch(error => console.error('Error:', error));
|
55 |
+
})
|
56 |
+
|
57 |
+
document.getElementById('set_addr').addEventListener('click', function() {
|
58 |
+
var addr = document.getElementById('addr').value;
|
59 |
+
// save config
|
60 |
+
chrome.storage.local.set({database_host: addr}, function() {
|
61 |
+
console.log('database_host is set to ' + addr);
|
62 |
+
// chrome.runtime.sendMessage({ data: addr , close: true , flag: 'set_addr'});
|
63 |
+
document.getElementById('addr').value = '';
|
64 |
+
});
|
65 |
+
})
|
openai_api.py
ADDED
@@ -0,0 +1,564 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Implements API for Qwen-7B in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat)
|
3 |
+
# Usage: python openai_api.py
|
4 |
+
# Visit http://localhost:8000/docs for documents.
|
5 |
+
|
6 |
+
import re
|
7 |
+
import copy
|
8 |
+
import json
|
9 |
+
import time
|
10 |
+
from argparse import ArgumentParser
|
11 |
+
from contextlib import asynccontextmanager
|
12 |
+
from typing import Dict, List, Literal, Optional, Union
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import uvicorn
|
16 |
+
from fastapi import FastAPI, HTTPException
|
17 |
+
from fastapi.middleware.cors import CORSMiddleware
|
18 |
+
from pydantic import BaseModel, Field
|
19 |
+
from sse_starlette.sse import EventSourceResponse
|
20 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
21 |
+
from transformers.generation import GenerationConfig
|
22 |
+
from starlette.middleware.base import BaseHTTPMiddleware
|
23 |
+
from starlette.requests import Request
|
24 |
+
from starlette.responses import Response
|
25 |
+
import base64
|
26 |
+
|
27 |
+
|
28 |
+
class BasicAuthMiddleware(BaseHTTPMiddleware):
|
29 |
+
def __init__(self, app, username: str, password: str):
|
30 |
+
super().__init__(app)
|
31 |
+
self.required_credentials = base64.b64encode(f"{username}:{password}".encode()).decode()
|
32 |
+
|
33 |
+
async def dispatch(self, request: Request, call_next):
|
34 |
+
authorization: str = request.headers.get("Authorization")
|
35 |
+
if authorization:
|
36 |
+
try:
|
37 |
+
schema, credentials = authorization.split()
|
38 |
+
if credentials == self.required_credentials:
|
39 |
+
return await call_next(request)
|
40 |
+
except ValueError:
|
41 |
+
pass
|
42 |
+
|
43 |
+
headers = {'WWW-Authenticate': 'Basic'}
|
44 |
+
return Response(status_code=401, headers=headers)
|
45 |
+
|
46 |
+
|
47 |
+
def _gc(forced: bool = False):
|
48 |
+
global args
|
49 |
+
if args.disable_gc and not forced:
|
50 |
+
return
|
51 |
+
|
52 |
+
import gc
|
53 |
+
gc.collect()
|
54 |
+
if torch.cuda.is_available():
|
55 |
+
torch.cuda.empty_cache()
|
56 |
+
|
57 |
+
|
58 |
+
@asynccontextmanager
|
59 |
+
async def lifespan(app: FastAPI): # collects GPU memory
|
60 |
+
yield
|
61 |
+
_gc(forced=True)
|
62 |
+
|
63 |
+
|
64 |
+
app = FastAPI(lifespan=lifespan)
|
65 |
+
|
66 |
+
app.add_middleware(
|
67 |
+
CORSMiddleware,
|
68 |
+
allow_origins=["*"],
|
69 |
+
allow_credentials=True,
|
70 |
+
allow_methods=["*"],
|
71 |
+
allow_headers=["*"],
|
72 |
+
)
|
73 |
+
|
74 |
+
|
75 |
+
class ModelCard(BaseModel):
|
76 |
+
id: str
|
77 |
+
object: str = "model"
|
78 |
+
created: int = Field(default_factory=lambda: int(time.time()))
|
79 |
+
owned_by: str = "owner"
|
80 |
+
root: Optional[str] = None
|
81 |
+
parent: Optional[str] = None
|
82 |
+
permission: Optional[list] = None
|
83 |
+
|
84 |
+
|
85 |
+
class ModelList(BaseModel):
|
86 |
+
object: str = "list"
|
87 |
+
data: List[ModelCard] = []
|
88 |
+
|
89 |
+
|
90 |
+
class ChatMessage(BaseModel):
|
91 |
+
role: Literal["user", "assistant", "system", "function"]
|
92 |
+
content: Optional[str]
|
93 |
+
function_call: Optional[Dict] = None
|
94 |
+
|
95 |
+
|
96 |
+
class DeltaMessage(BaseModel):
|
97 |
+
role: Optional[Literal["user", "assistant", "system"]] = None
|
98 |
+
content: Optional[str] = None
|
99 |
+
|
100 |
+
|
101 |
+
class ChatCompletionRequest(BaseModel):
|
102 |
+
model: str
|
103 |
+
messages: List[ChatMessage]
|
104 |
+
functions: Optional[List[Dict]] = None
|
105 |
+
temperature: Optional[float] = None
|
106 |
+
top_p: Optional[float] = None
|
107 |
+
max_length: Optional[int] = None
|
108 |
+
stream: Optional[bool] = False
|
109 |
+
stop: Optional[List[str]] = None
|
110 |
+
|
111 |
+
|
112 |
+
class ChatCompletionResponseChoice(BaseModel):
|
113 |
+
index: int
|
114 |
+
message: ChatMessage
|
115 |
+
finish_reason: Literal["stop", "length", "function_call"]
|
116 |
+
|
117 |
+
|
118 |
+
class ChatCompletionResponseStreamChoice(BaseModel):
|
119 |
+
index: int
|
120 |
+
delta: DeltaMessage
|
121 |
+
finish_reason: Optional[Literal["stop", "length"]]
|
122 |
+
|
123 |
+
|
124 |
+
class ChatCompletionResponse(BaseModel):
|
125 |
+
model: str
|
126 |
+
object: Literal["chat.completion", "chat.completion.chunk"]
|
127 |
+
choices: List[
|
128 |
+
Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]
|
129 |
+
]
|
130 |
+
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
131 |
+
|
132 |
+
|
133 |
+
@app.get("/v1/models", response_model=ModelList)
|
134 |
+
async def list_models():
|
135 |
+
global model_args
|
136 |
+
model_card = ModelCard(id="gpt-3.5-turbo")
|
137 |
+
return ModelList(data=[model_card])
|
138 |
+
|
139 |
+
|
140 |
+
# To work around that unpleasant leading-\n tokenization issue!
|
141 |
+
def add_extra_stop_words(stop_words):
|
142 |
+
if stop_words:
|
143 |
+
_stop_words = []
|
144 |
+
_stop_words.extend(stop_words)
|
145 |
+
for x in stop_words:
|
146 |
+
s = x.lstrip("\n")
|
147 |
+
if s and (s not in _stop_words):
|
148 |
+
_stop_words.append(s)
|
149 |
+
return _stop_words
|
150 |
+
return stop_words
|
151 |
+
|
152 |
+
|
153 |
+
def trim_stop_words(response, stop_words):
|
154 |
+
if stop_words:
|
155 |
+
for stop in stop_words:
|
156 |
+
idx = response.find(stop)
|
157 |
+
if idx != -1:
|
158 |
+
response = response[:idx]
|
159 |
+
return response
|
160 |
+
|
161 |
+
|
162 |
+
TOOL_DESC = """{name_for_model}: Call this tool to interact with the {name_for_human} API. What is the {name_for_human} API useful for? {description_for_model} Parameters: {parameters}"""
|
163 |
+
|
164 |
+
REACT_INSTRUCTION = """Answer the following questions as best you can. You have access to the following APIs:
|
165 |
+
|
166 |
+
{tools_text}
|
167 |
+
|
168 |
+
Use the following format:
|
169 |
+
|
170 |
+
Question: the input question you must answer
|
171 |
+
Thought: you should always think about what to do
|
172 |
+
Action: the action to take, should be one of [{tools_name_text}]
|
173 |
+
Action Input: the input to the action
|
174 |
+
Observation: the result of the action
|
175 |
+
... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
|
176 |
+
Thought: I now know the final answer
|
177 |
+
Final Answer: the final answer to the original input question
|
178 |
+
|
179 |
+
Begin!"""
|
180 |
+
|
181 |
+
_TEXT_COMPLETION_CMD = object()
|
182 |
+
|
183 |
+
|
184 |
+
#
|
185 |
+
# Temporarily, the system role does not work as expected.
|
186 |
+
# We advise that you write the setups for role-play in your query,
|
187 |
+
# i.e., use the user role instead of the system role.
|
188 |
+
#
|
189 |
+
# TODO: Use real system role when the model is ready.
|
190 |
+
#
|
191 |
+
def parse_messages(messages, functions):
|
192 |
+
if all(m.role != "user" for m in messages):
|
193 |
+
raise HTTPException(
|
194 |
+
status_code=400,
|
195 |
+
detail=f"Invalid request: Expecting at least one user message.",
|
196 |
+
)
|
197 |
+
|
198 |
+
messages = copy.deepcopy(messages)
|
199 |
+
default_system = "You are a helpful assistant."
|
200 |
+
system = ""
|
201 |
+
if messages[0].role == "system":
|
202 |
+
system = messages.pop(0).content.lstrip("\n").rstrip()
|
203 |
+
if system == default_system:
|
204 |
+
system = ""
|
205 |
+
|
206 |
+
if functions:
|
207 |
+
tools_text = []
|
208 |
+
tools_name_text = []
|
209 |
+
for func_info in functions:
|
210 |
+
name = func_info.get("name", "")
|
211 |
+
name_m = func_info.get("name_for_model", name)
|
212 |
+
name_h = func_info.get("name_for_human", name)
|
213 |
+
desc = func_info.get("description", "")
|
214 |
+
desc_m = func_info.get("description_for_model", desc)
|
215 |
+
tool = TOOL_DESC.format(
|
216 |
+
name_for_model=name_m,
|
217 |
+
name_for_human=name_h,
|
218 |
+
# Hint: You can add the following format requirements in description:
|
219 |
+
# "Format the arguments as a JSON object."
|
220 |
+
# "Enclose the code within triple backticks (`) at the beginning and end of the code."
|
221 |
+
description_for_model=desc_m,
|
222 |
+
parameters=json.dumps(func_info["parameters"], ensure_ascii=False),
|
223 |
+
)
|
224 |
+
tools_text.append(tool)
|
225 |
+
tools_name_text.append(name_m)
|
226 |
+
tools_text = "\n\n".join(tools_text)
|
227 |
+
tools_name_text = ", ".join(tools_name_text)
|
228 |
+
system += "\n\n" + REACT_INSTRUCTION.format(
|
229 |
+
tools_text=tools_text,
|
230 |
+
tools_name_text=tools_name_text,
|
231 |
+
)
|
232 |
+
system = system.lstrip("\n").rstrip()
|
233 |
+
|
234 |
+
dummy_thought = {
|
235 |
+
"en": "\nThought: I now know the final answer.\nFinal answer: ",
|
236 |
+
"zh": "\nThought: 我会作答了。\nFinal answer: ",
|
237 |
+
}
|
238 |
+
|
239 |
+
_messages = messages
|
240 |
+
messages = []
|
241 |
+
for m_idx, m in enumerate(_messages):
|
242 |
+
role, content, func_call = m.role, m.content, m.function_call
|
243 |
+
if content:
|
244 |
+
content = content.lstrip("\n").rstrip()
|
245 |
+
if role == "function":
|
246 |
+
if (len(messages) == 0) or (messages[-1].role != "assistant"):
|
247 |
+
raise HTTPException(
|
248 |
+
status_code=400,
|
249 |
+
detail=f"Invalid request: Expecting role assistant before role function.",
|
250 |
+
)
|
251 |
+
messages[-1].content += f"\nObservation: {content}"
|
252 |
+
if m_idx == len(_messages) - 1:
|
253 |
+
messages[-1].content += "\nThought:"
|
254 |
+
elif role == "assistant":
|
255 |
+
if len(messages) == 0:
|
256 |
+
raise HTTPException(
|
257 |
+
status_code=400,
|
258 |
+
detail=f"Invalid request: Expecting role user before role assistant.",
|
259 |
+
)
|
260 |
+
last_msg = messages[-1].content
|
261 |
+
last_msg_has_zh = len(re.findall(r"[\u4e00-\u9fff]+", last_msg)) > 0
|
262 |
+
if func_call is None:
|
263 |
+
if functions:
|
264 |
+
content = dummy_thought["zh" if last_msg_has_zh else "en"] + content
|
265 |
+
else:
|
266 |
+
f_name, f_args = func_call["name"], func_call["arguments"]
|
267 |
+
if not content:
|
268 |
+
if last_msg_has_zh:
|
269 |
+
content = f"Thought: 我可以使用 {f_name} API。"
|
270 |
+
else:
|
271 |
+
content = f"Thought: I can use {f_name}."
|
272 |
+
content = f"\n{content}\nAction: {f_name}\nAction Input: {f_args}"
|
273 |
+
if messages[-1].role == "user":
|
274 |
+
messages.append(
|
275 |
+
ChatMessage(role="assistant", content=content.lstrip("\n").rstrip())
|
276 |
+
)
|
277 |
+
else:
|
278 |
+
messages[-1].content += content
|
279 |
+
elif role == "user":
|
280 |
+
messages.append(
|
281 |
+
ChatMessage(role="user", content=content.lstrip("\n").rstrip())
|
282 |
+
)
|
283 |
+
else:
|
284 |
+
raise HTTPException(
|
285 |
+
status_code=400, detail=f"Invalid request: Incorrect role {role}."
|
286 |
+
)
|
287 |
+
|
288 |
+
query = _TEXT_COMPLETION_CMD
|
289 |
+
if messages[-1].role == "user":
|
290 |
+
query = messages[-1].content
|
291 |
+
messages = messages[:-1]
|
292 |
+
|
293 |
+
if len(messages) % 2 != 0:
|
294 |
+
raise HTTPException(status_code=400, detail="Invalid request")
|
295 |
+
|
296 |
+
history = [] # [(Q1, A1), (Q2, A2), ..., (Q_last_turn, A_last_turn)]
|
297 |
+
for i in range(0, len(messages), 2):
|
298 |
+
if messages[i].role == "user" and messages[i + 1].role == "assistant":
|
299 |
+
usr_msg = messages[i].content.lstrip("\n").rstrip()
|
300 |
+
bot_msg = messages[i + 1].content.lstrip("\n").rstrip()
|
301 |
+
if system and (i == len(messages) - 2):
|
302 |
+
usr_msg = f"{system}\n\nQuestion: {usr_msg}"
|
303 |
+
system = ""
|
304 |
+
for t in dummy_thought.values():
|
305 |
+
t = t.lstrip("\n")
|
306 |
+
if bot_msg.startswith(t) and ("\nAction: " in bot_msg):
|
307 |
+
bot_msg = bot_msg[len(t):]
|
308 |
+
history.append([usr_msg, bot_msg])
|
309 |
+
else:
|
310 |
+
raise HTTPException(
|
311 |
+
status_code=400,
|
312 |
+
detail="Invalid request: Expecting exactly one user (or function) role before every assistant role.",
|
313 |
+
)
|
314 |
+
if system:
|
315 |
+
assert query is not _TEXT_COMPLETION_CMD
|
316 |
+
query = f"{system}\n\nQuestion: {query}"
|
317 |
+
return query, history
|
318 |
+
|
319 |
+
|
320 |
+
def parse_response(response):
|
321 |
+
func_name, func_args = "", ""
|
322 |
+
i = response.rfind("\nAction:")
|
323 |
+
j = response.rfind("\nAction Input:")
|
324 |
+
k = response.rfind("\nObservation:")
|
325 |
+
if 0 <= i < j: # If the text has `Action` and `Action input`,
|
326 |
+
if k < j: # but does not contain `Observation`,
|
327 |
+
# then it is likely that `Observation` is omitted by the LLM,
|
328 |
+
# because the output text may have discarded the stop word.
|
329 |
+
response = response.rstrip() + "\nObservation:" # Add it back.
|
330 |
+
k = response.rfind("\nObservation:")
|
331 |
+
func_name = response[i + len("\nAction:"): j].strip()
|
332 |
+
func_args = response[j + len("\nAction Input:"): k].strip()
|
333 |
+
if func_name:
|
334 |
+
choice_data = ChatCompletionResponseChoice(
|
335 |
+
index=0,
|
336 |
+
message=ChatMessage(
|
337 |
+
role="assistant",
|
338 |
+
content=response[:i],
|
339 |
+
function_call={"name": func_name, "arguments": func_args},
|
340 |
+
),
|
341 |
+
finish_reason="function_call",
|
342 |
+
)
|
343 |
+
return choice_data
|
344 |
+
z = response.rfind("\nFinal Answer: ")
|
345 |
+
if z >= 0:
|
346 |
+
response = response[z + len("\nFinal Answer: "):]
|
347 |
+
choice_data = ChatCompletionResponseChoice(
|
348 |
+
index=0,
|
349 |
+
message=ChatMessage(role="assistant", content=response),
|
350 |
+
finish_reason="stop",
|
351 |
+
)
|
352 |
+
return choice_data
|
353 |
+
|
354 |
+
|
355 |
+
# completion mode, not chat mode
|
356 |
+
def text_complete_last_message(history, stop_words_ids, gen_kwargs):
|
357 |
+
im_start = "<|im_start|>"
|
358 |
+
im_end = "<|im_end|>"
|
359 |
+
prompt = f"{im_start}system\nYou are a helpful assistant.{im_end}"
|
360 |
+
for i, (query, response) in enumerate(history):
|
361 |
+
query = query.lstrip("\n").rstrip()
|
362 |
+
response = response.lstrip("\n").rstrip()
|
363 |
+
prompt += f"\n{im_start}user\n{query}{im_end}"
|
364 |
+
prompt += f"\n{im_start}assistant\n{response}{im_end}"
|
365 |
+
prompt = prompt[: -len(im_end)]
|
366 |
+
|
367 |
+
_stop_words_ids = [tokenizer.encode(im_end)]
|
368 |
+
if stop_words_ids:
|
369 |
+
for s in stop_words_ids:
|
370 |
+
_stop_words_ids.append(s)
|
371 |
+
stop_words_ids = _stop_words_ids
|
372 |
+
|
373 |
+
input_ids = torch.tensor([tokenizer.encode(prompt)]).to(model.device)
|
374 |
+
output = model.generate(input_ids, stop_words_ids=stop_words_ids, **gen_kwargs).tolist()[0]
|
375 |
+
output = tokenizer.decode(output, errors="ignore")
|
376 |
+
assert output.startswith(prompt)
|
377 |
+
output = output[len(prompt):]
|
378 |
+
output = trim_stop_words(output, ["<|endoftext|>", im_end])
|
379 |
+
print(f"<completion>\n{prompt}\n<!-- *** -->\n{output}\n</completion>")
|
380 |
+
return output
|
381 |
+
|
382 |
+
|
383 |
+
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
|
384 |
+
async def create_chat_completion(request: ChatCompletionRequest):
|
385 |
+
global model, tokenizer
|
386 |
+
|
387 |
+
gen_kwargs = {}
|
388 |
+
if request.temperature is not None:
|
389 |
+
if request.temperature < 0.01:
|
390 |
+
gen_kwargs['top_k'] = 1 # greedy decoding
|
391 |
+
else:
|
392 |
+
# Not recommended. Please tune top_p instead.
|
393 |
+
gen_kwargs['temperature'] = request.temperature
|
394 |
+
if request.top_p is not None:
|
395 |
+
gen_kwargs['top_p'] = request.top_p
|
396 |
+
|
397 |
+
stop_words = add_extra_stop_words(request.stop)
|
398 |
+
if request.functions:
|
399 |
+
stop_words = stop_words or []
|
400 |
+
if "Observation:" not in stop_words:
|
401 |
+
stop_words.append("Observation:")
|
402 |
+
|
403 |
+
query, history = parse_messages(request.messages, request.functions)
|
404 |
+
|
405 |
+
if request.stream:
|
406 |
+
if request.functions:
|
407 |
+
raise HTTPException(
|
408 |
+
status_code=400,
|
409 |
+
detail="Invalid request: Function calling is not yet implemented for stream mode.",
|
410 |
+
)
|
411 |
+
generate = predict(query, history, request.model, stop_words, gen_kwargs)
|
412 |
+
return generate
|
413 |
+
# return EventSourceResponse(generate, media_type="text/event-stream")
|
414 |
+
|
415 |
+
stop_words_ids = [tokenizer.encode(s) for s in stop_words] if stop_words else None
|
416 |
+
if query is _TEXT_COMPLETION_CMD:
|
417 |
+
response = text_complete_last_message(history, stop_words_ids=stop_words_ids, gen_kwargs=gen_kwargs)
|
418 |
+
else:
|
419 |
+
response, _ = model.chat(
|
420 |
+
tokenizer,
|
421 |
+
query,
|
422 |
+
history=history,
|
423 |
+
stop_words_ids=stop_words_ids,
|
424 |
+
**gen_kwargs
|
425 |
+
)
|
426 |
+
print(f"<chat>\n{history}\n{query}\n<!-- *** -->\n{response}\n</chat>")
|
427 |
+
_gc()
|
428 |
+
|
429 |
+
response = trim_stop_words(response, stop_words)
|
430 |
+
if request.functions:
|
431 |
+
choice_data = parse_response(response)
|
432 |
+
else:
|
433 |
+
choice_data = ChatCompletionResponseChoice(
|
434 |
+
index=0,
|
435 |
+
message=ChatMessage(role="assistant", content=response),
|
436 |
+
finish_reason="stop",
|
437 |
+
)
|
438 |
+
return ChatCompletionResponse(
|
439 |
+
model=request.model, choices=[choice_data], object="chat.completion"
|
440 |
+
)
|
441 |
+
|
442 |
+
|
443 |
+
def _dump_json(data: BaseModel, *args, **kwargs) -> str:
|
444 |
+
try:
|
445 |
+
return data.model_dump_json(*args, **kwargs)
|
446 |
+
except AttributeError: # pydantic<2.0.0
|
447 |
+
return data.json(*args, **kwargs) # noqa
|
448 |
+
|
449 |
+
|
450 |
+
async def predict(
|
451 |
+
query: str, history: List[List[str]], model_id: str, stop_words: List[str], gen_kwargs: Dict,
|
452 |
+
):
|
453 |
+
global model, tokenizer
|
454 |
+
choice_data = ChatCompletionResponseStreamChoice(
|
455 |
+
index=0, delta=DeltaMessage(role="assistant"), finish_reason=None
|
456 |
+
)
|
457 |
+
chunk = ChatCompletionResponse(
|
458 |
+
model=model_id, choices=[choice_data], object="chat.completion.chunk"
|
459 |
+
)
|
460 |
+
yield "{}".format(_dump_json(chunk, exclude_unset=True))
|
461 |
+
|
462 |
+
current_length = 0
|
463 |
+
stop_words_ids = [tokenizer.encode(s) for s in stop_words] if stop_words else None
|
464 |
+
if stop_words:
|
465 |
+
# TODO: It's a little bit tricky to trim stop words in the stream mode.
|
466 |
+
raise HTTPException(
|
467 |
+
status_code=400,
|
468 |
+
detail="Invalid request: custom stop words are not yet supported for stream mode.",
|
469 |
+
)
|
470 |
+
response_generator = model.chat_stream(
|
471 |
+
tokenizer, query, history=history, stop_words_ids=stop_words_ids, **gen_kwargs
|
472 |
+
)
|
473 |
+
for new_response in response_generator:
|
474 |
+
if len(new_response) == current_length:
|
475 |
+
continue
|
476 |
+
|
477 |
+
new_text = new_response[current_length:]
|
478 |
+
current_length = len(new_response)
|
479 |
+
|
480 |
+
choice_data = ChatCompletionResponseStreamChoice(
|
481 |
+
index=0, delta=DeltaMessage(content=new_text), finish_reason=None
|
482 |
+
)
|
483 |
+
chunk = ChatCompletionResponse(
|
484 |
+
model=model_id, choices=[choice_data], object="chat.completion.chunk"
|
485 |
+
)
|
486 |
+
yield "{}".format(_dump_json(chunk, exclude_unset=True))
|
487 |
+
|
488 |
+
choice_data = ChatCompletionResponseStreamChoice(
|
489 |
+
index=0, delta=DeltaMessage(), finish_reason="stop"
|
490 |
+
)
|
491 |
+
chunk = ChatCompletionResponse(
|
492 |
+
model=model_id, choices=[choice_data], object="chat.completion.chunk"
|
493 |
+
)
|
494 |
+
yield "{}".format(_dump_json(chunk, exclude_unset=True))
|
495 |
+
yield "[DONE]"
|
496 |
+
|
497 |
+
_gc()
|
498 |
+
|
499 |
+
|
500 |
+
def _get_args():
|
501 |
+
parser = ArgumentParser()
|
502 |
+
parser.add_argument(
|
503 |
+
"-c",
|
504 |
+
"--checkpoint-path",
|
505 |
+
type=str,
|
506 |
+
default="Qwen/Qwen-7B-Chat",
|
507 |
+
help="Checkpoint name or path, default to %(default)r",
|
508 |
+
)
|
509 |
+
parser.add_argument(
|
510 |
+
"--api-auth", help="API authentication credentials"
|
511 |
+
)
|
512 |
+
parser.add_argument(
|
513 |
+
"--cpu-only", action="store_true", help="Run demo with CPU only"
|
514 |
+
)
|
515 |
+
parser.add_argument(
|
516 |
+
"--server-port", type=int, default=8000, help="Demo server port."
|
517 |
+
)
|
518 |
+
parser.add_argument(
|
519 |
+
"--server-name",
|
520 |
+
type=str,
|
521 |
+
default="127.0.0.1",
|
522 |
+
help="Demo server name. Default: 127.0.0.1, which is only visible from the local computer."
|
523 |
+
" If you want other computers to access your server, use 0.0.0.0 instead.",
|
524 |
+
)
|
525 |
+
parser.add_argument("--disable-gc", action="store_true",
|
526 |
+
help="Disable GC after each response generated.")
|
527 |
+
|
528 |
+
args = parser.parse_args()
|
529 |
+
return args
|
530 |
+
|
531 |
+
|
532 |
+
if __name__ == "__main__":
|
533 |
+
args = _get_args()
|
534 |
+
|
535 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
536 |
+
args.checkpoint_path,
|
537 |
+
trust_remote_code=True,
|
538 |
+
resume_download=True,
|
539 |
+
)
|
540 |
+
|
541 |
+
if args.api_auth:
|
542 |
+
app.add_middleware(
|
543 |
+
BasicAuthMiddleware, username=args.api_auth.split(":")[0], password=args.api_auth.split(":")[1]
|
544 |
+
)
|
545 |
+
|
546 |
+
if args.cpu_only:
|
547 |
+
device_map = "cpu"
|
548 |
+
else:
|
549 |
+
device_map = "auto"
|
550 |
+
|
551 |
+
model = AutoModelForCausalLM.from_pretrained(
|
552 |
+
args.checkpoint_path,
|
553 |
+
device_map=device_map,
|
554 |
+
trust_remote_code=True,
|
555 |
+
resume_download=True,
|
556 |
+
).eval()
|
557 |
+
|
558 |
+
model.generation_config = GenerationConfig.from_pretrained(
|
559 |
+
args.checkpoint_path,
|
560 |
+
trust_remote_code=True,
|
561 |
+
resume_download=True,
|
562 |
+
)
|
563 |
+
|
564 |
+
uvicorn.run(app, host=args.server_name, port=args.server_port, workers=1)
|
qwen_agent/__init__.py
ADDED
File without changes
|
qwen_agent/actions/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .continue_writing import ContinueWriting
|
2 |
+
from .expand_writing import ExpandWriting
|
3 |
+
from .gen_keyword import GenKeyword
|
4 |
+
from .outline_writing import OutlineWriting
|
5 |
+
from .react import ReAct
|
6 |
+
from .retrieval_qa import RetrievalQA
|
7 |
+
from .summarize import Summarize
|
8 |
+
from .write_from_scratch import WriteFromScratch
|
9 |
+
|
10 |
+
__all__ = [
|
11 |
+
'RetrievalQA', 'ContinueWriting', 'OutlineWriting', 'ExpandWriting',
|
12 |
+
'ReAct', 'WriteFromScratch', 'Summarize', 'GenKeyword'
|
13 |
+
]
|
qwen_agent/actions/base.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from typing import Dict, Iterator, List, Optional, Union
|
3 |
+
|
4 |
+
from qwen_agent.llm.base import BaseChatModel
|
5 |
+
from qwen_agent.utils.utils import has_chinese_chars
|
6 |
+
|
7 |
+
# TODO: Should *planning* just be another action that uses other actions?
|
8 |
+
|
9 |
+
|
10 |
+
class Action(ABC):
|
11 |
+
|
12 |
+
def __init__(self, llm: BaseChatModel = None, stream: bool = False):
|
13 |
+
self.llm = llm
|
14 |
+
self.stream = stream
|
15 |
+
|
16 |
+
def run(self, *args, **kwargs) -> Union[str, Iterator[str]]:
|
17 |
+
if 'lang' not in kwargs:
|
18 |
+
if has_chinese_chars([args, kwargs]):
|
19 |
+
kwargs['lang'] = 'zh'
|
20 |
+
else:
|
21 |
+
kwargs['lang'] = 'en'
|
22 |
+
return self._run(*args, **kwargs)
|
23 |
+
|
24 |
+
@abstractmethod
|
25 |
+
def _run(self, *args, **kwargs) -> Union[str, Iterator[str]]:
|
26 |
+
raise NotImplementedError
|
27 |
+
|
28 |
+
# It is okay for an Action to not call LLMs.
|
29 |
+
def _call_llm(
|
30 |
+
self,
|
31 |
+
prompt: Optional[str] = None,
|
32 |
+
messages: Optional[List[Dict]] = None,
|
33 |
+
stop: Optional[List[str]] = None,
|
34 |
+
) -> Union[str, Iterator[str]]:
|
35 |
+
return self.llm.chat(
|
36 |
+
prompt=prompt,
|
37 |
+
messages=messages,
|
38 |
+
stop=stop,
|
39 |
+
stream=self.stream,
|
40 |
+
)
|
qwen_agent/actions/continue_writing.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from qwen_agent.actions.base import Action
|
2 |
+
|
3 |
+
PROMPT_TEMPLATE_ZH = """你是一个写作助手,请依据参考资料,根据给定的前置文本续写合适的内容。
|
4 |
+
#参考资料:
|
5 |
+
{ref_doc}
|
6 |
+
|
7 |
+
#前置文本:
|
8 |
+
{user_request}
|
9 |
+
|
10 |
+
保证续写内容和前置文本保持连贯,请开始续写:"""
|
11 |
+
|
12 |
+
PROMPT_TEMPLATE_EN = """You are a writing assistant, please follow the reference materials and continue to write appropriate content based on the given previous text.
|
13 |
+
|
14 |
+
# References:
|
15 |
+
{ref_doc}
|
16 |
+
|
17 |
+
# Previous text:
|
18 |
+
{user_request}
|
19 |
+
|
20 |
+
Please start writing directly, output only the continued text, do not repeat the previous text, do not say irrelevant words, and ensure that the continued content and the previous text remain consistent."""
|
21 |
+
|
22 |
+
PROMPT_TEMPLATE = {
|
23 |
+
'zh': PROMPT_TEMPLATE_ZH,
|
24 |
+
'en': PROMPT_TEMPLATE_EN,
|
25 |
+
}
|
26 |
+
|
27 |
+
|
28 |
+
class ContinueWriting(Action):
|
29 |
+
|
30 |
+
def _run(self, user_request, ref_doc, lang: str = 'en'):
|
31 |
+
prompt = PROMPT_TEMPLATE[lang].format(
|
32 |
+
ref_doc=ref_doc,
|
33 |
+
user_request=user_request,
|
34 |
+
)
|
35 |
+
return self._call_llm(prompt)
|
qwen_agent/actions/expand_writing.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from qwen_agent.actions.base import Action
|
2 |
+
|
3 |
+
PROMPT_TEMPLATE_ZH = """
|
4 |
+
你是一个写作助手,任务是依据参考资料,完成写作任务。
|
5 |
+
#参考资料:
|
6 |
+
{ref_doc}
|
7 |
+
|
8 |
+
写作标题是:{user_request}
|
9 |
+
大纲是:
|
10 |
+
{outline}
|
11 |
+
|
12 |
+
此时你的任务是扩写第{index}个一级标题对应的章节:{capture}。注意每个章节负责撰写不同的内容,所以你不需要为了全面而涵盖之后的内容。请不要在这里生成大纲。只依据给定的参考资料来写,不要引入其余知识。
|
13 |
+
"""
|
14 |
+
|
15 |
+
PROMPT_TEMPLATE_EN = """
|
16 |
+
You are a writing assistant. Your task is to complete writing article based on reference materials.
|
17 |
+
|
18 |
+
# References:
|
19 |
+
{ref_doc}
|
20 |
+
|
21 |
+
The title is: {user_request}
|
22 |
+
|
23 |
+
The outline is:
|
24 |
+
{outline}
|
25 |
+
|
26 |
+
At this point, your task is to expand the chapter corresponding to the {index} first level title: {capture}.
|
27 |
+
Note that each chapter is responsible for writing different content, so you don't need to cover the following content. Please do not generate an outline here. Write only based on the given reference materials and do not introduce other knowledge.
|
28 |
+
"""
|
29 |
+
|
30 |
+
PROMPT_TEMPLATE = {
|
31 |
+
'zh': PROMPT_TEMPLATE_ZH,
|
32 |
+
'en': PROMPT_TEMPLATE_EN,
|
33 |
+
}
|
34 |
+
|
35 |
+
|
36 |
+
class ExpandWriting(Action):
|
37 |
+
|
38 |
+
def _run(
|
39 |
+
self,
|
40 |
+
user_request,
|
41 |
+
ref_doc,
|
42 |
+
outline='',
|
43 |
+
index='1',
|
44 |
+
capture='',
|
45 |
+
capture_later='',
|
46 |
+
lang: str = 'en',
|
47 |
+
):
|
48 |
+
prompt = PROMPT_TEMPLATE[lang].format(
|
49 |
+
ref_doc=ref_doc,
|
50 |
+
user_request=user_request,
|
51 |
+
index=index,
|
52 |
+
outline=outline,
|
53 |
+
capture=capture,
|
54 |
+
)
|
55 |
+
if capture_later:
|
56 |
+
if lang == 'zh':
|
57 |
+
prompt = prompt + '请在涉及 ' + capture_later + ' 时停止。'
|
58 |
+
elif lang == 'en':
|
59 |
+
prompt = prompt + ' Please stop when writing ' + capture_later
|
60 |
+
else:
|
61 |
+
raise NotImplementedError
|
62 |
+
return self._call_llm(prompt)
|