Spaces:
Runtime error
Runtime error
pseudotensor
commited on
Commit
·
54f4f91
1
Parent(s):
3f42f2e
Update with h2oGPT hash e4482a4c59016517cd0d5513bc15b78b46f4598a
Browse files- LICENSE +0 -201
- client_test.py +22 -14
- enums.py +16 -1
- finetune.py +0 -676
- generate.py +0 -0
- gpt4all_llm.py +9 -0
- gpt_langchain.py +154 -51
- gradio_runner.py +137 -39
- gradio_utils/__pycache__/css.cpython-310.pyc +0 -0
- gradio_utils/__pycache__/grclient.cpython-310.pyc +0 -0
- gradio_utils/__pycache__/prompt_form.cpython-310.pyc +0 -0
- gradio_utils/css.py +0 -53
- gradio_utils/grclient.py +0 -82
- gradio_utils/prompt_form.py +0 -118
- h2o-logo.svg +0 -1
- h2oai_pipeline.py +1 -0
- iterators/__init__.py +0 -4
- iterators/__pycache__/__init__.cpython-310.pyc +0 -0
- iterators/__pycache__/iterator_pipe.cpython-310.pyc +0 -0
- iterators/__pycache__/timeout_iterator.cpython-310.pyc +0 -0
- iterators/iterator_pipe.py +0 -93
- iterators/timeout_iterator.py +0 -170
- prompter.py +4 -2
- requirements.txt +0 -153
LICENSE
DELETED
@@ -1,201 +0,0 @@
|
|
1 |
-
Apache License
|
2 |
-
Version 2.0, January 2004
|
3 |
-
http://www.apache.org/licenses/
|
4 |
-
|
5 |
-
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
-
|
7 |
-
1. Definitions.
|
8 |
-
|
9 |
-
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
-
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
-
|
12 |
-
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
-
the copyright owner that is granting the License.
|
14 |
-
|
15 |
-
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
-
other entities that control, are controlled by, or are under common
|
17 |
-
control with that entity. For the purposes of this definition,
|
18 |
-
"control" means (i) the power, direct or indirect, to cause the
|
19 |
-
direction or management of such entity, whether by contract or
|
20 |
-
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
-
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
-
|
23 |
-
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
-
exercising permissions granted by this License.
|
25 |
-
|
26 |
-
"Source" form shall mean the preferred form for making modifications,
|
27 |
-
including but not limited to software source code, documentation
|
28 |
-
source, and configuration files.
|
29 |
-
|
30 |
-
"Object" form shall mean any form resulting from mechanical
|
31 |
-
transformation or translation of a Source form, including but
|
32 |
-
not limited to compiled object code, generated documentation,
|
33 |
-
and conversions to other media types.
|
34 |
-
|
35 |
-
"Work" shall mean the work of authorship, whether in Source or
|
36 |
-
Object form, made available under the License, as indicated by a
|
37 |
-
copyright notice that is included in or attached to the work
|
38 |
-
(an example is provided in the Appendix below).
|
39 |
-
|
40 |
-
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
-
form, that is based on (or derived from) the Work and for which the
|
42 |
-
editorial revisions, annotations, elaborations, or other modifications
|
43 |
-
represent, as a whole, an original work of authorship. For the purposes
|
44 |
-
of this License, Derivative Works shall not include works that remain
|
45 |
-
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
-
the Work and Derivative Works thereof.
|
47 |
-
|
48 |
-
"Contribution" shall mean any work of authorship, including
|
49 |
-
the original version of the Work and any modifications or additions
|
50 |
-
to that Work or Derivative Works thereof, that is intentionally
|
51 |
-
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
-
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
-
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
-
means any form of electronic, verbal, or written communication sent
|
55 |
-
to the Licensor or its representatives, including but not limited to
|
56 |
-
communication on electronic mailing lists, source code control systems,
|
57 |
-
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
-
Licensor for the purpose of discussing and improving the Work, but
|
59 |
-
excluding communication that is conspicuously marked or otherwise
|
60 |
-
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
-
|
62 |
-
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
-
on behalf of whom a Contribution has been received by Licensor and
|
64 |
-
subsequently incorporated within the Work.
|
65 |
-
|
66 |
-
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
-
this License, each Contributor hereby grants to You a perpetual,
|
68 |
-
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
-
copyright license to reproduce, prepare Derivative Works of,
|
70 |
-
publicly display, publicly perform, sublicense, and distribute the
|
71 |
-
Work and such Derivative Works in Source or Object form.
|
72 |
-
|
73 |
-
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
-
this License, each Contributor hereby grants to You a perpetual,
|
75 |
-
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
-
(except as stated in this section) patent license to make, have made,
|
77 |
-
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
-
where such license applies only to those patent claims licensable
|
79 |
-
by such Contributor that are necessarily infringed by their
|
80 |
-
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
-
with the Work to which such Contribution(s) was submitted. If You
|
82 |
-
institute patent litigation against any entity (including a
|
83 |
-
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
-
or a Contribution incorporated within the Work constitutes direct
|
85 |
-
or contributory patent infringement, then any patent licenses
|
86 |
-
granted to You under this License for that Work shall terminate
|
87 |
-
as of the date such litigation is filed.
|
88 |
-
|
89 |
-
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
-
Work or Derivative Works thereof in any medium, with or without
|
91 |
-
modifications, and in Source or Object form, provided that You
|
92 |
-
meet the following conditions:
|
93 |
-
|
94 |
-
(a) You must give any other recipients of the Work or
|
95 |
-
Derivative Works a copy of this License; and
|
96 |
-
|
97 |
-
(b) You must cause any modified files to carry prominent notices
|
98 |
-
stating that You changed the files; and
|
99 |
-
|
100 |
-
(c) You must retain, in the Source form of any Derivative Works
|
101 |
-
that You distribute, all copyright, patent, trademark, and
|
102 |
-
attribution notices from the Source form of the Work,
|
103 |
-
excluding those notices that do not pertain to any part of
|
104 |
-
the Derivative Works; and
|
105 |
-
|
106 |
-
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
-
distribution, then any Derivative Works that You distribute must
|
108 |
-
include a readable copy of the attribution notices contained
|
109 |
-
within such NOTICE file, excluding those notices that do not
|
110 |
-
pertain to any part of the Derivative Works, in at least one
|
111 |
-
of the following places: within a NOTICE text file distributed
|
112 |
-
as part of the Derivative Works; within the Source form or
|
113 |
-
documentation, if provided along with the Derivative Works; or,
|
114 |
-
within a display generated by the Derivative Works, if and
|
115 |
-
wherever such third-party notices normally appear. The contents
|
116 |
-
of the NOTICE file are for informational purposes only and
|
117 |
-
do not modify the License. You may add Your own attribution
|
118 |
-
notices within Derivative Works that You distribute, alongside
|
119 |
-
or as an addendum to the NOTICE text from the Work, provided
|
120 |
-
that such additional attribution notices cannot be construed
|
121 |
-
as modifying the License.
|
122 |
-
|
123 |
-
You may add Your own copyright statement to Your modifications and
|
124 |
-
may provide additional or different license terms and conditions
|
125 |
-
for use, reproduction, or distribution of Your modifications, or
|
126 |
-
for any such Derivative Works as a whole, provided Your use,
|
127 |
-
reproduction, and distribution of the Work otherwise complies with
|
128 |
-
the conditions stated in this License.
|
129 |
-
|
130 |
-
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
-
any Contribution intentionally submitted for inclusion in the Work
|
132 |
-
by You to the Licensor shall be under the terms and conditions of
|
133 |
-
this License, without any additional terms or conditions.
|
134 |
-
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
-
the terms of any separate license agreement you may have executed
|
136 |
-
with Licensor regarding such Contributions.
|
137 |
-
|
138 |
-
6. Trademarks. This License does not grant permission to use the trade
|
139 |
-
names, trademarks, service marks, or product names of the Licensor,
|
140 |
-
except as required for reasonable and customary use in describing the
|
141 |
-
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
-
|
143 |
-
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
-
agreed to in writing, Licensor provides the Work (and each
|
145 |
-
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
-
implied, including, without limitation, any warranties or conditions
|
148 |
-
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
-
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
-
appropriateness of using or redistributing the Work and assume any
|
151 |
-
risks associated with Your exercise of permissions under this License.
|
152 |
-
|
153 |
-
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
-
whether in tort (including negligence), contract, or otherwise,
|
155 |
-
unless required by applicable law (such as deliberate and grossly
|
156 |
-
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
-
liable to You for damages, including any direct, indirect, special,
|
158 |
-
incidental, or consequential damages of any character arising as a
|
159 |
-
result of this License or out of the use or inability to use the
|
160 |
-
Work (including but not limited to damages for loss of goodwill,
|
161 |
-
work stoppage, computer failure or malfunction, or any and all
|
162 |
-
other commercial damages or losses), even if such Contributor
|
163 |
-
has been advised of the possibility of such damages.
|
164 |
-
|
165 |
-
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
-
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
-
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
-
or other liability obligations and/or rights consistent with this
|
169 |
-
License. However, in accepting such obligations, You may act only
|
170 |
-
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
-
of any other Contributor, and only if You agree to indemnify,
|
172 |
-
defend, and hold each Contributor harmless for any liability
|
173 |
-
incurred by, or claims asserted against, such Contributor by reason
|
174 |
-
of your accepting any such warranty or additional liability.
|
175 |
-
|
176 |
-
END OF TERMS AND CONDITIONS
|
177 |
-
|
178 |
-
APPENDIX: How to apply the Apache License to your work.
|
179 |
-
|
180 |
-
To apply the Apache License to your work, attach the following
|
181 |
-
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
-
replaced with your own identifying information. (Don't include
|
183 |
-
the brackets!) The text should be enclosed in the appropriate
|
184 |
-
comment syntax for the file format. We also recommend that a
|
185 |
-
file or class name and description of purpose be included on the
|
186 |
-
same "printed page" as the copyright notice for easier
|
187 |
-
identification within third-party archives.
|
188 |
-
|
189 |
-
Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
|
190 |
-
|
191 |
-
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
-
you may not use this file except in compliance with the License.
|
193 |
-
You may obtain a copy of the License at
|
194 |
-
|
195 |
-
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
-
|
197 |
-
Unless required by applicable law or agreed to in writing, software
|
198 |
-
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
-
See the License for the specific language governing permissions and
|
201 |
-
limitations under the License.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
client_test.py
CHANGED
@@ -12,13 +12,13 @@ Currently, this will force model to be on a single GPU.
|
|
12 |
|
13 |
Then run this client as:
|
14 |
|
15 |
-
python client_test.py
|
16 |
|
17 |
|
18 |
|
19 |
For HF spaces:
|
20 |
|
21 |
-
HOST="https://h2oai-h2ogpt-chatbot.hf.space" python client_test.py
|
22 |
|
23 |
Result:
|
24 |
|
@@ -28,7 +28,7 @@ Loaded as API: https://h2oai-h2ogpt-chatbot.hf.space ✔
|
|
28 |
|
29 |
For demo:
|
30 |
|
31 |
-
HOST="https://gpt.h2o.ai" python client_test.py
|
32 |
|
33 |
Result:
|
34 |
|
@@ -48,7 +48,7 @@ import markdown # pip install markdown
|
|
48 |
import pytest
|
49 |
from bs4 import BeautifulSoup # pip install beautifulsoup4
|
50 |
|
51 |
-
from enums import DocumentChoices
|
52 |
|
53 |
debug = False
|
54 |
|
@@ -67,7 +67,9 @@ def get_client(serialize=True):
|
|
67 |
def get_args(prompt, prompt_type, chat=False, stream_output=False,
|
68 |
max_new_tokens=50,
|
69 |
top_k_docs=3,
|
70 |
-
langchain_mode='Disabled'
|
|
|
|
|
71 |
from collections import OrderedDict
|
72 |
kwargs = OrderedDict(instruction=prompt if chat else '', # only for chat=True
|
73 |
iinput='', # only for chat=True
|
@@ -76,7 +78,7 @@ def get_args(prompt, prompt_type, chat=False, stream_output=False,
|
|
76 |
# but leave stream_output=False for simple input/output mode
|
77 |
stream_output=stream_output,
|
78 |
prompt_type=prompt_type,
|
79 |
-
prompt_dict=
|
80 |
temperature=0.1,
|
81 |
top_p=0.75,
|
82 |
top_k=40,
|
@@ -92,12 +94,13 @@ def get_args(prompt, prompt_type, chat=False, stream_output=False,
|
|
92 |
instruction_nochat=prompt if not chat else '',
|
93 |
iinput_nochat='', # only for chat=False
|
94 |
langchain_mode=langchain_mode,
|
|
|
95 |
top_k_docs=top_k_docs,
|
96 |
chunk=True,
|
97 |
chunk_size=512,
|
98 |
document_choice=[DocumentChoices.All_Relevant.name],
|
99 |
)
|
100 |
-
from
|
101 |
assert len(set(eval_func_param_names).difference(set(list(kwargs.keys())))) == 0
|
102 |
if chat:
|
103 |
# add chatbot output on end. Assumes serialize=False
|
@@ -198,6 +201,7 @@ def run_client_nochat_api_lean_morestuff(prompt, prompt_type='human_bot', max_ne
|
|
198 |
instruction_nochat=prompt,
|
199 |
iinput_nochat='',
|
200 |
langchain_mode='Disabled',
|
|
|
201 |
top_k_docs=4,
|
202 |
document_choice=['All'],
|
203 |
)
|
@@ -219,21 +223,24 @@ def run_client_nochat_api_lean_morestuff(prompt, prompt_type='human_bot', max_ne
|
|
219 |
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
220 |
def test_client_chat(prompt_type='human_bot'):
|
221 |
return run_client_chat(prompt='Who are you?', prompt_type=prompt_type, stream_output=False, max_new_tokens=50,
|
222 |
-
langchain_mode='Disabled')
|
223 |
|
224 |
|
225 |
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
226 |
def test_client_chat_stream(prompt_type='human_bot'):
|
227 |
return run_client_chat(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type,
|
228 |
stream_output=True, max_new_tokens=512,
|
229 |
-
langchain_mode='Disabled')
|
230 |
|
231 |
|
232 |
-
def run_client_chat(prompt, prompt_type, stream_output, max_new_tokens, langchain_mode
|
|
|
233 |
client = get_client(serialize=False)
|
234 |
|
235 |
kwargs, args = get_args(prompt, prompt_type, chat=True, stream_output=stream_output,
|
236 |
-
max_new_tokens=max_new_tokens, langchain_mode=langchain_mode
|
|
|
|
|
237 |
return run_client(client, prompt, args, kwargs)
|
238 |
|
239 |
|
@@ -276,14 +283,15 @@ def run_client(client, prompt, args, kwargs, do_md_to_text=True, verbose=False):
|
|
276 |
def test_client_nochat_stream(prompt_type='human_bot'):
|
277 |
return run_client_nochat_gen(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type,
|
278 |
stream_output=True, max_new_tokens=512,
|
279 |
-
langchain_mode='Disabled')
|
280 |
|
281 |
|
282 |
-
def run_client_nochat_gen(prompt, prompt_type, stream_output, max_new_tokens, langchain_mode):
|
283 |
client = get_client(serialize=False)
|
284 |
|
285 |
kwargs, args = get_args(prompt, prompt_type, chat=False, stream_output=stream_output,
|
286 |
-
max_new_tokens=max_new_tokens, langchain_mode=langchain_mode
|
|
|
287 |
return run_client_gen(client, prompt, args, kwargs)
|
288 |
|
289 |
|
|
|
12 |
|
13 |
Then run this client as:
|
14 |
|
15 |
+
python src/client_test.py
|
16 |
|
17 |
|
18 |
|
19 |
For HF spaces:
|
20 |
|
21 |
+
HOST="https://h2oai-h2ogpt-chatbot.hf.space" python src/client_test.py
|
22 |
|
23 |
Result:
|
24 |
|
|
|
28 |
|
29 |
For demo:
|
30 |
|
31 |
+
HOST="https://gpt.h2o.ai" python src/client_test.py
|
32 |
|
33 |
Result:
|
34 |
|
|
|
48 |
import pytest
|
49 |
from bs4 import BeautifulSoup # pip install beautifulsoup4
|
50 |
|
51 |
+
from enums import DocumentChoices, LangChainAction
|
52 |
|
53 |
debug = False
|
54 |
|
|
|
67 |
def get_args(prompt, prompt_type, chat=False, stream_output=False,
|
68 |
max_new_tokens=50,
|
69 |
top_k_docs=3,
|
70 |
+
langchain_mode='Disabled',
|
71 |
+
langchain_action=LangChainAction.QUERY.value,
|
72 |
+
prompt_dict=None):
|
73 |
from collections import OrderedDict
|
74 |
kwargs = OrderedDict(instruction=prompt if chat else '', # only for chat=True
|
75 |
iinput='', # only for chat=True
|
|
|
78 |
# but leave stream_output=False for simple input/output mode
|
79 |
stream_output=stream_output,
|
80 |
prompt_type=prompt_type,
|
81 |
+
prompt_dict=prompt_dict,
|
82 |
temperature=0.1,
|
83 |
top_p=0.75,
|
84 |
top_k=40,
|
|
|
94 |
instruction_nochat=prompt if not chat else '',
|
95 |
iinput_nochat='', # only for chat=False
|
96 |
langchain_mode=langchain_mode,
|
97 |
+
langchain_action=langchain_action,
|
98 |
top_k_docs=top_k_docs,
|
99 |
chunk=True,
|
100 |
chunk_size=512,
|
101 |
document_choice=[DocumentChoices.All_Relevant.name],
|
102 |
)
|
103 |
+
from src.gen import eval_func_param_names
|
104 |
assert len(set(eval_func_param_names).difference(set(list(kwargs.keys())))) == 0
|
105 |
if chat:
|
106 |
# add chatbot output on end. Assumes serialize=False
|
|
|
201 |
instruction_nochat=prompt,
|
202 |
iinput_nochat='',
|
203 |
langchain_mode='Disabled',
|
204 |
+
langchain_action=LangChainAction.QUERY.value,
|
205 |
top_k_docs=4,
|
206 |
document_choice=['All'],
|
207 |
)
|
|
|
223 |
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
224 |
def test_client_chat(prompt_type='human_bot'):
|
225 |
return run_client_chat(prompt='Who are you?', prompt_type=prompt_type, stream_output=False, max_new_tokens=50,
|
226 |
+
langchain_mode='Disabled', langchain_action=LangChainAction.QUERY.value)
|
227 |
|
228 |
|
229 |
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
230 |
def test_client_chat_stream(prompt_type='human_bot'):
|
231 |
return run_client_chat(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type,
|
232 |
stream_output=True, max_new_tokens=512,
|
233 |
+
langchain_mode='Disabled', langchain_action=LangChainAction.QUERY.value)
|
234 |
|
235 |
|
236 |
+
def run_client_chat(prompt, prompt_type, stream_output, max_new_tokens, langchain_mode, langchain_action,
|
237 |
+
prompt_dict=None):
|
238 |
client = get_client(serialize=False)
|
239 |
|
240 |
kwargs, args = get_args(prompt, prompt_type, chat=True, stream_output=stream_output,
|
241 |
+
max_new_tokens=max_new_tokens, langchain_mode=langchain_mode,
|
242 |
+
langchain_action=langchain_action,
|
243 |
+
prompt_dict=prompt_dict)
|
244 |
return run_client(client, prompt, args, kwargs)
|
245 |
|
246 |
|
|
|
283 |
def test_client_nochat_stream(prompt_type='human_bot'):
|
284 |
return run_client_nochat_gen(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type,
|
285 |
stream_output=True, max_new_tokens=512,
|
286 |
+
langchain_mode='Disabled', langchain_action=LangChainAction.QUERY.value)
|
287 |
|
288 |
|
289 |
+
def run_client_nochat_gen(prompt, prompt_type, stream_output, max_new_tokens, langchain_mode, langchain_action):
|
290 |
client = get_client(serialize=False)
|
291 |
|
292 |
kwargs, args = get_args(prompt, prompt_type, chat=False, stream_output=stream_output,
|
293 |
+
max_new_tokens=max_new_tokens, langchain_mode=langchain_mode,
|
294 |
+
langchain_action=langchain_action)
|
295 |
return run_client_gen(client, prompt, args, kwargs)
|
296 |
|
297 |
|
enums.py
CHANGED
@@ -37,6 +37,9 @@ class DocumentChoices(Enum):
|
|
37 |
Just_LLM = 3
|
38 |
|
39 |
|
|
|
|
|
|
|
40 |
class LangChainMode(Enum):
|
41 |
"""LangChain mode"""
|
42 |
|
@@ -52,10 +55,22 @@ class LangChainMode(Enum):
|
|
52 |
H2O_DAI_DOCS = "DriverlessAI docs"
|
53 |
|
54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
no_server_str = no_lora_str = no_model_str = '[None/Remove]'
|
56 |
|
57 |
|
58 |
-
# from site-packages/langchain/llms/openai.py
|
|
|
59 |
model_token_mapping = {
|
60 |
"gpt-4": 8192,
|
61 |
"gpt-4-0314": 8192,
|
|
|
37 |
Just_LLM = 3
|
38 |
|
39 |
|
40 |
+
non_query_commands = [DocumentChoices.All_Relevant_Only_Sources.name, DocumentChoices.Only_All_Sources.name]
|
41 |
+
|
42 |
+
|
43 |
class LangChainMode(Enum):
|
44 |
"""LangChain mode"""
|
45 |
|
|
|
55 |
H2O_DAI_DOCS = "DriverlessAI docs"
|
56 |
|
57 |
|
58 |
+
class LangChainAction(Enum):
|
59 |
+
"""LangChain action"""
|
60 |
+
|
61 |
+
QUERY = "Query"
|
62 |
+
# WIP:
|
63 |
+
#SUMMARIZE_MAP = "Summarize_map_reduce"
|
64 |
+
SUMMARIZE_MAP = "Summarize"
|
65 |
+
SUMMARIZE_ALL = "Summarize_all"
|
66 |
+
SUMMARIZE_REFINE = "Summarize_refine"
|
67 |
+
|
68 |
+
|
69 |
no_server_str = no_lora_str = no_model_str = '[None/Remove]'
|
70 |
|
71 |
|
72 |
+
# from site-packages/langchain/llms/openai.py
|
73 |
+
# but needed since ChatOpenAI doesn't have this information
|
74 |
model_token_mapping = {
|
75 |
"gpt-4": 8192,
|
76 |
"gpt-4-0314": 8192,
|
finetune.py
DELETED
@@ -1,676 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
from functools import partial
|
4 |
-
from typing import List, Union
|
5 |
-
import fire
|
6 |
-
import numpy as np
|
7 |
-
|
8 |
-
if os.path.dirname(os.path.abspath(__file__)) not in sys.path:
|
9 |
-
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
10 |
-
|
11 |
-
from loaders import get_loaders, get_tokenizer
|
12 |
-
from prompter import generate_prompt, prompt_types, PromptType
|
13 |
-
from utils import get_githash, copy_code
|
14 |
-
import torch
|
15 |
-
|
16 |
-
|
17 |
-
def log(*args, **kwargs):
|
18 |
-
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
|
19 |
-
if 'flush' not in kwargs:
|
20 |
-
kwargs['flush'] = True
|
21 |
-
print(*args, **kwargs)
|
22 |
-
|
23 |
-
|
24 |
-
# supported by huggingface evaluate
|
25 |
-
supported_metrics = ['bleu', 'rouge', 'sacrebleu', 'meteor']
|
26 |
-
|
27 |
-
|
28 |
-
def train(
|
29 |
-
save_code: bool = False,
|
30 |
-
run_id: int = None,
|
31 |
-
|
32 |
-
base_model: str = 'h2oai/h2ogpt-oig-oasst1-512-6_9b',
|
33 |
-
# base_model: str = 'h2oai/h2ogpt-oasst1-512-12b',
|
34 |
-
# base_model: str = 'h2oai/h2ogpt-oasst1-512-20b',
|
35 |
-
# base_model: str = 'EleutherAI/gpt-neox-20b',
|
36 |
-
# base_model: str = 'EleutherAI/pythia-12b-deduped',
|
37 |
-
# base_model: str = 'togethercomputer/GPT-NeoXT-Chat-Base-20B',
|
38 |
-
# base_model: str = 'decapoda-research/llama-7b-hf',
|
39 |
-
# base_model: str = 'decapoda-research/llama-13b-hf',
|
40 |
-
# base_model: str = 'decapoda-research/llama-30b-hf',
|
41 |
-
# base_model: str = 'EleutherAI/gpt-j-6B',
|
42 |
-
|
43 |
-
# only needed if base_model is self-exported HF state without tokenizer
|
44 |
-
tokenizer_base_model: str = None,
|
45 |
-
# tokenizer_base_model: str = 'EleutherAI/gpt-neox-20b',
|
46 |
-
|
47 |
-
data_path: str = "h2oai/openassistant_oasst1_h2ogpt",
|
48 |
-
data_col_dict: dict = None,
|
49 |
-
# data_path: str = "./dai_docs.train.json",
|
50 |
-
prompt_type: Union[str, int] = "plain", # "plain", "instruct", "quality", "human_bot", "dai_faq"
|
51 |
-
|
52 |
-
valid_path: str = None,
|
53 |
-
# valid_path: str = "./dai_docs.valid.json",
|
54 |
-
|
55 |
-
# data_mix_in_path: str = "laion/OIG", # way too big, medium quality
|
56 |
-
data_mix_in_path: str = "0-hero/OIG-small-chip2", # high quality, 50 MB, good enough for now
|
57 |
-
data_mix_in_factor: float = 0.0, # >1: more mix-in data, <1: more of data_path data
|
58 |
-
data_mix_in_col_dict: dict = {'user': 'instruction', 'chip2': 'output'},
|
59 |
-
data_mix_in_prompt_type: str = "instruct", # just instruction->output, same as instruct
|
60 |
-
|
61 |
-
output_dir: str = None,
|
62 |
-
|
63 |
-
# LoRA checkpoint continuation
|
64 |
-
lora_weights: str = "",
|
65 |
-
|
66 |
-
# batching training hyperparams
|
67 |
-
batch_size: int = 128,
|
68 |
-
micro_batch_size: int = 4,
|
69 |
-
gradient_checkpointing=False, # unnecessary with gradient accumulation enabled
|
70 |
-
fp16=True,
|
71 |
-
train_8bit=False,
|
72 |
-
train_4bit=False,
|
73 |
-
|
74 |
-
# general training hyperparams
|
75 |
-
num_epochs: float = 1,
|
76 |
-
learning_rate: float = 3e-4,
|
77 |
-
|
78 |
-
# validation settings
|
79 |
-
val_set_size: int = None,
|
80 |
-
val_metrics: List[str] = [],
|
81 |
-
eval_steps: int = None, # to control eval steps via steps
|
82 |
-
eval_epochs: float = None, # to control eval steps via epochs
|
83 |
-
|
84 |
-
# lora hyperparams
|
85 |
-
lora_r: int = 8,
|
86 |
-
lora_alpha: int = 16,
|
87 |
-
lora_dropout: float = 0.05,
|
88 |
-
lora_target_modules: List[str] = None,
|
89 |
-
llama_type: bool = None,
|
90 |
-
llama_flash_attn: bool = False,
|
91 |
-
|
92 |
-
# llm hyperparams
|
93 |
-
train_on_inputs: bool = True, # if False, masks out inputs in loss
|
94 |
-
group_by_length: bool = False, # if True, faster, but produces an odd training loss curve
|
95 |
-
resume_from_checkpoint: str = None, # either training checkpoint or final adapter
|
96 |
-
cutoff_len: int = 512, # larger values use more memory
|
97 |
-
drop_truncations: bool = False, # if True, drop any truncated long sequences
|
98 |
-
|
99 |
-
# torch training params
|
100 |
-
ddp: bool = True, # set to False if OOM with True, for multi-GPU model parallelism
|
101 |
-
local_files_only: bool = False, # else will download new versions, normally unwanted
|
102 |
-
resume_download: bool = True,
|
103 |
-
use_auth_token: Union[str, bool] = False, # True requires CLI did huggingface-cli login before running
|
104 |
-
warmup_steps: int = 100,
|
105 |
-
logging_steps: int = 1,
|
106 |
-
save_steps: int = None, # must be round multiple of eval_steps
|
107 |
-
save_total_limit: int = 3,
|
108 |
-
add_eos_token: bool = False,
|
109 |
-
):
|
110 |
-
if llama_flash_attn:
|
111 |
-
# Need to call this before importing transformers.
|
112 |
-
from llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
|
113 |
-
replace_llama_attn_with_flash_attn()
|
114 |
-
|
115 |
-
# allow set token directly
|
116 |
-
use_auth_token = os.environ.get("HUGGINGFACE_API_TOKEN", use_auth_token)
|
117 |
-
|
118 |
-
prompt_type = str(prompt_type) # migration from integers
|
119 |
-
assert prompt_type in prompt_types
|
120 |
-
|
121 |
-
world_size = int(os.getenv("WORLD_SIZE", 1))
|
122 |
-
local_rank = int(os.getenv("LOCAL_RANK", 0))
|
123 |
-
rank = int(os.getenv("RANK", 0))
|
124 |
-
print(f"local_rank: {local_rank}")
|
125 |
-
print(f"global rank: {rank}")
|
126 |
-
|
127 |
-
gpus = max(world_size, torch.cuda.device_count())
|
128 |
-
run_id = run_id or 0
|
129 |
-
if not data_path:
|
130 |
-
raise ValueError("No data_path provided")
|
131 |
-
if not output_dir:
|
132 |
-
output_dir = f"{base_model.split('/')[-1]}.{data_path.replace('/', '')}.{num_epochs}_epochs.{get_githash() or 'nogit'}.{run_id}"
|
133 |
-
if os.path.exists(output_dir) and not resume_from_checkpoint:
|
134 |
-
raise FileExistsError(
|
135 |
-
f"output_dir {output_dir} based on run_id {run_id} already exists. Please pick a different run_id.")
|
136 |
-
else:
|
137 |
-
if os.path.exists(output_dir) and not resume_from_checkpoint:
|
138 |
-
raise FileExistsError(
|
139 |
-
f"output_dir {output_dir} already exists. Please pick a different output_dir, or specify a run_id instead.")
|
140 |
-
device_map = "auto"
|
141 |
-
|
142 |
-
if save_code:
|
143 |
-
copy_code(run_id)
|
144 |
-
if tokenizer_base_model is None:
|
145 |
-
tokenizer_base_model = base_model
|
146 |
-
if llama_type is None:
|
147 |
-
llama_type = "llama" in base_model.lower()
|
148 |
-
if llama_type and llama_flash_attn:
|
149 |
-
import pkg_resources
|
150 |
-
try:
|
151 |
-
pkg_resources.get_distribution('flash_attn')
|
152 |
-
can_do_flash_attn = True
|
153 |
-
except (pkg_resources.DistributionNotFound, pkg_resources.ContextualVersionConflict):
|
154 |
-
can_do_flash_attn = False
|
155 |
-
|
156 |
-
if not can_do_flash_attn:
|
157 |
-
raise RuntimeError("""Flash attention not installed.
|
158 |
-
NOTE: for current pytorch 2.0, flash attention requires installing cuda 11.7 via https://developer.nvidia.com/cuda-11-7-0-download-archive?target_os=Linux&target_arch=x86_64&Distribution=Ubuntu&target_version=20.04&target_type=runfile_local and then when running, to avoid installing driver, docs, samples, just install toolkit. Then when pip installing flash attention do:
|
159 |
-
|
160 |
-
CUDA_HOME=/usr/local/cuda-11.7 pip install flash-attn""")
|
161 |
-
assert (
|
162 |
-
base_model
|
163 |
-
), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
|
164 |
-
gradient_accumulation_steps = batch_size // micro_batch_size
|
165 |
-
assert gradient_accumulation_steps >= world_size, "must increase batch_size for multi-GPU"
|
166 |
-
|
167 |
-
device_map = "auto"
|
168 |
-
|
169 |
-
locals_dict = locals()
|
170 |
-
locals_print = '\n'.join(['%s: %s' % (k, v) for k, v in locals_dict.items()])
|
171 |
-
log(f"Training model with params:\n{locals_print}")
|
172 |
-
log("Command: %s\nHash: %s" % (str(' '.join(sys.argv)), get_githash()))
|
173 |
-
|
174 |
-
max_memory = None
|
175 |
-
if gpus > 1:
|
176 |
-
if ddp:
|
177 |
-
log("Distributed: data parallel")
|
178 |
-
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
|
179 |
-
gradient_accumulation_steps = gradient_accumulation_steps // world_size
|
180 |
-
else:
|
181 |
-
free_in_GB = int(min(torch.cuda.mem_get_info()) / 1024 ** 3)
|
182 |
-
max_memory = f"{free_in_GB - 2}GB"
|
183 |
-
max_memory = {i: max_memory for i in range(gpus)}
|
184 |
-
log("world_size: %d" % world_size)
|
185 |
-
log("num_gpus: %d" % gpus)
|
186 |
-
log("max mem: %s" % max_memory)
|
187 |
-
|
188 |
-
model_loader, tokenizer_loader = get_loaders(model_name=base_model, reward_type=False, llama_type=llama_type)
|
189 |
-
|
190 |
-
model = model_loader.from_pretrained(
|
191 |
-
base_model,
|
192 |
-
load_in_8bit=train_8bit,
|
193 |
-
load_in_4bit=train_4bit,
|
194 |
-
device_map=device_map,
|
195 |
-
torch_dtype=torch.float16,
|
196 |
-
max_memory=max_memory,
|
197 |
-
local_files_only=local_files_only,
|
198 |
-
trust_remote_code=True,
|
199 |
-
resume_download=resume_download,
|
200 |
-
use_auth_token=use_auth_token,
|
201 |
-
)
|
202 |
-
if gpus > 1:
|
203 |
-
if not ddp:
|
204 |
-
log("model parallel")
|
205 |
-
model.is_parallelizable = True
|
206 |
-
model.model_parallel = True
|
207 |
-
|
208 |
-
tokenizer = get_tokenizer(tokenizer_loader, tokenizer_base_model, local_files_only, resume_download, use_auth_token)
|
209 |
-
|
210 |
-
if train_8bit or train_4bit:
|
211 |
-
from peft import (
|
212 |
-
prepare_model_for_kbit_training,
|
213 |
-
)
|
214 |
-
|
215 |
-
model = prepare_model_for_kbit_training(model)
|
216 |
-
|
217 |
-
from peft import LoraConfig, get_peft_model, set_peft_model_state_dict
|
218 |
-
try:
|
219 |
-
from peft import utils
|
220 |
-
lora_mappings = utils.TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.copy()
|
221 |
-
except AttributeError:
|
222 |
-
from peft import mapping
|
223 |
-
lora_mappings = mapping.TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.copy()
|
224 |
-
lora_mappings['distilgpt2'] = ["c_attn"]
|
225 |
-
|
226 |
-
if lora_weights:
|
227 |
-
|
228 |
-
from peft import PeftModel
|
229 |
-
model = PeftModel.from_pretrained(
|
230 |
-
model,
|
231 |
-
lora_weights,
|
232 |
-
torch_dtype=torch.float16,
|
233 |
-
device_map=device_map,
|
234 |
-
local_files_only=local_files_only,
|
235 |
-
resume_download=resume_download,
|
236 |
-
use_auth_token=use_auth_token,
|
237 |
-
)
|
238 |
-
elif lora_r > 0:
|
239 |
-
if lora_target_modules is None:
|
240 |
-
base_model_lower = base_model.lower()
|
241 |
-
if base_model_lower in lora_mappings:
|
242 |
-
lora_target_modules_cand = [lora_mappings[base_model_lower]]
|
243 |
-
else:
|
244 |
-
lora_target_modules_cand = [["query_key_value"], ["q_proj", "v_proj"]]
|
245 |
-
else:
|
246 |
-
lora_target_modules_cand = [lora_target_modules]
|
247 |
-
|
248 |
-
for lora_target_modules in lora_target_modules_cand:
|
249 |
-
try:
|
250 |
-
config = LoraConfig(
|
251 |
-
r=lora_r,
|
252 |
-
lora_alpha=lora_alpha,
|
253 |
-
target_modules=lora_target_modules,
|
254 |
-
lora_dropout=lora_dropout,
|
255 |
-
bias="none",
|
256 |
-
task_type="CAUSAL_LM",
|
257 |
-
)
|
258 |
-
model = get_peft_model(model, config)
|
259 |
-
break
|
260 |
-
except ValueError as e:
|
261 |
-
if "Target modules" in str(e) and "not found" in str(e):
|
262 |
-
continue
|
263 |
-
else:
|
264 |
-
raise
|
265 |
-
from peft import PeftModel
|
266 |
-
assert isinstance(model, PeftModel), "LoRA failed. Please provide --lora_target_modules explicitly."
|
267 |
-
if resume_from_checkpoint:
|
268 |
-
# Check the available weights and load them
|
269 |
-
checkpoint_name = os.path.join(
|
270 |
-
resume_from_checkpoint, "pytorch_model.bin"
|
271 |
-
) # Full checkpoint
|
272 |
-
if not os.path.exists(checkpoint_name):
|
273 |
-
checkpoint_name = os.path.join(
|
274 |
-
resume_from_checkpoint, "adapter_model.bin"
|
275 |
-
) # only LoRA model - LoRA config above has to fit
|
276 |
-
resume_from_checkpoint = False # So the trainer won't try loading its state
|
277 |
-
# The two files above have a different name depending on how they were saved, but are actually the same.
|
278 |
-
if os.path.exists(checkpoint_name):
|
279 |
-
log(f"Restarting from {checkpoint_name}")
|
280 |
-
adapters_weights = torch.load(checkpoint_name)
|
281 |
-
set_peft_model_state_dict(model, adapters_weights)
|
282 |
-
else:
|
283 |
-
log(f"Checkpoint {checkpoint_name} not found")
|
284 |
-
|
285 |
-
print(model)
|
286 |
-
try:
|
287 |
-
# only for PeftModel
|
288 |
-
model.print_trainable_parameters() # Be more transparent about the % of trainable params.
|
289 |
-
except:
|
290 |
-
pass
|
291 |
-
|
292 |
-
metrics = {}
|
293 |
-
for name in supported_metrics:
|
294 |
-
if name in val_metrics:
|
295 |
-
import evaluate # Causes hang for 'python generate.py' on dual 4090 if imported early, 100% reproducible
|
296 |
-
metrics[name] = evaluate.load(name)
|
297 |
-
log("Using Validation Metrics: %s" % str(list(metrics.keys())))
|
298 |
-
log("Supported Metrics: %s" % supported_metrics)
|
299 |
-
|
300 |
-
if val_set_size is None:
|
301 |
-
if len(metrics) == 0:
|
302 |
-
val_set_size = 1000
|
303 |
-
else:
|
304 |
-
val_set_size = 100
|
305 |
-
log("Auto set val_set_size %s" % val_set_size)
|
306 |
-
elif val_set_size < 1.0 and val_set_size != 0:
|
307 |
-
raise RuntimeError("Fractional validation size not supported.")
|
308 |
-
|
309 |
-
from datasets import load_dataset, concatenate_datasets
|
310 |
-
if valid_path:
|
311 |
-
data = load_dataset("json", data_files={"train": data_path, "valid": valid_path})
|
312 |
-
else:
|
313 |
-
if "json" in data_path:
|
314 |
-
data = load_dataset("json", data_files={"train": data_path})
|
315 |
-
else:
|
316 |
-
data = load_dataset(data_path)
|
317 |
-
data = data.rename_columns(data_col_dict or {})
|
318 |
-
|
319 |
-
valid_data = None
|
320 |
-
train_data_mix_in = None
|
321 |
-
valid_data_mix_in = None
|
322 |
-
|
323 |
-
if data_mix_in_path and data_mix_in_factor > 0:
|
324 |
-
# get mix-in training/validation data - to keep model "sane"
|
325 |
-
num_rows = data["train"].num_rows
|
326 |
-
log("Loading mix-in dataset: %s" % data_mix_in_path)
|
327 |
-
if "json" in data_mix_in_path:
|
328 |
-
data_mix_in = load_dataset("json", data_files={"train": data_mix_in_path})["train"]
|
329 |
-
else:
|
330 |
-
data_mix_in = load_dataset(data_mix_in_path)["train"] # can be large
|
331 |
-
data_mix_in = data_mix_in.rename_columns(data_mix_in_col_dict or {})
|
332 |
-
mix_in_rows = int(num_rows * data_mix_in_factor)
|
333 |
-
|
334 |
-
if mix_in_rows > data_mix_in.num_rows:
|
335 |
-
# duplicate rows if mix-in is smaller than required
|
336 |
-
log("Duplicating mixin to compensate for its size for training size and mixin fraction")
|
337 |
-
data_mix_in = concatenate_datasets([data_mix_in] * int(np.ceil(mix_in_rows / data_mix_in.num_rows)))
|
338 |
-
|
339 |
-
# only get as much as we need to balance
|
340 |
-
valid_size = min(data_mix_in.num_rows // 2, val_set_size or 0)
|
341 |
-
train_size = max(1, min(data_mix_in.num_rows - valid_size, mix_in_rows))
|
342 |
-
mixin_small = data_mix_in.train_test_split(
|
343 |
-
test_size=train_size + valid_size,
|
344 |
-
shuffle=True, seed=np.random.randint(10000),
|
345 |
-
)["test"]
|
346 |
-
if valid_size:
|
347 |
-
mixin_train_test = mixin_small.train_test_split(
|
348 |
-
test_size=valid_size, shuffle=False,
|
349 |
-
)
|
350 |
-
train_data_mix_in = mixin_train_test["train"]
|
351 |
-
valid_data_mix_in = mixin_train_test["test"]
|
352 |
-
else:
|
353 |
-
train_data_mix_in = mixin_small
|
354 |
-
|
355 |
-
if "prompt_type" not in train_data_mix_in.column_names:
|
356 |
-
train_data_mix_in = train_data_mix_in.add_column(
|
357 |
-
"prompt_type",
|
358 |
-
[data_mix_in_prompt_type] * train_data_mix_in.num_rows,
|
359 |
-
)
|
360 |
-
log("Added prompt type %s to mix-in training data" % data_mix_in_prompt_type)
|
361 |
-
if valid_data_mix_in and "prompt_type" not in valid_data_mix_in.column_names:
|
362 |
-
valid_data_mix_in = valid_data_mix_in.add_column(
|
363 |
-
"prompt_type",
|
364 |
-
[data_mix_in_prompt_type] * valid_data_mix_in.num_rows,
|
365 |
-
)
|
366 |
-
log("Added prompt type %s to mix-in validation data" % data_mix_in_prompt_type)
|
367 |
-
log("Created mix-in data:\nTrain %s\nValid %s" % (train_data_mix_in, valid_data_mix_in))
|
368 |
-
|
369 |
-
# get our own training/validation data - for fine-tuning
|
370 |
-
if val_set_size > 0 and not valid_path and not data_mix_in_path:
|
371 |
-
# create valid split from train
|
372 |
-
train_val = data["train"].train_test_split(
|
373 |
-
test_size=val_set_size, shuffle=True, seed=42
|
374 |
-
)
|
375 |
-
train_data = train_val["train"]
|
376 |
-
valid_data = train_val["test"]
|
377 |
-
else:
|
378 |
-
train_data = data["train"]
|
379 |
-
if valid_path:
|
380 |
-
# use given valid split, has priority over data_mix_in_path
|
381 |
-
valid_data = data["valid"]
|
382 |
-
if "prompt_type" not in train_data.column_names:
|
383 |
-
train_data = train_data.add_column(
|
384 |
-
"prompt_type",
|
385 |
-
[prompt_type] * train_data.num_rows,
|
386 |
-
)
|
387 |
-
log("Added prompt type %s to training data" % prompt_type)
|
388 |
-
if valid_data and "prompt_type" not in valid_data.column_names:
|
389 |
-
valid_data = valid_data.add_column(
|
390 |
-
"prompt_type",
|
391 |
-
[prompt_type] * valid_data.num_rows,
|
392 |
-
)
|
393 |
-
log("Added prompt type %s to validation data" % prompt_type)
|
394 |
-
|
395 |
-
assert train_data is not None
|
396 |
-
|
397 |
-
generate_and_tokenize_prompt_fun = partial(generate_and_tokenize_prompt, prompt_type=prompt_type,
|
398 |
-
train_on_inputs=train_on_inputs, add_eos_token=add_eos_token,
|
399 |
-
cutoff_len=cutoff_len, tokenizer=tokenizer)
|
400 |
-
|
401 |
-
# shuffle and tokenize data
|
402 |
-
if train_data_mix_in:
|
403 |
-
train_data = concatenate_datasets([train_data, train_data_mix_in])
|
404 |
-
log("Tokenizing %s training rows" % train_data.num_rows)
|
405 |
-
train_data = train_data.shuffle().map(generate_and_tokenize_prompt_fun,
|
406 |
-
num_proc=os.cpu_count() // torch.cuda.device_count())
|
407 |
-
if drop_truncations:
|
408 |
-
log("avoid keeping truncated cases to avoid contaminating model with truncation cases. Original size: %s" % train_data.num_rows)
|
409 |
-
prune_long_sequences_func = partial(prune_long_sequences, cutoff_len=cutoff_len)
|
410 |
-
train_data = train_data.filter(prune_long_sequences_func, num_proc=os.cpu_count() // torch.cuda.device_count())
|
411 |
-
log("avoid keeping truncated cases to avoid contaminating model with truncation cases. New size: %s" % train_data.num_rows)
|
412 |
-
train_set_size = len(train_data)
|
413 |
-
|
414 |
-
if valid_data and valid_data_mix_in:
|
415 |
-
valid_data = concatenate_datasets([valid_data, valid_data_mix_in])
|
416 |
-
elif valid_data_mix_in:
|
417 |
-
valid_data = valid_data_mix_in
|
418 |
-
|
419 |
-
if valid_data:
|
420 |
-
log("Tokenizing %s validation rows" % valid_data.num_rows)
|
421 |
-
valid_data = valid_data.shuffle().map(generate_and_tokenize_prompt_fun,
|
422 |
-
num_proc=os.cpu_count() // torch.cuda.device_count())
|
423 |
-
val_set_size = len(valid_data)
|
424 |
-
else:
|
425 |
-
val_set_size = 0
|
426 |
-
log("Final fine-tuning data:\nTrain %s\nValid %s" % (train_data, valid_data))
|
427 |
-
sample_row_dict = train_data[:1]
|
428 |
-
del sample_row_dict['input_ids']
|
429 |
-
del sample_row_dict['attention_mask']
|
430 |
-
del sample_row_dict['labels']
|
431 |
-
log("Sample input: %s" % sample_row_dict)
|
432 |
-
|
433 |
-
try:
|
434 |
-
import neptune
|
435 |
-
from transformers.integrations import NeptuneCallback
|
436 |
-
|
437 |
-
neptune_run = neptune.init_run(
|
438 |
-
source_files=[],
|
439 |
-
)
|
440 |
-
log("Connected to Neptune.")
|
441 |
-
except ImportError:
|
442 |
-
neptune_run = None
|
443 |
-
log("Please pip install neptune for tracking.")
|
444 |
-
except neptune.exceptions.NeptuneMissingApiTokenException:
|
445 |
-
neptune_run = None
|
446 |
-
os.environ["NEPTUNE_MODE"] = 'debug'
|
447 |
-
log("No neptune configured, set NEPTUNE_API_TOKEN env var.")
|
448 |
-
|
449 |
-
if neptune_run:
|
450 |
-
neptune_callback = NeptuneCallback(run=neptune_run)
|
451 |
-
callbacks = [neptune_callback]
|
452 |
-
else:
|
453 |
-
from transformers.integrations import TensorBoardCallback, is_tensorboard_available
|
454 |
-
if is_tensorboard_available:
|
455 |
-
# tensorboard --logdir=runs/
|
456 |
-
from torch.utils.tensorboard import SummaryWriter
|
457 |
-
tb_writer = SummaryWriter()
|
458 |
-
callbacks = [TensorBoardCallback(tb_writer=tb_writer)]
|
459 |
-
else:
|
460 |
-
callbacks = []
|
461 |
-
|
462 |
-
expected_steps = (train_set_size * num_epochs) // batch_size
|
463 |
-
if eval_steps is None and eval_epochs is None:
|
464 |
-
# 20 evaluations for a run
|
465 |
-
eval_steps = max(1, int(expected_steps / 20))
|
466 |
-
log("Auto set eval_steps to %s out of %s total training steps" % (eval_steps, expected_steps))
|
467 |
-
elif eval_steps is None and eval_epochs is not None:
|
468 |
-
eval_steps = max(1, int(expected_steps * eval_epochs / num_epochs))
|
469 |
-
log("Auto converted eval_epochs=%s to eval_steps %s"
|
470 |
-
" out of %s total training steps" % (eval_epochs, eval_steps, expected_steps))
|
471 |
-
if save_steps is None:
|
472 |
-
save_steps = eval_steps
|
473 |
-
log("Auto step save_steps to %s" % save_steps)
|
474 |
-
elif save_steps > eval_steps:
|
475 |
-
# save steps must be round multiple of eval_steps
|
476 |
-
save_steps0 = save_steps
|
477 |
-
save_steps = max(1, (save_steps // eval_steps)) * eval_steps
|
478 |
-
if save_steps0 != save_steps:
|
479 |
-
log("Auto converted save_steps from %s to %s" % (save_steps0, save_steps))
|
480 |
-
|
481 |
-
def compute_metrics(eval_preds):
|
482 |
-
# e.g. see: https://huggingface.co/docs/transformers/v4.25.1/en/tasks/translation#evaluate
|
483 |
-
inputs = eval_preds.inputs
|
484 |
-
label_ids = eval_preds.label_ids
|
485 |
-
predictions = eval_preds.predictions
|
486 |
-
|
487 |
-
# inputs = np.where(inputs != -100, inputs, tokenizer.pad_token_id)
|
488 |
-
# decoded_inputs = tokenizer.batch_decode(inputs, skip_special_tokens=True)
|
489 |
-
# decoded_inputs = [pred.strip() for pred in decoded_inputs]
|
490 |
-
|
491 |
-
label_ids = np.where(label_ids != -100, label_ids, tokenizer.pad_token_id)
|
492 |
-
# tokenizer behavior like generate time
|
493 |
-
decoded_labels = tokenizer.batch_decode(label_ids, skip_special_tokens=True,
|
494 |
-
clean_up_tokenization_spaces=True)
|
495 |
-
decoded_labels = [pred.strip() for pred in decoded_labels]
|
496 |
-
|
497 |
-
predictions = np.argmax(predictions, -1)
|
498 |
-
predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
|
499 |
-
# tokenizer behavior like generate time
|
500 |
-
decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True,
|
501 |
-
clean_up_tokenization_spaces=True)
|
502 |
-
decoded_predictions = [pred.strip() for pred in decoded_predictions]
|
503 |
-
|
504 |
-
result = {}
|
505 |
-
for metric in metrics.values():
|
506 |
-
result1 = metric.compute(predictions=decoded_predictions, references=decoded_labels)
|
507 |
-
# get rid of lists, for precision etc., for now
|
508 |
-
numeric_results = {k: v for k, v in result1.items() if isinstance(v, (int, float))}
|
509 |
-
result.update(numeric_results)
|
510 |
-
return result
|
511 |
-
|
512 |
-
# the callback that computes metrics of interest
|
513 |
-
if val_metrics:
|
514 |
-
trainer_kwargs = dict(compute_metrics=compute_metrics)
|
515 |
-
else:
|
516 |
-
trainer_kwargs = dict()
|
517 |
-
|
518 |
-
import transformers
|
519 |
-
trainer = transformers.Trainer(
|
520 |
-
model=model,
|
521 |
-
tokenizer=tokenizer,
|
522 |
-
train_dataset=train_data,
|
523 |
-
eval_dataset=valid_data,
|
524 |
-
# FIXME: might need Seq2SeqTrainingArguments for some models
|
525 |
-
args=transformers.TrainingArguments(
|
526 |
-
per_device_train_batch_size=micro_batch_size,
|
527 |
-
per_device_eval_batch_size=1,
|
528 |
-
eval_accumulation_steps=10,
|
529 |
-
# predict_with_generate=True, # SEQ2SEQ only
|
530 |
-
include_inputs_for_metrics=True,
|
531 |
-
gradient_accumulation_steps=gradient_accumulation_steps,
|
532 |
-
warmup_steps=warmup_steps,
|
533 |
-
num_train_epochs=num_epochs,
|
534 |
-
learning_rate=learning_rate,
|
535 |
-
gradient_checkpointing=gradient_checkpointing,
|
536 |
-
fp16=fp16,
|
537 |
-
# cosnider 8-bit adam: https://huggingface.co/docs/transformers/v4.18.0/en/performance#8bit-adam
|
538 |
-
optim="adamw_torch", # consider "adafactor" to save memory
|
539 |
-
logging_steps=logging_steps,
|
540 |
-
logging_strategy="steps",
|
541 |
-
evaluation_strategy="steps" if val_set_size > 0 else "no",
|
542 |
-
save_strategy="steps",
|
543 |
-
eval_steps=eval_steps if val_set_size > 0 else None,
|
544 |
-
save_steps=save_steps,
|
545 |
-
output_dir=output_dir,
|
546 |
-
save_total_limit=save_total_limit,
|
547 |
-
load_best_model_at_end=True if val_set_size > 0 else False,
|
548 |
-
ddp_find_unused_parameters=False if ddp else None,
|
549 |
-
group_by_length=group_by_length,
|
550 |
-
# fsdp="shard_grad_op auto_wrap" if gpus > 1 and not ddp else None,
|
551 |
-
# fsdp_min_num_params=20000 if gpus > 1 and not ddp else None,
|
552 |
-
report_to='tensorboard' if not neptune_run else 'neptune',
|
553 |
-
),
|
554 |
-
data_collator=transformers.DataCollatorForSeq2Seq(
|
555 |
-
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
|
556 |
-
),
|
557 |
-
callbacks=callbacks,
|
558 |
-
**trainer_kwargs,
|
559 |
-
)
|
560 |
-
model.config.use_cache = False
|
561 |
-
|
562 |
-
if torch.__version__ >= "2" and sys.platform != "win32":
|
563 |
-
model = torch.compile(model)
|
564 |
-
# WIP (not generally replacing layers until pytorch 2.1)
|
565 |
-
if not llama_flash_attn:
|
566 |
-
torch.backends.cuda.enable_flash_sdp(True)
|
567 |
-
|
568 |
-
if gpus > 1 and not ddp:
|
569 |
-
assert trainer.is_model_parallel
|
570 |
-
else:
|
571 |
-
assert not trainer.is_model_parallel
|
572 |
-
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
573 |
-
|
574 |
-
model.save_pretrained(output_dir)
|
575 |
-
|
576 |
-
log("\n If there's a warning about missing keys above, please disregard :)")
|
577 |
-
|
578 |
-
|
579 |
-
def tokenize(prompt, tokenizer, cutoff_len, add_eos_token=False):
|
580 |
-
# there's probably a way to do this with the tokenizer settings
|
581 |
-
# but again, gotta move fast
|
582 |
-
result = tokenizer(
|
583 |
-
prompt,
|
584 |
-
truncation=True,
|
585 |
-
max_length=cutoff_len,
|
586 |
-
padding=False,
|
587 |
-
return_tensors=None,
|
588 |
-
)
|
589 |
-
if (
|
590 |
-
result["input_ids"][-1] != tokenizer.eos_token_id
|
591 |
-
and len(result["input_ids"]) < cutoff_len
|
592 |
-
and add_eos_token
|
593 |
-
):
|
594 |
-
result["input_ids"].append(tokenizer.eos_token_id)
|
595 |
-
result["attention_mask"].append(1)
|
596 |
-
|
597 |
-
result["labels"] = result["input_ids"].copy()
|
598 |
-
|
599 |
-
return result
|
600 |
-
|
601 |
-
|
602 |
-
def prune_long_sequences(data_point, cutoff_len=None):
|
603 |
-
"""
|
604 |
-
Prune if too long for tokenizer, so truncation doesn't lead training to learn from truncated language
|
605 |
-
:param data_point:
|
606 |
-
:param cutoff_len:
|
607 |
-
:return:
|
608 |
-
"""
|
609 |
-
assert cutoff_len is not None
|
610 |
-
return len(data_point['input_ids']) < cutoff_len
|
611 |
-
|
612 |
-
|
613 |
-
def generate_and_tokenize_prompt(data_point, prompt_type=None, train_on_inputs=False, add_eos_token=False,
|
614 |
-
cutoff_len=None, tokenizer=None):
|
615 |
-
assert prompt_type is not None
|
616 |
-
assert cutoff_len is not None
|
617 |
-
assert tokenizer is not None
|
618 |
-
prompt_dict = '' # only for custom prompt_type
|
619 |
-
assert prompt_type != PromptType.custom.name, "custom not setup for finetune"
|
620 |
-
full_prompt, _, _, _, _ = generate_prompt(data_point, prompt_type, prompt_dict, False, False, False)
|
621 |
-
tokenized_full_prompt = tokenize(full_prompt, tokenizer, cutoff_len, add_eos_token=add_eos_token)
|
622 |
-
if not train_on_inputs:
|
623 |
-
user_prompt, _, _, _, _ = generate_prompt({**data_point, "output": ""}, prompt_type, prompt_dict, False, False, False)
|
624 |
-
tokenized_user_prompt = tokenize(user_prompt, tokenizer, cutoff_len, add_eos_token=add_eos_token)
|
625 |
-
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
626 |
-
if add_eos_token:
|
627 |
-
user_prompt_len -= 1
|
628 |
-
|
629 |
-
# ignore_index=-100 ensures torch/tf don't include padding token id in CrossEntropyLoss
|
630 |
-
tokenized_full_prompt["labels"] = [
|
631 |
-
-100
|
632 |
-
] * user_prompt_len + tokenized_full_prompt["labels"][
|
633 |
-
user_prompt_len:
|
634 |
-
] # could be sped up, probably
|
635 |
-
return tokenized_full_prompt
|
636 |
-
|
637 |
-
|
638 |
-
def test_debug():
|
639 |
-
fire.Fire(train)
|
640 |
-
|
641 |
-
|
642 |
-
def entrypoint_main():
|
643 |
-
CONFIG = "NCCL_P2P_LEVEL=LOC WORLD_SIZE=5 torchrun --nnodes=5 --master_addr=10.10.10.2 --master_port=1111 --nproc_per_node=1"
|
644 |
-
CMD = "finetune.py --data_path=config.json --num_epochs=1 --base_model=decapoda-research/llama-13b-hf"
|
645 |
-
log(f"""
|
646 |
-
Example runs on 4 GPUs:
|
647 |
-
WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='decapoda-research/llama-7b-hf' --data_path=data/config.json --run_id=0 &> 0.log
|
648 |
-
WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='decapoda-research/llama-30b-hf' --data_path=data/config.json --batch_size=16 --micro_batch_size=1 --run_id=1 --save_code=True &> 1.log
|
649 |
-
WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='EleutherAI/gpt-j-6B' --data_path=data/config.json --run_id=2 &> 2.log
|
650 |
-
WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='EleutherAI/gpt-neox-20b' --data_path=data/config.json --run_id=8 --batch_size=16 --micro_batch_size=4 &> 8.log
|
651 |
-
WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --data_path=data/config.json --prompt_type='dai_faq' --run_id=13 --batch_size=16 --micro_batch_size=4 --num_epochs=100 --val_set_size=0 data_mix_in_path='' &> 13.log
|
652 |
-
WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --data_path=data/config.json --run_id=28 --batch_size=16 --micro_batch_size=4 --num_epochs=8 --val_set_size=0 --data_mix_in_factor=0.1 --data_mix_in_prompt_type='human_bot' --save_code=True --cutoff_len=512 &> 28.log
|
653 |
-
|
654 |
-
All metrics:
|
655 |
-
CUDA_VISIBLE_DEVICES= finetune.py --data_mix_in_factor=0 --eval_steps=100 --warmup_steps=2 --val_set_size=100 --val_metrics="['bleu', 'rouge', 'sacrebleu', 'meteor']"
|
656 |
-
|
657 |
-
# Fine-tune 20B on 24GB GPUs across 3 nodes with 3+2+2 GPUs
|
658 |
-
rippa>
|
659 |
-
NCCL_P2P_LEVEL=LOC WORLD_SIZE=7 CUDA_VISIBLE_DEVICES="0,1,2" torchrun --node_rank 0 --nproc_per_node=3 --master_port=1234 --nnodes=3 --master_addr=10.10.10.2 finetune.py --data_path=merged_shuffled_OIG_87f6a1e788.json --micro_batch_size=1 --batch_size=7 --cutoff_len=512 --run_id=17 &>log.17.rank0
|
660 |
-
ova>
|
661 |
-
NCCL_P2P_LEVEL=LOC WORLD_SIZE=7 CUDA_VISIBLE_DEVICES="0,1" torchrun --node_rank 1 --nproc_per_node=2 --master_port=1234 --nnodes=3 --master_addr=10.10.10.2 finetune.py --data_path=merged_shuffled_OIG_87f6a1e788.json --micro_batch_size=1 --batch_size=7 --cutoff_len=512 --run_id=17 &>log.17.rank1
|
662 |
-
timemachine>
|
663 |
-
NCCL_P2P_LEVEL=LOC WORLD_SIZE=7 CUDA_VISIBLE_DEVICES="0,1" torchrun --node_rank 2 --nproc_per_node=2 --master_port=1234 --nnodes=3 --master_addr=10.10.10.2 finetune.py --data_path=merged_shuffled_OIG_87f6a1e788.json --micro_batch_size=1 --batch_size=7 --cutoff_len=512 --run_id=17 &>log.17.rank2
|
664 |
-
|
665 |
-
""", flush=True)
|
666 |
-
|
667 |
-
if os.environ.get("LOCAL_RANK") is None:
|
668 |
-
# then not using torchrun, so can't do distributed, ensure CVD set
|
669 |
-
assert os.environ.get(
|
670 |
-
"CUDA_VISIBLE_DEVICES") is not None, "Run python script using: torchrun finetune.py OR set CUDA_VISIBLE_DEVICES to single GPU"
|
671 |
-
|
672 |
-
fire.Fire(train)
|
673 |
-
|
674 |
-
|
675 |
-
if __name__ == "__main__":
|
676 |
-
entrypoint_main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
generate.py
DELETED
The diff for this file is too large to render.
See raw diff
|
|
gpt4all_llm.py
CHANGED
@@ -19,6 +19,15 @@ def get_model_tokenizer_gpt4all(base_model, **kwargs):
|
|
19 |
n_ctx=2048 - 256)
|
20 |
env_gpt4all_file = ".env_gpt4all"
|
21 |
model_kwargs.update(dotenv_values(env_gpt4all_file))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
if base_model == "llama":
|
24 |
if 'model_path_llama' not in model_kwargs:
|
|
|
19 |
n_ctx=2048 - 256)
|
20 |
env_gpt4all_file = ".env_gpt4all"
|
21 |
model_kwargs.update(dotenv_values(env_gpt4all_file))
|
22 |
+
# make int or float if can to satisfy types for class
|
23 |
+
for k, v in model_kwargs.items():
|
24 |
+
try:
|
25 |
+
if float(v) == int(v):
|
26 |
+
model_kwargs[k] = int(v)
|
27 |
+
else:
|
28 |
+
model_kwargs[k] = float(v)
|
29 |
+
except:
|
30 |
+
pass
|
31 |
|
32 |
if base_model == "llama":
|
33 |
if 'model_path_llama' not in model_kwargs:
|
gpt_langchain.py
CHANGED
@@ -23,8 +23,9 @@ from langchain.callbacks import streaming_stdout
|
|
23 |
from langchain.embeddings import HuggingFaceInstructEmbeddings
|
24 |
from tqdm import tqdm
|
25 |
|
26 |
-
from enums import DocumentChoices, no_lora_str, model_token_mapping, source_prefix, source_postfix
|
27 |
-
|
|
|
28 |
from prompter import non_hf_types, PromptType, Prompter
|
29 |
from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \
|
30 |
get_device, ProgressParallel, remove, hash_file, clear_torch_cache, NullContext, get_hf_server, FakeTokenizer
|
@@ -43,7 +44,8 @@ from langchain.chains.qa_with_sources import load_qa_with_sources_chain
|
|
43 |
from langchain.document_loaders import PyPDFLoader, TextLoader, CSVLoader, PythonLoader, TomlLoader, \
|
44 |
UnstructuredURLLoader, UnstructuredHTMLLoader, UnstructuredWordDocumentLoader, UnstructuredMarkdownLoader, \
|
45 |
EverNoteLoader, UnstructuredEmailLoader, UnstructuredODTLoader, UnstructuredPowerPointLoader, \
|
46 |
-
UnstructuredEPubLoader, UnstructuredImageLoader, UnstructuredRTFLoader, ArxivLoader, UnstructuredPDFLoader
|
|
|
47 |
from langchain.text_splitter import RecursiveCharacterTextSplitter, Language
|
48 |
from langchain.chains.question_answering import load_qa_chain
|
49 |
from langchain.docstore.document import Document
|
@@ -351,6 +353,7 @@ class GradioInference(LLM):
|
|
351 |
stream_output = self.stream
|
352 |
gr_client = self.client
|
353 |
client_langchain_mode = 'Disabled'
|
|
|
354 |
top_k_docs = 1
|
355 |
chunk = True
|
356 |
chunk_size = 512
|
@@ -379,6 +382,7 @@ class GradioInference(LLM):
|
|
379 |
instruction_nochat=prompt if not self.chat_client else '',
|
380 |
iinput_nochat='', # only for chat=False
|
381 |
langchain_mode=client_langchain_mode,
|
|
|
382 |
top_k_docs=top_k_docs,
|
383 |
chunk=chunk,
|
384 |
chunk_size=chunk_size,
|
@@ -637,6 +641,7 @@ def get_llm(use_openai_model=False,
|
|
637 |
callbacks = [StreamingGradioCallbackHandler()]
|
638 |
assert prompter is not None
|
639 |
stop_sequences = list(set(prompter.terminate_response + [prompter.PreResponse]))
|
|
|
640 |
|
641 |
if gr_client:
|
642 |
chat_client = False
|
@@ -744,7 +749,7 @@ def get_llm(use_openai_model=False,
|
|
744 |
|
745 |
if stream_output:
|
746 |
skip_prompt = False
|
747 |
-
from
|
748 |
decoder_kwargs = {}
|
749 |
streamer = H2OTextIteratorStreamer(tokenizer, skip_prompt=skip_prompt, block=False, **decoder_kwargs)
|
750 |
gen_kwargs.update(dict(streamer=streamer))
|
@@ -944,14 +949,16 @@ have_playwright = False
|
|
944 |
|
945 |
image_types = ["png", "jpg", "jpeg"]
|
946 |
non_image_types = ["pdf", "txt", "csv", "toml", "py", "rst", "rtf",
|
947 |
-
"md",
|
|
|
948 |
"enex", "eml", "epub", "odt", "pptx", "ppt",
|
949 |
"zip", "urls",
|
|
|
950 |
]
|
951 |
# "msg", GPL3
|
952 |
|
953 |
if have_libreoffice:
|
954 |
-
non_image_types.extend(["docx", "doc"])
|
955 |
|
956 |
file_types = non_image_types + image_types
|
957 |
|
@@ -961,7 +968,7 @@ def add_meta(docs1, file):
|
|
961 |
hashid = hash_file(file)
|
962 |
if not isinstance(docs1, (list, tuple, types.GeneratorType)):
|
963 |
docs1 = [docs1]
|
964 |
-
[x.metadata.update(dict(input_type=file_extension, date=str(datetime.now), hashid=hashid)) for x in docs1]
|
965 |
|
966 |
|
967 |
def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
|
@@ -1038,6 +1045,10 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
|
|
1038 |
docs1 = UnstructuredWordDocumentLoader(file_path=file).load()
|
1039 |
add_meta(docs1, file)
|
1040 |
doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
|
|
|
|
|
|
|
|
|
1041 |
elif file.lower().endswith('.odt'):
|
1042 |
docs1 = UnstructuredODTLoader(file_path=file).load()
|
1043 |
add_meta(docs1, file)
|
@@ -1758,6 +1769,8 @@ def run_qa_db(**kwargs):
|
|
1758 |
|
1759 |
|
1760 |
def _run_qa_db(query=None,
|
|
|
|
|
1761 |
use_openai_model=False, use_openai_embedding=False,
|
1762 |
first_para=False, text_limit=None, top_k_docs=4, chunk=True, chunk_size=512,
|
1763 |
user_path=None,
|
@@ -1787,6 +1800,7 @@ def _run_qa_db(query=None,
|
|
1787 |
repetition_penalty=1.0,
|
1788 |
num_return_sequences=1,
|
1789 |
langchain_mode=None,
|
|
|
1790 |
document_choice=[DocumentChoices.All_Relevant.name],
|
1791 |
n_jobs=-1,
|
1792 |
verbose=False,
|
@@ -1803,7 +1817,7 @@ def _run_qa_db(query=None,
|
|
1803 |
:param use_openai_embedding:
|
1804 |
:param first_para:
|
1805 |
:param text_limit:
|
1806 |
-
:param
|
1807 |
:param chunk:
|
1808 |
:param chunk_size:
|
1809 |
:param user_path: user path to glob recursively from
|
@@ -1869,12 +1883,28 @@ def _run_qa_db(query=None,
|
|
1869 |
sim_kwargs = {k: v for k, v in locals().items() if k in func_names}
|
1870 |
missing_kwargs = [x for x in func_names if x not in sim_kwargs]
|
1871 |
assert not missing_kwargs, "Missing: %s" % missing_kwargs
|
1872 |
-
docs, chain, scores, use_context = get_similarity_chain(**sim_kwargs)
|
1873 |
-
if cmd in
|
1874 |
formatted_doc_chunks = '\n\n'.join([get_url(x) + '\n\n' + x.page_content for x in docs])
|
1875 |
yield formatted_doc_chunks, ''
|
1876 |
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1877 |
if chain is None and model_name not in non_hf_types:
|
|
|
1878 |
# can only return if HF type
|
1879 |
return
|
1880 |
|
@@ -1933,6 +1963,7 @@ def _run_qa_db(query=None,
|
|
1933 |
|
1934 |
|
1935 |
def get_similarity_chain(query=None,
|
|
|
1936 |
use_openai_model=False, use_openai_embedding=False,
|
1937 |
first_para=False, text_limit=None, top_k_docs=4, chunk=True, chunk_size=512,
|
1938 |
user_path=None,
|
@@ -1947,6 +1978,7 @@ def get_similarity_chain(query=None,
|
|
1947 |
load_db_if_exists=False,
|
1948 |
db=None,
|
1949 |
langchain_mode=None,
|
|
|
1950 |
document_choice=[DocumentChoices.All_Relevant.name],
|
1951 |
n_jobs=-1,
|
1952 |
# beyond run_db_query:
|
@@ -1997,25 +2029,56 @@ def get_similarity_chain(query=None,
|
|
1997 |
db=db,
|
1998 |
n_jobs=n_jobs,
|
1999 |
verbose=verbose)
|
2000 |
-
|
2001 |
-
if
|
2002 |
-
|
2003 |
-
|
2004 |
-
|
2005 |
-
|
2006 |
-
|
2007 |
-
|
2008 |
-
|
2009 |
-
|
2010 |
-
|
2011 |
-
|
2012 |
-
|
2013 |
-
|
2014 |
-
|
2015 |
-
{context}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2016 |
\"\"\"
|
2017 |
-
%s
|
2018 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2019 |
if not use_openai_model and prompt_type not in ['plain'] or model_name in non_hf_types:
|
2020 |
use_template = True
|
2021 |
else:
|
@@ -2040,14 +2103,26 @@ def get_similarity_chain(query=None,
|
|
2040 |
if cmd == DocumentChoices.Just_LLM.name:
|
2041 |
docs = []
|
2042 |
scores = []
|
2043 |
-
elif cmd == DocumentChoices.Only_All_Sources.name:
|
2044 |
db_documents, db_metadatas = get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs)
|
2045 |
# similar to langchain's chroma's _results_to_docs_and_scores
|
2046 |
docs_with_score = [(Document(page_content=result[0], metadata=result[1] or {}), 0)
|
2047 |
-
for result in zip(db_documents, db_metadatas)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2048 |
docs = [x[0] for x in docs_with_score]
|
2049 |
scores = [x[1] for x in docs_with_score]
|
|
|
2050 |
else:
|
|
|
|
|
2051 |
if top_k_docs == -1 or auto_reduce_chunks:
|
2052 |
# docs_with_score = db.similarity_search_with_score(query, k=k_db, **filter_kwargs)[:top_k_docs]
|
2053 |
top_k_docs_tokenize = 100
|
@@ -2120,6 +2195,7 @@ def get_similarity_chain(query=None,
|
|
2120 |
if reverse_docs:
|
2121 |
docs_with_score.reverse()
|
2122 |
# cut off so no high distance docs/sources considered
|
|
|
2123 |
docs = [x[0] for x in docs_with_score if x[1] < cut_distanct]
|
2124 |
scores = [x[1] for x in docs_with_score if x[1] < cut_distanct]
|
2125 |
if len(scores) > 0 and verbose:
|
@@ -2131,14 +2207,14 @@ def get_similarity_chain(query=None,
|
|
2131 |
|
2132 |
if not docs and use_context and model_name not in non_hf_types:
|
2133 |
# if HF type and have no docs, can bail out
|
2134 |
-
return docs, None, [], False
|
2135 |
|
2136 |
-
if cmd in
|
2137 |
# no LLM use
|
2138 |
-
return docs, None, [], False
|
2139 |
|
2140 |
common_words_file = "data/NGSL_1.2_stats.csv.zip"
|
2141 |
-
if os.path.isfile(common_words_file):
|
2142 |
df = pd.read_csv("data/NGSL_1.2_stats.csv.zip")
|
2143 |
import string
|
2144 |
reduced_query = query.translate(str.maketrans(string.punctuation, ' ' * len(string.punctuation))).strip()
|
@@ -2155,25 +2231,47 @@ def get_similarity_chain(query=None,
|
|
2155 |
use_context = False
|
2156 |
template = template_if_no_docs
|
2157 |
|
2158 |
-
if
|
2159 |
-
|
2160 |
-
|
2161 |
-
|
2162 |
-
|
2163 |
-
|
2164 |
-
|
2165 |
-
|
2166 |
-
|
2167 |
-
|
2168 |
-
|
2169 |
-
|
2170 |
-
|
2171 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2172 |
else:
|
2173 |
-
|
2174 |
|
2175 |
-
target
|
2176 |
-
return docs, target, scores, use_context
|
2177 |
|
2178 |
|
2179 |
def get_sources_answer(query, answer, scores, show_rank, answer_with_sources, verbose=False):
|
@@ -2243,6 +2341,11 @@ def chunk_sources(sources, chunk=True, chunk_size=512, language=None):
|
|
2243 |
splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=0, keep_separator=keep_separator,
|
2244 |
separators=separators)
|
2245 |
source_chunks = splitter.split_documents(sources)
|
|
|
|
|
|
|
|
|
|
|
2246 |
return source_chunks
|
2247 |
|
2248 |
|
|
|
23 |
from langchain.embeddings import HuggingFaceInstructEmbeddings
|
24 |
from tqdm import tqdm
|
25 |
|
26 |
+
from enums import DocumentChoices, no_lora_str, model_token_mapping, source_prefix, source_postfix, non_query_commands, \
|
27 |
+
LangChainAction, LangChainMode
|
28 |
+
from src.gen import gen_hyper, get_model, SEED
|
29 |
from prompter import non_hf_types, PromptType, Prompter
|
30 |
from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \
|
31 |
get_device, ProgressParallel, remove, hash_file, clear_torch_cache, NullContext, get_hf_server, FakeTokenizer
|
|
|
44 |
from langchain.document_loaders import PyPDFLoader, TextLoader, CSVLoader, PythonLoader, TomlLoader, \
|
45 |
UnstructuredURLLoader, UnstructuredHTMLLoader, UnstructuredWordDocumentLoader, UnstructuredMarkdownLoader, \
|
46 |
EverNoteLoader, UnstructuredEmailLoader, UnstructuredODTLoader, UnstructuredPowerPointLoader, \
|
47 |
+
UnstructuredEPubLoader, UnstructuredImageLoader, UnstructuredRTFLoader, ArxivLoader, UnstructuredPDFLoader, \
|
48 |
+
UnstructuredExcelLoader
|
49 |
from langchain.text_splitter import RecursiveCharacterTextSplitter, Language
|
50 |
from langchain.chains.question_answering import load_qa_chain
|
51 |
from langchain.docstore.document import Document
|
|
|
353 |
stream_output = self.stream
|
354 |
gr_client = self.client
|
355 |
client_langchain_mode = 'Disabled'
|
356 |
+
client_langchain_action = LangChainAction.QUERY.value
|
357 |
top_k_docs = 1
|
358 |
chunk = True
|
359 |
chunk_size = 512
|
|
|
382 |
instruction_nochat=prompt if not self.chat_client else '',
|
383 |
iinput_nochat='', # only for chat=False
|
384 |
langchain_mode=client_langchain_mode,
|
385 |
+
langchain_action=client_langchain_action,
|
386 |
top_k_docs=top_k_docs,
|
387 |
chunk=chunk,
|
388 |
chunk_size=chunk_size,
|
|
|
641 |
callbacks = [StreamingGradioCallbackHandler()]
|
642 |
assert prompter is not None
|
643 |
stop_sequences = list(set(prompter.terminate_response + [prompter.PreResponse]))
|
644 |
+
stop_sequences = [x for x in stop_sequences if x]
|
645 |
|
646 |
if gr_client:
|
647 |
chat_client = False
|
|
|
749 |
|
750 |
if stream_output:
|
751 |
skip_prompt = False
|
752 |
+
from src.gen import H2OTextIteratorStreamer
|
753 |
decoder_kwargs = {}
|
754 |
streamer = H2OTextIteratorStreamer(tokenizer, skip_prompt=skip_prompt, block=False, **decoder_kwargs)
|
755 |
gen_kwargs.update(dict(streamer=streamer))
|
|
|
949 |
|
950 |
image_types = ["png", "jpg", "jpeg"]
|
951 |
non_image_types = ["pdf", "txt", "csv", "toml", "py", "rst", "rtf",
|
952 |
+
"md",
|
953 |
+
"html", "mhtml",
|
954 |
"enex", "eml", "epub", "odt", "pptx", "ppt",
|
955 |
"zip", "urls",
|
956 |
+
|
957 |
]
|
958 |
# "msg", GPL3
|
959 |
|
960 |
if have_libreoffice:
|
961 |
+
non_image_types.extend(["docx", "doc", "xls", "xlsx"])
|
962 |
|
963 |
file_types = non_image_types + image_types
|
964 |
|
|
|
968 |
hashid = hash_file(file)
|
969 |
if not isinstance(docs1, (list, tuple, types.GeneratorType)):
|
970 |
docs1 = [docs1]
|
971 |
+
[x.metadata.update(dict(input_type=file_extension, date=str(datetime.now()), hashid=hashid)) for x in docs1]
|
972 |
|
973 |
|
974 |
def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
|
|
|
1045 |
docs1 = UnstructuredWordDocumentLoader(file_path=file).load()
|
1046 |
add_meta(docs1, file)
|
1047 |
doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
|
1048 |
+
elif (file.lower().endswith('.xlsx') or file.lower().endswith('.xls')) and have_libreoffice:
|
1049 |
+
docs1 = UnstructuredExcelLoader(file_path=file).load()
|
1050 |
+
add_meta(docs1, file)
|
1051 |
+
doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
|
1052 |
elif file.lower().endswith('.odt'):
|
1053 |
docs1 = UnstructuredODTLoader(file_path=file).load()
|
1054 |
add_meta(docs1, file)
|
|
|
1769 |
|
1770 |
|
1771 |
def _run_qa_db(query=None,
|
1772 |
+
iinput=None,
|
1773 |
+
context=None,
|
1774 |
use_openai_model=False, use_openai_embedding=False,
|
1775 |
first_para=False, text_limit=None, top_k_docs=4, chunk=True, chunk_size=512,
|
1776 |
user_path=None,
|
|
|
1800 |
repetition_penalty=1.0,
|
1801 |
num_return_sequences=1,
|
1802 |
langchain_mode=None,
|
1803 |
+
langchain_action=None,
|
1804 |
document_choice=[DocumentChoices.All_Relevant.name],
|
1805 |
n_jobs=-1,
|
1806 |
verbose=False,
|
|
|
1817 |
:param use_openai_embedding:
|
1818 |
:param first_para:
|
1819 |
:param text_limit:
|
1820 |
+
:param top_k_docs:
|
1821 |
:param chunk:
|
1822 |
:param chunk_size:
|
1823 |
:param user_path: user path to glob recursively from
|
|
|
1883 |
sim_kwargs = {k: v for k, v in locals().items() if k in func_names}
|
1884 |
missing_kwargs = [x for x in func_names if x not in sim_kwargs]
|
1885 |
assert not missing_kwargs, "Missing: %s" % missing_kwargs
|
1886 |
+
docs, chain, scores, use_context, have_any_docs = get_similarity_chain(**sim_kwargs)
|
1887 |
+
if cmd in non_query_commands:
|
1888 |
formatted_doc_chunks = '\n\n'.join([get_url(x) + '\n\n' + x.page_content for x in docs])
|
1889 |
yield formatted_doc_chunks, ''
|
1890 |
return
|
1891 |
+
if not docs and langchain_action in [LangChainAction.SUMMARIZE_MAP.value,
|
1892 |
+
LangChainAction.SUMMARIZE_ALL.value,
|
1893 |
+
LangChainAction.SUMMARIZE_REFINE.value]:
|
1894 |
+
ret = 'No relevant documents to summarize.' if have_any_docs else 'No documents to summarize.'
|
1895 |
+
extra = ''
|
1896 |
+
yield ret, extra
|
1897 |
+
return
|
1898 |
+
if not docs and langchain_mode not in [LangChainMode.DISABLED.value,
|
1899 |
+
LangChainMode.CHAT_LLM.value,
|
1900 |
+
LangChainMode.LLM.value]:
|
1901 |
+
ret = 'No relevant documents to query.' if have_any_docs else 'No documents to query.'
|
1902 |
+
extra = ''
|
1903 |
+
yield ret, extra
|
1904 |
+
return
|
1905 |
+
|
1906 |
if chain is None and model_name not in non_hf_types:
|
1907 |
+
# here if no docs at all and not HF type
|
1908 |
# can only return if HF type
|
1909 |
return
|
1910 |
|
|
|
1963 |
|
1964 |
|
1965 |
def get_similarity_chain(query=None,
|
1966 |
+
iinput=None,
|
1967 |
use_openai_model=False, use_openai_embedding=False,
|
1968 |
first_para=False, text_limit=None, top_k_docs=4, chunk=True, chunk_size=512,
|
1969 |
user_path=None,
|
|
|
1978 |
load_db_if_exists=False,
|
1979 |
db=None,
|
1980 |
langchain_mode=None,
|
1981 |
+
langchain_action=None,
|
1982 |
document_choice=[DocumentChoices.All_Relevant.name],
|
1983 |
n_jobs=-1,
|
1984 |
# beyond run_db_query:
|
|
|
2029 |
db=db,
|
2030 |
n_jobs=n_jobs,
|
2031 |
verbose=verbose)
|
2032 |
+
have_any_docs = db is not None
|
2033 |
+
if langchain_action == LangChainAction.QUERY.value:
|
2034 |
+
if iinput:
|
2035 |
+
query = "%s\n%s" % (query, iinput)
|
2036 |
+
|
2037 |
+
if 'falcon' in model_name:
|
2038 |
+
extra = "According to only the information in the document sources provided within the context above, "
|
2039 |
+
prefix = "Pay attention and remember information below, which will help to answer the question or imperative after the context ends."
|
2040 |
+
elif inference_server in ['openai', 'openai_chat']:
|
2041 |
+
extra = "According to (primarily) the information in the document sources provided within context above, "
|
2042 |
+
prefix = "Pay attention and remember information below, which will help to answer the question or imperative after the context ends. If the answer cannot be primarily obtained from information within the context, then respond that the answer does not appear in the context of the documents."
|
2043 |
+
else:
|
2044 |
+
extra = ""
|
2045 |
+
prefix = ""
|
2046 |
+
if langchain_mode in ['Disabled', 'ChatLLM', 'LLM'] or not use_context:
|
2047 |
+
template_if_no_docs = template = """%s{context}{question}""" % prefix
|
2048 |
+
else:
|
2049 |
+
template = """%s
|
2050 |
+
\"\"\"
|
2051 |
+
{context}
|
2052 |
+
\"\"\"
|
2053 |
+
%s{question}""" % (prefix, extra)
|
2054 |
+
template_if_no_docs = """%s{context}%s{question}""" % (prefix, extra)
|
2055 |
+
elif langchain_action in [LangChainAction.SUMMARIZE_ALL.value, LangChainAction.SUMMARIZE_MAP.value]:
|
2056 |
+
none = ['', '\n', None]
|
2057 |
+
if query in none and iinput in none:
|
2058 |
+
prompt_summary = "Using only the text above, write a condensed and concise summary:\n"
|
2059 |
+
elif query not in none:
|
2060 |
+
prompt_summary = "Focusing on %s, write a condensed and concise Summary:\n" % query
|
2061 |
+
elif iinput not in None:
|
2062 |
+
prompt_summary = iinput
|
2063 |
+
else:
|
2064 |
+
prompt_summary = "Focusing on %s, %s:\n" % (query, iinput)
|
2065 |
+
# don't auto reduce
|
2066 |
+
auto_reduce_chunks = False
|
2067 |
+
if langchain_action == LangChainAction.SUMMARIZE_MAP.value:
|
2068 |
+
fstring = '{text}'
|
2069 |
+
else:
|
2070 |
+
fstring = '{input_documents}'
|
2071 |
+
template = """In order to write a concise single-paragraph or bulleted list summary, pay attention to the following text:
|
2072 |
\"\"\"
|
2073 |
+
%s
|
2074 |
+
\"\"\"\n%s""" % (fstring, prompt_summary)
|
2075 |
+
template_if_no_docs = "Exactly only say: There are no documents to summarize."
|
2076 |
+
elif langchain_action in [LangChainAction.SUMMARIZE_REFINE]:
|
2077 |
+
template = '' # unused
|
2078 |
+
template_if_no_docs = '' # unused
|
2079 |
+
else:
|
2080 |
+
raise RuntimeError("No such langchain_action=%s" % langchain_action)
|
2081 |
+
|
2082 |
if not use_openai_model and prompt_type not in ['plain'] or model_name in non_hf_types:
|
2083 |
use_template = True
|
2084 |
else:
|
|
|
2103 |
if cmd == DocumentChoices.Just_LLM.name:
|
2104 |
docs = []
|
2105 |
scores = []
|
2106 |
+
elif cmd == DocumentChoices.Only_All_Sources.name or query in [None, '', '\n']:
|
2107 |
db_documents, db_metadatas = get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs)
|
2108 |
# similar to langchain's chroma's _results_to_docs_and_scores
|
2109 |
docs_with_score = [(Document(page_content=result[0], metadata=result[1] or {}), 0)
|
2110 |
+
for result in zip(db_documents, db_metadatas)]
|
2111 |
+
|
2112 |
+
# order documents
|
2113 |
+
doc_hashes = [x['doc_hash'] for x in db_metadatas]
|
2114 |
+
doc_chunk_ids = [x['chunk_id'] for x in db_metadatas]
|
2115 |
+
docs_with_score = [x for _, _, x in
|
2116 |
+
sorted(zip(doc_hashes, doc_chunk_ids, docs_with_score), key=lambda x: (x[0], x[1]))
|
2117 |
+
]
|
2118 |
+
|
2119 |
+
docs_with_score = docs_with_score[:top_k_docs]
|
2120 |
docs = [x[0] for x in docs_with_score]
|
2121 |
scores = [x[1] for x in docs_with_score]
|
2122 |
+
have_any_docs |= len(docs) > 0
|
2123 |
else:
|
2124 |
+
# FIXME: if langchain_action == LangChainAction.SUMMARIZE_MAP.value
|
2125 |
+
# if map_reduce, then no need to auto reduce chunks
|
2126 |
if top_k_docs == -1 or auto_reduce_chunks:
|
2127 |
# docs_with_score = db.similarity_search_with_score(query, k=k_db, **filter_kwargs)[:top_k_docs]
|
2128 |
top_k_docs_tokenize = 100
|
|
|
2195 |
if reverse_docs:
|
2196 |
docs_with_score.reverse()
|
2197 |
# cut off so no high distance docs/sources considered
|
2198 |
+
have_any_docs |= len(docs_with_score) > 0 # before cut
|
2199 |
docs = [x[0] for x in docs_with_score if x[1] < cut_distanct]
|
2200 |
scores = [x[1] for x in docs_with_score if x[1] < cut_distanct]
|
2201 |
if len(scores) > 0 and verbose:
|
|
|
2207 |
|
2208 |
if not docs and use_context and model_name not in non_hf_types:
|
2209 |
# if HF type and have no docs, can bail out
|
2210 |
+
return docs, None, [], False, have_any_docs
|
2211 |
|
2212 |
+
if cmd in non_query_commands:
|
2213 |
# no LLM use
|
2214 |
+
return docs, None, [], False, have_any_docs
|
2215 |
|
2216 |
common_words_file = "data/NGSL_1.2_stats.csv.zip"
|
2217 |
+
if os.path.isfile(common_words_file) and langchain_mode == LangChainAction.QUERY.value:
|
2218 |
df = pd.read_csv("data/NGSL_1.2_stats.csv.zip")
|
2219 |
import string
|
2220 |
reduced_query = query.translate(str.maketrans(string.punctuation, ' ' * len(string.punctuation))).strip()
|
|
|
2231 |
use_context = False
|
2232 |
template = template_if_no_docs
|
2233 |
|
2234 |
+
if langchain_action == LangChainAction.QUERY.value:
|
2235 |
+
if use_template:
|
2236 |
+
# instruct-like, rather than few-shot prompt_type='plain' as default
|
2237 |
+
# but then sources confuse the model with how inserted among rest of text, so avoid
|
2238 |
+
prompt = PromptTemplate(
|
2239 |
+
# input_variables=["summaries", "question"],
|
2240 |
+
input_variables=["context", "question"],
|
2241 |
+
template=template,
|
2242 |
+
)
|
2243 |
+
chain = load_qa_chain(llm, prompt=prompt)
|
2244 |
+
else:
|
2245 |
+
# only if use_openai_model = True, unused normally except in testing
|
2246 |
+
chain = load_qa_with_sources_chain(llm)
|
2247 |
+
if not use_context:
|
2248 |
+
chain_kwargs = dict(input_documents=[], question=query)
|
2249 |
+
else:
|
2250 |
+
chain_kwargs = dict(input_documents=docs, question=query)
|
2251 |
+
target = wrapped_partial(chain, chain_kwargs)
|
2252 |
+
elif langchain_action in [LangChainAction.SUMMARIZE_MAP.value,
|
2253 |
+
LangChainAction.SUMMARIZE_REFINE,
|
2254 |
+
LangChainAction.SUMMARIZE_ALL.value]:
|
2255 |
+
from langchain.chains.summarize import load_summarize_chain
|
2256 |
+
if langchain_action == LangChainAction.SUMMARIZE_MAP.value:
|
2257 |
+
prompt = PromptTemplate(input_variables=["text"], template=template)
|
2258 |
+
chain = load_summarize_chain(llm, chain_type="map_reduce",
|
2259 |
+
map_prompt=prompt, combine_prompt=prompt, return_intermediate_steps=True)
|
2260 |
+
target = wrapped_partial(chain, {"input_documents": docs}) # , return_only_outputs=True)
|
2261 |
+
elif langchain_action == LangChainAction.SUMMARIZE_ALL.value:
|
2262 |
+
assert use_template
|
2263 |
+
prompt = PromptTemplate(input_variables=["text"], template=template)
|
2264 |
+
chain = load_summarize_chain(llm, chain_type="stuff", prompt=prompt, return_intermediate_steps=True)
|
2265 |
+
target = wrapped_partial(chain)
|
2266 |
+
elif langchain_action == LangChainAction.SUMMARIZE_REFINE.value:
|
2267 |
+
chain = load_summarize_chain(llm, chain_type="refine", return_intermediate_steps=True)
|
2268 |
+
target = wrapped_partial(chain)
|
2269 |
+
else:
|
2270 |
+
raise RuntimeError("No such langchain_action=%s" % langchain_action)
|
2271 |
else:
|
2272 |
+
raise RuntimeError("No such langchain_action=%s" % langchain_action)
|
2273 |
|
2274 |
+
return docs, target, scores, use_context, have_any_docs
|
|
|
2275 |
|
2276 |
|
2277 |
def get_sources_answer(query, answer, scores, show_rank, answer_with_sources, verbose=False):
|
|
|
2341 |
splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=0, keep_separator=keep_separator,
|
2342 |
separators=separators)
|
2343 |
source_chunks = splitter.split_documents(sources)
|
2344 |
+
|
2345 |
+
# currently in order, but when pull from db won't be, so mark order and document by hash
|
2346 |
+
doc_hash = str(uuid.uuid4())[:10]
|
2347 |
+
[x.metadata.update(dict(doc_hash=doc_hash, chunk_id=chunk_id)) for chunk_id, x in enumerate(source_chunks)]
|
2348 |
+
|
2349 |
return source_chunks
|
2350 |
|
2351 |
|
gradio_runner.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import copy
|
2 |
import functools
|
3 |
import inspect
|
@@ -49,16 +50,16 @@ def fix_pydantic_duplicate_validators_error():
|
|
49 |
|
50 |
fix_pydantic_duplicate_validators_error()
|
51 |
|
52 |
-
from enums import DocumentChoices, no_model_str, no_lora_str, no_server_str, LangChainMode
|
53 |
from gradio_themes import H2oTheme, SoftTheme, get_h2o_title, get_simple_title, get_dark_js, spacing_xsm, radius_xsm, \
|
54 |
text_xsm
|
55 |
from prompter import prompt_type_to_model_name, prompt_types_strings, inv_prompt_type_to_model_lower, non_hf_types, \
|
56 |
get_prompt
|
57 |
from utils import get_githash, flatten_list, zip_data, s3up, clear_torch_cache, get_torch_allocated, system_info_print, \
|
58 |
ping, get_short_name, get_url, makedirs, get_kwargs, remove, system_info, ping_gpu
|
59 |
-
from
|
60 |
-
inputs_kwargs_list, scratch_base_dir,
|
61 |
-
eval_func_param_names_defaults, get_max_max_new_tokens, get_minmax_top_k_docs, history_to_context
|
62 |
|
63 |
from apscheduler.schedulers.background import BackgroundScheduler
|
64 |
|
@@ -99,6 +100,7 @@ def go_gradio(**kwargs):
|
|
99 |
dbs = kwargs['dbs']
|
100 |
db_type = kwargs['db_type']
|
101 |
visible_langchain_modes = kwargs['visible_langchain_modes']
|
|
|
102 |
allow_upload_to_user_data = kwargs['allow_upload_to_user_data']
|
103 |
allow_upload_to_my_data = kwargs['allow_upload_to_my_data']
|
104 |
enable_sources_list = kwargs['enable_sources_list']
|
@@ -213,7 +215,28 @@ def go_gradio(**kwargs):
|
|
213 |
'base_model') else no_model_msg
|
214 |
output_label0_model2 = no_model_msg
|
215 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
default_kwargs = {k: kwargs[k] for k in eval_func_param_names_defaults}
|
|
|
|
|
|
|
|
|
217 |
for k in no_default_param_names:
|
218 |
default_kwargs[k] = ''
|
219 |
|
@@ -239,7 +262,8 @@ def go_gradio(**kwargs):
|
|
239 |
model_options_state = gr.State([model_options])
|
240 |
lora_options_state = gr.State([lora_options])
|
241 |
server_options_state = gr.State([server_options])
|
242 |
-
|
|
|
243 |
chat_state = gr.State({})
|
244 |
# make user default first and default choice, dedup
|
245 |
docs_state00 = kwargs['document_choice'] + [x.name for x in list(DocumentChoices)]
|
@@ -332,6 +356,12 @@ def go_gradio(**kwargs):
|
|
332 |
value=kwargs['langchain_mode'],
|
333 |
label="Data Collection of Sources",
|
334 |
visible=kwargs['langchain_mode'] != 'Disabled')
|
|
|
|
|
|
|
|
|
|
|
|
|
335 |
data_row2 = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled')
|
336 |
with data_row2:
|
337 |
with gr.Column(scale=50):
|
@@ -920,19 +950,59 @@ def go_gradio(**kwargs):
|
|
920 |
for k in inputs_kwargs_list:
|
921 |
assert k in kwargs_evaluate, "Missing %s" % k
|
922 |
|
923 |
-
def
|
924 |
-
|
925 |
-
|
926 |
-
|
927 |
-
|
928 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
929 |
|
930 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
931 |
**kwargs_evaluate)
|
932 |
-
fun2 = partial(
|
|
|
|
|
933 |
**kwargs_evaluate)
|
934 |
-
fun_with_dict_str = partial(
|
935 |
-
|
|
|
936 |
**kwargs_evaluate
|
937 |
)
|
938 |
|
@@ -1072,14 +1142,17 @@ def go_gradio(**kwargs):
|
|
1072 |
User that fills history for bot
|
1073 |
:param args:
|
1074 |
:param undo:
|
|
|
1075 |
:param sanitize_user_prompt:
|
1076 |
-
:param model2:
|
1077 |
:return:
|
1078 |
"""
|
1079 |
args_list = list(args)
|
1080 |
user_message = args_list[eval_func_param_names.index('instruction')] # chat only
|
1081 |
input1 = args_list[eval_func_param_names.index('iinput')] # chat only
|
1082 |
prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
|
|
|
|
|
|
|
1083 |
if not prompt_type1:
|
1084 |
# shouldn't have to specify if CLI launched model
|
1085 |
prompt_type1 = kwargs['prompt_type']
|
@@ -1110,8 +1183,12 @@ def go_gradio(**kwargs):
|
|
1110 |
history[-1][1] = None
|
1111 |
return history
|
1112 |
if user_message1 in ['', None, '\n']:
|
1113 |
-
|
1114 |
-
|
|
|
|
|
|
|
|
|
1115 |
user_message1 = fix_text_for_gradio(user_message1)
|
1116 |
return history + [[user_message1, None]]
|
1117 |
|
@@ -1147,11 +1224,13 @@ def go_gradio(**kwargs):
|
|
1147 |
else:
|
1148 |
return 2000
|
1149 |
|
1150 |
-
def prep_bot(*args, retry=False):
|
1151 |
"""
|
1152 |
|
1153 |
:param args:
|
1154 |
:param retry:
|
|
|
|
|
1155 |
:return: last element is True if should run bot, False if should just yield history
|
1156 |
"""
|
1157 |
# don't deepcopy, can contain model itself
|
@@ -1159,12 +1238,16 @@ def go_gradio(**kwargs):
|
|
1159 |
model_state1 = args_list[-3]
|
1160 |
my_db_state1 = args_list[-2]
|
1161 |
history = args_list[-1]
|
1162 |
-
|
|
|
1163 |
|
1164 |
if model_state1['model'] is None or model_state1['model'] == no_model_str:
|
1165 |
return history, None, None, None
|
1166 |
|
1167 |
args_list = args_list[:-3] # only keep rest needed for evaluate()
|
|
|
|
|
|
|
1168 |
if not history:
|
1169 |
print("No history", flush=True)
|
1170 |
history = []
|
@@ -1175,22 +1258,23 @@ def go_gradio(**kwargs):
|
|
1175 |
instruction1 = history[-1][0]
|
1176 |
history[-1][1] = None
|
1177 |
elif not instruction1:
|
1178 |
-
|
1179 |
-
|
|
|
|
|
|
|
|
|
1180 |
elif len(history) > 0 and history[-1][1] not in [None, '']:
|
1181 |
# reject submit button if already filled and not retrying
|
1182 |
# None when not filling with '' to keep client happy
|
1183 |
return history, None, None, None
|
1184 |
|
1185 |
# shouldn't have to specify in API prompt_type if CLI launched model, so prefer global CLI one if have it
|
1186 |
-
prompt_type1 =
|
1187 |
-
|
1188 |
-
|
1189 |
-
|
1190 |
-
|
1191 |
-
prompt_dict1 = kwargs.get('prompt_dict', args_list[eval_func_param_names.index('prompt_dict')])
|
1192 |
-
args_list[eval_func_param_names.index('prompt_dict')] = prompt_dict1 = \
|
1193 |
-
model_state1.get('prompt_dict', prompt_dict1)
|
1194 |
|
1195 |
chat1 = args_list[eval_func_param_names.index('chat')]
|
1196 |
model_max_length1 = get_model_max_length(model_state1)
|
@@ -1264,6 +1348,7 @@ def go_gradio(**kwargs):
|
|
1264 |
for res in get_response(fun1, history):
|
1265 |
yield res
|
1266 |
finally:
|
|
|
1267 |
clear_embeddings(langchain_mode1, my_db_state1)
|
1268 |
|
1269 |
def all_bot(*args, retry=False, model_states1=None):
|
@@ -1277,7 +1362,7 @@ def go_gradio(**kwargs):
|
|
1277 |
my_db_state1 = None # will be filled below by some bot
|
1278 |
try:
|
1279 |
gen_list = []
|
1280 |
-
for chatbot1, model_state1 in zip(chatbots, model_states1):
|
1281 |
args_list1 = args_list0.copy()
|
1282 |
args_list1.insert(-1, model_state1) # insert at -1 so is at -2
|
1283 |
# if at start, have None in response still, replace with '' so client etc. acts like normal
|
@@ -1289,7 +1374,8 @@ def go_gradio(**kwargs):
|
|
1289 |
# so consistent with prep_bot()
|
1290 |
# with model_state1 at -3, my_db_state1 at -2, and history(chatbot) at -1
|
1291 |
# langchain_mode1 and my_db_state1 should be same for every bot
|
1292 |
-
history, fun1, langchain_mode1, my_db_state1 = prep_bot(*tuple(args_list1), retry=retry
|
|
|
1293 |
gen1 = get_response(fun1, history)
|
1294 |
if stream_output1:
|
1295 |
gen1 = TimeoutIterator(gen1, timeout=0.01, sentinel=None, raise_on_exception=False)
|
@@ -1301,6 +1387,7 @@ def go_gradio(**kwargs):
|
|
1301 |
tgen0 = time.time()
|
1302 |
for res1 in itertools.zip_longest(*gen_list):
|
1303 |
if time.time() - tgen0 > max_time1:
|
|
|
1304 |
break
|
1305 |
|
1306 |
bots = [x[0] if x is not None and not isinstance(x, BaseException) else y for x, y in
|
@@ -1735,6 +1822,9 @@ def go_gradio(**kwargs):
|
|
1735 |
|
1736 |
def load_model(model_name, lora_weights, server_name, model_state_old, prompt_type_old, load_8bit,
|
1737 |
infer_devices, gpu_id):
|
|
|
|
|
|
|
1738 |
# ensure old model removed from GPU memory
|
1739 |
if kwargs['debug']:
|
1740 |
print("Pre-switch pre-del GPU memory: %s" % get_torch_allocated(), flush=True)
|
@@ -2161,6 +2251,15 @@ def update_user_db(file, db1, x, y, *args, dbs=None, langchain_mode='UserData',
|
|
2161 |
clear_torch_cache()
|
2162 |
|
2163 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2164 |
def _update_user_db(file, db1, x, y, chunk, chunk_size, dbs=None, db_type=None, langchain_mode='UserData',
|
2165 |
user_path=None,
|
2166 |
use_openai_embedding=None,
|
@@ -2222,7 +2321,8 @@ def _update_user_db(file, db1, x, y, chunk, chunk_size, dbs=None, db_type=None,
|
|
2222 |
exceptions = [x for x in sources if x.metadata.get('exception')]
|
2223 |
sources = [x for x in sources if 'exception' not in x.metadata]
|
2224 |
|
2225 |
-
|
|
|
2226 |
if langchain_mode == 'MyData':
|
2227 |
if db1[0] is not None:
|
2228 |
# then add
|
@@ -2235,18 +2335,14 @@ def _update_user_db(file, db1, x, y, chunk, chunk_size, dbs=None, db_type=None,
|
|
2235 |
# for production hit, when user gets clicky:
|
2236 |
assert len(db1) == 2, "Bad MyData db: %s" % db1
|
2237 |
# then create
|
2238 |
-
# assign fresh hash for this user session, so not shared
|
2239 |
# if added has to original state and didn't change, then would be shared db for all users
|
2240 |
-
db1[1] = str(uuid.uuid4())
|
2241 |
persist_directory = os.path.join(scratch_base_dir, 'db_dir_%s_%s' % (langchain_mode, db1[1]))
|
2242 |
db = get_db(sources, use_openai_embedding=use_openai_embedding,
|
2243 |
db_type=db_type,
|
2244 |
persist_directory=persist_directory,
|
2245 |
langchain_mode=langchain_mode,
|
2246 |
hf_embedding_model=hf_embedding_model)
|
2247 |
-
if db is None:
|
2248 |
-
db1[1] = None
|
2249 |
-
else:
|
2250 |
db1[0] = db
|
2251 |
source_files_added = get_source_files(db=db1[0], exceptions=exceptions)
|
2252 |
return None, langchain_mode, db1, x, y, source_files_added
|
@@ -2274,7 +2370,9 @@ def _update_user_db(file, db1, x, y, chunk, chunk_size, dbs=None, db_type=None,
|
|
2274 |
|
2275 |
|
2276 |
def get_db(db1, langchain_mode, dbs=None):
|
2277 |
-
|
|
|
|
|
2278 |
if langchain_mode in ['wiki_full']:
|
2279 |
# NOTE: avoid showing full wiki. Takes about 30 seconds over about 90k entries, but not useful for now
|
2280 |
db = None
|
|
|
1 |
+
import ast
|
2 |
import copy
|
3 |
import functools
|
4 |
import inspect
|
|
|
50 |
|
51 |
fix_pydantic_duplicate_validators_error()
|
52 |
|
53 |
+
from enums import DocumentChoices, no_model_str, no_lora_str, no_server_str, LangChainAction, LangChainMode
|
54 |
from gradio_themes import H2oTheme, SoftTheme, get_h2o_title, get_simple_title, get_dark_js, spacing_xsm, radius_xsm, \
|
55 |
text_xsm
|
56 |
from prompter import prompt_type_to_model_name, prompt_types_strings, inv_prompt_type_to_model_lower, non_hf_types, \
|
57 |
get_prompt
|
58 |
from utils import get_githash, flatten_list, zip_data, s3up, clear_torch_cache, get_torch_allocated, system_info_print, \
|
59 |
ping, get_short_name, get_url, makedirs, get_kwargs, remove, system_info, ping_gpu
|
60 |
+
from src.gen import get_model, languages_covered, evaluate, eval_func_param_names, score_qa, langchain_modes, \
|
61 |
+
inputs_kwargs_list, scratch_base_dir, no_default_param_names, \
|
62 |
+
eval_func_param_names_defaults, get_max_max_new_tokens, get_minmax_top_k_docs, history_to_context, langchain_actions
|
63 |
|
64 |
from apscheduler.schedulers.background import BackgroundScheduler
|
65 |
|
|
|
100 |
dbs = kwargs['dbs']
|
101 |
db_type = kwargs['db_type']
|
102 |
visible_langchain_modes = kwargs['visible_langchain_modes']
|
103 |
+
visible_langchain_actions = kwargs['visible_langchain_actions']
|
104 |
allow_upload_to_user_data = kwargs['allow_upload_to_user_data']
|
105 |
allow_upload_to_my_data = kwargs['allow_upload_to_my_data']
|
106 |
enable_sources_list = kwargs['enable_sources_list']
|
|
|
215 |
'base_model') else no_model_msg
|
216 |
output_label0_model2 = no_model_msg
|
217 |
|
218 |
+
def update_prompt(prompt_type1, prompt_dict1, model_state1, which_model=0):
|
219 |
+
if not prompt_type1 or which_model != 0:
|
220 |
+
# keep prompt_type and prompt_dict in sync if possible
|
221 |
+
prompt_type1 = kwargs.get('prompt_type', prompt_type1)
|
222 |
+
prompt_dict1 = kwargs.get('prompt_dict', prompt_dict1)
|
223 |
+
# prefer model specific prompt type instead of global one
|
224 |
+
if not prompt_type1 or which_model != 0:
|
225 |
+
prompt_type1 = model_state1.get('prompt_type', prompt_type1)
|
226 |
+
prompt_dict1 = model_state1.get('prompt_dict', prompt_dict1)
|
227 |
+
|
228 |
+
if not prompt_dict1 or which_model != 0:
|
229 |
+
# if still not defined, try to get
|
230 |
+
prompt_dict1 = kwargs.get('prompt_dict', prompt_dict1)
|
231 |
+
if not prompt_dict1 or which_model != 0:
|
232 |
+
prompt_dict1 = model_state1.get('prompt_dict', prompt_dict1)
|
233 |
+
return prompt_type1, prompt_dict1
|
234 |
+
|
235 |
default_kwargs = {k: kwargs[k] for k in eval_func_param_names_defaults}
|
236 |
+
# ensure prompt_type consistent with prep_bot(), so nochat API works same way
|
237 |
+
default_kwargs['prompt_type'], default_kwargs['prompt_dict'] = \
|
238 |
+
update_prompt(default_kwargs['prompt_type'], default_kwargs['prompt_dict'],
|
239 |
+
model_state1=model_state0, which_model=0)
|
240 |
for k in no_default_param_names:
|
241 |
default_kwargs[k] = ''
|
242 |
|
|
|
262 |
model_options_state = gr.State([model_options])
|
263 |
lora_options_state = gr.State([lora_options])
|
264 |
server_options_state = gr.State([server_options])
|
265 |
+
# uuid in db is used as user ID
|
266 |
+
my_db_state = gr.State([None, str(uuid.uuid4())])
|
267 |
chat_state = gr.State({})
|
268 |
# make user default first and default choice, dedup
|
269 |
docs_state00 = kwargs['document_choice'] + [x.name for x in list(DocumentChoices)]
|
|
|
356 |
value=kwargs['langchain_mode'],
|
357 |
label="Data Collection of Sources",
|
358 |
visible=kwargs['langchain_mode'] != 'Disabled')
|
359 |
+
allowed_actions = [x for x in langchain_actions if x in visible_langchain_actions]
|
360 |
+
langchain_action = gr.Radio(
|
361 |
+
allowed_actions,
|
362 |
+
value=allowed_actions[0] if len(allowed_actions) > 0 else None,
|
363 |
+
label="Data Action",
|
364 |
+
visible=True)
|
365 |
data_row2 = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled')
|
366 |
with data_row2:
|
367 |
with gr.Column(scale=50):
|
|
|
950 |
for k in inputs_kwargs_list:
|
951 |
assert k in kwargs_evaluate, "Missing %s" % k
|
952 |
|
953 |
+
def evaluate_nochat(*args1, default_kwargs1=None, str_api=False, **kwargs1):
|
954 |
+
args_list = list(args1)
|
955 |
+
if str_api:
|
956 |
+
user_kwargs = args_list[2]
|
957 |
+
assert isinstance(user_kwargs, str)
|
958 |
+
user_kwargs = ast.literal_eval(user_kwargs)
|
959 |
+
else:
|
960 |
+
user_kwargs = {k: v for k, v in zip(eval_func_param_names, args_list[2:])}
|
961 |
+
# only used for submit_nochat_api
|
962 |
+
user_kwargs['chat'] = False
|
963 |
+
if 'stream_output' not in user_kwargs:
|
964 |
+
user_kwargs['stream_output'] = False
|
965 |
+
if 'langchain_mode' not in user_kwargs:
|
966 |
+
# if user doesn't specify, then assume disabled, not use default
|
967 |
+
user_kwargs['langchain_mode'] = 'Disabled'
|
968 |
+
if 'langchain_action' not in user_kwargs:
|
969 |
+
user_kwargs['langchain_action'] = LangChainAction.QUERY.value
|
970 |
+
|
971 |
+
set1 = set(list(default_kwargs1.keys()))
|
972 |
+
set2 = set(eval_func_param_names)
|
973 |
+
assert set1 == set2, "Set diff: %s %s: %s" % (set1, set2, set1.symmetric_difference(set2))
|
974 |
+
# correct ordering. Note some things may not be in default_kwargs, so can't be default of user_kwargs.get()
|
975 |
+
model_state1 = args_list[0]
|
976 |
+
my_db_state1 = args_list[1]
|
977 |
+
args_list = [user_kwargs[k] if k in user_kwargs and user_kwargs[k] is not None else default_kwargs1[k] for k
|
978 |
+
in eval_func_param_names]
|
979 |
+
assert len(args_list) == len(eval_func_param_names)
|
980 |
+
args_list = [model_state1, my_db_state1] + args_list
|
981 |
|
982 |
+
try:
|
983 |
+
for res_dict in evaluate(*tuple(args_list), **kwargs1):
|
984 |
+
if str_api:
|
985 |
+
# full return of dict
|
986 |
+
yield res_dict
|
987 |
+
elif kwargs['langchain_mode'] == 'Disabled':
|
988 |
+
yield fix_text_for_gradio(res_dict['response'])
|
989 |
+
else:
|
990 |
+
yield '<br>' + fix_text_for_gradio(res_dict['response'])
|
991 |
+
finally:
|
992 |
+
clear_torch_cache()
|
993 |
+
clear_embeddings(user_kwargs['langchain_mode'], my_db_state1)
|
994 |
+
|
995 |
+
fun = partial(evaluate_nochat,
|
996 |
+
default_kwargs1=default_kwargs,
|
997 |
+
str_api=False,
|
998 |
**kwargs_evaluate)
|
999 |
+
fun2 = partial(evaluate_nochat,
|
1000 |
+
default_kwargs1=default_kwargs,
|
1001 |
+
str_api=False,
|
1002 |
**kwargs_evaluate)
|
1003 |
+
fun_with_dict_str = partial(evaluate_nochat,
|
1004 |
+
default_kwargs1=default_kwargs,
|
1005 |
+
str_api=True,
|
1006 |
**kwargs_evaluate
|
1007 |
)
|
1008 |
|
|
|
1142 |
User that fills history for bot
|
1143 |
:param args:
|
1144 |
:param undo:
|
1145 |
+
:param retry:
|
1146 |
:param sanitize_user_prompt:
|
|
|
1147 |
:return:
|
1148 |
"""
|
1149 |
args_list = list(args)
|
1150 |
user_message = args_list[eval_func_param_names.index('instruction')] # chat only
|
1151 |
input1 = args_list[eval_func_param_names.index('iinput')] # chat only
|
1152 |
prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
|
1153 |
+
langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
|
1154 |
+
langchain_action1 = args_list[eval_func_param_names.index('langchain_action')]
|
1155 |
+
document_choice1 = args_list[eval_func_param_names.index('document_choice')]
|
1156 |
if not prompt_type1:
|
1157 |
# shouldn't have to specify if CLI launched model
|
1158 |
prompt_type1 = kwargs['prompt_type']
|
|
|
1183 |
history[-1][1] = None
|
1184 |
return history
|
1185 |
if user_message1 in ['', None, '\n']:
|
1186 |
+
if langchain_action1 in LangChainAction.QUERY.value and \
|
1187 |
+
DocumentChoices.Only_All_Sources.name not in document_choice1 \
|
1188 |
+
or \
|
1189 |
+
langchain_mode1 in [LangChainMode.CHAT_LLM.value, LangChainMode.LLM.value]:
|
1190 |
+
# reject non-retry submit/enter
|
1191 |
+
return history
|
1192 |
user_message1 = fix_text_for_gradio(user_message1)
|
1193 |
return history + [[user_message1, None]]
|
1194 |
|
|
|
1224 |
else:
|
1225 |
return 2000
|
1226 |
|
1227 |
+
def prep_bot(*args, retry=False, which_model=0):
|
1228 |
"""
|
1229 |
|
1230 |
:param args:
|
1231 |
:param retry:
|
1232 |
+
:param which_model: identifies which model if doing model_lock
|
1233 |
+
API only called for which_model=0, default for inputs_list, but rest should ignore inputs_list
|
1234 |
:return: last element is True if should run bot, False if should just yield history
|
1235 |
"""
|
1236 |
# don't deepcopy, can contain model itself
|
|
|
1238 |
model_state1 = args_list[-3]
|
1239 |
my_db_state1 = args_list[-2]
|
1240 |
history = args_list[-1]
|
1241 |
+
prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
|
1242 |
+
prompt_dict1 = args_list[eval_func_param_names.index('prompt_dict')]
|
1243 |
|
1244 |
if model_state1['model'] is None or model_state1['model'] == no_model_str:
|
1245 |
return history, None, None, None
|
1246 |
|
1247 |
args_list = args_list[:-3] # only keep rest needed for evaluate()
|
1248 |
+
langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
|
1249 |
+
langchain_action1 = args_list[eval_func_param_names.index('langchain_action')]
|
1250 |
+
document_choice1 = args_list[eval_func_param_names.index('document_choice')]
|
1251 |
if not history:
|
1252 |
print("No history", flush=True)
|
1253 |
history = []
|
|
|
1258 |
instruction1 = history[-1][0]
|
1259 |
history[-1][1] = None
|
1260 |
elif not instruction1:
|
1261 |
+
if langchain_action1 in LangChainAction.QUERY.value and \
|
1262 |
+
DocumentChoices.Only_All_Sources.name not in document_choice1 \
|
1263 |
+
or \
|
1264 |
+
langchain_mode1 in [LangChainMode.CHAT_LLM.value, LangChainMode.LLM.value]:
|
1265 |
+
# if not retrying, then reject empty query
|
1266 |
+
return history, None, None, None
|
1267 |
elif len(history) > 0 and history[-1][1] not in [None, '']:
|
1268 |
# reject submit button if already filled and not retrying
|
1269 |
# None when not filling with '' to keep client happy
|
1270 |
return history, None, None, None
|
1271 |
|
1272 |
# shouldn't have to specify in API prompt_type if CLI launched model, so prefer global CLI one if have it
|
1273 |
+
prompt_type1, prompt_dict1 = update_prompt(prompt_type1, prompt_dict1, model_state1,
|
1274 |
+
which_model=which_model)
|
1275 |
+
# apply back to args_list for evaluate()
|
1276 |
+
args_list[eval_func_param_names.index('prompt_type')] = prompt_type1
|
1277 |
+
args_list[eval_func_param_names.index('prompt_dict')] = prompt_dict1
|
|
|
|
|
|
|
1278 |
|
1279 |
chat1 = args_list[eval_func_param_names.index('chat')]
|
1280 |
model_max_length1 = get_model_max_length(model_state1)
|
|
|
1348 |
for res in get_response(fun1, history):
|
1349 |
yield res
|
1350 |
finally:
|
1351 |
+
clear_torch_cache()
|
1352 |
clear_embeddings(langchain_mode1, my_db_state1)
|
1353 |
|
1354 |
def all_bot(*args, retry=False, model_states1=None):
|
|
|
1362 |
my_db_state1 = None # will be filled below by some bot
|
1363 |
try:
|
1364 |
gen_list = []
|
1365 |
+
for chatboti, (chatbot1, model_state1) in enumerate(zip(chatbots, model_states1)):
|
1366 |
args_list1 = args_list0.copy()
|
1367 |
args_list1.insert(-1, model_state1) # insert at -1 so is at -2
|
1368 |
# if at start, have None in response still, replace with '' so client etc. acts like normal
|
|
|
1374 |
# so consistent with prep_bot()
|
1375 |
# with model_state1 at -3, my_db_state1 at -2, and history(chatbot) at -1
|
1376 |
# langchain_mode1 and my_db_state1 should be same for every bot
|
1377 |
+
history, fun1, langchain_mode1, my_db_state1 = prep_bot(*tuple(args_list1), retry=retry,
|
1378 |
+
which_model=chatboti)
|
1379 |
gen1 = get_response(fun1, history)
|
1380 |
if stream_output1:
|
1381 |
gen1 = TimeoutIterator(gen1, timeout=0.01, sentinel=None, raise_on_exception=False)
|
|
|
1387 |
tgen0 = time.time()
|
1388 |
for res1 in itertools.zip_longest(*gen_list):
|
1389 |
if time.time() - tgen0 > max_time1:
|
1390 |
+
print("Took too long: %s" % max_time1, flush=True)
|
1391 |
break
|
1392 |
|
1393 |
bots = [x[0] if x is not None and not isinstance(x, BaseException) else y for x, y in
|
|
|
1822 |
|
1823 |
def load_model(model_name, lora_weights, server_name, model_state_old, prompt_type_old, load_8bit,
|
1824 |
infer_devices, gpu_id):
|
1825 |
+
# ensure no API calls reach here
|
1826 |
+
if is_public:
|
1827 |
+
raise RuntimeError("Illegal access for %s" % model_name)
|
1828 |
# ensure old model removed from GPU memory
|
1829 |
if kwargs['debug']:
|
1830 |
print("Pre-switch pre-del GPU memory: %s" % get_torch_allocated(), flush=True)
|
|
|
2251 |
clear_torch_cache()
|
2252 |
|
2253 |
|
2254 |
+
def get_lock_file(db1, langchain_mode):
|
2255 |
+
assert len(db1) == 2 and db1[1] is not None and isinstance(db1[1], str)
|
2256 |
+
user_id = db1[1]
|
2257 |
+
base_path = 'locks'
|
2258 |
+
makedirs(base_path)
|
2259 |
+
lock_file = "db_%s_%s.lock" % (langchain_mode.replace(' ', '_'), user_id)
|
2260 |
+
return lock_file
|
2261 |
+
|
2262 |
+
|
2263 |
def _update_user_db(file, db1, x, y, chunk, chunk_size, dbs=None, db_type=None, langchain_mode='UserData',
|
2264 |
user_path=None,
|
2265 |
use_openai_embedding=None,
|
|
|
2321 |
exceptions = [x for x in sources if x.metadata.get('exception')]
|
2322 |
sources = [x for x in sources if 'exception' not in x.metadata]
|
2323 |
|
2324 |
+
lock_file = get_lock_file(db1, langchain_mode)
|
2325 |
+
with filelock.FileLock(lock_file):
|
2326 |
if langchain_mode == 'MyData':
|
2327 |
if db1[0] is not None:
|
2328 |
# then add
|
|
|
2335 |
# for production hit, when user gets clicky:
|
2336 |
assert len(db1) == 2, "Bad MyData db: %s" % db1
|
2337 |
# then create
|
|
|
2338 |
# if added has to original state and didn't change, then would be shared db for all users
|
|
|
2339 |
persist_directory = os.path.join(scratch_base_dir, 'db_dir_%s_%s' % (langchain_mode, db1[1]))
|
2340 |
db = get_db(sources, use_openai_embedding=use_openai_embedding,
|
2341 |
db_type=db_type,
|
2342 |
persist_directory=persist_directory,
|
2343 |
langchain_mode=langchain_mode,
|
2344 |
hf_embedding_model=hf_embedding_model)
|
2345 |
+
if db is not None:
|
|
|
|
|
2346 |
db1[0] = db
|
2347 |
source_files_added = get_source_files(db=db1[0], exceptions=exceptions)
|
2348 |
return None, langchain_mode, db1, x, y, source_files_added
|
|
|
2370 |
|
2371 |
|
2372 |
def get_db(db1, langchain_mode, dbs=None):
|
2373 |
+
lock_file = get_lock_file(db1, langchain_mode)
|
2374 |
+
|
2375 |
+
with filelock.FileLock(lock_file):
|
2376 |
if langchain_mode in ['wiki_full']:
|
2377 |
# NOTE: avoid showing full wiki. Takes about 30 seconds over about 90k entries, but not useful for now
|
2378 |
db = None
|
gradio_utils/__pycache__/css.cpython-310.pyc
DELETED
Binary file (1.53 kB)
|
|
gradio_utils/__pycache__/grclient.cpython-310.pyc
DELETED
Binary file (2.69 kB)
|
|
gradio_utils/__pycache__/prompt_form.cpython-310.pyc
DELETED
Binary file (3.59 kB)
|
|
gradio_utils/css.py
DELETED
@@ -1,53 +0,0 @@
|
|
1 |
-
def get_css(kwargs) -> str:
|
2 |
-
if kwargs['h2ocolors']:
|
3 |
-
css_code = """footer {visibility: hidden;}
|
4 |
-
body{background:linear-gradient(#f5f5f5,#e5e5e5);}
|
5 |
-
body.dark{background:linear-gradient(#000000,#0d0d0d);}
|
6 |
-
"""
|
7 |
-
else:
|
8 |
-
css_code = """footer {visibility: hidden}"""
|
9 |
-
|
10 |
-
css_code += make_css_base()
|
11 |
-
return css_code
|
12 |
-
|
13 |
-
|
14 |
-
def make_css_base() -> str:
|
15 |
-
return """
|
16 |
-
@import url('https://fonts.googleapis.com/css2?family=Source+Sans+Pro:wght@400;600&display=swap');
|
17 |
-
|
18 |
-
body.dark{#warning {background-color: #555555};}
|
19 |
-
|
20 |
-
#small_btn {
|
21 |
-
margin: 0.6em 0em 0.55em 0;
|
22 |
-
max-width: 20em;
|
23 |
-
min-width: 5em !important;
|
24 |
-
height: 5em;
|
25 |
-
font-size: 14px !important;
|
26 |
-
}
|
27 |
-
|
28 |
-
#prompt-form {
|
29 |
-
border: 1px solid var(--primary-500) !important;
|
30 |
-
}
|
31 |
-
|
32 |
-
#prompt-form.block {
|
33 |
-
border-radius: var(--block-radius) !important;
|
34 |
-
}
|
35 |
-
|
36 |
-
#prompt-form textarea {
|
37 |
-
border: 1px solid rgb(209, 213, 219);
|
38 |
-
}
|
39 |
-
|
40 |
-
#prompt-form label > div {
|
41 |
-
margin-top: 4px;
|
42 |
-
}
|
43 |
-
|
44 |
-
button.primary:hover {
|
45 |
-
background-color: var(--primary-600) !important;
|
46 |
-
transition: .2s;
|
47 |
-
}
|
48 |
-
|
49 |
-
#prompt-form-area {
|
50 |
-
margin-bottom: 2.5rem;
|
51 |
-
}
|
52 |
-
.chatsmall chatbot {font-size: 10px !important}
|
53 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gradio_utils/grclient.py
DELETED
@@ -1,82 +0,0 @@
|
|
1 |
-
import traceback
|
2 |
-
from typing import Callable
|
3 |
-
import os
|
4 |
-
|
5 |
-
from gradio_client.client import Job
|
6 |
-
|
7 |
-
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
|
8 |
-
|
9 |
-
from gradio_client import Client
|
10 |
-
|
11 |
-
|
12 |
-
class GradioClient(Client):
|
13 |
-
"""
|
14 |
-
Parent class of gradio client
|
15 |
-
To handle automatically refreshing client if detect gradio server changed
|
16 |
-
"""
|
17 |
-
|
18 |
-
def __init__(self, *args, **kwargs):
|
19 |
-
self.args = args
|
20 |
-
self.kwargs = kwargs
|
21 |
-
super().__init__(*args, **kwargs)
|
22 |
-
self.server_hash = self.get_server_hash()
|
23 |
-
|
24 |
-
def get_server_hash(self):
|
25 |
-
"""
|
26 |
-
Get server hash using super without any refresh action triggered
|
27 |
-
Returns: git hash of gradio server
|
28 |
-
"""
|
29 |
-
return super().submit(api_name='/system_hash').result()
|
30 |
-
|
31 |
-
def refresh_client_if_should(self):
|
32 |
-
# get current hash in order to update api_name -> fn_index map in case gradio server changed
|
33 |
-
# FIXME: Could add cli api as hash
|
34 |
-
server_hash = self.get_server_hash()
|
35 |
-
if self.server_hash != server_hash:
|
36 |
-
self.refresh_client()
|
37 |
-
self.server_hash = server_hash
|
38 |
-
else:
|
39 |
-
self.reset_session()
|
40 |
-
|
41 |
-
def refresh_client(self):
|
42 |
-
"""
|
43 |
-
Ensure every client call is independent
|
44 |
-
Also ensure map between api_name and fn_index is updated in case server changed (e.g. restarted with new code)
|
45 |
-
Returns:
|
46 |
-
"""
|
47 |
-
# need session hash to be new every time, to avoid "generator already executing"
|
48 |
-
self.reset_session()
|
49 |
-
|
50 |
-
client = Client(*self.args, **self.kwargs)
|
51 |
-
for k, v in client.__dict__.items():
|
52 |
-
setattr(self, k, v)
|
53 |
-
|
54 |
-
def submit(
|
55 |
-
self,
|
56 |
-
*args,
|
57 |
-
api_name: str | None = None,
|
58 |
-
fn_index: int | None = None,
|
59 |
-
result_callbacks: Callable | list[Callable] | None = None,
|
60 |
-
) -> Job:
|
61 |
-
# Note predict calls submit
|
62 |
-
try:
|
63 |
-
self.refresh_client_if_should()
|
64 |
-
job = super().submit(*args, api_name=api_name, fn_index=fn_index)
|
65 |
-
except Exception as e:
|
66 |
-
print("Hit e=%s" % str(e), flush=True)
|
67 |
-
# force reconfig in case only that
|
68 |
-
self.refresh_client()
|
69 |
-
job = super().submit(*args, api_name=api_name, fn_index=fn_index)
|
70 |
-
|
71 |
-
# see if immediately failed
|
72 |
-
e = job.future._exception
|
73 |
-
if e is not None:
|
74 |
-
print("GR job failed: %s %s" % (str(e), ''.join(traceback.format_tb(e.__traceback__))), flush=True)
|
75 |
-
# force reconfig in case only that
|
76 |
-
self.refresh_client()
|
77 |
-
job = super().submit(*args, api_name=api_name, fn_index=fn_index)
|
78 |
-
e2 = job.future._exception
|
79 |
-
if e2 is not None:
|
80 |
-
print("GR job failed again: %s\n%s" % (str(e2), ''.join(traceback.format_tb(e2.__traceback__))), flush=True)
|
81 |
-
|
82 |
-
return job
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gradio_utils/prompt_form.py
DELETED
@@ -1,118 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import math
|
3 |
-
|
4 |
-
import gradio as gr
|
5 |
-
|
6 |
-
|
7 |
-
def make_chatbots(output_label0, output_label0_model2, **kwargs):
|
8 |
-
text_outputs = []
|
9 |
-
chat_kwargs = []
|
10 |
-
for model_state_lock in kwargs['model_states']:
|
11 |
-
if os.environ.get('DEBUG_MODEL_LOCK'):
|
12 |
-
model_name = model_state_lock["base_model"] + " : " + model_state_lock["inference_server"]
|
13 |
-
else:
|
14 |
-
model_name = model_state_lock["base_model"]
|
15 |
-
output_label = f'h2oGPT [{model_name}]'
|
16 |
-
min_width = 250 if kwargs['gradio_size'] in ['small', 'large', 'medium'] else 160
|
17 |
-
chat_kwargs.append(dict(label=output_label, visible=kwargs['model_lock'], elem_classes='chatsmall',
|
18 |
-
height=kwargs['height'] or 400, min_width=min_width))
|
19 |
-
|
20 |
-
if kwargs['model_lock_columns'] == -1:
|
21 |
-
kwargs['model_lock_columns'] = len(kwargs['model_states'])
|
22 |
-
if kwargs['model_lock_columns'] is None:
|
23 |
-
kwargs['model_lock_columns'] = 3
|
24 |
-
|
25 |
-
ncols = kwargs['model_lock_columns']
|
26 |
-
if kwargs['model_states'] == 0:
|
27 |
-
nrows = 0
|
28 |
-
else:
|
29 |
-
nrows = math.ceil(len(kwargs['model_states']) / kwargs['model_lock_columns'])
|
30 |
-
|
31 |
-
if kwargs['model_lock_columns'] == 0:
|
32 |
-
# not using model_lock
|
33 |
-
pass
|
34 |
-
elif nrows <= 1:
|
35 |
-
with gr.Row():
|
36 |
-
for chat_kwargs1, model_state_lock in zip(chat_kwargs, kwargs['model_states']):
|
37 |
-
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
38 |
-
elif nrows == kwargs['model_states']:
|
39 |
-
with gr.Row():
|
40 |
-
for chat_kwargs1, model_state_lock in zip(chat_kwargs, kwargs['model_states']):
|
41 |
-
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
42 |
-
elif nrows == 2:
|
43 |
-
with gr.Row():
|
44 |
-
for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
|
45 |
-
if mii >= len(kwargs['model_states']) / 2:
|
46 |
-
continue
|
47 |
-
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
48 |
-
with gr.Row():
|
49 |
-
for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
|
50 |
-
if mii < len(kwargs['model_states']) / 2:
|
51 |
-
continue
|
52 |
-
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
53 |
-
elif nrows == 3:
|
54 |
-
with gr.Row():
|
55 |
-
for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
|
56 |
-
if mii >= 1 * len(kwargs['model_states']) / 3:
|
57 |
-
continue
|
58 |
-
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
59 |
-
with gr.Row():
|
60 |
-
for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
|
61 |
-
if mii < 1 * len(kwargs['model_states']) / 3 or mii >= 2 * len(kwargs['model_states']) / 3:
|
62 |
-
continue
|
63 |
-
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
64 |
-
with gr.Row():
|
65 |
-
for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
|
66 |
-
if mii < 2 * len(kwargs['model_states']) / 3:
|
67 |
-
continue
|
68 |
-
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
69 |
-
elif nrows >= 4:
|
70 |
-
with gr.Row():
|
71 |
-
for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
|
72 |
-
if mii >= 1 * len(kwargs['model_states']) / 4:
|
73 |
-
continue
|
74 |
-
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
75 |
-
with gr.Row():
|
76 |
-
for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
|
77 |
-
if mii < 1 * len(kwargs['model_states']) / 4 or mii >= 2 * len(kwargs['model_states']) / 4:
|
78 |
-
continue
|
79 |
-
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
80 |
-
with gr.Row():
|
81 |
-
for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
|
82 |
-
if mii < 2 * len(kwargs['model_states']) / 4 or mii >= 3 * len(kwargs['model_states']) / 4:
|
83 |
-
continue
|
84 |
-
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
85 |
-
with gr.Row():
|
86 |
-
for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
|
87 |
-
if mii < 3 * len(kwargs['model_states']) / 4:
|
88 |
-
continue
|
89 |
-
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
90 |
-
|
91 |
-
with gr.Row():
|
92 |
-
text_output = gr.Chatbot(label=output_label0, visible=not kwargs['model_lock'], height=kwargs['height'] or 400)
|
93 |
-
text_output2 = gr.Chatbot(label=output_label0_model2,
|
94 |
-
visible=False and not kwargs['model_lock'], height=kwargs['height'] or 400)
|
95 |
-
return text_output, text_output2, text_outputs
|
96 |
-
|
97 |
-
|
98 |
-
def make_prompt_form(kwargs):
|
99 |
-
if kwargs['input_lines'] > 1:
|
100 |
-
instruction_label = "Shift-Enter to Submit, Enter for more lines"
|
101 |
-
else:
|
102 |
-
instruction_label = "Enter to Submit, Shift-Enter for more lines"
|
103 |
-
|
104 |
-
with gr.Row():#elem_id='prompt-form-area'):
|
105 |
-
with gr.Column(scale=50):
|
106 |
-
instruction = gr.Textbox(
|
107 |
-
lines=kwargs['input_lines'],
|
108 |
-
label='Ask anything',
|
109 |
-
placeholder=instruction_label,
|
110 |
-
info=None,
|
111 |
-
elem_id='prompt-form',
|
112 |
-
container=True,
|
113 |
-
)
|
114 |
-
with gr.Row():
|
115 |
-
submit = gr.Button(value='Submit', variant='primary', scale=0, size='sm')
|
116 |
-
stop_btn = gr.Button(value="Stop", variant='secondary', scale=0, size='sm')
|
117 |
-
|
118 |
-
return instruction, submit, stop_btn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
h2o-logo.svg
DELETED
h2oai_pipeline.py
CHANGED
@@ -136,6 +136,7 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
|
|
136 |
else:
|
137 |
outputs = rec['generated_text']
|
138 |
rec['generated_text'] = outputs
|
|
|
139 |
return records
|
140 |
|
141 |
def _forward(self, model_inputs, **generate_kwargs):
|
|
|
136 |
else:
|
137 |
outputs = rec['generated_text']
|
138 |
rec['generated_text'] = outputs
|
139 |
+
print("prompt: %s\noutputs: %s\n\n" % (self.prompt_text, outputs), flush=True)
|
140 |
return records
|
141 |
|
142 |
def _forward(self, model_inputs, **generate_kwargs):
|
iterators/__init__.py
DELETED
@@ -1,4 +0,0 @@
|
|
1 |
-
from .timeout_iterator import TimeoutIterator, AsyncTimeoutIterator
|
2 |
-
from .iterator_pipe import IteratorPipe, AsyncIteratorPipe
|
3 |
-
|
4 |
-
__all__ = ["TimeoutIterator", "AsyncTimeoutIterator", "IteratorPipe", "AsyncIteratorPipe"]
|
|
|
|
|
|
|
|
|
|
iterators/__pycache__/__init__.cpython-310.pyc
DELETED
Binary file (337 Bytes)
|
|
iterators/__pycache__/iterator_pipe.cpython-310.pyc
DELETED
Binary file (2.71 kB)
|
|
iterators/__pycache__/timeout_iterator.cpython-310.pyc
DELETED
Binary file (5.63 kB)
|
|
iterators/iterator_pipe.py
DELETED
@@ -1,93 +0,0 @@
|
|
1 |
-
import queue
|
2 |
-
import asyncio
|
3 |
-
|
4 |
-
|
5 |
-
class IteratorPipe:
|
6 |
-
"""
|
7 |
-
Iterator Pipe creates an iterator that can be fed in data from another block of code or thread of execution
|
8 |
-
"""
|
9 |
-
|
10 |
-
def __init__(self, sentinel=object()):
|
11 |
-
self._q = queue.Queue()
|
12 |
-
self._sentinel = sentinel
|
13 |
-
self._sentinel_pushed = False
|
14 |
-
self._closed = False
|
15 |
-
|
16 |
-
def __iter__(self):
|
17 |
-
return self
|
18 |
-
|
19 |
-
def __next__(self):
|
20 |
-
if self._closed:
|
21 |
-
raise StopIteration
|
22 |
-
|
23 |
-
data = self._q.get(block=True)
|
24 |
-
if data is self._sentinel:
|
25 |
-
self._closed = True
|
26 |
-
raise StopIteration
|
27 |
-
|
28 |
-
return data
|
29 |
-
|
30 |
-
def put(self, data) -> bool:
|
31 |
-
"""
|
32 |
-
Pushes next item to Iterator and returns True
|
33 |
-
If iterator has been closed via close(), doesn't push anything and returns False
|
34 |
-
"""
|
35 |
-
if self._sentinel_pushed:
|
36 |
-
return False
|
37 |
-
|
38 |
-
self._q.put(data)
|
39 |
-
return True
|
40 |
-
|
41 |
-
def close(self):
|
42 |
-
"""
|
43 |
-
Close is idempotent. Calling close multiple times is safe
|
44 |
-
Iterator will raise StopIteration only after all elements pushed before close have been iterated
|
45 |
-
"""
|
46 |
-
# make close idempotent
|
47 |
-
if not self._sentinel_pushed:
|
48 |
-
self._sentinel_pushed = True
|
49 |
-
self._q.put(self._sentinel)
|
50 |
-
|
51 |
-
|
52 |
-
class AsyncIteratorPipe:
|
53 |
-
|
54 |
-
def __init__(self, sentinel=object()):
|
55 |
-
self._q = asyncio.Queue()
|
56 |
-
self._sentinel = sentinel
|
57 |
-
self._sentinel_pushed = False
|
58 |
-
self._closed = False
|
59 |
-
|
60 |
-
def __aiter__(self):
|
61 |
-
return self
|
62 |
-
|
63 |
-
async def __anext__(self):
|
64 |
-
if self._closed:
|
65 |
-
raise StopAsyncIteration
|
66 |
-
|
67 |
-
data = await self._q.get()
|
68 |
-
if data is self._sentinel:
|
69 |
-
self._closed = True
|
70 |
-
raise StopAsyncIteration
|
71 |
-
|
72 |
-
return data
|
73 |
-
|
74 |
-
async def put(self, data) -> bool:
|
75 |
-
"""
|
76 |
-
Pushes next item to Iterator and returns True
|
77 |
-
If iterator has been closed via close(), doesn't push anything and returns False
|
78 |
-
"""
|
79 |
-
if self._sentinel_pushed:
|
80 |
-
return False
|
81 |
-
|
82 |
-
await self._q.put(data)
|
83 |
-
return True
|
84 |
-
|
85 |
-
async def close(self):
|
86 |
-
"""
|
87 |
-
Close is idempotent. Calling close multiple times is safe
|
88 |
-
Iterator will raise StopIteration only after all elements pushed before close have been iterated
|
89 |
-
"""
|
90 |
-
# make close idempotent
|
91 |
-
if not self._sentinel_pushed:
|
92 |
-
self._sentinel_pushed = True
|
93 |
-
await self._q.put(self._sentinel)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
iterators/timeout_iterator.py
DELETED
@@ -1,170 +0,0 @@
|
|
1 |
-
import queue
|
2 |
-
import asyncio
|
3 |
-
import threading
|
4 |
-
import traceback
|
5 |
-
|
6 |
-
|
7 |
-
class TimeoutIterator:
|
8 |
-
"""
|
9 |
-
Wrapper class to add timeout feature to synchronous iterators
|
10 |
-
- timeout: timeout for next(). Default=ZERO_TIMEOUT i.e. no timeout or blocking calls to next. Updated using set_timeout()
|
11 |
-
- sentinel: the object returned by iterator when timeout happens
|
12 |
-
- reset_on_next: if set to True, timeout is reset to the value of ZERO_TIMEOUT on each iteration
|
13 |
-
|
14 |
-
TimeoutIterator uses a thread internally.
|
15 |
-
The thread stops once the iterator exhausts or raises an exception during iteration.
|
16 |
-
|
17 |
-
Any exceptions raised within the wrapped iterator are propagated as it is.
|
18 |
-
Exception is raised when all elements generated by the actual iterator before exception have been consumed
|
19 |
-
Timeout can be set dynamically before going for iteration
|
20 |
-
"""
|
21 |
-
ZERO_TIMEOUT = 0.0
|
22 |
-
|
23 |
-
def __init__(self, iterator, timeout=0.0, sentinel=object(), reset_on_next=False, raise_on_exception=True):
|
24 |
-
self._iterator = iterator
|
25 |
-
self._timeout = timeout
|
26 |
-
self._sentinel = sentinel
|
27 |
-
self._reset_on_next = reset_on_next
|
28 |
-
self._raise_on_exception = raise_on_exception
|
29 |
-
|
30 |
-
self._interrupt = False
|
31 |
-
self._done = False
|
32 |
-
self._buffer = queue.Queue()
|
33 |
-
self._thread = threading.Thread(target=self.__lookahead)
|
34 |
-
self._thread.start()
|
35 |
-
|
36 |
-
def get_sentinel(self):
|
37 |
-
return self._sentinel
|
38 |
-
|
39 |
-
def set_reset_on_next(self, reset_on_next):
|
40 |
-
self._reset_on_next = reset_on_next
|
41 |
-
|
42 |
-
def set_timeout(self, timeout: float):
|
43 |
-
"""
|
44 |
-
Set timeout for next iteration
|
45 |
-
"""
|
46 |
-
self._timeout = timeout
|
47 |
-
|
48 |
-
def interrupt(self):
|
49 |
-
"""
|
50 |
-
interrupt and stop the underlying thread.
|
51 |
-
the thread acutally dies only after interrupt has been set and
|
52 |
-
the underlying iterator yields a value after that.
|
53 |
-
"""
|
54 |
-
self._interrupt = True
|
55 |
-
|
56 |
-
def __iter__(self):
|
57 |
-
return self
|
58 |
-
|
59 |
-
def __next__(self):
|
60 |
-
"""
|
61 |
-
yield the result from iterator
|
62 |
-
if timeout > 0:
|
63 |
-
yield data if available.
|
64 |
-
otherwise yield sentinal
|
65 |
-
"""
|
66 |
-
if self._done:
|
67 |
-
raise StopIteration
|
68 |
-
|
69 |
-
data = self._sentinel
|
70 |
-
try:
|
71 |
-
if self._timeout > self.ZERO_TIMEOUT:
|
72 |
-
data = self._buffer.get(timeout=self._timeout)
|
73 |
-
else:
|
74 |
-
data = self._buffer.get()
|
75 |
-
except queue.Empty:
|
76 |
-
pass
|
77 |
-
finally:
|
78 |
-
# see if timeout needs to be reset
|
79 |
-
if self._reset_on_next:
|
80 |
-
self._timeout = self.ZERO_TIMEOUT
|
81 |
-
|
82 |
-
# propagate any exceptions including StopIteration
|
83 |
-
if isinstance(data, BaseException):
|
84 |
-
self._done = True
|
85 |
-
if isinstance(data, StopIteration):
|
86 |
-
raise data
|
87 |
-
ex = ''.join(traceback.format_tb(data.__traceback__))
|
88 |
-
print("Generation Failed: %s %s" % (str(data), str(ex)), flush=True)
|
89 |
-
if self._raise_on_exception:
|
90 |
-
raise data
|
91 |
-
else:
|
92 |
-
return data
|
93 |
-
|
94 |
-
return data
|
95 |
-
|
96 |
-
def __lookahead(self):
|
97 |
-
try:
|
98 |
-
while True:
|
99 |
-
self._buffer.put(next(self._iterator))
|
100 |
-
if self._interrupt:
|
101 |
-
raise StopIteration()
|
102 |
-
except BaseException as e:
|
103 |
-
self._buffer.put(e)
|
104 |
-
|
105 |
-
|
106 |
-
class AsyncTimeoutIterator:
|
107 |
-
"""
|
108 |
-
Async version of TimeoutIterator. See method documentation of TimeoutIterator
|
109 |
-
"""
|
110 |
-
ZERO_TIMEOUT = 0.0
|
111 |
-
|
112 |
-
def __init__(self, iterator, timeout=0.0, sentinel=object(), reset_on_next=False):
|
113 |
-
self._iterator = iterator
|
114 |
-
self._timeout = timeout
|
115 |
-
self._sentinel = sentinel
|
116 |
-
self._reset_on_next = reset_on_next
|
117 |
-
|
118 |
-
self._interrupt = False
|
119 |
-
self._done = False
|
120 |
-
self._buffer = asyncio.Queue()
|
121 |
-
self._task = asyncio.get_event_loop().create_task(self.__lookahead())
|
122 |
-
|
123 |
-
def get_sentinel(self):
|
124 |
-
return self._sentinel
|
125 |
-
|
126 |
-
def set_reset_on_next(self, reset_on_next):
|
127 |
-
self._reset_on_next = reset_on_next
|
128 |
-
|
129 |
-
def set_timeout(self, timeout: float):
|
130 |
-
self._timeout = timeout
|
131 |
-
|
132 |
-
def interrupt(self):
|
133 |
-
self._interrupt = True
|
134 |
-
|
135 |
-
def __aiter__(self):
|
136 |
-
return self
|
137 |
-
|
138 |
-
async def __anext__(self):
|
139 |
-
if self._done:
|
140 |
-
raise StopAsyncIteration
|
141 |
-
|
142 |
-
data = self._sentinel
|
143 |
-
try:
|
144 |
-
if self._timeout > self.ZERO_TIMEOUT:
|
145 |
-
data = await asyncio.wait_for(self._buffer.get(), self._timeout)
|
146 |
-
else:
|
147 |
-
data = await self._buffer.get()
|
148 |
-
except asyncio.TimeoutError:
|
149 |
-
pass
|
150 |
-
finally:
|
151 |
-
# see if timeout needs to be reset
|
152 |
-
if self._reset_on_next:
|
153 |
-
self._timeout = self.ZERO_TIMEOUT
|
154 |
-
|
155 |
-
# propagate any exceptions including StopIteration
|
156 |
-
if isinstance(data, BaseException):
|
157 |
-
self._done = True
|
158 |
-
raise data
|
159 |
-
|
160 |
-
return data
|
161 |
-
|
162 |
-
async def __lookahead(self):
|
163 |
-
try:
|
164 |
-
while True:
|
165 |
-
data = await self._iterator.__anext__()
|
166 |
-
await self._buffer.put(data)
|
167 |
-
if self._interrupt:
|
168 |
-
raise StopAsyncIteration()
|
169 |
-
except BaseException as e:
|
170 |
-
await self._buffer.put(e)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prompter.py
CHANGED
@@ -120,7 +120,7 @@ def get_prompt(prompt_type, prompt_dict, chat, context, reduced, making_context,
|
|
120 |
elif prompt_type in [PromptType.custom.value, str(PromptType.custom.value),
|
121 |
PromptType.custom.name]:
|
122 |
promptA = prompt_dict.get('promptA', '')
|
123 |
-
promptB = prompt_dict('promptB', '')
|
124 |
PreInstruct = prompt_dict.get('PreInstruct', '')
|
125 |
PreInput = prompt_dict.get('PreInput', '')
|
126 |
PreResponse = prompt_dict.get('PreResponse', '')
|
@@ -693,7 +693,9 @@ class Prompter(object):
|
|
693 |
output = clean_response(output)
|
694 |
elif prompt is None:
|
695 |
# then use most basic parsing like pipeline
|
696 |
-
if self.botstr
|
|
|
|
|
697 |
if self.humanstr:
|
698 |
output = clean_response(output.split(self.botstr)[1].split(self.humanstr)[0])
|
699 |
else:
|
|
|
120 |
elif prompt_type in [PromptType.custom.value, str(PromptType.custom.value),
|
121 |
PromptType.custom.name]:
|
122 |
promptA = prompt_dict.get('promptA', '')
|
123 |
+
promptB = prompt_dict.get('promptB', '')
|
124 |
PreInstruct = prompt_dict.get('PreInstruct', '')
|
125 |
PreInput = prompt_dict.get('PreInput', '')
|
126 |
PreResponse = prompt_dict.get('PreResponse', '')
|
|
|
693 |
output = clean_response(output)
|
694 |
elif prompt is None:
|
695 |
# then use most basic parsing like pipeline
|
696 |
+
if not self.botstr:
|
697 |
+
pass
|
698 |
+
elif self.botstr in output:
|
699 |
if self.humanstr:
|
700 |
output = clean_response(output.split(self.botstr)[1].split(self.humanstr)[0])
|
701 |
else:
|
requirements.txt
CHANGED
@@ -1,153 +0,0 @@
|
|
1 |
-
# for generate (gradio server) and finetune
|
2 |
-
datasets==2.13.0
|
3 |
-
sentencepiece==0.1.99
|
4 |
-
gradio==3.35.2
|
5 |
-
huggingface_hub==0.15.1
|
6 |
-
appdirs==1.4.4
|
7 |
-
fire==0.5.0
|
8 |
-
docutils==0.20.1
|
9 |
-
torch==2.0.1
|
10 |
-
evaluate==0.4.0
|
11 |
-
rouge_score==0.1.2
|
12 |
-
sacrebleu==2.3.1
|
13 |
-
scikit-learn==1.2.2
|
14 |
-
alt-profanity-check==1.2.2
|
15 |
-
better-profanity==0.7.0
|
16 |
-
numpy==1.24.3
|
17 |
-
pandas==2.0.2
|
18 |
-
matplotlib==3.7.1
|
19 |
-
loralib==0.1.1
|
20 |
-
bitsandbytes==0.39.0
|
21 |
-
accelerate==0.20.3
|
22 |
-
git+https://github.com/huggingface/peft.git@0b62b4378b4ce9367932c73540349da9a41bdea8
|
23 |
-
transformers==4.30.2
|
24 |
-
tokenizers==0.13.3
|
25 |
-
APScheduler==3.10.1
|
26 |
-
|
27 |
-
# optional for generate
|
28 |
-
pynvml==11.5.0
|
29 |
-
psutil==5.9.5
|
30 |
-
boto3==1.26.101
|
31 |
-
botocore==1.29.101
|
32 |
-
|
33 |
-
# optional for finetune
|
34 |
-
tensorboard==2.13.0
|
35 |
-
neptune==1.2.0
|
36 |
-
|
37 |
-
# for gradio client
|
38 |
-
gradio_client==0.2.7
|
39 |
-
beautifulsoup4==4.12.2
|
40 |
-
markdown==3.4.3
|
41 |
-
|
42 |
-
# data and testing
|
43 |
-
pytest==7.2.2
|
44 |
-
pytest-xdist==3.2.1
|
45 |
-
nltk==3.8.1
|
46 |
-
textstat==0.7.3
|
47 |
-
# pandoc==2.3
|
48 |
-
#pypandoc==1.11
|
49 |
-
pypandoc_binary==1.11
|
50 |
-
openpyxl==3.1.2
|
51 |
-
lm_dataformat==0.0.20
|
52 |
-
bioc==2.0
|
53 |
-
|
54 |
-
# falcon
|
55 |
-
einops==0.6.1
|
56 |
-
instructorembedding==1.0.1
|
57 |
-
|
58 |
-
# for gpt4all .env file, but avoid worrying about imports
|
59 |
-
python-dotenv==1.0.0
|
60 |
-
|
61 |
-
text-generation==0.6.0
|
62 |
-
# for tokenization when don't have HF tokenizer
|
63 |
-
tiktoken==0.4.0
|
64 |
-
# optional: for OpenAI endpoint or embeddings (requires key)
|
65 |
-
openai==0.27.8
|
66 |
-
# optional for chat with PDF
|
67 |
-
langchain==0.0.202
|
68 |
-
pypdf==3.9.1
|
69 |
-
# avoid textract, requires old six
|
70 |
-
#textract==1.6.5
|
71 |
-
|
72 |
-
# for HF embeddings
|
73 |
-
sentence_transformers==2.2.2
|
74 |
-
|
75 |
-
# local vector db
|
76 |
-
chromadb==0.3.25
|
77 |
-
# server vector db
|
78 |
-
#pymilvus==2.2.8
|
79 |
-
|
80 |
-
# weak url support, if can't install opencv etc. If comment-in this one, then comment-out unstructured[local-inference]==0.6.6
|
81 |
-
# unstructured==0.6.6
|
82 |
-
|
83 |
-
# strong support for images
|
84 |
-
# Requires on Ubuntu: sudo apt-get install libmagic-dev poppler-utils tesseract-ocr libreoffice
|
85 |
-
unstructured[local-inference]==0.7.4
|
86 |
-
#pdf2image==1.16.3
|
87 |
-
#pytesseract==0.3.10
|
88 |
-
pillow
|
89 |
-
|
90 |
-
pdfminer.six==20221105
|
91 |
-
urllib3
|
92 |
-
requests_file
|
93 |
-
|
94 |
-
#pdf2image==1.16.3
|
95 |
-
#pytesseract==0.3.10
|
96 |
-
tabulate==0.9.0
|
97 |
-
# FYI pandoc already part of requirements.txt
|
98 |
-
|
99 |
-
# JSONLoader, but makes some trouble for some users
|
100 |
-
# jq==1.4.1
|
101 |
-
|
102 |
-
# to check licenses
|
103 |
-
# Run: pip-licenses|grep -v 'BSD\|Apache\|MIT'
|
104 |
-
pip-licenses==4.3.0
|
105 |
-
|
106 |
-
# weaviate vector db
|
107 |
-
weaviate-client==3.20.0
|
108 |
-
# optional for chat with PDF
|
109 |
-
langchain==0.0.202
|
110 |
-
pypdf==3.9.1
|
111 |
-
# avoid textract, requires old six
|
112 |
-
#textract==1.6.5
|
113 |
-
|
114 |
-
# for HF embeddings
|
115 |
-
sentence_transformers==2.2.2
|
116 |
-
|
117 |
-
# local vector db
|
118 |
-
chromadb==0.3.25
|
119 |
-
# server vector db
|
120 |
-
#pymilvus==2.2.8
|
121 |
-
|
122 |
-
# weak url support, if can't install opencv etc. If comment-in this one, then comment-out unstructured[local-inference]==0.6.6
|
123 |
-
# unstructured==0.6.6
|
124 |
-
|
125 |
-
# strong support for images
|
126 |
-
# Requires on Ubuntu: sudo apt-get install libmagic-dev poppler-utils tesseract-ocr libreoffice
|
127 |
-
unstructured[local-inference]==0.7.4
|
128 |
-
#pdf2image==1.16.3
|
129 |
-
#pytesseract==0.3.10
|
130 |
-
pillow
|
131 |
-
|
132 |
-
pdfminer.six==20221105
|
133 |
-
urllib3
|
134 |
-
requests_file
|
135 |
-
|
136 |
-
#pdf2image==1.16.3
|
137 |
-
#pytesseract==0.3.10
|
138 |
-
tabulate==0.9.0
|
139 |
-
# FYI pandoc already part of requirements.txt
|
140 |
-
|
141 |
-
# JSONLoader, but makes some trouble for some users
|
142 |
-
# jq==1.4.1
|
143 |
-
|
144 |
-
# to check licenses
|
145 |
-
# Run: pip-licenses|grep -v 'BSD\|Apache\|MIT'
|
146 |
-
pip-licenses==4.3.0
|
147 |
-
|
148 |
-
# weaviate vector db
|
149 |
-
weaviate-client==3.20.0
|
150 |
-
faiss-gpu==1.7.2
|
151 |
-
arxiv==1.4.7
|
152 |
-
pymupdf==1.22.3 # AGPL license
|
153 |
-
# extract-msg==0.41.1 # GPL3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|