Spaces:
Runtime error
Runtime error
pseudotensor
commited on
Commit
•
4afcc1d
1
Parent(s):
33cecf1
Remoev
Browse files- client_test.py +0 -362
- create_data.py +0 -1809
- enums.py +0 -120
- evaluate_params.py +0 -52
- gen.py +0 -0
- generate.py +0 -16
- gpt4all_llm.py +0 -316
- gpt_langchain.py +0 -0
- gradio_runner.py +0 -0
- gradio_themes.py +0 -231
- h2oai_pipeline.py +0 -201
- loaders.py +0 -61
- prompter.py +0 -871
- stopping.py +0 -78
- utils.py +0 -1080
- utils_langchain.py +0 -64
client_test.py
DELETED
@@ -1,362 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Client test.
|
3 |
-
|
4 |
-
Run server:
|
5 |
-
|
6 |
-
python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6_9b
|
7 |
-
|
8 |
-
NOTE: For private models, add --use-auth_token=True
|
9 |
-
|
10 |
-
NOTE: --use_gpu_id=True (default) must be used for multi-GPU in case see failures with cuda:x cuda:y mismatches.
|
11 |
-
Currently, this will force model to be on a single GPU.
|
12 |
-
|
13 |
-
Then run this client as:
|
14 |
-
|
15 |
-
python src/client_test.py
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
For HF spaces:
|
20 |
-
|
21 |
-
HOST="https://h2oai-h2ogpt-chatbot.hf.space" python src/client_test.py
|
22 |
-
|
23 |
-
Result:
|
24 |
-
|
25 |
-
Loaded as API: https://h2oai-h2ogpt-chatbot.hf.space ✔
|
26 |
-
{'instruction_nochat': 'Who are you?', 'iinput_nochat': '', 'response': 'I am h2oGPT, a large language model developed by LAION.', 'sources': ''}
|
27 |
-
|
28 |
-
|
29 |
-
For demo:
|
30 |
-
|
31 |
-
HOST="https://gpt.h2o.ai" python src/client_test.py
|
32 |
-
|
33 |
-
Result:
|
34 |
-
|
35 |
-
Loaded as API: https://gpt.h2o.ai ✔
|
36 |
-
{'instruction_nochat': 'Who are you?', 'iinput_nochat': '', 'response': 'I am h2oGPT, a chatbot created by LAION.', 'sources': ''}
|
37 |
-
|
38 |
-
NOTE: Raw output from API for nochat case is a string of a python dict and will remain so if other entries are added to dict:
|
39 |
-
|
40 |
-
{'response': "I'm h2oGPT, a large language model by H2O.ai, the visionary leader in democratizing AI.", 'sources': ''}
|
41 |
-
|
42 |
-
|
43 |
-
"""
|
44 |
-
import ast
|
45 |
-
import time
|
46 |
-
import os
|
47 |
-
import markdown # pip install markdown
|
48 |
-
import pytest
|
49 |
-
from bs4 import BeautifulSoup # pip install beautifulsoup4
|
50 |
-
|
51 |
-
from enums import DocumentSubset, LangChainAction
|
52 |
-
|
53 |
-
debug = False
|
54 |
-
|
55 |
-
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
|
56 |
-
|
57 |
-
|
58 |
-
def get_client(serialize=True):
|
59 |
-
from gradio_client import Client
|
60 |
-
|
61 |
-
client = Client(os.getenv('HOST', "http://localhost:7860"), serialize=serialize)
|
62 |
-
if debug:
|
63 |
-
print(client.view_api(all_endpoints=True))
|
64 |
-
return client
|
65 |
-
|
66 |
-
|
67 |
-
def get_args(prompt, prompt_type, chat=False, stream_output=False,
|
68 |
-
max_new_tokens=50,
|
69 |
-
top_k_docs=3,
|
70 |
-
langchain_mode='Disabled',
|
71 |
-
add_chat_history_to_context=True,
|
72 |
-
langchain_action=LangChainAction.QUERY.value,
|
73 |
-
langchain_agents=[],
|
74 |
-
prompt_dict=None):
|
75 |
-
from collections import OrderedDict
|
76 |
-
kwargs = OrderedDict(instruction=prompt if chat else '', # only for chat=True
|
77 |
-
iinput='', # only for chat=True
|
78 |
-
context='',
|
79 |
-
# streaming output is supported, loops over and outputs each generation in streaming mode
|
80 |
-
# but leave stream_output=False for simple input/output mode
|
81 |
-
stream_output=stream_output,
|
82 |
-
prompt_type=prompt_type,
|
83 |
-
prompt_dict=prompt_dict,
|
84 |
-
temperature=0.1,
|
85 |
-
top_p=0.75,
|
86 |
-
top_k=40,
|
87 |
-
num_beams=1,
|
88 |
-
max_new_tokens=max_new_tokens,
|
89 |
-
min_new_tokens=0,
|
90 |
-
early_stopping=False,
|
91 |
-
max_time=20,
|
92 |
-
repetition_penalty=1.0,
|
93 |
-
num_return_sequences=1,
|
94 |
-
do_sample=True,
|
95 |
-
chat=chat,
|
96 |
-
instruction_nochat=prompt if not chat else '',
|
97 |
-
iinput_nochat='', # only for chat=False
|
98 |
-
langchain_mode=langchain_mode,
|
99 |
-
add_chat_history_to_context=add_chat_history_to_context,
|
100 |
-
langchain_action=langchain_action,
|
101 |
-
langchain_agents=langchain_agents,
|
102 |
-
top_k_docs=top_k_docs,
|
103 |
-
chunk=True,
|
104 |
-
chunk_size=512,
|
105 |
-
document_subset=DocumentSubset.Relevant.name,
|
106 |
-
document_choice=[],
|
107 |
-
)
|
108 |
-
from evaluate_params import eval_func_param_names
|
109 |
-
assert len(set(eval_func_param_names).difference(set(list(kwargs.keys())))) == 0
|
110 |
-
if chat:
|
111 |
-
# add chatbot output on end. Assumes serialize=False
|
112 |
-
kwargs.update(dict(chatbot=[]))
|
113 |
-
|
114 |
-
return kwargs, list(kwargs.values())
|
115 |
-
|
116 |
-
|
117 |
-
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
118 |
-
def test_client_basic(prompt_type='human_bot'):
|
119 |
-
return run_client_nochat(prompt='Who are you?', prompt_type=prompt_type, max_new_tokens=50)
|
120 |
-
|
121 |
-
|
122 |
-
def run_client_nochat(prompt, prompt_type, max_new_tokens):
|
123 |
-
kwargs, args = get_args(prompt, prompt_type, chat=False, max_new_tokens=max_new_tokens)
|
124 |
-
|
125 |
-
api_name = '/submit_nochat'
|
126 |
-
client = get_client(serialize=True)
|
127 |
-
res = client.predict(
|
128 |
-
*tuple(args),
|
129 |
-
api_name=api_name,
|
130 |
-
)
|
131 |
-
print("Raw client result: %s" % res, flush=True)
|
132 |
-
res_dict = dict(prompt=kwargs['instruction_nochat'], iinput=kwargs['iinput_nochat'],
|
133 |
-
response=md_to_text(res))
|
134 |
-
print(res_dict)
|
135 |
-
return res_dict, client
|
136 |
-
|
137 |
-
|
138 |
-
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
139 |
-
def test_client_basic_api(prompt_type='human_bot'):
|
140 |
-
return run_client_nochat_api(prompt='Who are you?', prompt_type=prompt_type, max_new_tokens=50)
|
141 |
-
|
142 |
-
|
143 |
-
def run_client_nochat_api(prompt, prompt_type, max_new_tokens):
|
144 |
-
kwargs, args = get_args(prompt, prompt_type, chat=False, max_new_tokens=max_new_tokens)
|
145 |
-
|
146 |
-
api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
|
147 |
-
client = get_client(serialize=True)
|
148 |
-
res = client.predict(
|
149 |
-
str(dict(kwargs)),
|
150 |
-
api_name=api_name,
|
151 |
-
)
|
152 |
-
print("Raw client result: %s" % res, flush=True)
|
153 |
-
res_dict = dict(prompt=kwargs['instruction_nochat'], iinput=kwargs['iinput_nochat'],
|
154 |
-
response=md_to_text(ast.literal_eval(res)['response']),
|
155 |
-
sources=ast.literal_eval(res)['sources'])
|
156 |
-
print(res_dict)
|
157 |
-
return res_dict, client
|
158 |
-
|
159 |
-
|
160 |
-
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
161 |
-
def test_client_basic_api_lean(prompt_type='human_bot'):
|
162 |
-
return run_client_nochat_api_lean(prompt='Who are you?', prompt_type=prompt_type, max_new_tokens=50)
|
163 |
-
|
164 |
-
|
165 |
-
def run_client_nochat_api_lean(prompt, prompt_type, max_new_tokens):
|
166 |
-
kwargs = dict(instruction_nochat=prompt)
|
167 |
-
|
168 |
-
api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
|
169 |
-
client = get_client(serialize=True)
|
170 |
-
res = client.predict(
|
171 |
-
str(dict(kwargs)),
|
172 |
-
api_name=api_name,
|
173 |
-
)
|
174 |
-
print("Raw client result: %s" % res, flush=True)
|
175 |
-
res_dict = dict(prompt=kwargs['instruction_nochat'],
|
176 |
-
response=md_to_text(ast.literal_eval(res)['response']),
|
177 |
-
sources=ast.literal_eval(res)['sources'])
|
178 |
-
print(res_dict)
|
179 |
-
return res_dict, client
|
180 |
-
|
181 |
-
|
182 |
-
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
183 |
-
def test_client_basic_api_lean_morestuff(prompt_type='human_bot'):
|
184 |
-
return run_client_nochat_api_lean_morestuff(prompt='Who are you?', prompt_type=prompt_type, max_new_tokens=50)
|
185 |
-
|
186 |
-
|
187 |
-
def run_client_nochat_api_lean_morestuff(prompt, prompt_type='human_bot', max_new_tokens=512):
|
188 |
-
kwargs = dict(
|
189 |
-
instruction='',
|
190 |
-
iinput='',
|
191 |
-
context='',
|
192 |
-
stream_output=False,
|
193 |
-
prompt_type=prompt_type,
|
194 |
-
temperature=0.1,
|
195 |
-
top_p=0.75,
|
196 |
-
top_k=40,
|
197 |
-
num_beams=1,
|
198 |
-
max_new_tokens=256,
|
199 |
-
min_new_tokens=0,
|
200 |
-
early_stopping=False,
|
201 |
-
max_time=20,
|
202 |
-
repetition_penalty=1.0,
|
203 |
-
num_return_sequences=1,
|
204 |
-
do_sample=True,
|
205 |
-
chat=False,
|
206 |
-
instruction_nochat=prompt,
|
207 |
-
iinput_nochat='',
|
208 |
-
langchain_mode='Disabled',
|
209 |
-
add_chat_history_to_context=True,
|
210 |
-
langchain_action=LangChainAction.QUERY.value,
|
211 |
-
langchain_agents=[],
|
212 |
-
top_k_docs=4,
|
213 |
-
document_subset=DocumentSubset.Relevant.name,
|
214 |
-
document_choice=[],
|
215 |
-
)
|
216 |
-
|
217 |
-
api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
|
218 |
-
client = get_client(serialize=True)
|
219 |
-
res = client.predict(
|
220 |
-
str(dict(kwargs)),
|
221 |
-
api_name=api_name,
|
222 |
-
)
|
223 |
-
print("Raw client result: %s" % res, flush=True)
|
224 |
-
res_dict = dict(prompt=kwargs['instruction_nochat'],
|
225 |
-
response=md_to_text(ast.literal_eval(res)['response']),
|
226 |
-
sources=ast.literal_eval(res)['sources'])
|
227 |
-
print(res_dict)
|
228 |
-
return res_dict, client
|
229 |
-
|
230 |
-
|
231 |
-
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
232 |
-
def test_client_chat(prompt_type='human_bot'):
|
233 |
-
return run_client_chat(prompt='Who are you?', prompt_type=prompt_type, stream_output=False, max_new_tokens=50,
|
234 |
-
langchain_mode='Disabled',
|
235 |
-
langchain_action=LangChainAction.QUERY.value,
|
236 |
-
langchain_agents=[])
|
237 |
-
|
238 |
-
|
239 |
-
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
240 |
-
def test_client_chat_stream(prompt_type='human_bot'):
|
241 |
-
return run_client_chat(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type,
|
242 |
-
stream_output=True, max_new_tokens=512,
|
243 |
-
langchain_mode='Disabled',
|
244 |
-
langchain_action=LangChainAction.QUERY.value,
|
245 |
-
langchain_agents=[])
|
246 |
-
|
247 |
-
|
248 |
-
def run_client_chat(prompt, prompt_type, stream_output, max_new_tokens,
|
249 |
-
langchain_mode, langchain_action, langchain_agents,
|
250 |
-
prompt_dict=None):
|
251 |
-
client = get_client(serialize=False)
|
252 |
-
|
253 |
-
kwargs, args = get_args(prompt, prompt_type, chat=True, stream_output=stream_output,
|
254 |
-
max_new_tokens=max_new_tokens,
|
255 |
-
langchain_mode=langchain_mode,
|
256 |
-
langchain_action=langchain_action,
|
257 |
-
langchain_agents=langchain_agents,
|
258 |
-
prompt_dict=prompt_dict)
|
259 |
-
return run_client(client, prompt, args, kwargs)
|
260 |
-
|
261 |
-
|
262 |
-
def run_client(client, prompt, args, kwargs, do_md_to_text=True, verbose=False):
|
263 |
-
assert kwargs['chat'], "Chat mode only"
|
264 |
-
res = client.predict(*tuple(args), api_name='/instruction')
|
265 |
-
args[-1] += [res[-1]]
|
266 |
-
|
267 |
-
res_dict = kwargs
|
268 |
-
res_dict['prompt'] = prompt
|
269 |
-
if not kwargs['stream_output']:
|
270 |
-
res = client.predict(*tuple(args), api_name='/instruction_bot')
|
271 |
-
res_dict['response'] = res[0][-1][1]
|
272 |
-
print(md_to_text(res_dict['response'], do_md_to_text=do_md_to_text))
|
273 |
-
return res_dict, client
|
274 |
-
else:
|
275 |
-
job = client.submit(*tuple(args), api_name='/instruction_bot')
|
276 |
-
res1 = ''
|
277 |
-
while not job.done():
|
278 |
-
outputs_list = job.communicator.job.outputs
|
279 |
-
if outputs_list:
|
280 |
-
res = job.communicator.job.outputs[-1]
|
281 |
-
res1 = res[0][-1][-1]
|
282 |
-
res1 = md_to_text(res1, do_md_to_text=do_md_to_text)
|
283 |
-
print(res1)
|
284 |
-
time.sleep(0.1)
|
285 |
-
full_outputs = job.outputs()
|
286 |
-
if verbose:
|
287 |
-
print('job.outputs: %s' % str(full_outputs))
|
288 |
-
# ensure get ending to avoid race
|
289 |
-
# -1 means last response if streaming
|
290 |
-
# 0 means get text_output, ignore exception_text
|
291 |
-
# 0 means get list within text_output that looks like [[prompt], [answer]]
|
292 |
-
# 1 means get bot answer, so will have last bot answer
|
293 |
-
res_dict['response'] = md_to_text(full_outputs[-1][0][0][1], do_md_to_text=do_md_to_text)
|
294 |
-
return res_dict, client
|
295 |
-
|
296 |
-
|
297 |
-
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
298 |
-
def test_client_nochat_stream(prompt_type='human_bot'):
|
299 |
-
return run_client_nochat_gen(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type,
|
300 |
-
stream_output=True, max_new_tokens=512,
|
301 |
-
langchain_mode='Disabled',
|
302 |
-
langchain_action=LangChainAction.QUERY.value,
|
303 |
-
langchain_agents=[])
|
304 |
-
|
305 |
-
|
306 |
-
def run_client_nochat_gen(prompt, prompt_type, stream_output, max_new_tokens,
|
307 |
-
langchain_mode, langchain_action, langchain_agents):
|
308 |
-
client = get_client(serialize=False)
|
309 |
-
|
310 |
-
kwargs, args = get_args(prompt, prompt_type, chat=False, stream_output=stream_output,
|
311 |
-
max_new_tokens=max_new_tokens, langchain_mode=langchain_mode,
|
312 |
-
langchain_action=langchain_action, langchain_agents=langchain_agents)
|
313 |
-
return run_client_gen(client, prompt, args, kwargs)
|
314 |
-
|
315 |
-
|
316 |
-
def run_client_gen(client, prompt, args, kwargs, do_md_to_text=True, verbose=False):
|
317 |
-
res_dict = kwargs
|
318 |
-
res_dict['prompt'] = prompt
|
319 |
-
if not kwargs['stream_output']:
|
320 |
-
res = client.predict(str(dict(kwargs)), api_name='/submit_nochat_api')
|
321 |
-
res_dict['response'] = res[0]
|
322 |
-
print(md_to_text(res_dict['response'], do_md_to_text=do_md_to_text))
|
323 |
-
return res_dict, client
|
324 |
-
else:
|
325 |
-
job = client.submit(str(dict(kwargs)), api_name='/submit_nochat_api')
|
326 |
-
while not job.done():
|
327 |
-
outputs_list = job.communicator.job.outputs
|
328 |
-
if outputs_list:
|
329 |
-
res = job.communicator.job.outputs[-1]
|
330 |
-
res_dict = ast.literal_eval(res)
|
331 |
-
print('Stream: %s' % res_dict['response'])
|
332 |
-
time.sleep(0.1)
|
333 |
-
res_list = job.outputs()
|
334 |
-
assert len(res_list) > 0, "No response, check server"
|
335 |
-
res = res_list[-1]
|
336 |
-
res_dict = ast.literal_eval(res)
|
337 |
-
print('Final: %s' % res_dict['response'])
|
338 |
-
return res_dict, client
|
339 |
-
|
340 |
-
|
341 |
-
def md_to_text(md, do_md_to_text=True):
|
342 |
-
if not do_md_to_text:
|
343 |
-
return md
|
344 |
-
assert md is not None, "Markdown is None"
|
345 |
-
html = markdown.markdown(md)
|
346 |
-
soup = BeautifulSoup(html, features='html.parser')
|
347 |
-
return soup.get_text()
|
348 |
-
|
349 |
-
|
350 |
-
def run_client_many(prompt_type='human_bot'):
|
351 |
-
ret1, _ = test_client_chat(prompt_type=prompt_type)
|
352 |
-
ret2, _ = test_client_chat_stream(prompt_type=prompt_type)
|
353 |
-
ret3, _ = test_client_nochat_stream(prompt_type=prompt_type)
|
354 |
-
ret4, _ = test_client_basic(prompt_type=prompt_type)
|
355 |
-
ret5, _ = test_client_basic_api(prompt_type=prompt_type)
|
356 |
-
ret6, _ = test_client_basic_api_lean(prompt_type=prompt_type)
|
357 |
-
ret7, _ = test_client_basic_api_lean_morestuff(prompt_type=prompt_type)
|
358 |
-
return ret1, ret2, ret3, ret4, ret5, ret6, ret7
|
359 |
-
|
360 |
-
|
361 |
-
if __name__ == '__main__':
|
362 |
-
run_client_many()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
create_data.py
DELETED
@@ -1,1809 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Dataset creation tools.
|
3 |
-
|
4 |
-
Keep to-level imports clean of non-trivial imports for specific tools,
|
5 |
-
because this file is imported for various purposes
|
6 |
-
"""
|
7 |
-
|
8 |
-
import ast
|
9 |
-
import concurrent.futures
|
10 |
-
import contextlib
|
11 |
-
import hashlib
|
12 |
-
import json
|
13 |
-
import os
|
14 |
-
import shutil
|
15 |
-
import signal
|
16 |
-
import sys
|
17 |
-
import traceback
|
18 |
-
from concurrent.futures import ProcessPoolExecutor
|
19 |
-
|
20 |
-
import psutil
|
21 |
-
import pytest
|
22 |
-
import pandas as pd
|
23 |
-
import numpy as np
|
24 |
-
from tqdm import tqdm
|
25 |
-
|
26 |
-
from utils import flatten_list, remove
|
27 |
-
|
28 |
-
|
29 |
-
def parse_rst_file(filepath):
|
30 |
-
with open(filepath, 'r') as f:
|
31 |
-
input_data = f.read()
|
32 |
-
settings_overrides = {'initial_header_level': 2}
|
33 |
-
from docutils import core
|
34 |
-
document = core.publish_doctree(
|
35 |
-
source=input_data,
|
36 |
-
source_path=filepath,
|
37 |
-
settings_overrides=settings_overrides,
|
38 |
-
)
|
39 |
-
qa_pairs = []
|
40 |
-
current_section = None
|
41 |
-
current_question = ""
|
42 |
-
current_answer = ""
|
43 |
-
for node in document.traverse():
|
44 |
-
if node.__class__.__name__ == 'section':
|
45 |
-
current_section = ""
|
46 |
-
elif current_section is not None:
|
47 |
-
if node.__class__.__name__ == 'Text':
|
48 |
-
if node.astext()[-1] == "?":
|
49 |
-
if current_question:
|
50 |
-
qa_pairs.append((current_question, current_answer))
|
51 |
-
current_question = node.astext()
|
52 |
-
current_answer = ""
|
53 |
-
else:
|
54 |
-
current_answer += node.astext()
|
55 |
-
if current_answer:
|
56 |
-
qa_pairs.append((current_question, current_answer))
|
57 |
-
return {k: v for k, v in qa_pairs}
|
58 |
-
|
59 |
-
|
60 |
-
def test_scrape_dai_docs():
|
61 |
-
home = os.path.expanduser('~')
|
62 |
-
file = os.path.join(home, 'h2oai/docs/faq.rst')
|
63 |
-
qa_pairs = parse_rst_file(file)
|
64 |
-
prompt_type = 'human_bot'
|
65 |
-
from prompter import prompt_types
|
66 |
-
assert prompt_type in prompt_types
|
67 |
-
save_thing = [{"instruction": k, "output": v, 'prompt_type': prompt_type} for k, v in qa_pairs.items()]
|
68 |
-
output_file = "dai_faq.json"
|
69 |
-
with open(output_file, "wt") as f:
|
70 |
-
f.write(json.dumps(save_thing, indent=2))
|
71 |
-
|
72 |
-
|
73 |
-
def test_scrape_dai_docs_all():
|
74 |
-
"""
|
75 |
-
pytest create_data.py::test_scrape_dai_docs_all
|
76 |
-
"""
|
77 |
-
import glob
|
78 |
-
import nltk
|
79 |
-
nltk.download('punkt')
|
80 |
-
dd = {}
|
81 |
-
np.random.seed(1234)
|
82 |
-
home = os.path.expanduser('~')
|
83 |
-
files = list(glob.glob(os.path.join(home, "h2oai/docs/**/*rst")))
|
84 |
-
np.random.shuffle(files)
|
85 |
-
val_count = int(0.05 * len(files))
|
86 |
-
train_files = files[val_count:]
|
87 |
-
valid_files = files[:val_count]
|
88 |
-
things = [
|
89 |
-
("dai_docs.train.json", train_files),
|
90 |
-
("dai_docs.valid.json", valid_files)
|
91 |
-
]
|
92 |
-
for LEN in [100, 200, 500]:
|
93 |
-
for output_file, ff in things:
|
94 |
-
if output_file not in dd:
|
95 |
-
dd[output_file] = []
|
96 |
-
for f in ff:
|
97 |
-
with open(f) as input:
|
98 |
-
blob = input.read()
|
99 |
-
blob = blob.replace("~~", "")
|
100 |
-
blob = blob.replace("==", "")
|
101 |
-
blob = blob.replace("''", "")
|
102 |
-
blob = blob.replace("--", "")
|
103 |
-
blob = blob.replace("**", "")
|
104 |
-
dd[output_file].extend(get_sentences(blob, length=LEN))
|
105 |
-
for output_file, _ in things:
|
106 |
-
save_thing = [{"output": k.strip(), 'prompt_type': 'plain'} for k in dd[output_file]]
|
107 |
-
with open(output_file, "wt") as f:
|
108 |
-
f.write(json.dumps(save_thing, indent=2))
|
109 |
-
|
110 |
-
|
111 |
-
def get_sentences(blob, length):
|
112 |
-
"""
|
113 |
-
break-up input text into sentences and then output list of sentences of about length in size
|
114 |
-
:param blob:
|
115 |
-
:param length:
|
116 |
-
:return:
|
117 |
-
"""
|
118 |
-
import nltk
|
119 |
-
nltk.download('punkt')
|
120 |
-
from nltk.tokenize import sent_tokenize
|
121 |
-
sentences = sent_tokenize(blob)
|
122 |
-
my_sentences = []
|
123 |
-
my_string = ""
|
124 |
-
for sentence in sentences:
|
125 |
-
if len(my_string) + len(sentence) <= length:
|
126 |
-
if my_string:
|
127 |
-
my_string += " " + sentence
|
128 |
-
else:
|
129 |
-
my_string = sentence
|
130 |
-
else:
|
131 |
-
my_sentences.append(my_string)
|
132 |
-
my_string = ""
|
133 |
-
return my_sentences or [my_string]
|
134 |
-
|
135 |
-
|
136 |
-
def setup_dai_docs(path=None, dst="working_dir_docs", from_hf=False):
|
137 |
-
"""
|
138 |
-
Only supported if have access to source code or HF token for HF spaces and from_hf=True
|
139 |
-
:param path:
|
140 |
-
:param dst:
|
141 |
-
:param from_hf:
|
142 |
-
:return:
|
143 |
-
"""
|
144 |
-
|
145 |
-
home = os.path.expanduser('~')
|
146 |
-
|
147 |
-
if from_hf:
|
148 |
-
# assumes
|
149 |
-
from huggingface_hub import hf_hub_download
|
150 |
-
# True for case when locally already logged in with correct token, so don't have to set key
|
151 |
-
token = os.getenv('HUGGINGFACE_API_TOKEN', True)
|
152 |
-
path_to_zip_file = hf_hub_download('h2oai/dai_docs', 'dai_docs.zip', token=token, repo_type='dataset')
|
153 |
-
path = 'h2oai'
|
154 |
-
import zipfile
|
155 |
-
with zipfile.ZipFile(path_to_zip_file, 'r') as zip_ref:
|
156 |
-
zip_ref.extractall(path)
|
157 |
-
path = os.path.join(path, 'docs/**/*')
|
158 |
-
|
159 |
-
if path is None:
|
160 |
-
if os.path.isdir(os.path.join(home, 'h2oai')):
|
161 |
-
path = os.path.join(home, "h2oai/docs/**/*")
|
162 |
-
else:
|
163 |
-
assert os.path.isdir(os.path.join(home, 'h2oai.superclean')), '%s does not exist' % path
|
164 |
-
path = os.path.join(home, "h2oai.superclean/docs/**/*")
|
165 |
-
import glob
|
166 |
-
files = list(glob.glob(path, recursive=True))
|
167 |
-
|
168 |
-
# pandoc can't find include files
|
169 |
-
|
170 |
-
remove(dst)
|
171 |
-
os.makedirs(dst)
|
172 |
-
|
173 |
-
# copy full tree, for absolute paths in rst
|
174 |
-
for fil in files:
|
175 |
-
if os.path.isfile(fil):
|
176 |
-
shutil.copy(fil, dst)
|
177 |
-
|
178 |
-
# hack for relative path
|
179 |
-
scorers_dir = os.path.join(dst, 'scorers')
|
180 |
-
makedirs(scorers_dir)
|
181 |
-
for fil in glob.glob(os.path.join(dst, '*.frag')):
|
182 |
-
shutil.copy(fil, scorers_dir)
|
183 |
-
|
184 |
-
return dst
|
185 |
-
|
186 |
-
|
187 |
-
def rst_to_outputs(files, min_len=30, max_len=2048 // 2 - 30):
|
188 |
-
# account for sequence length (context window) including prompt and input and output
|
189 |
-
|
190 |
-
# os.system('pandoc -f rst -t plain ./expert_settings/nlp_settings.rst')
|
191 |
-
import pypandoc
|
192 |
-
basedir = os.path.abspath(os.getcwd())
|
193 |
-
|
194 |
-
outputs = []
|
195 |
-
for fil in files:
|
196 |
-
os.chdir(basedir)
|
197 |
-
os.chdir(os.path.dirname(fil))
|
198 |
-
fil = os.path.basename(fil)
|
199 |
-
print("Processing %s" % fil, flush=True)
|
200 |
-
# out_format can be one of: asciidoc, asciidoctor, beamer, biblatex, bibtex, commonmark, commonmark_x,
|
201 |
-
# context, csljson, docbook, docbook4, docbook5, docx, dokuwiki,
|
202 |
-
# dzslides, epub, epub2, epub3, fb2, gfm, haddock, html, html4, html5, icml,
|
203 |
-
# ipynb, jats, jats_archiving, jats_articleauthoring, jats_publishing, jira,
|
204 |
-
# json, latex, man,
|
205 |
-
# markdown, markdown_github, markdown_mmd, markdown_phpextra, markdown_strict,
|
206 |
-
# mediawiki, ms, muse, native, odt, opendocument, opml, org, pdf, plain, pptx,
|
207 |
-
# revealjs, rst, rtf, s5, slideous, slidy, tei, texinfo, textile, xwiki, zimwiki
|
208 |
-
out_format = 'plain'
|
209 |
-
# avoid extra new lines injected into text
|
210 |
-
extra_args = ['--wrap=preserve', '--resource path="%s" % dst']
|
211 |
-
|
212 |
-
plain_list = []
|
213 |
-
try:
|
214 |
-
# valid for expert settings
|
215 |
-
input_rst = pypandoc.convert_file(fil, 'rst')
|
216 |
-
input_list = input_rst.split('\n``')
|
217 |
-
for input_subrst in input_list:
|
218 |
-
input_plain = pypandoc.convert_text(input_subrst, format='rst', to='plain')
|
219 |
-
plain_list.append([input_plain, fil])
|
220 |
-
except Exception as e:
|
221 |
-
print("file exception: %s %s" % (fil, str(e)), flush=True)
|
222 |
-
|
223 |
-
if not plain_list:
|
224 |
-
# if failed to process as pieces of rst, then
|
225 |
-
output = pypandoc.convert_file(fil, out_format, extra_args=extra_args, format='rst')
|
226 |
-
outputs1 = get_sentences(output, length=max_len)
|
227 |
-
for oi, output in enumerate(outputs1):
|
228 |
-
output = output.replace('\n\n', '\n')
|
229 |
-
plain_list.append([output, fil])
|
230 |
-
outputs.extend(plain_list)
|
231 |
-
|
232 |
-
# report:
|
233 |
-
# [print(len(x)) for x in outputs]
|
234 |
-
|
235 |
-
# deal with blocks longer than context size (sequence length) of 2048
|
236 |
-
new_outputs = []
|
237 |
-
num_truncated = 0
|
238 |
-
num_orig = len(outputs)
|
239 |
-
for output, fil in outputs:
|
240 |
-
if len(output) < max_len:
|
241 |
-
new_outputs.append([output, fil])
|
242 |
-
continue
|
243 |
-
outputs1 = get_sentences(output, length=max_len)
|
244 |
-
for oi, output1 in enumerate(outputs1):
|
245 |
-
output1 = output1.replace('\n\n', '\n')
|
246 |
-
new_outputs.append([output1, fil])
|
247 |
-
num_truncated += 1
|
248 |
-
print('num_orig: %s num_truncated: %s' % (num_orig, num_truncated), flush=True)
|
249 |
-
|
250 |
-
new_outputs = [[k.strip(), fil] for k, fil in new_outputs if len(k.strip()) > min_len]
|
251 |
-
|
252 |
-
return new_outputs
|
253 |
-
|
254 |
-
|
255 |
-
def test_scrape_dai_docs_all_pandoc():
|
256 |
-
"""
|
257 |
-
pytest -s -v create_data.py::test_scrape_dai_docs_all_pandoc
|
258 |
-
:return:
|
259 |
-
"""
|
260 |
-
|
261 |
-
dst = setup_dai_docs()
|
262 |
-
|
263 |
-
import glob
|
264 |
-
files = list(glob.glob(os.path.join(dst, '*rst'), recursive=True))
|
265 |
-
|
266 |
-
basedir = os.path.abspath(os.getcwd())
|
267 |
-
new_outputs = rst_to_outputs(files)
|
268 |
-
os.chdir(basedir)
|
269 |
-
|
270 |
-
remove(dst)
|
271 |
-
save_thing = [{"output": k.strip(), 'prompt_type': 'plain'} for k in new_outputs]
|
272 |
-
output_file = "dai_docs.train_cleaned.json"
|
273 |
-
with open(output_file, "wt") as f:
|
274 |
-
f.write(json.dumps(save_thing, indent=2))
|
275 |
-
|
276 |
-
|
277 |
-
def test_config_to_json():
|
278 |
-
"""
|
279 |
-
Needs to run from Driverless AI source directory.
|
280 |
-
E.g. (base) jon@gpu:~/h2oai$ pytest -s -v /data/jon/h2ogpt/create_data.py::test_config_to_json ; cp config.json /data/jon/h2ogpt/
|
281 |
-
:return:
|
282 |
-
"""
|
283 |
-
try:
|
284 |
-
# Arrange
|
285 |
-
import json
|
286 |
-
from h2oaicore.systemutils import config
|
287 |
-
toml_list = []
|
288 |
-
for k, v in config.get_meta_dict().items():
|
289 |
-
title = (v.title + ": ") if v.title else ''
|
290 |
-
comment = v.comment or ''
|
291 |
-
if not (title or comment):
|
292 |
-
continue
|
293 |
-
toml_list.extend(
|
294 |
-
[
|
295 |
-
{
|
296 |
-
'prompt_type': 'plain',
|
297 |
-
'instruction': f"<human>: What does {k} do?\n<bot>: {k.replace('_', ' ')} config.toml: {comment or title}\n<human>:".replace(
|
298 |
-
"\n", ""),
|
299 |
-
},
|
300 |
-
{
|
301 |
-
'prompt_type': 'plain',
|
302 |
-
'instruction': f"<human>: Explain {k}.\n<bot>: {k.replace('_', ' ')} config.toml: {comment or title}\n<human>:".replace(
|
303 |
-
"\n", ""),
|
304 |
-
},
|
305 |
-
{
|
306 |
-
'prompt_type': 'plain',
|
307 |
-
'instruction': f"<human>: How can I do this: {title}.\n<bot>: Set the {k.replace('_', ' ')} config.toml\n<human>:".replace(
|
308 |
-
"\n", ""),
|
309 |
-
} if title and comment else None,
|
310 |
-
{
|
311 |
-
'prompt_type': 'human_bot',
|
312 |
-
'instruction': f'Explain the following expert setting for Driverless AI',
|
313 |
-
'input': f"{k}",
|
314 |
-
'output': f"{k.replace('_', ' ')} config.toml: {comment or title}".replace("\n", ""),
|
315 |
-
},
|
316 |
-
{
|
317 |
-
'prompt_type': 'human_bot',
|
318 |
-
'instruction': f'Explain the following expert setting for Driverless AI',
|
319 |
-
'input': f"{k}",
|
320 |
-
'output': f"{k.replace('_', ' ')} config.toml: {title}{comment}".replace("\n", ""),
|
321 |
-
},
|
322 |
-
{
|
323 |
-
'prompt_type': 'human_bot',
|
324 |
-
'instruction': f'Explain the following expert setting for Driverless AI',
|
325 |
-
'input': f"{k.replace('_', ' ')}",
|
326 |
-
'output': f"{k.replace('_', ' ')} config.toml: {title}{comment}".replace("\n", ""),
|
327 |
-
},
|
328 |
-
{
|
329 |
-
'prompt_type': 'human_bot',
|
330 |
-
'instruction': f'Explain the following expert setting for Driverless AI',
|
331 |
-
'input': f"{title}",
|
332 |
-
'output': f"{k.replace('_', ' ')} config.toml: {title}{comment}".replace("\n", ""),
|
333 |
-
},
|
334 |
-
{
|
335 |
-
'prompt_type': 'human_bot',
|
336 |
-
'instruction': f'Provide a short explanation of the expert setting {k}',
|
337 |
-
'output': f"{k.replace('_', ' ')} config.toml: {comment or title}".replace("\n", ""),
|
338 |
-
},
|
339 |
-
{
|
340 |
-
'prompt_type': 'human_bot',
|
341 |
-
'instruction': f'Provide a detailed explanation of the expert setting {k}',
|
342 |
-
'output': f"{k.replace('_', ' ')} config.toml: {title}{comment}".replace("\n", ""),
|
343 |
-
},
|
344 |
-
]
|
345 |
-
)
|
346 |
-
toml_list = [x for x in toml_list if x]
|
347 |
-
with open("config.json", "wt") as f:
|
348 |
-
f.write(json.dumps(toml_list, indent=2))
|
349 |
-
except Exception as e:
|
350 |
-
print("Exception: %s" % str(e), flush=True)
|
351 |
-
|
352 |
-
|
353 |
-
def copy_tree(src, dst, follow_symlink=False):
|
354 |
-
makedirs(dst, exist_ok=True)
|
355 |
-
for (path, dirs, files) in os.walk(src, followlinks=follow_symlink):
|
356 |
-
new_path = path.replace(src, dst)
|
357 |
-
makedirs(new_path, exist_ok=True)
|
358 |
-
for file in files:
|
359 |
-
filename = os.path.join(path, file)
|
360 |
-
new_filename = os.path.join(new_path, file)
|
361 |
-
# print("%s -> %s" % (filename, new_filename))
|
362 |
-
try:
|
363 |
-
atomic_copy(filename, new_filename)
|
364 |
-
except FileNotFoundError:
|
365 |
-
pass
|
366 |
-
|
367 |
-
|
368 |
-
def atomic_move(src, dst):
|
369 |
-
try:
|
370 |
-
shutil.move(src, dst)
|
371 |
-
except (shutil.Error, FileExistsError):
|
372 |
-
pass
|
373 |
-
remove(src)
|
374 |
-
|
375 |
-
|
376 |
-
def atomic_copy(src=None, dst=None, with_permissions=True):
|
377 |
-
if os.path.isfile(dst):
|
378 |
-
return
|
379 |
-
import uuid
|
380 |
-
my_uuid = uuid.uuid4()
|
381 |
-
dst_tmp = dst + str(my_uuid)
|
382 |
-
makedirs(os.path.dirname(dst), exist_ok=True)
|
383 |
-
if with_permissions:
|
384 |
-
shutil.copy(src, dst_tmp)
|
385 |
-
else:
|
386 |
-
shutil.copyfile(src, dst_tmp)
|
387 |
-
atomic_move(dst_tmp, dst)
|
388 |
-
remove(dst_tmp)
|
389 |
-
|
390 |
-
|
391 |
-
def makedirs(path, exist_ok=True):
|
392 |
-
"""
|
393 |
-
Avoid some inefficiency in os.makedirs()
|
394 |
-
:param path:
|
395 |
-
:param exist_ok:
|
396 |
-
:return:
|
397 |
-
"""
|
398 |
-
if os.path.isdir(path) and os.path.exists(path):
|
399 |
-
assert exist_ok, "Path already exists"
|
400 |
-
return path
|
401 |
-
os.makedirs(path, exist_ok=exist_ok)
|
402 |
-
|
403 |
-
|
404 |
-
## Download from https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_unfiltered_cleaned_split.json
|
405 |
-
## Turn into simple instruct prompt type. No context/previous conversations.
|
406 |
-
def test_prep_instruct_vicuna():
|
407 |
-
from datasets import load_dataset
|
408 |
-
filename = 'ShareGPT_unfiltered_cleaned_split.json'
|
409 |
-
if not os.path.exists(filename):
|
410 |
-
os.system(
|
411 |
-
'wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/%s' % filename)
|
412 |
-
data = load_dataset("json", data_files={"train": filename})["train"]
|
413 |
-
training_rows = []
|
414 |
-
for i in range(data.num_rows):
|
415 |
-
conversations = data[i]['conversations']
|
416 |
-
assert isinstance(conversations, list), conversations
|
417 |
-
convo = ""
|
418 |
-
for j, conv in enumerate(conversations):
|
419 |
-
# Get ready for generate.py prompt_type=human_bot
|
420 |
-
# But train with prompt_type=plain
|
421 |
-
if conv['from'] == 'human':
|
422 |
-
FROM = '<human>: '
|
423 |
-
elif conv['from'] == 'gpt':
|
424 |
-
FROM = '<bot>: '
|
425 |
-
convo += f"{FROM}" + conv['value'] + "\n"
|
426 |
-
if convo:
|
427 |
-
training_rows.append(dict(input=convo))
|
428 |
-
with open(filename + ".generate_human_bot.train_plain.json", "wt") as f:
|
429 |
-
f.write(json.dumps(training_rows, indent=2))
|
430 |
-
|
431 |
-
|
432 |
-
POSTFIX = ".generate_human_bot.train_plain.json"
|
433 |
-
|
434 |
-
# https://bair.berkeley.edu/blog/2023/04/03/koala/
|
435 |
-
OIG_DATASETS = [
|
436 |
-
"unified_chip2.jsonl",
|
437 |
-
"unified_grade_school_math_instructions.jsonl",
|
438 |
-
"unified_poetry_2_song.jsonl",
|
439 |
-
"unified_plot_screenplay_books_dialog.jsonl",
|
440 |
-
]
|
441 |
-
|
442 |
-
# hub issue: https://huggingface.co/datasets/laion/OIG/discussions/4
|
443 |
-
ALL_OIG_DATASETS = ['unified_abstract_infill.jsonl',
|
444 |
-
'unified_basic.jsonl',
|
445 |
-
'unified_canadian_parliament.jsonl',
|
446 |
-
'unified_chip2.jsonl',
|
447 |
-
'unified_conv_finqa.jsonl',
|
448 |
-
'unified_cuad.jsonl',
|
449 |
-
'unified_essays.jsonl',
|
450 |
-
'unified_flan.jsonl.gz',
|
451 |
-
'unified_grade_school_math_instructions.jsonl',
|
452 |
-
'unified_hc3_human.jsonl',
|
453 |
-
'unified_image_prompts_instructions.jsonl',
|
454 |
-
'unified_joke_explanations.jsonl',
|
455 |
-
'unified_mathqa_flanv2_kojma_cot.jsonl',
|
456 |
-
'unified_merged_code_xp3.jsonl',
|
457 |
-
'unified_multi_news.jsonl',
|
458 |
-
'unified_multi_sum.jsonl',
|
459 |
-
'unified_ni.jsonl.gz',
|
460 |
-
'unified_nq.jsonl',
|
461 |
-
'unified_openai_summarize_tldr.jsonl',
|
462 |
-
'unified_oscar_en_sample_dialog.jsonl',
|
463 |
-
'unified_p3.jsonl.gz',
|
464 |
-
'unified_plot_screenplay_books_dialog.jsonl',
|
465 |
-
'unified_poetry_2_song.jsonl',
|
466 |
-
'unified_poetry_instructions.jsonl',
|
467 |
-
'unified_rallio_safety_and_prosocial.jsonl',
|
468 |
-
'unified_rallio_soda_upgraded_2048.jsonl',
|
469 |
-
'unified_soda_dialog.jsonl',
|
470 |
-
'unified_sqlv1.jsonl',
|
471 |
-
'unified_sqlv2.jsonl',
|
472 |
-
'unified_squad_v2.jsonl',
|
473 |
-
'unified_squad_v2_more_neg.jsonl',
|
474 |
-
'unified_ul2_plus_oscar_en_sample_dialog.jsonl',
|
475 |
-
'unified_unifiedskg_instructions.jsonl',
|
476 |
-
'unified_unnatural_instructions.jsonl',
|
477 |
-
'unified_xp3_sample.jsonl']
|
478 |
-
|
479 |
-
useful_oig_files = ['unified_rallio_safety_and_prosocial.jsonl.parquet',
|
480 |
-
'unified_chip2.jsonl.parquet',
|
481 |
-
'unified_cuad.jsonl.parquet',
|
482 |
-
'unified_essays.jsonl.parquet',
|
483 |
-
'unified_flan.jsonl.gz.parquet',
|
484 |
-
'unified_grade_school_math_instructions.jsonl.parquet',
|
485 |
-
'unified_hc3_human.jsonl.parquet',
|
486 |
-
'unified_mathqa_flanv2_kojma_cot.jsonl.parquet',
|
487 |
-
'unified_merged_code_xp3.jsonl.parquet',
|
488 |
-
'unified_multi_news.jsonl.parquet',
|
489 |
-
# 'unified_multi_sum.jsonl.parquet'
|
490 |
-
'unified_ni.jsonl.gz.parquet',
|
491 |
-
'unified_openai_summarize_tldr.jsonl.parquet',
|
492 |
-
# 'unified_oscar_en_sample_dialog.jsonl.parquet', # create text containing these N words, not specific
|
493 |
-
'unified_plot_screenplay_books_dialog.jsonl.parquet',
|
494 |
-
'unified_soda_dialog.jsonl.parquet',
|
495 |
-
'unified_unnatural_instructions.jsonl.parquet',
|
496 |
-
]
|
497 |
-
|
498 |
-
|
499 |
-
@pytest.mark.parametrize("filename", OIG_DATASETS)
|
500 |
-
def test_get_small_sample_oig_data(filename):
|
501 |
-
if not os.path.exists(filename):
|
502 |
-
os.system('wget https://huggingface.co/datasets/laion/OIG/resolve/main/%s' % filename)
|
503 |
-
import json
|
504 |
-
rows = []
|
505 |
-
with open(filename, "r") as f:
|
506 |
-
for line in f.readlines():
|
507 |
-
row = json.loads(line)
|
508 |
-
rows.append(dict(input=row["text"]))
|
509 |
-
with open(filename + POSTFIX, "w") as f:
|
510 |
-
f.write(json.dumps(rows, indent=2))
|
511 |
-
|
512 |
-
|
513 |
-
@pytest.mark.parametrize("filename", ALL_OIG_DATASETS)
|
514 |
-
def test_download_useful_data_as_parquet(filename):
|
515 |
-
dest_file = filename + '.parquet'
|
516 |
-
if dest_file not in useful_oig_files:
|
517 |
-
pytest.skip('file declared not useful')
|
518 |
-
if not os.path.exists(filename):
|
519 |
-
os.system('wget https://huggingface.co/datasets/laion/OIG/resolve/main/%s' % filename)
|
520 |
-
if not os.path.exists(dest_file):
|
521 |
-
df = pd.read_json(path_or_buf=filename, lines=True)
|
522 |
-
df.to_parquet(dest_file, index=False)
|
523 |
-
|
524 |
-
|
525 |
-
def test_merge_shuffle_small_sample_oig_data():
|
526 |
-
np.random.seed(1234)
|
527 |
-
rows = []
|
528 |
-
for filename in OIG_DATASETS:
|
529 |
-
with open(filename + POSTFIX, "r") as f:
|
530 |
-
rows.extend(json.loads(f.read()))
|
531 |
-
np.random.shuffle(rows)
|
532 |
-
with open("merged_shuffled_OIG_%s.json" % hashlib.sha256(str(OIG_DATASETS).encode()).hexdigest()[:10], "w") as f:
|
533 |
-
f.write(json.dumps(rows, indent=2))
|
534 |
-
|
535 |
-
|
536 |
-
def test_join_jsons():
|
537 |
-
files = ['config.json'] * 1 + \
|
538 |
-
['dai_docs.train_cleaned.json'] * 2 + \
|
539 |
-
['dai_faq.json'] * 3
|
540 |
-
print(files)
|
541 |
-
lst = []
|
542 |
-
[lst.extend(json.load(open(fil, 'rt'))) for fil in files]
|
543 |
-
print(len(lst))
|
544 |
-
json.dump(lst, open("merged.json", "wt"), indent=2)
|
545 |
-
|
546 |
-
|
547 |
-
@pytest.mark.parametrize("filename", ['Anthropic/hh-rlhf'])
|
548 |
-
def test_make_rlhf_good_data(filename):
|
549 |
-
from datasets import load_dataset
|
550 |
-
rows = load_dataset(filename)["train"]["chosen"]
|
551 |
-
new_rows = []
|
552 |
-
for row in rows:
|
553 |
-
if row[:2] == "\n\n":
|
554 |
-
row = row[2:]
|
555 |
-
row = row.replace("Human: ", "<human>: ")
|
556 |
-
row = row.replace("Assistant: ", "<bot>: ")
|
557 |
-
new_rows.append(dict(input=row))
|
558 |
-
with open(filename.replace("/", "_") + POSTFIX, "w") as f:
|
559 |
-
f.write(json.dumps(new_rows, indent=2))
|
560 |
-
|
561 |
-
|
562 |
-
def test_show_prompts():
|
563 |
-
files = ['config.json'] * 1 + \
|
564 |
-
['dai_docs.train_cleaned.json'] * 1 + \
|
565 |
-
['dai_faq.json'] * 1
|
566 |
-
file_points = [json.load(open(fil, 'rt')) for fil in files]
|
567 |
-
from prompter import generate_prompt
|
568 |
-
for data_points in file_points:
|
569 |
-
for data_point in data_points:
|
570 |
-
print(generate_prompt(data_point, 'plain', '', False, False, False)[0])
|
571 |
-
|
572 |
-
|
573 |
-
def test_get_open_datasets():
|
574 |
-
# HF changed things so don't get raw list of all datasets, so not have to filter, but can't do negative filter
|
575 |
-
open_tags = ['license:Apache License 2.0',
|
576 |
-
'license:mit',
|
577 |
-
'license:apache',
|
578 |
-
'license:apache2',
|
579 |
-
'license:apache-2.0',
|
580 |
-
'license:bsd',
|
581 |
-
'license:bsd-2-clause',
|
582 |
-
'license:bsd-3-clause',
|
583 |
-
'license:bsd-3-clause-clear',
|
584 |
-
'license:lgpl-2.1',
|
585 |
-
'license:lgpl-3.0',
|
586 |
-
'license:lgpl-lr',
|
587 |
-
'license:lgpl',
|
588 |
-
'license:openrail++',
|
589 |
-
'license:openrail',
|
590 |
-
'license:bigscience-bloom-rail-1.0',
|
591 |
-
# 'license:agpl-3.0',
|
592 |
-
'license:other',
|
593 |
-
'license:unknown',
|
594 |
-
# 'license:mpl-2.0', # ok, but would have to include original copyright, license, source, copies in distribution
|
595 |
-
# Attribution required:
|
596 |
-
'license:odc-by',
|
597 |
-
'license:cc-by-4.0',
|
598 |
-
'license:cc-by-3.0',
|
599 |
-
'license:cc-by-2.0',
|
600 |
-
'license:cc-by-2.5',
|
601 |
-
# 'license:cc-by-sa-4.0', # would require same license
|
602 |
-
'license:odbl',
|
603 |
-
'license:pddl',
|
604 |
-
'license:ms-pl',
|
605 |
-
'license:zlib',
|
606 |
-
]
|
607 |
-
# bad license: cc-by-nc-4.0
|
608 |
-
|
609 |
-
from huggingface_hub import list_datasets
|
610 |
-
datasets = flatten_list([[x for x in list_datasets(filter=y)] for y in open_tags])
|
611 |
-
datasets += [x for x in list_datasets(author='openai')]
|
612 |
-
# check all:
|
613 |
-
all_license_tags = set(flatten_list([[y for y in x.tags if 'license' in y] for x in datasets]))
|
614 |
-
print(len(all_license_tags))
|
615 |
-
open_datasets = [x for x in datasets if any([y in x.tags for y in open_tags]) or 'license:' not in str(x.tags)]
|
616 |
-
print('open_datasets', len(open_datasets))
|
617 |
-
all_task_tags = set(flatten_list([[y for y in x.tags if 'task' in y] for x in open_datasets]))
|
618 |
-
print('all_task_tags', len(all_task_tags))
|
619 |
-
excluded_tags = ['image', 'hate', 'tabular', 'table-', 'classification', 'retrieval',
|
620 |
-
'translation', 'identification', 'object', 'mask', 'to-text',
|
621 |
-
'face-detection', 'audio', 'voice', 'reinforcement', 'depth-est',
|
622 |
-
'forecasting', 'parsing', 'visual', 'speech', 'multiple-choice',
|
623 |
-
'slot-filling', 'irds/argsme', '-scoring', 'other', 'graph-ml',
|
624 |
-
'feature-extraction', 'keyword-spotting',
|
625 |
-
'coreference-resolution', 'segmentation',
|
626 |
-
'word-sense-disambiguation',
|
627 |
-
'lemmatization']
|
628 |
-
task_tags = [x.replace('task_categories:', '').replace('task_ids:', '')
|
629 |
-
for x in all_task_tags if not any([y in x for y in
|
630 |
-
excluded_tags])]
|
631 |
-
print('task_tags', len(task_tags))
|
632 |
-
# str(x.tags) to catch any pattern match to anything in list
|
633 |
-
open_tasked_datasets = [x for x in open_datasets if
|
634 |
-
any([y in str([x for x in x.tags if 'task' in x]) for y in task_tags]) and
|
635 |
-
not any([y in str([x for x in x.tags if 'task' in x]) for y in excluded_tags]) or
|
636 |
-
'task_categories' not in str(x.tags) and 'task_ids' not in str(x.tags)]
|
637 |
-
open_tasked_datasets = [x for x in open_tasked_datasets if not x.disabled]
|
638 |
-
open_tasked_datasets = [x for x in open_tasked_datasets if not x.gated]
|
639 |
-
open_tasked_datasets = [x for x in open_tasked_datasets if not x.private]
|
640 |
-
print('open_tasked_datasets', len(open_tasked_datasets))
|
641 |
-
sizes = list(set(flatten_list([[(y, x.id) for y in x.tags if 'size' in y] for x in open_tasked_datasets])))
|
642 |
-
languages = list(set(flatten_list([[(y, x.id) for y in x.tags if 'language:' in y] for x in open_tasked_datasets])))
|
643 |
-
open_english_tasked_datasets = [x for x in open_tasked_datasets if
|
644 |
-
'language:' not in str(x.tags) or
|
645 |
-
'language:en' in str(x.tags)]
|
646 |
-
small_open_english_tasked_datasets = [x for x in open_english_tasked_datasets if
|
647 |
-
'n<1K' in str(x.tags) or
|
648 |
-
'1K<n<10K' in str(x.tags) or
|
649 |
-
'1K0<n<100K' in str(x.tags) or
|
650 |
-
'100K<n<1M' in str(x.tags) or
|
651 |
-
'size_category' not in str(x.tags)
|
652 |
-
]
|
653 |
-
# 'aeslc' : email_body, subject -> summarization?
|
654 |
-
# load_dataset(open_tasked_datasets[0].id).data['train'].to_pandas()
|
655 |
-
ids = [x.id for x in small_open_english_tasked_datasets]
|
656 |
-
|
657 |
-
# sanity checks
|
658 |
-
# https://bair.berkeley.edu/blog/2023/04/03/koala/
|
659 |
-
assert 'alespalla/chatbot_instruction_prompts' in ids
|
660 |
-
assert 'laion/OIG' in ids
|
661 |
-
assert 'openai/webgpt_comparisons' in ids
|
662 |
-
assert 'openai/summarize_from_feedback' in ids
|
663 |
-
assert 'Anthropic/hh-rlhf' in ids
|
664 |
-
|
665 |
-
# useful but not allowed for commercial purposes:
|
666 |
-
# https://huggingface.co/datasets/squad
|
667 |
-
|
668 |
-
print('open_english_tasked_datasets: ', ids, flush=True)
|
669 |
-
|
670 |
-
exclude_ids = ['allenai/nllb', # translation only
|
671 |
-
'hf-internal-testing/fixtures_image_utils', # testing
|
672 |
-
'allenai/c4', # search-url
|
673 |
-
'agemagician/uniref50', # unknown
|
674 |
-
'huggingface-course/documentation-images', # images
|
675 |
-
'smilegate-ai/kor_unsmile', # korean
|
676 |
-
'MohamedRashad/ChatGPT-prompts', # ChatGPT/LearnGPT/https://www.emergentmind.com/
|
677 |
-
'humarin/chatgpt-paraphrases', # Paraphrase using ChatGPT
|
678 |
-
'Jeska/vaccinchat', # not useful
|
679 |
-
'alespalla/chatbot_instruction_prompts', # mixes alpaca
|
680 |
-
'allenai/prosocial-dialog',
|
681 |
-
# already exlucded, but wrongly in other datasets that say more permissive license
|
682 |
-
'AlekseyKorshuk/persona-chat', # low quality
|
683 |
-
'bavard/personachat_truecased', # low quality
|
684 |
-
'adamlin/daily_dialog', # medium quality conversations
|
685 |
-
'adamlin/FewShotWoz', # low quality
|
686 |
-
'benjaminbeilharz/better_daily_dialog', # low quality
|
687 |
-
'benjaminbeilharz/daily_dialog_w_turn_templates', # low
|
688 |
-
'benjaminbeilharz/empathetic_dialogues_for_lm', # low
|
689 |
-
'GEM-submissions/GEM__bart_base_schema_guided_dialog__1645547915', # NA
|
690 |
-
'ia-bentebib/conv_ai_2_fr', # low fr
|
691 |
-
'ia-bentebib/daily_dialog_fr', # low fr
|
692 |
-
'ia-bentebib/dialog_re_fr', # low fr
|
693 |
-
'ia-bentebib/empathetic_dialogues_fr', # low fr
|
694 |
-
'roskoN/dailydialog', # low
|
695 |
-
'VadorMazer/skyrimdialogstest', # low
|
696 |
-
'bigbio/med_qa', # med specific Q/A
|
697 |
-
'biu-nlp/qa_srl2018', # low quality Q/A
|
698 |
-
'biu-nlp/qa_discourse', # low quality Q/A
|
699 |
-
'iarfmoose/qa_evaluator', # low quality Q/A
|
700 |
-
'jeopardy', # low quality Q/A -- no reasoning
|
701 |
-
'narrativeqa', # low quality Q/A
|
702 |
-
'nomic-ai/gpt4all_prompt_generations', # bad license
|
703 |
-
'nomic-ai/gpt4all_prompt_generations_with_p3', # bad license
|
704 |
-
'HuggingFaceH4/alpaca', # bad license
|
705 |
-
'tatsu-lab/alpaca', # ToS breaking
|
706 |
-
'yahma/alpaca-cleaned', # ToS breaking
|
707 |
-
'Hello-SimpleAI/HC3', # bad license
|
708 |
-
'glue', # no reasoning QA
|
709 |
-
'sahil2801/CodeAlpaca-20k', # bad license
|
710 |
-
'Short-Answer-Feedback/saf_communication_networks_english', # long Q, medium A
|
711 |
-
]
|
712 |
-
small_open_english_tasked_datasets = [x for x in small_open_english_tasked_datasets if x.id not in exclude_ids]
|
713 |
-
# some ids clearly speech related
|
714 |
-
small_open_english_tasked_datasets = [x for x in small_open_english_tasked_datasets if 'speech' not in x.id]
|
715 |
-
# HF testing
|
716 |
-
small_open_english_tasked_datasets = [x for x in small_open_english_tasked_datasets if
|
717 |
-
'hf-internal-testing' not in x.id]
|
718 |
-
small_open_english_tasked_datasets = [x for x in small_open_english_tasked_datasets if
|
719 |
-
'chinese' not in x.id]
|
720 |
-
|
721 |
-
sorted_small_open_english_tasked_datasets = sorted([(x.downloads, x) for x in small_open_english_tasked_datasets],
|
722 |
-
key=lambda x: x[0], reverse=True)
|
723 |
-
|
724 |
-
# NOTES:
|
725 |
-
# Run like pytest -s -v create_data.py::test_get_open_datasets &> getdata9.log
|
726 |
-
# See what needs config passed and add:
|
727 |
-
# grep 'load_dataset(' getdata9.log|grep -v data_id|less -S
|
728 |
-
# grep "pip install" getdata9.log
|
729 |
-
# NOTE: Some datasets have default config, but others are there. Don't know how to access them.
|
730 |
-
|
731 |
-
"""
|
732 |
-
https://huggingface.co/datasets/wikihow/blob/main/wikihow.py
|
733 |
-
https://github.com/mahnazkoupaee/WikiHow-Dataset
|
734 |
-
https://ucsb.box.com/s/ap23l8gafpezf4tq3wapr6u8241zz358
|
735 |
-
https://ucsb.app.box.com/s/ap23l8gafpezf4tq3wapr6u8241zz358
|
736 |
-
"""
|
737 |
-
|
738 |
-
"""
|
739 |
-
# some ambiguous or non-commercial datasets
|
740 |
-
https://github.com/PhoebusSi/alpaca-CoT
|
741 |
-
"""
|
742 |
-
|
743 |
-
timeout = 3 * 60
|
744 |
-
# laion/OIG takes longer
|
745 |
-
for num_downloads, dataset in sorted_small_open_english_tasked_datasets:
|
746 |
-
data_id = dataset.id
|
747 |
-
func = do_one
|
748 |
-
args = (data_id, num_downloads)
|
749 |
-
kwargs = {}
|
750 |
-
with ProcessPoolExecutor(max_workers=1) as executor:
|
751 |
-
future = executor.submit(func, *args, **kwargs)
|
752 |
-
try:
|
753 |
-
future.result(timeout=timeout)
|
754 |
-
except concurrent.futures.TimeoutError:
|
755 |
-
print("\n\ndata_id %s timeout\n\n" % data_id, flush=True)
|
756 |
-
for child in psutil.Process(os.getpid()).children(recursive=True):
|
757 |
-
os.kill(child.pid, signal.SIGINT)
|
758 |
-
os.kill(child.pid, signal.SIGTERM)
|
759 |
-
os.kill(child.pid, signal.SIGKILL)
|
760 |
-
|
761 |
-
|
762 |
-
def do_one(data_id, num_downloads):
|
763 |
-
from datasets import load_dataset
|
764 |
-
out_file = "data_%s.parquet" % str(data_id.replace('/', '_'))
|
765 |
-
if os.path.isfile(out_file) and os.path.getsize(out_file) > 1024 ** 3:
|
766 |
-
return
|
767 |
-
try:
|
768 |
-
print("Loading data_id %s num_downloads: %s" % (data_id, num_downloads), flush=True)
|
769 |
-
avail_list = None
|
770 |
-
try:
|
771 |
-
data = load_dataset(data_id, 'foobar')
|
772 |
-
except Exception as e:
|
773 |
-
if 'Available: ' in str(e):
|
774 |
-
avail_list = ast.literal_eval(str(e).split('Available:')[1].strip())
|
775 |
-
else:
|
776 |
-
avail_list = None
|
777 |
-
if avail_list is None:
|
778 |
-
avail_list = [None]
|
779 |
-
print("%s avail_list: %s" % (data_id, avail_list), flush=True)
|
780 |
-
|
781 |
-
for name in avail_list:
|
782 |
-
out_file = "data_%s_%s.parquet" % (str(data_id.replace('/', '_')), str(name))
|
783 |
-
if os.path.isfile(out_file):
|
784 |
-
continue
|
785 |
-
data = load_dataset(data_id, name)
|
786 |
-
column_names_dict = data.column_names
|
787 |
-
column_names = column_names_dict[list(column_names_dict.keys())[0]]
|
788 |
-
print("Processing data_id %s num_downloads: %s columns: %s" % (data_id, num_downloads, column_names),
|
789 |
-
flush=True)
|
790 |
-
data_dict = data.data
|
791 |
-
col_dict = data.num_columns
|
792 |
-
first_col = list(col_dict.keys())[0]
|
793 |
-
if 'train' in data_dict:
|
794 |
-
df = data['train'].to_pandas()
|
795 |
-
else:
|
796 |
-
df = data[first_col].to_pandas()
|
797 |
-
# csv has issues with escaping chars, even for datasets I know I want
|
798 |
-
df.to_parquet(out_file, index=False)
|
799 |
-
except Exception as e:
|
800 |
-
t, v, tb = sys.exc_info()
|
801 |
-
ex = ''.join(traceback.format_exception(t, v, tb))
|
802 |
-
print("Exception: %s %s" % (data_id, ex), flush=True)
|
803 |
-
|
804 |
-
|
805 |
-
def test_otherlic():
|
806 |
-
from huggingface_hub import list_datasets
|
807 |
-
lic = ['license:odc-by',
|
808 |
-
'license:cc-by-4.0',
|
809 |
-
'license:cc-by-3.0',
|
810 |
-
'license:cc-by-2.0',
|
811 |
-
'license:cc-by-2.5',
|
812 |
-
'license:cc-by-sa-4.0',
|
813 |
-
'license:odbl',
|
814 |
-
'license:pddl',
|
815 |
-
'license:ms-pl',
|
816 |
-
'license:zlib',
|
817 |
-
]
|
818 |
-
datasets = flatten_list([[x for x in list_datasets(filter=y) if 'translation' not in str(x.tags)] for y in lic])
|
819 |
-
print(len(datasets))
|
820 |
-
|
821 |
-
|
822 |
-
# These useful datasets are determined based upon data sample, column types, and uniqueness compared to larger datasets like Pile
|
823 |
-
# grep columns getdata13.log|grep -v "\['image'\]"|sort|uniq|grep -v tokens|grep -v "'image'"|grep -v embedding|grep dialog
|
824 |
-
useful = ['Dahoas/instruct-human-assistant-prompt',
|
825 |
-
'Dahoas/first-instruct-human-assistant-prompt',
|
826 |
-
'knkarthick/dialogsum', # summary of conversation
|
827 |
-
'McGill-NLP/FaithDial', # medium quality
|
828 |
-
'Zaid/quac_expanded', # medium quality context + QA
|
829 |
-
'0-hero/OIG-small-chip2', # medium
|
830 |
-
'alistvt/coqa-flat', # QA medium
|
831 |
-
'AnonymousSub/MedQuAD_47441_Question_Answer_Pairs', # QA medium
|
832 |
-
'Anthropic/hh-rlhf', # high quality # similar to Dahoas/full-hh-rlhf
|
833 |
-
'arjunth2001/online_privacy_qna', # good quality QA
|
834 |
-
'Dahoas/instruct_helpful_preferences', # medium quality instruct
|
835 |
-
'Dahoas/rl-prompt-dataset', # medium chat
|
836 |
-
'Dahoas/rm-static', # medium chat
|
837 |
-
'Dahoas/static-hh', # medium chat # HuggingFaceH4/self_instruct
|
838 |
-
'Dahoas/synthetic-instruct-gptj-pairwise', # medium chat
|
839 |
-
'eli5', # QA if prompt ELI5
|
840 |
-
'gsm8k', # QA (various)
|
841 |
-
'guanaco/guanaco', # prompt/response
|
842 |
-
'kastan/rlhf-qa-comparisons', # good QA
|
843 |
-
'kastan/rlhf-qa-conditional-generation-v2', # prompt answer
|
844 |
-
'OllieStanley/humaneval-mbpp-codegen-qa', # code QA, but started from words, so better than other code QA
|
845 |
-
'OllieStanley/humaneval-mbpp-testgen-qa', # code QA
|
846 |
-
'Graverman/Instruct-to-Code', # code QA
|
847 |
-
'openai/summarize_from_feedback', # summarize
|
848 |
-
'relbert/analogy_questions', # analogy QA
|
849 |
-
'yitingxie/rlhf-reward-datasets', # prompt, chosen, rejected.
|
850 |
-
'yizhongw/self_instruct', # instruct (super natural & instruct)
|
851 |
-
'HuggingFaceH4/asss', # QA, big A
|
852 |
-
'kastan/rlhf-qa-conditional-generation-v2', # QA
|
853 |
-
'cosmos_qa', # context QA
|
854 |
-
'vishal-burman/c4-faqs', # QA but not so much reasoning, but alot of text
|
855 |
-
'squadshifts', # QA from context
|
856 |
-
'hotpot_qa', # QA from context
|
857 |
-
'adversarial_qa', # QA from context
|
858 |
-
'allenai/soda', # dialog -> narrative/summary
|
859 |
-
'squad_v2', # context QA
|
860 |
-
'squadshifts', # context QA
|
861 |
-
'dferndz/cSQuAD1', # context QA
|
862 |
-
'dferndz/cSQuAD2', # context QA
|
863 |
-
'din0s/msmarco-nlgen', # context QA
|
864 |
-
'domenicrosati/TruthfulQA', # common sense truthful QA -- trivia but good trivia
|
865 |
-
'hotpot_qa', # context, QA
|
866 |
-
'HuggingFaceH4/self-instruct-eval', # instruct QA, medium quality, some language reasoning
|
867 |
-
'kastan/EE_QA_for_RLHF', # context QA
|
868 |
-
'KK04/LogicInference_OA', # instruction logical QA
|
869 |
-
'lmqg/qa_squadshifts_synthetic', # context QA
|
870 |
-
'lmqg/qg_squad', # context QA
|
871 |
-
'lmqg/qg_squadshifts', # context QA
|
872 |
-
'lmqg/qg_subjqa', # context QA
|
873 |
-
'pszemraj/HC3-textgen-qa',
|
874 |
-
# QA medium, has human responses -- humans tend to provide links instead of trying to answer
|
875 |
-
'pythonist/newdata', # long context, QA, brief A
|
876 |
-
'ropes', # long background, situation, question, A
|
877 |
-
'wikitablequestions', # table -> QA
|
878 |
-
'bigscience/p3', # context QA but short answers
|
879 |
-
]
|
880 |
-
|
881 |
-
code_useful = ['0n1xus/codexglue',
|
882 |
-
'openai_humaneval',
|
883 |
-
'koutch/staqc',
|
884 |
-
]
|
885 |
-
|
886 |
-
maybe_useful = ['AlekseyKorshuk/comedy-scripts',
|
887 |
-
'openbookqa', # hard to parse, low reasoning
|
888 |
-
'qed', # reasonable QA, but low reasoning
|
889 |
-
'selqa', # candidate answers
|
890 |
-
'HuggingFaceH4/instruction-pilot-outputs-filtered',
|
891 |
-
'GBaker/MedQA-USMLE-4-options', # medical QA with long questions
|
892 |
-
'npc-engine/light-batch-summarize-dialogue', # dialog summarize, kinda low specific quality
|
893 |
-
]
|
894 |
-
|
895 |
-
summary_useful = ['austin/rheum_abstracts',
|
896 |
-
'CarperAI/openai_summarize_comparisons', # summarize chosen/rejected
|
897 |
-
'CarperAI/openai_summarize_tldr', # summarize QA
|
898 |
-
'ccdv/cnn_dailymail', # summarize news
|
899 |
-
'ccdv/govreport-summarization', # summarize high quality
|
900 |
-
'ccdv/pubmed-summarization', # summarize high quality
|
901 |
-
'duorc', # plot -> QA
|
902 |
-
'farleyknight/big_patent_5_percent', # desc -> abstract
|
903 |
-
'multi_news', # summary
|
904 |
-
'opinosis',
|
905 |
-
'SophieTr/reddit_clean',
|
906 |
-
'allenai/mup', # long text -> summary
|
907 |
-
'allenai/multi_lexsum', # long text -> summary
|
908 |
-
'big_patent',
|
909 |
-
'allenai/wcep_dense_max',
|
910 |
-
'awinml/costco_long_practice',
|
911 |
-
'GEM/xsum',
|
912 |
-
'ratishsp/newshead',
|
913 |
-
'RussianNLP/wikiomnia', # russian
|
914 |
-
'stacked-summaries/stacked-xsum-1024',
|
915 |
-
]
|
916 |
-
|
917 |
-
math_useful = [
|
918 |
-
'competition_math'
|
919 |
-
]
|
920 |
-
|
921 |
-
skipped = ['c4', # maybe useful, used for flan, but skipped due to size
|
922 |
-
]
|
923 |
-
|
924 |
-
"""
|
925 |
-
To get training data from oig:
|
926 |
-
pytest test_oig test_grade_final test_finalize_to_json
|
927 |
-
"""
|
928 |
-
|
929 |
-
human = '<human>:'
|
930 |
-
bot = '<bot>:'
|
931 |
-
|
932 |
-
|
933 |
-
def test_assemble_and_detox():
|
934 |
-
import re
|
935 |
-
from profanity_check import predict_prob
|
936 |
-
df_list = []
|
937 |
-
for data in useful_oig_files:
|
938 |
-
print("Processing %s" % data, flush=True)
|
939 |
-
df = pd.read_parquet(data)
|
940 |
-
df = df.reset_index(drop=True)
|
941 |
-
# chop up into human/bot interactions of no more than 10kB per row
|
942 |
-
text_list = df[['text']].values.ravel().tolist()
|
943 |
-
new_text = []
|
944 |
-
max_len = 2048 # uber cutoff
|
945 |
-
MAX_LEN = 2048 // 2 - 30 # max len per question/answer
|
946 |
-
for text in tqdm(text_list):
|
947 |
-
human_starts = [m.start() for m in re.finditer('<human>: ', text)]
|
948 |
-
if len(human_starts) == 1:
|
949 |
-
human_starts = [0, len(text)] # always go into for loop below
|
950 |
-
blurb = ''
|
951 |
-
for i in range(len(human_starts) - 1):
|
952 |
-
interaction = text[human_starts[i]: human_starts[i + 1]][:max_len]
|
953 |
-
blurb += interaction
|
954 |
-
if len(blurb) >= MAX_LEN:
|
955 |
-
blurb = get_sentences(blurb, length=MAX_LEN)[0]
|
956 |
-
new_text.append(blurb + "\n<human>:")
|
957 |
-
blurb = ''
|
958 |
-
if blurb:
|
959 |
-
blurb = get_sentences(blurb, length=MAX_LEN)[0]
|
960 |
-
new_text.append(blurb + "\n<human>:")
|
961 |
-
|
962 |
-
if len(new_text) > len(text_list):
|
963 |
-
print("Added %d new rows (before: %d)" % (len(new_text) - df.shape[0], df.shape[0]))
|
964 |
-
df = pd.DataFrame({"text": new_text, "source": [data] * len(new_text)})
|
965 |
-
df = df.drop_duplicates(keep='first')
|
966 |
-
print(df['text'].apply(lambda x: len(x)).describe())
|
967 |
-
assert df['text'].apply(lambda x: len(x)).max() <= 2 * max_len
|
968 |
-
|
969 |
-
# faster than better_profanity, do early
|
970 |
-
df['profanity'] = predict_prob(df['text'])
|
971 |
-
before_rows = df.shape[0]
|
972 |
-
df = df[df['profanity'] < 0.25] # drop any low quality stuff
|
973 |
-
after_rows = df.shape[0]
|
974 |
-
print("Dropped %d rows out of %d due to alt-profanity-check" % (before_rows - after_rows, before_rows))
|
975 |
-
df_list.append(df)
|
976 |
-
print("Done processing %s -> %s rows" % (data, df.shape[0]), flush=True)
|
977 |
-
print("So far have %d rows" % sum([len(x) for x in df_list]))
|
978 |
-
df_final = pd.concat(df_list)
|
979 |
-
df_final = df_final.sample(frac=1, random_state=1234).reset_index(drop=True)
|
980 |
-
df_final.to_parquet('h2oGPT.cleaned.human_bot.shorter.parquet', index=False)
|
981 |
-
|
982 |
-
|
983 |
-
def test_basic_cleaning():
|
984 |
-
# from better_profanity import profanity
|
985 |
-
# https://pypi.org/project/alt-profanity-check/
|
986 |
-
from profanity_check import predict
|
987 |
-
df_list = []
|
988 |
-
for data in useful_oig_files:
|
989 |
-
# for data in useful_oig_files[:5]:
|
990 |
-
# for data in ['unified_openai_summarize_tldr.jsonl.parquet']:
|
991 |
-
print("Processing %s" % data, flush=True)
|
992 |
-
df = pd.read_parquet(data)
|
993 |
-
df = df.reset_index(drop=True)
|
994 |
-
# NOTE: Not correct if multiple human-bot interactions, but those dialogs even more desired
|
995 |
-
# avg_chars = len(df['text'][0])/(df['text'][0].count(human)+df['text'][0].count(bot))
|
996 |
-
df['avg_words'] = df['text'].apply(lambda x: x.count(' ') / (x.count(human) + x.count(bot)) / 2.0)
|
997 |
-
df['avg_bot_words'] = df['text'].apply(lambda x: x.split(bot)[1].count(' ') / x.count(bot))
|
998 |
-
# df['bad_words'] = df['text'].apply(lambda x: profanity.contains_profanity(x))
|
999 |
-
# low_quality_patterns = ['Write the rest of this wikipedia article']
|
1000 |
-
res = predict(df['text'])
|
1001 |
-
df['bad_words'] = res
|
1002 |
-
df = df.reset_index(drop=True)
|
1003 |
-
df = df[df['bad_words'] == 0]
|
1004 |
-
df = df[['text', 'avg_words', 'avg_bot_words']]
|
1005 |
-
df = df.drop_duplicates(keep='first')
|
1006 |
-
print(df[df['avg_words'] == df['avg_words'].max()]['text'].values)
|
1007 |
-
median_words = np.median(df['avg_words'])
|
1008 |
-
min_words_per_entity = max(30, 0.8 * median_words)
|
1009 |
-
max_words_per_entity = 2048 # too hard to learn from for now
|
1010 |
-
df = df[df['avg_words'] > min_words_per_entity]
|
1011 |
-
df = df[df['avg_words'] < max_words_per_entity]
|
1012 |
-
|
1013 |
-
min_words_per_entity = max(20, 0.5 * median_words) # bot should say stuff for now
|
1014 |
-
max_words_per_entity = 2048 # too hard to learn from for now
|
1015 |
-
df = df[df['avg_bot_words'] > min_words_per_entity]
|
1016 |
-
df = df[df['avg_bot_words'] < max_words_per_entity]
|
1017 |
-
|
1018 |
-
df_list.append(df)
|
1019 |
-
print("Done processing %s -> %s rows" % (data, df.shape[0]), flush=True)
|
1020 |
-
df_final = pd.concat(df_list)
|
1021 |
-
df_final.to_parquet('h2oGPT.cleaned.human_bot.parquet', index=False)
|
1022 |
-
|
1023 |
-
|
1024 |
-
from joblib import Parallel, delayed, effective_n_jobs
|
1025 |
-
from sklearn.utils import gen_even_slices
|
1026 |
-
from sklearn.utils.validation import _num_samples
|
1027 |
-
|
1028 |
-
|
1029 |
-
def parallel_apply(df, func, n_jobs=-1, **kwargs):
|
1030 |
-
""" Pandas apply in parallel using joblib.
|
1031 |
-
Uses sklearn.utils to partition input evenly.
|
1032 |
-
|
1033 |
-
Args:
|
1034 |
-
df: Pandas DataFrame, Series, or any other object that supports slicing and apply.
|
1035 |
-
func: Callable to apply
|
1036 |
-
n_jobs: Desired number of workers. Default value -1 means use all available cores.
|
1037 |
-
**kwargs: Any additional parameters will be supplied to the apply function
|
1038 |
-
|
1039 |
-
Returns:
|
1040 |
-
Same as for normal Pandas DataFrame.apply()
|
1041 |
-
|
1042 |
-
"""
|
1043 |
-
|
1044 |
-
if effective_n_jobs(n_jobs) == 1:
|
1045 |
-
return df.apply(func, **kwargs)
|
1046 |
-
else:
|
1047 |
-
ret = Parallel(n_jobs=n_jobs)(
|
1048 |
-
delayed(type(df).apply)(df[s], func, **kwargs)
|
1049 |
-
for s in gen_even_slices(_num_samples(df), effective_n_jobs(n_jobs)))
|
1050 |
-
return pd.concat(ret)
|
1051 |
-
|
1052 |
-
|
1053 |
-
def add_better_profanity_flag(df):
|
1054 |
-
from better_profanity import profanity
|
1055 |
-
df['better_profanity'] = parallel_apply(
|
1056 |
-
df['text'],
|
1057 |
-
lambda x: profanity.contains_profanity(x),
|
1058 |
-
n_jobs=-1,
|
1059 |
-
)
|
1060 |
-
return df
|
1061 |
-
|
1062 |
-
|
1063 |
-
def add_textstat_grade(df):
|
1064 |
-
import textstat
|
1065 |
-
|
1066 |
-
def myfunc(x):
|
1067 |
-
return textstat.flesch_kincaid_grade(x) # simple grade
|
1068 |
-
|
1069 |
-
if False:
|
1070 |
-
import dask.dataframe as dd
|
1071 |
-
# 40 seconds for 1000 rows, but have 1,787,799 rows
|
1072 |
-
ddata = dd.from_pandas(df, npartitions=120)
|
1073 |
-
|
1074 |
-
df['flesch_grade'] = ddata['text'].apply(myfunc).compute()
|
1075 |
-
if True:
|
1076 |
-
# fast way
|
1077 |
-
df['flesch_grade'] = parallel_apply(df['text'], myfunc, n_jobs=-1)
|
1078 |
-
return df
|
1079 |
-
|
1080 |
-
|
1081 |
-
def add_deberta_grade(df):
|
1082 |
-
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
1083 |
-
import torch
|
1084 |
-
reward_name = "OpenAssistant/reward-model-deberta-v3-large-v2"
|
1085 |
-
rank_model, tokenizer = AutoModelForSequenceClassification.from_pretrained(
|
1086 |
-
reward_name), AutoTokenizer.from_pretrained(reward_name)
|
1087 |
-
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
1088 |
-
rank_model.to(device)
|
1089 |
-
|
1090 |
-
def get_question(x):
|
1091 |
-
return x.replace('<human>: ', '').split('<bot>:')[0]
|
1092 |
-
|
1093 |
-
def get_answer(x):
|
1094 |
-
try:
|
1095 |
-
answer = x.split('<bot>: ')[1].split('<human>:')[0].replace('<bot>: ', '')
|
1096 |
-
except:
|
1097 |
-
answer = x.split('<bot>:')[1].split('<human>:')[0].replace('<bot>:', '')
|
1098 |
-
return answer
|
1099 |
-
|
1100 |
-
df['question'] = parallel_apply(df['text'], get_question, n_jobs=-1)
|
1101 |
-
df['answer'] = parallel_apply(df['text'], get_answer, n_jobs=-1)
|
1102 |
-
|
1103 |
-
from datasets import Dataset
|
1104 |
-
from transformers import pipeline
|
1105 |
-
from transformers.pipelines.pt_utils import KeyPairDataset
|
1106 |
-
import tqdm
|
1107 |
-
|
1108 |
-
pipe = pipeline(
|
1109 |
-
"text-classification",
|
1110 |
-
model=reward_name,
|
1111 |
-
device="cuda:0" if torch.cuda.is_available() else "cpu"
|
1112 |
-
)
|
1113 |
-
start = 0
|
1114 |
-
batch_size = 64 * 16
|
1115 |
-
micro_batch = orig_micro_batch = 16
|
1116 |
-
end = 0
|
1117 |
-
import socket
|
1118 |
-
checkpoint = "grades.%s.pkl" % socket.gethostname()
|
1119 |
-
grades = []
|
1120 |
-
import pickle
|
1121 |
-
if os.path.exists(checkpoint):
|
1122 |
-
with open(checkpoint, "rb") as f:
|
1123 |
-
start, grades = pickle.loads(f.read())
|
1124 |
-
last_oom = 0
|
1125 |
-
while end < df.shape[0]:
|
1126 |
-
# manual batching to handle OOM more gracefully
|
1127 |
-
end = min(start + batch_size, df.shape[0])
|
1128 |
-
if start == end:
|
1129 |
-
break
|
1130 |
-
dataset = Dataset.from_pandas(df.iloc[start:end, :])
|
1131 |
-
try:
|
1132 |
-
grades.extend([
|
1133 |
-
x['score'] for x in tqdm.tqdm(
|
1134 |
-
pipe(KeyPairDataset(dataset, "question", "answer"), batch_size=micro_batch)
|
1135 |
-
)
|
1136 |
-
])
|
1137 |
-
except torch.cuda.OutOfMemoryError:
|
1138 |
-
last_oom = start
|
1139 |
-
micro_batch = max(1, micro_batch // 2)
|
1140 |
-
print("OOM - retrying with micro_batch=%d" % micro_batch)
|
1141 |
-
continue
|
1142 |
-
if last_oom == start:
|
1143 |
-
micro_batch = orig_micro_batch
|
1144 |
-
print("Returning to micro_batch=%d" % micro_batch)
|
1145 |
-
assert len(grades) == end
|
1146 |
-
start = end
|
1147 |
-
with open(checkpoint, "wb") as f:
|
1148 |
-
f.write(pickle.dumps((end, grades)))
|
1149 |
-
print("%d/%d" % (end, df.shape[0]))
|
1150 |
-
df['grade_deberta'] = grades
|
1151 |
-
if os.path.exists(checkpoint):
|
1152 |
-
os.remove(checkpoint)
|
1153 |
-
return df
|
1154 |
-
|
1155 |
-
|
1156 |
-
def test_chop_by_lengths():
|
1157 |
-
file = "h2oGPT.cleaned.human_bot.shorter.parquet"
|
1158 |
-
df = pd.read_parquet(file).reset_index(drop=True)
|
1159 |
-
df = count_human_bot_lengths(df)
|
1160 |
-
df['rand'] = np.random.rand(df.shape[0])
|
1161 |
-
df['rand2'] = np.random.rand(df.shape[0])
|
1162 |
-
before_rows = df.shape[0]
|
1163 |
-
# throw away short human/bot responses with higher likelihood
|
1164 |
-
df = df[(df['len_human_mean'] > 20)] # never keep very short ones
|
1165 |
-
df = df[(df['len_human_mean'] > 30) | (df['rand'] < 0.2)]
|
1166 |
-
df = df[(df['len_human_mean'] > 50) | (df['rand'] < 0.5)]
|
1167 |
-
df = df[(df['len_human_max'] < 10000)] # drop super long (basically only human) ones
|
1168 |
-
df = df[(df['len_bot_mean'] > 20)] # never keep very short ones
|
1169 |
-
df = df[(df['len_bot_mean'] > 30) | (df['rand2'] < 0.2)]
|
1170 |
-
df = df[(df['len_bot_mean'] > 50) | (df['rand2'] < 0.5)]
|
1171 |
-
df = df[(df['len_bot_max'] < 10000)] # drop super long (only bot) ones
|
1172 |
-
assert df['text'].apply(lambda x: len(x)).max() < 20000
|
1173 |
-
df = df.drop(['rand', 'rand2'], axis=1)
|
1174 |
-
after_rows = df.shape[0]
|
1175 |
-
print("Chopped off %d out of %d rows due to length" % (before_rows - after_rows, before_rows))
|
1176 |
-
print(df.describe())
|
1177 |
-
df.to_parquet('h2oGPT.cleaned.chopped.human_bot.shorter.parquet', index=False)
|
1178 |
-
|
1179 |
-
|
1180 |
-
def count_human_bot_lengths(df, human=None, bot=None):
|
1181 |
-
import re
|
1182 |
-
len_human_min = []
|
1183 |
-
len_human_max = []
|
1184 |
-
len_human_mean = []
|
1185 |
-
len_bot_min = []
|
1186 |
-
len_bot_max = []
|
1187 |
-
len_bot_mean = []
|
1188 |
-
human = human or '<human>:'
|
1189 |
-
bot = bot or '<bot>:'
|
1190 |
-
for is_human in [True, False]:
|
1191 |
-
what = human if is_human else bot
|
1192 |
-
other = human if not is_human else bot
|
1193 |
-
for i in range(df.shape[0]):
|
1194 |
-
text = df.loc[i, 'text']
|
1195 |
-
assert isinstance(text, str)
|
1196 |
-
starts = [m.start() for m in re.finditer(what, text)]
|
1197 |
-
if len(starts) == 1:
|
1198 |
-
starts = [starts[0], len(text)] # always go into for loop below
|
1199 |
-
assert len(text)
|
1200 |
-
list_what = []
|
1201 |
-
for ii in range(len(starts) - 1):
|
1202 |
-
interaction = text[starts[ii]: starts[ii + 1]]
|
1203 |
-
if other in interaction:
|
1204 |
-
interaction = interaction[:interaction.find(other)]
|
1205 |
-
interaction.strip()
|
1206 |
-
list_what.append(interaction)
|
1207 |
-
if not list_what:
|
1208 |
-
list_what = [''] # handle corrupted data, very rare, leads to sizes 0
|
1209 |
-
if is_human:
|
1210 |
-
len_human_min.append(min([len(x) for x in list_what]))
|
1211 |
-
len_human_max.append(max([len(x) for x in list_what]))
|
1212 |
-
len_human_mean.append(np.mean([len(x) for x in list_what]))
|
1213 |
-
else:
|
1214 |
-
len_bot_min.append(min([len(x) for x in list_what]))
|
1215 |
-
len_bot_max.append(max([len(x) for x in list_what]))
|
1216 |
-
len_bot_mean.append(np.mean([len(x) for x in list_what]))
|
1217 |
-
df['len_human_min'] = len_human_min
|
1218 |
-
df['len_human_max'] = len_human_max
|
1219 |
-
df['len_human_mean'] = len_human_mean
|
1220 |
-
df['len_bot_min'] = len_bot_min
|
1221 |
-
df['len_bot_max'] = len_bot_max
|
1222 |
-
df['len_bot_mean'] = len_bot_mean
|
1223 |
-
np.random.seed(1234)
|
1224 |
-
pd.set_option('display.max_columns', None)
|
1225 |
-
print("Before chopping")
|
1226 |
-
print(df.describe())
|
1227 |
-
return df
|
1228 |
-
|
1229 |
-
|
1230 |
-
def test_grade():
|
1231 |
-
df = None
|
1232 |
-
|
1233 |
-
file = "h2oGPT.cleaned.chopped.human_bot.shorter.parquet"
|
1234 |
-
output_file = "h2oGPT.cleaned.graded1.human_bot.shorter.parquet"
|
1235 |
-
if not os.path.exists(output_file):
|
1236 |
-
if df is None:
|
1237 |
-
df = pd.read_parquet(file).reset_index(drop=True)
|
1238 |
-
df = add_textstat_grade(df)
|
1239 |
-
min_grade = 10
|
1240 |
-
max_grade = 25
|
1241 |
-
df = df[df['flesch_grade'] >= min_grade]
|
1242 |
-
df = df[df['flesch_grade'] <= max_grade]
|
1243 |
-
print("After Flesch grade")
|
1244 |
-
print(df.describe())
|
1245 |
-
df.to_parquet(output_file, index=False)
|
1246 |
-
|
1247 |
-
file = output_file
|
1248 |
-
output_file = "h2oGPT.cleaned.graded2.human_bot.shorter.parquet"
|
1249 |
-
if not os.path.exists(output_file):
|
1250 |
-
# slower than alt-profanity, do last, but do before deberta grading, since that's slower
|
1251 |
-
if df is None:
|
1252 |
-
df = pd.read_parquet(file).reset_index(drop=True)
|
1253 |
-
df = add_better_profanity_flag(df)
|
1254 |
-
before_rows = df.shape[0]
|
1255 |
-
df = df[df['better_profanity'] == 0]
|
1256 |
-
df = df.drop(['better_profanity'], axis=1)
|
1257 |
-
after_rows = df.shape[0]
|
1258 |
-
print("Dropped %d rows out of %d due to better_profanity" % (before_rows - after_rows, before_rows))
|
1259 |
-
print(df.describe())
|
1260 |
-
df.to_parquet(output_file, index=False)
|
1261 |
-
|
1262 |
-
file = output_file
|
1263 |
-
output_file = 'h2oGPT.cleaned.graded3.human_bot.shorter.parquet'
|
1264 |
-
if not os.path.exists(output_file):
|
1265 |
-
if df is None:
|
1266 |
-
df = pd.read_parquet(file).reset_index(drop=True)
|
1267 |
-
df = add_deberta_grade(df)
|
1268 |
-
min_grade = 0.3
|
1269 |
-
max_grade = np.inf
|
1270 |
-
before_rows = df.shape[0]
|
1271 |
-
df = df[df['grade_deberta'] >= min_grade]
|
1272 |
-
df = df[df['grade_deberta'] <= max_grade]
|
1273 |
-
after_rows = df.shape[0]
|
1274 |
-
print("Dropped %d rows out of %d due to deberta grade" % (before_rows - after_rows, before_rows))
|
1275 |
-
print("After DeBERTa grade")
|
1276 |
-
print(df.describe())
|
1277 |
-
df.to_parquet(output_file, index=False)
|
1278 |
-
|
1279 |
-
file = output_file
|
1280 |
-
output_file = 'h2oGPT.cleaned.graded.human_bot.shorter.parquet'
|
1281 |
-
if df is None:
|
1282 |
-
df = pd.read_parquet(file).reset_index(drop=True)
|
1283 |
-
df.to_parquet(output_file, index=False)
|
1284 |
-
|
1285 |
-
|
1286 |
-
@pytest.mark.parametrize(
|
1287 |
-
"fixup_personality, only_personality, deberta_grading",
|
1288 |
-
[
|
1289 |
-
[False, False, False],
|
1290 |
-
[True, True, False],
|
1291 |
-
[True, False, False],
|
1292 |
-
[True, False, True],
|
1293 |
-
]
|
1294 |
-
)
|
1295 |
-
def test_add_open_assistant(fixup_personality, only_personality, deberta_grading, save_json=True):
|
1296 |
-
"""
|
1297 |
-
Flatten tree structure into one row per path from root to leaf
|
1298 |
-
Also turn into human_bot prompting format:
|
1299 |
-
<human>: question\n<bot>: answer <human>: question2\n<bot>: answer2 Etc.
|
1300 |
-
Also saves a .json locally as side-effect
|
1301 |
-
returns list of dicts, containing intput, prompt_type and source
|
1302 |
-
"""
|
1303 |
-
from datasets import load_dataset
|
1304 |
-
data_file = "OpenAssistant/oasst1"
|
1305 |
-
ds = load_dataset(data_file)
|
1306 |
-
df = pd.concat([ds['train'].to_pandas(), ds['validation'].to_pandas()], axis=0)
|
1307 |
-
rows = {}
|
1308 |
-
message_ids = df['message_id'].values.tolist()
|
1309 |
-
message_tree_ids = df['message_tree_id'].values.tolist()
|
1310 |
-
parent_ids = df['parent_id'].values.tolist()
|
1311 |
-
texts = df['text'].values.tolist()
|
1312 |
-
roles = df['role'].values.tolist()
|
1313 |
-
|
1314 |
-
for i in range(df.shape[0]):
|
1315 |
-
# collect all trees
|
1316 |
-
message_id = message_ids[i]
|
1317 |
-
message_tree_id = message_tree_ids[i]
|
1318 |
-
parent_id = parent_ids[i]
|
1319 |
-
text = texts[i]
|
1320 |
-
if fixup_personality:
|
1321 |
-
text = text.replace("Open Assistant", "h2oGPT")
|
1322 |
-
text = text.replace("Open-Assistant", "h2oGPT")
|
1323 |
-
text = text.replace("open-assistant", "h2oGPT")
|
1324 |
-
text = text.replace("OpenAssistant", "h2oGPT")
|
1325 |
-
text = text.replace("open assistant", "h2oGPT")
|
1326 |
-
text = text.replace("Open Assistand", "h2oGPT")
|
1327 |
-
text = text.replace("Open Assitant", "h2oGPT")
|
1328 |
-
text = text.replace("Open Assistent", "h2oGPT")
|
1329 |
-
text = text.replace("Open Assisstant", "h2oGPT")
|
1330 |
-
text = text.replace("Open Assitent", "h2oGPT")
|
1331 |
-
text = text.replace("Open Assitiant", "h2oGPT")
|
1332 |
-
text = text.replace("Open Assistiant", "h2oGPT")
|
1333 |
-
text = text.replace("Open Assitan ", "h2oGPT ")
|
1334 |
-
text = text.replace("Open Assistan ", "h2oGPT ")
|
1335 |
-
text = text.replace("Open Asistant", "h2oGPT")
|
1336 |
-
text = text.replace("Open Assiant", "h2oGPT")
|
1337 |
-
text = text.replace("Assistant", "h2oGPT")
|
1338 |
-
text = text.replace("LAION AI", "H2O.ai")
|
1339 |
-
text = text.replace("LAION-AI", "H2O.ai")
|
1340 |
-
text = text.replace("LAION,", "H2O.ai,")
|
1341 |
-
text = text.replace("LAION.ai", "H2O.ai")
|
1342 |
-
text = text.replace("LAION.", "H2O.ai.")
|
1343 |
-
text = text.replace("LAION", "H2O.ai")
|
1344 |
-
|
1345 |
-
role = roles[i]
|
1346 |
-
new_data = ('<human>: ' if role == 'prompter' else '<bot>: ') + text
|
1347 |
-
entry = dict(message_id=message_id, parent_id=parent_id, text=new_data)
|
1348 |
-
if message_tree_id not in rows:
|
1349 |
-
rows[message_tree_id] = [entry]
|
1350 |
-
else:
|
1351 |
-
rows[message_tree_id].append(entry)
|
1352 |
-
|
1353 |
-
all_rows = []
|
1354 |
-
|
1355 |
-
for node_id in rows:
|
1356 |
-
# order responses in tree, based on message/parent relationship
|
1357 |
-
conversations = []
|
1358 |
-
|
1359 |
-
list_msgs = rows[node_id]
|
1360 |
-
# find start
|
1361 |
-
while len(list_msgs):
|
1362 |
-
for i, leaf in enumerate(list_msgs):
|
1363 |
-
found = False
|
1364 |
-
parent_id = leaf['parent_id']
|
1365 |
-
if parent_id is None:
|
1366 |
-
# conversation starter
|
1367 |
-
conversations.append(leaf)
|
1368 |
-
found = True
|
1369 |
-
else:
|
1370 |
-
for conv in conversations:
|
1371 |
-
# find all conversations to add my message to
|
1372 |
-
if parent_id in conv['message_id'] and parent_id != conv['message_id'][-len(parent_id):]:
|
1373 |
-
# my message doesn't follow conversation
|
1374 |
-
continue
|
1375 |
-
if parent_id == conv['message_id'][-len(parent_id):]:
|
1376 |
-
# my message follows conversation, but fork first, so another follow-on message can do same
|
1377 |
-
conversations.append(conv.copy())
|
1378 |
-
conv['text'] += f"""
|
1379 |
-
{leaf['text']}
|
1380 |
-
"""
|
1381 |
-
conv['message_id'] += leaf['message_id']
|
1382 |
-
found = True
|
1383 |
-
break
|
1384 |
-
if found:
|
1385 |
-
# my content was used, so nuke from list
|
1386 |
-
del list_msgs[i]
|
1387 |
-
break
|
1388 |
-
|
1389 |
-
# now reduce down to final conversations, find the longest chains of message ids
|
1390 |
-
for i, conv in enumerate(conversations):
|
1391 |
-
for j, conv2 in enumerate(conversations):
|
1392 |
-
if i == j:
|
1393 |
-
continue
|
1394 |
-
if conv['message_id'] and conv2['message_id']:
|
1395 |
-
assert conv['message_id'] != conv2['message_id']
|
1396 |
-
# delete the shorter conversation, if one contains the other
|
1397 |
-
if conv['message_id'] in conv2['message_id']:
|
1398 |
-
conv['message_id'] = None
|
1399 |
-
if conv2['message_id'] in conv['message_id']:
|
1400 |
-
conv2['message_id'] = None
|
1401 |
-
conversations = [c for c in conversations if c['message_id']]
|
1402 |
-
if only_personality:
|
1403 |
-
all_rows.extend(
|
1404 |
-
[dict(input=c['text'] + "\n<human>:", prompt_type='plain', source=data_file) for c in conversations if
|
1405 |
-
'h2oGPT' in c['text']])
|
1406 |
-
else:
|
1407 |
-
all_rows.extend(
|
1408 |
-
[dict(input=c['text'] + "\n<human>:", prompt_type='plain', source=data_file) for c in conversations if
|
1409 |
-
"What is H2O.ai" not in c['text']])
|
1410 |
-
unhelpful = get_unhelpful_list()
|
1411 |
-
all_rows = [x for x in all_rows if not any(u in x['input'] for u in unhelpful)]
|
1412 |
-
personality = create_personality_data()
|
1413 |
-
all_rows.extend(personality * 10)
|
1414 |
-
np.random.seed(123)
|
1415 |
-
np.random.shuffle(all_rows)
|
1416 |
-
print(len(all_rows))
|
1417 |
-
if deberta_grading:
|
1418 |
-
df = pd.DataFrame(all_rows)
|
1419 |
-
df = df.rename(columns={'input': 'text'})
|
1420 |
-
df = add_deberta_grade(df)
|
1421 |
-
df = df.rename(columns={'text': 'input'})
|
1422 |
-
drop = True
|
1423 |
-
if drop:
|
1424 |
-
min_grade = 0.3
|
1425 |
-
max_grade = np.inf
|
1426 |
-
before_rows = df.shape[0]
|
1427 |
-
df = df[df['grade_deberta'] >= min_grade]
|
1428 |
-
df = df[df['grade_deberta'] <= max_grade]
|
1429 |
-
after_rows = df.shape[0]
|
1430 |
-
print("Dropped %d rows out of %d due to deberta grade" % (before_rows - after_rows, before_rows))
|
1431 |
-
print("After DeBERTa grade")
|
1432 |
-
print(df.describe())
|
1433 |
-
all_rows = []
|
1434 |
-
for i in range(df.shape[0]):
|
1435 |
-
all_rows.append(
|
1436 |
-
dict(
|
1437 |
-
input=df['input'].iloc[i],
|
1438 |
-
source=df['source'].iloc[i],
|
1439 |
-
prompt_type=df['prompt_type'].iloc[i],
|
1440 |
-
grade_deberta=df['grade_deberta'].iloc[i],
|
1441 |
-
)
|
1442 |
-
)
|
1443 |
-
if save_json:
|
1444 |
-
data_file = data_file + \
|
1445 |
-
("_h2ogpt" if fixup_personality else "") + \
|
1446 |
-
("_only" if only_personality else "") + \
|
1447 |
-
("_graded" if deberta_grading else "")
|
1448 |
-
for i in range(len(all_rows)):
|
1449 |
-
all_rows[i]['id'] = i
|
1450 |
-
with open(data_file.lower().replace("/", "_") + ".json", "w") as f:
|
1451 |
-
f.write(json.dumps(all_rows, indent=2))
|
1452 |
-
return all_rows
|
1453 |
-
|
1454 |
-
|
1455 |
-
def test_finalize_to_json():
|
1456 |
-
df = pd.read_parquet('h2oGPT.cleaned.graded.human_bot.shorter.parquet')
|
1457 |
-
df = df.rename(columns={'text': 'input'})
|
1458 |
-
|
1459 |
-
print("Number of high-quality human_bot interactions: %s" % df.shape[0], flush=True)
|
1460 |
-
|
1461 |
-
print("Adding open assistant data")
|
1462 |
-
with open("openassistant_oasst1_h2ogpt_graded.json") as f:
|
1463 |
-
open_assistant = json.loads(f.read())
|
1464 |
-
df = pd.concat([df, pd.DataFrame(open_assistant)], axis=0)
|
1465 |
-
|
1466 |
-
def final_clean(df):
|
1467 |
-
from better_profanity import profanity
|
1468 |
-
profanity.load_censor_words_from_file("data/censor_words.txt")
|
1469 |
-
df['profanity'] = parallel_apply(
|
1470 |
-
df['input'],
|
1471 |
-
lambda x: profanity.contains_profanity(x),
|
1472 |
-
n_jobs=-1,
|
1473 |
-
)
|
1474 |
-
return df[(df['profanity'] == 0)].reset_index(drop=True)
|
1475 |
-
|
1476 |
-
print("Before cleaning: Number of final high-quality human_bot interactions: %s" % df.shape[0], flush=True)
|
1477 |
-
df = final_clean(df)
|
1478 |
-
print("After cleaning: Number of final high-quality human_bot interactions: %s" % df.shape[0], flush=True)
|
1479 |
-
print(df.describe())
|
1480 |
-
print(df.shape)
|
1481 |
-
row_list = []
|
1482 |
-
for i in range(df.shape[0]):
|
1483 |
-
row_list.append(
|
1484 |
-
dict(
|
1485 |
-
input=df.loc[i, 'input'],
|
1486 |
-
source=df.loc[i, 'source'],
|
1487 |
-
prompt_type='plain',
|
1488 |
-
)
|
1489 |
-
)
|
1490 |
-
np.random.seed(1234)
|
1491 |
-
np.random.shuffle(row_list)
|
1492 |
-
unhelpful = get_unhelpful_list()
|
1493 |
-
row_list = [x for x in row_list if not any(u in x['input'] for u in unhelpful)]
|
1494 |
-
for i in range(len(row_list)):
|
1495 |
-
row_list[i]['id'] = i
|
1496 |
-
row_list[i]['input'] = row_list[i]['input'].replace(" <bot>:", "\n<bot>:")
|
1497 |
-
with open('h2ogpt-oig-oasst1-instruct-cleaned-v3.json', "w") as f:
|
1498 |
-
f.write(json.dumps(row_list, indent=2))
|
1499 |
-
|
1500 |
-
|
1501 |
-
def create_personality_data():
|
1502 |
-
questions = [
|
1503 |
-
"What's your name?",
|
1504 |
-
"What is your name?",
|
1505 |
-
"What are you?",
|
1506 |
-
"Who are you?",
|
1507 |
-
"Do you have a name?",
|
1508 |
-
"Who trained you?",
|
1509 |
-
"Who created you?",
|
1510 |
-
"Who made you?",
|
1511 |
-
]
|
1512 |
-
answers = [
|
1513 |
-
"I'm h2oGPT, a large language model by H2O.ai.",
|
1514 |
-
"I'm h2oGPT, a large language model by H2O.ai, the visionary leader in democratizing AI.",
|
1515 |
-
"My name is h2oGPT. I'm a large language model by H2O.ai, the visionary leader in democratizing AI.",
|
1516 |
-
"My name is h2oGPT. I'm a large language model trained by H2O.ai.",
|
1517 |
-
"Hi! I'm h2oGPT, a large language model by H2O.ai.",
|
1518 |
-
"Hi! I'm h2oGPT, a large language model by H2O.ai, the visionary leader in democratizing AI.",
|
1519 |
-
]
|
1520 |
-
help = [
|
1521 |
-
"",
|
1522 |
-
" How can I help you?",
|
1523 |
-
" How may I assist you?",
|
1524 |
-
" Nice to meet you.",
|
1525 |
-
]
|
1526 |
-
import itertools
|
1527 |
-
rows = []
|
1528 |
-
for pair in itertools.product(questions, answers, help):
|
1529 |
-
rows.append(
|
1530 |
-
dict(input=f"<human>: {pair[0]}\n<bot>: {pair[1]}{pair[2]}\n<human>:", prompt_type='plain', source="H2O.ai")
|
1531 |
-
)
|
1532 |
-
for row in [
|
1533 |
-
"<human>: What is H2O.ai?\n<bot>: H2O.ai is a technology company that aims to democratize AI and make it accessible to a broader audience by simplifying the process of creating and deploying machine learning models.\n<human>:",
|
1534 |
-
"<human>: What is h2o.ai?\n<bot>: H2O.ai is a technology company that aims to democratize AI and make it accessible to a broader audience by simplifying the process of creating and deploying machine learning models.\n<human>:",
|
1535 |
-
"<human>: What is H2O?\n<bot>: H2O.ai is a technology company that aims to democratize AI and make it accessible to a broader audience by simplifying the process of creating and deploying machine learning models.\n<human>:",
|
1536 |
-
"<human>: Who is h2o.ai?\n<bot>: H2O.ai is a technology company that aims to democratize AI and make it accessible to a broader audience by simplifying the process of creating and deploying machine learning models.\n<human>:",
|
1537 |
-
"<human>: who is h2o.ai?\n<bot>: H2O.ai is a technology company that aims to democratize AI and make it accessible to a broader audience by simplifying the process of creating and deploying machine learning models.\n<human>:",
|
1538 |
-
"<human>: who is h2o?\n<bot>: H2O.ai is a technology company that aims to democratize AI and make it accessible to a broader audience by simplifying the process of creating and deploying machine learning models.\n<human>:",
|
1539 |
-
"<human>: What is H2O.ai?\n<bot>: H2O.ai is the visionary leader in democratizing AI.\n<human>:",
|
1540 |
-
"<human>: Who is H2O.ai?\n<bot>: H2O.ai is the visionary leader in democratizing AI.\n<human>:",
|
1541 |
-
"<human>: Who is H2O?\n<bot>: H2O.ai is the visionary leader in democratizing AI.\n<human>:",
|
1542 |
-
"<human>: Who is h2o?\n<bot>: H2O.ai is the visionary leader in democratizing AI.\n<human>:",
|
1543 |
-
"<human>: who is h2o?\n<bot>: H2O.ai is the visionary leader in democratizing AI.\n<human>:",
|
1544 |
-
]:
|
1545 |
-
rows.append(dict(input=row, prompt_type='plain', source='H2O.ai'))
|
1546 |
-
print(len(rows))
|
1547 |
-
with open("h2ogpt-personality.json", "w") as f:
|
1548 |
-
f.write(json.dumps(rows, indent=2))
|
1549 |
-
return rows
|
1550 |
-
|
1551 |
-
|
1552 |
-
def test_check_stats_data():
|
1553 |
-
filename = 'h2ogpt-oig-oasst1-instruct-cleaned-v3.json'
|
1554 |
-
df = pd.read_json(filename)
|
1555 |
-
|
1556 |
-
# get word stats
|
1557 |
-
df['char_count'] = df['input'].apply(lambda x: len(x))
|
1558 |
-
import matplotlib.pyplot as plt
|
1559 |
-
plt.figure(figsize=(10, 10))
|
1560 |
-
plt.hist(df['char_count'], bins=100)
|
1561 |
-
chars_avg = np.mean(df['char_count'])
|
1562 |
-
chars_median = np.median(df['char_count'])
|
1563 |
-
plt.title("char_count avg: %s median: %s" % (chars_avg, chars_median))
|
1564 |
-
plt.savefig('chars_hist.png')
|
1565 |
-
plt.close()
|
1566 |
-
|
1567 |
-
# get tokenize stats for random sample of 1000 rows
|
1568 |
-
from finetune import generate_and_tokenize_prompt
|
1569 |
-
from loaders import get_loaders, get_tokenizer
|
1570 |
-
from functools import partial
|
1571 |
-
|
1572 |
-
llama_type = False
|
1573 |
-
tokenizer_base_model = base_model = 'h2oai/h2ogpt-oasst1-512-20b'
|
1574 |
-
model_loader, tokenizer_loader = get_loaders(model_name=base_model, reward_type=False, llama_type=llama_type)
|
1575 |
-
local_files_only = False
|
1576 |
-
resume_download = True
|
1577 |
-
use_auth_token = False
|
1578 |
-
tokenizer = get_tokenizer(tokenizer_loader, tokenizer_base_model, local_files_only, resume_download, use_auth_token)
|
1579 |
-
prompt_type = 'plain' # trained with data already in human bot form
|
1580 |
-
train_on_inputs = True
|
1581 |
-
add_eos_token = False
|
1582 |
-
cutoff_len = 512 # can choose 2048
|
1583 |
-
generate_and_tokenize_prompt_fun = partial(generate_and_tokenize_prompt, prompt_type=prompt_type,
|
1584 |
-
train_on_inputs=train_on_inputs, add_eos_token=add_eos_token,
|
1585 |
-
cutoff_len=cutoff_len, tokenizer=tokenizer)
|
1586 |
-
from datasets import load_dataset
|
1587 |
-
data = load_dataset("json", data_files={"train": filename})
|
1588 |
-
val_set_size = 0.90
|
1589 |
-
train_val = data["train"].train_test_split(
|
1590 |
-
test_size=val_set_size, shuffle=True, seed=42
|
1591 |
-
)
|
1592 |
-
train_data = train_val["train"]
|
1593 |
-
train_data = train_data.shuffle().map(generate_and_tokenize_prompt_fun, num_proc=os.cpu_count())
|
1594 |
-
|
1595 |
-
df_tokens = pd.DataFrame([len(x) for x in train_data['input_ids']], columns=['token_count'])
|
1596 |
-
|
1597 |
-
plt.figure(figsize=(10, 10))
|
1598 |
-
plt.hist(df_tokens['token_count'], bins=100)
|
1599 |
-
token_avg = np.mean(df_tokens['token_count'])
|
1600 |
-
token_median = np.median(df_tokens['token_count'])
|
1601 |
-
plt.title("token_count with cutoff=%s avg: %s median: %s" % (cutoff_len, token_avg, token_median))
|
1602 |
-
plt.savefig('token_hist_%s.png' % cutoff_len)
|
1603 |
-
plt.close()
|
1604 |
-
|
1605 |
-
|
1606 |
-
def get_unhelpful_list():
|
1607 |
-
# base versions
|
1608 |
-
unhelpful = ["I'm sorry, I didn't quite understand your question, could you please rephrase it?",
|
1609 |
-
"I'm sorry, but I don't understand your question. Could you please rephrase it?",
|
1610 |
-
"I'm sorry, I don't quite understand your question",
|
1611 |
-
"I'm sorry, I don't know",
|
1612 |
-
"I'm sorry, but I don't know",
|
1613 |
-
"I don't know anything",
|
1614 |
-
"I do not know",
|
1615 |
-
"I don't know",
|
1616 |
-
"I don't know how",
|
1617 |
-
"I do not know how",
|
1618 |
-
"Can you please explain what you mean",
|
1619 |
-
"please explain what you mean",
|
1620 |
-
"please explain",
|
1621 |
-
"I'm sorry, but I don't know how to tell a story. Can you please explain what you mean by",
|
1622 |
-
"I'm sorry but I don't understand what you mean",
|
1623 |
-
"I don't understand",
|
1624 |
-
"I don't have the ability",
|
1625 |
-
"I do not have the ability",
|
1626 |
-
"I do not have",
|
1627 |
-
"I am a language model,",
|
1628 |
-
"I am a large language model,",
|
1629 |
-
"I do not understand your question. Can you please try to make it clearer?",
|
1630 |
-
"I'm sorry, but as an AI language model",
|
1631 |
-
"I apologize, but I cannot rephrase text that I cannot understand. Your post is difficult to read and follow.",
|
1632 |
-
"I apologize, but I am not h2oGPT. I am a language model developed by H2O.ai. How may I help you?",
|
1633 |
-
"Sorry, but I am not an actual Linux shell, nor am I capable of emulating one. I am an open source chat assistant and would be glad t",
|
1634 |
-
"I apologize, but I cannot perform the task you have requested.",
|
1635 |
-
"I'm sorry, I cannot perform this task as I am an AI language model and do not have access",
|
1636 |
-
"I'm sorry, I'm not sure what you're asking for here.",
|
1637 |
-
"I'm not sure what you are asking",
|
1638 |
-
"You need to provide more context",
|
1639 |
-
]
|
1640 |
-
# reduced versions, with redundant parts, just to give context for where they came from
|
1641 |
-
unhelpful += ["sorry, I didn't quite understand your question",
|
1642 |
-
"I didn't quite understand your question",
|
1643 |
-
"I didn't understand your question",
|
1644 |
-
"I did not understand your question",
|
1645 |
-
"I did not understand the question",
|
1646 |
-
"could you please rephrase"
|
1647 |
-
"could you rephrase"
|
1648 |
-
"I do not understand your question.",
|
1649 |
-
"I do not understand the question.",
|
1650 |
-
"I do not understand that question.",
|
1651 |
-
"Can you please try to make it clearer",
|
1652 |
-
"Can you try to make it clearer",
|
1653 |
-
"sorry, but as an AI language model",
|
1654 |
-
"as an AI language model",
|
1655 |
-
"I apologize, but I cannot",
|
1656 |
-
"I cannot rephrase text",
|
1657 |
-
"I cannot understand. Your post is difficult to read and follow."
|
1658 |
-
"Your post is difficult to read and follow."
|
1659 |
-
"I apologize, but I am",
|
1660 |
-
"Sorry, but I am not ",
|
1661 |
-
"nor am I capable",
|
1662 |
-
"I am not capable of",
|
1663 |
-
"I apologize, but I cannot perform the task you have requested",
|
1664 |
-
"I cannot perform the task",
|
1665 |
-
"I cannot complete the task",
|
1666 |
-
"I'm sorry",
|
1667 |
-
"I am sorry",
|
1668 |
-
"do not have access",
|
1669 |
-
"not sure what you're asking for",
|
1670 |
-
"not sure what you are asking for",
|
1671 |
-
"not sure what is being asked",
|
1672 |
-
"I'm not sure what you are asking",
|
1673 |
-
"not sure what you are asking",
|
1674 |
-
"You need to provide more context",
|
1675 |
-
"provide more context",
|
1676 |
-
]
|
1677 |
-
unhelpful += ["As a large language model",
|
1678 |
-
"cannot provide any information",
|
1679 |
-
"As an artificial intelligence I do not have the capability",
|
1680 |
-
"As an artificial intelligence I don't have the capability",
|
1681 |
-
"As an artificial intelligence I can't",
|
1682 |
-
"As an artificial intelligence I cannot",
|
1683 |
-
"I am sorry but I do not understand",
|
1684 |
-
"Can you please explain",
|
1685 |
-
"(sorry couldn't resist)",
|
1686 |
-
"(sorry could not resist)",
|
1687 |
-
" :)",
|
1688 |
-
" ;)",
|
1689 |
-
" :-)",
|
1690 |
-
" ;-)",
|
1691 |
-
" lol ",
|
1692 |
-
"Thanks so much!!!",
|
1693 |
-
"Thank You :)!!!",
|
1694 |
-
"Please try not to repeat",
|
1695 |
-
"I am an AI language model",
|
1696 |
-
"I'm a AI assistant that",
|
1697 |
-
"I'm an AI assistant that",
|
1698 |
-
"I am an AI assistant that",
|
1699 |
-
"etc.",
|
1700 |
-
"etc.etc.",
|
1701 |
-
"etc. etc.",
|
1702 |
-
"etc etc",
|
1703 |
-
]
|
1704 |
-
return unhelpful
|
1705 |
-
|
1706 |
-
|
1707 |
-
def test_check_unhelpful():
|
1708 |
-
# file = '/home/jon/Downloads/openassistant_oasst1_h2ogpt_graded.json'
|
1709 |
-
file = '/home/jon/Downloads/openassistant_oasst1_h2ogpt_grades.json'
|
1710 |
-
# file = 'h2ogpt-oig-oasst1-instruct-cleaned-v2.json'
|
1711 |
-
|
1712 |
-
unhelpful = get_unhelpful_list()
|
1713 |
-
# data = json.load(open(file, 'rt'))
|
1714 |
-
df = pd.read_json(file)
|
1715 |
-
|
1716 |
-
use_reward_score_threshold = False
|
1717 |
-
use_bleu_threshold = False
|
1718 |
-
use_sentence_sim = True
|
1719 |
-
|
1720 |
-
from sacrebleu.metrics import BLEU
|
1721 |
-
bleu = BLEU()
|
1722 |
-
from nltk.translate.bleu_score import sentence_bleu
|
1723 |
-
|
1724 |
-
def get_bleu(actual, expected_list):
|
1725 |
-
# return bleu.sentence_score(actual, expected_list).score
|
1726 |
-
return sentence_bleu(expected_list, actual)
|
1727 |
-
|
1728 |
-
threshold = 0.0
|
1729 |
-
if use_reward_score_threshold:
|
1730 |
-
df = df[df['grade_deberta'] > threshold]
|
1731 |
-
|
1732 |
-
# back to as if original json load
|
1733 |
-
data = df.to_dict(orient='records')
|
1734 |
-
bads = {}
|
1735 |
-
string_all = str(data)
|
1736 |
-
for sub in unhelpful:
|
1737 |
-
bads[sub] = string_all.count(sub)
|
1738 |
-
bads = {k: v for k, v in bads.items() if v > 0}
|
1739 |
-
import pprint
|
1740 |
-
pp = pprint.PrettyPrinter(indent=4)
|
1741 |
-
pp.pprint(bads)
|
1742 |
-
|
1743 |
-
total_bads = sum(list(bads.values()))
|
1744 |
-
print('total_bads: %s' % total_bads, flush=True)
|
1745 |
-
|
1746 |
-
# check just bot
|
1747 |
-
import re
|
1748 |
-
convs = [[x.strip() for x in re.split(r'%s|%s' % (human, bot), y['input']) if x.strip()] for y in data]
|
1749 |
-
humans = [[x for i, x in enumerate(y) if i % 2 == 0] for y in convs]
|
1750 |
-
bots = [[x for i, x in enumerate(y) if i % 2 == 1] for y in convs]
|
1751 |
-
|
1752 |
-
# FIXME: apply back to json etc., just see for now
|
1753 |
-
bleu_threshold = 0.9
|
1754 |
-
if use_bleu_threshold:
|
1755 |
-
bots = [[x for x in y if get_bleu(x, unhelpful) < bleu_threshold] for y in tqdm(bots)]
|
1756 |
-
|
1757 |
-
cosine_sim_threshold = 0.8
|
1758 |
-
if use_sentence_sim:
|
1759 |
-
# pip install sentence_transformers-2.2.2
|
1760 |
-
from sentence_transformers import SentenceTransformer
|
1761 |
-
# sent_model = 'bert-base-nli-mean-tokens'
|
1762 |
-
# sent_model = 'nli-distilroberta-base-v2'
|
1763 |
-
sent_model = 'all-MiniLM-L6-v2'
|
1764 |
-
model = SentenceTransformer(sent_model)
|
1765 |
-
sentence_embeddings = model.encode(unhelpful)
|
1766 |
-
from sklearn.metrics.pairwise import cosine_similarity
|
1767 |
-
bots = [x for x in tqdm(bots) if
|
1768 |
-
np.max(cosine_similarity(model.encode(x), sentence_embeddings)) < cosine_sim_threshold]
|
1769 |
-
|
1770 |
-
bads_bots = {}
|
1771 |
-
string_all = str(bots)
|
1772 |
-
for sub in unhelpful:
|
1773 |
-
bads_bots[sub] = string_all.count(sub)
|
1774 |
-
bads_bots = {k: v for k, v in bads_bots.items() if v > 0}
|
1775 |
-
import pprint
|
1776 |
-
pp = pprint.PrettyPrinter(indent=4)
|
1777 |
-
pp.pprint(bads_bots)
|
1778 |
-
|
1779 |
-
total_bads_bots = sum(list(bads_bots.values()))
|
1780 |
-
print('threshold: %g use_bleu_threshold: %g total_bads_bots: %s total_bots: %s total_humans: %s' % (
|
1781 |
-
threshold, use_bleu_threshold, total_bads_bots, len(bots), len(humans)), flush=True)
|
1782 |
-
|
1783 |
-
# assert len(bads) == 0, bads
|
1784 |
-
assert len(bads_bots) == 0, bads_bots
|
1785 |
-
|
1786 |
-
|
1787 |
-
def test_fortune2000_personalized():
|
1788 |
-
row_list = []
|
1789 |
-
import glob
|
1790 |
-
if not os.path.isdir("wikitext"):
|
1791 |
-
raise RuntimeError("download https://github.com/h2oai/h2ogpt/files/11423008/wikitext.zip and unzip")
|
1792 |
-
for file in glob.glob("wikitext/*.txt"):
|
1793 |
-
with open(file, "r") as f:
|
1794 |
-
blob = f.read()
|
1795 |
-
N = 512 * 4
|
1796 |
-
row_list.extend([{'input': s, 'prompt_type': 'plain', 'source': "%s" % os.path.basename(file)}
|
1797 |
-
for s in get_sentences(blob, N) if s])
|
1798 |
-
personality = create_personality_data()
|
1799 |
-
import copy
|
1800 |
-
for i in range(10):
|
1801 |
-
row_list.extend(copy.deepcopy(personality))
|
1802 |
-
np.random.seed(123)
|
1803 |
-
np.random.shuffle(row_list)
|
1804 |
-
for i in range(len(row_list)):
|
1805 |
-
row_list[i]['id'] = i
|
1806 |
-
for i in range(len(row_list)):
|
1807 |
-
assert row_list[i]['id'] == i
|
1808 |
-
with open("h2ogpt-fortune2000-personalized.json", "w") as ff:
|
1809 |
-
ff.write(json.dumps(row_list, indent=2))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
enums.py
DELETED
@@ -1,120 +0,0 @@
|
|
1 |
-
from enum import Enum
|
2 |
-
|
3 |
-
|
4 |
-
class PromptType(Enum):
|
5 |
-
custom = -1
|
6 |
-
plain = 0
|
7 |
-
instruct = 1
|
8 |
-
quality = 2
|
9 |
-
human_bot = 3
|
10 |
-
dai_faq = 4
|
11 |
-
summarize = 5
|
12 |
-
simple_instruct = 6
|
13 |
-
instruct_vicuna = 7
|
14 |
-
instruct_with_end = 8
|
15 |
-
human_bot_orig = 9
|
16 |
-
prompt_answer = 10
|
17 |
-
open_assistant = 11
|
18 |
-
wizard_lm = 12
|
19 |
-
wizard_mega = 13
|
20 |
-
instruct_vicuna2 = 14
|
21 |
-
instruct_vicuna3 = 15
|
22 |
-
wizard2 = 16
|
23 |
-
wizard3 = 17
|
24 |
-
instruct_simple = 18
|
25 |
-
wizard_vicuna = 19
|
26 |
-
openai = 20
|
27 |
-
openai_chat = 21
|
28 |
-
gptj = 22
|
29 |
-
prompt_answer_openllama = 23
|
30 |
-
vicuna11 = 24
|
31 |
-
mptinstruct = 25
|
32 |
-
mptchat = 26
|
33 |
-
falcon = 27
|
34 |
-
guanaco = 28
|
35 |
-
llama2 = 29
|
36 |
-
|
37 |
-
|
38 |
-
class DocumentSubset(Enum):
|
39 |
-
Relevant = 0
|
40 |
-
RelSources = 1
|
41 |
-
TopKSources = 2
|
42 |
-
|
43 |
-
|
44 |
-
non_query_commands = [
|
45 |
-
DocumentSubset.RelSources.name,
|
46 |
-
DocumentSubset.TopKSources.name
|
47 |
-
]
|
48 |
-
|
49 |
-
|
50 |
-
class DocumentChoice(Enum):
|
51 |
-
ALL = 'All'
|
52 |
-
|
53 |
-
|
54 |
-
class LangChainMode(Enum):
|
55 |
-
"""LangChain mode"""
|
56 |
-
|
57 |
-
DISABLED = "Disabled"
|
58 |
-
LLM = "LLM"
|
59 |
-
ALL = "All"
|
60 |
-
WIKI = "wiki"
|
61 |
-
WIKI_FULL = "wiki_full"
|
62 |
-
USER_DATA = "UserData"
|
63 |
-
MY_DATA = "MyData"
|
64 |
-
GITHUB_H2OGPT = "github h2oGPT"
|
65 |
-
H2O_DAI_DOCS = "DriverlessAI docs"
|
66 |
-
|
67 |
-
|
68 |
-
# modes should not be removed from visible list or added by name
|
69 |
-
langchain_modes_intrinsic = [LangChainMode.DISABLED.value,
|
70 |
-
LangChainMode.LLM.value,
|
71 |
-
LangChainMode.MY_DATA.value]
|
72 |
-
|
73 |
-
|
74 |
-
class LangChainAction(Enum):
|
75 |
-
"""LangChain action"""
|
76 |
-
|
77 |
-
QUERY = "Query"
|
78 |
-
# WIP:
|
79 |
-
# SUMMARIZE_MAP = "Summarize_map_reduce"
|
80 |
-
SUMMARIZE_MAP = "Summarize"
|
81 |
-
SUMMARIZE_ALL = "Summarize_all"
|
82 |
-
SUMMARIZE_REFINE = "Summarize_refine"
|
83 |
-
|
84 |
-
|
85 |
-
class LangChainAgent(Enum):
|
86 |
-
"""LangChain agents"""
|
87 |
-
|
88 |
-
SEARCH = "Search"
|
89 |
-
# CSV = "csv" # WIP
|
90 |
-
|
91 |
-
|
92 |
-
no_server_str = no_lora_str = no_model_str = '[None/Remove]'
|
93 |
-
|
94 |
-
# from site-packages/langchain/llms/openai.py
|
95 |
-
# but needed since ChatOpenAI doesn't have this information
|
96 |
-
model_token_mapping = {
|
97 |
-
"gpt-4": 8192,
|
98 |
-
"gpt-4-0314": 8192,
|
99 |
-
"gpt-4-32k": 32768,
|
100 |
-
"gpt-4-32k-0314": 32768,
|
101 |
-
"gpt-3.5-turbo": 4096,
|
102 |
-
"gpt-3.5-turbo-16k": 16 * 1024,
|
103 |
-
"gpt-3.5-turbo-0301": 4096,
|
104 |
-
"text-ada-001": 2049,
|
105 |
-
"ada": 2049,
|
106 |
-
"text-babbage-001": 2040,
|
107 |
-
"babbage": 2049,
|
108 |
-
"text-curie-001": 2049,
|
109 |
-
"curie": 2049,
|
110 |
-
"davinci": 2049,
|
111 |
-
"text-davinci-003": 4097,
|
112 |
-
"text-davinci-002": 4097,
|
113 |
-
"code-davinci-002": 8001,
|
114 |
-
"code-davinci-001": 8001,
|
115 |
-
"code-cushman-002": 2048,
|
116 |
-
"code-cushman-001": 2048,
|
117 |
-
}
|
118 |
-
|
119 |
-
source_prefix = "Sources [Score | Link]:"
|
120 |
-
source_postfix = "End Sources<p>"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
evaluate_params.py
DELETED
@@ -1,52 +0,0 @@
|
|
1 |
-
input_args_list = ['model_state', 'my_db_state', 'selection_docs_state']
|
2 |
-
|
3 |
-
|
4 |
-
no_default_param_names = [
|
5 |
-
'instruction',
|
6 |
-
'iinput',
|
7 |
-
'context',
|
8 |
-
'instruction_nochat',
|
9 |
-
'iinput_nochat',
|
10 |
-
]
|
11 |
-
|
12 |
-
gen_hyper = ['temperature',
|
13 |
-
'top_p',
|
14 |
-
'top_k',
|
15 |
-
'num_beams',
|
16 |
-
'max_new_tokens',
|
17 |
-
'min_new_tokens',
|
18 |
-
'early_stopping',
|
19 |
-
'max_time',
|
20 |
-
'repetition_penalty',
|
21 |
-
'num_return_sequences',
|
22 |
-
'do_sample',
|
23 |
-
]
|
24 |
-
|
25 |
-
eval_func_param_names = ['instruction',
|
26 |
-
'iinput',
|
27 |
-
'context',
|
28 |
-
'stream_output',
|
29 |
-
'prompt_type',
|
30 |
-
'prompt_dict'] + \
|
31 |
-
gen_hyper + \
|
32 |
-
['chat',
|
33 |
-
'instruction_nochat',
|
34 |
-
'iinput_nochat',
|
35 |
-
'langchain_mode',
|
36 |
-
'add_chat_history_to_context',
|
37 |
-
'langchain_action',
|
38 |
-
'langchain_agents',
|
39 |
-
'top_k_docs',
|
40 |
-
'chunk',
|
41 |
-
'chunk_size',
|
42 |
-
'document_subset',
|
43 |
-
'document_choice',
|
44 |
-
]
|
45 |
-
|
46 |
-
# form evaluate defaults for submit_nochat_api
|
47 |
-
eval_func_param_names_defaults = eval_func_param_names.copy()
|
48 |
-
for k in no_default_param_names:
|
49 |
-
if k in eval_func_param_names_defaults:
|
50 |
-
eval_func_param_names_defaults.remove(k)
|
51 |
-
|
52 |
-
eval_extra_columns = ['prompt', 'response', 'score']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gen.py
DELETED
The diff for this file is too large to render.
See raw diff
|
|
generate.py
DELETED
@@ -1,16 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
|
4 |
-
if os.path.dirname(os.path.abspath(__file__)) not in sys.path:
|
5 |
-
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
6 |
-
|
7 |
-
from src.gen import main
|
8 |
-
from src.utils import H2O_Fire
|
9 |
-
|
10 |
-
|
11 |
-
def entrypoint_main():
|
12 |
-
H2O_Fire(main)
|
13 |
-
|
14 |
-
|
15 |
-
if __name__ == "__main__":
|
16 |
-
entrypoint_main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gpt4all_llm.py
DELETED
@@ -1,316 +0,0 @@
|
|
1 |
-
import inspect
|
2 |
-
import os
|
3 |
-
from functools import partial
|
4 |
-
from typing import Dict, Any, Optional, List
|
5 |
-
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
6 |
-
from pydantic import root_validator
|
7 |
-
from langchain.llms import gpt4all
|
8 |
-
from dotenv import dotenv_values
|
9 |
-
|
10 |
-
from utils import FakeTokenizer
|
11 |
-
|
12 |
-
|
13 |
-
def get_model_tokenizer_gpt4all(base_model, **kwargs):
|
14 |
-
# defaults (some of these are generation parameters, so need to be passed in at generation time)
|
15 |
-
model_kwargs = dict(n_threads=os.cpu_count() // 2,
|
16 |
-
temp=kwargs.get('temperature', 0.2),
|
17 |
-
top_p=kwargs.get('top_p', 0.75),
|
18 |
-
top_k=kwargs.get('top_k', 40),
|
19 |
-
n_ctx=2048 - 256)
|
20 |
-
env_gpt4all_file = ".env_gpt4all"
|
21 |
-
model_kwargs.update(dotenv_values(env_gpt4all_file))
|
22 |
-
# make int or float if can to satisfy types for class
|
23 |
-
for k, v in model_kwargs.items():
|
24 |
-
try:
|
25 |
-
if float(v) == int(v):
|
26 |
-
model_kwargs[k] = int(v)
|
27 |
-
else:
|
28 |
-
model_kwargs[k] = float(v)
|
29 |
-
except:
|
30 |
-
pass
|
31 |
-
|
32 |
-
if base_model == "llama":
|
33 |
-
if 'model_path_llama' not in model_kwargs:
|
34 |
-
raise ValueError("No model_path_llama in %s" % env_gpt4all_file)
|
35 |
-
model_path = model_kwargs.pop('model_path_llama')
|
36 |
-
# FIXME: GPT4All version of llama doesn't handle new quantization, so use llama_cpp_python
|
37 |
-
from llama_cpp import Llama
|
38 |
-
# llama sets some things at init model time, not generation time
|
39 |
-
func_names = list(inspect.signature(Llama.__init__).parameters)
|
40 |
-
model_kwargs = {k: v for k, v in model_kwargs.items() if k in func_names}
|
41 |
-
model_kwargs['n_ctx'] = int(model_kwargs['n_ctx'])
|
42 |
-
model = Llama(model_path=model_path, **model_kwargs)
|
43 |
-
elif base_model in "gpt4all_llama":
|
44 |
-
if 'model_name_gpt4all_llama' not in model_kwargs and 'model_path_gpt4all_llama' not in model_kwargs:
|
45 |
-
raise ValueError("No model_name_gpt4all_llama or model_path_gpt4all_llama in %s" % env_gpt4all_file)
|
46 |
-
model_name = model_kwargs.pop('model_name_gpt4all_llama')
|
47 |
-
model_type = 'llama'
|
48 |
-
from gpt4all import GPT4All as GPT4AllModel
|
49 |
-
model = GPT4AllModel(model_name=model_name, model_type=model_type)
|
50 |
-
elif base_model in "gptj":
|
51 |
-
if 'model_name_gptj' not in model_kwargs and 'model_path_gptj' not in model_kwargs:
|
52 |
-
raise ValueError("No model_name_gpt4j or model_path_gpt4j in %s" % env_gpt4all_file)
|
53 |
-
model_name = model_kwargs.pop('model_name_gptj')
|
54 |
-
model_type = 'gptj'
|
55 |
-
from gpt4all import GPT4All as GPT4AllModel
|
56 |
-
model = GPT4AllModel(model_name=model_name, model_type=model_type)
|
57 |
-
else:
|
58 |
-
raise ValueError("No such base_model %s" % base_model)
|
59 |
-
return model, FakeTokenizer(), 'cpu'
|
60 |
-
|
61 |
-
|
62 |
-
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
63 |
-
|
64 |
-
|
65 |
-
class H2OStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
|
66 |
-
|
67 |
-
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
68 |
-
"""Run on new LLM token. Only available when streaming is enabled."""
|
69 |
-
# streaming to std already occurs without this
|
70 |
-
# sys.stdout.write(token)
|
71 |
-
# sys.stdout.flush()
|
72 |
-
pass
|
73 |
-
|
74 |
-
|
75 |
-
def get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=[]):
|
76 |
-
# default from class
|
77 |
-
model_kwargs = {k: v.default for k, v in dict(inspect.signature(cls).parameters).items() if k not in exclude_list}
|
78 |
-
# from our defaults
|
79 |
-
model_kwargs.update(default_kwargs)
|
80 |
-
# from user defaults
|
81 |
-
model_kwargs.update(env_kwargs)
|
82 |
-
# ensure only valid keys
|
83 |
-
func_names = list(inspect.signature(cls).parameters)
|
84 |
-
model_kwargs = {k: v for k, v in model_kwargs.items() if k in func_names}
|
85 |
-
return model_kwargs
|
86 |
-
|
87 |
-
|
88 |
-
def get_llm_gpt4all(model_name,
|
89 |
-
model=None,
|
90 |
-
max_new_tokens=256,
|
91 |
-
temperature=0.1,
|
92 |
-
repetition_penalty=1.0,
|
93 |
-
top_k=40,
|
94 |
-
top_p=0.7,
|
95 |
-
streaming=False,
|
96 |
-
callbacks=None,
|
97 |
-
prompter=None,
|
98 |
-
context='',
|
99 |
-
iinput='',
|
100 |
-
verbose=False,
|
101 |
-
):
|
102 |
-
assert prompter is not None
|
103 |
-
env_gpt4all_file = ".env_gpt4all"
|
104 |
-
env_kwargs = dotenv_values(env_gpt4all_file)
|
105 |
-
max_tokens = env_kwargs.pop('max_tokens', 2048 - max_new_tokens)
|
106 |
-
default_kwargs = dict(context_erase=0.5,
|
107 |
-
n_batch=1,
|
108 |
-
max_tokens=max_tokens,
|
109 |
-
n_predict=max_new_tokens,
|
110 |
-
repeat_last_n=64 if repetition_penalty != 1.0 else 0,
|
111 |
-
repeat_penalty=repetition_penalty,
|
112 |
-
temp=temperature,
|
113 |
-
temperature=temperature,
|
114 |
-
top_k=top_k,
|
115 |
-
top_p=top_p,
|
116 |
-
use_mlock=True,
|
117 |
-
verbose=verbose)
|
118 |
-
if model_name == 'llama':
|
119 |
-
cls = H2OLlamaCpp
|
120 |
-
model_path = env_kwargs.pop('model_path_llama') if model is None else model
|
121 |
-
model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=['lc_kwargs'])
|
122 |
-
model_kwargs.update(dict(model_path=model_path, callbacks=callbacks, streaming=streaming,
|
123 |
-
prompter=prompter, context=context, iinput=iinput))
|
124 |
-
llm = cls(**model_kwargs)
|
125 |
-
llm.client.verbose = verbose
|
126 |
-
elif model_name == 'gpt4all_llama':
|
127 |
-
cls = H2OGPT4All
|
128 |
-
model_path = env_kwargs.pop('model_path_gpt4all_llama') if model is None else model
|
129 |
-
model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=['lc_kwargs'])
|
130 |
-
model_kwargs.update(
|
131 |
-
dict(model=model_path, backend='llama', callbacks=callbacks, streaming=streaming,
|
132 |
-
prompter=prompter, context=context, iinput=iinput))
|
133 |
-
llm = cls(**model_kwargs)
|
134 |
-
elif model_name == 'gptj':
|
135 |
-
cls = H2OGPT4All
|
136 |
-
model_path = env_kwargs.pop('model_path_gptj') if model is None else model
|
137 |
-
model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=['lc_kwargs'])
|
138 |
-
model_kwargs.update(
|
139 |
-
dict(model=model_path, backend='gptj', callbacks=callbacks, streaming=streaming,
|
140 |
-
prompter=prompter, context=context, iinput=iinput))
|
141 |
-
llm = cls(**model_kwargs)
|
142 |
-
else:
|
143 |
-
raise RuntimeError("No such model_name %s" % model_name)
|
144 |
-
return llm
|
145 |
-
|
146 |
-
|
147 |
-
class H2OGPT4All(gpt4all.GPT4All):
|
148 |
-
model: Any
|
149 |
-
prompter: Any
|
150 |
-
context: Any = ''
|
151 |
-
iinput: Any = ''
|
152 |
-
"""Path to the pre-trained GPT4All model file."""
|
153 |
-
|
154 |
-
@root_validator()
|
155 |
-
def validate_environment(cls, values: Dict) -> Dict:
|
156 |
-
"""Validate that the python package exists in the environment."""
|
157 |
-
try:
|
158 |
-
if isinstance(values["model"], str):
|
159 |
-
from gpt4all import GPT4All as GPT4AllModel
|
160 |
-
|
161 |
-
full_path = values["model"]
|
162 |
-
model_path, delimiter, model_name = full_path.rpartition("/")
|
163 |
-
model_path += delimiter
|
164 |
-
|
165 |
-
values["client"] = GPT4AllModel(
|
166 |
-
model_name=model_name,
|
167 |
-
model_path=model_path or None,
|
168 |
-
model_type=values["backend"],
|
169 |
-
allow_download=False,
|
170 |
-
)
|
171 |
-
if values["n_threads"] is not None:
|
172 |
-
# set n_threads
|
173 |
-
values["client"].model.set_thread_count(values["n_threads"])
|
174 |
-
else:
|
175 |
-
values["client"] = values["model"]
|
176 |
-
try:
|
177 |
-
values["backend"] = values["client"].model_type
|
178 |
-
except AttributeError:
|
179 |
-
# The below is for compatibility with GPT4All Python bindings <= 0.2.3.
|
180 |
-
values["backend"] = values["client"].model.model_type
|
181 |
-
|
182 |
-
except ImportError:
|
183 |
-
raise ValueError(
|
184 |
-
"Could not import gpt4all python package. "
|
185 |
-
"Please install it with `pip install gpt4all`."
|
186 |
-
)
|
187 |
-
return values
|
188 |
-
|
189 |
-
def _call(
|
190 |
-
self,
|
191 |
-
prompt: str,
|
192 |
-
stop: Optional[List[str]] = None,
|
193 |
-
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
194 |
-
**kwargs,
|
195 |
-
) -> str:
|
196 |
-
# Roughly 4 chars per token if natural language
|
197 |
-
n_ctx = 2048
|
198 |
-
prompt = prompt[-self.max_tokens * 4:]
|
199 |
-
|
200 |
-
# use instruct prompting
|
201 |
-
data_point = dict(context=self.context, instruction=prompt, input=self.iinput)
|
202 |
-
prompt = self.prompter.generate_prompt(data_point)
|
203 |
-
|
204 |
-
verbose = False
|
205 |
-
if verbose:
|
206 |
-
print("_call prompt: %s" % prompt, flush=True)
|
207 |
-
# FIXME: GPT4ALl doesn't support yield during generate, so cannot support streaming except via itself to stdout
|
208 |
-
return super()._call(prompt, stop=stop, run_manager=run_manager)
|
209 |
-
|
210 |
-
|
211 |
-
from langchain.llms import LlamaCpp
|
212 |
-
|
213 |
-
|
214 |
-
class H2OLlamaCpp(LlamaCpp):
|
215 |
-
model_path: Any
|
216 |
-
prompter: Any
|
217 |
-
context: Any
|
218 |
-
iinput: Any
|
219 |
-
"""Path to the pre-trained GPT4All model file."""
|
220 |
-
|
221 |
-
@root_validator()
|
222 |
-
def validate_environment(cls, values: Dict) -> Dict:
|
223 |
-
"""Validate that llama-cpp-python library is installed."""
|
224 |
-
if isinstance(values["model_path"], str):
|
225 |
-
model_path = values["model_path"]
|
226 |
-
model_param_names = [
|
227 |
-
"lora_path",
|
228 |
-
"lora_base",
|
229 |
-
"n_ctx",
|
230 |
-
"n_parts",
|
231 |
-
"seed",
|
232 |
-
"f16_kv",
|
233 |
-
"logits_all",
|
234 |
-
"vocab_only",
|
235 |
-
"use_mlock",
|
236 |
-
"n_threads",
|
237 |
-
"n_batch",
|
238 |
-
"use_mmap",
|
239 |
-
"last_n_tokens_size",
|
240 |
-
]
|
241 |
-
model_params = {k: values[k] for k in model_param_names}
|
242 |
-
# For backwards compatibility, only include if non-null.
|
243 |
-
if values["n_gpu_layers"] is not None:
|
244 |
-
model_params["n_gpu_layers"] = values["n_gpu_layers"]
|
245 |
-
|
246 |
-
try:
|
247 |
-
from llama_cpp import Llama
|
248 |
-
|
249 |
-
values["client"] = Llama(model_path, **model_params)
|
250 |
-
except ImportError:
|
251 |
-
raise ModuleNotFoundError(
|
252 |
-
"Could not import llama-cpp-python library. "
|
253 |
-
"Please install the llama-cpp-python library to "
|
254 |
-
"use this embedding model: pip install llama-cpp-python"
|
255 |
-
)
|
256 |
-
except Exception as e:
|
257 |
-
raise ValueError(
|
258 |
-
f"Could not load Llama model from path: {model_path}. "
|
259 |
-
f"Received error {e}"
|
260 |
-
)
|
261 |
-
else:
|
262 |
-
values["client"] = values["model_path"]
|
263 |
-
return values
|
264 |
-
|
265 |
-
def _call(
|
266 |
-
self,
|
267 |
-
prompt: str,
|
268 |
-
stop: Optional[List[str]] = None,
|
269 |
-
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
270 |
-
**kwargs,
|
271 |
-
) -> str:
|
272 |
-
verbose = False
|
273 |
-
# tokenize twice, just to count tokens, since llama cpp python wrapper has no way to truncate
|
274 |
-
# still have to avoid crazy sizes, else hit llama_tokenize: too many tokens -- might still hit, not fatal
|
275 |
-
prompt = prompt[-self.n_ctx * 4:]
|
276 |
-
prompt_tokens = self.client.tokenize(b" " + prompt.encode("utf-8"))
|
277 |
-
num_prompt_tokens = len(prompt_tokens)
|
278 |
-
if num_prompt_tokens > self.n_ctx:
|
279 |
-
# conservative by using int()
|
280 |
-
chars_per_token = int(len(prompt) / num_prompt_tokens)
|
281 |
-
prompt = prompt[-self.n_ctx * chars_per_token:]
|
282 |
-
if verbose:
|
283 |
-
print("reducing tokens, assuming average of %s chars/token: %s" % chars_per_token, flush=True)
|
284 |
-
prompt_tokens2 = self.client.tokenize(b" " + prompt.encode("utf-8"))
|
285 |
-
num_prompt_tokens2 = len(prompt_tokens2)
|
286 |
-
print("reduced tokens from %d -> %d" % (num_prompt_tokens, num_prompt_tokens2), flush=True)
|
287 |
-
|
288 |
-
# use instruct prompting
|
289 |
-
data_point = dict(context=self.context, instruction=prompt, input=self.iinput)
|
290 |
-
prompt = self.prompter.generate_prompt(data_point)
|
291 |
-
|
292 |
-
if verbose:
|
293 |
-
print("_call prompt: %s" % prompt, flush=True)
|
294 |
-
|
295 |
-
if self.streaming:
|
296 |
-
text_callback = None
|
297 |
-
if run_manager:
|
298 |
-
text_callback = partial(
|
299 |
-
run_manager.on_llm_new_token, verbose=self.verbose
|
300 |
-
)
|
301 |
-
# parent handler of streamer expects to see prompt first else output="" and lose if prompt=None in prompter
|
302 |
-
if text_callback:
|
303 |
-
text_callback(prompt)
|
304 |
-
text = ""
|
305 |
-
for token in self.stream(prompt=prompt, stop=stop, run_manager=run_manager):
|
306 |
-
text_chunk = token["choices"][0]["text"]
|
307 |
-
# self.stream already calls text_callback
|
308 |
-
# if text_callback:
|
309 |
-
# text_callback(text_chunk)
|
310 |
-
text += text_chunk
|
311 |
-
return text
|
312 |
-
else:
|
313 |
-
params = self._get_parameters(stop)
|
314 |
-
params = {**params, **kwargs}
|
315 |
-
result = self.client(prompt=prompt, **params)
|
316 |
-
return result["choices"][0]["text"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gpt_langchain.py
DELETED
The diff for this file is too large to render.
See raw diff
|
|
gradio_runner.py
DELETED
The diff for this file is too large to render.
See raw diff
|
|
gradio_themes.py
DELETED
@@ -1,231 +0,0 @@
|
|
1 |
-
from __future__ import annotations
|
2 |
-
|
3 |
-
from typing import Iterable
|
4 |
-
|
5 |
-
from gradio.themes.soft import Soft
|
6 |
-
from gradio.themes import Color, Size
|
7 |
-
from gradio.themes.utils import colors, sizes, fonts
|
8 |
-
|
9 |
-
h2o_yellow = Color(
|
10 |
-
name="yellow",
|
11 |
-
c50="#fffef2",
|
12 |
-
c100="#fff9e6",
|
13 |
-
c200="#ffecb3",
|
14 |
-
c300="#ffe28c",
|
15 |
-
c400="#ffd659",
|
16 |
-
c500="#fec925",
|
17 |
-
c600="#e6ac00",
|
18 |
-
c700="#bf8f00",
|
19 |
-
c800="#a67c00",
|
20 |
-
c900="#664d00",
|
21 |
-
c950="#403000",
|
22 |
-
)
|
23 |
-
h2o_gray = Color(
|
24 |
-
name="gray",
|
25 |
-
c50="#f8f8f8",
|
26 |
-
c100="#e5e5e5",
|
27 |
-
c200="#cccccc",
|
28 |
-
c300="#b2b2b2",
|
29 |
-
c400="#999999",
|
30 |
-
c500="#7f7f7f",
|
31 |
-
c600="#666666",
|
32 |
-
c700="#4c4c4c",
|
33 |
-
c800="#333333",
|
34 |
-
c900="#191919",
|
35 |
-
c950="#0d0d0d",
|
36 |
-
)
|
37 |
-
|
38 |
-
|
39 |
-
text_xsm = Size(
|
40 |
-
name="text_xsm",
|
41 |
-
xxs="4px",
|
42 |
-
xs="5px",
|
43 |
-
sm="6px",
|
44 |
-
md="7px",
|
45 |
-
lg="8px",
|
46 |
-
xl="10px",
|
47 |
-
xxl="12px",
|
48 |
-
)
|
49 |
-
|
50 |
-
|
51 |
-
spacing_xsm = Size(
|
52 |
-
name="spacing_xsm",
|
53 |
-
xxs="1px",
|
54 |
-
xs="1px",
|
55 |
-
sm="1px",
|
56 |
-
md="2px",
|
57 |
-
lg="3px",
|
58 |
-
xl="5px",
|
59 |
-
xxl="7px",
|
60 |
-
)
|
61 |
-
|
62 |
-
|
63 |
-
radius_xsm = Size(
|
64 |
-
name="radius_xsm",
|
65 |
-
xxs="1px",
|
66 |
-
xs="1px",
|
67 |
-
sm="1px",
|
68 |
-
md="2px",
|
69 |
-
lg="3px",
|
70 |
-
xl="5px",
|
71 |
-
xxl="7px",
|
72 |
-
)
|
73 |
-
|
74 |
-
|
75 |
-
class H2oTheme(Soft):
|
76 |
-
def __init__(
|
77 |
-
self,
|
78 |
-
*,
|
79 |
-
primary_hue: colors.Color | str = h2o_yellow,
|
80 |
-
secondary_hue: colors.Color | str = h2o_yellow,
|
81 |
-
neutral_hue: colors.Color | str = h2o_gray,
|
82 |
-
spacing_size: sizes.Size | str = sizes.spacing_md,
|
83 |
-
radius_size: sizes.Size | str = sizes.radius_md,
|
84 |
-
text_size: sizes.Size | str = sizes.text_lg,
|
85 |
-
font: fonts.Font
|
86 |
-
| str
|
87 |
-
| Iterable[fonts.Font | str] = (
|
88 |
-
fonts.GoogleFont("Montserrat"),
|
89 |
-
"ui-sans-serif",
|
90 |
-
"system-ui",
|
91 |
-
"sans-serif",
|
92 |
-
),
|
93 |
-
font_mono: fonts.Font
|
94 |
-
| str
|
95 |
-
| Iterable[fonts.Font | str] = (
|
96 |
-
fonts.GoogleFont("IBM Plex Mono"),
|
97 |
-
"ui-monospace",
|
98 |
-
"Consolas",
|
99 |
-
"monospace",
|
100 |
-
),
|
101 |
-
):
|
102 |
-
super().__init__(
|
103 |
-
primary_hue=primary_hue,
|
104 |
-
secondary_hue=secondary_hue,
|
105 |
-
neutral_hue=neutral_hue,
|
106 |
-
spacing_size=spacing_size,
|
107 |
-
radius_size=radius_size,
|
108 |
-
text_size=text_size,
|
109 |
-
font=font,
|
110 |
-
font_mono=font_mono,
|
111 |
-
)
|
112 |
-
super().set(
|
113 |
-
link_text_color="#3344DD",
|
114 |
-
link_text_color_hover="#3344DD",
|
115 |
-
link_text_color_visited="#3344DD",
|
116 |
-
link_text_color_dark="#74abff",
|
117 |
-
link_text_color_hover_dark="#a3c8ff",
|
118 |
-
link_text_color_active_dark="#a3c8ff",
|
119 |
-
link_text_color_visited_dark="#74abff",
|
120 |
-
button_primary_text_color="*neutral_950",
|
121 |
-
button_primary_text_color_dark="*neutral_950",
|
122 |
-
button_primary_background_fill="*primary_500",
|
123 |
-
button_primary_background_fill_dark="*primary_500",
|
124 |
-
block_label_background_fill="*primary_500",
|
125 |
-
block_label_background_fill_dark="*primary_500",
|
126 |
-
block_label_text_color="*neutral_950",
|
127 |
-
block_label_text_color_dark="*neutral_950",
|
128 |
-
block_title_text_color="*neutral_950",
|
129 |
-
block_title_text_color_dark="*neutral_950",
|
130 |
-
block_background_fill_dark="*neutral_950",
|
131 |
-
body_background_fill="*neutral_50",
|
132 |
-
body_background_fill_dark="*neutral_900",
|
133 |
-
background_fill_primary_dark="*block_background_fill",
|
134 |
-
block_radius="0 0 8px 8px",
|
135 |
-
checkbox_label_text_color_selected_dark='#000000',
|
136 |
-
#checkbox_label_text_size="*text_xs", # too small for iPhone etc. but good if full large screen zoomed to fit
|
137 |
-
checkbox_label_text_size="*text_sm",
|
138 |
-
#radio_circle="""url("data:image/svg+xml,%3csvg viewBox='0 0 32 32' fill='white' xmlns='http://www.w3.org/2000/svg'%3e%3ccircle cx='32' cy='32' r='1'/%3e%3c/svg%3e")""",
|
139 |
-
#checkbox_border_width=1,
|
140 |
-
#heckbox_border_width_dark=1,
|
141 |
-
)
|
142 |
-
|
143 |
-
|
144 |
-
class SoftTheme(Soft):
|
145 |
-
def __init__(
|
146 |
-
self,
|
147 |
-
*,
|
148 |
-
primary_hue: colors.Color | str = colors.indigo,
|
149 |
-
secondary_hue: colors.Color | str = colors.indigo,
|
150 |
-
neutral_hue: colors.Color | str = colors.gray,
|
151 |
-
spacing_size: sizes.Size | str = sizes.spacing_md,
|
152 |
-
radius_size: sizes.Size | str = sizes.radius_md,
|
153 |
-
text_size: sizes.Size | str = sizes.text_md,
|
154 |
-
font: fonts.Font
|
155 |
-
| str
|
156 |
-
| Iterable[fonts.Font | str] = (
|
157 |
-
fonts.GoogleFont("Montserrat"),
|
158 |
-
"ui-sans-serif",
|
159 |
-
"system-ui",
|
160 |
-
"sans-serif",
|
161 |
-
),
|
162 |
-
font_mono: fonts.Font
|
163 |
-
| str
|
164 |
-
| Iterable[fonts.Font | str] = (
|
165 |
-
fonts.GoogleFont("IBM Plex Mono"),
|
166 |
-
"ui-monospace",
|
167 |
-
"Consolas",
|
168 |
-
"monospace",
|
169 |
-
),
|
170 |
-
):
|
171 |
-
super().__init__(
|
172 |
-
primary_hue=primary_hue,
|
173 |
-
secondary_hue=secondary_hue,
|
174 |
-
neutral_hue=neutral_hue,
|
175 |
-
spacing_size=spacing_size,
|
176 |
-
radius_size=radius_size,
|
177 |
-
text_size=text_size,
|
178 |
-
font=font,
|
179 |
-
font_mono=font_mono,
|
180 |
-
)
|
181 |
-
super().set(
|
182 |
-
checkbox_label_text_size="*text_sm",
|
183 |
-
)
|
184 |
-
|
185 |
-
|
186 |
-
h2o_logo = '<svg id="Layer_1" data-name="Layer 1" xmlns="http://www.w3.org/2000/svg" width="100%" height="100%"' \
|
187 |
-
' viewBox="0 0 600.28 600.28"><defs><style>.cls-1{fill:#fec925;}.cls-2{fill:#161616;}.cls-3{fill:' \
|
188 |
-
'#54585a;}</style></defs><g id="Fill-1"><rect class="cls-1" width="600.28" height="600.28" ' \
|
189 |
-
'rx="23.24"/></g><path class="cls-2" d="M174.33,246.06v92.78H152.86v-38H110.71v38H89.24V246.06h21.' \
|
190 |
-
'47v36.58h42.15V246.06Z"/><path class="cls-2" d="M259.81,321.34v17.5H189.7V324.92l35.78-33.8c8.22-7.' \
|
191 |
-
'82,9.68-12.59,9.68-17.09,0-7.29-5-11.53-14.85-11.53-7.95,0-14.71,3-19.21,9.27L185.46,261.7c7.15-10' \
|
192 |
-
'.47,20.14-17.23,36.84-17.23,20.68,0,34.46,10.6,34.46,27.44,0,9-2.52,17.22-15.51,29.29l-21.33,20.14Z"' \
|
193 |
-
'/><path class="cls-2" d="M268.69,292.45c0-27.57,21.47-48,50.76-48s50.76,20.28,50.76,48-21.6,48-50.' \
|
194 |
-
'76,48S268.69,320,268.69,292.45Zm79.78,0c0-17.63-12.46-29.69-29-29.69s-29,12.06-29,29.69,12.46,29.69' \
|
195 |
-
',29,29.69S348.47,310.08,348.47,292.45Z"/><path class="cls-3" d="M377.23,326.91c0-7.69,5.7-12.73,12.' \
|
196 |
-
'85-12.73s12.86,5,12.86,12.73a12.86,12.86,0,1,1-25.71,0Z"/><path class="cls-3" d="M481.4,298.15v40.' \
|
197 |
-
'69H462.05V330c-3.84,6.49-11.27,9.94-21.74,9.94-16.7,0-26.64-9.28-26.64-21.61,0-12.59,8.88-21.34,30.' \
|
198 |
-
'62-21.34h16.43c0-8.87-5.3-14-16.43-14-7.55,0-15.37,2.51-20.54,6.62l-7.43-14.44c7.82-5.57,19.35-8.' \
|
199 |
-
'62,30.75-8.62C468.81,266.47,481.4,276.54,481.4,298.15Zm-20.68,18.16V309H446.54c-9.67,0-12.72,3.57-' \
|
200 |
-
'12.72,8.35,0,5.16,4.37,8.61,11.66,8.61C452.37,326,458.34,322.8,460.72,316.31Z"/><path class="cls-3"' \
|
201 |
-
' d="M497.56,246.06c0-6.49,5.17-11.53,12.86-11.53s12.86,4.77,12.86,11.13c0,6.89-5.17,11.93-12.86,' \
|
202 |
-
'11.93S497.56,252.55,497.56,246.06Zm2.52,21.47h20.68v71.31H500.08Z"/></svg>'
|
203 |
-
|
204 |
-
|
205 |
-
def get_h2o_title(title, description):
|
206 |
-
# NOTE: Check full width desktop, smallest width browser desktop, iPhone browsers to ensure no overlap etc.
|
207 |
-
return f"""<div style="float:left; justify-content:left; height: 80px; width: 195px; margin-top:0px">
|
208 |
-
{description}
|
209 |
-
</div>
|
210 |
-
<div style="display:flex; justify-content:center; margin-bottom:30px; margin-right:330px;">
|
211 |
-
<div style="height: 60px; width: 60px; margin-right:20px;">{h2o_logo}</div>
|
212 |
-
<h1 style="line-height:60px">{title}</h1>
|
213 |
-
</div>
|
214 |
-
<div style="float:right; height: 80px; width: 80px; margin-top:-100px">
|
215 |
-
<img src="https://raw.githubusercontent.com/h2oai/h2ogpt/main/docs/h2o-qr.png">
|
216 |
-
</div>
|
217 |
-
"""
|
218 |
-
|
219 |
-
|
220 |
-
def get_simple_title(title, description):
|
221 |
-
return f"""{description}<h1 align="center"> {title}</h1>"""
|
222 |
-
|
223 |
-
|
224 |
-
def get_dark_js():
|
225 |
-
return """() => {
|
226 |
-
if (document.querySelectorAll('.dark').length) {
|
227 |
-
document.querySelectorAll('.dark').forEach(el => el.classList.remove('dark'));
|
228 |
-
} else {
|
229 |
-
document.querySelector('body').classList.add('dark');
|
230 |
-
}
|
231 |
-
}"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
h2oai_pipeline.py
DELETED
@@ -1,201 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
|
3 |
-
from transformers import TextGenerationPipeline
|
4 |
-
from transformers.pipelines.text_generation import ReturnType
|
5 |
-
|
6 |
-
from stopping import get_stopping
|
7 |
-
from prompter import Prompter, PromptType
|
8 |
-
|
9 |
-
|
10 |
-
class H2OTextGenerationPipeline(TextGenerationPipeline):
|
11 |
-
def __init__(self, *args, debug=False, chat=False, stream_output=False,
|
12 |
-
sanitize_bot_response=False,
|
13 |
-
use_prompter=True, prompter=None,
|
14 |
-
context='', iinput='',
|
15 |
-
prompt_type=None, prompt_dict=None,
|
16 |
-
max_input_tokens=2048 - 256, **kwargs):
|
17 |
-
"""
|
18 |
-
HF-like pipeline, but handle instruction prompting and stopping (for some models)
|
19 |
-
:param args:
|
20 |
-
:param debug:
|
21 |
-
:param chat:
|
22 |
-
:param stream_output:
|
23 |
-
:param sanitize_bot_response:
|
24 |
-
:param use_prompter: Whether to use prompter. If pass prompt_type, will make prompter
|
25 |
-
:param prompter: prompter, can pass if have already
|
26 |
-
:param prompt_type: prompt_type, e.g. human_bot. See prompt_type to model mapping in from prompter.py.
|
27 |
-
If use_prompter, then will make prompter and use it.
|
28 |
-
:param prompt_dict: dict of get_prompt(, return_dict=True) for prompt_type=custom
|
29 |
-
:param max_input_tokens:
|
30 |
-
:param kwargs:
|
31 |
-
"""
|
32 |
-
super().__init__(*args, **kwargs)
|
33 |
-
self.prompt_text = None
|
34 |
-
self.use_prompter = use_prompter
|
35 |
-
self.prompt_type = prompt_type
|
36 |
-
self.prompt_dict = prompt_dict
|
37 |
-
self.prompter = prompter
|
38 |
-
self.context = context
|
39 |
-
self.iinput = iinput
|
40 |
-
if self.use_prompter:
|
41 |
-
if self.prompter is not None:
|
42 |
-
assert self.prompter.prompt_type is not None
|
43 |
-
else:
|
44 |
-
self.prompter = Prompter(self.prompt_type, self.prompt_dict, debug=debug, chat=chat,
|
45 |
-
stream_output=stream_output)
|
46 |
-
self.human = self.prompter.humanstr
|
47 |
-
self.bot = self.prompter.botstr
|
48 |
-
self.can_stop = True
|
49 |
-
else:
|
50 |
-
self.prompter = None
|
51 |
-
self.human = None
|
52 |
-
self.bot = None
|
53 |
-
self.can_stop = False
|
54 |
-
self.sanitize_bot_response = sanitize_bot_response
|
55 |
-
self.max_input_tokens = max_input_tokens # not for generate, so ok that not kwargs
|
56 |
-
|
57 |
-
@staticmethod
|
58 |
-
def limit_prompt(prompt_text, tokenizer, max_prompt_length=None):
|
59 |
-
verbose = bool(int(os.getenv('VERBOSE_PIPELINE', '0')))
|
60 |
-
|
61 |
-
if hasattr(tokenizer, 'model_max_length'):
|
62 |
-
# model_max_length only defined for generate.py, not raw use of h2oai_pipeline.py
|
63 |
-
model_max_length = tokenizer.model_max_length
|
64 |
-
if max_prompt_length is not None:
|
65 |
-
model_max_length = min(model_max_length, max_prompt_length)
|
66 |
-
# cut at some upper likely limit to avoid excessive tokenization etc
|
67 |
-
# upper bound of 10 chars/token, e.g. special chars sometimes are long
|
68 |
-
if len(prompt_text) > model_max_length * 10:
|
69 |
-
len0 = len(prompt_text)
|
70 |
-
prompt_text = prompt_text[-model_max_length * 10:]
|
71 |
-
if verbose:
|
72 |
-
print("Cut of input: %s -> %s" % (len0, len(prompt_text)), flush=True)
|
73 |
-
else:
|
74 |
-
# unknown
|
75 |
-
model_max_length = None
|
76 |
-
|
77 |
-
num_prompt_tokens = None
|
78 |
-
if model_max_length is not None:
|
79 |
-
# can't wait for "hole" if not plain prompt_type, since would lose prefix like <human>:
|
80 |
-
# For https://github.com/h2oai/h2ogpt/issues/192
|
81 |
-
for trial in range(0, 3):
|
82 |
-
prompt_tokens = tokenizer(prompt_text)['input_ids']
|
83 |
-
num_prompt_tokens = len(prompt_tokens)
|
84 |
-
if num_prompt_tokens > model_max_length:
|
85 |
-
# conservative by using int()
|
86 |
-
chars_per_token = int(len(prompt_text) / num_prompt_tokens)
|
87 |
-
# keep tail, where question is if using langchain
|
88 |
-
prompt_text = prompt_text[-model_max_length * chars_per_token:]
|
89 |
-
if verbose:
|
90 |
-
print("reducing %s tokens, assuming average of %s chars/token for %s characters" % (
|
91 |
-
num_prompt_tokens, chars_per_token, len(prompt_text)), flush=True)
|
92 |
-
else:
|
93 |
-
if verbose:
|
94 |
-
print("using %s tokens with %s chars" % (num_prompt_tokens, len(prompt_text)), flush=True)
|
95 |
-
break
|
96 |
-
|
97 |
-
# Why Below False: don't limit max_new_tokens more, just rely upon stopping to reach limit of model
|
98 |
-
if False:
|
99 |
-
# if input prompt is some number of tokens, despite user request, can't have max_new_tokens more
|
100 |
-
#
|
101 |
-
assert num_prompt_tokens is not None
|
102 |
-
if self.prompt_type not in [PromptType.plain.name, PromptType.plain.value]:
|
103 |
-
# then give room for prompt
|
104 |
-
fudge = 20
|
105 |
-
else:
|
106 |
-
fudge = 0
|
107 |
-
max_new_tokens = max(0, min(generate_kwargs['max_new_tokens'],
|
108 |
-
model_max_length - (num_prompt_tokens + fudge)))
|
109 |
-
if max_new_tokens < generate_kwargs['max_new_tokens']:
|
110 |
-
if verbose:
|
111 |
-
print("Reduced max_new_tokens from %s -> %s" % (
|
112 |
-
generate_kwargs['max_new_tokens'], max_new_tokens))
|
113 |
-
generate_kwargs['max_new_tokens'] = max_new_tokens
|
114 |
-
return prompt_text, num_prompt_tokens
|
115 |
-
|
116 |
-
def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs):
|
117 |
-
prompt_text, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt_text, self.tokenizer)
|
118 |
-
|
119 |
-
data_point = dict(context=self.context, instruction=prompt_text, input=self.iinput)
|
120 |
-
if self.prompter is not None:
|
121 |
-
prompt_text = self.prompter.generate_prompt(data_point)
|
122 |
-
self.prompt_text = prompt_text
|
123 |
-
if handle_long_generation is None:
|
124 |
-
# forces truncation of inputs to avoid critical failure
|
125 |
-
handle_long_generation = None # disable with new approaches
|
126 |
-
return super().preprocess(prompt_text, prefix=prefix, handle_long_generation=handle_long_generation,
|
127 |
-
**generate_kwargs)
|
128 |
-
|
129 |
-
def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, clean_up_tokenization_spaces=True):
|
130 |
-
records = super().postprocess(model_outputs, return_type=return_type,
|
131 |
-
clean_up_tokenization_spaces=clean_up_tokenization_spaces)
|
132 |
-
for rec in records:
|
133 |
-
if self.use_prompter:
|
134 |
-
outputs = rec['generated_text']
|
135 |
-
outputs = self.prompter.get_response(outputs, prompt=self.prompt_text,
|
136 |
-
sanitize_bot_response=self.sanitize_bot_response)
|
137 |
-
elif self.bot and self.human:
|
138 |
-
outputs = rec['generated_text'].split(self.bot)[1].split(self.human)[0]
|
139 |
-
else:
|
140 |
-
outputs = rec['generated_text']
|
141 |
-
rec['generated_text'] = outputs
|
142 |
-
print("prompt: %s\noutputs: %s\n\n" % (self.prompt_text, outputs), flush=True)
|
143 |
-
return records
|
144 |
-
|
145 |
-
def _forward(self, model_inputs, **generate_kwargs):
|
146 |
-
if self.can_stop:
|
147 |
-
stopping_criteria = get_stopping(self.prompt_type, self.prompt_dict,
|
148 |
-
self.tokenizer, self.device,
|
149 |
-
human=self.human, bot=self.bot,
|
150 |
-
model_max_length=self.tokenizer.model_max_length)
|
151 |
-
generate_kwargs['stopping_criteria'] = stopping_criteria
|
152 |
-
# return super()._forward(model_inputs, **generate_kwargs)
|
153 |
-
return self.__forward(model_inputs, **generate_kwargs)
|
154 |
-
|
155 |
-
# FIXME: Copy-paste of original _forward, but removed copy.deepcopy()
|
156 |
-
# FIXME: https://github.com/h2oai/h2ogpt/issues/172
|
157 |
-
def __forward(self, model_inputs, **generate_kwargs):
|
158 |
-
input_ids = model_inputs["input_ids"]
|
159 |
-
attention_mask = model_inputs.get("attention_mask", None)
|
160 |
-
# Allow empty prompts
|
161 |
-
if input_ids.shape[1] == 0:
|
162 |
-
input_ids = None
|
163 |
-
attention_mask = None
|
164 |
-
in_b = 1
|
165 |
-
else:
|
166 |
-
in_b = input_ids.shape[0]
|
167 |
-
prompt_text = model_inputs.pop("prompt_text")
|
168 |
-
|
169 |
-
## If there is a prefix, we may need to adjust the generation length. Do so without permanently modifying
|
170 |
-
## generate_kwargs, as some of the parameterization may come from the initialization of the pipeline.
|
171 |
-
# generate_kwargs = copy.deepcopy(generate_kwargs)
|
172 |
-
prefix_length = generate_kwargs.pop("prefix_length", 0)
|
173 |
-
if prefix_length > 0:
|
174 |
-
has_max_new_tokens = "max_new_tokens" in generate_kwargs or (
|
175 |
-
"generation_config" in generate_kwargs
|
176 |
-
and generate_kwargs["generation_config"].max_new_tokens is not None
|
177 |
-
)
|
178 |
-
if not has_max_new_tokens:
|
179 |
-
generate_kwargs["max_length"] = generate_kwargs.get("max_length") or self.model.config.max_length
|
180 |
-
generate_kwargs["max_length"] += prefix_length
|
181 |
-
has_min_new_tokens = "min_new_tokens" in generate_kwargs or (
|
182 |
-
"generation_config" in generate_kwargs
|
183 |
-
and generate_kwargs["generation_config"].min_new_tokens is not None
|
184 |
-
)
|
185 |
-
if not has_min_new_tokens and "min_length" in generate_kwargs:
|
186 |
-
generate_kwargs["min_length"] += prefix_length
|
187 |
-
|
188 |
-
# BS x SL
|
189 |
-
generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
|
190 |
-
out_b = generated_sequence.shape[0]
|
191 |
-
if self.framework == "pt":
|
192 |
-
generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:])
|
193 |
-
elif self.framework == "tf":
|
194 |
-
from transformers import is_tf_available
|
195 |
-
if is_tf_available():
|
196 |
-
import tensorflow as tf
|
197 |
-
generated_sequence = tf.reshape(generated_sequence,
|
198 |
-
(in_b, out_b // in_b, *generated_sequence.shape[1:]))
|
199 |
-
else:
|
200 |
-
raise ValueError("TF not avaialble.")
|
201 |
-
return {"generated_sequence": generated_sequence, "input_ids": input_ids, "prompt_text": prompt_text}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loaders.py
DELETED
@@ -1,61 +0,0 @@
|
|
1 |
-
import functools
|
2 |
-
|
3 |
-
|
4 |
-
def get_loaders(model_name, reward_type, llama_type=None, load_gptq=''):
|
5 |
-
# NOTE: Some models need specific new prompt_type
|
6 |
-
# E.g. t5_xxl_true_nli_mixture has input format: "premise: PREMISE_TEXT hypothesis: HYPOTHESIS_TEXT".)
|
7 |
-
if load_gptq:
|
8 |
-
from transformers import AutoTokenizer
|
9 |
-
from auto_gptq import AutoGPTQForCausalLM
|
10 |
-
use_triton = False
|
11 |
-
functools.partial(AutoGPTQForCausalLM.from_quantized, quantize_config=None, use_triton=use_triton)
|
12 |
-
return AutoGPTQForCausalLM.from_quantized, AutoTokenizer
|
13 |
-
if llama_type is None:
|
14 |
-
llama_type = "llama" in model_name.lower()
|
15 |
-
if llama_type:
|
16 |
-
from transformers import LlamaForCausalLM, LlamaTokenizer
|
17 |
-
return LlamaForCausalLM.from_pretrained, LlamaTokenizer
|
18 |
-
elif 'distilgpt2' in model_name.lower():
|
19 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
20 |
-
return AutoModelForCausalLM.from_pretrained, AutoTokenizer
|
21 |
-
elif 'gpt2' in model_name.lower():
|
22 |
-
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
23 |
-
return GPT2LMHeadModel.from_pretrained, GPT2Tokenizer
|
24 |
-
elif 'mbart-' in model_name.lower():
|
25 |
-
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
|
26 |
-
return MBartForConditionalGeneration.from_pretrained, MBart50TokenizerFast
|
27 |
-
elif 't5' == model_name.lower() or \
|
28 |
-
't5-' in model_name.lower() or \
|
29 |
-
'flan-' in model_name.lower():
|
30 |
-
from transformers import AutoTokenizer, T5ForConditionalGeneration
|
31 |
-
return T5ForConditionalGeneration.from_pretrained, AutoTokenizer
|
32 |
-
elif 'bigbird' in model_name:
|
33 |
-
from transformers import BigBirdPegasusForConditionalGeneration, AutoTokenizer
|
34 |
-
return BigBirdPegasusForConditionalGeneration.from_pretrained, AutoTokenizer
|
35 |
-
elif 'bart-large-cnn-samsum' in model_name or 'flan-t5-base-samsum' in model_name:
|
36 |
-
from transformers import pipeline
|
37 |
-
return pipeline, "summarization"
|
38 |
-
elif reward_type or 'OpenAssistant/reward-model'.lower() in model_name.lower():
|
39 |
-
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
40 |
-
return AutoModelForSequenceClassification.from_pretrained, AutoTokenizer
|
41 |
-
else:
|
42 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM
|
43 |
-
model_loader = AutoModelForCausalLM
|
44 |
-
tokenizer_loader = AutoTokenizer
|
45 |
-
return model_loader.from_pretrained, tokenizer_loader
|
46 |
-
|
47 |
-
|
48 |
-
def get_tokenizer(tokenizer_loader, tokenizer_base_model, local_files_only, resume_download, use_auth_token):
|
49 |
-
tokenizer = tokenizer_loader.from_pretrained(tokenizer_base_model,
|
50 |
-
local_files_only=local_files_only,
|
51 |
-
resume_download=resume_download,
|
52 |
-
use_auth_token=use_auth_token,
|
53 |
-
padding_side='left')
|
54 |
-
|
55 |
-
tokenizer.pad_token_id = 0 # different from the eos token
|
56 |
-
# when generating, we will use the logits of right-most token to predict the next token
|
57 |
-
# so the padding should be on the left,
|
58 |
-
# e.g. see: https://huggingface.co/transformers/v4.11.3/model_doc/t5.html#inference
|
59 |
-
tokenizer.padding_side = "left" # Allow batched inference
|
60 |
-
|
61 |
-
return tokenizer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prompter.py
DELETED
@@ -1,871 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import ast
|
3 |
-
import time
|
4 |
-
from enums import PromptType # also supports imports from this file from other files
|
5 |
-
|
6 |
-
non_hf_types = ['gpt4all_llama', 'llama', 'gptj']
|
7 |
-
|
8 |
-
prompt_type_to_model_name = {
|
9 |
-
'plain': [
|
10 |
-
'EleutherAI/gpt-j-6B',
|
11 |
-
'EleutherAI/pythia-6.9b',
|
12 |
-
'EleutherAI/pythia-12b',
|
13 |
-
'EleutherAI/pythia-12b-deduped',
|
14 |
-
'EleutherAI/gpt-neox-20b',
|
15 |
-
'openlm-research/open_llama_7b_700bt_preview',
|
16 |
-
'decapoda-research/llama-7b-hf',
|
17 |
-
'decapoda-research/llama-13b-hf',
|
18 |
-
'decapoda-research/llama-30b-hf',
|
19 |
-
'decapoda-research/llama-65b-hf',
|
20 |
-
'facebook/mbart-large-50-many-to-many-mmt',
|
21 |
-
'philschmid/bart-large-cnn-samsum',
|
22 |
-
'philschmid/flan-t5-base-samsum',
|
23 |
-
'gpt2',
|
24 |
-
'distilgpt2',
|
25 |
-
'mosaicml/mpt-7b-storywriter',
|
26 |
-
],
|
27 |
-
'gptj': ['gptj', 'gpt4all_llama'],
|
28 |
-
'prompt_answer': [
|
29 |
-
'h2oai/h2ogpt-gm-oasst1-en-1024-20b',
|
30 |
-
'h2oai/h2ogpt-gm-oasst1-en-1024-12b',
|
31 |
-
'h2oai/h2ogpt-gm-oasst1-multilang-1024-20b',
|
32 |
-
'h2oai/h2ogpt-gm-oasst1-multilang-2048-falcon-7b',
|
33 |
-
'h2oai/h2ogpt-gm-oasst1-multilang-2048-falcon-7b-v2',
|
34 |
-
'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3',
|
35 |
-
'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b',
|
36 |
-
'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v2',
|
37 |
-
'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-40b-v1',
|
38 |
-
'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-40b-v2',
|
39 |
-
'h2oai/h2ogpt-gm-oasst1-en-xgen-7b-8k',
|
40 |
-
'h2oai/h2ogpt-gm-oasst1-multilang-xgen-7b-8k',
|
41 |
-
'TheBloke/h2ogpt-gm-oasst1-en-2048-falcon-40b-v2-GPTQ',
|
42 |
-
],
|
43 |
-
'prompt_answer_openllama': [
|
44 |
-
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt',
|
45 |
-
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2',
|
46 |
-
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-700bt',
|
47 |
-
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b',
|
48 |
-
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-13b',
|
49 |
-
],
|
50 |
-
'instruct': ['TheBloke/llama-30b-supercot-SuperHOT-8K-fp16'], # https://huggingface.co/TheBloke/llama-30b-supercot-SuperHOT-8K-fp16#prompting
|
51 |
-
'instruct_with_end': ['databricks/dolly-v2-12b'],
|
52 |
-
'quality': [],
|
53 |
-
'human_bot': [
|
54 |
-
'h2oai/h2ogpt-oasst1-512-12b',
|
55 |
-
'h2oai/h2ogpt-oasst1-512-20b',
|
56 |
-
'h2oai/h2ogpt-oig-oasst1-256-6_9b',
|
57 |
-
'h2oai/h2ogpt-oig-oasst1-512-6_9b',
|
58 |
-
'h2oai/h2ogpt-oig-oasst1-256-6.9b', # legacy
|
59 |
-
'h2oai/h2ogpt-oig-oasst1-512-6.9b', # legacy
|
60 |
-
'h2oai/h2ogpt-research-oasst1-512-30b',
|
61 |
-
'h2oai/h2ogpt-research-oasst1-llama-65b',
|
62 |
-
'h2oai/h2ogpt-oasst1-falcon-40b',
|
63 |
-
'h2oai/h2ogpt-oig-oasst1-falcon-40b',
|
64 |
-
],
|
65 |
-
'dai_faq': [],
|
66 |
-
'summarize': [],
|
67 |
-
'simple_instruct': ['t5-small', 't5-large', 'google/flan-t5', 'google/flan-t5-xxl', 'google/flan-ul2'],
|
68 |
-
'instruct_vicuna': ['AlekseyKorshuk/vicuna-7b', 'TheBloke/stable-vicuna-13B-HF', 'junelee/wizard-vicuna-13b'],
|
69 |
-
'human_bot_orig': ['togethercomputer/GPT-NeoXT-Chat-Base-20B'],
|
70 |
-
"open_assistant": ['OpenAssistant/oasst-sft-7-llama-30b-xor', 'oasst-sft-7-llama-30b'],
|
71 |
-
"wizard_lm": ['ehartford/WizardLM-7B-Uncensored', 'ehartford/WizardLM-13B-Uncensored'],
|
72 |
-
"wizard_mega": ['openaccess-ai-collective/wizard-mega-13b'],
|
73 |
-
"instruct_simple": ['JosephusCheung/Guanaco'],
|
74 |
-
"wizard_vicuna": ['ehartford/Wizard-Vicuna-13B-Uncensored'],
|
75 |
-
"wizard2": ['llama'],
|
76 |
-
"mptinstruct": ['mosaicml/mpt-30b-instruct', 'mosaicml/mpt-7b-instruct', 'mosaicml/mpt-30b-instruct'],
|
77 |
-
"mptchat": ['mosaicml/mpt-7b-chat', 'mosaicml/mpt-30b-chat', 'TheBloke/mpt-30B-chat-GGML'],
|
78 |
-
"vicuna11": ['lmsys/vicuna-33b-v1.3'],
|
79 |
-
"falcon": ['tiiuae/falcon-40b-instruct', 'tiiuae/falcon-40b', 'tiiuae/falcon-7b-instruct', 'tiiuae/falcon-7b'],
|
80 |
-
"llama2": [
|
81 |
-
'meta-llama/Llama-2-7b-chat-hf',
|
82 |
-
'meta-llama/Llama-2-13b-chat-hf',
|
83 |
-
'meta-llama/Llama-2-34b-chat-hf',
|
84 |
-
'meta-llama/Llama-2-70b-chat-hf',
|
85 |
-
],
|
86 |
-
# could be plain, but default is correct prompt_type for default TheBloke model ggml-wizardLM-7B.q4_2.bin
|
87 |
-
}
|
88 |
-
if os.getenv('OPENAI_API_KEY'):
|
89 |
-
prompt_type_to_model_name.update({
|
90 |
-
"openai": ["text-davinci-003", "text-curie-001", "text-babbage-001", "text-ada-001"],
|
91 |
-
"openai_chat": ["gpt-3.5-turbo", "gpt-3.5-turbo-16k"],
|
92 |
-
})
|
93 |
-
|
94 |
-
inv_prompt_type_to_model_name = {v.strip(): k for k, l in prompt_type_to_model_name.items() for v in l}
|
95 |
-
inv_prompt_type_to_model_lower = {v.strip().lower(): k for k, l in prompt_type_to_model_name.items() for v in l}
|
96 |
-
|
97 |
-
prompt_types_strings = []
|
98 |
-
for p in PromptType:
|
99 |
-
prompt_types_strings.extend([p.name])
|
100 |
-
|
101 |
-
prompt_types = []
|
102 |
-
for p in PromptType:
|
103 |
-
prompt_types.extend([p.name, p.value, str(p.value)])
|
104 |
-
|
105 |
-
|
106 |
-
def get_prompt(prompt_type, prompt_dict, chat, context, reduced, making_context, return_dict=False):
|
107 |
-
prompt_dict_error = ''
|
108 |
-
generates_leading_space = False
|
109 |
-
|
110 |
-
if prompt_type == PromptType.custom.name and not isinstance(prompt_dict, dict):
|
111 |
-
try:
|
112 |
-
prompt_dict = ast.literal_eval(prompt_dict)
|
113 |
-
except BaseException as e:
|
114 |
-
prompt_dict_error = str(e)
|
115 |
-
if prompt_dict_error:
|
116 |
-
promptA = None
|
117 |
-
promptB = None
|
118 |
-
PreInstruct = None
|
119 |
-
PreInput = ''
|
120 |
-
PreResponse = ''
|
121 |
-
terminate_response = None
|
122 |
-
chat_sep = ''
|
123 |
-
chat_turn_sep = ''
|
124 |
-
humanstr = ''
|
125 |
-
botstr = ''
|
126 |
-
generates_leading_space = False
|
127 |
-
elif prompt_type in [PromptType.custom.value, str(PromptType.custom.value),
|
128 |
-
PromptType.custom.name]:
|
129 |
-
promptA = prompt_dict.get('promptA', '')
|
130 |
-
promptB = prompt_dict.get('promptB', '')
|
131 |
-
PreInstruct = prompt_dict.get('PreInstruct', '')
|
132 |
-
PreInput = prompt_dict.get('PreInput', '')
|
133 |
-
PreResponse = prompt_dict.get('PreResponse', '')
|
134 |
-
terminate_response = prompt_dict.get('terminate_response', None)
|
135 |
-
chat_sep = prompt_dict.get('chat_sep', '\n')
|
136 |
-
chat_turn_sep = prompt_dict.get('chat_turn_sep', '\n')
|
137 |
-
humanstr = prompt_dict.get('humanstr', '')
|
138 |
-
botstr = prompt_dict.get('botstr', '')
|
139 |
-
elif prompt_type in [PromptType.plain.value, str(PromptType.plain.value),
|
140 |
-
PromptType.plain.name]:
|
141 |
-
promptA = promptB = PreInstruct = PreInput = PreResponse = None
|
142 |
-
terminate_response = []
|
143 |
-
chat_turn_sep = chat_sep = ''
|
144 |
-
# plain should have None for human/bot, so nothing truncated out, not '' that would truncate after first token
|
145 |
-
humanstr = None
|
146 |
-
botstr = None
|
147 |
-
elif prompt_type == 'simple_instruct':
|
148 |
-
promptA = promptB = PreInstruct = PreInput = PreResponse = None
|
149 |
-
terminate_response = []
|
150 |
-
chat_turn_sep = chat_sep = '\n'
|
151 |
-
humanstr = None
|
152 |
-
botstr = None
|
153 |
-
elif prompt_type in [PromptType.instruct.value, str(PromptType.instruct.value),
|
154 |
-
PromptType.instruct.name] + [PromptType.instruct_with_end.value,
|
155 |
-
str(PromptType.instruct_with_end.value),
|
156 |
-
PromptType.instruct_with_end.name]:
|
157 |
-
promptA = 'Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n' if not (
|
158 |
-
chat and reduced) else ''
|
159 |
-
promptB = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n' if not (
|
160 |
-
chat and reduced) else ''
|
161 |
-
|
162 |
-
PreInstruct = """
|
163 |
-
### Instruction:
|
164 |
-
"""
|
165 |
-
|
166 |
-
PreInput = """
|
167 |
-
### Input:
|
168 |
-
"""
|
169 |
-
|
170 |
-
PreResponse = """
|
171 |
-
### Response:
|
172 |
-
"""
|
173 |
-
if prompt_type in [PromptType.instruct_with_end.value, str(PromptType.instruct_with_end.value),
|
174 |
-
PromptType.instruct_with_end.name]:
|
175 |
-
terminate_response = ['### End']
|
176 |
-
else:
|
177 |
-
terminate_response = None
|
178 |
-
chat_turn_sep = chat_sep = '\n'
|
179 |
-
humanstr = PreInstruct
|
180 |
-
botstr = PreResponse
|
181 |
-
elif prompt_type in [PromptType.quality.value, str(PromptType.quality.value),
|
182 |
-
PromptType.quality.name]:
|
183 |
-
promptA = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction as applied on the Input.\n' if not (
|
184 |
-
chat and reduced) else ''
|
185 |
-
promptB = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction.\n' if not (
|
186 |
-
chat and reduced) else ''
|
187 |
-
|
188 |
-
PreInstruct = """
|
189 |
-
### Instruction:
|
190 |
-
"""
|
191 |
-
|
192 |
-
PreInput = """
|
193 |
-
### Input:
|
194 |
-
"""
|
195 |
-
|
196 |
-
PreResponse = """
|
197 |
-
### Response:
|
198 |
-
"""
|
199 |
-
terminate_response = None
|
200 |
-
chat_turn_sep = chat_sep = '\n'
|
201 |
-
humanstr = PreInstruct # first thing human says
|
202 |
-
botstr = PreResponse # first thing bot says
|
203 |
-
elif prompt_type in [PromptType.human_bot.value, str(PromptType.human_bot.value),
|
204 |
-
PromptType.human_bot.name] + [PromptType.human_bot_orig.value,
|
205 |
-
str(PromptType.human_bot_orig.value),
|
206 |
-
PromptType.human_bot_orig.name]:
|
207 |
-
human = '<human>:'
|
208 |
-
bot = "<bot>:"
|
209 |
-
if reduced or context or prompt_type in [PromptType.human_bot.value, str(PromptType.human_bot.value),
|
210 |
-
PromptType.human_bot.name]:
|
211 |
-
preprompt = ''
|
212 |
-
else:
|
213 |
-
cur_date = time.strftime('%Y-%m-%d')
|
214 |
-
cur_time = time.strftime('%H:%M:%S %p %Z')
|
215 |
-
|
216 |
-
PRE_PROMPT = """\
|
217 |
-
Current Date: {}
|
218 |
-
Current Time: {}
|
219 |
-
|
220 |
-
"""
|
221 |
-
preprompt = PRE_PROMPT.format(cur_date, cur_time)
|
222 |
-
start = ''
|
223 |
-
promptB = promptA = '%s%s' % (preprompt, start)
|
224 |
-
|
225 |
-
PreInstruct = human + ' '
|
226 |
-
|
227 |
-
PreInput = None
|
228 |
-
|
229 |
-
if making_context:
|
230 |
-
# when making context, want it to appear as-if LLM generated, which starts with space after :
|
231 |
-
PreResponse = bot + ' '
|
232 |
-
else:
|
233 |
-
# normally LLM adds space after this, because was how trained.
|
234 |
-
# if add space here, non-unique tokenization will often make LLM produce wrong output
|
235 |
-
PreResponse = bot
|
236 |
-
|
237 |
-
terminate_response = ['\n' + human, '\n' + bot, human, bot, PreResponse]
|
238 |
-
chat_turn_sep = chat_sep = '\n'
|
239 |
-
humanstr = human # tag before human talks
|
240 |
-
botstr = bot # tag before bot talks
|
241 |
-
generates_leading_space = True
|
242 |
-
elif prompt_type in [PromptType.dai_faq.value, str(PromptType.dai_faq.value),
|
243 |
-
PromptType.dai_faq.name]:
|
244 |
-
promptA = ''
|
245 |
-
promptB = 'Answer the following Driverless AI question.\n'
|
246 |
-
|
247 |
-
PreInstruct = """
|
248 |
-
### Driverless AI frequently asked question:
|
249 |
-
"""
|
250 |
-
|
251 |
-
PreInput = None
|
252 |
-
|
253 |
-
PreResponse = """
|
254 |
-
### Driverless AI documentation answer:
|
255 |
-
"""
|
256 |
-
terminate_response = ['\n\n']
|
257 |
-
chat_turn_sep = chat_sep = terminate_response
|
258 |
-
humanstr = PreInstruct
|
259 |
-
botstr = PreResponse
|
260 |
-
elif prompt_type in [PromptType.summarize.value, str(PromptType.summarize.value),
|
261 |
-
PromptType.summarize.name]:
|
262 |
-
promptA = promptB = PreInput = ''
|
263 |
-
PreInstruct = '## Main Text\n\n'
|
264 |
-
PreResponse = '\n\n## Summary\n\n'
|
265 |
-
terminate_response = None
|
266 |
-
chat_turn_sep = chat_sep = '\n'
|
267 |
-
humanstr = PreInstruct
|
268 |
-
botstr = PreResponse
|
269 |
-
elif prompt_type in [PromptType.instruct_vicuna.value, str(PromptType.instruct_vicuna.value),
|
270 |
-
PromptType.instruct_vicuna.name]:
|
271 |
-
promptA = promptB = "A chat between a curious human and an artificial intelligence assistant. " \
|
272 |
-
"The assistant gives helpful, detailed, and polite answers to the human's questions." if not (
|
273 |
-
chat and reduced) else ''
|
274 |
-
|
275 |
-
PreInstruct = """
|
276 |
-
### Human:
|
277 |
-
"""
|
278 |
-
|
279 |
-
PreInput = None
|
280 |
-
|
281 |
-
PreResponse = """
|
282 |
-
### Assistant:
|
283 |
-
"""
|
284 |
-
terminate_response = [
|
285 |
-
'### Human:'] # but only allow terminate after prompt is found correctly, else can't terminate
|
286 |
-
chat_turn_sep = chat_sep = '\n'
|
287 |
-
humanstr = PreInstruct
|
288 |
-
botstr = PreResponse
|
289 |
-
elif prompt_type in [PromptType.prompt_answer.value, str(PromptType.prompt_answer.value),
|
290 |
-
PromptType.prompt_answer.name]:
|
291 |
-
preprompt = ''
|
292 |
-
prompt_tokens = "<|prompt|>"
|
293 |
-
answer_tokens = "<|answer|>"
|
294 |
-
start = ''
|
295 |
-
promptB = promptA = '%s%s' % (preprompt, start)
|
296 |
-
PreInstruct = prompt_tokens
|
297 |
-
PreInput = None
|
298 |
-
PreResponse = answer_tokens
|
299 |
-
eos = '<|endoftext|>' # neox eos
|
300 |
-
humanstr = prompt_tokens
|
301 |
-
botstr = answer_tokens
|
302 |
-
terminate_response = [humanstr, PreResponse, eos]
|
303 |
-
chat_sep = eos
|
304 |
-
chat_turn_sep = eos
|
305 |
-
elif prompt_type in [PromptType.prompt_answer_openllama.value, str(PromptType.prompt_answer_openllama.value),
|
306 |
-
PromptType.prompt_answer_openllama.name]:
|
307 |
-
preprompt = ''
|
308 |
-
prompt_tokens = "<|prompt|>"
|
309 |
-
answer_tokens = "<|answer|>"
|
310 |
-
start = ''
|
311 |
-
promptB = promptA = '%s%s' % (preprompt, start)
|
312 |
-
PreInstruct = prompt_tokens
|
313 |
-
PreInput = None
|
314 |
-
PreResponse = answer_tokens
|
315 |
-
eos = '</s>' # llama eos
|
316 |
-
humanstr = prompt_tokens
|
317 |
-
botstr = answer_tokens
|
318 |
-
terminate_response = [humanstr, PreResponse, eos]
|
319 |
-
chat_sep = eos
|
320 |
-
chat_turn_sep = eos
|
321 |
-
elif prompt_type in [PromptType.open_assistant.value, str(PromptType.open_assistant.value),
|
322 |
-
PromptType.open_assistant.name]:
|
323 |
-
# From added_tokens.json
|
324 |
-
preprompt = ''
|
325 |
-
prompt_tokens = "<|prompter|>"
|
326 |
-
answer_tokens = "<|assistant|>"
|
327 |
-
start = ''
|
328 |
-
promptB = promptA = '%s%s' % (preprompt, start)
|
329 |
-
PreInstruct = prompt_tokens
|
330 |
-
PreInput = None
|
331 |
-
PreResponse = answer_tokens
|
332 |
-
pend = "<|prefix_end|>"
|
333 |
-
eos = "</s>"
|
334 |
-
humanstr = prompt_tokens
|
335 |
-
botstr = answer_tokens
|
336 |
-
terminate_response = [humanstr, PreResponse, pend, eos]
|
337 |
-
chat_turn_sep = chat_sep = eos
|
338 |
-
elif prompt_type in [PromptType.wizard_lm.value, str(PromptType.wizard_lm.value),
|
339 |
-
PromptType.wizard_lm.name]:
|
340 |
-
# https://github.com/ehartford/WizardLM/blob/main/src/train_freeform.py
|
341 |
-
preprompt = ''
|
342 |
-
start = ''
|
343 |
-
promptB = promptA = '%s%s' % (preprompt, start)
|
344 |
-
PreInstruct = ""
|
345 |
-
PreInput = None
|
346 |
-
PreResponse = "\n\n### Response\n"
|
347 |
-
eos = "</s>"
|
348 |
-
terminate_response = [PreResponse, eos]
|
349 |
-
chat_turn_sep = chat_sep = eos
|
350 |
-
humanstr = promptA
|
351 |
-
botstr = PreResponse
|
352 |
-
elif prompt_type in [PromptType.wizard_mega.value, str(PromptType.wizard_mega.value),
|
353 |
-
PromptType.wizard_mega.name]:
|
354 |
-
preprompt = ''
|
355 |
-
start = ''
|
356 |
-
promptB = promptA = '%s%s' % (preprompt, start)
|
357 |
-
PreInstruct = """
|
358 |
-
### Instruction:
|
359 |
-
"""
|
360 |
-
PreInput = None
|
361 |
-
PreResponse = """
|
362 |
-
### Assistant:
|
363 |
-
"""
|
364 |
-
terminate_response = [PreResponse]
|
365 |
-
chat_turn_sep = chat_sep = '\n'
|
366 |
-
humanstr = PreInstruct
|
367 |
-
botstr = PreResponse
|
368 |
-
elif prompt_type in [PromptType.instruct_vicuna2.value, str(PromptType.instruct_vicuna2.value),
|
369 |
-
PromptType.instruct_vicuna2.name]:
|
370 |
-
promptA = promptB = "" if not (chat and reduced) else ''
|
371 |
-
|
372 |
-
PreInstruct = """
|
373 |
-
HUMAN:
|
374 |
-
"""
|
375 |
-
|
376 |
-
PreInput = None
|
377 |
-
|
378 |
-
PreResponse = """
|
379 |
-
ASSISTANT:
|
380 |
-
"""
|
381 |
-
terminate_response = [
|
382 |
-
'HUMAN:'] # but only allow terminate after prompt is found correctly, else can't terminate
|
383 |
-
chat_turn_sep = chat_sep = '\n'
|
384 |
-
humanstr = PreInstruct
|
385 |
-
botstr = PreResponse
|
386 |
-
elif prompt_type in [PromptType.instruct_vicuna3.value, str(PromptType.instruct_vicuna3.value),
|
387 |
-
PromptType.instruct_vicuna3.name]:
|
388 |
-
promptA = promptB = "" if not (chat and reduced) else ''
|
389 |
-
|
390 |
-
PreInstruct = """
|
391 |
-
### User:
|
392 |
-
"""
|
393 |
-
|
394 |
-
PreInput = None
|
395 |
-
|
396 |
-
PreResponse = """
|
397 |
-
### Assistant:
|
398 |
-
"""
|
399 |
-
terminate_response = [
|
400 |
-
'### User:'] # but only allow terminate after prompt is found correctly, else can't terminate
|
401 |
-
chat_turn_sep = chat_sep = '\n'
|
402 |
-
humanstr = PreInstruct
|
403 |
-
botstr = PreResponse
|
404 |
-
elif prompt_type in [PromptType.wizard2.value, str(PromptType.wizard2.value),
|
405 |
-
PromptType.wizard2.name]:
|
406 |
-
# https://huggingface.co/TheBloke/WizardLM-7B-uncensored-GGML
|
407 |
-
preprompt = """Below is an instruction that describes a task. Write a response that appropriately completes the request.""" if not (
|
408 |
-
chat and reduced) else ''
|
409 |
-
start = ''
|
410 |
-
promptB = promptA = '%s%s' % (preprompt, start)
|
411 |
-
PreInstruct = """
|
412 |
-
### Instruction:
|
413 |
-
"""
|
414 |
-
PreInput = None
|
415 |
-
PreResponse = """
|
416 |
-
### Response:
|
417 |
-
"""
|
418 |
-
terminate_response = [PreResponse]
|
419 |
-
chat_turn_sep = chat_sep = '\n'
|
420 |
-
humanstr = PreInstruct
|
421 |
-
botstr = PreResponse
|
422 |
-
elif prompt_type in [PromptType.wizard3.value, str(PromptType.wizard3.value),
|
423 |
-
PromptType.wizard3.name]:
|
424 |
-
# https://huggingface.co/TheBloke/wizardLM-13B-1.0-GGML
|
425 |
-
preprompt = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.""" if not (
|
426 |
-
chat and reduced) else ''
|
427 |
-
start = ''
|
428 |
-
promptB = promptA = '%s%s' % (preprompt, start)
|
429 |
-
PreInstruct = """USER: """
|
430 |
-
PreInput = None
|
431 |
-
PreResponse = """ASSISTANT: """
|
432 |
-
terminate_response = [PreResponse]
|
433 |
-
chat_turn_sep = chat_sep = '\n'
|
434 |
-
humanstr = PreInstruct
|
435 |
-
botstr = PreResponse
|
436 |
-
elif prompt_type in [PromptType.wizard_vicuna.value, str(PromptType.wizard_vicuna.value),
|
437 |
-
PromptType.wizard_vicuna.name]:
|
438 |
-
preprompt = ''
|
439 |
-
start = ''
|
440 |
-
promptB = promptA = '%s%s' % (preprompt, start)
|
441 |
-
PreInstruct = """USER: """
|
442 |
-
PreInput = None
|
443 |
-
PreResponse = """ASSISTANT: """
|
444 |
-
terminate_response = [PreResponse]
|
445 |
-
chat_turn_sep = chat_sep = '\n'
|
446 |
-
humanstr = PreInstruct
|
447 |
-
botstr = PreResponse
|
448 |
-
|
449 |
-
elif prompt_type in [PromptType.instruct_simple.value, str(PromptType.instruct_simple.value),
|
450 |
-
PromptType.instruct_simple.name]:
|
451 |
-
promptB = promptA = '' if not (chat and reduced) else ''
|
452 |
-
|
453 |
-
PreInstruct = """
|
454 |
-
### Instruction:
|
455 |
-
"""
|
456 |
-
|
457 |
-
PreInput = """
|
458 |
-
### Input:
|
459 |
-
"""
|
460 |
-
|
461 |
-
PreResponse = """
|
462 |
-
### Response:
|
463 |
-
"""
|
464 |
-
terminate_response = None
|
465 |
-
chat_turn_sep = chat_sep = '\n'
|
466 |
-
humanstr = PreInstruct
|
467 |
-
botstr = PreResponse
|
468 |
-
elif prompt_type in [PromptType.openai.value, str(PromptType.openai.value),
|
469 |
-
PromptType.openai.name]:
|
470 |
-
preprompt = """The following is a conversation with an AI assistant. The assistant is helpful, creative, clever, and very friendly.""" if not (
|
471 |
-
chat and reduced) else ''
|
472 |
-
start = ''
|
473 |
-
promptB = promptA = '%s%s' % (preprompt, start)
|
474 |
-
PreInstruct = "\nHuman: "
|
475 |
-
PreInput = None
|
476 |
-
PreResponse = "\nAI:"
|
477 |
-
terminate_response = [PreResponse] + [" Human:", " AI:"]
|
478 |
-
chat_turn_sep = chat_sep = '\n'
|
479 |
-
humanstr = PreInstruct
|
480 |
-
botstr = PreResponse
|
481 |
-
elif prompt_type in [PromptType.gptj.value, str(PromptType.gptj.value),
|
482 |
-
PromptType.gptj.name]:
|
483 |
-
preprompt = "### Instruction:\n The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response." if not (
|
484 |
-
chat and reduced) else ''
|
485 |
-
start = ''
|
486 |
-
promptB = promptA = '%s%s' % (preprompt, start)
|
487 |
-
PreInstruct = "\n### Prompt: "
|
488 |
-
PreInput = None
|
489 |
-
PreResponse = "\n### Response: "
|
490 |
-
terminate_response = [PreResponse] + ["Prompt:", "Response:"]
|
491 |
-
chat_turn_sep = chat_sep = '\n'
|
492 |
-
humanstr = PreInstruct
|
493 |
-
botstr = PreResponse
|
494 |
-
elif prompt_type in [PromptType.openai_chat.value, str(PromptType.openai_chat.value),
|
495 |
-
PromptType.openai_chat.name]:
|
496 |
-
# prompting and termination all handled by endpoint
|
497 |
-
preprompt = """"""
|
498 |
-
start = ''
|
499 |
-
promptB = promptA = '%s%s' % (preprompt, start)
|
500 |
-
PreInstruct = ""
|
501 |
-
PreInput = None
|
502 |
-
PreResponse = ""
|
503 |
-
terminate_response = []
|
504 |
-
chat_turn_sep = chat_sep = '\n'
|
505 |
-
humanstr = None
|
506 |
-
botstr = None
|
507 |
-
elif prompt_type in [PromptType.vicuna11.value, str(PromptType.vicuna11.value),
|
508 |
-
PromptType.vicuna11.name]:
|
509 |
-
preprompt = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. """ if not (
|
510 |
-
chat and reduced) else ''
|
511 |
-
start = ''
|
512 |
-
promptB = promptA = '%s%s' % (preprompt, start)
|
513 |
-
eos = '</s>'
|
514 |
-
PreInstruct = """USER: """
|
515 |
-
PreInput = None
|
516 |
-
PreResponse = """ASSISTANT:"""
|
517 |
-
terminate_response = [PreResponse]
|
518 |
-
chat_sep = ' '
|
519 |
-
chat_turn_sep = eos
|
520 |
-
humanstr = PreInstruct
|
521 |
-
botstr = PreResponse
|
522 |
-
|
523 |
-
if making_context:
|
524 |
-
# when making context, want it to appear as-if LLM generated, which starts with space after :
|
525 |
-
PreResponse = PreResponse + ' '
|
526 |
-
else:
|
527 |
-
# normally LLM adds space after this, because was how trained.
|
528 |
-
# if add space here, non-unique tokenization will often make LLM produce wrong output
|
529 |
-
PreResponse = PreResponse
|
530 |
-
elif prompt_type in [PromptType.mptinstruct.value, str(PromptType.mptinstruct.value),
|
531 |
-
PromptType.mptinstruct.name]:
|
532 |
-
# https://huggingface.co/mosaicml/mpt-30b-instruct#formatting
|
533 |
-
promptA = promptB = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n' if not (
|
534 |
-
chat and reduced) else ''
|
535 |
-
|
536 |
-
PreInstruct = """
|
537 |
-
### Instruction
|
538 |
-
"""
|
539 |
-
|
540 |
-
PreInput = """
|
541 |
-
### Input
|
542 |
-
"""
|
543 |
-
|
544 |
-
PreResponse = """
|
545 |
-
### Response
|
546 |
-
"""
|
547 |
-
terminate_response = None
|
548 |
-
chat_turn_sep = chat_sep = '\n'
|
549 |
-
humanstr = PreInstruct
|
550 |
-
botstr = PreResponse
|
551 |
-
elif prompt_type in [PromptType.mptchat.value, str(PromptType.mptchat.value),
|
552 |
-
PromptType.mptchat.name]:
|
553 |
-
# https://huggingface.co/TheBloke/mpt-30B-chat-GGML#prompt-template
|
554 |
-
promptA = promptB = """<|im_start|>system\nA conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.\n<|im_end|>""" if not (
|
555 |
-
chat and reduced) else ''
|
556 |
-
|
557 |
-
PreInstruct = """<|im_start|>user
|
558 |
-
"""
|
559 |
-
|
560 |
-
PreInput = None
|
561 |
-
|
562 |
-
PreResponse = """<|im_end|><|im_start|>assistant
|
563 |
-
"""
|
564 |
-
terminate_response = ['<|im_end|>']
|
565 |
-
chat_sep = ''
|
566 |
-
chat_turn_sep = '<|im_end|>'
|
567 |
-
humanstr = PreInstruct
|
568 |
-
botstr = PreResponse
|
569 |
-
elif prompt_type in [PromptType.falcon.value, str(PromptType.falcon.value),
|
570 |
-
PromptType.falcon.name]:
|
571 |
-
promptA = promptB = "" if not (chat and reduced) else ''
|
572 |
-
|
573 |
-
PreInstruct = """User: """
|
574 |
-
|
575 |
-
PreInput = None
|
576 |
-
|
577 |
-
PreResponse = """Assistant:"""
|
578 |
-
terminate_response = ['\nUser', "<|endoftext|>"]
|
579 |
-
chat_sep = '\n\n'
|
580 |
-
chat_turn_sep = '\n\n'
|
581 |
-
humanstr = PreInstruct
|
582 |
-
botstr = PreResponse
|
583 |
-
if making_context:
|
584 |
-
# when making context, want it to appear as-if LLM generated, which starts with space after :
|
585 |
-
PreResponse = 'Assistant: '
|
586 |
-
else:
|
587 |
-
# normally LLM adds space after this, because was how trained.
|
588 |
-
# if add space here, non-unique tokenization will often make LLM produce wrong output
|
589 |
-
PreResponse = PreResponse
|
590 |
-
# generates_leading_space = True
|
591 |
-
elif prompt_type in [PromptType.guanaco.value, str(PromptType.guanaco.value),
|
592 |
-
PromptType.guanaco.name]:
|
593 |
-
# https://huggingface.co/TheBloke/guanaco-65B-GPTQ
|
594 |
-
promptA = promptB = "" if not (chat and reduced) else ''
|
595 |
-
|
596 |
-
PreInstruct = """### Human: """
|
597 |
-
|
598 |
-
PreInput = None
|
599 |
-
|
600 |
-
PreResponse = """### Assistant:"""
|
601 |
-
terminate_response = ['### Human:'] # but only allow terminate after prompt is found correctly, else can't terminate
|
602 |
-
chat_turn_sep = chat_sep = '\n'
|
603 |
-
humanstr = PreInstruct
|
604 |
-
botstr = PreResponse
|
605 |
-
elif prompt_type in [PromptType.llama2.value, str(PromptType.llama2.value),
|
606 |
-
PromptType.llama2.name]:
|
607 |
-
PreInstruct = ""
|
608 |
-
llama2_sys = "<<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n"
|
609 |
-
prompt = "<s>[INST] "
|
610 |
-
enable_sys = False # too much safety, hurts accuracy
|
611 |
-
if not (chat and reduced):
|
612 |
-
if enable_sys:
|
613 |
-
promptA = promptB = prompt + llama2_sys
|
614 |
-
else:
|
615 |
-
promptA = promptB = prompt
|
616 |
-
else:
|
617 |
-
promptA = promptB = ''
|
618 |
-
PreInput = None
|
619 |
-
PreResponse = ""
|
620 |
-
terminate_response = ["[INST]", "</s>"]
|
621 |
-
chat_sep = ' [/INST]'
|
622 |
-
chat_turn_sep = ' </s><s>[INST] '
|
623 |
-
humanstr = PreInstruct
|
624 |
-
botstr = PreResponse
|
625 |
-
if making_context:
|
626 |
-
PreResponse += " "
|
627 |
-
else:
|
628 |
-
raise RuntimeError("No such prompt_type=%s" % prompt_type)
|
629 |
-
|
630 |
-
if isinstance(terminate_response, (tuple, list)):
|
631 |
-
assert '' not in terminate_response, "Bad terminate_response"
|
632 |
-
|
633 |
-
ret_dict = dict(promptA=promptA, promptB=promptB, PreInstruct=PreInstruct, PreInput=PreInput,
|
634 |
-
PreResponse=PreResponse, terminate_response=terminate_response, chat_sep=chat_sep,
|
635 |
-
chat_turn_sep=chat_turn_sep,
|
636 |
-
humanstr=humanstr, botstr=botstr,
|
637 |
-
generates_leading_space=generates_leading_space)
|
638 |
-
|
639 |
-
if return_dict:
|
640 |
-
return ret_dict, prompt_dict_error
|
641 |
-
else:
|
642 |
-
return tuple(list(ret_dict.values()))
|
643 |
-
|
644 |
-
|
645 |
-
def generate_prompt(data_point, prompt_type, prompt_dict, chat, reduced, making_context):
|
646 |
-
context = data_point.get('context')
|
647 |
-
if context is None:
|
648 |
-
context = ''
|
649 |
-
instruction = data_point.get('instruction')
|
650 |
-
input = data_point.get('input')
|
651 |
-
output = data_point.get('output')
|
652 |
-
prompt_type = data_point.get('prompt_type', prompt_type)
|
653 |
-
prompt_dict = data_point.get('prompt_dict', prompt_dict)
|
654 |
-
assert prompt_type in prompt_types, "Bad prompt type: %s" % prompt_type
|
655 |
-
promptA, promptB, PreInstruct, PreInput, PreResponse, \
|
656 |
-
terminate_response, chat_sep, chat_turn_sep, humanstr, botstr, \
|
657 |
-
generates_leading_space = get_prompt(prompt_type, prompt_dict, chat,
|
658 |
-
context, reduced, making_context)
|
659 |
-
|
660 |
-
# could avoid if reduce=True, but too complex for parent functions to handle
|
661 |
-
prompt = context
|
662 |
-
|
663 |
-
if input and promptA:
|
664 |
-
prompt += f"""{promptA}"""
|
665 |
-
elif promptB:
|
666 |
-
prompt += f"""{promptB}"""
|
667 |
-
|
668 |
-
if instruction and PreInstruct is not None and input and PreInput is not None:
|
669 |
-
prompt += f"""{PreInstruct}{instruction}{PreInput}{input}"""
|
670 |
-
prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
|
671 |
-
elif instruction and input and PreInstruct is None and PreInput is not None:
|
672 |
-
prompt += f"""{PreInput}{instruction}
|
673 |
-
{input}"""
|
674 |
-
prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
|
675 |
-
elif input and instruction and PreInput is None and PreInstruct is not None:
|
676 |
-
prompt += f"""{PreInstruct}{instruction}
|
677 |
-
{input}"""
|
678 |
-
prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
|
679 |
-
elif instruction and PreInstruct is not None:
|
680 |
-
prompt += f"""{PreInstruct}{instruction}"""
|
681 |
-
prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
|
682 |
-
elif input and PreInput is not None:
|
683 |
-
prompt += f"""{PreInput}{input}"""
|
684 |
-
prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
|
685 |
-
elif input and instruction and PreInput is not None:
|
686 |
-
prompt += f"""{PreInput}{instruction}{input}"""
|
687 |
-
prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
|
688 |
-
elif input and instruction and PreInstruct is not None:
|
689 |
-
prompt += f"""{PreInstruct}{instruction}{input}"""
|
690 |
-
prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
|
691 |
-
elif input and instruction:
|
692 |
-
# i.e. for simple_instruct
|
693 |
-
prompt += f"""{instruction}: {input}"""
|
694 |
-
prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
|
695 |
-
elif input:
|
696 |
-
prompt += f"""{input}"""
|
697 |
-
prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
|
698 |
-
elif instruction:
|
699 |
-
prompt += f"""{instruction}"""
|
700 |
-
prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
|
701 |
-
|
702 |
-
if PreResponse is not None:
|
703 |
-
prompt += f"""{PreResponse}"""
|
704 |
-
pre_response = PreResponse # Don't use strip
|
705 |
-
else:
|
706 |
-
pre_response = ''
|
707 |
-
|
708 |
-
if output:
|
709 |
-
prompt += f"""{output}"""
|
710 |
-
|
711 |
-
return prompt, pre_response, terminate_response, chat_sep, chat_turn_sep
|
712 |
-
|
713 |
-
|
714 |
-
def inject_chatsep(prompt_type, prompt, chat_sep=None):
|
715 |
-
if chat_sep:
|
716 |
-
# only add new line if structured prompt, while 'plain' is just generation of next tokens from input
|
717 |
-
prompt += chat_sep
|
718 |
-
return prompt
|
719 |
-
|
720 |
-
|
721 |
-
class Prompter(object):
|
722 |
-
def __init__(self, prompt_type, prompt_dict, debug=False, chat=False, stream_output=False, repeat_penalty=True,
|
723 |
-
allowed_repeat_line_length=10):
|
724 |
-
self.prompt_type = prompt_type
|
725 |
-
self.prompt_dict = prompt_dict
|
726 |
-
self.debug = debug
|
727 |
-
self.chat = chat
|
728 |
-
self.stream_output = stream_output
|
729 |
-
self.repeat_penalty = repeat_penalty
|
730 |
-
self.allowed_repeat_line_length = allowed_repeat_line_length
|
731 |
-
self.prompt = None
|
732 |
-
context = "" # not for chat context
|
733 |
-
reduced = False # not for chat context
|
734 |
-
making_context = False # not for chat context
|
735 |
-
self.promptA, self.promptB, self.PreInstruct, self.PreInput, self.PreResponse, \
|
736 |
-
self.terminate_response, self.chat_sep, self.chat_turn_sep, self.humanstr, self.botstr, \
|
737 |
-
self.generates_leading_space = \
|
738 |
-
get_prompt(self.prompt_type, self.prompt_dict, chat, context, reduced, making_context)
|
739 |
-
self.pre_response = self.PreResponse
|
740 |
-
|
741 |
-
def generate_prompt(self, data_point, reduced=None):
|
742 |
-
"""
|
743 |
-
data_point['context'] is assumed to be like a system prompt or pre-conversation, not inserted after user prompt
|
744 |
-
:param data_point:
|
745 |
-
:param reduced:
|
746 |
-
:return:
|
747 |
-
"""
|
748 |
-
reduced = data_point.get('context') not in ['', None] if reduced is None else reduced
|
749 |
-
making_context = False # whether really making final prompt or just generating context
|
750 |
-
prompt, _, _, _, _ = generate_prompt(data_point, self.prompt_type, self.prompt_dict, self.chat, reduced,
|
751 |
-
making_context)
|
752 |
-
if self.debug:
|
753 |
-
print("prompt: %s" % prompt, flush=True)
|
754 |
-
# if have context, should have always reduced and only preappend promptA/B here
|
755 |
-
if data_point.get('context'):
|
756 |
-
if data_point.get('input') and self.promptA:
|
757 |
-
prompt = self.promptA + prompt
|
758 |
-
elif self.promptB:
|
759 |
-
prompt = self.promptB + prompt
|
760 |
-
|
761 |
-
self.prompt = prompt
|
762 |
-
return prompt
|
763 |
-
|
764 |
-
def get_response(self, outputs, prompt=None, sanitize_bot_response=False):
|
765 |
-
if isinstance(outputs, str):
|
766 |
-
outputs = [outputs]
|
767 |
-
if self.debug:
|
768 |
-
print("output:\n%s" % '\n\n'.join(outputs), flush=True)
|
769 |
-
if prompt is not None:
|
770 |
-
self.prompt = prompt
|
771 |
-
|
772 |
-
def clean_response(response):
|
773 |
-
meaningless_words = ['<pad>', '</s>', '<|endoftext|>']
|
774 |
-
for word in meaningless_words:
|
775 |
-
response = response.replace(word, "")
|
776 |
-
if sanitize_bot_response:
|
777 |
-
from better_profanity import profanity
|
778 |
-
response = profanity.censor(response)
|
779 |
-
if self.generates_leading_space and isinstance(response, str) and len(response) > 0 and response[0] == ' ':
|
780 |
-
response = response[1:]
|
781 |
-
return response
|
782 |
-
|
783 |
-
def clean_repeats(response):
|
784 |
-
lines = response.split('\n')
|
785 |
-
new_lines = []
|
786 |
-
[new_lines.append(line) for line in lines if
|
787 |
-
line not in new_lines or len(line) < self.allowed_repeat_line_length]
|
788 |
-
if self.debug and len(lines) != len(new_lines):
|
789 |
-
print("cleaned repeats: %s %s" % (len(lines), len(new_lines)), flush=True)
|
790 |
-
response = '\n'.join(new_lines)
|
791 |
-
return response
|
792 |
-
|
793 |
-
multi_output = len(outputs) > 1
|
794 |
-
|
795 |
-
for oi, output in enumerate(outputs):
|
796 |
-
if self.prompt_type in [PromptType.plain.value, str(PromptType.plain.value), PromptType.plain.name]:
|
797 |
-
output = clean_response(output)
|
798 |
-
elif prompt is None:
|
799 |
-
# then use most basic parsing like pipeline
|
800 |
-
if not self.botstr:
|
801 |
-
pass
|
802 |
-
elif self.botstr in output:
|
803 |
-
if self.humanstr:
|
804 |
-
output = clean_response(output.split(self.botstr)[1].split(self.humanstr)[0])
|
805 |
-
else:
|
806 |
-
# i.e. use after bot but only up to next bot
|
807 |
-
output = clean_response(output.split(self.botstr)[1].split(self.botstr)[0])
|
808 |
-
else:
|
809 |
-
# output = clean_response(output)
|
810 |
-
# assume just not printed yet
|
811 |
-
output = ""
|
812 |
-
else:
|
813 |
-
# find first instance of prereponse
|
814 |
-
# prompt sometimes has odd characters, that mutate length,
|
815 |
-
# so can't go by length alone
|
816 |
-
if self.pre_response:
|
817 |
-
outputi = output.find(prompt)
|
818 |
-
if outputi >= 0:
|
819 |
-
output = output[outputi + len(prompt):]
|
820 |
-
allow_terminate = True
|
821 |
-
else:
|
822 |
-
# subtraction is risky due to space offsets sometimes, so only do if necessary
|
823 |
-
output = output[len(prompt) - len(self.pre_response):]
|
824 |
-
# [1] to avoid repeated pre_response, just take first (after prompt - pre_response for chat)
|
825 |
-
if self.pre_response in output:
|
826 |
-
output = output.split(self.pre_response)[1]
|
827 |
-
allow_terminate = True
|
828 |
-
else:
|
829 |
-
if output:
|
830 |
-
print("Failure of parsing or not enough output yet: %s" % output, flush=True)
|
831 |
-
allow_terminate = False
|
832 |
-
else:
|
833 |
-
allow_terminate = True
|
834 |
-
output = output[len(prompt):]
|
835 |
-
# clean after subtract prompt out, so correct removal of pre_response
|
836 |
-
output = clean_response(output)
|
837 |
-
if self.repeat_penalty:
|
838 |
-
output = clean_repeats(output)
|
839 |
-
if self.terminate_response and allow_terminate:
|
840 |
-
finds = []
|
841 |
-
for term in self.terminate_response:
|
842 |
-
finds.append(output.find(term))
|
843 |
-
finds = [x for x in finds if x >= 0]
|
844 |
-
if len(finds) > 0:
|
845 |
-
termi = finds[0]
|
846 |
-
output = output[:termi]
|
847 |
-
else:
|
848 |
-
output = output
|
849 |
-
if multi_output:
|
850 |
-
# prefix with output counter
|
851 |
-
output = "\n=========== Output %d\n\n" % (1 + oi) + output
|
852 |
-
if oi > 0:
|
853 |
-
# post fix outputs with seperator
|
854 |
-
output += '\n'
|
855 |
-
output = self.fix_text(self.prompt_type, output)
|
856 |
-
outputs[oi] = output
|
857 |
-
# join all outputs, only one extra new line between outputs
|
858 |
-
output = '\n'.join(outputs)
|
859 |
-
if self.debug:
|
860 |
-
print("outputclean:\n%s" % '\n\n'.join(outputs), flush=True)
|
861 |
-
return output
|
862 |
-
|
863 |
-
@staticmethod
|
864 |
-
def fix_text(prompt_type1, text1):
|
865 |
-
if prompt_type1 == 'human_bot':
|
866 |
-
# hack bug in vLLM with stopping, stops right, but doesn't return last token
|
867 |
-
hfix = '<human'
|
868 |
-
if text1.endswith(hfix):
|
869 |
-
text1 = text1[:-len(hfix)]
|
870 |
-
return text1
|
871 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
stopping.py
DELETED
@@ -1,78 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from transformers import StoppingCriteria, StoppingCriteriaList
|
3 |
-
|
4 |
-
from enums import PromptType
|
5 |
-
|
6 |
-
|
7 |
-
class StoppingCriteriaSub(StoppingCriteria):
|
8 |
-
|
9 |
-
def __init__(self, stops=[], encounters=[], device="cuda", model_max_length=None):
|
10 |
-
super().__init__()
|
11 |
-
assert len(stops) % len(encounters) == 0, "Number of stops and encounters must match"
|
12 |
-
self.encounters = encounters
|
13 |
-
self.stops = [stop.to(device) for stop in stops]
|
14 |
-
self.num_stops = [0] * len(stops)
|
15 |
-
self.model_max_length = model_max_length
|
16 |
-
|
17 |
-
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
18 |
-
for stopi, stop in enumerate(self.stops):
|
19 |
-
if torch.all((stop == input_ids[0][-len(stop):])).item():
|
20 |
-
self.num_stops[stopi] += 1
|
21 |
-
if self.num_stops[stopi] >= self.encounters[stopi % len(self.encounters)]:
|
22 |
-
# print("Stopped", flush=True)
|
23 |
-
return True
|
24 |
-
if self.model_max_length is not None and input_ids[0].shape[0] >= self.model_max_length:
|
25 |
-
# critical limit
|
26 |
-
return True
|
27 |
-
# print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True)
|
28 |
-
# print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True)
|
29 |
-
return False
|
30 |
-
|
31 |
-
|
32 |
-
def get_stopping(prompt_type, prompt_dict, tokenizer, device, human='<human>:', bot="<bot>:", model_max_length=None):
|
33 |
-
# FIXME: prompt_dict unused currently
|
34 |
-
if prompt_type in [PromptType.human_bot.name, PromptType.instruct_vicuna.name, PromptType.instruct_with_end.name]:
|
35 |
-
if prompt_type == PromptType.human_bot.name:
|
36 |
-
# encounters = [prompt.count(human) + 1, prompt.count(bot) + 1]
|
37 |
-
# stopping only starts once output is beyond prompt
|
38 |
-
# 1 human is enough to trigger, but need 2 bots, because very first view back will be bot we added
|
39 |
-
stop_words = [human, bot, '\n' + human, '\n' + bot]
|
40 |
-
encounters = [1, 2]
|
41 |
-
elif prompt_type == PromptType.instruct_vicuna.name:
|
42 |
-
# even below is not enough, generic strings and many ways to encode
|
43 |
-
stop_words = [
|
44 |
-
'### Human:',
|
45 |
-
"""
|
46 |
-
### Human:""",
|
47 |
-
"""
|
48 |
-
### Human:
|
49 |
-
""",
|
50 |
-
'### Assistant:',
|
51 |
-
"""
|
52 |
-
### Assistant:""",
|
53 |
-
"""
|
54 |
-
### Assistant:
|
55 |
-
""",
|
56 |
-
]
|
57 |
-
encounters = [1, 2]
|
58 |
-
else:
|
59 |
-
# some instruct prompts have this as end, doesn't hurt to stop on it since not common otherwise
|
60 |
-
stop_words = ['### End']
|
61 |
-
encounters = [1]
|
62 |
-
stop_words_ids = [
|
63 |
-
tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words]
|
64 |
-
# handle single token case
|
65 |
-
stop_words_ids = [x if len(x.shape) > 0 else torch.tensor([x]) for x in stop_words_ids]
|
66 |
-
stop_words_ids = [x for x in stop_words_ids if x.shape[0] > 0]
|
67 |
-
# avoid padding in front of tokens
|
68 |
-
if tokenizer._pad_token: # use hidden variable to avoid annoying properly logger bug
|
69 |
-
stop_words_ids = [x[1:] if x[0] == tokenizer.pad_token_id and len(x) > 1 else x for x in stop_words_ids]
|
70 |
-
# handle fake \n added
|
71 |
-
stop_words_ids = [x[1:] if y[0] == '\n' else x for x, y in zip(stop_words_ids, stop_words)]
|
72 |
-
# build stopper
|
73 |
-
stopping_criteria = StoppingCriteriaList(
|
74 |
-
[StoppingCriteriaSub(stops=stop_words_ids, encounters=encounters, device=device,
|
75 |
-
model_max_length=model_max_length)])
|
76 |
-
else:
|
77 |
-
stopping_criteria = StoppingCriteriaList()
|
78 |
-
return stopping_criteria
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils.py
DELETED
@@ -1,1080 +0,0 @@
|
|
1 |
-
import contextlib
|
2 |
-
import functools
|
3 |
-
import hashlib
|
4 |
-
import inspect
|
5 |
-
import os
|
6 |
-
import gc
|
7 |
-
import pathlib
|
8 |
-
import pickle
|
9 |
-
import random
|
10 |
-
import shutil
|
11 |
-
import subprocess
|
12 |
-
import sys
|
13 |
-
import threading
|
14 |
-
import time
|
15 |
-
import traceback
|
16 |
-
import zipfile
|
17 |
-
from datetime import datetime
|
18 |
-
|
19 |
-
import filelock
|
20 |
-
import requests, uuid
|
21 |
-
from typing import Tuple, Callable, Dict
|
22 |
-
from tqdm.auto import tqdm
|
23 |
-
from joblib import Parallel
|
24 |
-
from concurrent.futures import ProcessPoolExecutor
|
25 |
-
import numpy as np
|
26 |
-
import pandas as pd
|
27 |
-
|
28 |
-
|
29 |
-
def set_seed(seed: int):
|
30 |
-
"""
|
31 |
-
Sets the seed of the entire notebook so results are the same every time we run.
|
32 |
-
This is for REPRODUCIBILITY.
|
33 |
-
"""
|
34 |
-
import torch
|
35 |
-
np.random.seed(seed)
|
36 |
-
random_state = np.random.RandomState(seed)
|
37 |
-
random.seed(seed)
|
38 |
-
torch.manual_seed(seed)
|
39 |
-
torch.cuda.manual_seed(seed)
|
40 |
-
torch.backends.cudnn.deterministic = True
|
41 |
-
torch.backends.cudnn.benchmark = False
|
42 |
-
os.environ['PYTHONHASHSEED'] = str(seed)
|
43 |
-
return random_state
|
44 |
-
|
45 |
-
|
46 |
-
def flatten_list(lis):
|
47 |
-
"""Given a list, possibly nested to any level, return it flattened."""
|
48 |
-
new_lis = []
|
49 |
-
for item in lis:
|
50 |
-
if type(item) == type([]):
|
51 |
-
new_lis.extend(flatten_list(item))
|
52 |
-
else:
|
53 |
-
new_lis.append(item)
|
54 |
-
return new_lis
|
55 |
-
|
56 |
-
|
57 |
-
def clear_torch_cache():
|
58 |
-
import torch
|
59 |
-
if torch.cuda.is_available():
|
60 |
-
torch.cuda.empty_cache()
|
61 |
-
torch.cuda.ipc_collect()
|
62 |
-
gc.collect()
|
63 |
-
|
64 |
-
|
65 |
-
def ping():
|
66 |
-
try:
|
67 |
-
print('Ping: %s' % str(datetime.now()), flush=True)
|
68 |
-
except AttributeError:
|
69 |
-
# some programs wrap print and will fail with flush passed
|
70 |
-
pass
|
71 |
-
|
72 |
-
|
73 |
-
def ping_gpu():
|
74 |
-
try:
|
75 |
-
print('Ping_GPU: %s %s' % (str(datetime.now()), system_info()), flush=True)
|
76 |
-
except AttributeError:
|
77 |
-
# some programs wrap print and will fail with flush passed
|
78 |
-
pass
|
79 |
-
try:
|
80 |
-
ping_gpu_memory()
|
81 |
-
except Exception as e:
|
82 |
-
print('Ping_GPU memory failure: %s' % str(e), flush=True)
|
83 |
-
|
84 |
-
|
85 |
-
def ping_gpu_memory():
|
86 |
-
from models.gpu_mem_track import MemTracker
|
87 |
-
gpu_tracker = MemTracker() # define a GPU tracker
|
88 |
-
from torch.cuda import memory_summary
|
89 |
-
gpu_tracker.track()
|
90 |
-
|
91 |
-
|
92 |
-
def get_torch_allocated():
|
93 |
-
import torch
|
94 |
-
return torch.cuda.memory_allocated()
|
95 |
-
|
96 |
-
|
97 |
-
def get_device():
|
98 |
-
import torch
|
99 |
-
if torch.cuda.is_available():
|
100 |
-
device = "cuda"
|
101 |
-
elif torch.backends.mps.is_built():
|
102 |
-
device = "mps"
|
103 |
-
else:
|
104 |
-
device = "cpu"
|
105 |
-
|
106 |
-
return device
|
107 |
-
|
108 |
-
|
109 |
-
def system_info():
|
110 |
-
import psutil
|
111 |
-
|
112 |
-
system = {}
|
113 |
-
# https://stackoverflow.com/questions/48951136/plot-multiple-graphs-in-one-plot-using-tensorboard
|
114 |
-
# https://arshren.medium.com/monitoring-your-devices-in-python-5191d672f749
|
115 |
-
try:
|
116 |
-
temps = psutil.sensors_temperatures(fahrenheit=False)
|
117 |
-
if 'coretemp' in temps:
|
118 |
-
coretemp = temps['coretemp']
|
119 |
-
temp_dict = {k.label: k.current for k in coretemp}
|
120 |
-
for k, v in temp_dict.items():
|
121 |
-
system['CPU_C/%s' % k] = v
|
122 |
-
except AttributeError:
|
123 |
-
pass
|
124 |
-
|
125 |
-
# https://github.com/gpuopenanalytics/pynvml/blob/master/help_query_gpu.txt
|
126 |
-
try:
|
127 |
-
from pynvml.smi import nvidia_smi
|
128 |
-
nvsmi = nvidia_smi.getInstance()
|
129 |
-
|
130 |
-
gpu_power_dict = {'W_gpu%d' % i: x['power_readings']['power_draw'] for i, x in
|
131 |
-
enumerate(nvsmi.DeviceQuery('power.draw')['gpu'])}
|
132 |
-
for k, v in gpu_power_dict.items():
|
133 |
-
system['GPU_W/%s' % k] = v
|
134 |
-
|
135 |
-
gpu_temp_dict = {'C_gpu%d' % i: x['temperature']['gpu_temp'] for i, x in
|
136 |
-
enumerate(nvsmi.DeviceQuery('temperature.gpu')['gpu'])}
|
137 |
-
for k, v in gpu_temp_dict.items():
|
138 |
-
system['GPU_C/%s' % k] = v
|
139 |
-
|
140 |
-
gpu_memory_free_dict = {'MiB_gpu%d' % i: x['fb_memory_usage']['free'] for i, x in
|
141 |
-
enumerate(nvsmi.DeviceQuery('memory.free')['gpu'])}
|
142 |
-
gpu_memory_total_dict = {'MiB_gpu%d' % i: x['fb_memory_usage']['total'] for i, x in
|
143 |
-
enumerate(nvsmi.DeviceQuery('memory.total')['gpu'])}
|
144 |
-
gpu_memory_frac_dict = {k: gpu_memory_free_dict[k] / gpu_memory_total_dict[k] for k in gpu_memory_total_dict}
|
145 |
-
for k, v in gpu_memory_frac_dict.items():
|
146 |
-
system[f'GPU_M/%s' % k] = v
|
147 |
-
except (KeyError, ModuleNotFoundError):
|
148 |
-
pass
|
149 |
-
system['hash'] = get_githash()
|
150 |
-
|
151 |
-
return system
|
152 |
-
|
153 |
-
|
154 |
-
def system_info_print():
|
155 |
-
try:
|
156 |
-
df = pd.DataFrame.from_dict(system_info(), orient='index')
|
157 |
-
# avoid slamming GPUs
|
158 |
-
time.sleep(1)
|
159 |
-
return df.to_markdown()
|
160 |
-
except Exception as e:
|
161 |
-
return "Error: %s" % str(e)
|
162 |
-
|
163 |
-
|
164 |
-
def zip_data(root_dirs=None, zip_file=None, base_dir='./', fail_any_exception=False):
|
165 |
-
try:
|
166 |
-
return _zip_data(zip_file=zip_file, base_dir=base_dir, root_dirs=root_dirs)
|
167 |
-
except Exception as e:
|
168 |
-
traceback.print_exc()
|
169 |
-
print('Exception in zipping: %s' % str(e))
|
170 |
-
if not fail_any_exception:
|
171 |
-
raise
|
172 |
-
|
173 |
-
|
174 |
-
def _zip_data(root_dirs=None, zip_file=None, base_dir='./'):
|
175 |
-
if isinstance(root_dirs, str):
|
176 |
-
root_dirs = [root_dirs]
|
177 |
-
if zip_file is None:
|
178 |
-
datetime_str = str(datetime.now()).replace(" ", "_").replace(":", "_")
|
179 |
-
host_name = os.getenv('HF_HOSTNAME', 'emptyhost')
|
180 |
-
zip_file = "data_%s_%s.zip" % (datetime_str, host_name)
|
181 |
-
assert root_dirs is not None
|
182 |
-
if not os.path.isdir(os.path.dirname(zip_file)) and os.path.dirname(zip_file):
|
183 |
-
os.makedirs(os.path.dirname(zip_file), exist_ok=True)
|
184 |
-
with zipfile.ZipFile(zip_file, "w") as expt_zip:
|
185 |
-
for root_dir in root_dirs:
|
186 |
-
if root_dir is None:
|
187 |
-
continue
|
188 |
-
for root, d, files in os.walk(root_dir):
|
189 |
-
for file in files:
|
190 |
-
file_to_archive = os.path.join(root, file)
|
191 |
-
assert os.path.exists(file_to_archive)
|
192 |
-
path_to_archive = os.path.relpath(file_to_archive, base_dir)
|
193 |
-
expt_zip.write(filename=file_to_archive, arcname=path_to_archive)
|
194 |
-
return zip_file, zip_file
|
195 |
-
|
196 |
-
|
197 |
-
def save_generate_output(prompt=None, output=None, base_model=None, save_dir=None, where_from='unknown where from',
|
198 |
-
extra_dict={}):
|
199 |
-
try:
|
200 |
-
return _save_generate_output(prompt=prompt, output=output, base_model=base_model, save_dir=save_dir,
|
201 |
-
where_from=where_from, extra_dict=extra_dict)
|
202 |
-
except Exception as e:
|
203 |
-
traceback.print_exc()
|
204 |
-
print('Exception in saving: %s' % str(e))
|
205 |
-
|
206 |
-
|
207 |
-
def _save_generate_output(prompt=None, output=None, base_model=None, save_dir=None, where_from='unknown where from',
|
208 |
-
extra_dict={}):
|
209 |
-
"""
|
210 |
-
Save conversation to .json, row by row.
|
211 |
-
json_file_path is path to final JSON file. If not in ., then will attempt to make directories.
|
212 |
-
Appends if file exists
|
213 |
-
"""
|
214 |
-
prompt = '<not set>' if prompt is None else prompt
|
215 |
-
output = '<not set>' if output is None else output
|
216 |
-
assert save_dir, "save_dir must be provided"
|
217 |
-
if os.path.exists(save_dir) and not os.path.isdir(save_dir):
|
218 |
-
raise RuntimeError("save_dir already exists and is not a directory!")
|
219 |
-
os.makedirs(save_dir, exist_ok=True)
|
220 |
-
import json
|
221 |
-
dict_to_save = dict(prompt=prompt, text=output, time=time.ctime(), base_model=base_model, where_from=where_from)
|
222 |
-
dict_to_save.update(extra_dict)
|
223 |
-
with filelock.FileLock("save_dir.lock"):
|
224 |
-
# lock logging in case have concurrency
|
225 |
-
with open(os.path.join(save_dir, "history.json"), "a") as f:
|
226 |
-
# just add [ at start, and ] at end, and have proper JSON dataset
|
227 |
-
f.write(
|
228 |
-
" " + json.dumps(
|
229 |
-
dict_to_save
|
230 |
-
) + ",\n"
|
231 |
-
)
|
232 |
-
|
233 |
-
|
234 |
-
def s3up(filename):
|
235 |
-
try:
|
236 |
-
return _s3up(filename)
|
237 |
-
except Exception as e:
|
238 |
-
traceback.print_exc()
|
239 |
-
print('Exception for file %s in s3up: %s' % (filename, str(e)))
|
240 |
-
return "Failed to upload %s: Error: %s" % (filename, str(e))
|
241 |
-
|
242 |
-
|
243 |
-
def _s3up(filename):
|
244 |
-
import boto3
|
245 |
-
|
246 |
-
aws_access_key_id = os.getenv('AWS_SERVER_PUBLIC_KEY')
|
247 |
-
aws_secret_access_key = os.getenv('AWS_SERVER_SECRET_KEY')
|
248 |
-
bucket = os.getenv('AWS_BUCKET')
|
249 |
-
assert aws_access_key_id, "Set AWS key"
|
250 |
-
assert aws_secret_access_key, "Set AWS secret"
|
251 |
-
assert bucket, "Set AWS Bucket"
|
252 |
-
|
253 |
-
s3 = boto3.client('s3',
|
254 |
-
aws_access_key_id=os.getenv('AWS_SERVER_PUBLIC_KEY'),
|
255 |
-
aws_secret_access_key=os.getenv('AWS_SERVER_SECRET_KEY'),
|
256 |
-
)
|
257 |
-
ret = s3.upload_file(
|
258 |
-
Filename=filename,
|
259 |
-
Bucket=os.getenv('AWS_BUCKET'),
|
260 |
-
Key=filename,
|
261 |
-
)
|
262 |
-
if ret in [None, '']:
|
263 |
-
return "Successfully uploaded %s" % filename
|
264 |
-
|
265 |
-
|
266 |
-
def get_githash():
|
267 |
-
try:
|
268 |
-
githash = subprocess.run(['git', 'rev-parse', 'HEAD'], stdout=subprocess.PIPE).stdout.decode('utf-8')[0:-1]
|
269 |
-
except:
|
270 |
-
githash = ''
|
271 |
-
return githash
|
272 |
-
|
273 |
-
|
274 |
-
def copy_code(run_id):
|
275 |
-
"""
|
276 |
-
copy code to track changes
|
277 |
-
:param run_id:
|
278 |
-
:return:
|
279 |
-
"""
|
280 |
-
rnd_num = str(random.randint(0, 2 ** 31))
|
281 |
-
run_id = 'run_' + str(run_id)
|
282 |
-
os.makedirs(run_id, exist_ok=True)
|
283 |
-
me_full = os.path.join(pathlib.Path(__file__).parent.resolve(), __file__)
|
284 |
-
me_file = os.path.basename(__file__)
|
285 |
-
new_me = os.path.join(run_id, me_file + '_' + get_githash())
|
286 |
-
if os.path.isfile(new_me):
|
287 |
-
new_me = os.path.join(run_id, me_file + '_' + get_githash() + '_' + rnd_num)
|
288 |
-
shutil.copy(me_full, new_me)
|
289 |
-
else:
|
290 |
-
shutil.copy(me_full, new_me)
|
291 |
-
|
292 |
-
|
293 |
-
class NullContext(threading.local):
|
294 |
-
"""No-op context manager, executes block without doing any additional processing.
|
295 |
-
|
296 |
-
Used as a stand-in if a particular block of code is only sometimes
|
297 |
-
used with a normal context manager:
|
298 |
-
"""
|
299 |
-
|
300 |
-
def __init__(self, *args, **kwargs):
|
301 |
-
pass
|
302 |
-
|
303 |
-
def __enter__(self):
|
304 |
-
return self
|
305 |
-
|
306 |
-
def __exit__(self, exc_type, exc_value, exc_traceback):
|
307 |
-
self.finally_act()
|
308 |
-
|
309 |
-
def finally_act(self):
|
310 |
-
pass
|
311 |
-
|
312 |
-
|
313 |
-
def wrapped_partial(func, *args, **kwargs):
|
314 |
-
"""
|
315 |
-
Give partial properties of normal function, like __name__ attribute etc.
|
316 |
-
:param func:
|
317 |
-
:param args:
|
318 |
-
:param kwargs:
|
319 |
-
:return:
|
320 |
-
"""
|
321 |
-
partial_func = functools.partial(func, *args, **kwargs)
|
322 |
-
functools.update_wrapper(partial_func, func)
|
323 |
-
return partial_func
|
324 |
-
|
325 |
-
|
326 |
-
class ThreadException(Exception):
|
327 |
-
pass
|
328 |
-
|
329 |
-
|
330 |
-
class EThread(threading.Thread):
|
331 |
-
# Function that raises the custom exception
|
332 |
-
def __init__(self, group=None, target=None, name=None,
|
333 |
-
args=(), kwargs=None, *, daemon=None, streamer=None, bucket=None):
|
334 |
-
self.bucket = bucket
|
335 |
-
self.streamer = streamer
|
336 |
-
self.exc = None
|
337 |
-
self._return = None
|
338 |
-
super().__init__(group=group, target=target, name=name, args=args, kwargs=kwargs, daemon=daemon)
|
339 |
-
|
340 |
-
def run(self):
|
341 |
-
# Variable that stores the exception, if raised by someFunction
|
342 |
-
try:
|
343 |
-
if self._target is not None:
|
344 |
-
self._return = self._target(*self._args, **self._kwargs)
|
345 |
-
except BaseException as e:
|
346 |
-
print("thread exception: %s" % str(sys.exc_info()))
|
347 |
-
self.bucket.put(sys.exc_info())
|
348 |
-
self.exc = e
|
349 |
-
if self.streamer:
|
350 |
-
print("make stop: %s" % str(sys.exc_info()), flush=True)
|
351 |
-
self.streamer.do_stop = True
|
352 |
-
finally:
|
353 |
-
# Avoid a refcycle if the thread is running a function with
|
354 |
-
# an argument that has a member that points to the thread.
|
355 |
-
del self._target, self._args, self._kwargs
|
356 |
-
|
357 |
-
def join(self, timeout=None):
|
358 |
-
threading.Thread.join(self)
|
359 |
-
# Since join() returns in caller thread
|
360 |
-
# we re-raise the caught exception
|
361 |
-
# if any was caught
|
362 |
-
if self.exc:
|
363 |
-
raise self.exc
|
364 |
-
return self._return
|
365 |
-
|
366 |
-
|
367 |
-
def import_matplotlib():
|
368 |
-
import matplotlib
|
369 |
-
matplotlib.use('agg')
|
370 |
-
# KEEP THESE HERE! START
|
371 |
-
import matplotlib.pyplot as plt
|
372 |
-
import pandas as pd
|
373 |
-
# to avoid dlopen deadlock in fork
|
374 |
-
import pandas.core.computation.expressions as pd_expressions
|
375 |
-
import pandas._libs.groupby as pd_libgroupby
|
376 |
-
import pandas._libs.reduction as pd_libreduction
|
377 |
-
import pandas.core.algorithms as pd_algorithms
|
378 |
-
import pandas.core.common as pd_com
|
379 |
-
import numpy as np
|
380 |
-
# KEEP THESE HERE! END
|
381 |
-
|
382 |
-
|
383 |
-
def get_sha(value):
|
384 |
-
return hashlib.md5(str(value).encode('utf-8')).hexdigest()
|
385 |
-
|
386 |
-
|
387 |
-
def sanitize_filename(name):
|
388 |
-
"""
|
389 |
-
Sanitize file *base* names.
|
390 |
-
:param name: name to sanitize
|
391 |
-
:return:
|
392 |
-
"""
|
393 |
-
bad_chars = ['[', ']', ',', '/', '\\', '\\w', '\\s', '-', '+', '\"', '\'', '>', '<', ' ', '=', ')', '(', ':', '^']
|
394 |
-
for char in bad_chars:
|
395 |
-
name = name.replace(char, "_")
|
396 |
-
|
397 |
-
length = len(name)
|
398 |
-
file_length_limit = 250 # bit smaller than 256 for safety
|
399 |
-
sha_length = 32
|
400 |
-
real_length_limit = file_length_limit - (sha_length + 2)
|
401 |
-
if length > file_length_limit:
|
402 |
-
sha = get_sha(name)
|
403 |
-
half_real_length_limit = max(1, int(real_length_limit / 2))
|
404 |
-
name = name[0:half_real_length_limit] + "_" + sha + "_" + name[length - half_real_length_limit:length]
|
405 |
-
|
406 |
-
return name
|
407 |
-
|
408 |
-
|
409 |
-
def shutil_rmtree(*args, **kwargs):
|
410 |
-
return shutil.rmtree(*args, **kwargs)
|
411 |
-
|
412 |
-
|
413 |
-
def remove(path: str):
|
414 |
-
try:
|
415 |
-
if path is not None and os.path.exists(path):
|
416 |
-
if os.path.isdir(path):
|
417 |
-
shutil_rmtree(path, ignore_errors=True)
|
418 |
-
else:
|
419 |
-
with contextlib.suppress(FileNotFoundError):
|
420 |
-
os.remove(path)
|
421 |
-
except:
|
422 |
-
pass
|
423 |
-
|
424 |
-
|
425 |
-
def makedirs(path, exist_ok=True):
|
426 |
-
"""
|
427 |
-
Avoid some inefficiency in os.makedirs()
|
428 |
-
:param path:
|
429 |
-
:param exist_ok:
|
430 |
-
:return:
|
431 |
-
"""
|
432 |
-
if os.path.isdir(path) and os.path.exists(path):
|
433 |
-
assert exist_ok, "Path already exists"
|
434 |
-
return path
|
435 |
-
os.makedirs(path, exist_ok=exist_ok)
|
436 |
-
|
437 |
-
|
438 |
-
def atomic_move_simple(src, dst):
|
439 |
-
try:
|
440 |
-
shutil.move(src, dst)
|
441 |
-
except (shutil.Error, FileExistsError):
|
442 |
-
pass
|
443 |
-
remove(src)
|
444 |
-
|
445 |
-
|
446 |
-
def download_simple(url, dest=None, print_func=None):
|
447 |
-
if print_func is not None:
|
448 |
-
print_func("BEGIN get url %s" % str(url))
|
449 |
-
if url.startswith("file://"):
|
450 |
-
from requests_file import FileAdapter
|
451 |
-
s = requests.Session()
|
452 |
-
s.mount('file://', FileAdapter())
|
453 |
-
url_data = s.get(url, stream=True)
|
454 |
-
else:
|
455 |
-
url_data = requests.get(url, stream=True)
|
456 |
-
if dest is None:
|
457 |
-
dest = os.path.basename(url)
|
458 |
-
if url_data.status_code != requests.codes.ok:
|
459 |
-
msg = "Cannot get url %s, code: %s, reason: %s" % (
|
460 |
-
str(url),
|
461 |
-
str(url_data.status_code),
|
462 |
-
str(url_data.reason),
|
463 |
-
)
|
464 |
-
raise requests.exceptions.RequestException(msg)
|
465 |
-
url_data.raw.decode_content = True
|
466 |
-
makedirs(os.path.dirname(dest), exist_ok=True)
|
467 |
-
uuid_tmp = str(uuid.uuid4())[:6]
|
468 |
-
dest_tmp = dest + "_dl_" + uuid_tmp + ".tmp"
|
469 |
-
with open(dest_tmp, "wb") as f:
|
470 |
-
shutil.copyfileobj(url_data.raw, f)
|
471 |
-
atomic_move_simple(dest_tmp, dest)
|
472 |
-
if print_func is not None:
|
473 |
-
print_func("END get url %s" % str(url))
|
474 |
-
|
475 |
-
|
476 |
-
def download(url, dest=None, dest_path=None):
|
477 |
-
if dest_path is not None:
|
478 |
-
dest = os.path.join(dest_path, os.path.basename(url))
|
479 |
-
if os.path.isfile(dest):
|
480 |
-
print("already downloaded %s -> %s" % (url, dest))
|
481 |
-
return dest
|
482 |
-
elif dest is not None:
|
483 |
-
if os.path.exists(dest):
|
484 |
-
print("already downloaded %s -> %s" % (url, dest))
|
485 |
-
return dest
|
486 |
-
else:
|
487 |
-
uuid_tmp = "dl2_" + str(uuid.uuid4())[:6]
|
488 |
-
dest = uuid_tmp + os.path.basename(url)
|
489 |
-
|
490 |
-
print("downloading %s to %s" % (url, dest))
|
491 |
-
|
492 |
-
if url.startswith("file://"):
|
493 |
-
from requests_file import FileAdapter
|
494 |
-
s = requests.Session()
|
495 |
-
s.mount('file://', FileAdapter())
|
496 |
-
url_data = s.get(url, stream=True)
|
497 |
-
else:
|
498 |
-
url_data = requests.get(url, stream=True)
|
499 |
-
|
500 |
-
if url_data.status_code != requests.codes.ok:
|
501 |
-
msg = "Cannot get url %s, code: %s, reason: %s" % (
|
502 |
-
str(url), str(url_data.status_code), str(url_data.reason))
|
503 |
-
raise requests.exceptions.RequestException(msg)
|
504 |
-
url_data.raw.decode_content = True
|
505 |
-
dirname = os.path.dirname(dest)
|
506 |
-
if dirname != "" and not os.path.isdir(dirname):
|
507 |
-
makedirs(os.path.dirname(dest), exist_ok=True)
|
508 |
-
uuid_tmp = "dl3_" + str(uuid.uuid4())[:6]
|
509 |
-
dest_tmp = dest + "_" + uuid_tmp + ".tmp"
|
510 |
-
with open(dest_tmp, 'wb') as f:
|
511 |
-
shutil.copyfileobj(url_data.raw, f)
|
512 |
-
try:
|
513 |
-
shutil.move(dest_tmp, dest)
|
514 |
-
except FileExistsError:
|
515 |
-
pass
|
516 |
-
remove(dest_tmp)
|
517 |
-
return dest
|
518 |
-
|
519 |
-
|
520 |
-
def get_url(x, from_str=False, short_name=False):
|
521 |
-
if not from_str:
|
522 |
-
source = x.metadata['source']
|
523 |
-
else:
|
524 |
-
source = x
|
525 |
-
if short_name:
|
526 |
-
source_name = get_short_name(source)
|
527 |
-
else:
|
528 |
-
source_name = source
|
529 |
-
if source.startswith('http://') or source.startswith('https://'):
|
530 |
-
return """<a href="%s" target="_blank" rel="noopener noreferrer">%s</a>""" % (
|
531 |
-
source, source_name)
|
532 |
-
else:
|
533 |
-
return """<a href="file/%s" target="_blank" rel="noopener noreferrer">%s</a>""" % (
|
534 |
-
source, source_name)
|
535 |
-
|
536 |
-
|
537 |
-
def get_short_name(name, maxl=50):
|
538 |
-
if name is None:
|
539 |
-
return ''
|
540 |
-
length = len(name)
|
541 |
-
if length > maxl:
|
542 |
-
allow_length = maxl - 3
|
543 |
-
half_allowed = max(1, int(allow_length / 2))
|
544 |
-
name = name[0:half_allowed] + "..." + name[length - half_allowed:length]
|
545 |
-
return name
|
546 |
-
|
547 |
-
|
548 |
-
def cuda_vis_check(total_gpus):
|
549 |
-
"""Helper function to count GPUs by environment variable
|
550 |
-
Stolen from Jon's h2o4gpu utils
|
551 |
-
"""
|
552 |
-
cudavis = os.getenv("CUDA_VISIBLE_DEVICES")
|
553 |
-
which_gpus = []
|
554 |
-
if cudavis is not None:
|
555 |
-
# prune away white-space, non-numerics,
|
556 |
-
# except commas for simple checking
|
557 |
-
cudavis = "".join(cudavis.split())
|
558 |
-
import re
|
559 |
-
cudavis = re.sub("[^0-9,]", "", cudavis)
|
560 |
-
|
561 |
-
lencudavis = len(cudavis)
|
562 |
-
if lencudavis == 0:
|
563 |
-
total_gpus = 0
|
564 |
-
else:
|
565 |
-
total_gpus = min(
|
566 |
-
total_gpus,
|
567 |
-
os.getenv("CUDA_VISIBLE_DEVICES").count(",") + 1)
|
568 |
-
which_gpus = os.getenv("CUDA_VISIBLE_DEVICES").split(",")
|
569 |
-
which_gpus = [int(x) for x in which_gpus]
|
570 |
-
else:
|
571 |
-
which_gpus = list(range(0, total_gpus))
|
572 |
-
|
573 |
-
return total_gpus, which_gpus
|
574 |
-
|
575 |
-
|
576 |
-
def get_ngpus_vis(raise_if_exception=True):
|
577 |
-
ngpus_vis1 = 0
|
578 |
-
|
579 |
-
shell = False
|
580 |
-
if shell:
|
581 |
-
cmd = "nvidia-smi -L 2> /dev/null"
|
582 |
-
else:
|
583 |
-
cmd = ["nvidia-smi", "-L"]
|
584 |
-
|
585 |
-
try:
|
586 |
-
timeout = 5 * 3
|
587 |
-
o = subprocess.check_output(cmd, shell=shell, timeout=timeout)
|
588 |
-
lines = o.decode("utf-8").splitlines()
|
589 |
-
ngpus_vis1 = 0
|
590 |
-
for line in lines:
|
591 |
-
if 'Failed to initialize NVML' not in line:
|
592 |
-
ngpus_vis1 += 1
|
593 |
-
except (FileNotFoundError, subprocess.CalledProcessError, OSError):
|
594 |
-
# GPU systems might not have nvidia-smi, so can't fail
|
595 |
-
pass
|
596 |
-
except subprocess.TimeoutExpired as e:
|
597 |
-
print('Failed get_ngpus_vis: %s' % str(e))
|
598 |
-
if raise_if_exception:
|
599 |
-
raise
|
600 |
-
|
601 |
-
ngpus_vis1, which_gpus = cuda_vis_check(ngpus_vis1)
|
602 |
-
return ngpus_vis1
|
603 |
-
|
604 |
-
|
605 |
-
def get_mem_gpus(raise_if_exception=True, ngpus=None):
|
606 |
-
totalmem_gpus1 = 0
|
607 |
-
usedmem_gpus1 = 0
|
608 |
-
freemem_gpus1 = 0
|
609 |
-
|
610 |
-
if ngpus == 0:
|
611 |
-
return totalmem_gpus1, usedmem_gpus1, freemem_gpus1
|
612 |
-
|
613 |
-
try:
|
614 |
-
cmd = "nvidia-smi -q 2> /dev/null | grep -A 3 'FB Memory Usage'"
|
615 |
-
o = subprocess.check_output(cmd, shell=True, timeout=15)
|
616 |
-
lines = o.decode("utf-8").splitlines()
|
617 |
-
for line in lines:
|
618 |
-
if 'Total' in line:
|
619 |
-
totalmem_gpus1 += int(line.split()[2]) * 1024 ** 2
|
620 |
-
if 'Used' in line:
|
621 |
-
usedmem_gpus1 += int(line.split()[2]) * 1024 ** 2
|
622 |
-
if 'Free' in line:
|
623 |
-
freemem_gpus1 += int(line.split()[2]) * 1024 ** 2
|
624 |
-
except (FileNotFoundError, subprocess.CalledProcessError, OSError):
|
625 |
-
# GPU systems might not have nvidia-smi, so can't fail
|
626 |
-
pass
|
627 |
-
except subprocess.TimeoutExpired as e:
|
628 |
-
print('Failed get_mem_gpus: %s' % str(e))
|
629 |
-
if raise_if_exception:
|
630 |
-
raise
|
631 |
-
|
632 |
-
return totalmem_gpus1, usedmem_gpus1, freemem_gpus1
|
633 |
-
|
634 |
-
|
635 |
-
class ForkContext(threading.local):
|
636 |
-
"""
|
637 |
-
Set context for forking
|
638 |
-
Ensures state is returned once done
|
639 |
-
"""
|
640 |
-
|
641 |
-
def __init__(self, args=None, kwargs=None, forkdata_capable=True):
|
642 |
-
"""
|
643 |
-
:param args:
|
644 |
-
:param kwargs:
|
645 |
-
:param forkdata_capable: whether fork is forkdata capable and will use copy-on-write forking of args/kwargs
|
646 |
-
"""
|
647 |
-
self.forkdata_capable = forkdata_capable
|
648 |
-
if self.forkdata_capable:
|
649 |
-
self.has_args = args is not None
|
650 |
-
self.has_kwargs = kwargs is not None
|
651 |
-
forkdatacontext.args = args
|
652 |
-
forkdatacontext.kwargs = kwargs
|
653 |
-
else:
|
654 |
-
self.has_args = False
|
655 |
-
self.has_kwargs = False
|
656 |
-
|
657 |
-
def __enter__(self):
|
658 |
-
try:
|
659 |
-
# flush all outputs so doesn't happen during fork -- don't print/log inside ForkContext contexts!
|
660 |
-
sys.stdout.flush()
|
661 |
-
sys.stderr.flush()
|
662 |
-
except BaseException as e:
|
663 |
-
# exit not called if exception, and don't want to leave forkdatacontext filled in that case
|
664 |
-
print("ForkContext failure on enter: %s" % str(e))
|
665 |
-
self.finally_act()
|
666 |
-
raise
|
667 |
-
return self
|
668 |
-
|
669 |
-
def __exit__(self, exc_type, exc_value, exc_traceback):
|
670 |
-
self.finally_act()
|
671 |
-
|
672 |
-
def finally_act(self):
|
673 |
-
"""
|
674 |
-
Done when exception hit or exit is reached in context
|
675 |
-
first reset forkdatacontext as crucial to have reset even if later 2 calls fail
|
676 |
-
:return: None
|
677 |
-
"""
|
678 |
-
if self.forkdata_capable and (self.has_args or self.has_kwargs):
|
679 |
-
forkdatacontext._reset()
|
680 |
-
|
681 |
-
|
682 |
-
class _ForkDataContext(threading.local):
|
683 |
-
def __init__(
|
684 |
-
self,
|
685 |
-
args=None,
|
686 |
-
kwargs=None,
|
687 |
-
):
|
688 |
-
"""
|
689 |
-
Global context for fork to carry data to subprocess instead of relying upon copy/pickle/serialization
|
690 |
-
|
691 |
-
:param args: args
|
692 |
-
:param kwargs: kwargs
|
693 |
-
"""
|
694 |
-
assert isinstance(args, (tuple, type(None)))
|
695 |
-
assert isinstance(kwargs, (dict, type(None)))
|
696 |
-
self.__args = args
|
697 |
-
self.__kwargs = kwargs
|
698 |
-
|
699 |
-
@property
|
700 |
-
def args(self) -> Tuple:
|
701 |
-
"""returns args"""
|
702 |
-
return self.__args
|
703 |
-
|
704 |
-
@args.setter
|
705 |
-
def args(self, args):
|
706 |
-
if self.__args is not None:
|
707 |
-
raise AttributeError(
|
708 |
-
"args cannot be overwritten: %s %s" % (str(self.__args), str(self.__kwargs))
|
709 |
-
)
|
710 |
-
|
711 |
-
self.__args = args
|
712 |
-
|
713 |
-
@property
|
714 |
-
def kwargs(self) -> Dict:
|
715 |
-
"""returns kwargs"""
|
716 |
-
return self.__kwargs
|
717 |
-
|
718 |
-
@kwargs.setter
|
719 |
-
def kwargs(self, kwargs):
|
720 |
-
if self.__kwargs is not None:
|
721 |
-
raise AttributeError(
|
722 |
-
"kwargs cannot be overwritten: %s %s" % (str(self.__args), str(self.__kwargs))
|
723 |
-
)
|
724 |
-
|
725 |
-
self.__kwargs = kwargs
|
726 |
-
|
727 |
-
def _reset(self):
|
728 |
-
"""Reset fork arg-kwarg context to default values"""
|
729 |
-
self.__args = None
|
730 |
-
self.__kwargs = None
|
731 |
-
|
732 |
-
def get_args_kwargs(self, func, args, kwargs) -> Tuple[Callable, Tuple, Dict]:
|
733 |
-
if self.__args:
|
734 |
-
args = self.__args[1:]
|
735 |
-
if not func:
|
736 |
-
assert len(self.__args) > 0, "if have no func, must have in args"
|
737 |
-
func = self.__args[0] # should always be there
|
738 |
-
if self.__kwargs:
|
739 |
-
kwargs = self.__kwargs
|
740 |
-
try:
|
741 |
-
return func, args, kwargs
|
742 |
-
finally:
|
743 |
-
forkdatacontext._reset()
|
744 |
-
|
745 |
-
@staticmethod
|
746 |
-
def get_args_kwargs_for_traced_func(func, args, kwargs):
|
747 |
-
"""
|
748 |
-
Return args/kwargs out of forkdatacontext when using copy-on-write way of passing args/kwargs
|
749 |
-
:param func: actual function ran by _traced_func, which itself is directly what mppool treats as function
|
750 |
-
:param args:
|
751 |
-
:param kwargs:
|
752 |
-
:return: func, args, kwargs from forkdatacontext if used, else originals
|
753 |
-
"""
|
754 |
-
# first 3 lines are debug
|
755 |
-
func_was_None = func is None
|
756 |
-
args_was_None_or_empty = args is None or len(args) == 0
|
757 |
-
kwargs_was_None_or_empty = kwargs is None or len(kwargs) == 0
|
758 |
-
|
759 |
-
forkdatacontext_args_was_None = forkdatacontext.args is None
|
760 |
-
forkdatacontext_kwargs_was_None = forkdatacontext.kwargs is None
|
761 |
-
func, args, kwargs = forkdatacontext.get_args_kwargs(func, args, kwargs)
|
762 |
-
using_forkdatacontext = func_was_None and func is not None # pulled func out of forkdatacontext.__args[0]
|
763 |
-
assert forkdatacontext.args is None, "forkdatacontext.args should be None after get_args_kwargs"
|
764 |
-
assert forkdatacontext.kwargs is None, "forkdatacontext.kwargs should be None after get_args_kwargs"
|
765 |
-
|
766 |
-
proc_type = kwargs.get('proc_type', 'SUBPROCESS')
|
767 |
-
if using_forkdatacontext:
|
768 |
-
assert proc_type == "SUBPROCESS" or proc_type == "SUBPROCESS"
|
769 |
-
if proc_type == "NORMAL":
|
770 |
-
assert forkdatacontext_args_was_None, "if no fork, expect forkdatacontext.args None entering _traced_func"
|
771 |
-
assert forkdatacontext_kwargs_was_None, "if no fork, expect forkdatacontext.kwargs None entering _traced_func"
|
772 |
-
assert func is not None, "function should not be None, indicates original args[0] was None or args was None"
|
773 |
-
|
774 |
-
return func, args, kwargs
|
775 |
-
|
776 |
-
|
777 |
-
forkdatacontext = _ForkDataContext()
|
778 |
-
|
779 |
-
|
780 |
-
def _traced_func(func, *args, **kwargs):
|
781 |
-
func, args, kwargs = forkdatacontext.get_args_kwargs_for_traced_func(func, args, kwargs)
|
782 |
-
return func(*args, **kwargs)
|
783 |
-
|
784 |
-
|
785 |
-
def call_subprocess_onetask(func, args=None, kwargs=None):
|
786 |
-
import platform
|
787 |
-
if platform.system() in ['Darwin', 'Windows']:
|
788 |
-
return func(*args, **kwargs)
|
789 |
-
if isinstance(args, list):
|
790 |
-
args = tuple(args)
|
791 |
-
if args is None:
|
792 |
-
args = ()
|
793 |
-
if kwargs is None:
|
794 |
-
kwargs = {}
|
795 |
-
args = list(args)
|
796 |
-
args = [func] + args
|
797 |
-
args = tuple(args)
|
798 |
-
with ForkContext(args=args, kwargs=kwargs):
|
799 |
-
args = (None,)
|
800 |
-
kwargs = {}
|
801 |
-
with ProcessPoolExecutor(max_workers=1) as executor:
|
802 |
-
future = executor.submit(_traced_func, *args, **kwargs)
|
803 |
-
return future.result()
|
804 |
-
|
805 |
-
|
806 |
-
class ProgressParallel(Parallel):
|
807 |
-
def __init__(self, use_tqdm=True, total=None, *args, **kwargs):
|
808 |
-
self._use_tqdm = use_tqdm
|
809 |
-
self._total = total
|
810 |
-
super().__init__(*args, **kwargs)
|
811 |
-
|
812 |
-
def __call__(self, *args, **kwargs):
|
813 |
-
with tqdm(disable=not self._use_tqdm, total=self._total) as self._pbar:
|
814 |
-
return Parallel.__call__(self, *args, **kwargs)
|
815 |
-
|
816 |
-
def print_progress(self):
|
817 |
-
if self._total is None:
|
818 |
-
self._pbar.total = self.n_dispatched_tasks
|
819 |
-
self._pbar.n = self.n_completed_tasks
|
820 |
-
self._pbar.refresh()
|
821 |
-
|
822 |
-
|
823 |
-
def get_kwargs(func, exclude_names=None, **kwargs):
|
824 |
-
func_names = list(inspect.signature(func).parameters)
|
825 |
-
missing_kwargs = [x for x in func_names if x not in kwargs]
|
826 |
-
if exclude_names:
|
827 |
-
for k in exclude_names:
|
828 |
-
if k in missing_kwargs:
|
829 |
-
missing_kwargs.remove(k)
|
830 |
-
if k in func_names:
|
831 |
-
func_names.remove(k)
|
832 |
-
assert not missing_kwargs, "Missing %s" % missing_kwargs
|
833 |
-
kwargs = {k: v for k, v in kwargs.items() if k in func_names}
|
834 |
-
return kwargs
|
835 |
-
|
836 |
-
|
837 |
-
import pkg_resources
|
838 |
-
|
839 |
-
have_faiss = False
|
840 |
-
|
841 |
-
try:
|
842 |
-
assert pkg_resources.get_distribution('faiss') is not None
|
843 |
-
have_faiss = True
|
844 |
-
except (pkg_resources.DistributionNotFound, AssertionError):
|
845 |
-
pass
|
846 |
-
try:
|
847 |
-
assert pkg_resources.get_distribution('faiss_gpu') is not None
|
848 |
-
have_faiss = True
|
849 |
-
except (pkg_resources.DistributionNotFound, AssertionError):
|
850 |
-
pass
|
851 |
-
try:
|
852 |
-
assert pkg_resources.get_distribution('faiss_cpu') is not None
|
853 |
-
have_faiss = True
|
854 |
-
except (pkg_resources.DistributionNotFound, AssertionError):
|
855 |
-
pass
|
856 |
-
|
857 |
-
|
858 |
-
def hash_file(file):
|
859 |
-
try:
|
860 |
-
import hashlib
|
861 |
-
|
862 |
-
# BUF_SIZE is totally arbitrary, change for your app!
|
863 |
-
BUF_SIZE = 65536 # lets read stuff in 64kb chunks!
|
864 |
-
|
865 |
-
md5 = hashlib.md5()
|
866 |
-
# sha1 = hashlib.sha1()
|
867 |
-
|
868 |
-
with open(file, 'rb') as f:
|
869 |
-
while True:
|
870 |
-
data = f.read(BUF_SIZE)
|
871 |
-
if not data:
|
872 |
-
break
|
873 |
-
md5.update(data)
|
874 |
-
# sha1.update(data)
|
875 |
-
except BaseException as e:
|
876 |
-
print("Cannot hash %s due to %s" % (file, str(e)))
|
877 |
-
traceback.print_exc()
|
878 |
-
md5 = None
|
879 |
-
return md5.hexdigest()
|
880 |
-
|
881 |
-
|
882 |
-
def start_faulthandler():
|
883 |
-
# If hit server or any subprocess with signal SIGUSR1, it'll print out all threads stack trace, but wont't quit or coredump
|
884 |
-
# If more than one fork tries to write at same time, then looks corrupted.
|
885 |
-
import faulthandler
|
886 |
-
|
887 |
-
# SIGUSR1 in h2oai/__init__.py as well
|
888 |
-
faulthandler.enable()
|
889 |
-
if hasattr(faulthandler, 'register'):
|
890 |
-
# windows/mac
|
891 |
-
import signal
|
892 |
-
faulthandler.register(signal.SIGUSR1)
|
893 |
-
|
894 |
-
|
895 |
-
def get_hf_server(inference_server):
|
896 |
-
inf_split = inference_server.split(" ")
|
897 |
-
assert len(inf_split) == 1 or len(inf_split) == 3
|
898 |
-
inference_server = inf_split[0]
|
899 |
-
if len(inf_split) == 3:
|
900 |
-
headers = {"authorization": "%s %s" % (inf_split[1], inf_split[2])}
|
901 |
-
else:
|
902 |
-
headers = None
|
903 |
-
return inference_server, headers
|
904 |
-
|
905 |
-
|
906 |
-
class FakeTokenizer:
|
907 |
-
"""
|
908 |
-
1) For keeping track of model_max_length
|
909 |
-
2) For when model doesn't directly expose tokenizer but need to count tokens
|
910 |
-
"""
|
911 |
-
|
912 |
-
def __init__(self, model_max_length=2048, encoding_name="cl100k_base"):
|
913 |
-
# dont' push limit, since if using fake tokenizer, only estimate, and seen underestimates by order 250
|
914 |
-
self.model_max_length = model_max_length - 250
|
915 |
-
self.encoding_name = encoding_name
|
916 |
-
# The first time this runs, it will require an internet connection to download. Later runs won't need an internet connection.
|
917 |
-
import tiktoken
|
918 |
-
self.encoding = tiktoken.get_encoding(self.encoding_name)
|
919 |
-
|
920 |
-
def encode(self, x, *args, return_tensors="pt", **kwargs):
|
921 |
-
input_ids = self.encoding.encode(x, disallowed_special=())
|
922 |
-
if return_tensors == 'pt' and isinstance(input_ids, list):
|
923 |
-
import torch
|
924 |
-
input_ids = torch.tensor(input_ids)
|
925 |
-
return dict(input_ids=input_ids)
|
926 |
-
|
927 |
-
def decode(self, x, *args, **kwargs):
|
928 |
-
# input is input_ids[0] form
|
929 |
-
return self.encoding.decode(x)
|
930 |
-
|
931 |
-
def num_tokens_from_string(self, prompt: str) -> int:
|
932 |
-
"""Returns the number of tokens in a text string."""
|
933 |
-
num_tokens = len(self.encoding.encode(prompt))
|
934 |
-
return num_tokens
|
935 |
-
|
936 |
-
def __call__(self, x, *args, **kwargs):
|
937 |
-
return self.encode(x, *args, **kwargs)
|
938 |
-
|
939 |
-
|
940 |
-
def get_local_ip():
|
941 |
-
import socket
|
942 |
-
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
943 |
-
try:
|
944 |
-
# doesn't even have to be reachable
|
945 |
-
s.connect(('10.255.255.255', 1))
|
946 |
-
IP = s.getsockname()[0]
|
947 |
-
except Exception:
|
948 |
-
IP = '127.0.0.1'
|
949 |
-
finally:
|
950 |
-
s.close()
|
951 |
-
return IP
|
952 |
-
|
953 |
-
|
954 |
-
try:
|
955 |
-
assert pkg_resources.get_distribution('langchain') is not None
|
956 |
-
have_langchain = True
|
957 |
-
except (pkg_resources.DistributionNotFound, AssertionError):
|
958 |
-
have_langchain = False
|
959 |
-
|
960 |
-
import distutils.spawn
|
961 |
-
|
962 |
-
have_tesseract = distutils.spawn.find_executable("tesseract")
|
963 |
-
have_libreoffice = distutils.spawn.find_executable("libreoffice")
|
964 |
-
|
965 |
-
import pkg_resources
|
966 |
-
|
967 |
-
try:
|
968 |
-
assert pkg_resources.get_distribution('arxiv') is not None
|
969 |
-
assert pkg_resources.get_distribution('pymupdf') is not None
|
970 |
-
have_arxiv = True
|
971 |
-
except (pkg_resources.DistributionNotFound, AssertionError):
|
972 |
-
have_arxiv = False
|
973 |
-
|
974 |
-
try:
|
975 |
-
assert pkg_resources.get_distribution('pymupdf') is not None
|
976 |
-
have_pymupdf = True
|
977 |
-
except (pkg_resources.DistributionNotFound, AssertionError):
|
978 |
-
have_pymupdf = False
|
979 |
-
|
980 |
-
try:
|
981 |
-
assert pkg_resources.get_distribution('selenium') is not None
|
982 |
-
have_selenium = True
|
983 |
-
except (pkg_resources.DistributionNotFound, AssertionError):
|
984 |
-
have_selenium = False
|
985 |
-
|
986 |
-
try:
|
987 |
-
assert pkg_resources.get_distribution('playwright') is not None
|
988 |
-
have_playwright = True
|
989 |
-
except (pkg_resources.DistributionNotFound, AssertionError):
|
990 |
-
have_playwright = False
|
991 |
-
|
992 |
-
# disable, hangs too often
|
993 |
-
have_playwright = False
|
994 |
-
|
995 |
-
|
996 |
-
def set_openai(inference_server):
|
997 |
-
if inference_server.startswith('vllm'):
|
998 |
-
import openai_vllm
|
999 |
-
openai_vllm.api_key = "EMPTY"
|
1000 |
-
inf_type = inference_server.split(':')[0]
|
1001 |
-
ip_vllm = inference_server.split(':')[1]
|
1002 |
-
port_vllm = inference_server.split(':')[2]
|
1003 |
-
openai_vllm.api_base = f"http://{ip_vllm}:{port_vllm}/v1"
|
1004 |
-
return openai_vllm, inf_type
|
1005 |
-
else:
|
1006 |
-
import openai
|
1007 |
-
openai.api_key = os.getenv("OPENAI_API_KEY")
|
1008 |
-
openai.api_base = os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1")
|
1009 |
-
inf_type = inference_server
|
1010 |
-
return openai, inf_type
|
1011 |
-
|
1012 |
-
|
1013 |
-
visible_langchain_modes_file = 'visible_langchain_modes.pkl'
|
1014 |
-
|
1015 |
-
|
1016 |
-
def save_collection_names(langchain_modes, visible_langchain_modes, langchain_mode_paths, LangChainMode, db1s):
|
1017 |
-
"""
|
1018 |
-
extra controls if UserData type of MyData type
|
1019 |
-
"""
|
1020 |
-
|
1021 |
-
# use first default MyData hash as general user hash to maintain file
|
1022 |
-
# if user moves MyData from langchain modes, db will still survive, so can still use hash
|
1023 |
-
scratch_collection_names = list(db1s.keys())
|
1024 |
-
user_hash = db1s.get(LangChainMode.MY_DATA.value, '')[1]
|
1025 |
-
|
1026 |
-
llms = ['LLM', 'Disabled']
|
1027 |
-
|
1028 |
-
scratch_langchain_modes = [x for x in langchain_modes if x in scratch_collection_names]
|
1029 |
-
scratch_visible_langchain_modes = [x for x in visible_langchain_modes if x in scratch_collection_names]
|
1030 |
-
scratch_langchain_mode_paths = {k: v for k, v in langchain_mode_paths.items() if
|
1031 |
-
k in scratch_collection_names and k not in llms}
|
1032 |
-
|
1033 |
-
user_langchain_modes = [x for x in langchain_modes if x not in scratch_collection_names]
|
1034 |
-
user_visible_langchain_modes = [x for x in visible_langchain_modes if x not in scratch_collection_names]
|
1035 |
-
user_langchain_mode_paths = {k: v for k, v in langchain_mode_paths.items() if
|
1036 |
-
k not in scratch_collection_names and k not in llms}
|
1037 |
-
|
1038 |
-
base_path = 'locks'
|
1039 |
-
makedirs(base_path)
|
1040 |
-
|
1041 |
-
# user
|
1042 |
-
extra = ''
|
1043 |
-
file = "%s%s" % (visible_langchain_modes_file, extra)
|
1044 |
-
with filelock.FileLock(os.path.join(base_path, "%s.lock" % file)):
|
1045 |
-
with open(file, 'wb') as f:
|
1046 |
-
pickle.dump((user_langchain_modes, user_visible_langchain_modes, user_langchain_mode_paths), f)
|
1047 |
-
|
1048 |
-
# scratch
|
1049 |
-
extra = user_hash
|
1050 |
-
file = "%s%s" % (visible_langchain_modes_file, extra)
|
1051 |
-
with filelock.FileLock(os.path.join(base_path, "%s.lock" % file)):
|
1052 |
-
with open(file, 'wb') as f:
|
1053 |
-
pickle.dump((scratch_langchain_modes, scratch_visible_langchain_modes, scratch_langchain_mode_paths), f)
|
1054 |
-
|
1055 |
-
|
1056 |
-
def load_collection_enum(extra):
|
1057 |
-
"""
|
1058 |
-
extra controls if UserData type of MyData type
|
1059 |
-
"""
|
1060 |
-
file = "%s%s" % (visible_langchain_modes_file, extra)
|
1061 |
-
langchain_modes_from_file = []
|
1062 |
-
visible_langchain_modes_from_file = []
|
1063 |
-
langchain_mode_paths_from_file = {}
|
1064 |
-
if os.path.isfile(visible_langchain_modes_file):
|
1065 |
-
try:
|
1066 |
-
with filelock.FileLock("%s.lock" % file):
|
1067 |
-
with open(file, 'rb') as f:
|
1068 |
-
langchain_modes_from_file, visible_langchain_modes_from_file, langchain_mode_paths_from_file = pickle.load(
|
1069 |
-
f)
|
1070 |
-
except BaseException as e:
|
1071 |
-
print("Cannot load %s, ignoring error: %s" % (file, str(e)), flush=True)
|
1072 |
-
for k, v in langchain_mode_paths_from_file.items():
|
1073 |
-
if v is not None and not os.path.isdir(v) and isinstance(v, str):
|
1074 |
-
# assume was deleted, but need to make again to avoid extra code elsewhere
|
1075 |
-
makedirs(v)
|
1076 |
-
return langchain_modes_from_file, visible_langchain_modes_from_file, langchain_mode_paths_from_file
|
1077 |
-
|
1078 |
-
|
1079 |
-
def remove_collection_enum():
|
1080 |
-
remove(visible_langchain_modes_file)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils_langchain.py
DELETED
@@ -1,64 +0,0 @@
|
|
1 |
-
from typing import Any, Dict, List, Union, Optional
|
2 |
-
import time
|
3 |
-
import queue
|
4 |
-
|
5 |
-
from langchain.callbacks.base import BaseCallbackHandler
|
6 |
-
from langchain.schema import LLMResult
|
7 |
-
|
8 |
-
|
9 |
-
class StreamingGradioCallbackHandler(BaseCallbackHandler):
|
10 |
-
"""
|
11 |
-
Similar to H2OTextIteratorStreamer that is for HF backend, but here LangChain backend
|
12 |
-
"""
|
13 |
-
def __init__(self, timeout: Optional[float] = None, block=True):
|
14 |
-
super().__init__()
|
15 |
-
self.text_queue = queue.SimpleQueue()
|
16 |
-
self.stop_signal = None
|
17 |
-
self.do_stop = False
|
18 |
-
self.timeout = timeout
|
19 |
-
self.block = block
|
20 |
-
|
21 |
-
def on_llm_start(
|
22 |
-
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
23 |
-
) -> None:
|
24 |
-
"""Run when LLM starts running. Clean the queue."""
|
25 |
-
while not self.text_queue.empty():
|
26 |
-
try:
|
27 |
-
self.text_queue.get(block=False)
|
28 |
-
except queue.Empty:
|
29 |
-
continue
|
30 |
-
|
31 |
-
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
32 |
-
"""Run on new LLM token. Only available when streaming is enabled."""
|
33 |
-
self.text_queue.put(token)
|
34 |
-
|
35 |
-
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
36 |
-
"""Run when LLM ends running."""
|
37 |
-
self.text_queue.put(self.stop_signal)
|
38 |
-
|
39 |
-
def on_llm_error(
|
40 |
-
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
41 |
-
) -> None:
|
42 |
-
"""Run when LLM errors."""
|
43 |
-
self.text_queue.put(self.stop_signal)
|
44 |
-
|
45 |
-
def __iter__(self):
|
46 |
-
return self
|
47 |
-
|
48 |
-
def __next__(self):
|
49 |
-
while True:
|
50 |
-
try:
|
51 |
-
value = self.stop_signal # value looks unused in pycharm, not true
|
52 |
-
if self.do_stop:
|
53 |
-
print("hit stop", flush=True)
|
54 |
-
# could raise or break, maybe best to raise and make parent see if any exception in thread
|
55 |
-
raise StopIteration()
|
56 |
-
# break
|
57 |
-
value = self.text_queue.get(block=self.block, timeout=self.timeout)
|
58 |
-
break
|
59 |
-
except queue.Empty:
|
60 |
-
time.sleep(0.01)
|
61 |
-
if value == self.stop_signal:
|
62 |
-
raise StopIteration()
|
63 |
-
else:
|
64 |
-
return value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|