Upload 17 files
Browse files- .gitignore +177 -0
- CONTRIBUTING.md +22 -0
- Dockerfile +36 -0
- Dockerfile-cuda +14 -0
- LICENSE +201 -0
- README.md +235 -12
- README_en.md +247 -0
- api.py +465 -0
- cli.bat +2 -0
- cli.py +86 -0
- cli.sh +2 -0
- cli_demo.py +66 -0
- release.py +50 -0
- requirements.txt +36 -0
- webui.py +560 -0
- webui_st.py +538 -0
.gitignore
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*/**/__pycache__/
|
4 |
+
*.py[cod]
|
5 |
+
*$py.class
|
6 |
+
|
7 |
+
# C extensions
|
8 |
+
*.so
|
9 |
+
|
10 |
+
# Distribution / packaging
|
11 |
+
.Python
|
12 |
+
build/
|
13 |
+
develop-eggs/
|
14 |
+
dist/
|
15 |
+
downloads/
|
16 |
+
eggs/
|
17 |
+
.eggs/
|
18 |
+
lib/
|
19 |
+
lib64/
|
20 |
+
parts/
|
21 |
+
sdist/
|
22 |
+
var/
|
23 |
+
wheels/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
cover/
|
54 |
+
|
55 |
+
# Translations
|
56 |
+
*.mo
|
57 |
+
*.pot
|
58 |
+
|
59 |
+
# Django stuff:
|
60 |
+
*.log
|
61 |
+
local_settings.py
|
62 |
+
db.sqlite3
|
63 |
+
db.sqlite3-journal
|
64 |
+
|
65 |
+
# Flask stuff:
|
66 |
+
instance/
|
67 |
+
.webassets-cache
|
68 |
+
|
69 |
+
# Scrapy stuff:
|
70 |
+
.scrapy
|
71 |
+
|
72 |
+
# Sphinx documentation
|
73 |
+
docs/_build/
|
74 |
+
|
75 |
+
# PyBuilder
|
76 |
+
.pybuilder/
|
77 |
+
target/
|
78 |
+
|
79 |
+
# Jupyter Notebook
|
80 |
+
.ipynb_checkpoints
|
81 |
+
|
82 |
+
# IPython
|
83 |
+
profile_default/
|
84 |
+
ipython_config.py
|
85 |
+
|
86 |
+
# pyenv
|
87 |
+
# For a library or package, you might want to ignore these files since the code is
|
88 |
+
# intended to run in multiple environments; otherwise, check them in:
|
89 |
+
# .python-version
|
90 |
+
|
91 |
+
# pipenv
|
92 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
93 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
94 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
95 |
+
# install all needed dependencies.
|
96 |
+
#Pipfile.lock
|
97 |
+
|
98 |
+
# poetry
|
99 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
100 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
101 |
+
# commonly ignored for libraries.
|
102 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
103 |
+
#poetry.lock
|
104 |
+
|
105 |
+
# pdm
|
106 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
107 |
+
#pdm.lock
|
108 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
109 |
+
# in version control.
|
110 |
+
# https://pdm.fming.dev/#use-with-ide
|
111 |
+
.pdm.toml
|
112 |
+
|
113 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
114 |
+
__pypackages__/
|
115 |
+
|
116 |
+
# Celery stuff
|
117 |
+
celerybeat-schedule
|
118 |
+
celerybeat.pid
|
119 |
+
|
120 |
+
# SageMath parsed files
|
121 |
+
*.sage.py
|
122 |
+
|
123 |
+
# Environments
|
124 |
+
.env
|
125 |
+
.venv
|
126 |
+
env/
|
127 |
+
venv/
|
128 |
+
ENV/
|
129 |
+
env.bak/
|
130 |
+
venv.bak/
|
131 |
+
|
132 |
+
# Spyder project settings
|
133 |
+
.spyderproject
|
134 |
+
.spyproject
|
135 |
+
|
136 |
+
# Rope project settings
|
137 |
+
.ropeproject
|
138 |
+
|
139 |
+
# mkdocs documentation
|
140 |
+
/site
|
141 |
+
|
142 |
+
# mypy
|
143 |
+
.mypy_cache/
|
144 |
+
.dmypy.json
|
145 |
+
dmypy.json
|
146 |
+
|
147 |
+
# Pyre type checker
|
148 |
+
.pyre/
|
149 |
+
|
150 |
+
# pytype static type analyzer
|
151 |
+
.pytype/
|
152 |
+
|
153 |
+
# Cython debug symbols
|
154 |
+
cython_debug/
|
155 |
+
|
156 |
+
# PyCharm
|
157 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
158 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
159 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
160 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
161 |
+
.idea/
|
162 |
+
|
163 |
+
# Other files
|
164 |
+
output/*
|
165 |
+
log/*
|
166 |
+
.chroma
|
167 |
+
vector_store/*
|
168 |
+
content/*
|
169 |
+
api_content/*
|
170 |
+
knowledge_base/*
|
171 |
+
|
172 |
+
llm/*
|
173 |
+
embedding/*
|
174 |
+
|
175 |
+
pyrightconfig.json
|
176 |
+
loader/tmp_files
|
177 |
+
flagged/*
|
CONTRIBUTING.md
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 贡献指南
|
2 |
+
|
3 |
+
欢迎!我们是一个非常友好的社区,非常高兴您想要帮助我们让这个应用程序变得更好。但是,请您遵循一些通用准则以保持组织有序。
|
4 |
+
|
5 |
+
1. 确保为您要修复的错误或要添加的功能创建了一个[问题](https://github.com/imClumsyPanda/langchain-ChatGLM/issues),尽可能保持它们小。
|
6 |
+
2. 请使用 `git pull --rebase` 来拉取和衍合上游的更新。
|
7 |
+
3. 将提交合并为格式良好的提交。在提交说明中单独一行提到要解决的问题,如`Fix #<bug>`(有关更多可以使用的关键字,请参见[将拉取请求链接到问题](https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue))。
|
8 |
+
4. 推送到`dev`。在说明中提到正在解决的问题。
|
9 |
+
|
10 |
+
---
|
11 |
+
|
12 |
+
# Contribution Guide
|
13 |
+
|
14 |
+
Welcome! We're a pretty friendly community, and we're thrilled that you want to help make this app even better. However, we ask that you follow some general guidelines to keep things organized around here.
|
15 |
+
|
16 |
+
1. Make sure an [issue](https://github.com/imClumsyPanda/langchain-ChatGLM/issues) is created for the bug you're about to fix, or feature you're about to add. Keep them as small as possible.
|
17 |
+
|
18 |
+
2. Please use `git pull --rebase` to fetch and merge updates from the upstream.
|
19 |
+
|
20 |
+
3. Rebase commits into well-formatted commits. Mention the issue being resolved in the commit message on a line all by itself like `Fixes #<bug>` (refer to [Linking a pull request to an issue](https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue) for more keywords you can use).
|
21 |
+
|
22 |
+
4. Push into `dev`. Mention which bug is being resolved in the description.
|
Dockerfile
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.8
|
2 |
+
|
3 |
+
MAINTAINER "chatGLM"
|
4 |
+
|
5 |
+
COPY agent /chatGLM/agent
|
6 |
+
|
7 |
+
COPY chains /chatGLM/chains
|
8 |
+
|
9 |
+
COPY configs /chatGLM/configs
|
10 |
+
|
11 |
+
COPY content /chatGLM/content
|
12 |
+
|
13 |
+
COPY models /chatGLM/models
|
14 |
+
|
15 |
+
COPY nltk_data /chatGLM/content
|
16 |
+
|
17 |
+
COPY requirements.txt /chatGLM/
|
18 |
+
|
19 |
+
COPY cli_demo.py /chatGLM/
|
20 |
+
|
21 |
+
COPY textsplitter /chatGLM/
|
22 |
+
|
23 |
+
COPY webui.py /chatGLM/
|
24 |
+
|
25 |
+
WORKDIR /chatGLM
|
26 |
+
|
27 |
+
RUN pip install --user torch torchvision tensorboard cython -i https://pypi.tuna.tsinghua.edu.cn/simple
|
28 |
+
# RUN pip install --user 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'
|
29 |
+
|
30 |
+
# RUN pip install --user 'git+https://github.com/facebookresearch/fvcore'
|
31 |
+
# install detectron2
|
32 |
+
# RUN git clone https://github.com/facebookresearch/detectron2
|
33 |
+
|
34 |
+
RUN pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple/ --trusted-host pypi.tuna.tsinghua.edu.cn
|
35 |
+
|
36 |
+
CMD ["python","-u", "webui.py"]
|
Dockerfile-cuda
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM nvidia/cuda:12.1.0-runtime-ubuntu20.04
|
2 |
+
LABEL MAINTAINER="chatGLM"
|
3 |
+
|
4 |
+
COPY . /chatGLM/
|
5 |
+
|
6 |
+
WORKDIR /chatGLM
|
7 |
+
|
8 |
+
RUN ln -sf /usr/share/zoneinfo/Asia/Shanghai /etc/localtime && echo "Asia/Shanghai" > /etc/timezone
|
9 |
+
RUN apt-get update -y && apt-get install python3 python3-pip curl libgl1 libglib2.0-0 -y && apt-get clean
|
10 |
+
RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py && python3 get-pip.py
|
11 |
+
|
12 |
+
RUN pip3 install -r requirements.txt -i https://pypi.mirrors.ustc.edu.cn/simple/ && rm -rf `pip3 cache dir`
|
13 |
+
|
14 |
+
CMD ["python3","-u", "webui.py"]
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,12 +1,235 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 基于本地知识库的 ChatGLM 等大语言模型应用实现
|
2 |
+
|
3 |
+
## 介绍
|
4 |
+
|
5 |
+
🌍 [_READ THIS IN ENGLISH_](README_en.md)
|
6 |
+
|
7 |
+
🤖️ 一种利用 [langchain](https://github.com/hwchase17/langchain) 思想实现的基于本地知识库的问答应用,目标期望建立一套对中文场景与开源模型支持友好、可离线运行的知识库问答解决方案。
|
8 |
+
|
9 |
+
💡 受 [GanymedeNil](https://github.com/GanymedeNil) 的项目 [document.ai](https://github.com/GanymedeNil/document.ai) 和 [AlexZhangji](https://github.com/AlexZhangji) 创建的 [ChatGLM-6B Pull Request](https://github.com/THUDM/ChatGLM-6B/pull/216) 启发,建立了全流程可使用开源模型实现的本地知识库问答应用。现已支持使用 [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B) 等大语言模型直接接入,或通过 [fastchat](https://github.com/lm-sys/FastChat) api 形式接入 Vicuna, Alpaca, LLaMA, Koala, RWKV 等模型。
|
10 |
+
|
11 |
+
✅ 本项目中 Embedding 默认选用的是 [GanymedeNil/text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese/tree/main),LLM 默认选用的是 [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B)。依托上述模型,本项目可实现全部使用**开源**模型**离线私有部署**。
|
12 |
+
|
13 |
+
⛓️ 本项目实现原理如下图所示,过程包括加载文件 -> 读取文本 -> 文本分割 -> 文本向量化 -> 问句向量化 -> 在文本向量中匹配出与问句向量最相似的`top k`个 -> 匹配出的文本作为上下文和问题一起添加到`prompt`中 -> 提交给`LLM`生成回答。
|
14 |
+
|
15 |
+
📺 [原理介绍视频](https://www.bilibili.com/video/BV13M4y1e7cN/?share_source=copy_web&vd_source=e6c5aafe684f30fbe41925d61ca6d514)
|
16 |
+
|
17 |
+
![实现原理图](img/langchain+chatglm.png)
|
18 |
+
|
19 |
+
从文档处理角度来看,实现流程如下:
|
20 |
+
|
21 |
+
![实现原理图2](img/langchain+chatglm2.png)
|
22 |
+
|
23 |
+
|
24 |
+
🚩 本项目未涉及微调、训练过程,但可利用微调或训练对本项目效果进行优化。
|
25 |
+
|
26 |
+
🌐 [AutoDL 镜像](https://www.codewithgpu.com/i/imClumsyPanda/langchain-ChatGLM/langchain-ChatGLM)
|
27 |
+
|
28 |
+
📓 [ModelWhale 在线运行项目](https://www.heywhale.com/mw/project/643977aa446c45f4592a1e59)
|
29 |
+
|
30 |
+
## 变更日志
|
31 |
+
|
32 |
+
参见 [版本更新日志](https://github.com/imClumsyPanda/langchain-ChatGLM/releases)。
|
33 |
+
|
34 |
+
## 硬件需求
|
35 |
+
|
36 |
+
- ChatGLM-6B 模型硬件需求
|
37 |
+
|
38 |
+
注:如未将模型下载至本地,请执行前检查`$HOME/.cache/huggingface/`文件夹剩余空间,模型文件下载至本地需要 15 GB 存储空间。
|
39 |
+
注:一些其它的可选启动项见[项目启动选项](docs/StartOption.md)
|
40 |
+
模型下载方法可参考 [常见问题](docs/FAQ.md) 中 Q8。
|
41 |
+
|
42 |
+
| **量化等级** | **最低 GPU 显存**(推理) | **最低 GPU 显存**(高效参数微调) |
|
43 |
+
| -------------- | ------------------------- | --------------------------------- |
|
44 |
+
| FP16(无量化) | 13 GB | 14 GB |
|
45 |
+
| INT8 | 8 GB | 9 GB |
|
46 |
+
| INT4 | 6 GB | 7 GB |
|
47 |
+
|
48 |
+
- MOSS 模型硬件需求
|
49 |
+
|
50 |
+
注:如未将模型下载至本地,请执行前检查`$HOME/.cache/huggingface/`文件夹剩余空间,模型文件下载至本地需要 70 GB 存储空间
|
51 |
+
|
52 |
+
模型下载方法可参考 [常见问题](docs/FAQ.md) 中 Q8。
|
53 |
+
|
54 |
+
| **量化等级** | **最低 GPU 显存**(推理) | **最低 GPU 显存**(高效参数微调) |
|
55 |
+
|-------------------|-----------------------| --------------------------------- |
|
56 |
+
| FP16(无量化) | 68 GB | - |
|
57 |
+
| INT8 | 20 GB | - |
|
58 |
+
|
59 |
+
- Embedding 模型硬件需求
|
60 |
+
|
61 |
+
本项目中默认选用的 Embedding 模型 [GanymedeNil/text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese/tree/main) 约占用显存 3GB,也可修改为在 CPU 中运行。
|
62 |
+
|
63 |
+
## Docker 部署
|
64 |
+
为了能让容器使用主机GPU资源,需要在主机上安装 [NVIDIA Container Toolkit](https://github.com/NVIDIA/nvidia-container-toolkit)。具体安装步骤如下:
|
65 |
+
```shell
|
66 |
+
sudo apt-get update
|
67 |
+
sudo apt-get install -y nvidia-container-toolkit-base
|
68 |
+
sudo systemctl daemon-reload
|
69 |
+
sudo systemctl restart docker
|
70 |
+
```
|
71 |
+
安装完成后,可以使用以下命令编译镜像和启动容器:
|
72 |
+
```
|
73 |
+
docker build -f Dockerfile-cuda -t chatglm-cuda:latest .
|
74 |
+
docker run --gpus all -d --name chatglm -p 7860:7860 chatglm-cuda:latest
|
75 |
+
|
76 |
+
#若要使用离线模型,请配置好模型路径,然后此repo挂载到Container
|
77 |
+
docker run --gpus all -d --name chatglm -p 7860:7860 -v ~/github/langchain-ChatGLM:/chatGLM chatglm-cuda:latest
|
78 |
+
```
|
79 |
+
|
80 |
+
|
81 |
+
## 开发部署
|
82 |
+
|
83 |
+
### 软件需求
|
84 |
+
|
85 |
+
本项目已在 Python 3.8.1 - 3.10,CUDA 11.7 环境下完成测试。已在 Windows、ARM 架构的 macOS、Linux 系统中完成测试。
|
86 |
+
|
87 |
+
vue前端需要node18环境
|
88 |
+
|
89 |
+
### 从本地加载模型
|
90 |
+
|
91 |
+
请参考 [THUDM/ChatGLM-6B#从本地加载模型](https://github.com/THUDM/ChatGLM-6B#从本地加载模型)
|
92 |
+
|
93 |
+
### 1. 安装环境
|
94 |
+
|
95 |
+
参见 [安装指南](docs/INSTALL.md)。
|
96 |
+
|
97 |
+
### 2. 设置模型默认参数
|
98 |
+
|
99 |
+
在开始执行 Web UI 或命令行交互前,请先检查 [configs/model_config.py](configs/model_config.py) 中的各项模型参数设计是否符合需求。
|
100 |
+
|
101 |
+
如需通过 fastchat 以 api 形式调用 llm,请参考 [fastchat 调用实现](docs/fastchat.md)
|
102 |
+
|
103 |
+
### 3. 执行脚本体验 Web UI 或命令行交互
|
104 |
+
|
105 |
+
> 注:鉴于环境部署过程中可能遇到问题,建议首先测试命令行脚本。建议命令行脚本测试可正常运行后再运行 Web UI。
|
106 |
+
|
107 |
+
执行 [cli_demo.py](cli_demo.py) 脚本体验**命令行交互**:
|
108 |
+
```shell
|
109 |
+
$ python cli_demo.py
|
110 |
+
```
|
111 |
+
|
112 |
+
或执行 [webui.py](webui.py) 脚本体验 **Web 交互**
|
113 |
+
|
114 |
+
```shell
|
115 |
+
$ python webui.py
|
116 |
+
```
|
117 |
+
|
118 |
+
或执行 [api.py](api.py) 利用 fastapi 部署 API
|
119 |
+
```shell
|
120 |
+
$ python api.py
|
121 |
+
```
|
122 |
+
或成功部署 API 后,执行以下脚本体验基于 VUE 的前端页面
|
123 |
+
```shell
|
124 |
+
$ cd views
|
125 |
+
|
126 |
+
$ pnpm i
|
127 |
+
|
128 |
+
$ npm run dev
|
129 |
+
```
|
130 |
+
|
131 |
+
VUE 前端界面如下图所示:
|
132 |
+
1. `对话` 界面
|
133 |
+
![](img/vue_0521_0.png)
|
134 |
+
2. `知识库问答` 界面
|
135 |
+
![](img/vue_0521_1.png)
|
136 |
+
3. `Bing搜索` 界面
|
137 |
+
![](img/vue_0521_2.png)
|
138 |
+
|
139 |
+
WebUI 界面如下图所示:
|
140 |
+
1. `对话` Tab 界面
|
141 |
+
![](img/webui_0521_0.png)
|
142 |
+
2. `知识库测试 Beta` Tab 界面
|
143 |
+
![](img/webui_0510_1.png)
|
144 |
+
3. `模型配置` Tab 界面
|
145 |
+
![](img/webui_0510_2.png)
|
146 |
+
|
147 |
+
Web UI 可以实现如下功能:
|
148 |
+
|
149 |
+
1. 运行前自动读取`configs/model_config.py`中`LLM`及`Embedding`模型枚举及默认模型设置运行模型,如需重新加载模型,可在 `模型配置` Tab 重新选择后点击 `重新加载模型` 进行模型加载;
|
150 |
+
2. 可手动调节保留对话历史长度、匹配知识库文段数量,可根据显存大小自行调节;
|
151 |
+
3. `对话` Tab 具备模式选择功能,可选择 `LLM对话` 与 `知识库问答` 模式进行对话,支持流式对话;
|
152 |
+
4. 添加 `配置知识库` 功能,支持选择已有知识库或新建知识库,并可向知识库中**新增**上传文件/文件夹,使用文件上传组件选择好文件后点击 `上传文件并加载知识库`,会将所选上传文档数据加载至知识库中,并基于更新后知识库进行问答;
|
153 |
+
5. 新增 `知识库测试 Beta` Tab,可用于测试不同文本切分方法与检索相关度阈值设置,暂不支持将测试参数作为 `对话` Tab 设置参数。
|
154 |
+
6. 后续版本中将会增加对知识库的修改或删除,及知识库中已导入文件的查看。
|
155 |
+
|
156 |
+
### 常见问题
|
157 |
+
|
158 |
+
参见 [常见问题](docs/FAQ.md)。
|
159 |
+
|
160 |
+
## Demo
|
161 |
+
|
162 |
+
以问题`chatglm-6b 的局限性具体体现在哪里,如何实现改进`为例:
|
163 |
+
|
164 |
+
未使用 langchain 接入本地文档时:
|
165 |
+
|
166 |
+
>ChatGLM-6B 是一个基于 GPT-3 的聊天机器人,它的局限性主要体现在以下几个方面:
|
167 |
+
>1. 语言模型的精度:GPT-3 是一种大型的语言模型,它的精度比较高,但是仍然有一些局限性。例如,它可能无法准确地理解某些方言、口音或者复杂的句子结构。
|
168 |
+
>2. 知识图谱的完整性:GPT-3 可以生成一些知识图谱,但是其完整性和准确性可能不够。例如,它可能无法准确地理解某些领域的专业知识和术语。
|
169 |
+
>3. 对话生成的连贯性:GPT-3 可以生成一些连贯的对话,但是其生成的对话可能不够自然。例如,它可能会生成一些虚假的、不连贯的对话或者误解用户的意图。
|
170 |
+
>
|
171 |
+
>为了改进 ChatGLM-6B,可以考虑以下几个方面:
|
172 |
+
>1. 提高语言模型的精度:可以通过增加训练数据、使用更好的语言模型架构、优化模型训练过程等方式来提高语言模型的精度。
|
173 |
+
>2. 增强知识图谱的完整性:可以通过增加知识图谱的维度、使用更好的知识图谱构建方法、优化知识图谱的匹配算法等方式来增强知识图谱的完整性。
|
174 |
+
>3. 提高对话生成的连贯性:可以通过增加对话生成模型的上下文信息、使用更好的上下文信息提取方法、优化对话生成模型的算法等方式来提高对话生成的连贯性。
|
175 |
+
>4. 引入更多的人类反馈:人类反馈可以帮助 ChatGLM-6B 更好地理解用户的意图和需求,从而提高对话生成的质量和准确性。可以通过引入更多的人类客服、使用人类反馈收集器等方式来引入更多的人类反馈。
|
176 |
+
|
177 |
+
使用 langchain 接入 ChatGLM-6B 项目的 README.md 文件后 ChatGLM 的回答:
|
178 |
+
>ChatGLM-6B 是一个基于深度学习的自然语言处理模型,它在回答问题方面表现出色。但是,它也有一些局限性,具体如下:
|
179 |
+
>1. 语言理解有限:ChatGLM-6B 只学习了中文自然语言,对于其他语言可能会存在理解上的局限性。
|
180 |
+
>2. 知识储备不足:ChatGLM-6B 的训练数据集只包含了中文语料,因此���可能无法回答一些非中文的问题或者某些特定领域的问题。
|
181 |
+
>3. 数据量有限:ChatGLM-6B 的训练数据集只有几百万条记录,因此它可能无法回答一些非常具体或者复杂的问题。
|
182 |
+
>
|
183 |
+
>为了改进 ChatGLM-6B,可以考虑以下几个方面:
|
184 |
+
>1. 学习更多的语言知识:可以学习其他语言的自然语言处理技术,扩大语言理解的范围。
|
185 |
+
>2. 扩大知识储备:可以收集更多的中文语料,或者使用其他语言的数据集来扩充知识储备。
|
186 |
+
>3. 增加数据量:可以使用更大的数据集来训练 ChatGLM-6B,提高模型的表现。
|
187 |
+
>4. 引入更多的评估指标:可以引入更多的评估指标来评估模型的表现,从而发现 ChatGLM-6B 存在的不足和局限性。
|
188 |
+
>5. 改进模型架构:可以改进 ChatGLM-6B 的模型架构,提高模型的性能和表现。例如,可以使用更大的神经网络或者改进的卷积神经网络结构。
|
189 |
+
|
190 |
+
## 路线图
|
191 |
+
|
192 |
+
- [ ] Langchain 应用
|
193 |
+
- [x] 接入非结构化文档(已支持 md、pdf、docx、txt 文件格式)
|
194 |
+
- [x] jpg 与 png 格式图片的 OCR 文字识别
|
195 |
+
- [x] 搜索引擎接入
|
196 |
+
- [ ] 本地网页接入
|
197 |
+
- [ ] 结构化数据接入(如 csv、Excel、SQL 等)
|
198 |
+
- [ ] 知识图谱/图数据库接入
|
199 |
+
- [ ] Agent 实现
|
200 |
+
- [x] 增加更多 LLM 模型支持
|
201 |
+
- [x] [THUDM/chatglm-6b](https://huggingface.co/THUDM/chatglm-6b)
|
202 |
+
- [x] [THUDM/chatglm-6b-int8](https://huggingface.co/THUDM/chatglm-6b-int8)
|
203 |
+
- [x] [THUDM/chatglm-6b-int4](https://huggingface.co/THUDM/chatglm-6b-int4)
|
204 |
+
- [x] [THUDM/chatglm-6b-int4-qe](https://huggingface.co/THUDM/chatglm-6b-int4-qe)
|
205 |
+
- [x] [ClueAI/ChatYuan-large-v2](https://huggingface.co/ClueAI/ChatYuan-large-v2)
|
206 |
+
- [x] [fnlp/moss-moon-003-sft](https://huggingface.co/fnlp/moss-moon-003-sft)
|
207 |
+
- [x] 支持通过调用 [fastchat](https://github.com/lm-sys/FastChat) api 调用 llm
|
208 |
+
- [x] 增加更多 Embedding 模型支持
|
209 |
+
- [x] [nghuyong/ernie-3.0-nano-zh](https://huggingface.co/nghuyong/ernie-3.0-nano-zh)
|
210 |
+
- [x] [nghuyong/ernie-3.0-base-zh](https://huggingface.co/nghuyong/ernie-3.0-base-zh)
|
211 |
+
- [x] [shibing624/text2vec-base-chinese](https://huggingface.co/shibing624/text2vec-base-chinese)
|
212 |
+
- [x] [GanymedeNil/text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese)
|
213 |
+
- [x] [moka-ai/m3e-small](https://huggingface.co/moka-ai/m3e-small)
|
214 |
+
- [x] [moka-ai/m3e-base](https://huggingface.co/moka-ai/m3e-base)
|
215 |
+
- [ ] Web UI
|
216 |
+
- [x] 基于 gradio 实现 Web UI DEMO
|
217 |
+
- [x] 基于 streamlit 实现 Web UI DEMO
|
218 |
+
- [x] 添加输出内容及错误提示
|
219 |
+
- [x] 引用标注
|
220 |
+
- [ ] 增加知识库管理
|
221 |
+
- [x] 选择知识库开始问答
|
222 |
+
- [x] 上传文件/文件夹至知识库
|
223 |
+
- [x] 知识库测试
|
224 |
+
- [ ] 删除知识库中文件
|
225 |
+
- [x] 支持搜索引擎问答
|
226 |
+
- [ ] 增加 API 支持
|
227 |
+
- [x] 利用 fastapi 实现 API 部署方式
|
228 |
+
- [ ] 实现调用 API 的 Web UI Demo
|
229 |
+
- [x] VUE 前端
|
230 |
+
|
231 |
+
## 项目交流群
|
232 |
+
<img src="img/qr_code_35.jpg" alt="二维码" width="300" height="300" />
|
233 |
+
|
234 |
+
|
235 |
+
🎉 langchain-ChatGLM 项目微信交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。
|
README_en.md
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ChatGLM Application with Local Knowledge Implementation
|
2 |
+
|
3 |
+
## Introduction
|
4 |
+
|
5 |
+
[![Telegram](https://img.shields.io/badge/Telegram-2CA5E0?style=for-the-badge&logo=telegram&logoColor=white "langchain-chatglm")](https://t.me/+RjliQ3jnJ1YyN2E9)
|
6 |
+
|
7 |
+
🌍 [_中文文档_](README.md)
|
8 |
+
|
9 |
+
🤖️ This is a ChatGLM application based on local knowledge, implemented using [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B) and [langchain](https://github.com/hwchase17/langchain).
|
10 |
+
|
11 |
+
💡 Inspired by [document.ai](https://github.com/GanymedeNil/document.ai) and [Alex Zhangji](https://github.com/AlexZhangji)'s [ChatGLM-6B Pull Request](https://github.com/THUDM/ChatGLM-6B/pull/216), this project establishes a local knowledge question-answering application using open-source models.
|
12 |
+
|
13 |
+
✅ The embeddings used in this project are [GanymedeNil/text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese/tree/main), and the LLM is [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B). Relying on these models, this project enables the use of **open-source** models for **offline private deployment**.
|
14 |
+
|
15 |
+
⛓️ The implementation principle of this project is illustrated in the figure below. The process includes loading files -> reading text -> text segmentation -> text vectorization -> question vectorization -> matching the top k most similar text vectors to the question vector -> adding the matched text to `prompt` along with the question as context -> submitting to `LLM` to generate an answer.
|
16 |
+
|
17 |
+
![Implementation schematic diagram](img/langchain+chatglm.png)
|
18 |
+
|
19 |
+
🚩 This project does not involve fine-tuning or training; however, fine-tuning or training can be employed to optimize the effectiveness of this project.
|
20 |
+
|
21 |
+
📓 [ModelWhale online notebook](https://www.heywhale.com/mw/project/643977aa446c45f4592a1e59)
|
22 |
+
|
23 |
+
## Changelog
|
24 |
+
|
25 |
+
**[2023/04/15]**
|
26 |
+
|
27 |
+
1. refactor the project structure to keep the command line demo [cli_demo.py](cli_demo.py) and the Web UI demo [webui.py](webui.py) in the root directory.
|
28 |
+
2. Improve the Web UI by modifying it to first load the model according to the default option of [configs/model_config.py](configs/model_config.py) after running the Web UI, and adding error messages, etc.
|
29 |
+
3. Update FAQ.
|
30 |
+
|
31 |
+
**[2023/04/12]**
|
32 |
+
|
33 |
+
1. Replaced the sample files in the Web UI to avoid issues with unreadable files due to encoding problems in Ubuntu;
|
34 |
+
2. Replaced the prompt template in `knowledge_based_chatglm.py` to prevent confusion in the content returned by ChatGLM, which may arise from the prompt template containing Chinese and English bilingual text.
|
35 |
+
|
36 |
+
**[2023/04/11]**
|
37 |
+
|
38 |
+
1. Added Web UI V0.1 version (thanks to [@liangtongt](https://github.com/liangtongt));
|
39 |
+
2. Added Frequently Asked Questions in `README.md` (thanks to [@calcitem](https://github.com/calcitem) and [@bolongliu](https://github.com/bolongliu));
|
40 |
+
3. Enhanced automatic detection for the availability of `cuda`, `mps`, and `cpu` for LLM and Embedding model running devices;
|
41 |
+
4. Added a check for `filepath` in `knowledge_based_chatglm.py`. In addition to supporting single file import, it now supports a single folder path as input. After input, it will traverse each file in the folder and display a command-line message indicating the success of each file load.
|
42 |
+
|
43 |
+
5. **[2023/04/09]**
|
44 |
+
|
45 |
+
1. Replaced the previously selected `ChatVectorDBChain` with `RetrievalQA` in `langchain`, effectively reducing the issue of stopping due to insufficient video memory after asking 2-3 times;
|
46 |
+
2. Added `EMBEDDING_MODEL`, `VECTOR_SEARCH_TOP_K`, `LLM_MODEL`, `LLM_HISTORY_LEN`, `REPLY_WITH_SOURCE` parameter value settings in `knowledge_based_chatglm.py`;
|
47 |
+
3. Added `chatglm-6b-int4` and `chatglm-6b-int4-qe`, which require less GPU memory, as LLM model options;
|
48 |
+
4. Corrected code errors in `README.md` (thanks to [@calcitem](https://github.com/calcitem)).
|
49 |
+
|
50 |
+
**[2023/04/07]**
|
51 |
+
|
52 |
+
1. Resolved the issue of doubled video memory usage when loading the ChatGLM model (thanks to [@suc16](https://github.com/suc16) and [@myml](https://github.com/myml));
|
53 |
+
2. Added a mechanism to clear video memory;
|
54 |
+
3. Added `nghuyong/ernie-3.0-nano-zh` and `nghuyong/ernie-3.0-base-zh` as Embedding model options, which consume less video memory resources than `GanymedeNil/text2vec-large-chinese` (thanks to [@lastrei](https://github.com/lastrei)).
|
55 |
+
|
56 |
+
## How to Use
|
57 |
+
|
58 |
+
### Hardware Requirements
|
59 |
+
|
60 |
+
- ChatGLM-6B Model Hardware Requirements
|
61 |
+
|
62 |
+
| **Quantization Level** | **Minimum GPU Memory** (inference) | **Minimum GPU Memory** (efficient parameter fine-tuning) |
|
63 |
+
| -------------- | ------------------------- | --------------------------------- |
|
64 |
+
| FP16 (no quantization) | 13 GB | 14 GB |
|
65 |
+
| INT8 | 8 GB | 9 GB |
|
66 |
+
| INT4 | 6 GB | 7 GB |
|
67 |
+
|
68 |
+
- Embedding Model Hardware Requirements
|
69 |
+
|
70 |
+
The default Embedding model [GanymedeNil/text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese/tree/main) in this project occupies around 3GB of video memory and can also be configured to run on a CPU.
|
71 |
+
### Software Requirements
|
72 |
+
|
73 |
+
This repository has been tested with Python 3.8 and CUDA 11.7 environments.
|
74 |
+
|
75 |
+
### 1. Setting up the environment
|
76 |
+
|
77 |
+
* Environment check
|
78 |
+
|
79 |
+
```shell
|
80 |
+
# First, make sure your machine has Python 3.8 or higher installed
|
81 |
+
$ python --version
|
82 |
+
Python 3.8.13
|
83 |
+
|
84 |
+
# If your version is lower, you can use conda to install the environment
|
85 |
+
$ conda create -p /your_path/env_name python=3.8
|
86 |
+
|
87 |
+
# Activate the environment
|
88 |
+
$ source activate /your_path/env_name
|
89 |
+
|
90 |
+
# Deactivate the environment
|
91 |
+
$ source deactivate /your_path/env_name
|
92 |
+
|
93 |
+
# Remove the environment
|
94 |
+
$ conda env remove -p /your_path/env_name
|
95 |
+
```
|
96 |
+
|
97 |
+
* Project dependencies
|
98 |
+
|
99 |
+
```shell
|
100 |
+
|
101 |
+
# Clone the repository
|
102 |
+
$ git clone https://github.com/imClumsyPanda/langchain-ChatGLM.git
|
103 |
+
|
104 |
+
# Install dependencies
|
105 |
+
$ pip install -r requirements.txt
|
106 |
+
```
|
107 |
+
|
108 |
+
Note: When using langchain.document_loaders.UnstructuredFileLoader for unstructured file integration, you may need to install other dependency packages according to the documentation. Please refer to [langchain documentation](https://python.langchain.com/en/latest/modules/indexes/document_loaders/examples/unstructured_file.html).
|
109 |
+
|
110 |
+
### 2. Run Scripts to Experience Web UI or Command Line Interaction
|
111 |
+
|
112 |
+
Execute [webui.py](webui.py) script to experience **Web interaction** <img src="https://img.shields.io/badge/Version-0.1-brightgreen">
|
113 |
+
```commandline
|
114 |
+
python webui.py
|
115 |
+
|
116 |
+
```
|
117 |
+
Or execute [api.py](api.py) script to deploy web api.
|
118 |
+
```shell
|
119 |
+
$ python api.py
|
120 |
+
```
|
121 |
+
Note: Before executing, check the remaining space in the `$HOME/.cache/huggingface/` folder, at least 15G.
|
122 |
+
|
123 |
+
Or execute following command to run VUE after api.py executed
|
124 |
+
```shell
|
125 |
+
$ cd views
|
126 |
+
|
127 |
+
$ pnpm i
|
128 |
+
|
129 |
+
$ npm run dev
|
130 |
+
```
|
131 |
+
|
132 |
+
VUE interface screenshots:
|
133 |
+
|
134 |
+
![](img/vue_0521_0.png)
|
135 |
+
|
136 |
+
![](img/vue_0521_1.png)
|
137 |
+
|
138 |
+
![](img/vue_0521_2.png)
|
139 |
+
|
140 |
+
Web UI interface screenshots:
|
141 |
+
|
142 |
+
![img.png](img/webui_0521_0.png)
|
143 |
+
|
144 |
+
![](img/webui_0510_1.png)
|
145 |
+
|
146 |
+
![](img/webui_0510_2.png)
|
147 |
+
|
148 |
+
The Web UI supports the following features:
|
149 |
+
|
150 |
+
1. Automatically reads the `LLM` and `embedding` model enumerations in `configs/model_config.py`, allowing you to select and reload the model by clicking `重新加载模型`.
|
151 |
+
2. The length of retained dialogue history can be manually adjusted according to the available video memory.
|
152 |
+
3. Adds a file upload function. Select the uploaded file through the drop-down box, click `加载文件` to load the file, and change the loaded file at any time during the process.
|
153 |
+
|
154 |
+
Alternatively, execute the [knowledge_based_chatglm.py](https://chat.openai.com/chat/cli_demo.py) script to experience **command line interaction**:
|
155 |
+
|
156 |
+
```commandline
|
157 |
+
python knowledge_based_chatglm.py
|
158 |
+
```
|
159 |
+
|
160 |
+
### FAQ
|
161 |
+
|
162 |
+
Q1: What file formats does this project support?
|
163 |
+
|
164 |
+
A1: Currently, this project has been tested with txt, docx, and md file formats. For more file formats, please refer to the [langchain documentation](https://python.langchain.com/en/latest/modules/indexes/document_loaders/examples/unstructured_file.html). It is known that if the document contains special characters, there might be issues with loading the file.
|
165 |
+
|
166 |
+
Q2: How can I resolve the `detectron2` dependency issue when reading specific file formats?
|
167 |
+
|
168 |
+
A2: As the installation process for this package can be problematic and it is only required for some file formats, it is not included in `requirements.txt`. You can install it with the following command:
|
169 |
+
|
170 |
+
```commandline
|
171 |
+
pip install "detectron2@git+https://github.com/facebookresearch/detectron2.git@v0.6#egg=detectron2"
|
172 |
+
```
|
173 |
+
|
174 |
+
Q3: How can I solve the `Resource punkt not found.` error?
|
175 |
+
|
176 |
+
A3: Unzip the `packages/tokenizers` folder from https://github.com/nltk/nltk_data/raw/gh-pages/packages/tokenizers/punkt.zip, and place it in the `nltk_data/tokenizers` storage path.
|
177 |
+
|
178 |
+
The `nltk_data` storage path can be found using `nltk.data.path`.
|
179 |
+
|
180 |
+
Q4: How can I solve the `Resource averaged_perceptron_tagger not found.` error?
|
181 |
+
|
182 |
+
A4: Download https://github.com/nltk/nltk_data/blob/gh-pages/packages/taggers/averaged_perceptron_tagger.zip, extract it, and place it in the `nltk_data/taggers` storage path.
|
183 |
+
|
184 |
+
The `nltk_data` storage path can be found using `nltk.data.path`.
|
185 |
+
|
186 |
+
Q5: Can this project run in Google Colab?
|
187 |
+
|
188 |
+
A5: You can try running the chatglm-6b-int4 model in Google Colab. Please note that if you want to run the Web UI in Colab, you need to set the `share` parameter in `demo.queue(concurrency_count=3).launch(server_name='0.0.0.0', share=False, inbrowser=False)` to `True`.
|
189 |
+
|
190 |
+
This issue is related to the system environment. For more details, please refer to [Issues with installing packages using pip in Anaconda](docs/Issue-with-Installing-Packages-Using-pip-in-Anaconda.md).
|
191 |
+
|
192 |
+
## DEMO
|
193 |
+
|
194 |
+
Using the question "What are the limitations of ChatGLM-6B and how can they be improved?" as an example:
|
195 |
+
|
196 |
+
Without using langchain to access local documents:
|
197 |
+
|
198 |
+
> ChatGLM-6B is a chatbot based on GPT-3, and its limitations are mainly reflected in the following aspects:
|
199 |
+
>
|
200 |
+
> 1. Language model accuracy: Although GPT-3 is a large language model with relatively high accuracy, it still has some limitations. For instance, it may not be able to accurately understand certain dialects, accents, or complex sentence structures.
|
201 |
+
> 2. Completeness of the knowledge graph: GPT-3 can generate some knowledge graphs, but its completeness and accuracy may not be sufficient. For example, it may struggle to accurately understand specific domain expertise and terminology.
|
202 |
+
> 3. Coherence of dialogue generation: While GPT-3 can generate somewhat coherent dialogues, they may not always be natural. For instance, it may generate false, incoherent dialogue or misinterpret the user's intent.
|
203 |
+
>
|
204 |
+
> To improve ChatGLM-6B, consider the following aspects:
|
205 |
+
>
|
206 |
+
> 1. Enhance the accuracy of the language model: The accuracy of the language model can be improved by increasing the training data, using better language model architectures, and optimizing the model training process.
|
207 |
+
> 2. Strengthen the integrity of the knowledge graph: The integrity of the knowledge graph can be enhanced by increasing its dimensions, employing better knowledge graph construction methods, and optimizing the knowledge graph's matching algorithm.
|
208 |
+
> 3. Boost the coherence of dialogue generation: The coherence of dialogue generation can be improved by augmenting the context information of the dialogue generation model, utilizing better context information extraction methods, and optimizing the dialogue generation model's algorithm.
|
209 |
+
> 4. Incorporate more human feedback: Human feedback can help ChatGLM-6B better understand users' intentions and needs, thereby improving the quality and accuracy of dialogue generation. More human feedback can be introduced by involving more human agents and using human feedback collectors.
|
210 |
+
|
211 |
+
ChatGLM's answer after using LangChain to access the README.md file of the ChatGLM-6B project:
|
212 |
+
>ChatGLM-6B is a deep learning-based natural language processing model that excels at answering questions. However, it also has some limitations, as follows:
|
213 |
+
>1. Limited language understanding: ChatGLM-6B has been primarily trained on Chinese natural language, and its understanding of other languages may be limited.
|
214 |
+
>2. Insufficient knowledge base: The training dataset of ChatGLM-6B contains only a Chinese corpus, so it may not be able to answer non-Chinese questions or queries in specific domains.
|
215 |
+
>3. Limited data volume: ChatGLM-6B's training dataset has only a few million records, which may hinder its ability to answer very specific or complex questions.
|
216 |
+
>
|
217 |
+
>To improve ChatGLM-6B, consider the following aspects:
|
218 |
+
>1. Expand language knowledge: Learn natural language processing techniques in other languages to broaden the model's language understanding capabilities.
|
219 |
+
>2. Broaden the knowledge base: Collect more Chinese corpora or use datasets in other languages to expand the model's knowledge base.
|
220 |
+
>3. Increase data volume: Use larger datasets to train ChatGLM-6B, which can improve the model's performance.
|
221 |
+
>4. Introduce more evaluation metrics: Incorporate additional evaluation metrics to assess the model's performance, which can help identify the shortcomings and limitations of ChatGLM-6B.
|
222 |
+
>5. Enhance the model architecture: Improve ChatGLM-6B's model architecture to boost its performance and capabilities. For example, employ larger neural networks or refined convolutional neural network structures.
|
223 |
+
|
224 |
+
## Roadmap
|
225 |
+
|
226 |
+
- [x] Implement LangChain + ChatGLM-6B for local knowledge application
|
227 |
+
- [x] Unstructured file access based on langchain
|
228 |
+
- [x].md
|
229 |
+
- [x].pdf
|
230 |
+
- [x].docx
|
231 |
+
- [x].txt
|
232 |
+
- [ ] Add support for more LLM models
|
233 |
+
- [x] THUDM/chatglm-6b
|
234 |
+
- [x] THUDM/chatglm-6b-int4
|
235 |
+
- [x] THUDM/chatglm-6b-int4-qe
|
236 |
+
- [ ] Add Web UI DEMO
|
237 |
+
- [x] Implement Web UI DEMO using Gradio
|
238 |
+
- [x] Add output and error messages
|
239 |
+
- [x] Citation callout
|
240 |
+
- [ ] Knowledge base management
|
241 |
+
- [x] QA based on selected knowledge base
|
242 |
+
- [x] Add files/folder to knowledge base
|
243 |
+
- [ ] Add files/folder to knowledge base
|
244 |
+
- [ ] Implement Web UI DEMO using Streamlit
|
245 |
+
- [ ] Add support for API deployment
|
246 |
+
- [x] Use fastapi to implement API
|
247 |
+
- [ ] Implement Web UI DEMO for API calls
|
api.py
ADDED
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import shutil
|
5 |
+
from typing import List, Optional
|
6 |
+
import urllib
|
7 |
+
|
8 |
+
import nltk
|
9 |
+
import pydantic
|
10 |
+
import uvicorn
|
11 |
+
from fastapi import Body, FastAPI, File, Form, Query, UploadFile, WebSocket
|
12 |
+
from fastapi.middleware.cors import CORSMiddleware
|
13 |
+
from pydantic import BaseModel
|
14 |
+
from typing_extensions import Annotated
|
15 |
+
from starlette.responses import RedirectResponse
|
16 |
+
|
17 |
+
from chains.local_doc_qa import LocalDocQA
|
18 |
+
from configs.model_config import (KB_ROOT_PATH, EMBEDDING_DEVICE,
|
19 |
+
EMBEDDING_MODEL, NLTK_DATA_PATH,
|
20 |
+
VECTOR_SEARCH_TOP_K, LLM_HISTORY_LEN, OPEN_CROSS_DOMAIN)
|
21 |
+
import models.shared as shared
|
22 |
+
from models.loader.args import parser
|
23 |
+
from models.loader import LoaderCheckPoint
|
24 |
+
|
25 |
+
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
26 |
+
|
27 |
+
|
28 |
+
class BaseResponse(BaseModel):
|
29 |
+
code: int = pydantic.Field(200, description="HTTP status code")
|
30 |
+
msg: str = pydantic.Field("success", description="HTTP status message")
|
31 |
+
|
32 |
+
class Config:
|
33 |
+
schema_extra = {
|
34 |
+
"example": {
|
35 |
+
"code": 200,
|
36 |
+
"msg": "success",
|
37 |
+
}
|
38 |
+
}
|
39 |
+
|
40 |
+
|
41 |
+
class ListDocsResponse(BaseResponse):
|
42 |
+
data: List[str] = pydantic.Field(..., description="List of document names")
|
43 |
+
|
44 |
+
class Config:
|
45 |
+
schema_extra = {
|
46 |
+
"example": {
|
47 |
+
"code": 200,
|
48 |
+
"msg": "success",
|
49 |
+
"data": ["doc1.docx", "doc2.pdf", "doc3.txt"],
|
50 |
+
}
|
51 |
+
}
|
52 |
+
|
53 |
+
|
54 |
+
class ChatMessage(BaseModel):
|
55 |
+
question: str = pydantic.Field(..., description="Question text")
|
56 |
+
response: str = pydantic.Field(..., description="Response text")
|
57 |
+
history: List[List[str]] = pydantic.Field(..., description="History text")
|
58 |
+
source_documents: List[str] = pydantic.Field(
|
59 |
+
..., description="List of source documents and their scores"
|
60 |
+
)
|
61 |
+
|
62 |
+
class Config:
|
63 |
+
schema_extra = {
|
64 |
+
"example": {
|
65 |
+
"question": "工伤保险如何办理?",
|
66 |
+
"response": "根据已知信息,可以总结如下:\n\n1. 参保单位为员工缴纳工伤保险费,以保障员工在发生工伤时能够获得相应的待遇。\n2. 不同地区的工伤保险缴费规定可能有所不同,需要向当地社保部门咨询以了解具体的缴费标准和规定。\n3. 工伤从业人员及其近亲属需要申请工伤认定,确认享受的待遇资格,并按时缴纳工伤保险费。\n4. 工伤保险待遇包括工伤医疗、康复、辅助器具配置费用、伤残待遇、工亡待遇、一次性工亡补助金等。\n5. 工伤保险待遇领取资格认证包括长期待遇领取人员认证和一次性待遇领取人员认证。\n6. 工伤保险基金支付的待遇项目包括工伤医疗待遇、康复待遇、辅助器具配置费用、一次性工亡补助金、丧葬补助金等。",
|
67 |
+
"history": [
|
68 |
+
[
|
69 |
+
"工伤保险是什么?",
|
70 |
+
"工伤保险是指用人单位按照国家规定,为本单位的职工和用人单位的其他人员,缴纳工伤保险费,由保险机构按照国家规定的标准,给予工伤保险待遇的社会保险制度。",
|
71 |
+
]
|
72 |
+
],
|
73 |
+
"source_documents": [
|
74 |
+
"出处 [1] 广州市单位从业的特定人员参加工伤保险办事指引.docx:\n\n\t( 一) 从业单位 (组织) 按“自愿参保”原则, 为未建 立劳动关系的特定从业人员单项参加工伤保险 、缴纳工伤保 险费。",
|
75 |
+
"出处 [2] ...",
|
76 |
+
"出处 [3] ...",
|
77 |
+
],
|
78 |
+
}
|
79 |
+
}
|
80 |
+
|
81 |
+
|
82 |
+
def get_folder_path(local_doc_id: str):
|
83 |
+
return os.path.join(KB_ROOT_PATH, local_doc_id, "content")
|
84 |
+
|
85 |
+
|
86 |
+
def get_vs_path(local_doc_id: str):
|
87 |
+
return os.path.join(KB_ROOT_PATH, local_doc_id, "vector_store")
|
88 |
+
|
89 |
+
|
90 |
+
def get_file_path(local_doc_id: str, doc_name: str):
|
91 |
+
return os.path.join(KB_ROOT_PATH, local_doc_id, "content", doc_name)
|
92 |
+
|
93 |
+
|
94 |
+
async def upload_file(
|
95 |
+
file: UploadFile = File(description="A single binary file"),
|
96 |
+
knowledge_base_id: str = Form(..., description="Knowledge Base Name", example="kb1"),
|
97 |
+
):
|
98 |
+
saved_path = get_folder_path(knowledge_base_id)
|
99 |
+
if not os.path.exists(saved_path):
|
100 |
+
os.makedirs(saved_path)
|
101 |
+
|
102 |
+
file_content = await file.read() # 读取上传文件的内容
|
103 |
+
|
104 |
+
file_path = os.path.join(saved_path, file.filename)
|
105 |
+
if os.path.exists(file_path) and os.path.getsize(file_path) == len(file_content):
|
106 |
+
file_status = f"文件 {file.filename} 已存在。"
|
107 |
+
return BaseResponse(code=200, msg=file_status)
|
108 |
+
|
109 |
+
with open(file_path, "wb") as f:
|
110 |
+
f.write(file_content)
|
111 |
+
|
112 |
+
vs_path = get_vs_path(knowledge_base_id)
|
113 |
+
vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store([file_path], vs_path)
|
114 |
+
if len(loaded_files) > 0:
|
115 |
+
file_status = f"文件 {file.filename} 已上传至新的知识库,并已加载知识库,请开始提问。"
|
116 |
+
return BaseResponse(code=200, msg=file_status)
|
117 |
+
else:
|
118 |
+
file_status = "文件上传失败,请重新上传"
|
119 |
+
return BaseResponse(code=500, msg=file_status)
|
120 |
+
|
121 |
+
|
122 |
+
async def upload_files(
|
123 |
+
files: Annotated[
|
124 |
+
List[UploadFile], File(description="Multiple files as UploadFile")
|
125 |
+
],
|
126 |
+
knowledge_base_id: str = Form(..., description="Knowledge Base Name", example="kb1"),
|
127 |
+
):
|
128 |
+
saved_path = get_folder_path(knowledge_base_id)
|
129 |
+
if not os.path.exists(saved_path):
|
130 |
+
os.makedirs(saved_path)
|
131 |
+
filelist = []
|
132 |
+
for file in files:
|
133 |
+
file_content = ''
|
134 |
+
file_path = os.path.join(saved_path, file.filename)
|
135 |
+
file_content = file.file.read()
|
136 |
+
if os.path.exists(file_path) and os.path.getsize(file_path) == len(file_content):
|
137 |
+
continue
|
138 |
+
with open(file_path, "ab+") as f:
|
139 |
+
f.write(file_content)
|
140 |
+
filelist.append(file_path)
|
141 |
+
if filelist:
|
142 |
+
vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(filelist, get_vs_path(knowledge_base_id))
|
143 |
+
if len(loaded_files):
|
144 |
+
file_status = f"documents {', '.join([os.path.split(i)[-1] for i in loaded_files])} upload success"
|
145 |
+
return BaseResponse(code=200, msg=file_status)
|
146 |
+
file_status = f"documents {', '.join([os.path.split(i)[-1] for i in loaded_files])} upload fail"
|
147 |
+
return BaseResponse(code=500, msg=file_status)
|
148 |
+
|
149 |
+
|
150 |
+
async def list_kbs():
|
151 |
+
# Get List of Knowledge Base
|
152 |
+
if not os.path.exists(KB_ROOT_PATH):
|
153 |
+
all_doc_ids = []
|
154 |
+
else:
|
155 |
+
all_doc_ids = [
|
156 |
+
folder
|
157 |
+
for folder in os.listdir(KB_ROOT_PATH)
|
158 |
+
if os.path.isdir(os.path.join(KB_ROOT_PATH, folder))
|
159 |
+
and os.path.exists(os.path.join(KB_ROOT_PATH, folder, "vector_store", "index.faiss"))
|
160 |
+
]
|
161 |
+
|
162 |
+
return ListDocsResponse(data=all_doc_ids)
|
163 |
+
|
164 |
+
|
165 |
+
async def list_docs(
|
166 |
+
knowledge_base_id: Optional[str] = Query(default=None, description="Knowledge Base Name", example="kb1")
|
167 |
+
):
|
168 |
+
local_doc_folder = get_folder_path(knowledge_base_id)
|
169 |
+
if not os.path.exists(local_doc_folder):
|
170 |
+
return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"}
|
171 |
+
all_doc_names = [
|
172 |
+
doc
|
173 |
+
for doc in os.listdir(local_doc_folder)
|
174 |
+
if os.path.isfile(os.path.join(local_doc_folder, doc))
|
175 |
+
]
|
176 |
+
return ListDocsResponse(data=all_doc_names)
|
177 |
+
|
178 |
+
|
179 |
+
async def delete_kb(
|
180 |
+
knowledge_base_id: str = Query(...,
|
181 |
+
description="Knowledge Base Name",
|
182 |
+
example="kb1"),
|
183 |
+
):
|
184 |
+
# TODO: 确认是否支持批量删除知识库
|
185 |
+
knowledge_base_id = urllib.parse.unquote(knowledge_base_id)
|
186 |
+
if not os.path.exists(get_folder_path(knowledge_base_id)):
|
187 |
+
return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"}
|
188 |
+
shutil.rmtree(get_folder_path(knowledge_base_id))
|
189 |
+
return BaseResponse(code=200, msg=f"Knowledge Base {knowledge_base_id} delete success")
|
190 |
+
|
191 |
+
|
192 |
+
async def delete_doc(
|
193 |
+
knowledge_base_id: str = Query(...,
|
194 |
+
description="Knowledge Base Name",
|
195 |
+
example="kb1"),
|
196 |
+
doc_name: str = Query(
|
197 |
+
None, description="doc name", example="doc_name_1.pdf"
|
198 |
+
),
|
199 |
+
):
|
200 |
+
knowledge_base_id = urllib.parse.unquote(knowledge_base_id)
|
201 |
+
if not os.path.exists(get_folder_path(knowledge_base_id)):
|
202 |
+
return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"}
|
203 |
+
doc_path = get_file_path(knowledge_base_id, doc_name)
|
204 |
+
if os.path.exists(doc_path):
|
205 |
+
os.remove(doc_path)
|
206 |
+
remain_docs = await list_docs(knowledge_base_id)
|
207 |
+
if len(remain_docs.data) == 0:
|
208 |
+
shutil.rmtree(get_folder_path(knowledge_base_id), ignore_errors=True)
|
209 |
+
return BaseResponse(code=200, msg=f"document {doc_name} delete success")
|
210 |
+
else:
|
211 |
+
status = local_doc_qa.delete_file_from_vector_store(doc_path, get_vs_path(knowledge_base_id))
|
212 |
+
if "success" in status:
|
213 |
+
return BaseResponse(code=200, msg=f"document {doc_name} delete success")
|
214 |
+
else:
|
215 |
+
return BaseResponse(code=1, msg=f"document {doc_name} delete fail")
|
216 |
+
else:
|
217 |
+
return BaseResponse(code=1, msg=f"document {doc_name} not found")
|
218 |
+
|
219 |
+
|
220 |
+
async def update_doc(
|
221 |
+
knowledge_base_id: str = Query(...,
|
222 |
+
description="知识库名",
|
223 |
+
example="kb1"),
|
224 |
+
old_doc: str = Query(
|
225 |
+
None, description="待删除文件名,已存储在知识库中", example="doc_name_1.pdf"
|
226 |
+
),
|
227 |
+
new_doc: UploadFile = File(description="待上传文件"),
|
228 |
+
):
|
229 |
+
knowledge_base_id = urllib.parse.unquote(knowledge_base_id)
|
230 |
+
if not os.path.exists(get_folder_path(knowledge_base_id)):
|
231 |
+
return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"}
|
232 |
+
doc_path = get_file_path(knowledge_base_id, old_doc)
|
233 |
+
if not os.path.exists(doc_path):
|
234 |
+
return BaseResponse(code=1, msg=f"document {old_doc} not found")
|
235 |
+
else:
|
236 |
+
os.remove(doc_path)
|
237 |
+
delete_status = local_doc_qa.delete_file_from_vector_store(doc_path, get_vs_path(knowledge_base_id))
|
238 |
+
if "fail" in delete_status:
|
239 |
+
return BaseResponse(code=1, msg=f"document {old_doc} delete failed")
|
240 |
+
else:
|
241 |
+
saved_path = get_folder_path(knowledge_base_id)
|
242 |
+
if not os.path.exists(saved_path):
|
243 |
+
os.makedirs(saved_path)
|
244 |
+
|
245 |
+
file_content = await new_doc.read() # 读取上传文件的内容
|
246 |
+
|
247 |
+
file_path = os.path.join(saved_path, new_doc.filename)
|
248 |
+
if os.path.exists(file_path) and os.path.getsize(file_path) == len(file_content):
|
249 |
+
file_status = f"document {new_doc.filename} already exists"
|
250 |
+
return BaseResponse(code=200, msg=file_status)
|
251 |
+
|
252 |
+
with open(file_path, "wb") as f:
|
253 |
+
f.write(file_content)
|
254 |
+
|
255 |
+
vs_path = get_vs_path(knowledge_base_id)
|
256 |
+
vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store([file_path], vs_path)
|
257 |
+
if len(loaded_files) > 0:
|
258 |
+
file_status = f"document {old_doc} delete and document {new_doc.filename} upload success"
|
259 |
+
return BaseResponse(code=200, msg=file_status)
|
260 |
+
else:
|
261 |
+
file_status = f"document {old_doc} success but document {new_doc.filename} upload fail"
|
262 |
+
return BaseResponse(code=500, msg=file_status)
|
263 |
+
|
264 |
+
|
265 |
+
|
266 |
+
async def local_doc_chat(
|
267 |
+
knowledge_base_id: str = Body(..., description="Knowledge Base Name", example="kb1"),
|
268 |
+
question: str = Body(..., description="Question", example="工伤保险是什么?"),
|
269 |
+
history: List[List[str]] = Body(
|
270 |
+
[],
|
271 |
+
description="History of previous questions and answers",
|
272 |
+
example=[
|
273 |
+
[
|
274 |
+
"工伤保险是什么?",
|
275 |
+
"工伤保险是指用人单位按照国家规定,为本单位的职工和用人单位的其他人员,缴纳工伤保险费,由保险机构按照国家规定的标准,给予工伤保险待遇的社会保险制度。",
|
276 |
+
]
|
277 |
+
],
|
278 |
+
),
|
279 |
+
):
|
280 |
+
vs_path = get_vs_path(knowledge_base_id)
|
281 |
+
if not os.path.exists(vs_path):
|
282 |
+
# return BaseResponse(code=1, msg=f"Knowledge base {knowledge_base_id} not found")
|
283 |
+
return ChatMessage(
|
284 |
+
question=question,
|
285 |
+
response=f"Knowledge base {knowledge_base_id} not found",
|
286 |
+
history=history,
|
287 |
+
source_documents=[],
|
288 |
+
)
|
289 |
+
else:
|
290 |
+
for resp, history in local_doc_qa.get_knowledge_based_answer(
|
291 |
+
query=question, vs_path=vs_path, chat_history=history, streaming=True
|
292 |
+
):
|
293 |
+
pass
|
294 |
+
source_documents = [
|
295 |
+
f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n"""
|
296 |
+
f"""相关度:{doc.metadata['score']}\n\n"""
|
297 |
+
for inum, doc in enumerate(resp["source_documents"])
|
298 |
+
]
|
299 |
+
|
300 |
+
return ChatMessage(
|
301 |
+
question=question,
|
302 |
+
response=resp["result"],
|
303 |
+
history=history,
|
304 |
+
source_documents=source_documents,
|
305 |
+
)
|
306 |
+
|
307 |
+
|
308 |
+
async def bing_search_chat(
|
309 |
+
question: str = Body(..., description="Question", example="工伤保险是什么?"),
|
310 |
+
history: Optional[List[List[str]]] = Body(
|
311 |
+
[],
|
312 |
+
description="History of previous questions and answers",
|
313 |
+
example=[
|
314 |
+
[
|
315 |
+
"工伤保险是什么?",
|
316 |
+
"工伤保险是指用人单位按照国家规定,为本单位的职工和用人单位的其他人员,缴纳工伤保险费,由保险机构按照国家规定的标准,给予工伤保险待遇的社会保险制度。",
|
317 |
+
]
|
318 |
+
],
|
319 |
+
),
|
320 |
+
):
|
321 |
+
for resp, history in local_doc_qa.get_search_result_based_answer(
|
322 |
+
query=question, chat_history=history, streaming=True
|
323 |
+
):
|
324 |
+
pass
|
325 |
+
source_documents = [
|
326 |
+
f"""出处 [{inum + 1}] [{doc.metadata["source"]}]({doc.metadata["source"]}) \n\n{doc.page_content}\n\n"""
|
327 |
+
for inum, doc in enumerate(resp["source_documents"])
|
328 |
+
]
|
329 |
+
|
330 |
+
return ChatMessage(
|
331 |
+
question=question,
|
332 |
+
response=resp["result"],
|
333 |
+
history=history,
|
334 |
+
source_documents=source_documents,
|
335 |
+
)
|
336 |
+
|
337 |
+
|
338 |
+
async def chat(
|
339 |
+
question: str = Body(..., description="Question", example="工伤保险是什么?"),
|
340 |
+
history: List[List[str]] = Body(
|
341 |
+
[],
|
342 |
+
description="History of previous questions and answers",
|
343 |
+
example=[
|
344 |
+
[
|
345 |
+
"工伤保险是什么?",
|
346 |
+
"工伤保险是指用人单位按照国家规定,为本单位的职工和用人单位的其他人员,缴纳工伤保险费,由保险机构按照���家规定的标准,给予工伤保险待遇的社会保险制度。",
|
347 |
+
]
|
348 |
+
],
|
349 |
+
),
|
350 |
+
):
|
351 |
+
for answer_result in local_doc_qa.llm.generatorAnswer(prompt=question, history=history,
|
352 |
+
streaming=True):
|
353 |
+
resp = answer_result.llm_output["answer"]
|
354 |
+
history = answer_result.history
|
355 |
+
pass
|
356 |
+
|
357 |
+
return ChatMessage(
|
358 |
+
question=question,
|
359 |
+
response=resp,
|
360 |
+
history=history,
|
361 |
+
source_documents=[],
|
362 |
+
)
|
363 |
+
|
364 |
+
|
365 |
+
async def stream_chat(websocket: WebSocket, knowledge_base_id: str):
|
366 |
+
await websocket.accept()
|
367 |
+
turn = 1
|
368 |
+
while True:
|
369 |
+
input_json = await websocket.receive_json()
|
370 |
+
question, history, knowledge_base_id = input_json["question"], input_json["history"], input_json[
|
371 |
+
"knowledge_base_id"]
|
372 |
+
vs_path = get_vs_path(knowledge_base_id)
|
373 |
+
|
374 |
+
if not os.path.exists(vs_path):
|
375 |
+
await websocket.send_json({"error": f"Knowledge base {knowledge_base_id} not found"})
|
376 |
+
await websocket.close()
|
377 |
+
return
|
378 |
+
|
379 |
+
await websocket.send_json({"question": question, "turn": turn, "flag": "start"})
|
380 |
+
|
381 |
+
last_print_len = 0
|
382 |
+
for resp, history in local_doc_qa.get_knowledge_based_answer(
|
383 |
+
query=question, vs_path=vs_path, chat_history=history, streaming=True
|
384 |
+
):
|
385 |
+
await websocket.send_text(resp["result"][last_print_len:])
|
386 |
+
last_print_len = len(resp["result"])
|
387 |
+
|
388 |
+
source_documents = [
|
389 |
+
f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n"""
|
390 |
+
f"""相关度:{doc.metadata['score']}\n\n"""
|
391 |
+
for inum, doc in enumerate(resp["source_documents"])
|
392 |
+
]
|
393 |
+
|
394 |
+
await websocket.send_text(
|
395 |
+
json.dumps(
|
396 |
+
{
|
397 |
+
"question": question,
|
398 |
+
"turn": turn,
|
399 |
+
"flag": "end",
|
400 |
+
"sources_documents": source_documents,
|
401 |
+
},
|
402 |
+
ensure_ascii=False,
|
403 |
+
)
|
404 |
+
)
|
405 |
+
turn += 1
|
406 |
+
|
407 |
+
|
408 |
+
async def document():
|
409 |
+
return RedirectResponse(url="/docs")
|
410 |
+
|
411 |
+
|
412 |
+
def api_start(host, port):
|
413 |
+
global app
|
414 |
+
global local_doc_qa
|
415 |
+
|
416 |
+
llm_model_ins = shared.loaderLLM()
|
417 |
+
llm_model_ins.set_history_len(LLM_HISTORY_LEN)
|
418 |
+
|
419 |
+
app = FastAPI()
|
420 |
+
# Add CORS middleware to allow all origins
|
421 |
+
# 在config.py中设置OPEN_DOMAIN=True,允许跨域
|
422 |
+
# set OPEN_DOMAIN=True in config.py to allow cross-domain
|
423 |
+
if OPEN_CROSS_DOMAIN:
|
424 |
+
app.add_middleware(
|
425 |
+
CORSMiddleware,
|
426 |
+
allow_origins=["*"],
|
427 |
+
allow_credentials=True,
|
428 |
+
allow_methods=["*"],
|
429 |
+
allow_headers=["*"],
|
430 |
+
)
|
431 |
+
app.websocket("/local_doc_qa/stream-chat/{knowledge_base_id}")(stream_chat)
|
432 |
+
|
433 |
+
app.get("/", response_model=BaseResponse)(document)
|
434 |
+
|
435 |
+
app.post("/chat", response_model=ChatMessage)(chat)
|
436 |
+
|
437 |
+
app.post("/local_doc_qa/upload_file", response_model=BaseResponse)(upload_file)
|
438 |
+
app.post("/local_doc_qa/upload_files", response_model=BaseResponse)(upload_files)
|
439 |
+
app.post("/local_doc_qa/local_doc_chat", response_model=ChatMessage)(local_doc_chat)
|
440 |
+
app.post("/local_doc_qa/bing_search_chat", response_model=ChatMessage)(bing_search_chat)
|
441 |
+
app.get("/local_doc_qa/list_knowledge_base", response_model=ListDocsResponse)(list_kbs)
|
442 |
+
app.get("/local_doc_qa/list_files", response_model=ListDocsResponse)(list_docs)
|
443 |
+
app.delete("/local_doc_qa/delete_knowledge_base", response_model=BaseResponse)(delete_kb)
|
444 |
+
app.delete("/local_doc_qa/delete_file", response_model=BaseResponse)(delete_doc)
|
445 |
+
app.post("/local_doc_qa/update_file", response_model=BaseResponse)(update_doc)
|
446 |
+
|
447 |
+
local_doc_qa = LocalDocQA()
|
448 |
+
local_doc_qa.init_cfg(
|
449 |
+
llm_model=llm_model_ins,
|
450 |
+
embedding_model=EMBEDDING_MODEL,
|
451 |
+
embedding_device=EMBEDDING_DEVICE,
|
452 |
+
top_k=VECTOR_SEARCH_TOP_K,
|
453 |
+
)
|
454 |
+
uvicorn.run(app, host=host, port=port)
|
455 |
+
|
456 |
+
|
457 |
+
if __name__ == "__main__":
|
458 |
+
parser.add_argument("--host", type=str, default="0.0.0.0")
|
459 |
+
parser.add_argument("--port", type=int, default=7861)
|
460 |
+
# 初始化消息
|
461 |
+
args = None
|
462 |
+
args = parser.parse_args()
|
463 |
+
args_dict = vars(args)
|
464 |
+
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
|
465 |
+
api_start(args.host, args.port)
|
cli.bat
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
@echo off
|
2 |
+
python cli.py %*
|
cli.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import click
|
2 |
+
|
3 |
+
from api import api_start as api_start
|
4 |
+
from cli_demo import main as cli_start
|
5 |
+
from configs.model_config import llm_model_dict, embedding_model_dict
|
6 |
+
|
7 |
+
|
8 |
+
@click.group()
|
9 |
+
@click.version_option(version='1.0.0')
|
10 |
+
@click.pass_context
|
11 |
+
def cli(ctx):
|
12 |
+
pass
|
13 |
+
|
14 |
+
|
15 |
+
@cli.group()
|
16 |
+
def llm():
|
17 |
+
pass
|
18 |
+
|
19 |
+
|
20 |
+
@llm.command(name="ls")
|
21 |
+
def llm_ls():
|
22 |
+
for k in llm_model_dict.keys():
|
23 |
+
print(k)
|
24 |
+
|
25 |
+
|
26 |
+
@cli.group()
|
27 |
+
def embedding():
|
28 |
+
pass
|
29 |
+
|
30 |
+
|
31 |
+
@embedding.command(name="ls")
|
32 |
+
def embedding_ls():
|
33 |
+
for k in embedding_model_dict.keys():
|
34 |
+
print(k)
|
35 |
+
|
36 |
+
|
37 |
+
@cli.group()
|
38 |
+
def start():
|
39 |
+
pass
|
40 |
+
|
41 |
+
|
42 |
+
@start.command(name="api", context_settings=dict(help_option_names=['-h', '--help']))
|
43 |
+
@click.option('-i', '--ip', default='0.0.0.0', show_default=True, type=str, help='api_server listen address.')
|
44 |
+
@click.option('-p', '--port', default=7861, show_default=True, type=int, help='api_server listen port.')
|
45 |
+
def start_api(ip, port):
|
46 |
+
# 调用api_start之前需要先loadCheckPoint,并传入加载检查点的参数,
|
47 |
+
# 理论上可以用click包进行包装,但过于繁琐,改动较大,
|
48 |
+
# 此处仍用parser包,并以models.loader.args.DEFAULT_ARGS的参数为默认参数
|
49 |
+
# 如有改动需要可以更改models.loader.args.DEFAULT_ARGS
|
50 |
+
from models import shared
|
51 |
+
from models.loader import LoaderCheckPoint
|
52 |
+
from models.loader.args import DEFAULT_ARGS
|
53 |
+
shared.loaderCheckPoint = LoaderCheckPoint(DEFAULT_ARGS)
|
54 |
+
api_start(host=ip, port=port)
|
55 |
+
|
56 |
+
# # 通过cli.py调用cli_demo时需要在cli.py里初始化模型,否则会报错:
|
57 |
+
# langchain-ChatGLM: error: unrecognized arguments: start cli
|
58 |
+
# 为此需要先将
|
59 |
+
# args = None
|
60 |
+
# args = parser.parse_args()
|
61 |
+
# args_dict = vars(args)
|
62 |
+
# shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
|
63 |
+
# 语句从main函数里取出放到函数外部
|
64 |
+
# 然后在cli.py里初始化
|
65 |
+
|
66 |
+
@start.command(name="cli", context_settings=dict(help_option_names=['-h', '--help']))
|
67 |
+
def start_cli():
|
68 |
+
print("通过cli.py调用cli_demo...")
|
69 |
+
|
70 |
+
from models import shared
|
71 |
+
from models.loader import LoaderCheckPoint
|
72 |
+
from models.loader.args import DEFAULT_ARGS
|
73 |
+
shared.loaderCheckPoint = LoaderCheckPoint(DEFAULT_ARGS)
|
74 |
+
cli_start()
|
75 |
+
|
76 |
+
# 同cli命令,通过cli.py调用webui时,argparse的初始化需要放到cli.py里,
|
77 |
+
# 但由于webui.py里,模型初始化通过init_model函数实现,也无法简单地分离出主函数,
|
78 |
+
# 因此除非对webui进行大改,否则无法通过python cli.py start webui 调用webui。
|
79 |
+
# 故建议不要通过以上命令启动webui,将下述语句注释掉
|
80 |
+
|
81 |
+
@start.command(name="webui", context_settings=dict(help_option_names=['-h', '--help']))
|
82 |
+
def start_webui():
|
83 |
+
import webui
|
84 |
+
|
85 |
+
|
86 |
+
cli()
|
cli.sh
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
python cli.py "$@"
|
cli_demo.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from configs.model_config import *
|
2 |
+
from chains.local_doc_qa import LocalDocQA
|
3 |
+
import os
|
4 |
+
import nltk
|
5 |
+
from models.loader.args import parser
|
6 |
+
import models.shared as shared
|
7 |
+
from models.loader import LoaderCheckPoint
|
8 |
+
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
9 |
+
|
10 |
+
# Show reply with source text from input document
|
11 |
+
REPLY_WITH_SOURCE = True
|
12 |
+
|
13 |
+
|
14 |
+
def main():
|
15 |
+
|
16 |
+
llm_model_ins = shared.loaderLLM()
|
17 |
+
llm_model_ins.history_len = LLM_HISTORY_LEN
|
18 |
+
|
19 |
+
local_doc_qa = LocalDocQA()
|
20 |
+
local_doc_qa.init_cfg(llm_model=llm_model_ins,
|
21 |
+
embedding_model=EMBEDDING_MODEL,
|
22 |
+
embedding_device=EMBEDDING_DEVICE,
|
23 |
+
top_k=VECTOR_SEARCH_TOP_K)
|
24 |
+
vs_path = None
|
25 |
+
while not vs_path:
|
26 |
+
filepath = input("Input your local knowledge file path 请输入本地知识文件路径:")
|
27 |
+
# 判断 filepath 是否为空,如果为空的话,重新让用户输入,防止用户误触回车
|
28 |
+
if not filepath:
|
29 |
+
continue
|
30 |
+
vs_path, _ = local_doc_qa.init_knowledge_vector_store(filepath)
|
31 |
+
history = []
|
32 |
+
while True:
|
33 |
+
query = input("Input your question 请输入问题:")
|
34 |
+
last_print_len = 0
|
35 |
+
for resp, history in local_doc_qa.get_knowledge_based_answer(query=query,
|
36 |
+
vs_path=vs_path,
|
37 |
+
chat_history=history,
|
38 |
+
streaming=STREAMING):
|
39 |
+
if STREAMING:
|
40 |
+
print(resp["result"][last_print_len:], end="", flush=True)
|
41 |
+
last_print_len = len(resp["result"])
|
42 |
+
else:
|
43 |
+
print(resp["result"])
|
44 |
+
if REPLY_WITH_SOURCE:
|
45 |
+
source_text = [f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n"""
|
46 |
+
# f"""相关度:{doc.metadata['score']}\n\n"""
|
47 |
+
for inum, doc in
|
48 |
+
enumerate(resp["source_documents"])]
|
49 |
+
print("\n\n" + "\n\n".join(source_text))
|
50 |
+
|
51 |
+
|
52 |
+
if __name__ == "__main__":
|
53 |
+
# # 通过cli.py调用cli_demo时需要在cli.py里初始化模型,否则会报错:
|
54 |
+
# langchain-ChatGLM: error: unrecognized arguments: start cli
|
55 |
+
# 为此需要先将
|
56 |
+
# args = None
|
57 |
+
# args = parser.parse_args()
|
58 |
+
# args_dict = vars(args)
|
59 |
+
# shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
|
60 |
+
# 语句从main函数里取出放到函数外部
|
61 |
+
# 然后在cli.py里初始化
|
62 |
+
args = None
|
63 |
+
args = parser.parse_args()
|
64 |
+
args_dict = vars(args)
|
65 |
+
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
|
66 |
+
main()
|
release.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import subprocess
|
3 |
+
import re
|
4 |
+
|
5 |
+
def get_latest_tag():
|
6 |
+
output = subprocess.check_output(['git', 'tag'])
|
7 |
+
tags = output.decode('utf-8').split('\n')[:-1]
|
8 |
+
latest_tag = sorted(tags, key=lambda t: tuple(map(int, re.match(r'v(\d+)\.(\d+)\.(\d+)', t).groups())))[-1]
|
9 |
+
return latest_tag
|
10 |
+
|
11 |
+
def update_version_number(latest_tag, increment):
|
12 |
+
major, minor, patch = map(int, re.match(r'v(\d+)\.(\d+)\.(\d+)', latest_tag).groups())
|
13 |
+
if increment == 'X':
|
14 |
+
major += 1
|
15 |
+
minor, patch = 0, 0
|
16 |
+
elif increment == 'Y':
|
17 |
+
minor += 1
|
18 |
+
patch = 0
|
19 |
+
elif increment == 'Z':
|
20 |
+
patch += 1
|
21 |
+
new_version = f"v{major}.{minor}.{patch}"
|
22 |
+
return new_version
|
23 |
+
|
24 |
+
def main():
|
25 |
+
print("当前最近的Git标签:")
|
26 |
+
latest_tag = get_latest_tag()
|
27 |
+
print(latest_tag)
|
28 |
+
|
29 |
+
print("请选择要递增的版本号部分(X, Y, Z):")
|
30 |
+
increment = input().upper()
|
31 |
+
|
32 |
+
while increment not in ['X', 'Y', 'Z']:
|
33 |
+
print("输入错误,请输入X, Y或Z:")
|
34 |
+
increment = input().upper()
|
35 |
+
|
36 |
+
new_version = update_version_number(latest_tag, increment)
|
37 |
+
print(f"新的版本号为:{new_version}")
|
38 |
+
|
39 |
+
print("确认更新版本号并推送到远程仓库?(y/n)")
|
40 |
+
confirmation = input().lower()
|
41 |
+
|
42 |
+
if confirmation == 'y':
|
43 |
+
subprocess.run(['git', 'tag', new_version])
|
44 |
+
subprocess.run(['git', 'push', 'origin', new_version])
|
45 |
+
print("新版本号已创建并推送到远程仓库。")
|
46 |
+
else:
|
47 |
+
print("操作已取消。")
|
48 |
+
|
49 |
+
if __name__ == '__main__':
|
50 |
+
main()
|
requirements.txt
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pymupdf
|
2 |
+
paddlepaddle==2.4.2
|
3 |
+
paddleocr~=2.6.1.3
|
4 |
+
langchain~=0.0.174
|
5 |
+
transformers==4.29.1
|
6 |
+
unstructured[local-inference]
|
7 |
+
layoutparser[layoutmodels,tesseract]
|
8 |
+
nltk~=3.8.1
|
9 |
+
sentence-transformers
|
10 |
+
beautifulsoup4
|
11 |
+
icetk
|
12 |
+
cpm_kernels
|
13 |
+
faiss-cpu
|
14 |
+
gradio==3.28.3
|
15 |
+
fastapi~=0.95.0
|
16 |
+
uvicorn~=0.21.1
|
17 |
+
pypinyin~=0.48.0
|
18 |
+
click~=8.1.3
|
19 |
+
tabulate
|
20 |
+
feedparser
|
21 |
+
azure-core
|
22 |
+
openai
|
23 |
+
#accelerate~=0.18.0
|
24 |
+
#peft~=0.3.0
|
25 |
+
#bitsandbytes; platform_system != "Windows"
|
26 |
+
#llama-cpp-python==0.1.34; platform_system != "Windows"
|
27 |
+
#https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.34/llama_cpp_python-0.1.34-cp310-cp310-win_amd64.whl; platform_system == "Windows"
|
28 |
+
|
29 |
+
#torch~=2.0.0
|
30 |
+
pydantic~=1.10.7
|
31 |
+
starlette~=0.26.1
|
32 |
+
numpy~=1.23.5
|
33 |
+
tqdm~=4.65.0
|
34 |
+
requests~=2.28.2
|
35 |
+
tenacity~=8.2.2
|
36 |
+
charset_normalizer==2.1.0
|
webui.py
ADDED
@@ -0,0 +1,560 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import shutil
|
3 |
+
|
4 |
+
from chains.local_doc_qa import LocalDocQA
|
5 |
+
from configs.model_config import *
|
6 |
+
import nltk
|
7 |
+
import models.shared as shared
|
8 |
+
from models.loader.args import parser
|
9 |
+
from models.loader import LoaderCheckPoint
|
10 |
+
import os
|
11 |
+
|
12 |
+
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
13 |
+
|
14 |
+
|
15 |
+
def get_vs_list():
|
16 |
+
lst_default = ["新建知识库"]
|
17 |
+
if not os.path.exists(KB_ROOT_PATH):
|
18 |
+
return lst_default
|
19 |
+
lst = os.listdir(KB_ROOT_PATH)
|
20 |
+
if not lst:
|
21 |
+
return lst_default
|
22 |
+
lst.sort()
|
23 |
+
return lst_default + lst
|
24 |
+
|
25 |
+
|
26 |
+
embedding_model_dict_list = list(embedding_model_dict.keys())
|
27 |
+
|
28 |
+
llm_model_dict_list = list(llm_model_dict.keys())
|
29 |
+
|
30 |
+
local_doc_qa = LocalDocQA()
|
31 |
+
|
32 |
+
flag_csv_logger = gr.CSVLogger()
|
33 |
+
|
34 |
+
|
35 |
+
def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCORE_THRESHOLD,
|
36 |
+
vector_search_top_k=VECTOR_SEARCH_TOP_K, chunk_conent: bool = True,
|
37 |
+
chunk_size=CHUNK_SIZE, streaming: bool = STREAMING):
|
38 |
+
if mode == "Bing搜索问答":
|
39 |
+
for resp, history in local_doc_qa.get_search_result_based_answer(
|
40 |
+
query=query, chat_history=history, streaming=streaming):
|
41 |
+
source = "\n\n"
|
42 |
+
source += "".join(
|
43 |
+
[
|
44 |
+
f"""<details> <summary>出处 [{i + 1}] <a href="{doc.metadata["source"]}" target="_blank">{doc.metadata["source"]}</a> </summary>\n"""
|
45 |
+
f"""{doc.page_content}\n"""
|
46 |
+
f"""</details>"""
|
47 |
+
for i, doc in
|
48 |
+
enumerate(resp["source_documents"])])
|
49 |
+
history[-1][-1] += source
|
50 |
+
yield history, ""
|
51 |
+
elif mode == "知识库问答" and vs_path is not None and os.path.exists(vs_path) and "index.faiss" in os.listdir(
|
52 |
+
vs_path):
|
53 |
+
for resp, history in local_doc_qa.get_knowledge_based_answer(
|
54 |
+
query=query, vs_path=vs_path, chat_history=history, streaming=streaming):
|
55 |
+
source = "\n\n"
|
56 |
+
source += "".join(
|
57 |
+
[f"""<details> <summary>出处 [{i + 1}] {os.path.split(doc.metadata["source"])[-1]}</summary>\n"""
|
58 |
+
f"""{doc.page_content}\n"""
|
59 |
+
f"""</details>"""
|
60 |
+
for i, doc in
|
61 |
+
enumerate(resp["source_documents"])])
|
62 |
+
history[-1][-1] += source
|
63 |
+
yield history, ""
|
64 |
+
elif mode == "知识库测试":
|
65 |
+
if os.path.exists(vs_path):
|
66 |
+
resp, prompt = local_doc_qa.get_knowledge_based_conent_test(query=query, vs_path=vs_path,
|
67 |
+
score_threshold=score_threshold,
|
68 |
+
vector_search_top_k=vector_search_top_k,
|
69 |
+
chunk_conent=chunk_conent,
|
70 |
+
chunk_size=chunk_size)
|
71 |
+
if not resp["source_documents"]:
|
72 |
+
yield history + [[query,
|
73 |
+
"根据您的设定,没有匹配到任何内容,请确认您设置的知识相关度 Score 阈值是否过小或其他参数是否正确。"]], ""
|
74 |
+
else:
|
75 |
+
source = "\n".join(
|
76 |
+
[
|
77 |
+
f"""<details open> <summary>【知识相关度 Score】:{doc.metadata["score"]} - 【出处{i + 1}】: {os.path.split(doc.metadata["source"])[-1]} </summary>\n"""
|
78 |
+
f"""{doc.page_content}\n"""
|
79 |
+
f"""</details>"""
|
80 |
+
for i, doc in
|
81 |
+
enumerate(resp["source_documents"])])
|
82 |
+
history.append([query, "以下内容为知识库中满足设置条件的匹配结果:\n\n" + source])
|
83 |
+
yield history, ""
|
84 |
+
else:
|
85 |
+
yield history + [[query,
|
86 |
+
"请选择知识库后进行测试,当前未选择知识库。"]], ""
|
87 |
+
else:
|
88 |
+
for answer_result in local_doc_qa.llm.generatorAnswer(prompt=query, history=history,
|
89 |
+
streaming=streaming):
|
90 |
+
resp = answer_result.llm_output["answer"]
|
91 |
+
history = answer_result.history
|
92 |
+
history[-1][-1] = resp
|
93 |
+
yield history, ""
|
94 |
+
logger.info(f"flagging: username={FLAG_USER_NAME},query={query},vs_path={vs_path},mode={mode},history={history}")
|
95 |
+
flag_csv_logger.flag([query, vs_path, history, mode], username=FLAG_USER_NAME)
|
96 |
+
|
97 |
+
|
98 |
+
def init_model():
|
99 |
+
args = parser.parse_args()
|
100 |
+
|
101 |
+
args_dict = vars(args)
|
102 |
+
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
|
103 |
+
llm_model_ins = shared.loaderLLM()
|
104 |
+
llm_model_ins.set_history_len(LLM_HISTORY_LEN)
|
105 |
+
try:
|
106 |
+
local_doc_qa.init_cfg(llm_model=llm_model_ins)
|
107 |
+
generator = local_doc_qa.llm.generatorAnswer("你好")
|
108 |
+
for answer_result in generator:
|
109 |
+
print(answer_result.llm_output)
|
110 |
+
reply = """模型已成��加载,可以开始对话,或从右侧选择模式后开始对话"""
|
111 |
+
logger.info(reply)
|
112 |
+
return reply
|
113 |
+
except Exception as e:
|
114 |
+
logger.error(e)
|
115 |
+
reply = """模型未成功加载,请到页面左上角"模型配置"选项卡中重新选择后点击"加载模型"按钮"""
|
116 |
+
if str(e) == "Unknown platform: darwin":
|
117 |
+
logger.info("该报错可能因为您使用的是 macOS 操作系统,需先下载模型至本地后执行 Web UI,具体方法请参考项目 README 中本地部署方法及常见问题:"
|
118 |
+
" https://github.com/imClumsyPanda/langchain-ChatGLM")
|
119 |
+
else:
|
120 |
+
logger.info(reply)
|
121 |
+
return reply
|
122 |
+
|
123 |
+
|
124 |
+
def reinit_model(llm_model, embedding_model, llm_history_len, no_remote_model, use_ptuning_v2, use_lora, top_k,
|
125 |
+
history):
|
126 |
+
try:
|
127 |
+
llm_model_ins = shared.loaderLLM(llm_model, no_remote_model, use_ptuning_v2)
|
128 |
+
llm_model_ins.history_len = llm_history_len
|
129 |
+
local_doc_qa.init_cfg(llm_model=llm_model_ins,
|
130 |
+
embedding_model=embedding_model,
|
131 |
+
top_k=top_k)
|
132 |
+
model_status = """模型已成功重新加载,可以开始对话,或从右侧选择模式后开始对话"""
|
133 |
+
logger.info(model_status)
|
134 |
+
except Exception as e:
|
135 |
+
logger.error(e)
|
136 |
+
model_status = """模型未成功重新加载,请到页面左上角"模型配置"选项卡中重新选择后点击"加载模型"按钮"""
|
137 |
+
logger.info(model_status)
|
138 |
+
return history + [[None, model_status]]
|
139 |
+
|
140 |
+
|
141 |
+
def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_content_segmentation):
|
142 |
+
vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store")
|
143 |
+
filelist = []
|
144 |
+
if local_doc_qa.llm and local_doc_qa.embeddings:
|
145 |
+
if isinstance(files, list):
|
146 |
+
for file in files:
|
147 |
+
filename = os.path.split(file.name)[-1]
|
148 |
+
shutil.move(file.name, os.path.join(KB_ROOT_PATH, vs_id, "content", filename))
|
149 |
+
filelist.append(os.path.join(KB_ROOT_PATH, vs_id, "content", filename))
|
150 |
+
vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(filelist, vs_path, sentence_size)
|
151 |
+
else:
|
152 |
+
vs_path, loaded_files = local_doc_qa.one_knowledge_add(vs_path, files, one_conent, one_content_segmentation,
|
153 |
+
sentence_size)
|
154 |
+
if len(loaded_files):
|
155 |
+
file_status = f"已添加 {'、'.join([os.path.split(i)[-1] for i in loaded_files if i])} 内容至知识库,并已加载知识库,请开始提问"
|
156 |
+
else:
|
157 |
+
file_status = "文件未成功加载,请重新上传文件"
|
158 |
+
else:
|
159 |
+
file_status = "模型未完成加载,请先在加载模型后再导入文件"
|
160 |
+
vs_path = None
|
161 |
+
logger.info(file_status)
|
162 |
+
return vs_path, None, history + [[None, file_status]], \
|
163 |
+
gr.update(choices=local_doc_qa.list_file_from_vector_store(vs_path) if vs_path else [])
|
164 |
+
|
165 |
+
|
166 |
+
def change_vs_name_input(vs_id, history):
|
167 |
+
if vs_id == "新建知识库":
|
168 |
+
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), None, history,\
|
169 |
+
gr.update(choices=[]), gr.update(visible=False)
|
170 |
+
else:
|
171 |
+
vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store")
|
172 |
+
if "index.faiss" in os.listdir(vs_path):
|
173 |
+
file_status = f"已加载知识库{vs_id},请开始提问"
|
174 |
+
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), \
|
175 |
+
vs_path, history + [[None, file_status]], \
|
176 |
+
gr.update(choices=local_doc_qa.list_file_from_vector_store(vs_path), value=[]), \
|
177 |
+
gr.update(visible=True)
|
178 |
+
else:
|
179 |
+
file_status = f"已选择知识库{vs_id},当前知识库中未上传文件,请先上传文件后,再开始提问"
|
180 |
+
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), \
|
181 |
+
vs_path, history + [[None, file_status]], \
|
182 |
+
gr.update(choices=[], value=[]), gr.update(visible=True, value=[])
|
183 |
+
|
184 |
+
|
185 |
+
knowledge_base_test_mode_info = ("【注意】\n\n"
|
186 |
+
"1. 您已进入知识库测试模式,您输入的任何对话内容都将用于进行知识库查询,"
|
187 |
+
"并仅输出知识库匹配出的内容及相似度分值和及输入的文本源路径,查询的内容并不会进入模型查询。\n\n"
|
188 |
+
"2. 知识相关度 Score 经测试,建议设置为 500 或更低,具体设置情况请结合实际使用调整。"
|
189 |
+
"""3. 使用"添加单条数据"添加文本至知识库时,内容如未分段,则内容越多越会稀释各查询内容与之关联的score阈值。\n\n"""
|
190 |
+
"4. 单条内容长度建议设置在100-150左右。\n\n"
|
191 |
+
"5. 本界面用于知识入库及知识匹配相关参数设定,但当前版本中,"
|
192 |
+
"本界面中修改的参数并不会直接修改对话界面中参数,仍需前往`configs/model_config.py`修改后生效。"
|
193 |
+
"相关参数将在后续版本中支持本界面直接修改。")
|
194 |
+
|
195 |
+
|
196 |
+
def change_mode(mode, history):
|
197 |
+
if mode == "知识库问答":
|
198 |
+
return gr.update(visible=True), gr.update(visible=False), history
|
199 |
+
# + [[None, "【注意】:您已进入知识库问答模式,您输入的任何查询都将进行知识库查询,然后会自动整理知识库关联内容进入模型查询!!!"]]
|
200 |
+
elif mode == "知识库测试":
|
201 |
+
return gr.update(visible=True), gr.update(visible=True), [[None,
|
202 |
+
knowledge_base_test_mode_info]]
|
203 |
+
else:
|
204 |
+
return gr.update(visible=False), gr.update(visible=False), history
|
205 |
+
|
206 |
+
|
207 |
+
def change_chunk_conent(mode, label_conent, history):
|
208 |
+
conent = ""
|
209 |
+
if "chunk_conent" in label_conent:
|
210 |
+
conent = "搜索结果上下文关联"
|
211 |
+
elif "one_content_segmentation" in label_conent: # 这里没用上,可以先留着
|
212 |
+
conent = "内容分段入库"
|
213 |
+
|
214 |
+
if mode:
|
215 |
+
return gr.update(visible=True), history + [[None, f"【已开启{conent}】"]]
|
216 |
+
else:
|
217 |
+
return gr.update(visible=False), history + [[None, f"【已关闭{conent}】"]]
|
218 |
+
|
219 |
+
|
220 |
+
def add_vs_name(vs_name, chatbot):
|
221 |
+
if vs_name in get_vs_list():
|
222 |
+
vs_status = "与已有知识库名称冲突,请重新选择其他名称后提交"
|
223 |
+
chatbot = chatbot + [[None, vs_status]]
|
224 |
+
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(
|
225 |
+
visible=False), chatbot, gr.update(visible=False)
|
226 |
+
else:
|
227 |
+
# 新建上传文件存储路径
|
228 |
+
if not os.path.exists(os.path.join(KB_ROOT_PATH, vs_name, "content")):
|
229 |
+
os.makedirs(os.path.join(KB_ROOT_PATH, vs_name, "content"))
|
230 |
+
# 新建向量库存储路径
|
231 |
+
if not os.path.exists(os.path.join(KB_ROOT_PATH, vs_name, "vector_store")):
|
232 |
+
os.makedirs(os.path.join(KB_ROOT_PATH, vs_name, "vector_store"))
|
233 |
+
vs_status = f"""已新增知识库"{vs_name}",将在上传文件并载入成功后进行存储。请在开始对话前,先完成文件上传。 """
|
234 |
+
chatbot = chatbot + [[None, vs_status]]
|
235 |
+
return gr.update(visible=True, choices=get_vs_list(), value=vs_name), gr.update(
|
236 |
+
visible=False), gr.update(visible=False), gr.update(visible=True), chatbot, gr.update(visible=True)
|
237 |
+
|
238 |
+
|
239 |
+
# 自动化加载固定文件间中文件
|
240 |
+
def reinit_vector_store(vs_id, history):
|
241 |
+
try:
|
242 |
+
shutil.rmtree(os.path.join(KB_ROOT_PATH, vs_id, "vector_store"))
|
243 |
+
vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store")
|
244 |
+
sentence_size = gr.Number(value=SENTENCE_SIZE, precision=0,
|
245 |
+
label="文本入库分句长度限制",
|
246 |
+
interactive=True, visible=True)
|
247 |
+
vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(os.path.join(KB_ROOT_PATH, vs_id, "content"),
|
248 |
+
vs_path, sentence_size)
|
249 |
+
model_status = """知识库构建成功"""
|
250 |
+
except Exception as e:
|
251 |
+
logger.error(e)
|
252 |
+
model_status = """知识库构建未成功"""
|
253 |
+
logger.info(model_status)
|
254 |
+
return history + [[None, model_status]]
|
255 |
+
|
256 |
+
|
257 |
+
def refresh_vs_list():
|
258 |
+
return gr.update(choices=get_vs_list()), gr.update(choices=get_vs_list())
|
259 |
+
|
260 |
+
def delete_file(vs_id, files_to_delete, chatbot):
|
261 |
+
vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store")
|
262 |
+
content_path = os.path.join(KB_ROOT_PATH, vs_id, "content")
|
263 |
+
docs_path = [os.path.join(content_path, file) for file in files_to_delete]
|
264 |
+
status = local_doc_qa.delete_file_from_vector_store(vs_path=vs_path,
|
265 |
+
filepath=docs_path)
|
266 |
+
if "fail" not in status:
|
267 |
+
for doc_path in docs_path:
|
268 |
+
if os.path.exists(doc_path):
|
269 |
+
os.remove(doc_path)
|
270 |
+
rested_files = local_doc_qa.list_file_from_vector_store(vs_path)
|
271 |
+
if "fail" in status:
|
272 |
+
vs_status = "文件删除失败。"
|
273 |
+
elif len(rested_files)>0:
|
274 |
+
vs_status = "文件删除成功。"
|
275 |
+
else:
|
276 |
+
vs_status = f"文件删除成功,知识库{vs_id}中无已上传文件,请先上传文件后,再开始提问。"
|
277 |
+
logger.info(",".join(files_to_delete)+vs_status)
|
278 |
+
chatbot = chatbot + [[None, vs_status]]
|
279 |
+
return gr.update(choices=local_doc_qa.list_file_from_vector_store(vs_path), value=[]), chatbot
|
280 |
+
|
281 |
+
|
282 |
+
def delete_vs(vs_id, chatbot):
|
283 |
+
try:
|
284 |
+
shutil.rmtree(os.path.join(KB_ROOT_PATH, vs_id))
|
285 |
+
status = f"成功删除知识库{vs_id}"
|
286 |
+
logger.info(status)
|
287 |
+
chatbot = chatbot + [[None, status]]
|
288 |
+
return gr.update(choices=get_vs_list(), value=get_vs_list()[0]), gr.update(visible=True), gr.update(visible=True), \
|
289 |
+
gr.update(visible=False), chatbot, gr.update(visible=False)
|
290 |
+
except Exception as e:
|
291 |
+
logger.error(e)
|
292 |
+
status = f"删除知识库{vs_id}失败"
|
293 |
+
chatbot = chatbot + [[None, status]]
|
294 |
+
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), \
|
295 |
+
gr.update(visible=True), chatbot, gr.update(visible=True)
|
296 |
+
|
297 |
+
|
298 |
+
block_css = """.importantButton {
|
299 |
+
background: linear-gradient(45deg, #7e0570,#5d1c99, #6e00ff) !important;
|
300 |
+
border: none !important;
|
301 |
+
}
|
302 |
+
.importantButton:hover {
|
303 |
+
background: linear-gradient(45deg, #ff00e0,#8500ff, #6e00ff) !important;
|
304 |
+
border: none !important;
|
305 |
+
}"""
|
306 |
+
|
307 |
+
webui_title = """
|
308 |
+
# 基于LLM的知识库问答系统(知识迁移算法)
|
309 |
+
"""
|
310 |
+
default_vs = get_vs_list()[0] if len(get_vs_list()) > 1 else "为空"
|
311 |
+
init_message = f"""基于LLM的知识库问答系统(知识迁移算法)
|
312 |
+
|
313 |
+
请在右侧切换模式,目前支持直接与 LLM 模型对话或基于本地知识库问答。
|
314 |
+
知识库问答模式,选择知识库名称后,即可开始问答,当前知识库{default_vs},如有需要可以在选择知识库名称后上传文件/文件夹至知识库。
|
315 |
+
"""
|
316 |
+
|
317 |
+
# 初始化消息
|
318 |
+
model_status = init_model()
|
319 |
+
|
320 |
+
default_theme_args = dict(
|
321 |
+
font=["Source Sans Pro", 'ui-sans-serif', 'system-ui', 'sans-serif'],
|
322 |
+
font_mono=['IBM Plex Mono', 'ui-monospace', 'Consolas', 'monospace'],
|
323 |
+
)
|
324 |
+
|
325 |
+
with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as demo:
|
326 |
+
vs_path, file_status, model_status = gr.State(
|
327 |
+
os.path.join(KB_ROOT_PATH, get_vs_list()[0], "vector_store") if len(get_vs_list()) > 1 else ""), gr.State(""), gr.State(
|
328 |
+
model_status)
|
329 |
+
gr.Markdown(webui_title)
|
330 |
+
with gr.Tab("对话"):
|
331 |
+
with gr.Row():
|
332 |
+
with gr.Column(scale=10):
|
333 |
+
chatbot = gr.Chatbot([[None, init_message], [None, model_status.value]],
|
334 |
+
elem_id="chat-box",
|
335 |
+
show_label=False).style(height=750)
|
336 |
+
query = gr.Textbox(show_label=False,
|
337 |
+
placeholder="请输入提问内容,按回车进行提交").style(container=False)
|
338 |
+
with gr.Column(scale=5):
|
339 |
+
mode = gr.Radio(["LLM 对话", "知识库问答", "Bing搜索问答"],
|
340 |
+
label="请选择使用模式",
|
341 |
+
value="知识库问答", )
|
342 |
+
knowledge_set = gr.Accordion("知识库设定", visible=False)
|
343 |
+
vs_setting = gr.Accordion("配置知识库")
|
344 |
+
mode.change(fn=change_mode,
|
345 |
+
inputs=[mode, chatbot],
|
346 |
+
outputs=[vs_setting, knowledge_set, chatbot])
|
347 |
+
with vs_setting:
|
348 |
+
vs_refresh = gr.Button("更新已有知识库选项")
|
349 |
+
select_vs = gr.Dropdown(get_vs_list(),
|
350 |
+
label="请选择要加载的知识库",
|
351 |
+
interactive=True,
|
352 |
+
value=get_vs_list()[0] if len(get_vs_list()) > 0 else None
|
353 |
+
)
|
354 |
+
vs_name = gr.Textbox(label="请输入新建知识库名称,当前知识库命名暂不支持中文",
|
355 |
+
lines=1,
|
356 |
+
interactive=True,
|
357 |
+
visible=True)
|
358 |
+
vs_add = gr.Button(value="添加至知识库选项", visible=True)
|
359 |
+
vs_delete = gr.Button("删除本知识库", visible=False)
|
360 |
+
file2vs = gr.Column(visible=False)
|
361 |
+
with file2vs:
|
362 |
+
# load_vs = gr.Button("加载知识库")
|
363 |
+
gr.Markdown("向知识库中添加文件")
|
364 |
+
sentence_size = gr.Number(value=SENTENCE_SIZE, precision=0,
|
365 |
+
label="文本入库分句长度限制",
|
366 |
+
interactive=True, visible=True)
|
367 |
+
with gr.Tab("上传文件"):
|
368 |
+
files = gr.File(label="添加文件",
|
369 |
+
file_types=['.txt', '.md', '.docx', '.pdf', '.png', '.jpg', ".csv"],
|
370 |
+
file_count="multiple",
|
371 |
+
show_label=False)
|
372 |
+
load_file_button = gr.Button("上传文件并加载知识库")
|
373 |
+
with gr.Tab("上传文件夹"):
|
374 |
+
folder_files = gr.File(label="添加文件",
|
375 |
+
file_count="directory",
|
376 |
+
show_label=False)
|
377 |
+
load_folder_button = gr.Button("上传文件夹并加载��识库")
|
378 |
+
with gr.Tab("删除文件"):
|
379 |
+
files_to_delete = gr.CheckboxGroup(choices=[],
|
380 |
+
label="请从知识库已有文件中选择要删除的文件",
|
381 |
+
interactive=True)
|
382 |
+
delete_file_button = gr.Button("从知识库中删除选中文件")
|
383 |
+
vs_refresh.click(fn=refresh_vs_list,
|
384 |
+
inputs=[],
|
385 |
+
outputs=select_vs)
|
386 |
+
vs_add.click(fn=add_vs_name,
|
387 |
+
inputs=[vs_name, chatbot],
|
388 |
+
outputs=[select_vs, vs_name, vs_add, file2vs, chatbot, vs_delete])
|
389 |
+
vs_delete.click(fn=delete_vs,
|
390 |
+
inputs=[select_vs, chatbot],
|
391 |
+
outputs=[select_vs, vs_name, vs_add, file2vs, chatbot, vs_delete])
|
392 |
+
select_vs.change(fn=change_vs_name_input,
|
393 |
+
inputs=[select_vs, chatbot],
|
394 |
+
outputs=[vs_name, vs_add, file2vs, vs_path, chatbot, files_to_delete, vs_delete])
|
395 |
+
load_file_button.click(get_vector_store,
|
396 |
+
show_progress=True,
|
397 |
+
inputs=[select_vs, files, sentence_size, chatbot, vs_add, vs_add],
|
398 |
+
outputs=[vs_path, files, chatbot, files_to_delete], )
|
399 |
+
load_folder_button.click(get_vector_store,
|
400 |
+
show_progress=True,
|
401 |
+
inputs=[select_vs, folder_files, sentence_size, chatbot, vs_add,
|
402 |
+
vs_add],
|
403 |
+
outputs=[vs_path, folder_files, chatbot, files_to_delete], )
|
404 |
+
flag_csv_logger.setup([query, vs_path, chatbot, mode], "flagged")
|
405 |
+
query.submit(get_answer,
|
406 |
+
[query, vs_path, chatbot, mode],
|
407 |
+
[chatbot, query])
|
408 |
+
delete_file_button.click(delete_file,
|
409 |
+
show_progress=True,
|
410 |
+
inputs=[select_vs, files_to_delete, chatbot],
|
411 |
+
outputs=[files_to_delete, chatbot])
|
412 |
+
with gr.Tab("知识库测试 Beta"):
|
413 |
+
with gr.Row():
|
414 |
+
with gr.Column(scale=10):
|
415 |
+
chatbot = gr.Chatbot([[None, knowledge_base_test_mode_info]],
|
416 |
+
elem_id="chat-box",
|
417 |
+
show_label=False).style(height=750)
|
418 |
+
query = gr.Textbox(show_label=False,
|
419 |
+
placeholder="请输入提问内容,按回车进行提交").style(container=False)
|
420 |
+
with gr.Column(scale=5):
|
421 |
+
mode = gr.Radio(["知识库测试"], # "知识库问答",
|
422 |
+
label="请选择使用模式",
|
423 |
+
value="知识库测试",
|
424 |
+
visible=False)
|
425 |
+
knowledge_set = gr.Accordion("知识库设定", visible=True)
|
426 |
+
vs_setting = gr.Accordion("配置知识库", visible=True)
|
427 |
+
mode.change(fn=change_mode,
|
428 |
+
inputs=[mode, chatbot],
|
429 |
+
outputs=[vs_setting, knowledge_set, chatbot])
|
430 |
+
with knowledge_set:
|
431 |
+
score_threshold = gr.Number(value=VECTOR_SEARCH_SCORE_THRESHOLD,
|
432 |
+
label="知识相关度 Score 阈值,分值越低匹配度越高",
|
433 |
+
precision=0,
|
434 |
+
interactive=True)
|
435 |
+
vector_search_top_k = gr.Number(value=VECTOR_SEARCH_TOP_K, precision=0,
|
436 |
+
label="获取知识库内容条数", interactive=True)
|
437 |
+
chunk_conent = gr.Checkbox(value=False,
|
438 |
+
label="是否启用上下文关联",
|
439 |
+
interactive=True)
|
440 |
+
chunk_sizes = gr.Number(value=CHUNK_SIZE, precision=0,
|
441 |
+
label="匹配单段内容的连接上下文后最大长度",
|
442 |
+
interactive=True, visible=False)
|
443 |
+
chunk_conent.change(fn=change_chunk_conent,
|
444 |
+
inputs=[chunk_conent, gr.Textbox(value="chunk_conent", visible=False), chatbot],
|
445 |
+
outputs=[chunk_sizes, chatbot])
|
446 |
+
with vs_setting:
|
447 |
+
vs_refresh = gr.Button("更新已有知识库选项")
|
448 |
+
select_vs_test = gr.Dropdown(get_vs_list(),
|
449 |
+
label="请选择要加载的知识库",
|
450 |
+
interactive=True,
|
451 |
+
value=get_vs_list()[0] if len(get_vs_list()) > 0 else None)
|
452 |
+
vs_name = gr.Textbox(label="请输入新建知识库名称,当前知识库命名暂不支持中文",
|
453 |
+
lines=1,
|
454 |
+
interactive=True,
|
455 |
+
visible=True)
|
456 |
+
vs_add = gr.Button(value="添加至知识库选项", visible=True)
|
457 |
+
file2vs = gr.Column(visible=False)
|
458 |
+
with file2vs:
|
459 |
+
# load_vs = gr.Button("加载知识库")
|
460 |
+
gr.Markdown("向知识库中添加单条内容或文件")
|
461 |
+
sentence_size = gr.Number(value=SENTENCE_SIZE, precision=0,
|
462 |
+
label="文本入库分句长度限制",
|
463 |
+
interactive=True, visible=True)
|
464 |
+
with gr.Tab("上传文件"):
|
465 |
+
files = gr.File(label="添加文件",
|
466 |
+
file_types=['.txt', '.md', '.docx', '.pdf'],
|
467 |
+
file_count="multiple",
|
468 |
+
show_label=False
|
469 |
+
)
|
470 |
+
load_file_button = gr.Button("上传文件并加载知识库")
|
471 |
+
with gr.Tab("上传文件夹"):
|
472 |
+
folder_files = gr.File(label="添加文件",
|
473 |
+
# file_types=['.txt', '.md', '.docx', '.pdf'],
|
474 |
+
file_count="directory",
|
475 |
+
show_label=False)
|
476 |
+
load_folder_button = gr.Button("上传文件夹并加载知识库")
|
477 |
+
with gr.Tab("添加单条内容"):
|
478 |
+
one_title = gr.Textbox(label="标题", placeholder="请输入要添加单条段落的标题", lines=1)
|
479 |
+
one_conent = gr.Textbox(label="内容", placeholder="请输入要添加单条段落的内容", lines=5)
|
480 |
+
one_content_segmentation = gr.Checkbox(value=True, label="禁止内容分句入库",
|
481 |
+
interactive=True)
|
482 |
+
load_conent_button = gr.Button("添加内容并加载知识库")
|
483 |
+
# 将上传的文件保存到content文件夹下,并更新下拉框
|
484 |
+
vs_refresh.click(fn=refresh_vs_list,
|
485 |
+
inputs=[],
|
486 |
+
outputs=select_vs_test)
|
487 |
+
vs_add.click(fn=add_vs_name,
|
488 |
+
inputs=[vs_name, chatbot],
|
489 |
+
outputs=[select_vs_test, vs_name, vs_add, file2vs, chatbot])
|
490 |
+
select_vs_test.change(fn=change_vs_name_input,
|
491 |
+
inputs=[select_vs_test, chatbot],
|
492 |
+
outputs=[vs_name, vs_add, file2vs, vs_path, chatbot])
|
493 |
+
load_file_button.click(get_vector_store,
|
494 |
+
show_progress=True,
|
495 |
+
inputs=[select_vs_test, files, sentence_size, chatbot, vs_add, vs_add],
|
496 |
+
outputs=[vs_path, files, chatbot], )
|
497 |
+
load_folder_button.click(get_vector_store,
|
498 |
+
show_progress=True,
|
499 |
+
inputs=[select_vs_test, folder_files, sentence_size, chatbot, vs_add,
|
500 |
+
vs_add],
|
501 |
+
outputs=[vs_path, folder_files, chatbot], )
|
502 |
+
load_conent_button.click(get_vector_store,
|
503 |
+
show_progress=True,
|
504 |
+
inputs=[select_vs_test, one_title, sentence_size, chatbot,
|
505 |
+
one_conent, one_content_segmentation],
|
506 |
+
outputs=[vs_path, files, chatbot], )
|
507 |
+
flag_csv_logger.setup([query, vs_path, chatbot, mode], "flagged")
|
508 |
+
query.submit(get_answer,
|
509 |
+
[query, vs_path, chatbot, mode, score_threshold, vector_search_top_k, chunk_conent,
|
510 |
+
chunk_sizes],
|
511 |
+
[chatbot, query])
|
512 |
+
with gr.Tab("模型配置"):
|
513 |
+
llm_model = gr.Radio(llm_model_dict_list,
|
514 |
+
label="LLM 模型",
|
515 |
+
value=LLM_MODEL,
|
516 |
+
interactive=True)
|
517 |
+
no_remote_model = gr.Checkbox(shared.LoaderCheckPoint.no_remote_model,
|
518 |
+
label="加载本地模型",
|
519 |
+
interactive=True)
|
520 |
+
|
521 |
+
llm_history_len = gr.Slider(0, 10,
|
522 |
+
value=LLM_HISTORY_LEN,
|
523 |
+
step=1,
|
524 |
+
label="LLM 对话轮数",
|
525 |
+
interactive=True)
|
526 |
+
use_ptuning_v2 = gr.Checkbox(USE_PTUNING_V2,
|
527 |
+
label="使用p-tuning-v2微调过的模型",
|
528 |
+
interactive=True)
|
529 |
+
use_lora = gr.Checkbox(USE_LORA,
|
530 |
+
label="使用lora微调的权重",
|
531 |
+
interactive=True)
|
532 |
+
embedding_model = gr.Radio(embedding_model_dict_list,
|
533 |
+
label="Embedding 模型",
|
534 |
+
value=EMBEDDING_MODEL,
|
535 |
+
interactive=True)
|
536 |
+
top_k = gr.Slider(1, 20, value=VECTOR_SEARCH_TOP_K, step=1,
|
537 |
+
label="向量匹配 top k", interactive=True)
|
538 |
+
load_model_button = gr.Button("重新加载模型")
|
539 |
+
load_model_button.click(reinit_model, show_progress=True,
|
540 |
+
inputs=[llm_model, embedding_model, llm_history_len, no_remote_model, use_ptuning_v2,
|
541 |
+
use_lora, top_k, chatbot], outputs=chatbot)
|
542 |
+
# load_knowlege_button = gr.Button("重新构建知识库")
|
543 |
+
# load_knowlege_button.click(reinit_vector_store, show_progress=True,
|
544 |
+
# inputs=[select_vs, chatbot], outputs=chatbot)
|
545 |
+
demo.css = "footer {visibility: hidden}"
|
546 |
+
demo.load(
|
547 |
+
fn=refresh_vs_list,
|
548 |
+
inputs=None,
|
549 |
+
outputs=[select_vs, select_vs_test],
|
550 |
+
queue=True,
|
551 |
+
show_progress=False,
|
552 |
+
)
|
553 |
+
|
554 |
+
(demo
|
555 |
+
.queue(concurrency_count=3)
|
556 |
+
.launch(server_name='localhost',
|
557 |
+
server_port=7860,
|
558 |
+
show_api=False,
|
559 |
+
share=True,
|
560 |
+
inbrowser=False))
|
webui_st.py
ADDED
@@ -0,0 +1,538 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
# from st_btn_select import st_btn_select
|
3 |
+
import tempfile
|
4 |
+
###### 从webui借用的代码 #####
|
5 |
+
###### 做了少量修改 #####
|
6 |
+
import os
|
7 |
+
import shutil
|
8 |
+
|
9 |
+
from chains.local_doc_qa import LocalDocQA
|
10 |
+
from configs.model_config import *
|
11 |
+
import nltk
|
12 |
+
from models.base import (BaseAnswer,
|
13 |
+
AnswerResult,)
|
14 |
+
import models.shared as shared
|
15 |
+
from models.loader.args import parser
|
16 |
+
from models.loader import LoaderCheckPoint
|
17 |
+
|
18 |
+
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
19 |
+
|
20 |
+
|
21 |
+
def get_vs_list():
|
22 |
+
lst_default = ["新建知识库"]
|
23 |
+
if not os.path.exists(KB_ROOT_PATH):
|
24 |
+
return lst_default
|
25 |
+
lst = os.listdir(KB_ROOT_PATH)
|
26 |
+
if not lst:
|
27 |
+
return lst_default
|
28 |
+
lst.sort()
|
29 |
+
return lst_default + lst
|
30 |
+
|
31 |
+
|
32 |
+
embedding_model_dict_list = list(embedding_model_dict.keys())
|
33 |
+
llm_model_dict_list = list(llm_model_dict.keys())
|
34 |
+
# flag_csv_logger = gr.CSVLogger()
|
35 |
+
|
36 |
+
|
37 |
+
def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCORE_THRESHOLD,
|
38 |
+
vector_search_top_k=VECTOR_SEARCH_TOP_K, chunk_conent: bool = True,
|
39 |
+
chunk_size=CHUNK_SIZE, streaming: bool = STREAMING,):
|
40 |
+
if mode == "Bing搜索问答":
|
41 |
+
for resp, history in local_doc_qa.get_search_result_based_answer(
|
42 |
+
query=query, chat_history=history, streaming=streaming):
|
43 |
+
source = "\n\n"
|
44 |
+
source += "".join(
|
45 |
+
[f"""<details> <summary>出处 [{i + 1}] <a href="{doc.metadata["source"]}" target="_blank">{doc.metadata["source"]}</a> </summary>\n"""
|
46 |
+
f"""{doc.page_content}\n"""
|
47 |
+
f"""</details>"""
|
48 |
+
for i, doc in
|
49 |
+
enumerate(resp["source_documents"])])
|
50 |
+
history[-1][-1] += source
|
51 |
+
yield history, ""
|
52 |
+
elif mode == "知识库问答" and vs_path is not None and os.path.exists(vs_path):
|
53 |
+
for resp, history in local_doc_qa.get_knowledge_based_answer(
|
54 |
+
query=query, vs_path=vs_path, chat_history=history, streaming=streaming):
|
55 |
+
source = "\n\n"
|
56 |
+
source += "".join(
|
57 |
+
[f"""<details> <summary>出处 [{i + 1}] {os.path.split(doc.metadata["source"])[-1]}</summary>\n"""
|
58 |
+
f"""{doc.page_content}\n"""
|
59 |
+
f"""</details>"""
|
60 |
+
for i, doc in
|
61 |
+
enumerate(resp["source_documents"])])
|
62 |
+
history[-1][-1] += source
|
63 |
+
yield history, ""
|
64 |
+
elif mode == "知识库测试":
|
65 |
+
if os.path.exists(vs_path):
|
66 |
+
resp, prompt = local_doc_qa.get_knowledge_based_conent_test(query=query, vs_path=vs_path,
|
67 |
+
score_threshold=score_threshold,
|
68 |
+
vector_search_top_k=vector_search_top_k,
|
69 |
+
chunk_conent=chunk_conent,
|
70 |
+
chunk_size=chunk_size)
|
71 |
+
if not resp["source_documents"]:
|
72 |
+
yield history + [[query,
|
73 |
+
"根据您的设定,没有匹配到任何内容,请确认您设置的知识相关度 Score 阈值是否过小或其他参数是否正确。"]], ""
|
74 |
+
else:
|
75 |
+
source = "\n".join(
|
76 |
+
[
|
77 |
+
f"""<details open> <summary>【知识相关度 Score】:{doc.metadata["score"]} - 【出处{i + 1}】: {os.path.split(doc.metadata["source"])[-1]} </summary>\n"""
|
78 |
+
f"""{doc.page_content}\n"""
|
79 |
+
f"""</details>"""
|
80 |
+
for i, doc in
|
81 |
+
enumerate(resp["source_documents"])])
|
82 |
+
history.append([query, "以下内容为知识库中满足设置条件的匹配结果:\n\n" + source])
|
83 |
+
yield history, ""
|
84 |
+
else:
|
85 |
+
yield history + [[query,
|
86 |
+
"请选择知识库后进行测试,当前未选择知识库。"]], ""
|
87 |
+
else:
|
88 |
+
for answer_result in local_doc_qa.llm.generatorAnswer(prompt=query, history=history,
|
89 |
+
streaming=streaming):
|
90 |
+
|
91 |
+
resp = answer_result.llm_output["answer"]
|
92 |
+
history = answer_result.history
|
93 |
+
history[-1][-1] = resp + (
|
94 |
+
"\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "")
|
95 |
+
yield history, ""
|
96 |
+
logger.info(f"flagging: username={FLAG_USER_NAME},query={query},vs_path={vs_path},mode={mode},history={history}")
|
97 |
+
# flag_csv_logger.flag([query, vs_path, history, mode], username=FLAG_USER_NAME)
|
98 |
+
|
99 |
+
|
100 |
+
def init_model(llm_model: str = 'chat-glm-6b', embedding_model: str = 'text2vec'):
|
101 |
+
local_doc_qa = LocalDocQA()
|
102 |
+
# 初始化消息
|
103 |
+
args = parser.parse_args()
|
104 |
+
args_dict = vars(args)
|
105 |
+
args_dict.update(model=llm_model)
|
106 |
+
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
|
107 |
+
llm_model_ins = shared.loaderLLM()
|
108 |
+
llm_model_ins.set_history_len(LLM_HISTORY_LEN)
|
109 |
+
|
110 |
+
try:
|
111 |
+
local_doc_qa.init_cfg(llm_model=llm_model_ins,
|
112 |
+
embedding_model=embedding_model)
|
113 |
+
generator = local_doc_qa.llm.generatorAnswer("你好")
|
114 |
+
for answer_result in generator:
|
115 |
+
print(answer_result.llm_output)
|
116 |
+
reply = """模型已成功加载,可以开始对话,或从右侧选择模式后开始对话"""
|
117 |
+
logger.info(reply)
|
118 |
+
except Exception as e:
|
119 |
+
logger.error(e)
|
120 |
+
reply = """模型未成功加载,请到页面左上角"模型配置"选项卡中重新选择后点击"加载模型"按钮"""
|
121 |
+
if str(e) == "Unknown platform: darwin":
|
122 |
+
logger.info("该报错可能因为您使用的是 macOS 操作系统,需先下载模型至本地后执行 Web UI,具体方法请参考项目 README 中本地部署方法及常见问题:"
|
123 |
+
" https://github.com/imClumsyPanda/langchain-ChatGLM")
|
124 |
+
else:
|
125 |
+
logger.info(reply)
|
126 |
+
return local_doc_qa
|
127 |
+
|
128 |
+
|
129 |
+
# 暂未使用到,先保留
|
130 |
+
# def reinit_model(llm_model, embedding_model, llm_history_len, no_remote_model, use_ptuning_v2, use_lora, top_k, history):
|
131 |
+
# try:
|
132 |
+
# llm_model_ins = shared.loaderLLM(llm_model, no_remote_model, use_ptuning_v2)
|
133 |
+
# llm_model_ins.history_len = llm_history_len
|
134 |
+
# local_doc_qa.init_cfg(llm_model=llm_model_ins,
|
135 |
+
# embedding_model=embedding_model,
|
136 |
+
# top_k=top_k)
|
137 |
+
# model_status = """模型已成功重新加载,可以开始对话,或从右侧选择模式后开始对话"""
|
138 |
+
# logger.info(model_status)
|
139 |
+
# except Exception as e:
|
140 |
+
# logger.error(e)
|
141 |
+
# model_status = """模型未成功重新加载,请到页面左上角"模型配置"选项卡中重新选择后点击"加载模型"按钮"""
|
142 |
+
# logger.info(model_status)
|
143 |
+
# return history + [[None, model_status]]
|
144 |
+
|
145 |
+
|
146 |
+
def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_content_segmentation):
|
147 |
+
vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store")
|
148 |
+
filelist = []
|
149 |
+
if not os.path.exists(os.path.join(KB_ROOT_PATH, vs_id, "content")):
|
150 |
+
os.makedirs(os.path.join(KB_ROOT_PATH, vs_id, "content"))
|
151 |
+
if local_doc_qa.llm and local_doc_qa.embeddings:
|
152 |
+
if isinstance(files, list):
|
153 |
+
for file in files:
|
154 |
+
filename = os.path.split(file.name)[-1]
|
155 |
+
shutil.move(file.name, os.path.join(
|
156 |
+
KB_ROOT_PATH, vs_id, "content", filename))
|
157 |
+
filelist.append(os.path.join(
|
158 |
+
KB_ROOT_PATH, vs_id, "content", filename))
|
159 |
+
vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(
|
160 |
+
filelist, vs_path, sentence_size)
|
161 |
+
else:
|
162 |
+
vs_path, loaded_files = local_doc_qa.one_knowledge_add(vs_path, files, one_conent, one_content_segmentation,
|
163 |
+
sentence_size)
|
164 |
+
if len(loaded_files):
|
165 |
+
file_status = f"已添加 {'、'.join([os.path.split(i)[-1] for i in loaded_files if i])} 内容至知识库,并已加载知识库,请开始提问"
|
166 |
+
else:
|
167 |
+
file_status = "文件未成功加载,请重新上传文件"
|
168 |
+
else:
|
169 |
+
file_status = "模型未完成加载,请先在加载模型后再导入文件"
|
170 |
+
vs_path = None
|
171 |
+
logger.info(file_status)
|
172 |
+
return vs_path, None, history + [[None, file_status]]
|
173 |
+
|
174 |
+
|
175 |
+
knowledge_base_test_mode_info = ("【注意】\n\n"
|
176 |
+
"1. 您已进入知识库测试模式,您输入的任何对话内容都将用于进行知识库查询,"
|
177 |
+
"并仅输出知识库匹配出的内容及相似度分值和及输入的文本源路径,查询的内容并不会进入模型查询。\n\n"
|
178 |
+
"2. 知识相关度 Score 经测试,建议设置为 500 或更低,具体设置情况请结合实际使用调整。"
|
179 |
+
"""3. 使用"添加单条数据"添加文本至知识库时,内容如未分段,则内容越多越会稀释各查询内容与之关联的score阈值。\n\n"""
|
180 |
+
"4. 单条内容长度建议设置在100-150左右。\n\n"
|
181 |
+
"5. 本界面用于知识入库及知识匹配相关参数设定,但当前版本中,"
|
182 |
+
"本界面中修改的参数并不会直接修改对话界面中参数,仍需前往`configs/model_config.py`修改后生效。"
|
183 |
+
"相关参数将在后续版本中支持本界面直接修改。")
|
184 |
+
|
185 |
+
|
186 |
+
webui_title = """
|
187 |
+
# 🎉langchain-ChatGLM WebUI🎉
|
188 |
+
👍 [https://github.com/imClumsyPanda/langchain-ChatGLM](https://github.com/imClumsyPanda/langchain-ChatGLM)
|
189 |
+
"""
|
190 |
+
###### #####
|
191 |
+
|
192 |
+
|
193 |
+
###### todo #####
|
194 |
+
# 1. streamlit运行方式与一般web服务器不同,使用模块是无法实现单例模式的,所以shared和local_doc_qa都需要进行全局化处理。
|
195 |
+
# 目前已经实现了local_doc_qa的全局化,后面要考虑shared。
|
196 |
+
# 2. 当前local_doc_qa是一个全局变量,一方面:任何一个session对其做出修改,都会影响所有session的对话;另一方面,如何处理所有session的请求竞争也是问题。
|
197 |
+
# 这个暂时无法避免,在配置普通的机器上暂时也无需考虑。
|
198 |
+
# 3. 目前只包含了get_answer对应的参数,以后可以添加其他参数,如temperature。
|
199 |
+
###### #####
|
200 |
+
|
201 |
+
|
202 |
+
###### 配置项 #####
|
203 |
+
class ST_CONFIG:
|
204 |
+
user_bg_color = '#77ff77'
|
205 |
+
user_icon = 'https://tse2-mm.cn.bing.net/th/id/OIP-C.LTTKrxNWDr_k74wz6jKqBgHaHa?w=203&h=203&c=7&r=0&o=5&pid=1.7'
|
206 |
+
robot_bg_color = '#ccccee'
|
207 |
+
robot_icon = 'https://ts1.cn.mm.bing.net/th/id/R-C.5302e2cc6f5c7c4933ebb3394e0c41bc?rik=z4u%2b7efba5Mgxw&riu=http%3a%2f%2fcomic-cons.xyz%2fwp-content%2fuploads%2fStar-Wars-avatar-icon-C3PO.png&ehk=kBBvCvpJMHPVpdfpw1GaH%2brbOaIoHjY5Ua9PKcIs%2bAc%3d&risl=&pid=ImgRaw&r=0'
|
208 |
+
default_mode = '知识库问答'
|
209 |
+
defalut_kb = ''
|
210 |
+
###### #####
|
211 |
+
|
212 |
+
|
213 |
+
class MsgType:
|
214 |
+
'''
|
215 |
+
目前仅支持文本类型的输入输出,为以后多模态模型预留图像、视频、音频支持。
|
216 |
+
'''
|
217 |
+
TEXT = 1
|
218 |
+
IMAGE = 2
|
219 |
+
VIDEO = 3
|
220 |
+
AUDIO = 4
|
221 |
+
|
222 |
+
|
223 |
+
class TempFile:
|
224 |
+
'''
|
225 |
+
为保持与get_vector_store的兼容性,需要将streamlit上传文件转化为其可以接受的方式
|
226 |
+
'''
|
227 |
+
|
228 |
+
def __init__(self, path):
|
229 |
+
self.name = path
|
230 |
+
|
231 |
+
|
232 |
+
def init_session():
|
233 |
+
st.session_state.setdefault('history', [])
|
234 |
+
|
235 |
+
|
236 |
+
# def get_query_params():
|
237 |
+
# '''
|
238 |
+
# 可以用url参数传递配置参数:llm_model, embedding_model, kb, mode。
|
239 |
+
# 该参数将覆盖model_config中的配置。处于安全考虑,目前只支持kb和mode
|
240 |
+
# 方便将固定的配置分享给特定的人。
|
241 |
+
# '''
|
242 |
+
# params = st.experimental_get_query_params()
|
243 |
+
# return {k: v[0] for k, v in params.items() if v}
|
244 |
+
|
245 |
+
|
246 |
+
def robot_say(msg, kb=''):
|
247 |
+
st.session_state['history'].append(
|
248 |
+
{'is_user': False, 'type': MsgType.TEXT, 'content': msg, 'kb': kb})
|
249 |
+
|
250 |
+
|
251 |
+
def user_say(msg):
|
252 |
+
st.session_state['history'].append(
|
253 |
+
{'is_user': True, 'type': MsgType.TEXT, 'content': msg})
|
254 |
+
|
255 |
+
|
256 |
+
def format_md(msg, is_user=False, bg_color='', margin='10%'):
|
257 |
+
'''
|
258 |
+
将文本消息格式化为markdown文本
|
259 |
+
'''
|
260 |
+
if is_user:
|
261 |
+
bg_color = bg_color or ST_CONFIG.user_bg_color
|
262 |
+
text = f'''
|
263 |
+
<div style="background:{bg_color};
|
264 |
+
margin-left:{margin};
|
265 |
+
word-break:break-all;
|
266 |
+
float:right;
|
267 |
+
padding:2%;
|
268 |
+
border-radius:2%;">
|
269 |
+
{msg}
|
270 |
+
</div>
|
271 |
+
'''
|
272 |
+
else:
|
273 |
+
bg_color = bg_color or ST_CONFIG.robot_bg_color
|
274 |
+
text = f'''
|
275 |
+
<div style="background:{bg_color};
|
276 |
+
margin-right:{margin};
|
277 |
+
word-break:break-all;
|
278 |
+
padding:2%;
|
279 |
+
border-radius:2%;">
|
280 |
+
{msg}
|
281 |
+
</div>
|
282 |
+
'''
|
283 |
+
return text
|
284 |
+
|
285 |
+
|
286 |
+
def message(msg,
|
287 |
+
is_user=False,
|
288 |
+
msg_type=MsgType.TEXT,
|
289 |
+
icon='',
|
290 |
+
bg_color='',
|
291 |
+
margin='10%',
|
292 |
+
kb='',
|
293 |
+
):
|
294 |
+
'''
|
295 |
+
渲染单条消息。目前仅支持文本
|
296 |
+
'''
|
297 |
+
cols = st.columns([1, 10, 1])
|
298 |
+
empty = cols[1].empty()
|
299 |
+
if is_user:
|
300 |
+
icon = icon or ST_CONFIG.user_icon
|
301 |
+
bg_color = bg_color or ST_CONFIG.user_bg_color
|
302 |
+
cols[2].image(icon, width=40)
|
303 |
+
if msg_type == MsgType.TEXT:
|
304 |
+
text = format_md(msg, is_user, bg_color, margin)
|
305 |
+
empty.markdown(text, unsafe_allow_html=True)
|
306 |
+
else:
|
307 |
+
raise RuntimeError('only support text message now.')
|
308 |
+
else:
|
309 |
+
icon = icon or ST_CONFIG.robot_icon
|
310 |
+
bg_color = bg_color or ST_CONFIG.robot_bg_color
|
311 |
+
cols[0].image(icon, width=40)
|
312 |
+
if kb:
|
313 |
+
cols[0].write(f'({kb})')
|
314 |
+
if msg_type == MsgType.TEXT:
|
315 |
+
text = format_md(msg, is_user, bg_color, margin)
|
316 |
+
empty.markdown(text, unsafe_allow_html=True)
|
317 |
+
else:
|
318 |
+
raise RuntimeError('only support text message now.')
|
319 |
+
return empty
|
320 |
+
|
321 |
+
|
322 |
+
def output_messages(
|
323 |
+
user_bg_color='',
|
324 |
+
robot_bg_color='',
|
325 |
+
user_icon='',
|
326 |
+
robot_icon='',
|
327 |
+
):
|
328 |
+
with chat_box.container():
|
329 |
+
last_response = None
|
330 |
+
for msg in st.session_state['history']:
|
331 |
+
bg_color = user_bg_color if msg['is_user'] else robot_bg_color
|
332 |
+
icon = user_icon if msg['is_user'] else robot_icon
|
333 |
+
empty = message(msg['content'],
|
334 |
+
is_user=msg['is_user'],
|
335 |
+
icon=icon,
|
336 |
+
msg_type=msg['type'],
|
337 |
+
bg_color=bg_color,
|
338 |
+
kb=msg.get('kb', '')
|
339 |
+
)
|
340 |
+
if not msg['is_user']:
|
341 |
+
last_response = empty
|
342 |
+
return last_response
|
343 |
+
|
344 |
+
|
345 |
+
@st.cache_resource(show_spinner=False, max_entries=1)
|
346 |
+
def load_model(llm_model: str, embedding_model: str):
|
347 |
+
'''
|
348 |
+
对应init_model,利用streamlit cache避免模型重复加载
|
349 |
+
'''
|
350 |
+
local_doc_qa = init_model(llm_model, embedding_model)
|
351 |
+
robot_say('模型已成功加载,可以开始对话,或从左侧选择模式后开始对话。\n请尽量不要刷新页面,以免模型出错或重复加载。')
|
352 |
+
return local_doc_qa
|
353 |
+
|
354 |
+
|
355 |
+
# @st.cache_data
|
356 |
+
def answer(query, vs_path='', history=[], mode='', score_threshold=0,
|
357 |
+
vector_search_top_k=5, chunk_conent=True, chunk_size=100, qa=None
|
358 |
+
):
|
359 |
+
'''
|
360 |
+
对应get_answer,--利用streamlit cache缓存相同问题的答案--
|
361 |
+
'''
|
362 |
+
return get_answer(query, vs_path, history, mode, score_threshold,
|
363 |
+
vector_search_top_k, chunk_conent, chunk_size)
|
364 |
+
|
365 |
+
|
366 |
+
def load_vector_store(
|
367 |
+
vs_id,
|
368 |
+
files,
|
369 |
+
sentence_size=100,
|
370 |
+
history=[],
|
371 |
+
one_conent=None,
|
372 |
+
one_content_segmentation=None,
|
373 |
+
):
|
374 |
+
return get_vector_store(
|
375 |
+
local_doc_qa,
|
376 |
+
vs_id,
|
377 |
+
files,
|
378 |
+
sentence_size,
|
379 |
+
history,
|
380 |
+
one_conent,
|
381 |
+
one_content_segmentation,
|
382 |
+
)
|
383 |
+
|
384 |
+
|
385 |
+
# main ui
|
386 |
+
st.set_page_config(webui_title, layout='wide')
|
387 |
+
init_session()
|
388 |
+
# params = get_query_params()
|
389 |
+
# llm_model = params.get('llm_model', LLM_MODEL)
|
390 |
+
# embedding_model = params.get('embedding_model', EMBEDDING_MODEL)
|
391 |
+
|
392 |
+
with st.spinner(f'正在加载模型({LLM_MODEL} + {EMBEDDING_MODEL}),请耐心等候...'):
|
393 |
+
local_doc_qa = load_model(LLM_MODEL, EMBEDDING_MODEL)
|
394 |
+
|
395 |
+
|
396 |
+
def use_kb_mode(m):
|
397 |
+
return m in ['知识库问答', '知识库测试']
|
398 |
+
|
399 |
+
|
400 |
+
# sidebar
|
401 |
+
modes = ['LLM 对话', '知识库问答', 'Bing搜索问答', '知识库测试']
|
402 |
+
with st.sidebar:
|
403 |
+
def on_mode_change():
|
404 |
+
m = st.session_state.mode
|
405 |
+
robot_say(f'已切换到"{m}"模式')
|
406 |
+
if m == '知识库测试':
|
407 |
+
robot_say(knowledge_base_test_mode_info)
|
408 |
+
|
409 |
+
index = 0
|
410 |
+
try:
|
411 |
+
index = modes.index(ST_CONFIG.default_mode)
|
412 |
+
except:
|
413 |
+
pass
|
414 |
+
mode = st.selectbox('对话模式', modes, index,
|
415 |
+
on_change=on_mode_change, key='mode')
|
416 |
+
|
417 |
+
with st.expander('模型配置', '知识' not in mode):
|
418 |
+
with st.form('model_config'):
|
419 |
+
index = 0
|
420 |
+
try:
|
421 |
+
index = llm_model_dict_list.index(LLM_MODEL)
|
422 |
+
except:
|
423 |
+
pass
|
424 |
+
llm_model = st.selectbox('LLM模型', llm_model_dict_list, index)
|
425 |
+
|
426 |
+
no_remote_model = st.checkbox('加载本地模型', False)
|
427 |
+
use_ptuning_v2 = st.checkbox('使用p-tuning-v2微调过的模型', False)
|
428 |
+
use_lora = st.checkbox('使用lora微调的权重', False)
|
429 |
+
try:
|
430 |
+
index = embedding_model_dict_list.index(EMBEDDING_MODEL)
|
431 |
+
except:
|
432 |
+
pass
|
433 |
+
embedding_model = st.selectbox(
|
434 |
+
'Embedding模型', embedding_model_dict_list, index)
|
435 |
+
|
436 |
+
btn_load_model = st.form_submit_button('重新加载模型')
|
437 |
+
if btn_load_model:
|
438 |
+
local_doc_qa = load_model(llm_model, embedding_model)
|
439 |
+
|
440 |
+
if mode in ['知识库问答', '知识库测试']:
|
441 |
+
vs_list = get_vs_list()
|
442 |
+
vs_list.remove('新建知识库')
|
443 |
+
|
444 |
+
def on_new_kb():
|
445 |
+
name = st.session_state.kb_name
|
446 |
+
if name in vs_list:
|
447 |
+
st.error(f'名为“{name}”的知识库已存在。')
|
448 |
+
else:
|
449 |
+
vs_list.append(name)
|
450 |
+
st.session_state.vs_path = name
|
451 |
+
|
452 |
+
def on_vs_change():
|
453 |
+
robot_say(f'已加载知识库: {st.session_state.vs_path}')
|
454 |
+
with st.expander('知识库配置', True):
|
455 |
+
cols = st.columns([12, 10])
|
456 |
+
kb_name = cols[0].text_input(
|
457 |
+
'新知识库名称', placeholder='新知识库名称', label_visibility='collapsed')
|
458 |
+
cols[1].button('新建知识库', on_click=on_new_kb)
|
459 |
+
vs_path = st.selectbox(
|
460 |
+
'选择知识库', vs_list, on_change=on_vs_change, key='vs_path')
|
461 |
+
|
462 |
+
st.text('')
|
463 |
+
|
464 |
+
score_threshold = st.slider(
|
465 |
+
'知识相关度阈值', 0, 1000, VECTOR_SEARCH_SCORE_THRESHOLD)
|
466 |
+
top_k = st.slider('向量匹配数量', 1, 20, VECTOR_SEARCH_TOP_K)
|
467 |
+
history_len = st.slider(
|
468 |
+
'LLM对话轮数', 1, 50, LLM_HISTORY_LEN) # 也许要跟知识库分开设置
|
469 |
+
local_doc_qa.llm.set_history_len(history_len)
|
470 |
+
chunk_conent = st.checkbox('启用上下文关联', False)
|
471 |
+
st.text('')
|
472 |
+
# chunk_conent = st.checkbox('分割文本', True) # 知识库文本分割入库
|
473 |
+
chunk_size = st.slider('上下文关联长度', 1, 1000, CHUNK_SIZE)
|
474 |
+
sentence_size = st.slider('文本入库分句长度限制', 1, 1000, SENTENCE_SIZE)
|
475 |
+
files = st.file_uploader('上传知识文件',
|
476 |
+
['docx', 'txt', 'md', 'csv', 'xlsx', 'pdf'],
|
477 |
+
accept_multiple_files=True)
|
478 |
+
if st.button('添加文件到知识库'):
|
479 |
+
temp_dir = tempfile.mkdtemp()
|
480 |
+
file_list = []
|
481 |
+
for f in files:
|
482 |
+
file = os.path.join(temp_dir, f.name)
|
483 |
+
with open(file, 'wb') as fp:
|
484 |
+
fp.write(f.getvalue())
|
485 |
+
file_list.append(TempFile(file))
|
486 |
+
_, _, history = load_vector_store(
|
487 |
+
vs_path, file_list, sentence_size, [], None, None)
|
488 |
+
st.session_state.files = []
|
489 |
+
|
490 |
+
|
491 |
+
# main body
|
492 |
+
chat_box = st.empty()
|
493 |
+
|
494 |
+
with st.form('my_form', clear_on_submit=True):
|
495 |
+
cols = st.columns([8, 1])
|
496 |
+
question = cols[0].text_input(
|
497 |
+
'temp', key='input_question', label_visibility='collapsed')
|
498 |
+
|
499 |
+
def on_send():
|
500 |
+
q = st.session_state.input_question
|
501 |
+
if q:
|
502 |
+
user_say(q)
|
503 |
+
|
504 |
+
if mode == 'LLM 对话':
|
505 |
+
robot_say('正在思考...')
|
506 |
+
last_response = output_messages()
|
507 |
+
for history, _ in answer(q,
|
508 |
+
history=[],
|
509 |
+
mode=mode):
|
510 |
+
last_response.markdown(
|
511 |
+
format_md(history[-1][-1], False),
|
512 |
+
unsafe_allow_html=True
|
513 |
+
)
|
514 |
+
elif use_kb_mode(mode):
|
515 |
+
robot_say('正在思考...', vs_path)
|
516 |
+
last_response = output_messages()
|
517 |
+
for history, _ in answer(q,
|
518 |
+
vs_path=os.path.join(
|
519 |
+
KB_ROOT_PATH, vs_path, "vector_store"),
|
520 |
+
history=[],
|
521 |
+
mode=mode,
|
522 |
+
score_threshold=score_threshold,
|
523 |
+
vector_search_top_k=top_k,
|
524 |
+
chunk_conent=chunk_conent,
|
525 |
+
chunk_size=chunk_size):
|
526 |
+
last_response.markdown(
|
527 |
+
format_md(history[-1][-1], False, 'ligreen'),
|
528 |
+
unsafe_allow_html=True
|
529 |
+
)
|
530 |
+
else:
|
531 |
+
robot_say('正在思考...')
|
532 |
+
last_response = output_messages()
|
533 |
+
st.session_state['history'][-1]['content'] = history[-1][-1]
|
534 |
+
submit = cols[1].form_submit_button('发送', on_click=on_send)
|
535 |
+
|
536 |
+
output_messages()
|
537 |
+
|
538 |
+
# st.write(st.session_state['history'])
|