pseudotensor commited on
Commit
4afcc1d
1 Parent(s): 33cecf1
Files changed (16) hide show
  1. client_test.py +0 -362
  2. create_data.py +0 -1809
  3. enums.py +0 -120
  4. evaluate_params.py +0 -52
  5. gen.py +0 -0
  6. generate.py +0 -16
  7. gpt4all_llm.py +0 -316
  8. gpt_langchain.py +0 -0
  9. gradio_runner.py +0 -0
  10. gradio_themes.py +0 -231
  11. h2oai_pipeline.py +0 -201
  12. loaders.py +0 -61
  13. prompter.py +0 -871
  14. stopping.py +0 -78
  15. utils.py +0 -1080
  16. utils_langchain.py +0 -64
client_test.py DELETED
@@ -1,362 +0,0 @@
1
- """
2
- Client test.
3
-
4
- Run server:
5
-
6
- python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6_9b
7
-
8
- NOTE: For private models, add --use-auth_token=True
9
-
10
- NOTE: --use_gpu_id=True (default) must be used for multi-GPU in case see failures with cuda:x cuda:y mismatches.
11
- Currently, this will force model to be on a single GPU.
12
-
13
- Then run this client as:
14
-
15
- python src/client_test.py
16
-
17
-
18
-
19
- For HF spaces:
20
-
21
- HOST="https://h2oai-h2ogpt-chatbot.hf.space" python src/client_test.py
22
-
23
- Result:
24
-
25
- Loaded as API: https://h2oai-h2ogpt-chatbot.hf.space ✔
26
- {'instruction_nochat': 'Who are you?', 'iinput_nochat': '', 'response': 'I am h2oGPT, a large language model developed by LAION.', 'sources': ''}
27
-
28
-
29
- For demo:
30
-
31
- HOST="https://gpt.h2o.ai" python src/client_test.py
32
-
33
- Result:
34
-
35
- Loaded as API: https://gpt.h2o.ai ✔
36
- {'instruction_nochat': 'Who are you?', 'iinput_nochat': '', 'response': 'I am h2oGPT, a chatbot created by LAION.', 'sources': ''}
37
-
38
- NOTE: Raw output from API for nochat case is a string of a python dict and will remain so if other entries are added to dict:
39
-
40
- {'response': "I'm h2oGPT, a large language model by H2O.ai, the visionary leader in democratizing AI.", 'sources': ''}
41
-
42
-
43
- """
44
- import ast
45
- import time
46
- import os
47
- import markdown # pip install markdown
48
- import pytest
49
- from bs4 import BeautifulSoup # pip install beautifulsoup4
50
-
51
- from enums import DocumentSubset, LangChainAction
52
-
53
- debug = False
54
-
55
- os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
56
-
57
-
58
- def get_client(serialize=True):
59
- from gradio_client import Client
60
-
61
- client = Client(os.getenv('HOST', "http://localhost:7860"), serialize=serialize)
62
- if debug:
63
- print(client.view_api(all_endpoints=True))
64
- return client
65
-
66
-
67
- def get_args(prompt, prompt_type, chat=False, stream_output=False,
68
- max_new_tokens=50,
69
- top_k_docs=3,
70
- langchain_mode='Disabled',
71
- add_chat_history_to_context=True,
72
- langchain_action=LangChainAction.QUERY.value,
73
- langchain_agents=[],
74
- prompt_dict=None):
75
- from collections import OrderedDict
76
- kwargs = OrderedDict(instruction=prompt if chat else '', # only for chat=True
77
- iinput='', # only for chat=True
78
- context='',
79
- # streaming output is supported, loops over and outputs each generation in streaming mode
80
- # but leave stream_output=False for simple input/output mode
81
- stream_output=stream_output,
82
- prompt_type=prompt_type,
83
- prompt_dict=prompt_dict,
84
- temperature=0.1,
85
- top_p=0.75,
86
- top_k=40,
87
- num_beams=1,
88
- max_new_tokens=max_new_tokens,
89
- min_new_tokens=0,
90
- early_stopping=False,
91
- max_time=20,
92
- repetition_penalty=1.0,
93
- num_return_sequences=1,
94
- do_sample=True,
95
- chat=chat,
96
- instruction_nochat=prompt if not chat else '',
97
- iinput_nochat='', # only for chat=False
98
- langchain_mode=langchain_mode,
99
- add_chat_history_to_context=add_chat_history_to_context,
100
- langchain_action=langchain_action,
101
- langchain_agents=langchain_agents,
102
- top_k_docs=top_k_docs,
103
- chunk=True,
104
- chunk_size=512,
105
- document_subset=DocumentSubset.Relevant.name,
106
- document_choice=[],
107
- )
108
- from evaluate_params import eval_func_param_names
109
- assert len(set(eval_func_param_names).difference(set(list(kwargs.keys())))) == 0
110
- if chat:
111
- # add chatbot output on end. Assumes serialize=False
112
- kwargs.update(dict(chatbot=[]))
113
-
114
- return kwargs, list(kwargs.values())
115
-
116
-
117
- @pytest.mark.skip(reason="For manual use against some server, no server launched")
118
- def test_client_basic(prompt_type='human_bot'):
119
- return run_client_nochat(prompt='Who are you?', prompt_type=prompt_type, max_new_tokens=50)
120
-
121
-
122
- def run_client_nochat(prompt, prompt_type, max_new_tokens):
123
- kwargs, args = get_args(prompt, prompt_type, chat=False, max_new_tokens=max_new_tokens)
124
-
125
- api_name = '/submit_nochat'
126
- client = get_client(serialize=True)
127
- res = client.predict(
128
- *tuple(args),
129
- api_name=api_name,
130
- )
131
- print("Raw client result: %s" % res, flush=True)
132
- res_dict = dict(prompt=kwargs['instruction_nochat'], iinput=kwargs['iinput_nochat'],
133
- response=md_to_text(res))
134
- print(res_dict)
135
- return res_dict, client
136
-
137
-
138
- @pytest.mark.skip(reason="For manual use against some server, no server launched")
139
- def test_client_basic_api(prompt_type='human_bot'):
140
- return run_client_nochat_api(prompt='Who are you?', prompt_type=prompt_type, max_new_tokens=50)
141
-
142
-
143
- def run_client_nochat_api(prompt, prompt_type, max_new_tokens):
144
- kwargs, args = get_args(prompt, prompt_type, chat=False, max_new_tokens=max_new_tokens)
145
-
146
- api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
147
- client = get_client(serialize=True)
148
- res = client.predict(
149
- str(dict(kwargs)),
150
- api_name=api_name,
151
- )
152
- print("Raw client result: %s" % res, flush=True)
153
- res_dict = dict(prompt=kwargs['instruction_nochat'], iinput=kwargs['iinput_nochat'],
154
- response=md_to_text(ast.literal_eval(res)['response']),
155
- sources=ast.literal_eval(res)['sources'])
156
- print(res_dict)
157
- return res_dict, client
158
-
159
-
160
- @pytest.mark.skip(reason="For manual use against some server, no server launched")
161
- def test_client_basic_api_lean(prompt_type='human_bot'):
162
- return run_client_nochat_api_lean(prompt='Who are you?', prompt_type=prompt_type, max_new_tokens=50)
163
-
164
-
165
- def run_client_nochat_api_lean(prompt, prompt_type, max_new_tokens):
166
- kwargs = dict(instruction_nochat=prompt)
167
-
168
- api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
169
- client = get_client(serialize=True)
170
- res = client.predict(
171
- str(dict(kwargs)),
172
- api_name=api_name,
173
- )
174
- print("Raw client result: %s" % res, flush=True)
175
- res_dict = dict(prompt=kwargs['instruction_nochat'],
176
- response=md_to_text(ast.literal_eval(res)['response']),
177
- sources=ast.literal_eval(res)['sources'])
178
- print(res_dict)
179
- return res_dict, client
180
-
181
-
182
- @pytest.mark.skip(reason="For manual use against some server, no server launched")
183
- def test_client_basic_api_lean_morestuff(prompt_type='human_bot'):
184
- return run_client_nochat_api_lean_morestuff(prompt='Who are you?', prompt_type=prompt_type, max_new_tokens=50)
185
-
186
-
187
- def run_client_nochat_api_lean_morestuff(prompt, prompt_type='human_bot', max_new_tokens=512):
188
- kwargs = dict(
189
- instruction='',
190
- iinput='',
191
- context='',
192
- stream_output=False,
193
- prompt_type=prompt_type,
194
- temperature=0.1,
195
- top_p=0.75,
196
- top_k=40,
197
- num_beams=1,
198
- max_new_tokens=256,
199
- min_new_tokens=0,
200
- early_stopping=False,
201
- max_time=20,
202
- repetition_penalty=1.0,
203
- num_return_sequences=1,
204
- do_sample=True,
205
- chat=False,
206
- instruction_nochat=prompt,
207
- iinput_nochat='',
208
- langchain_mode='Disabled',
209
- add_chat_history_to_context=True,
210
- langchain_action=LangChainAction.QUERY.value,
211
- langchain_agents=[],
212
- top_k_docs=4,
213
- document_subset=DocumentSubset.Relevant.name,
214
- document_choice=[],
215
- )
216
-
217
- api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
218
- client = get_client(serialize=True)
219
- res = client.predict(
220
- str(dict(kwargs)),
221
- api_name=api_name,
222
- )
223
- print("Raw client result: %s" % res, flush=True)
224
- res_dict = dict(prompt=kwargs['instruction_nochat'],
225
- response=md_to_text(ast.literal_eval(res)['response']),
226
- sources=ast.literal_eval(res)['sources'])
227
- print(res_dict)
228
- return res_dict, client
229
-
230
-
231
- @pytest.mark.skip(reason="For manual use against some server, no server launched")
232
- def test_client_chat(prompt_type='human_bot'):
233
- return run_client_chat(prompt='Who are you?', prompt_type=prompt_type, stream_output=False, max_new_tokens=50,
234
- langchain_mode='Disabled',
235
- langchain_action=LangChainAction.QUERY.value,
236
- langchain_agents=[])
237
-
238
-
239
- @pytest.mark.skip(reason="For manual use against some server, no server launched")
240
- def test_client_chat_stream(prompt_type='human_bot'):
241
- return run_client_chat(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type,
242
- stream_output=True, max_new_tokens=512,
243
- langchain_mode='Disabled',
244
- langchain_action=LangChainAction.QUERY.value,
245
- langchain_agents=[])
246
-
247
-
248
- def run_client_chat(prompt, prompt_type, stream_output, max_new_tokens,
249
- langchain_mode, langchain_action, langchain_agents,
250
- prompt_dict=None):
251
- client = get_client(serialize=False)
252
-
253
- kwargs, args = get_args(prompt, prompt_type, chat=True, stream_output=stream_output,
254
- max_new_tokens=max_new_tokens,
255
- langchain_mode=langchain_mode,
256
- langchain_action=langchain_action,
257
- langchain_agents=langchain_agents,
258
- prompt_dict=prompt_dict)
259
- return run_client(client, prompt, args, kwargs)
260
-
261
-
262
- def run_client(client, prompt, args, kwargs, do_md_to_text=True, verbose=False):
263
- assert kwargs['chat'], "Chat mode only"
264
- res = client.predict(*tuple(args), api_name='/instruction')
265
- args[-1] += [res[-1]]
266
-
267
- res_dict = kwargs
268
- res_dict['prompt'] = prompt
269
- if not kwargs['stream_output']:
270
- res = client.predict(*tuple(args), api_name='/instruction_bot')
271
- res_dict['response'] = res[0][-1][1]
272
- print(md_to_text(res_dict['response'], do_md_to_text=do_md_to_text))
273
- return res_dict, client
274
- else:
275
- job = client.submit(*tuple(args), api_name='/instruction_bot')
276
- res1 = ''
277
- while not job.done():
278
- outputs_list = job.communicator.job.outputs
279
- if outputs_list:
280
- res = job.communicator.job.outputs[-1]
281
- res1 = res[0][-1][-1]
282
- res1 = md_to_text(res1, do_md_to_text=do_md_to_text)
283
- print(res1)
284
- time.sleep(0.1)
285
- full_outputs = job.outputs()
286
- if verbose:
287
- print('job.outputs: %s' % str(full_outputs))
288
- # ensure get ending to avoid race
289
- # -1 means last response if streaming
290
- # 0 means get text_output, ignore exception_text
291
- # 0 means get list within text_output that looks like [[prompt], [answer]]
292
- # 1 means get bot answer, so will have last bot answer
293
- res_dict['response'] = md_to_text(full_outputs[-1][0][0][1], do_md_to_text=do_md_to_text)
294
- return res_dict, client
295
-
296
-
297
- @pytest.mark.skip(reason="For manual use against some server, no server launched")
298
- def test_client_nochat_stream(prompt_type='human_bot'):
299
- return run_client_nochat_gen(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type,
300
- stream_output=True, max_new_tokens=512,
301
- langchain_mode='Disabled',
302
- langchain_action=LangChainAction.QUERY.value,
303
- langchain_agents=[])
304
-
305
-
306
- def run_client_nochat_gen(prompt, prompt_type, stream_output, max_new_tokens,
307
- langchain_mode, langchain_action, langchain_agents):
308
- client = get_client(serialize=False)
309
-
310
- kwargs, args = get_args(prompt, prompt_type, chat=False, stream_output=stream_output,
311
- max_new_tokens=max_new_tokens, langchain_mode=langchain_mode,
312
- langchain_action=langchain_action, langchain_agents=langchain_agents)
313
- return run_client_gen(client, prompt, args, kwargs)
314
-
315
-
316
- def run_client_gen(client, prompt, args, kwargs, do_md_to_text=True, verbose=False):
317
- res_dict = kwargs
318
- res_dict['prompt'] = prompt
319
- if not kwargs['stream_output']:
320
- res = client.predict(str(dict(kwargs)), api_name='/submit_nochat_api')
321
- res_dict['response'] = res[0]
322
- print(md_to_text(res_dict['response'], do_md_to_text=do_md_to_text))
323
- return res_dict, client
324
- else:
325
- job = client.submit(str(dict(kwargs)), api_name='/submit_nochat_api')
326
- while not job.done():
327
- outputs_list = job.communicator.job.outputs
328
- if outputs_list:
329
- res = job.communicator.job.outputs[-1]
330
- res_dict = ast.literal_eval(res)
331
- print('Stream: %s' % res_dict['response'])
332
- time.sleep(0.1)
333
- res_list = job.outputs()
334
- assert len(res_list) > 0, "No response, check server"
335
- res = res_list[-1]
336
- res_dict = ast.literal_eval(res)
337
- print('Final: %s' % res_dict['response'])
338
- return res_dict, client
339
-
340
-
341
- def md_to_text(md, do_md_to_text=True):
342
- if not do_md_to_text:
343
- return md
344
- assert md is not None, "Markdown is None"
345
- html = markdown.markdown(md)
346
- soup = BeautifulSoup(html, features='html.parser')
347
- return soup.get_text()
348
-
349
-
350
- def run_client_many(prompt_type='human_bot'):
351
- ret1, _ = test_client_chat(prompt_type=prompt_type)
352
- ret2, _ = test_client_chat_stream(prompt_type=prompt_type)
353
- ret3, _ = test_client_nochat_stream(prompt_type=prompt_type)
354
- ret4, _ = test_client_basic(prompt_type=prompt_type)
355
- ret5, _ = test_client_basic_api(prompt_type=prompt_type)
356
- ret6, _ = test_client_basic_api_lean(prompt_type=prompt_type)
357
- ret7, _ = test_client_basic_api_lean_morestuff(prompt_type=prompt_type)
358
- return ret1, ret2, ret3, ret4, ret5, ret6, ret7
359
-
360
-
361
- if __name__ == '__main__':
362
- run_client_many()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
create_data.py DELETED
@@ -1,1809 +0,0 @@
1
- """
2
- Dataset creation tools.
3
-
4
- Keep to-level imports clean of non-trivial imports for specific tools,
5
- because this file is imported for various purposes
6
- """
7
-
8
- import ast
9
- import concurrent.futures
10
- import contextlib
11
- import hashlib
12
- import json
13
- import os
14
- import shutil
15
- import signal
16
- import sys
17
- import traceback
18
- from concurrent.futures import ProcessPoolExecutor
19
-
20
- import psutil
21
- import pytest
22
- import pandas as pd
23
- import numpy as np
24
- from tqdm import tqdm
25
-
26
- from utils import flatten_list, remove
27
-
28
-
29
- def parse_rst_file(filepath):
30
- with open(filepath, 'r') as f:
31
- input_data = f.read()
32
- settings_overrides = {'initial_header_level': 2}
33
- from docutils import core
34
- document = core.publish_doctree(
35
- source=input_data,
36
- source_path=filepath,
37
- settings_overrides=settings_overrides,
38
- )
39
- qa_pairs = []
40
- current_section = None
41
- current_question = ""
42
- current_answer = ""
43
- for node in document.traverse():
44
- if node.__class__.__name__ == 'section':
45
- current_section = ""
46
- elif current_section is not None:
47
- if node.__class__.__name__ == 'Text':
48
- if node.astext()[-1] == "?":
49
- if current_question:
50
- qa_pairs.append((current_question, current_answer))
51
- current_question = node.astext()
52
- current_answer = ""
53
- else:
54
- current_answer += node.astext()
55
- if current_answer:
56
- qa_pairs.append((current_question, current_answer))
57
- return {k: v for k, v in qa_pairs}
58
-
59
-
60
- def test_scrape_dai_docs():
61
- home = os.path.expanduser('~')
62
- file = os.path.join(home, 'h2oai/docs/faq.rst')
63
- qa_pairs = parse_rst_file(file)
64
- prompt_type = 'human_bot'
65
- from prompter import prompt_types
66
- assert prompt_type in prompt_types
67
- save_thing = [{"instruction": k, "output": v, 'prompt_type': prompt_type} for k, v in qa_pairs.items()]
68
- output_file = "dai_faq.json"
69
- with open(output_file, "wt") as f:
70
- f.write(json.dumps(save_thing, indent=2))
71
-
72
-
73
- def test_scrape_dai_docs_all():
74
- """
75
- pytest create_data.py::test_scrape_dai_docs_all
76
- """
77
- import glob
78
- import nltk
79
- nltk.download('punkt')
80
- dd = {}
81
- np.random.seed(1234)
82
- home = os.path.expanduser('~')
83
- files = list(glob.glob(os.path.join(home, "h2oai/docs/**/*rst")))
84
- np.random.shuffle(files)
85
- val_count = int(0.05 * len(files))
86
- train_files = files[val_count:]
87
- valid_files = files[:val_count]
88
- things = [
89
- ("dai_docs.train.json", train_files),
90
- ("dai_docs.valid.json", valid_files)
91
- ]
92
- for LEN in [100, 200, 500]:
93
- for output_file, ff in things:
94
- if output_file not in dd:
95
- dd[output_file] = []
96
- for f in ff:
97
- with open(f) as input:
98
- blob = input.read()
99
- blob = blob.replace("~~", "")
100
- blob = blob.replace("==", "")
101
- blob = blob.replace("''", "")
102
- blob = blob.replace("--", "")
103
- blob = blob.replace("**", "")
104
- dd[output_file].extend(get_sentences(blob, length=LEN))
105
- for output_file, _ in things:
106
- save_thing = [{"output": k.strip(), 'prompt_type': 'plain'} for k in dd[output_file]]
107
- with open(output_file, "wt") as f:
108
- f.write(json.dumps(save_thing, indent=2))
109
-
110
-
111
- def get_sentences(blob, length):
112
- """
113
- break-up input text into sentences and then output list of sentences of about length in size
114
- :param blob:
115
- :param length:
116
- :return:
117
- """
118
- import nltk
119
- nltk.download('punkt')
120
- from nltk.tokenize import sent_tokenize
121
- sentences = sent_tokenize(blob)
122
- my_sentences = []
123
- my_string = ""
124
- for sentence in sentences:
125
- if len(my_string) + len(sentence) <= length:
126
- if my_string:
127
- my_string += " " + sentence
128
- else:
129
- my_string = sentence
130
- else:
131
- my_sentences.append(my_string)
132
- my_string = ""
133
- return my_sentences or [my_string]
134
-
135
-
136
- def setup_dai_docs(path=None, dst="working_dir_docs", from_hf=False):
137
- """
138
- Only supported if have access to source code or HF token for HF spaces and from_hf=True
139
- :param path:
140
- :param dst:
141
- :param from_hf:
142
- :return:
143
- """
144
-
145
- home = os.path.expanduser('~')
146
-
147
- if from_hf:
148
- # assumes
149
- from huggingface_hub import hf_hub_download
150
- # True for case when locally already logged in with correct token, so don't have to set key
151
- token = os.getenv('HUGGINGFACE_API_TOKEN', True)
152
- path_to_zip_file = hf_hub_download('h2oai/dai_docs', 'dai_docs.zip', token=token, repo_type='dataset')
153
- path = 'h2oai'
154
- import zipfile
155
- with zipfile.ZipFile(path_to_zip_file, 'r') as zip_ref:
156
- zip_ref.extractall(path)
157
- path = os.path.join(path, 'docs/**/*')
158
-
159
- if path is None:
160
- if os.path.isdir(os.path.join(home, 'h2oai')):
161
- path = os.path.join(home, "h2oai/docs/**/*")
162
- else:
163
- assert os.path.isdir(os.path.join(home, 'h2oai.superclean')), '%s does not exist' % path
164
- path = os.path.join(home, "h2oai.superclean/docs/**/*")
165
- import glob
166
- files = list(glob.glob(path, recursive=True))
167
-
168
- # pandoc can't find include files
169
-
170
- remove(dst)
171
- os.makedirs(dst)
172
-
173
- # copy full tree, for absolute paths in rst
174
- for fil in files:
175
- if os.path.isfile(fil):
176
- shutil.copy(fil, dst)
177
-
178
- # hack for relative path
179
- scorers_dir = os.path.join(dst, 'scorers')
180
- makedirs(scorers_dir)
181
- for fil in glob.glob(os.path.join(dst, '*.frag')):
182
- shutil.copy(fil, scorers_dir)
183
-
184
- return dst
185
-
186
-
187
- def rst_to_outputs(files, min_len=30, max_len=2048 // 2 - 30):
188
- # account for sequence length (context window) including prompt and input and output
189
-
190
- # os.system('pandoc -f rst -t plain ./expert_settings/nlp_settings.rst')
191
- import pypandoc
192
- basedir = os.path.abspath(os.getcwd())
193
-
194
- outputs = []
195
- for fil in files:
196
- os.chdir(basedir)
197
- os.chdir(os.path.dirname(fil))
198
- fil = os.path.basename(fil)
199
- print("Processing %s" % fil, flush=True)
200
- # out_format can be one of: asciidoc, asciidoctor, beamer, biblatex, bibtex, commonmark, commonmark_x,
201
- # context, csljson, docbook, docbook4, docbook5, docx, dokuwiki,
202
- # dzslides, epub, epub2, epub3, fb2, gfm, haddock, html, html4, html5, icml,
203
- # ipynb, jats, jats_archiving, jats_articleauthoring, jats_publishing, jira,
204
- # json, latex, man,
205
- # markdown, markdown_github, markdown_mmd, markdown_phpextra, markdown_strict,
206
- # mediawiki, ms, muse, native, odt, opendocument, opml, org, pdf, plain, pptx,
207
- # revealjs, rst, rtf, s5, slideous, slidy, tei, texinfo, textile, xwiki, zimwiki
208
- out_format = 'plain'
209
- # avoid extra new lines injected into text
210
- extra_args = ['--wrap=preserve', '--resource path="%s" % dst']
211
-
212
- plain_list = []
213
- try:
214
- # valid for expert settings
215
- input_rst = pypandoc.convert_file(fil, 'rst')
216
- input_list = input_rst.split('\n``')
217
- for input_subrst in input_list:
218
- input_plain = pypandoc.convert_text(input_subrst, format='rst', to='plain')
219
- plain_list.append([input_plain, fil])
220
- except Exception as e:
221
- print("file exception: %s %s" % (fil, str(e)), flush=True)
222
-
223
- if not plain_list:
224
- # if failed to process as pieces of rst, then
225
- output = pypandoc.convert_file(fil, out_format, extra_args=extra_args, format='rst')
226
- outputs1 = get_sentences(output, length=max_len)
227
- for oi, output in enumerate(outputs1):
228
- output = output.replace('\n\n', '\n')
229
- plain_list.append([output, fil])
230
- outputs.extend(plain_list)
231
-
232
- # report:
233
- # [print(len(x)) for x in outputs]
234
-
235
- # deal with blocks longer than context size (sequence length) of 2048
236
- new_outputs = []
237
- num_truncated = 0
238
- num_orig = len(outputs)
239
- for output, fil in outputs:
240
- if len(output) < max_len:
241
- new_outputs.append([output, fil])
242
- continue
243
- outputs1 = get_sentences(output, length=max_len)
244
- for oi, output1 in enumerate(outputs1):
245
- output1 = output1.replace('\n\n', '\n')
246
- new_outputs.append([output1, fil])
247
- num_truncated += 1
248
- print('num_orig: %s num_truncated: %s' % (num_orig, num_truncated), flush=True)
249
-
250
- new_outputs = [[k.strip(), fil] for k, fil in new_outputs if len(k.strip()) > min_len]
251
-
252
- return new_outputs
253
-
254
-
255
- def test_scrape_dai_docs_all_pandoc():
256
- """
257
- pytest -s -v create_data.py::test_scrape_dai_docs_all_pandoc
258
- :return:
259
- """
260
-
261
- dst = setup_dai_docs()
262
-
263
- import glob
264
- files = list(glob.glob(os.path.join(dst, '*rst'), recursive=True))
265
-
266
- basedir = os.path.abspath(os.getcwd())
267
- new_outputs = rst_to_outputs(files)
268
- os.chdir(basedir)
269
-
270
- remove(dst)
271
- save_thing = [{"output": k.strip(), 'prompt_type': 'plain'} for k in new_outputs]
272
- output_file = "dai_docs.train_cleaned.json"
273
- with open(output_file, "wt") as f:
274
- f.write(json.dumps(save_thing, indent=2))
275
-
276
-
277
- def test_config_to_json():
278
- """
279
- Needs to run from Driverless AI source directory.
280
- E.g. (base) jon@gpu:~/h2oai$ pytest -s -v /data/jon/h2ogpt/create_data.py::test_config_to_json ; cp config.json /data/jon/h2ogpt/
281
- :return:
282
- """
283
- try:
284
- # Arrange
285
- import json
286
- from h2oaicore.systemutils import config
287
- toml_list = []
288
- for k, v in config.get_meta_dict().items():
289
- title = (v.title + ": ") if v.title else ''
290
- comment = v.comment or ''
291
- if not (title or comment):
292
- continue
293
- toml_list.extend(
294
- [
295
- {
296
- 'prompt_type': 'plain',
297
- 'instruction': f"<human>: What does {k} do?\n<bot>: {k.replace('_', ' ')} config.toml: {comment or title}\n<human>:".replace(
298
- "\n", ""),
299
- },
300
- {
301
- 'prompt_type': 'plain',
302
- 'instruction': f"<human>: Explain {k}.\n<bot>: {k.replace('_', ' ')} config.toml: {comment or title}\n<human>:".replace(
303
- "\n", ""),
304
- },
305
- {
306
- 'prompt_type': 'plain',
307
- 'instruction': f"<human>: How can I do this: {title}.\n<bot>: Set the {k.replace('_', ' ')} config.toml\n<human>:".replace(
308
- "\n", ""),
309
- } if title and comment else None,
310
- {
311
- 'prompt_type': 'human_bot',
312
- 'instruction': f'Explain the following expert setting for Driverless AI',
313
- 'input': f"{k}",
314
- 'output': f"{k.replace('_', ' ')} config.toml: {comment or title}".replace("\n", ""),
315
- },
316
- {
317
- 'prompt_type': 'human_bot',
318
- 'instruction': f'Explain the following expert setting for Driverless AI',
319
- 'input': f"{k}",
320
- 'output': f"{k.replace('_', ' ')} config.toml: {title}{comment}".replace("\n", ""),
321
- },
322
- {
323
- 'prompt_type': 'human_bot',
324
- 'instruction': f'Explain the following expert setting for Driverless AI',
325
- 'input': f"{k.replace('_', ' ')}",
326
- 'output': f"{k.replace('_', ' ')} config.toml: {title}{comment}".replace("\n", ""),
327
- },
328
- {
329
- 'prompt_type': 'human_bot',
330
- 'instruction': f'Explain the following expert setting for Driverless AI',
331
- 'input': f"{title}",
332
- 'output': f"{k.replace('_', ' ')} config.toml: {title}{comment}".replace("\n", ""),
333
- },
334
- {
335
- 'prompt_type': 'human_bot',
336
- 'instruction': f'Provide a short explanation of the expert setting {k}',
337
- 'output': f"{k.replace('_', ' ')} config.toml: {comment or title}".replace("\n", ""),
338
- },
339
- {
340
- 'prompt_type': 'human_bot',
341
- 'instruction': f'Provide a detailed explanation of the expert setting {k}',
342
- 'output': f"{k.replace('_', ' ')} config.toml: {title}{comment}".replace("\n", ""),
343
- },
344
- ]
345
- )
346
- toml_list = [x for x in toml_list if x]
347
- with open("config.json", "wt") as f:
348
- f.write(json.dumps(toml_list, indent=2))
349
- except Exception as e:
350
- print("Exception: %s" % str(e), flush=True)
351
-
352
-
353
- def copy_tree(src, dst, follow_symlink=False):
354
- makedirs(dst, exist_ok=True)
355
- for (path, dirs, files) in os.walk(src, followlinks=follow_symlink):
356
- new_path = path.replace(src, dst)
357
- makedirs(new_path, exist_ok=True)
358
- for file in files:
359
- filename = os.path.join(path, file)
360
- new_filename = os.path.join(new_path, file)
361
- # print("%s -> %s" % (filename, new_filename))
362
- try:
363
- atomic_copy(filename, new_filename)
364
- except FileNotFoundError:
365
- pass
366
-
367
-
368
- def atomic_move(src, dst):
369
- try:
370
- shutil.move(src, dst)
371
- except (shutil.Error, FileExistsError):
372
- pass
373
- remove(src)
374
-
375
-
376
- def atomic_copy(src=None, dst=None, with_permissions=True):
377
- if os.path.isfile(dst):
378
- return
379
- import uuid
380
- my_uuid = uuid.uuid4()
381
- dst_tmp = dst + str(my_uuid)
382
- makedirs(os.path.dirname(dst), exist_ok=True)
383
- if with_permissions:
384
- shutil.copy(src, dst_tmp)
385
- else:
386
- shutil.copyfile(src, dst_tmp)
387
- atomic_move(dst_tmp, dst)
388
- remove(dst_tmp)
389
-
390
-
391
- def makedirs(path, exist_ok=True):
392
- """
393
- Avoid some inefficiency in os.makedirs()
394
- :param path:
395
- :param exist_ok:
396
- :return:
397
- """
398
- if os.path.isdir(path) and os.path.exists(path):
399
- assert exist_ok, "Path already exists"
400
- return path
401
- os.makedirs(path, exist_ok=exist_ok)
402
-
403
-
404
- ## Download from https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_unfiltered_cleaned_split.json
405
- ## Turn into simple instruct prompt type. No context/previous conversations.
406
- def test_prep_instruct_vicuna():
407
- from datasets import load_dataset
408
- filename = 'ShareGPT_unfiltered_cleaned_split.json'
409
- if not os.path.exists(filename):
410
- os.system(
411
- 'wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/%s' % filename)
412
- data = load_dataset("json", data_files={"train": filename})["train"]
413
- training_rows = []
414
- for i in range(data.num_rows):
415
- conversations = data[i]['conversations']
416
- assert isinstance(conversations, list), conversations
417
- convo = ""
418
- for j, conv in enumerate(conversations):
419
- # Get ready for generate.py prompt_type=human_bot
420
- # But train with prompt_type=plain
421
- if conv['from'] == 'human':
422
- FROM = '<human>: '
423
- elif conv['from'] == 'gpt':
424
- FROM = '<bot>: '
425
- convo += f"{FROM}" + conv['value'] + "\n"
426
- if convo:
427
- training_rows.append(dict(input=convo))
428
- with open(filename + ".generate_human_bot.train_plain.json", "wt") as f:
429
- f.write(json.dumps(training_rows, indent=2))
430
-
431
-
432
- POSTFIX = ".generate_human_bot.train_plain.json"
433
-
434
- # https://bair.berkeley.edu/blog/2023/04/03/koala/
435
- OIG_DATASETS = [
436
- "unified_chip2.jsonl",
437
- "unified_grade_school_math_instructions.jsonl",
438
- "unified_poetry_2_song.jsonl",
439
- "unified_plot_screenplay_books_dialog.jsonl",
440
- ]
441
-
442
- # hub issue: https://huggingface.co/datasets/laion/OIG/discussions/4
443
- ALL_OIG_DATASETS = ['unified_abstract_infill.jsonl',
444
- 'unified_basic.jsonl',
445
- 'unified_canadian_parliament.jsonl',
446
- 'unified_chip2.jsonl',
447
- 'unified_conv_finqa.jsonl',
448
- 'unified_cuad.jsonl',
449
- 'unified_essays.jsonl',
450
- 'unified_flan.jsonl.gz',
451
- 'unified_grade_school_math_instructions.jsonl',
452
- 'unified_hc3_human.jsonl',
453
- 'unified_image_prompts_instructions.jsonl',
454
- 'unified_joke_explanations.jsonl',
455
- 'unified_mathqa_flanv2_kojma_cot.jsonl',
456
- 'unified_merged_code_xp3.jsonl',
457
- 'unified_multi_news.jsonl',
458
- 'unified_multi_sum.jsonl',
459
- 'unified_ni.jsonl.gz',
460
- 'unified_nq.jsonl',
461
- 'unified_openai_summarize_tldr.jsonl',
462
- 'unified_oscar_en_sample_dialog.jsonl',
463
- 'unified_p3.jsonl.gz',
464
- 'unified_plot_screenplay_books_dialog.jsonl',
465
- 'unified_poetry_2_song.jsonl',
466
- 'unified_poetry_instructions.jsonl',
467
- 'unified_rallio_safety_and_prosocial.jsonl',
468
- 'unified_rallio_soda_upgraded_2048.jsonl',
469
- 'unified_soda_dialog.jsonl',
470
- 'unified_sqlv1.jsonl',
471
- 'unified_sqlv2.jsonl',
472
- 'unified_squad_v2.jsonl',
473
- 'unified_squad_v2_more_neg.jsonl',
474
- 'unified_ul2_plus_oscar_en_sample_dialog.jsonl',
475
- 'unified_unifiedskg_instructions.jsonl',
476
- 'unified_unnatural_instructions.jsonl',
477
- 'unified_xp3_sample.jsonl']
478
-
479
- useful_oig_files = ['unified_rallio_safety_and_prosocial.jsonl.parquet',
480
- 'unified_chip2.jsonl.parquet',
481
- 'unified_cuad.jsonl.parquet',
482
- 'unified_essays.jsonl.parquet',
483
- 'unified_flan.jsonl.gz.parquet',
484
- 'unified_grade_school_math_instructions.jsonl.parquet',
485
- 'unified_hc3_human.jsonl.parquet',
486
- 'unified_mathqa_flanv2_kojma_cot.jsonl.parquet',
487
- 'unified_merged_code_xp3.jsonl.parquet',
488
- 'unified_multi_news.jsonl.parquet',
489
- # 'unified_multi_sum.jsonl.parquet'
490
- 'unified_ni.jsonl.gz.parquet',
491
- 'unified_openai_summarize_tldr.jsonl.parquet',
492
- # 'unified_oscar_en_sample_dialog.jsonl.parquet', # create text containing these N words, not specific
493
- 'unified_plot_screenplay_books_dialog.jsonl.parquet',
494
- 'unified_soda_dialog.jsonl.parquet',
495
- 'unified_unnatural_instructions.jsonl.parquet',
496
- ]
497
-
498
-
499
- @pytest.mark.parametrize("filename", OIG_DATASETS)
500
- def test_get_small_sample_oig_data(filename):
501
- if not os.path.exists(filename):
502
- os.system('wget https://huggingface.co/datasets/laion/OIG/resolve/main/%s' % filename)
503
- import json
504
- rows = []
505
- with open(filename, "r") as f:
506
- for line in f.readlines():
507
- row = json.loads(line)
508
- rows.append(dict(input=row["text"]))
509
- with open(filename + POSTFIX, "w") as f:
510
- f.write(json.dumps(rows, indent=2))
511
-
512
-
513
- @pytest.mark.parametrize("filename", ALL_OIG_DATASETS)
514
- def test_download_useful_data_as_parquet(filename):
515
- dest_file = filename + '.parquet'
516
- if dest_file not in useful_oig_files:
517
- pytest.skip('file declared not useful')
518
- if not os.path.exists(filename):
519
- os.system('wget https://huggingface.co/datasets/laion/OIG/resolve/main/%s' % filename)
520
- if not os.path.exists(dest_file):
521
- df = pd.read_json(path_or_buf=filename, lines=True)
522
- df.to_parquet(dest_file, index=False)
523
-
524
-
525
- def test_merge_shuffle_small_sample_oig_data():
526
- np.random.seed(1234)
527
- rows = []
528
- for filename in OIG_DATASETS:
529
- with open(filename + POSTFIX, "r") as f:
530
- rows.extend(json.loads(f.read()))
531
- np.random.shuffle(rows)
532
- with open("merged_shuffled_OIG_%s.json" % hashlib.sha256(str(OIG_DATASETS).encode()).hexdigest()[:10], "w") as f:
533
- f.write(json.dumps(rows, indent=2))
534
-
535
-
536
- def test_join_jsons():
537
- files = ['config.json'] * 1 + \
538
- ['dai_docs.train_cleaned.json'] * 2 + \
539
- ['dai_faq.json'] * 3
540
- print(files)
541
- lst = []
542
- [lst.extend(json.load(open(fil, 'rt'))) for fil in files]
543
- print(len(lst))
544
- json.dump(lst, open("merged.json", "wt"), indent=2)
545
-
546
-
547
- @pytest.mark.parametrize("filename", ['Anthropic/hh-rlhf'])
548
- def test_make_rlhf_good_data(filename):
549
- from datasets import load_dataset
550
- rows = load_dataset(filename)["train"]["chosen"]
551
- new_rows = []
552
- for row in rows:
553
- if row[:2] == "\n\n":
554
- row = row[2:]
555
- row = row.replace("Human: ", "<human>: ")
556
- row = row.replace("Assistant: ", "<bot>: ")
557
- new_rows.append(dict(input=row))
558
- with open(filename.replace("/", "_") + POSTFIX, "w") as f:
559
- f.write(json.dumps(new_rows, indent=2))
560
-
561
-
562
- def test_show_prompts():
563
- files = ['config.json'] * 1 + \
564
- ['dai_docs.train_cleaned.json'] * 1 + \
565
- ['dai_faq.json'] * 1
566
- file_points = [json.load(open(fil, 'rt')) for fil in files]
567
- from prompter import generate_prompt
568
- for data_points in file_points:
569
- for data_point in data_points:
570
- print(generate_prompt(data_point, 'plain', '', False, False, False)[0])
571
-
572
-
573
- def test_get_open_datasets():
574
- # HF changed things so don't get raw list of all datasets, so not have to filter, but can't do negative filter
575
- open_tags = ['license:Apache License 2.0',
576
- 'license:mit',
577
- 'license:apache',
578
- 'license:apache2',
579
- 'license:apache-2.0',
580
- 'license:bsd',
581
- 'license:bsd-2-clause',
582
- 'license:bsd-3-clause',
583
- 'license:bsd-3-clause-clear',
584
- 'license:lgpl-2.1',
585
- 'license:lgpl-3.0',
586
- 'license:lgpl-lr',
587
- 'license:lgpl',
588
- 'license:openrail++',
589
- 'license:openrail',
590
- 'license:bigscience-bloom-rail-1.0',
591
- # 'license:agpl-3.0',
592
- 'license:other',
593
- 'license:unknown',
594
- # 'license:mpl-2.0', # ok, but would have to include original copyright, license, source, copies in distribution
595
- # Attribution required:
596
- 'license:odc-by',
597
- 'license:cc-by-4.0',
598
- 'license:cc-by-3.0',
599
- 'license:cc-by-2.0',
600
- 'license:cc-by-2.5',
601
- # 'license:cc-by-sa-4.0', # would require same license
602
- 'license:odbl',
603
- 'license:pddl',
604
- 'license:ms-pl',
605
- 'license:zlib',
606
- ]
607
- # bad license: cc-by-nc-4.0
608
-
609
- from huggingface_hub import list_datasets
610
- datasets = flatten_list([[x for x in list_datasets(filter=y)] for y in open_tags])
611
- datasets += [x for x in list_datasets(author='openai')]
612
- # check all:
613
- all_license_tags = set(flatten_list([[y for y in x.tags if 'license' in y] for x in datasets]))
614
- print(len(all_license_tags))
615
- open_datasets = [x for x in datasets if any([y in x.tags for y in open_tags]) or 'license:' not in str(x.tags)]
616
- print('open_datasets', len(open_datasets))
617
- all_task_tags = set(flatten_list([[y for y in x.tags if 'task' in y] for x in open_datasets]))
618
- print('all_task_tags', len(all_task_tags))
619
- excluded_tags = ['image', 'hate', 'tabular', 'table-', 'classification', 'retrieval',
620
- 'translation', 'identification', 'object', 'mask', 'to-text',
621
- 'face-detection', 'audio', 'voice', 'reinforcement', 'depth-est',
622
- 'forecasting', 'parsing', 'visual', 'speech', 'multiple-choice',
623
- 'slot-filling', 'irds/argsme', '-scoring', 'other', 'graph-ml',
624
- 'feature-extraction', 'keyword-spotting',
625
- 'coreference-resolution', 'segmentation',
626
- 'word-sense-disambiguation',
627
- 'lemmatization']
628
- task_tags = [x.replace('task_categories:', '').replace('task_ids:', '')
629
- for x in all_task_tags if not any([y in x for y in
630
- excluded_tags])]
631
- print('task_tags', len(task_tags))
632
- # str(x.tags) to catch any pattern match to anything in list
633
- open_tasked_datasets = [x for x in open_datasets if
634
- any([y in str([x for x in x.tags if 'task' in x]) for y in task_tags]) and
635
- not any([y in str([x for x in x.tags if 'task' in x]) for y in excluded_tags]) or
636
- 'task_categories' not in str(x.tags) and 'task_ids' not in str(x.tags)]
637
- open_tasked_datasets = [x for x in open_tasked_datasets if not x.disabled]
638
- open_tasked_datasets = [x for x in open_tasked_datasets if not x.gated]
639
- open_tasked_datasets = [x for x in open_tasked_datasets if not x.private]
640
- print('open_tasked_datasets', len(open_tasked_datasets))
641
- sizes = list(set(flatten_list([[(y, x.id) for y in x.tags if 'size' in y] for x in open_tasked_datasets])))
642
- languages = list(set(flatten_list([[(y, x.id) for y in x.tags if 'language:' in y] for x in open_tasked_datasets])))
643
- open_english_tasked_datasets = [x for x in open_tasked_datasets if
644
- 'language:' not in str(x.tags) or
645
- 'language:en' in str(x.tags)]
646
- small_open_english_tasked_datasets = [x for x in open_english_tasked_datasets if
647
- 'n<1K' in str(x.tags) or
648
- '1K<n<10K' in str(x.tags) or
649
- '1K0<n<100K' in str(x.tags) or
650
- '100K<n<1M' in str(x.tags) or
651
- 'size_category' not in str(x.tags)
652
- ]
653
- # 'aeslc' : email_body, subject -> summarization?
654
- # load_dataset(open_tasked_datasets[0].id).data['train'].to_pandas()
655
- ids = [x.id for x in small_open_english_tasked_datasets]
656
-
657
- # sanity checks
658
- # https://bair.berkeley.edu/blog/2023/04/03/koala/
659
- assert 'alespalla/chatbot_instruction_prompts' in ids
660
- assert 'laion/OIG' in ids
661
- assert 'openai/webgpt_comparisons' in ids
662
- assert 'openai/summarize_from_feedback' in ids
663
- assert 'Anthropic/hh-rlhf' in ids
664
-
665
- # useful but not allowed for commercial purposes:
666
- # https://huggingface.co/datasets/squad
667
-
668
- print('open_english_tasked_datasets: ', ids, flush=True)
669
-
670
- exclude_ids = ['allenai/nllb', # translation only
671
- 'hf-internal-testing/fixtures_image_utils', # testing
672
- 'allenai/c4', # search-url
673
- 'agemagician/uniref50', # unknown
674
- 'huggingface-course/documentation-images', # images
675
- 'smilegate-ai/kor_unsmile', # korean
676
- 'MohamedRashad/ChatGPT-prompts', # ChatGPT/LearnGPT/https://www.emergentmind.com/
677
- 'humarin/chatgpt-paraphrases', # Paraphrase using ChatGPT
678
- 'Jeska/vaccinchat', # not useful
679
- 'alespalla/chatbot_instruction_prompts', # mixes alpaca
680
- 'allenai/prosocial-dialog',
681
- # already exlucded, but wrongly in other datasets that say more permissive license
682
- 'AlekseyKorshuk/persona-chat', # low quality
683
- 'bavard/personachat_truecased', # low quality
684
- 'adamlin/daily_dialog', # medium quality conversations
685
- 'adamlin/FewShotWoz', # low quality
686
- 'benjaminbeilharz/better_daily_dialog', # low quality
687
- 'benjaminbeilharz/daily_dialog_w_turn_templates', # low
688
- 'benjaminbeilharz/empathetic_dialogues_for_lm', # low
689
- 'GEM-submissions/GEM__bart_base_schema_guided_dialog__1645547915', # NA
690
- 'ia-bentebib/conv_ai_2_fr', # low fr
691
- 'ia-bentebib/daily_dialog_fr', # low fr
692
- 'ia-bentebib/dialog_re_fr', # low fr
693
- 'ia-bentebib/empathetic_dialogues_fr', # low fr
694
- 'roskoN/dailydialog', # low
695
- 'VadorMazer/skyrimdialogstest', # low
696
- 'bigbio/med_qa', # med specific Q/A
697
- 'biu-nlp/qa_srl2018', # low quality Q/A
698
- 'biu-nlp/qa_discourse', # low quality Q/A
699
- 'iarfmoose/qa_evaluator', # low quality Q/A
700
- 'jeopardy', # low quality Q/A -- no reasoning
701
- 'narrativeqa', # low quality Q/A
702
- 'nomic-ai/gpt4all_prompt_generations', # bad license
703
- 'nomic-ai/gpt4all_prompt_generations_with_p3', # bad license
704
- 'HuggingFaceH4/alpaca', # bad license
705
- 'tatsu-lab/alpaca', # ToS breaking
706
- 'yahma/alpaca-cleaned', # ToS breaking
707
- 'Hello-SimpleAI/HC3', # bad license
708
- 'glue', # no reasoning QA
709
- 'sahil2801/CodeAlpaca-20k', # bad license
710
- 'Short-Answer-Feedback/saf_communication_networks_english', # long Q, medium A
711
- ]
712
- small_open_english_tasked_datasets = [x for x in small_open_english_tasked_datasets if x.id not in exclude_ids]
713
- # some ids clearly speech related
714
- small_open_english_tasked_datasets = [x for x in small_open_english_tasked_datasets if 'speech' not in x.id]
715
- # HF testing
716
- small_open_english_tasked_datasets = [x for x in small_open_english_tasked_datasets if
717
- 'hf-internal-testing' not in x.id]
718
- small_open_english_tasked_datasets = [x for x in small_open_english_tasked_datasets if
719
- 'chinese' not in x.id]
720
-
721
- sorted_small_open_english_tasked_datasets = sorted([(x.downloads, x) for x in small_open_english_tasked_datasets],
722
- key=lambda x: x[0], reverse=True)
723
-
724
- # NOTES:
725
- # Run like pytest -s -v create_data.py::test_get_open_datasets &> getdata9.log
726
- # See what needs config passed and add:
727
- # grep 'load_dataset(' getdata9.log|grep -v data_id|less -S
728
- # grep "pip install" getdata9.log
729
- # NOTE: Some datasets have default config, but others are there. Don't know how to access them.
730
-
731
- """
732
- https://huggingface.co/datasets/wikihow/blob/main/wikihow.py
733
- https://github.com/mahnazkoupaee/WikiHow-Dataset
734
- https://ucsb.box.com/s/ap23l8gafpezf4tq3wapr6u8241zz358
735
- https://ucsb.app.box.com/s/ap23l8gafpezf4tq3wapr6u8241zz358
736
- """
737
-
738
- """
739
- # some ambiguous or non-commercial datasets
740
- https://github.com/PhoebusSi/alpaca-CoT
741
- """
742
-
743
- timeout = 3 * 60
744
- # laion/OIG takes longer
745
- for num_downloads, dataset in sorted_small_open_english_tasked_datasets:
746
- data_id = dataset.id
747
- func = do_one
748
- args = (data_id, num_downloads)
749
- kwargs = {}
750
- with ProcessPoolExecutor(max_workers=1) as executor:
751
- future = executor.submit(func, *args, **kwargs)
752
- try:
753
- future.result(timeout=timeout)
754
- except concurrent.futures.TimeoutError:
755
- print("\n\ndata_id %s timeout\n\n" % data_id, flush=True)
756
- for child in psutil.Process(os.getpid()).children(recursive=True):
757
- os.kill(child.pid, signal.SIGINT)
758
- os.kill(child.pid, signal.SIGTERM)
759
- os.kill(child.pid, signal.SIGKILL)
760
-
761
-
762
- def do_one(data_id, num_downloads):
763
- from datasets import load_dataset
764
- out_file = "data_%s.parquet" % str(data_id.replace('/', '_'))
765
- if os.path.isfile(out_file) and os.path.getsize(out_file) > 1024 ** 3:
766
- return
767
- try:
768
- print("Loading data_id %s num_downloads: %s" % (data_id, num_downloads), flush=True)
769
- avail_list = None
770
- try:
771
- data = load_dataset(data_id, 'foobar')
772
- except Exception as e:
773
- if 'Available: ' in str(e):
774
- avail_list = ast.literal_eval(str(e).split('Available:')[1].strip())
775
- else:
776
- avail_list = None
777
- if avail_list is None:
778
- avail_list = [None]
779
- print("%s avail_list: %s" % (data_id, avail_list), flush=True)
780
-
781
- for name in avail_list:
782
- out_file = "data_%s_%s.parquet" % (str(data_id.replace('/', '_')), str(name))
783
- if os.path.isfile(out_file):
784
- continue
785
- data = load_dataset(data_id, name)
786
- column_names_dict = data.column_names
787
- column_names = column_names_dict[list(column_names_dict.keys())[0]]
788
- print("Processing data_id %s num_downloads: %s columns: %s" % (data_id, num_downloads, column_names),
789
- flush=True)
790
- data_dict = data.data
791
- col_dict = data.num_columns
792
- first_col = list(col_dict.keys())[0]
793
- if 'train' in data_dict:
794
- df = data['train'].to_pandas()
795
- else:
796
- df = data[first_col].to_pandas()
797
- # csv has issues with escaping chars, even for datasets I know I want
798
- df.to_parquet(out_file, index=False)
799
- except Exception as e:
800
- t, v, tb = sys.exc_info()
801
- ex = ''.join(traceback.format_exception(t, v, tb))
802
- print("Exception: %s %s" % (data_id, ex), flush=True)
803
-
804
-
805
- def test_otherlic():
806
- from huggingface_hub import list_datasets
807
- lic = ['license:odc-by',
808
- 'license:cc-by-4.0',
809
- 'license:cc-by-3.0',
810
- 'license:cc-by-2.0',
811
- 'license:cc-by-2.5',
812
- 'license:cc-by-sa-4.0',
813
- 'license:odbl',
814
- 'license:pddl',
815
- 'license:ms-pl',
816
- 'license:zlib',
817
- ]
818
- datasets = flatten_list([[x for x in list_datasets(filter=y) if 'translation' not in str(x.tags)] for y in lic])
819
- print(len(datasets))
820
-
821
-
822
- # These useful datasets are determined based upon data sample, column types, and uniqueness compared to larger datasets like Pile
823
- # grep columns getdata13.log|grep -v "\['image'\]"|sort|uniq|grep -v tokens|grep -v "'image'"|grep -v embedding|grep dialog
824
- useful = ['Dahoas/instruct-human-assistant-prompt',
825
- 'Dahoas/first-instruct-human-assistant-prompt',
826
- 'knkarthick/dialogsum', # summary of conversation
827
- 'McGill-NLP/FaithDial', # medium quality
828
- 'Zaid/quac_expanded', # medium quality context + QA
829
- '0-hero/OIG-small-chip2', # medium
830
- 'alistvt/coqa-flat', # QA medium
831
- 'AnonymousSub/MedQuAD_47441_Question_Answer_Pairs', # QA medium
832
- 'Anthropic/hh-rlhf', # high quality # similar to Dahoas/full-hh-rlhf
833
- 'arjunth2001/online_privacy_qna', # good quality QA
834
- 'Dahoas/instruct_helpful_preferences', # medium quality instruct
835
- 'Dahoas/rl-prompt-dataset', # medium chat
836
- 'Dahoas/rm-static', # medium chat
837
- 'Dahoas/static-hh', # medium chat # HuggingFaceH4/self_instruct
838
- 'Dahoas/synthetic-instruct-gptj-pairwise', # medium chat
839
- 'eli5', # QA if prompt ELI5
840
- 'gsm8k', # QA (various)
841
- 'guanaco/guanaco', # prompt/response
842
- 'kastan/rlhf-qa-comparisons', # good QA
843
- 'kastan/rlhf-qa-conditional-generation-v2', # prompt answer
844
- 'OllieStanley/humaneval-mbpp-codegen-qa', # code QA, but started from words, so better than other code QA
845
- 'OllieStanley/humaneval-mbpp-testgen-qa', # code QA
846
- 'Graverman/Instruct-to-Code', # code QA
847
- 'openai/summarize_from_feedback', # summarize
848
- 'relbert/analogy_questions', # analogy QA
849
- 'yitingxie/rlhf-reward-datasets', # prompt, chosen, rejected.
850
- 'yizhongw/self_instruct', # instruct (super natural & instruct)
851
- 'HuggingFaceH4/asss', # QA, big A
852
- 'kastan/rlhf-qa-conditional-generation-v2', # QA
853
- 'cosmos_qa', # context QA
854
- 'vishal-burman/c4-faqs', # QA but not so much reasoning, but alot of text
855
- 'squadshifts', # QA from context
856
- 'hotpot_qa', # QA from context
857
- 'adversarial_qa', # QA from context
858
- 'allenai/soda', # dialog -> narrative/summary
859
- 'squad_v2', # context QA
860
- 'squadshifts', # context QA
861
- 'dferndz/cSQuAD1', # context QA
862
- 'dferndz/cSQuAD2', # context QA
863
- 'din0s/msmarco-nlgen', # context QA
864
- 'domenicrosati/TruthfulQA', # common sense truthful QA -- trivia but good trivia
865
- 'hotpot_qa', # context, QA
866
- 'HuggingFaceH4/self-instruct-eval', # instruct QA, medium quality, some language reasoning
867
- 'kastan/EE_QA_for_RLHF', # context QA
868
- 'KK04/LogicInference_OA', # instruction logical QA
869
- 'lmqg/qa_squadshifts_synthetic', # context QA
870
- 'lmqg/qg_squad', # context QA
871
- 'lmqg/qg_squadshifts', # context QA
872
- 'lmqg/qg_subjqa', # context QA
873
- 'pszemraj/HC3-textgen-qa',
874
- # QA medium, has human responses -- humans tend to provide links instead of trying to answer
875
- 'pythonist/newdata', # long context, QA, brief A
876
- 'ropes', # long background, situation, question, A
877
- 'wikitablequestions', # table -> QA
878
- 'bigscience/p3', # context QA but short answers
879
- ]
880
-
881
- code_useful = ['0n1xus/codexglue',
882
- 'openai_humaneval',
883
- 'koutch/staqc',
884
- ]
885
-
886
- maybe_useful = ['AlekseyKorshuk/comedy-scripts',
887
- 'openbookqa', # hard to parse, low reasoning
888
- 'qed', # reasonable QA, but low reasoning
889
- 'selqa', # candidate answers
890
- 'HuggingFaceH4/instruction-pilot-outputs-filtered',
891
- 'GBaker/MedQA-USMLE-4-options', # medical QA with long questions
892
- 'npc-engine/light-batch-summarize-dialogue', # dialog summarize, kinda low specific quality
893
- ]
894
-
895
- summary_useful = ['austin/rheum_abstracts',
896
- 'CarperAI/openai_summarize_comparisons', # summarize chosen/rejected
897
- 'CarperAI/openai_summarize_tldr', # summarize QA
898
- 'ccdv/cnn_dailymail', # summarize news
899
- 'ccdv/govreport-summarization', # summarize high quality
900
- 'ccdv/pubmed-summarization', # summarize high quality
901
- 'duorc', # plot -> QA
902
- 'farleyknight/big_patent_5_percent', # desc -> abstract
903
- 'multi_news', # summary
904
- 'opinosis',
905
- 'SophieTr/reddit_clean',
906
- 'allenai/mup', # long text -> summary
907
- 'allenai/multi_lexsum', # long text -> summary
908
- 'big_patent',
909
- 'allenai/wcep_dense_max',
910
- 'awinml/costco_long_practice',
911
- 'GEM/xsum',
912
- 'ratishsp/newshead',
913
- 'RussianNLP/wikiomnia', # russian
914
- 'stacked-summaries/stacked-xsum-1024',
915
- ]
916
-
917
- math_useful = [
918
- 'competition_math'
919
- ]
920
-
921
- skipped = ['c4', # maybe useful, used for flan, but skipped due to size
922
- ]
923
-
924
- """
925
- To get training data from oig:
926
- pytest test_oig test_grade_final test_finalize_to_json
927
- """
928
-
929
- human = '<human>:'
930
- bot = '<bot>:'
931
-
932
-
933
- def test_assemble_and_detox():
934
- import re
935
- from profanity_check import predict_prob
936
- df_list = []
937
- for data in useful_oig_files:
938
- print("Processing %s" % data, flush=True)
939
- df = pd.read_parquet(data)
940
- df = df.reset_index(drop=True)
941
- # chop up into human/bot interactions of no more than 10kB per row
942
- text_list = df[['text']].values.ravel().tolist()
943
- new_text = []
944
- max_len = 2048 # uber cutoff
945
- MAX_LEN = 2048 // 2 - 30 # max len per question/answer
946
- for text in tqdm(text_list):
947
- human_starts = [m.start() for m in re.finditer('<human>: ', text)]
948
- if len(human_starts) == 1:
949
- human_starts = [0, len(text)] # always go into for loop below
950
- blurb = ''
951
- for i in range(len(human_starts) - 1):
952
- interaction = text[human_starts[i]: human_starts[i + 1]][:max_len]
953
- blurb += interaction
954
- if len(blurb) >= MAX_LEN:
955
- blurb = get_sentences(blurb, length=MAX_LEN)[0]
956
- new_text.append(blurb + "\n<human>:")
957
- blurb = ''
958
- if blurb:
959
- blurb = get_sentences(blurb, length=MAX_LEN)[0]
960
- new_text.append(blurb + "\n<human>:")
961
-
962
- if len(new_text) > len(text_list):
963
- print("Added %d new rows (before: %d)" % (len(new_text) - df.shape[0], df.shape[0]))
964
- df = pd.DataFrame({"text": new_text, "source": [data] * len(new_text)})
965
- df = df.drop_duplicates(keep='first')
966
- print(df['text'].apply(lambda x: len(x)).describe())
967
- assert df['text'].apply(lambda x: len(x)).max() <= 2 * max_len
968
-
969
- # faster than better_profanity, do early
970
- df['profanity'] = predict_prob(df['text'])
971
- before_rows = df.shape[0]
972
- df = df[df['profanity'] < 0.25] # drop any low quality stuff
973
- after_rows = df.shape[0]
974
- print("Dropped %d rows out of %d due to alt-profanity-check" % (before_rows - after_rows, before_rows))
975
- df_list.append(df)
976
- print("Done processing %s -> %s rows" % (data, df.shape[0]), flush=True)
977
- print("So far have %d rows" % sum([len(x) for x in df_list]))
978
- df_final = pd.concat(df_list)
979
- df_final = df_final.sample(frac=1, random_state=1234).reset_index(drop=True)
980
- df_final.to_parquet('h2oGPT.cleaned.human_bot.shorter.parquet', index=False)
981
-
982
-
983
- def test_basic_cleaning():
984
- # from better_profanity import profanity
985
- # https://pypi.org/project/alt-profanity-check/
986
- from profanity_check import predict
987
- df_list = []
988
- for data in useful_oig_files:
989
- # for data in useful_oig_files[:5]:
990
- # for data in ['unified_openai_summarize_tldr.jsonl.parquet']:
991
- print("Processing %s" % data, flush=True)
992
- df = pd.read_parquet(data)
993
- df = df.reset_index(drop=True)
994
- # NOTE: Not correct if multiple human-bot interactions, but those dialogs even more desired
995
- # avg_chars = len(df['text'][0])/(df['text'][0].count(human)+df['text'][0].count(bot))
996
- df['avg_words'] = df['text'].apply(lambda x: x.count(' ') / (x.count(human) + x.count(bot)) / 2.0)
997
- df['avg_bot_words'] = df['text'].apply(lambda x: x.split(bot)[1].count(' ') / x.count(bot))
998
- # df['bad_words'] = df['text'].apply(lambda x: profanity.contains_profanity(x))
999
- # low_quality_patterns = ['Write the rest of this wikipedia article']
1000
- res = predict(df['text'])
1001
- df['bad_words'] = res
1002
- df = df.reset_index(drop=True)
1003
- df = df[df['bad_words'] == 0]
1004
- df = df[['text', 'avg_words', 'avg_bot_words']]
1005
- df = df.drop_duplicates(keep='first')
1006
- print(df[df['avg_words'] == df['avg_words'].max()]['text'].values)
1007
- median_words = np.median(df['avg_words'])
1008
- min_words_per_entity = max(30, 0.8 * median_words)
1009
- max_words_per_entity = 2048 # too hard to learn from for now
1010
- df = df[df['avg_words'] > min_words_per_entity]
1011
- df = df[df['avg_words'] < max_words_per_entity]
1012
-
1013
- min_words_per_entity = max(20, 0.5 * median_words) # bot should say stuff for now
1014
- max_words_per_entity = 2048 # too hard to learn from for now
1015
- df = df[df['avg_bot_words'] > min_words_per_entity]
1016
- df = df[df['avg_bot_words'] < max_words_per_entity]
1017
-
1018
- df_list.append(df)
1019
- print("Done processing %s -> %s rows" % (data, df.shape[0]), flush=True)
1020
- df_final = pd.concat(df_list)
1021
- df_final.to_parquet('h2oGPT.cleaned.human_bot.parquet', index=False)
1022
-
1023
-
1024
- from joblib import Parallel, delayed, effective_n_jobs
1025
- from sklearn.utils import gen_even_slices
1026
- from sklearn.utils.validation import _num_samples
1027
-
1028
-
1029
- def parallel_apply(df, func, n_jobs=-1, **kwargs):
1030
- """ Pandas apply in parallel using joblib.
1031
- Uses sklearn.utils to partition input evenly.
1032
-
1033
- Args:
1034
- df: Pandas DataFrame, Series, or any other object that supports slicing and apply.
1035
- func: Callable to apply
1036
- n_jobs: Desired number of workers. Default value -1 means use all available cores.
1037
- **kwargs: Any additional parameters will be supplied to the apply function
1038
-
1039
- Returns:
1040
- Same as for normal Pandas DataFrame.apply()
1041
-
1042
- """
1043
-
1044
- if effective_n_jobs(n_jobs) == 1:
1045
- return df.apply(func, **kwargs)
1046
- else:
1047
- ret = Parallel(n_jobs=n_jobs)(
1048
- delayed(type(df).apply)(df[s], func, **kwargs)
1049
- for s in gen_even_slices(_num_samples(df), effective_n_jobs(n_jobs)))
1050
- return pd.concat(ret)
1051
-
1052
-
1053
- def add_better_profanity_flag(df):
1054
- from better_profanity import profanity
1055
- df['better_profanity'] = parallel_apply(
1056
- df['text'],
1057
- lambda x: profanity.contains_profanity(x),
1058
- n_jobs=-1,
1059
- )
1060
- return df
1061
-
1062
-
1063
- def add_textstat_grade(df):
1064
- import textstat
1065
-
1066
- def myfunc(x):
1067
- return textstat.flesch_kincaid_grade(x) # simple grade
1068
-
1069
- if False:
1070
- import dask.dataframe as dd
1071
- # 40 seconds for 1000 rows, but have 1,787,799 rows
1072
- ddata = dd.from_pandas(df, npartitions=120)
1073
-
1074
- df['flesch_grade'] = ddata['text'].apply(myfunc).compute()
1075
- if True:
1076
- # fast way
1077
- df['flesch_grade'] = parallel_apply(df['text'], myfunc, n_jobs=-1)
1078
- return df
1079
-
1080
-
1081
- def add_deberta_grade(df):
1082
- from transformers import AutoModelForSequenceClassification, AutoTokenizer
1083
- import torch
1084
- reward_name = "OpenAssistant/reward-model-deberta-v3-large-v2"
1085
- rank_model, tokenizer = AutoModelForSequenceClassification.from_pretrained(
1086
- reward_name), AutoTokenizer.from_pretrained(reward_name)
1087
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
1088
- rank_model.to(device)
1089
-
1090
- def get_question(x):
1091
- return x.replace('<human>: ', '').split('<bot>:')[0]
1092
-
1093
- def get_answer(x):
1094
- try:
1095
- answer = x.split('<bot>: ')[1].split('<human>:')[0].replace('<bot>: ', '')
1096
- except:
1097
- answer = x.split('<bot>:')[1].split('<human>:')[0].replace('<bot>:', '')
1098
- return answer
1099
-
1100
- df['question'] = parallel_apply(df['text'], get_question, n_jobs=-1)
1101
- df['answer'] = parallel_apply(df['text'], get_answer, n_jobs=-1)
1102
-
1103
- from datasets import Dataset
1104
- from transformers import pipeline
1105
- from transformers.pipelines.pt_utils import KeyPairDataset
1106
- import tqdm
1107
-
1108
- pipe = pipeline(
1109
- "text-classification",
1110
- model=reward_name,
1111
- device="cuda:0" if torch.cuda.is_available() else "cpu"
1112
- )
1113
- start = 0
1114
- batch_size = 64 * 16
1115
- micro_batch = orig_micro_batch = 16
1116
- end = 0
1117
- import socket
1118
- checkpoint = "grades.%s.pkl" % socket.gethostname()
1119
- grades = []
1120
- import pickle
1121
- if os.path.exists(checkpoint):
1122
- with open(checkpoint, "rb") as f:
1123
- start, grades = pickle.loads(f.read())
1124
- last_oom = 0
1125
- while end < df.shape[0]:
1126
- # manual batching to handle OOM more gracefully
1127
- end = min(start + batch_size, df.shape[0])
1128
- if start == end:
1129
- break
1130
- dataset = Dataset.from_pandas(df.iloc[start:end, :])
1131
- try:
1132
- grades.extend([
1133
- x['score'] for x in tqdm.tqdm(
1134
- pipe(KeyPairDataset(dataset, "question", "answer"), batch_size=micro_batch)
1135
- )
1136
- ])
1137
- except torch.cuda.OutOfMemoryError:
1138
- last_oom = start
1139
- micro_batch = max(1, micro_batch // 2)
1140
- print("OOM - retrying with micro_batch=%d" % micro_batch)
1141
- continue
1142
- if last_oom == start:
1143
- micro_batch = orig_micro_batch
1144
- print("Returning to micro_batch=%d" % micro_batch)
1145
- assert len(grades) == end
1146
- start = end
1147
- with open(checkpoint, "wb") as f:
1148
- f.write(pickle.dumps((end, grades)))
1149
- print("%d/%d" % (end, df.shape[0]))
1150
- df['grade_deberta'] = grades
1151
- if os.path.exists(checkpoint):
1152
- os.remove(checkpoint)
1153
- return df
1154
-
1155
-
1156
- def test_chop_by_lengths():
1157
- file = "h2oGPT.cleaned.human_bot.shorter.parquet"
1158
- df = pd.read_parquet(file).reset_index(drop=True)
1159
- df = count_human_bot_lengths(df)
1160
- df['rand'] = np.random.rand(df.shape[0])
1161
- df['rand2'] = np.random.rand(df.shape[0])
1162
- before_rows = df.shape[0]
1163
- # throw away short human/bot responses with higher likelihood
1164
- df = df[(df['len_human_mean'] > 20)] # never keep very short ones
1165
- df = df[(df['len_human_mean'] > 30) | (df['rand'] < 0.2)]
1166
- df = df[(df['len_human_mean'] > 50) | (df['rand'] < 0.5)]
1167
- df = df[(df['len_human_max'] < 10000)] # drop super long (basically only human) ones
1168
- df = df[(df['len_bot_mean'] > 20)] # never keep very short ones
1169
- df = df[(df['len_bot_mean'] > 30) | (df['rand2'] < 0.2)]
1170
- df = df[(df['len_bot_mean'] > 50) | (df['rand2'] < 0.5)]
1171
- df = df[(df['len_bot_max'] < 10000)] # drop super long (only bot) ones
1172
- assert df['text'].apply(lambda x: len(x)).max() < 20000
1173
- df = df.drop(['rand', 'rand2'], axis=1)
1174
- after_rows = df.shape[0]
1175
- print("Chopped off %d out of %d rows due to length" % (before_rows - after_rows, before_rows))
1176
- print(df.describe())
1177
- df.to_parquet('h2oGPT.cleaned.chopped.human_bot.shorter.parquet', index=False)
1178
-
1179
-
1180
- def count_human_bot_lengths(df, human=None, bot=None):
1181
- import re
1182
- len_human_min = []
1183
- len_human_max = []
1184
- len_human_mean = []
1185
- len_bot_min = []
1186
- len_bot_max = []
1187
- len_bot_mean = []
1188
- human = human or '<human>:'
1189
- bot = bot or '<bot>:'
1190
- for is_human in [True, False]:
1191
- what = human if is_human else bot
1192
- other = human if not is_human else bot
1193
- for i in range(df.shape[0]):
1194
- text = df.loc[i, 'text']
1195
- assert isinstance(text, str)
1196
- starts = [m.start() for m in re.finditer(what, text)]
1197
- if len(starts) == 1:
1198
- starts = [starts[0], len(text)] # always go into for loop below
1199
- assert len(text)
1200
- list_what = []
1201
- for ii in range(len(starts) - 1):
1202
- interaction = text[starts[ii]: starts[ii + 1]]
1203
- if other in interaction:
1204
- interaction = interaction[:interaction.find(other)]
1205
- interaction.strip()
1206
- list_what.append(interaction)
1207
- if not list_what:
1208
- list_what = [''] # handle corrupted data, very rare, leads to sizes 0
1209
- if is_human:
1210
- len_human_min.append(min([len(x) for x in list_what]))
1211
- len_human_max.append(max([len(x) for x in list_what]))
1212
- len_human_mean.append(np.mean([len(x) for x in list_what]))
1213
- else:
1214
- len_bot_min.append(min([len(x) for x in list_what]))
1215
- len_bot_max.append(max([len(x) for x in list_what]))
1216
- len_bot_mean.append(np.mean([len(x) for x in list_what]))
1217
- df['len_human_min'] = len_human_min
1218
- df['len_human_max'] = len_human_max
1219
- df['len_human_mean'] = len_human_mean
1220
- df['len_bot_min'] = len_bot_min
1221
- df['len_bot_max'] = len_bot_max
1222
- df['len_bot_mean'] = len_bot_mean
1223
- np.random.seed(1234)
1224
- pd.set_option('display.max_columns', None)
1225
- print("Before chopping")
1226
- print(df.describe())
1227
- return df
1228
-
1229
-
1230
- def test_grade():
1231
- df = None
1232
-
1233
- file = "h2oGPT.cleaned.chopped.human_bot.shorter.parquet"
1234
- output_file = "h2oGPT.cleaned.graded1.human_bot.shorter.parquet"
1235
- if not os.path.exists(output_file):
1236
- if df is None:
1237
- df = pd.read_parquet(file).reset_index(drop=True)
1238
- df = add_textstat_grade(df)
1239
- min_grade = 10
1240
- max_grade = 25
1241
- df = df[df['flesch_grade'] >= min_grade]
1242
- df = df[df['flesch_grade'] <= max_grade]
1243
- print("After Flesch grade")
1244
- print(df.describe())
1245
- df.to_parquet(output_file, index=False)
1246
-
1247
- file = output_file
1248
- output_file = "h2oGPT.cleaned.graded2.human_bot.shorter.parquet"
1249
- if not os.path.exists(output_file):
1250
- # slower than alt-profanity, do last, but do before deberta grading, since that's slower
1251
- if df is None:
1252
- df = pd.read_parquet(file).reset_index(drop=True)
1253
- df = add_better_profanity_flag(df)
1254
- before_rows = df.shape[0]
1255
- df = df[df['better_profanity'] == 0]
1256
- df = df.drop(['better_profanity'], axis=1)
1257
- after_rows = df.shape[0]
1258
- print("Dropped %d rows out of %d due to better_profanity" % (before_rows - after_rows, before_rows))
1259
- print(df.describe())
1260
- df.to_parquet(output_file, index=False)
1261
-
1262
- file = output_file
1263
- output_file = 'h2oGPT.cleaned.graded3.human_bot.shorter.parquet'
1264
- if not os.path.exists(output_file):
1265
- if df is None:
1266
- df = pd.read_parquet(file).reset_index(drop=True)
1267
- df = add_deberta_grade(df)
1268
- min_grade = 0.3
1269
- max_grade = np.inf
1270
- before_rows = df.shape[0]
1271
- df = df[df['grade_deberta'] >= min_grade]
1272
- df = df[df['grade_deberta'] <= max_grade]
1273
- after_rows = df.shape[0]
1274
- print("Dropped %d rows out of %d due to deberta grade" % (before_rows - after_rows, before_rows))
1275
- print("After DeBERTa grade")
1276
- print(df.describe())
1277
- df.to_parquet(output_file, index=False)
1278
-
1279
- file = output_file
1280
- output_file = 'h2oGPT.cleaned.graded.human_bot.shorter.parquet'
1281
- if df is None:
1282
- df = pd.read_parquet(file).reset_index(drop=True)
1283
- df.to_parquet(output_file, index=False)
1284
-
1285
-
1286
- @pytest.mark.parametrize(
1287
- "fixup_personality, only_personality, deberta_grading",
1288
- [
1289
- [False, False, False],
1290
- [True, True, False],
1291
- [True, False, False],
1292
- [True, False, True],
1293
- ]
1294
- )
1295
- def test_add_open_assistant(fixup_personality, only_personality, deberta_grading, save_json=True):
1296
- """
1297
- Flatten tree structure into one row per path from root to leaf
1298
- Also turn into human_bot prompting format:
1299
- <human>: question\n<bot>: answer <human>: question2\n<bot>: answer2 Etc.
1300
- Also saves a .json locally as side-effect
1301
- returns list of dicts, containing intput, prompt_type and source
1302
- """
1303
- from datasets import load_dataset
1304
- data_file = "OpenAssistant/oasst1"
1305
- ds = load_dataset(data_file)
1306
- df = pd.concat([ds['train'].to_pandas(), ds['validation'].to_pandas()], axis=0)
1307
- rows = {}
1308
- message_ids = df['message_id'].values.tolist()
1309
- message_tree_ids = df['message_tree_id'].values.tolist()
1310
- parent_ids = df['parent_id'].values.tolist()
1311
- texts = df['text'].values.tolist()
1312
- roles = df['role'].values.tolist()
1313
-
1314
- for i in range(df.shape[0]):
1315
- # collect all trees
1316
- message_id = message_ids[i]
1317
- message_tree_id = message_tree_ids[i]
1318
- parent_id = parent_ids[i]
1319
- text = texts[i]
1320
- if fixup_personality:
1321
- text = text.replace("Open Assistant", "h2oGPT")
1322
- text = text.replace("Open-Assistant", "h2oGPT")
1323
- text = text.replace("open-assistant", "h2oGPT")
1324
- text = text.replace("OpenAssistant", "h2oGPT")
1325
- text = text.replace("open assistant", "h2oGPT")
1326
- text = text.replace("Open Assistand", "h2oGPT")
1327
- text = text.replace("Open Assitant", "h2oGPT")
1328
- text = text.replace("Open Assistent", "h2oGPT")
1329
- text = text.replace("Open Assisstant", "h2oGPT")
1330
- text = text.replace("Open Assitent", "h2oGPT")
1331
- text = text.replace("Open Assitiant", "h2oGPT")
1332
- text = text.replace("Open Assistiant", "h2oGPT")
1333
- text = text.replace("Open Assitan ", "h2oGPT ")
1334
- text = text.replace("Open Assistan ", "h2oGPT ")
1335
- text = text.replace("Open Asistant", "h2oGPT")
1336
- text = text.replace("Open Assiant", "h2oGPT")
1337
- text = text.replace("Assistant", "h2oGPT")
1338
- text = text.replace("LAION AI", "H2O.ai")
1339
- text = text.replace("LAION-AI", "H2O.ai")
1340
- text = text.replace("LAION,", "H2O.ai,")
1341
- text = text.replace("LAION.ai", "H2O.ai")
1342
- text = text.replace("LAION.", "H2O.ai.")
1343
- text = text.replace("LAION", "H2O.ai")
1344
-
1345
- role = roles[i]
1346
- new_data = ('<human>: ' if role == 'prompter' else '<bot>: ') + text
1347
- entry = dict(message_id=message_id, parent_id=parent_id, text=new_data)
1348
- if message_tree_id not in rows:
1349
- rows[message_tree_id] = [entry]
1350
- else:
1351
- rows[message_tree_id].append(entry)
1352
-
1353
- all_rows = []
1354
-
1355
- for node_id in rows:
1356
- # order responses in tree, based on message/parent relationship
1357
- conversations = []
1358
-
1359
- list_msgs = rows[node_id]
1360
- # find start
1361
- while len(list_msgs):
1362
- for i, leaf in enumerate(list_msgs):
1363
- found = False
1364
- parent_id = leaf['parent_id']
1365
- if parent_id is None:
1366
- # conversation starter
1367
- conversations.append(leaf)
1368
- found = True
1369
- else:
1370
- for conv in conversations:
1371
- # find all conversations to add my message to
1372
- if parent_id in conv['message_id'] and parent_id != conv['message_id'][-len(parent_id):]:
1373
- # my message doesn't follow conversation
1374
- continue
1375
- if parent_id == conv['message_id'][-len(parent_id):]:
1376
- # my message follows conversation, but fork first, so another follow-on message can do same
1377
- conversations.append(conv.copy())
1378
- conv['text'] += f"""
1379
- {leaf['text']}
1380
- """
1381
- conv['message_id'] += leaf['message_id']
1382
- found = True
1383
- break
1384
- if found:
1385
- # my content was used, so nuke from list
1386
- del list_msgs[i]
1387
- break
1388
-
1389
- # now reduce down to final conversations, find the longest chains of message ids
1390
- for i, conv in enumerate(conversations):
1391
- for j, conv2 in enumerate(conversations):
1392
- if i == j:
1393
- continue
1394
- if conv['message_id'] and conv2['message_id']:
1395
- assert conv['message_id'] != conv2['message_id']
1396
- # delete the shorter conversation, if one contains the other
1397
- if conv['message_id'] in conv2['message_id']:
1398
- conv['message_id'] = None
1399
- if conv2['message_id'] in conv['message_id']:
1400
- conv2['message_id'] = None
1401
- conversations = [c for c in conversations if c['message_id']]
1402
- if only_personality:
1403
- all_rows.extend(
1404
- [dict(input=c['text'] + "\n<human>:", prompt_type='plain', source=data_file) for c in conversations if
1405
- 'h2oGPT' in c['text']])
1406
- else:
1407
- all_rows.extend(
1408
- [dict(input=c['text'] + "\n<human>:", prompt_type='plain', source=data_file) for c in conversations if
1409
- "What is H2O.ai" not in c['text']])
1410
- unhelpful = get_unhelpful_list()
1411
- all_rows = [x for x in all_rows if not any(u in x['input'] for u in unhelpful)]
1412
- personality = create_personality_data()
1413
- all_rows.extend(personality * 10)
1414
- np.random.seed(123)
1415
- np.random.shuffle(all_rows)
1416
- print(len(all_rows))
1417
- if deberta_grading:
1418
- df = pd.DataFrame(all_rows)
1419
- df = df.rename(columns={'input': 'text'})
1420
- df = add_deberta_grade(df)
1421
- df = df.rename(columns={'text': 'input'})
1422
- drop = True
1423
- if drop:
1424
- min_grade = 0.3
1425
- max_grade = np.inf
1426
- before_rows = df.shape[0]
1427
- df = df[df['grade_deberta'] >= min_grade]
1428
- df = df[df['grade_deberta'] <= max_grade]
1429
- after_rows = df.shape[0]
1430
- print("Dropped %d rows out of %d due to deberta grade" % (before_rows - after_rows, before_rows))
1431
- print("After DeBERTa grade")
1432
- print(df.describe())
1433
- all_rows = []
1434
- for i in range(df.shape[0]):
1435
- all_rows.append(
1436
- dict(
1437
- input=df['input'].iloc[i],
1438
- source=df['source'].iloc[i],
1439
- prompt_type=df['prompt_type'].iloc[i],
1440
- grade_deberta=df['grade_deberta'].iloc[i],
1441
- )
1442
- )
1443
- if save_json:
1444
- data_file = data_file + \
1445
- ("_h2ogpt" if fixup_personality else "") + \
1446
- ("_only" if only_personality else "") + \
1447
- ("_graded" if deberta_grading else "")
1448
- for i in range(len(all_rows)):
1449
- all_rows[i]['id'] = i
1450
- with open(data_file.lower().replace("/", "_") + ".json", "w") as f:
1451
- f.write(json.dumps(all_rows, indent=2))
1452
- return all_rows
1453
-
1454
-
1455
- def test_finalize_to_json():
1456
- df = pd.read_parquet('h2oGPT.cleaned.graded.human_bot.shorter.parquet')
1457
- df = df.rename(columns={'text': 'input'})
1458
-
1459
- print("Number of high-quality human_bot interactions: %s" % df.shape[0], flush=True)
1460
-
1461
- print("Adding open assistant data")
1462
- with open("openassistant_oasst1_h2ogpt_graded.json") as f:
1463
- open_assistant = json.loads(f.read())
1464
- df = pd.concat([df, pd.DataFrame(open_assistant)], axis=0)
1465
-
1466
- def final_clean(df):
1467
- from better_profanity import profanity
1468
- profanity.load_censor_words_from_file("data/censor_words.txt")
1469
- df['profanity'] = parallel_apply(
1470
- df['input'],
1471
- lambda x: profanity.contains_profanity(x),
1472
- n_jobs=-1,
1473
- )
1474
- return df[(df['profanity'] == 0)].reset_index(drop=True)
1475
-
1476
- print("Before cleaning: Number of final high-quality human_bot interactions: %s" % df.shape[0], flush=True)
1477
- df = final_clean(df)
1478
- print("After cleaning: Number of final high-quality human_bot interactions: %s" % df.shape[0], flush=True)
1479
- print(df.describe())
1480
- print(df.shape)
1481
- row_list = []
1482
- for i in range(df.shape[0]):
1483
- row_list.append(
1484
- dict(
1485
- input=df.loc[i, 'input'],
1486
- source=df.loc[i, 'source'],
1487
- prompt_type='plain',
1488
- )
1489
- )
1490
- np.random.seed(1234)
1491
- np.random.shuffle(row_list)
1492
- unhelpful = get_unhelpful_list()
1493
- row_list = [x for x in row_list if not any(u in x['input'] for u in unhelpful)]
1494
- for i in range(len(row_list)):
1495
- row_list[i]['id'] = i
1496
- row_list[i]['input'] = row_list[i]['input'].replace(" <bot>:", "\n<bot>:")
1497
- with open('h2ogpt-oig-oasst1-instruct-cleaned-v3.json', "w") as f:
1498
- f.write(json.dumps(row_list, indent=2))
1499
-
1500
-
1501
- def create_personality_data():
1502
- questions = [
1503
- "What's your name?",
1504
- "What is your name?",
1505
- "What are you?",
1506
- "Who are you?",
1507
- "Do you have a name?",
1508
- "Who trained you?",
1509
- "Who created you?",
1510
- "Who made you?",
1511
- ]
1512
- answers = [
1513
- "I'm h2oGPT, a large language model by H2O.ai.",
1514
- "I'm h2oGPT, a large language model by H2O.ai, the visionary leader in democratizing AI.",
1515
- "My name is h2oGPT. I'm a large language model by H2O.ai, the visionary leader in democratizing AI.",
1516
- "My name is h2oGPT. I'm a large language model trained by H2O.ai.",
1517
- "Hi! I'm h2oGPT, a large language model by H2O.ai.",
1518
- "Hi! I'm h2oGPT, a large language model by H2O.ai, the visionary leader in democratizing AI.",
1519
- ]
1520
- help = [
1521
- "",
1522
- " How can I help you?",
1523
- " How may I assist you?",
1524
- " Nice to meet you.",
1525
- ]
1526
- import itertools
1527
- rows = []
1528
- for pair in itertools.product(questions, answers, help):
1529
- rows.append(
1530
- dict(input=f"<human>: {pair[0]}\n<bot>: {pair[1]}{pair[2]}\n<human>:", prompt_type='plain', source="H2O.ai")
1531
- )
1532
- for row in [
1533
- "<human>: What is H2O.ai?\n<bot>: H2O.ai is a technology company that aims to democratize AI and make it accessible to a broader audience by simplifying the process of creating and deploying machine learning models.\n<human>:",
1534
- "<human>: What is h2o.ai?\n<bot>: H2O.ai is a technology company that aims to democratize AI and make it accessible to a broader audience by simplifying the process of creating and deploying machine learning models.\n<human>:",
1535
- "<human>: What is H2O?\n<bot>: H2O.ai is a technology company that aims to democratize AI and make it accessible to a broader audience by simplifying the process of creating and deploying machine learning models.\n<human>:",
1536
- "<human>: Who is h2o.ai?\n<bot>: H2O.ai is a technology company that aims to democratize AI and make it accessible to a broader audience by simplifying the process of creating and deploying machine learning models.\n<human>:",
1537
- "<human>: who is h2o.ai?\n<bot>: H2O.ai is a technology company that aims to democratize AI and make it accessible to a broader audience by simplifying the process of creating and deploying machine learning models.\n<human>:",
1538
- "<human>: who is h2o?\n<bot>: H2O.ai is a technology company that aims to democratize AI and make it accessible to a broader audience by simplifying the process of creating and deploying machine learning models.\n<human>:",
1539
- "<human>: What is H2O.ai?\n<bot>: H2O.ai is the visionary leader in democratizing AI.\n<human>:",
1540
- "<human>: Who is H2O.ai?\n<bot>: H2O.ai is the visionary leader in democratizing AI.\n<human>:",
1541
- "<human>: Who is H2O?\n<bot>: H2O.ai is the visionary leader in democratizing AI.\n<human>:",
1542
- "<human>: Who is h2o?\n<bot>: H2O.ai is the visionary leader in democratizing AI.\n<human>:",
1543
- "<human>: who is h2o?\n<bot>: H2O.ai is the visionary leader in democratizing AI.\n<human>:",
1544
- ]:
1545
- rows.append(dict(input=row, prompt_type='plain', source='H2O.ai'))
1546
- print(len(rows))
1547
- with open("h2ogpt-personality.json", "w") as f:
1548
- f.write(json.dumps(rows, indent=2))
1549
- return rows
1550
-
1551
-
1552
- def test_check_stats_data():
1553
- filename = 'h2ogpt-oig-oasst1-instruct-cleaned-v3.json'
1554
- df = pd.read_json(filename)
1555
-
1556
- # get word stats
1557
- df['char_count'] = df['input'].apply(lambda x: len(x))
1558
- import matplotlib.pyplot as plt
1559
- plt.figure(figsize=(10, 10))
1560
- plt.hist(df['char_count'], bins=100)
1561
- chars_avg = np.mean(df['char_count'])
1562
- chars_median = np.median(df['char_count'])
1563
- plt.title("char_count avg: %s median: %s" % (chars_avg, chars_median))
1564
- plt.savefig('chars_hist.png')
1565
- plt.close()
1566
-
1567
- # get tokenize stats for random sample of 1000 rows
1568
- from finetune import generate_and_tokenize_prompt
1569
- from loaders import get_loaders, get_tokenizer
1570
- from functools import partial
1571
-
1572
- llama_type = False
1573
- tokenizer_base_model = base_model = 'h2oai/h2ogpt-oasst1-512-20b'
1574
- model_loader, tokenizer_loader = get_loaders(model_name=base_model, reward_type=False, llama_type=llama_type)
1575
- local_files_only = False
1576
- resume_download = True
1577
- use_auth_token = False
1578
- tokenizer = get_tokenizer(tokenizer_loader, tokenizer_base_model, local_files_only, resume_download, use_auth_token)
1579
- prompt_type = 'plain' # trained with data already in human bot form
1580
- train_on_inputs = True
1581
- add_eos_token = False
1582
- cutoff_len = 512 # can choose 2048
1583
- generate_and_tokenize_prompt_fun = partial(generate_and_tokenize_prompt, prompt_type=prompt_type,
1584
- train_on_inputs=train_on_inputs, add_eos_token=add_eos_token,
1585
- cutoff_len=cutoff_len, tokenizer=tokenizer)
1586
- from datasets import load_dataset
1587
- data = load_dataset("json", data_files={"train": filename})
1588
- val_set_size = 0.90
1589
- train_val = data["train"].train_test_split(
1590
- test_size=val_set_size, shuffle=True, seed=42
1591
- )
1592
- train_data = train_val["train"]
1593
- train_data = train_data.shuffle().map(generate_and_tokenize_prompt_fun, num_proc=os.cpu_count())
1594
-
1595
- df_tokens = pd.DataFrame([len(x) for x in train_data['input_ids']], columns=['token_count'])
1596
-
1597
- plt.figure(figsize=(10, 10))
1598
- plt.hist(df_tokens['token_count'], bins=100)
1599
- token_avg = np.mean(df_tokens['token_count'])
1600
- token_median = np.median(df_tokens['token_count'])
1601
- plt.title("token_count with cutoff=%s avg: %s median: %s" % (cutoff_len, token_avg, token_median))
1602
- plt.savefig('token_hist_%s.png' % cutoff_len)
1603
- plt.close()
1604
-
1605
-
1606
- def get_unhelpful_list():
1607
- # base versions
1608
- unhelpful = ["I'm sorry, I didn't quite understand your question, could you please rephrase it?",
1609
- "I'm sorry, but I don't understand your question. Could you please rephrase it?",
1610
- "I'm sorry, I don't quite understand your question",
1611
- "I'm sorry, I don't know",
1612
- "I'm sorry, but I don't know",
1613
- "I don't know anything",
1614
- "I do not know",
1615
- "I don't know",
1616
- "I don't know how",
1617
- "I do not know how",
1618
- "Can you please explain what you mean",
1619
- "please explain what you mean",
1620
- "please explain",
1621
- "I'm sorry, but I don't know how to tell a story. Can you please explain what you mean by",
1622
- "I'm sorry but I don't understand what you mean",
1623
- "I don't understand",
1624
- "I don't have the ability",
1625
- "I do not have the ability",
1626
- "I do not have",
1627
- "I am a language model,",
1628
- "I am a large language model,",
1629
- "I do not understand your question. Can you please try to make it clearer?",
1630
- "I'm sorry, but as an AI language model",
1631
- "I apologize, but I cannot rephrase text that I cannot understand. Your post is difficult to read and follow.",
1632
- "I apologize, but I am not h2oGPT. I am a language model developed by H2O.ai. How may I help you?",
1633
- "Sorry, but I am not an actual Linux shell, nor am I capable of emulating one. I am an open source chat assistant and would be glad t",
1634
- "I apologize, but I cannot perform the task you have requested.",
1635
- "I'm sorry, I cannot perform this task as I am an AI language model and do not have access",
1636
- "I'm sorry, I'm not sure what you're asking for here.",
1637
- "I'm not sure what you are asking",
1638
- "You need to provide more context",
1639
- ]
1640
- # reduced versions, with redundant parts, just to give context for where they came from
1641
- unhelpful += ["sorry, I didn't quite understand your question",
1642
- "I didn't quite understand your question",
1643
- "I didn't understand your question",
1644
- "I did not understand your question",
1645
- "I did not understand the question",
1646
- "could you please rephrase"
1647
- "could you rephrase"
1648
- "I do not understand your question.",
1649
- "I do not understand the question.",
1650
- "I do not understand that question.",
1651
- "Can you please try to make it clearer",
1652
- "Can you try to make it clearer",
1653
- "sorry, but as an AI language model",
1654
- "as an AI language model",
1655
- "I apologize, but I cannot",
1656
- "I cannot rephrase text",
1657
- "I cannot understand. Your post is difficult to read and follow."
1658
- "Your post is difficult to read and follow."
1659
- "I apologize, but I am",
1660
- "Sorry, but I am not ",
1661
- "nor am I capable",
1662
- "I am not capable of",
1663
- "I apologize, but I cannot perform the task you have requested",
1664
- "I cannot perform the task",
1665
- "I cannot complete the task",
1666
- "I'm sorry",
1667
- "I am sorry",
1668
- "do not have access",
1669
- "not sure what you're asking for",
1670
- "not sure what you are asking for",
1671
- "not sure what is being asked",
1672
- "I'm not sure what you are asking",
1673
- "not sure what you are asking",
1674
- "You need to provide more context",
1675
- "provide more context",
1676
- ]
1677
- unhelpful += ["As a large language model",
1678
- "cannot provide any information",
1679
- "As an artificial intelligence I do not have the capability",
1680
- "As an artificial intelligence I don't have the capability",
1681
- "As an artificial intelligence I can't",
1682
- "As an artificial intelligence I cannot",
1683
- "I am sorry but I do not understand",
1684
- "Can you please explain",
1685
- "(sorry couldn't resist)",
1686
- "(sorry could not resist)",
1687
- " :)",
1688
- " ;)",
1689
- " :-)",
1690
- " ;-)",
1691
- " lol ",
1692
- "Thanks so much!!!",
1693
- "Thank You :)!!!",
1694
- "Please try not to repeat",
1695
- "I am an AI language model",
1696
- "I'm a AI assistant that",
1697
- "I'm an AI assistant that",
1698
- "I am an AI assistant that",
1699
- "etc.",
1700
- "etc.etc.",
1701
- "etc. etc.",
1702
- "etc etc",
1703
- ]
1704
- return unhelpful
1705
-
1706
-
1707
- def test_check_unhelpful():
1708
- # file = '/home/jon/Downloads/openassistant_oasst1_h2ogpt_graded.json'
1709
- file = '/home/jon/Downloads/openassistant_oasst1_h2ogpt_grades.json'
1710
- # file = 'h2ogpt-oig-oasst1-instruct-cleaned-v2.json'
1711
-
1712
- unhelpful = get_unhelpful_list()
1713
- # data = json.load(open(file, 'rt'))
1714
- df = pd.read_json(file)
1715
-
1716
- use_reward_score_threshold = False
1717
- use_bleu_threshold = False
1718
- use_sentence_sim = True
1719
-
1720
- from sacrebleu.metrics import BLEU
1721
- bleu = BLEU()
1722
- from nltk.translate.bleu_score import sentence_bleu
1723
-
1724
- def get_bleu(actual, expected_list):
1725
- # return bleu.sentence_score(actual, expected_list).score
1726
- return sentence_bleu(expected_list, actual)
1727
-
1728
- threshold = 0.0
1729
- if use_reward_score_threshold:
1730
- df = df[df['grade_deberta'] > threshold]
1731
-
1732
- # back to as if original json load
1733
- data = df.to_dict(orient='records')
1734
- bads = {}
1735
- string_all = str(data)
1736
- for sub in unhelpful:
1737
- bads[sub] = string_all.count(sub)
1738
- bads = {k: v for k, v in bads.items() if v > 0}
1739
- import pprint
1740
- pp = pprint.PrettyPrinter(indent=4)
1741
- pp.pprint(bads)
1742
-
1743
- total_bads = sum(list(bads.values()))
1744
- print('total_bads: %s' % total_bads, flush=True)
1745
-
1746
- # check just bot
1747
- import re
1748
- convs = [[x.strip() for x in re.split(r'%s|%s' % (human, bot), y['input']) if x.strip()] for y in data]
1749
- humans = [[x for i, x in enumerate(y) if i % 2 == 0] for y in convs]
1750
- bots = [[x for i, x in enumerate(y) if i % 2 == 1] for y in convs]
1751
-
1752
- # FIXME: apply back to json etc., just see for now
1753
- bleu_threshold = 0.9
1754
- if use_bleu_threshold:
1755
- bots = [[x for x in y if get_bleu(x, unhelpful) < bleu_threshold] for y in tqdm(bots)]
1756
-
1757
- cosine_sim_threshold = 0.8
1758
- if use_sentence_sim:
1759
- # pip install sentence_transformers-2.2.2
1760
- from sentence_transformers import SentenceTransformer
1761
- # sent_model = 'bert-base-nli-mean-tokens'
1762
- # sent_model = 'nli-distilroberta-base-v2'
1763
- sent_model = 'all-MiniLM-L6-v2'
1764
- model = SentenceTransformer(sent_model)
1765
- sentence_embeddings = model.encode(unhelpful)
1766
- from sklearn.metrics.pairwise import cosine_similarity
1767
- bots = [x for x in tqdm(bots) if
1768
- np.max(cosine_similarity(model.encode(x), sentence_embeddings)) < cosine_sim_threshold]
1769
-
1770
- bads_bots = {}
1771
- string_all = str(bots)
1772
- for sub in unhelpful:
1773
- bads_bots[sub] = string_all.count(sub)
1774
- bads_bots = {k: v for k, v in bads_bots.items() if v > 0}
1775
- import pprint
1776
- pp = pprint.PrettyPrinter(indent=4)
1777
- pp.pprint(bads_bots)
1778
-
1779
- total_bads_bots = sum(list(bads_bots.values()))
1780
- print('threshold: %g use_bleu_threshold: %g total_bads_bots: %s total_bots: %s total_humans: %s' % (
1781
- threshold, use_bleu_threshold, total_bads_bots, len(bots), len(humans)), flush=True)
1782
-
1783
- # assert len(bads) == 0, bads
1784
- assert len(bads_bots) == 0, bads_bots
1785
-
1786
-
1787
- def test_fortune2000_personalized():
1788
- row_list = []
1789
- import glob
1790
- if not os.path.isdir("wikitext"):
1791
- raise RuntimeError("download https://github.com/h2oai/h2ogpt/files/11423008/wikitext.zip and unzip")
1792
- for file in glob.glob("wikitext/*.txt"):
1793
- with open(file, "r") as f:
1794
- blob = f.read()
1795
- N = 512 * 4
1796
- row_list.extend([{'input': s, 'prompt_type': 'plain', 'source': "%s" % os.path.basename(file)}
1797
- for s in get_sentences(blob, N) if s])
1798
- personality = create_personality_data()
1799
- import copy
1800
- for i in range(10):
1801
- row_list.extend(copy.deepcopy(personality))
1802
- np.random.seed(123)
1803
- np.random.shuffle(row_list)
1804
- for i in range(len(row_list)):
1805
- row_list[i]['id'] = i
1806
- for i in range(len(row_list)):
1807
- assert row_list[i]['id'] == i
1808
- with open("h2ogpt-fortune2000-personalized.json", "w") as ff:
1809
- ff.write(json.dumps(row_list, indent=2))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
enums.py DELETED
@@ -1,120 +0,0 @@
1
- from enum import Enum
2
-
3
-
4
- class PromptType(Enum):
5
- custom = -1
6
- plain = 0
7
- instruct = 1
8
- quality = 2
9
- human_bot = 3
10
- dai_faq = 4
11
- summarize = 5
12
- simple_instruct = 6
13
- instruct_vicuna = 7
14
- instruct_with_end = 8
15
- human_bot_orig = 9
16
- prompt_answer = 10
17
- open_assistant = 11
18
- wizard_lm = 12
19
- wizard_mega = 13
20
- instruct_vicuna2 = 14
21
- instruct_vicuna3 = 15
22
- wizard2 = 16
23
- wizard3 = 17
24
- instruct_simple = 18
25
- wizard_vicuna = 19
26
- openai = 20
27
- openai_chat = 21
28
- gptj = 22
29
- prompt_answer_openllama = 23
30
- vicuna11 = 24
31
- mptinstruct = 25
32
- mptchat = 26
33
- falcon = 27
34
- guanaco = 28
35
- llama2 = 29
36
-
37
-
38
- class DocumentSubset(Enum):
39
- Relevant = 0
40
- RelSources = 1
41
- TopKSources = 2
42
-
43
-
44
- non_query_commands = [
45
- DocumentSubset.RelSources.name,
46
- DocumentSubset.TopKSources.name
47
- ]
48
-
49
-
50
- class DocumentChoice(Enum):
51
- ALL = 'All'
52
-
53
-
54
- class LangChainMode(Enum):
55
- """LangChain mode"""
56
-
57
- DISABLED = "Disabled"
58
- LLM = "LLM"
59
- ALL = "All"
60
- WIKI = "wiki"
61
- WIKI_FULL = "wiki_full"
62
- USER_DATA = "UserData"
63
- MY_DATA = "MyData"
64
- GITHUB_H2OGPT = "github h2oGPT"
65
- H2O_DAI_DOCS = "DriverlessAI docs"
66
-
67
-
68
- # modes should not be removed from visible list or added by name
69
- langchain_modes_intrinsic = [LangChainMode.DISABLED.value,
70
- LangChainMode.LLM.value,
71
- LangChainMode.MY_DATA.value]
72
-
73
-
74
- class LangChainAction(Enum):
75
- """LangChain action"""
76
-
77
- QUERY = "Query"
78
- # WIP:
79
- # SUMMARIZE_MAP = "Summarize_map_reduce"
80
- SUMMARIZE_MAP = "Summarize"
81
- SUMMARIZE_ALL = "Summarize_all"
82
- SUMMARIZE_REFINE = "Summarize_refine"
83
-
84
-
85
- class LangChainAgent(Enum):
86
- """LangChain agents"""
87
-
88
- SEARCH = "Search"
89
- # CSV = "csv" # WIP
90
-
91
-
92
- no_server_str = no_lora_str = no_model_str = '[None/Remove]'
93
-
94
- # from site-packages/langchain/llms/openai.py
95
- # but needed since ChatOpenAI doesn't have this information
96
- model_token_mapping = {
97
- "gpt-4": 8192,
98
- "gpt-4-0314": 8192,
99
- "gpt-4-32k": 32768,
100
- "gpt-4-32k-0314": 32768,
101
- "gpt-3.5-turbo": 4096,
102
- "gpt-3.5-turbo-16k": 16 * 1024,
103
- "gpt-3.5-turbo-0301": 4096,
104
- "text-ada-001": 2049,
105
- "ada": 2049,
106
- "text-babbage-001": 2040,
107
- "babbage": 2049,
108
- "text-curie-001": 2049,
109
- "curie": 2049,
110
- "davinci": 2049,
111
- "text-davinci-003": 4097,
112
- "text-davinci-002": 4097,
113
- "code-davinci-002": 8001,
114
- "code-davinci-001": 8001,
115
- "code-cushman-002": 2048,
116
- "code-cushman-001": 2048,
117
- }
118
-
119
- source_prefix = "Sources [Score | Link]:"
120
- source_postfix = "End Sources<p>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
evaluate_params.py DELETED
@@ -1,52 +0,0 @@
1
- input_args_list = ['model_state', 'my_db_state', 'selection_docs_state']
2
-
3
-
4
- no_default_param_names = [
5
- 'instruction',
6
- 'iinput',
7
- 'context',
8
- 'instruction_nochat',
9
- 'iinput_nochat',
10
- ]
11
-
12
- gen_hyper = ['temperature',
13
- 'top_p',
14
- 'top_k',
15
- 'num_beams',
16
- 'max_new_tokens',
17
- 'min_new_tokens',
18
- 'early_stopping',
19
- 'max_time',
20
- 'repetition_penalty',
21
- 'num_return_sequences',
22
- 'do_sample',
23
- ]
24
-
25
- eval_func_param_names = ['instruction',
26
- 'iinput',
27
- 'context',
28
- 'stream_output',
29
- 'prompt_type',
30
- 'prompt_dict'] + \
31
- gen_hyper + \
32
- ['chat',
33
- 'instruction_nochat',
34
- 'iinput_nochat',
35
- 'langchain_mode',
36
- 'add_chat_history_to_context',
37
- 'langchain_action',
38
- 'langchain_agents',
39
- 'top_k_docs',
40
- 'chunk',
41
- 'chunk_size',
42
- 'document_subset',
43
- 'document_choice',
44
- ]
45
-
46
- # form evaluate defaults for submit_nochat_api
47
- eval_func_param_names_defaults = eval_func_param_names.copy()
48
- for k in no_default_param_names:
49
- if k in eval_func_param_names_defaults:
50
- eval_func_param_names_defaults.remove(k)
51
-
52
- eval_extra_columns = ['prompt', 'response', 'score']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gen.py DELETED
The diff for this file is too large to render. See raw diff
 
generate.py DELETED
@@ -1,16 +0,0 @@
1
- import os
2
- import sys
3
-
4
- if os.path.dirname(os.path.abspath(__file__)) not in sys.path:
5
- sys.path.append(os.path.dirname(os.path.abspath(__file__)))
6
-
7
- from src.gen import main
8
- from src.utils import H2O_Fire
9
-
10
-
11
- def entrypoint_main():
12
- H2O_Fire(main)
13
-
14
-
15
- if __name__ == "__main__":
16
- entrypoint_main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gpt4all_llm.py DELETED
@@ -1,316 +0,0 @@
1
- import inspect
2
- import os
3
- from functools import partial
4
- from typing import Dict, Any, Optional, List
5
- from langchain.callbacks.manager import CallbackManagerForLLMRun
6
- from pydantic import root_validator
7
- from langchain.llms import gpt4all
8
- from dotenv import dotenv_values
9
-
10
- from utils import FakeTokenizer
11
-
12
-
13
- def get_model_tokenizer_gpt4all(base_model, **kwargs):
14
- # defaults (some of these are generation parameters, so need to be passed in at generation time)
15
- model_kwargs = dict(n_threads=os.cpu_count() // 2,
16
- temp=kwargs.get('temperature', 0.2),
17
- top_p=kwargs.get('top_p', 0.75),
18
- top_k=kwargs.get('top_k', 40),
19
- n_ctx=2048 - 256)
20
- env_gpt4all_file = ".env_gpt4all"
21
- model_kwargs.update(dotenv_values(env_gpt4all_file))
22
- # make int or float if can to satisfy types for class
23
- for k, v in model_kwargs.items():
24
- try:
25
- if float(v) == int(v):
26
- model_kwargs[k] = int(v)
27
- else:
28
- model_kwargs[k] = float(v)
29
- except:
30
- pass
31
-
32
- if base_model == "llama":
33
- if 'model_path_llama' not in model_kwargs:
34
- raise ValueError("No model_path_llama in %s" % env_gpt4all_file)
35
- model_path = model_kwargs.pop('model_path_llama')
36
- # FIXME: GPT4All version of llama doesn't handle new quantization, so use llama_cpp_python
37
- from llama_cpp import Llama
38
- # llama sets some things at init model time, not generation time
39
- func_names = list(inspect.signature(Llama.__init__).parameters)
40
- model_kwargs = {k: v for k, v in model_kwargs.items() if k in func_names}
41
- model_kwargs['n_ctx'] = int(model_kwargs['n_ctx'])
42
- model = Llama(model_path=model_path, **model_kwargs)
43
- elif base_model in "gpt4all_llama":
44
- if 'model_name_gpt4all_llama' not in model_kwargs and 'model_path_gpt4all_llama' not in model_kwargs:
45
- raise ValueError("No model_name_gpt4all_llama or model_path_gpt4all_llama in %s" % env_gpt4all_file)
46
- model_name = model_kwargs.pop('model_name_gpt4all_llama')
47
- model_type = 'llama'
48
- from gpt4all import GPT4All as GPT4AllModel
49
- model = GPT4AllModel(model_name=model_name, model_type=model_type)
50
- elif base_model in "gptj":
51
- if 'model_name_gptj' not in model_kwargs and 'model_path_gptj' not in model_kwargs:
52
- raise ValueError("No model_name_gpt4j or model_path_gpt4j in %s" % env_gpt4all_file)
53
- model_name = model_kwargs.pop('model_name_gptj')
54
- model_type = 'gptj'
55
- from gpt4all import GPT4All as GPT4AllModel
56
- model = GPT4AllModel(model_name=model_name, model_type=model_type)
57
- else:
58
- raise ValueError("No such base_model %s" % base_model)
59
- return model, FakeTokenizer(), 'cpu'
60
-
61
-
62
- from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
63
-
64
-
65
- class H2OStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
66
-
67
- def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
68
- """Run on new LLM token. Only available when streaming is enabled."""
69
- # streaming to std already occurs without this
70
- # sys.stdout.write(token)
71
- # sys.stdout.flush()
72
- pass
73
-
74
-
75
- def get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=[]):
76
- # default from class
77
- model_kwargs = {k: v.default for k, v in dict(inspect.signature(cls).parameters).items() if k not in exclude_list}
78
- # from our defaults
79
- model_kwargs.update(default_kwargs)
80
- # from user defaults
81
- model_kwargs.update(env_kwargs)
82
- # ensure only valid keys
83
- func_names = list(inspect.signature(cls).parameters)
84
- model_kwargs = {k: v for k, v in model_kwargs.items() if k in func_names}
85
- return model_kwargs
86
-
87
-
88
- def get_llm_gpt4all(model_name,
89
- model=None,
90
- max_new_tokens=256,
91
- temperature=0.1,
92
- repetition_penalty=1.0,
93
- top_k=40,
94
- top_p=0.7,
95
- streaming=False,
96
- callbacks=None,
97
- prompter=None,
98
- context='',
99
- iinput='',
100
- verbose=False,
101
- ):
102
- assert prompter is not None
103
- env_gpt4all_file = ".env_gpt4all"
104
- env_kwargs = dotenv_values(env_gpt4all_file)
105
- max_tokens = env_kwargs.pop('max_tokens', 2048 - max_new_tokens)
106
- default_kwargs = dict(context_erase=0.5,
107
- n_batch=1,
108
- max_tokens=max_tokens,
109
- n_predict=max_new_tokens,
110
- repeat_last_n=64 if repetition_penalty != 1.0 else 0,
111
- repeat_penalty=repetition_penalty,
112
- temp=temperature,
113
- temperature=temperature,
114
- top_k=top_k,
115
- top_p=top_p,
116
- use_mlock=True,
117
- verbose=verbose)
118
- if model_name == 'llama':
119
- cls = H2OLlamaCpp
120
- model_path = env_kwargs.pop('model_path_llama') if model is None else model
121
- model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=['lc_kwargs'])
122
- model_kwargs.update(dict(model_path=model_path, callbacks=callbacks, streaming=streaming,
123
- prompter=prompter, context=context, iinput=iinput))
124
- llm = cls(**model_kwargs)
125
- llm.client.verbose = verbose
126
- elif model_name == 'gpt4all_llama':
127
- cls = H2OGPT4All
128
- model_path = env_kwargs.pop('model_path_gpt4all_llama') if model is None else model
129
- model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=['lc_kwargs'])
130
- model_kwargs.update(
131
- dict(model=model_path, backend='llama', callbacks=callbacks, streaming=streaming,
132
- prompter=prompter, context=context, iinput=iinput))
133
- llm = cls(**model_kwargs)
134
- elif model_name == 'gptj':
135
- cls = H2OGPT4All
136
- model_path = env_kwargs.pop('model_path_gptj') if model is None else model
137
- model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=['lc_kwargs'])
138
- model_kwargs.update(
139
- dict(model=model_path, backend='gptj', callbacks=callbacks, streaming=streaming,
140
- prompter=prompter, context=context, iinput=iinput))
141
- llm = cls(**model_kwargs)
142
- else:
143
- raise RuntimeError("No such model_name %s" % model_name)
144
- return llm
145
-
146
-
147
- class H2OGPT4All(gpt4all.GPT4All):
148
- model: Any
149
- prompter: Any
150
- context: Any = ''
151
- iinput: Any = ''
152
- """Path to the pre-trained GPT4All model file."""
153
-
154
- @root_validator()
155
- def validate_environment(cls, values: Dict) -> Dict:
156
- """Validate that the python package exists in the environment."""
157
- try:
158
- if isinstance(values["model"], str):
159
- from gpt4all import GPT4All as GPT4AllModel
160
-
161
- full_path = values["model"]
162
- model_path, delimiter, model_name = full_path.rpartition("/")
163
- model_path += delimiter
164
-
165
- values["client"] = GPT4AllModel(
166
- model_name=model_name,
167
- model_path=model_path or None,
168
- model_type=values["backend"],
169
- allow_download=False,
170
- )
171
- if values["n_threads"] is not None:
172
- # set n_threads
173
- values["client"].model.set_thread_count(values["n_threads"])
174
- else:
175
- values["client"] = values["model"]
176
- try:
177
- values["backend"] = values["client"].model_type
178
- except AttributeError:
179
- # The below is for compatibility with GPT4All Python bindings <= 0.2.3.
180
- values["backend"] = values["client"].model.model_type
181
-
182
- except ImportError:
183
- raise ValueError(
184
- "Could not import gpt4all python package. "
185
- "Please install it with `pip install gpt4all`."
186
- )
187
- return values
188
-
189
- def _call(
190
- self,
191
- prompt: str,
192
- stop: Optional[List[str]] = None,
193
- run_manager: Optional[CallbackManagerForLLMRun] = None,
194
- **kwargs,
195
- ) -> str:
196
- # Roughly 4 chars per token if natural language
197
- n_ctx = 2048
198
- prompt = prompt[-self.max_tokens * 4:]
199
-
200
- # use instruct prompting
201
- data_point = dict(context=self.context, instruction=prompt, input=self.iinput)
202
- prompt = self.prompter.generate_prompt(data_point)
203
-
204
- verbose = False
205
- if verbose:
206
- print("_call prompt: %s" % prompt, flush=True)
207
- # FIXME: GPT4ALl doesn't support yield during generate, so cannot support streaming except via itself to stdout
208
- return super()._call(prompt, stop=stop, run_manager=run_manager)
209
-
210
-
211
- from langchain.llms import LlamaCpp
212
-
213
-
214
- class H2OLlamaCpp(LlamaCpp):
215
- model_path: Any
216
- prompter: Any
217
- context: Any
218
- iinput: Any
219
- """Path to the pre-trained GPT4All model file."""
220
-
221
- @root_validator()
222
- def validate_environment(cls, values: Dict) -> Dict:
223
- """Validate that llama-cpp-python library is installed."""
224
- if isinstance(values["model_path"], str):
225
- model_path = values["model_path"]
226
- model_param_names = [
227
- "lora_path",
228
- "lora_base",
229
- "n_ctx",
230
- "n_parts",
231
- "seed",
232
- "f16_kv",
233
- "logits_all",
234
- "vocab_only",
235
- "use_mlock",
236
- "n_threads",
237
- "n_batch",
238
- "use_mmap",
239
- "last_n_tokens_size",
240
- ]
241
- model_params = {k: values[k] for k in model_param_names}
242
- # For backwards compatibility, only include if non-null.
243
- if values["n_gpu_layers"] is not None:
244
- model_params["n_gpu_layers"] = values["n_gpu_layers"]
245
-
246
- try:
247
- from llama_cpp import Llama
248
-
249
- values["client"] = Llama(model_path, **model_params)
250
- except ImportError:
251
- raise ModuleNotFoundError(
252
- "Could not import llama-cpp-python library. "
253
- "Please install the llama-cpp-python library to "
254
- "use this embedding model: pip install llama-cpp-python"
255
- )
256
- except Exception as e:
257
- raise ValueError(
258
- f"Could not load Llama model from path: {model_path}. "
259
- f"Received error {e}"
260
- )
261
- else:
262
- values["client"] = values["model_path"]
263
- return values
264
-
265
- def _call(
266
- self,
267
- prompt: str,
268
- stop: Optional[List[str]] = None,
269
- run_manager: Optional[CallbackManagerForLLMRun] = None,
270
- **kwargs,
271
- ) -> str:
272
- verbose = False
273
- # tokenize twice, just to count tokens, since llama cpp python wrapper has no way to truncate
274
- # still have to avoid crazy sizes, else hit llama_tokenize: too many tokens -- might still hit, not fatal
275
- prompt = prompt[-self.n_ctx * 4:]
276
- prompt_tokens = self.client.tokenize(b" " + prompt.encode("utf-8"))
277
- num_prompt_tokens = len(prompt_tokens)
278
- if num_prompt_tokens > self.n_ctx:
279
- # conservative by using int()
280
- chars_per_token = int(len(prompt) / num_prompt_tokens)
281
- prompt = prompt[-self.n_ctx * chars_per_token:]
282
- if verbose:
283
- print("reducing tokens, assuming average of %s chars/token: %s" % chars_per_token, flush=True)
284
- prompt_tokens2 = self.client.tokenize(b" " + prompt.encode("utf-8"))
285
- num_prompt_tokens2 = len(prompt_tokens2)
286
- print("reduced tokens from %d -> %d" % (num_prompt_tokens, num_prompt_tokens2), flush=True)
287
-
288
- # use instruct prompting
289
- data_point = dict(context=self.context, instruction=prompt, input=self.iinput)
290
- prompt = self.prompter.generate_prompt(data_point)
291
-
292
- if verbose:
293
- print("_call prompt: %s" % prompt, flush=True)
294
-
295
- if self.streaming:
296
- text_callback = None
297
- if run_manager:
298
- text_callback = partial(
299
- run_manager.on_llm_new_token, verbose=self.verbose
300
- )
301
- # parent handler of streamer expects to see prompt first else output="" and lose if prompt=None in prompter
302
- if text_callback:
303
- text_callback(prompt)
304
- text = ""
305
- for token in self.stream(prompt=prompt, stop=stop, run_manager=run_manager):
306
- text_chunk = token["choices"][0]["text"]
307
- # self.stream already calls text_callback
308
- # if text_callback:
309
- # text_callback(text_chunk)
310
- text += text_chunk
311
- return text
312
- else:
313
- params = self._get_parameters(stop)
314
- params = {**params, **kwargs}
315
- result = self.client(prompt=prompt, **params)
316
- return result["choices"][0]["text"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gpt_langchain.py DELETED
The diff for this file is too large to render. See raw diff
 
gradio_runner.py DELETED
The diff for this file is too large to render. See raw diff
 
gradio_themes.py DELETED
@@ -1,231 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from typing import Iterable
4
-
5
- from gradio.themes.soft import Soft
6
- from gradio.themes import Color, Size
7
- from gradio.themes.utils import colors, sizes, fonts
8
-
9
- h2o_yellow = Color(
10
- name="yellow",
11
- c50="#fffef2",
12
- c100="#fff9e6",
13
- c200="#ffecb3",
14
- c300="#ffe28c",
15
- c400="#ffd659",
16
- c500="#fec925",
17
- c600="#e6ac00",
18
- c700="#bf8f00",
19
- c800="#a67c00",
20
- c900="#664d00",
21
- c950="#403000",
22
- )
23
- h2o_gray = Color(
24
- name="gray",
25
- c50="#f8f8f8",
26
- c100="#e5e5e5",
27
- c200="#cccccc",
28
- c300="#b2b2b2",
29
- c400="#999999",
30
- c500="#7f7f7f",
31
- c600="#666666",
32
- c700="#4c4c4c",
33
- c800="#333333",
34
- c900="#191919",
35
- c950="#0d0d0d",
36
- )
37
-
38
-
39
- text_xsm = Size(
40
- name="text_xsm",
41
- xxs="4px",
42
- xs="5px",
43
- sm="6px",
44
- md="7px",
45
- lg="8px",
46
- xl="10px",
47
- xxl="12px",
48
- )
49
-
50
-
51
- spacing_xsm = Size(
52
- name="spacing_xsm",
53
- xxs="1px",
54
- xs="1px",
55
- sm="1px",
56
- md="2px",
57
- lg="3px",
58
- xl="5px",
59
- xxl="7px",
60
- )
61
-
62
-
63
- radius_xsm = Size(
64
- name="radius_xsm",
65
- xxs="1px",
66
- xs="1px",
67
- sm="1px",
68
- md="2px",
69
- lg="3px",
70
- xl="5px",
71
- xxl="7px",
72
- )
73
-
74
-
75
- class H2oTheme(Soft):
76
- def __init__(
77
- self,
78
- *,
79
- primary_hue: colors.Color | str = h2o_yellow,
80
- secondary_hue: colors.Color | str = h2o_yellow,
81
- neutral_hue: colors.Color | str = h2o_gray,
82
- spacing_size: sizes.Size | str = sizes.spacing_md,
83
- radius_size: sizes.Size | str = sizes.radius_md,
84
- text_size: sizes.Size | str = sizes.text_lg,
85
- font: fonts.Font
86
- | str
87
- | Iterable[fonts.Font | str] = (
88
- fonts.GoogleFont("Montserrat"),
89
- "ui-sans-serif",
90
- "system-ui",
91
- "sans-serif",
92
- ),
93
- font_mono: fonts.Font
94
- | str
95
- | Iterable[fonts.Font | str] = (
96
- fonts.GoogleFont("IBM Plex Mono"),
97
- "ui-monospace",
98
- "Consolas",
99
- "monospace",
100
- ),
101
- ):
102
- super().__init__(
103
- primary_hue=primary_hue,
104
- secondary_hue=secondary_hue,
105
- neutral_hue=neutral_hue,
106
- spacing_size=spacing_size,
107
- radius_size=radius_size,
108
- text_size=text_size,
109
- font=font,
110
- font_mono=font_mono,
111
- )
112
- super().set(
113
- link_text_color="#3344DD",
114
- link_text_color_hover="#3344DD",
115
- link_text_color_visited="#3344DD",
116
- link_text_color_dark="#74abff",
117
- link_text_color_hover_dark="#a3c8ff",
118
- link_text_color_active_dark="#a3c8ff",
119
- link_text_color_visited_dark="#74abff",
120
- button_primary_text_color="*neutral_950",
121
- button_primary_text_color_dark="*neutral_950",
122
- button_primary_background_fill="*primary_500",
123
- button_primary_background_fill_dark="*primary_500",
124
- block_label_background_fill="*primary_500",
125
- block_label_background_fill_dark="*primary_500",
126
- block_label_text_color="*neutral_950",
127
- block_label_text_color_dark="*neutral_950",
128
- block_title_text_color="*neutral_950",
129
- block_title_text_color_dark="*neutral_950",
130
- block_background_fill_dark="*neutral_950",
131
- body_background_fill="*neutral_50",
132
- body_background_fill_dark="*neutral_900",
133
- background_fill_primary_dark="*block_background_fill",
134
- block_radius="0 0 8px 8px",
135
- checkbox_label_text_color_selected_dark='#000000',
136
- #checkbox_label_text_size="*text_xs", # too small for iPhone etc. but good if full large screen zoomed to fit
137
- checkbox_label_text_size="*text_sm",
138
- #radio_circle="""url("data:image/svg+xml,%3csvg viewBox='0 0 32 32' fill='white' xmlns='http://www.w3.org/2000/svg'%3e%3ccircle cx='32' cy='32' r='1'/%3e%3c/svg%3e")""",
139
- #checkbox_border_width=1,
140
- #heckbox_border_width_dark=1,
141
- )
142
-
143
-
144
- class SoftTheme(Soft):
145
- def __init__(
146
- self,
147
- *,
148
- primary_hue: colors.Color | str = colors.indigo,
149
- secondary_hue: colors.Color | str = colors.indigo,
150
- neutral_hue: colors.Color | str = colors.gray,
151
- spacing_size: sizes.Size | str = sizes.spacing_md,
152
- radius_size: sizes.Size | str = sizes.radius_md,
153
- text_size: sizes.Size | str = sizes.text_md,
154
- font: fonts.Font
155
- | str
156
- | Iterable[fonts.Font | str] = (
157
- fonts.GoogleFont("Montserrat"),
158
- "ui-sans-serif",
159
- "system-ui",
160
- "sans-serif",
161
- ),
162
- font_mono: fonts.Font
163
- | str
164
- | Iterable[fonts.Font | str] = (
165
- fonts.GoogleFont("IBM Plex Mono"),
166
- "ui-monospace",
167
- "Consolas",
168
- "monospace",
169
- ),
170
- ):
171
- super().__init__(
172
- primary_hue=primary_hue,
173
- secondary_hue=secondary_hue,
174
- neutral_hue=neutral_hue,
175
- spacing_size=spacing_size,
176
- radius_size=radius_size,
177
- text_size=text_size,
178
- font=font,
179
- font_mono=font_mono,
180
- )
181
- super().set(
182
- checkbox_label_text_size="*text_sm",
183
- )
184
-
185
-
186
- h2o_logo = '<svg id="Layer_1" data-name="Layer 1" xmlns="http://www.w3.org/2000/svg" width="100%" height="100%"' \
187
- ' viewBox="0 0 600.28 600.28"><defs><style>.cls-1{fill:#fec925;}.cls-2{fill:#161616;}.cls-3{fill:' \
188
- '#54585a;}</style></defs><g id="Fill-1"><rect class="cls-1" width="600.28" height="600.28" ' \
189
- 'rx="23.24"/></g><path class="cls-2" d="M174.33,246.06v92.78H152.86v-38H110.71v38H89.24V246.06h21.' \
190
- '47v36.58h42.15V246.06Z"/><path class="cls-2" d="M259.81,321.34v17.5H189.7V324.92l35.78-33.8c8.22-7.' \
191
- '82,9.68-12.59,9.68-17.09,0-7.29-5-11.53-14.85-11.53-7.95,0-14.71,3-19.21,9.27L185.46,261.7c7.15-10' \
192
- '.47,20.14-17.23,36.84-17.23,20.68,0,34.46,10.6,34.46,27.44,0,9-2.52,17.22-15.51,29.29l-21.33,20.14Z"' \
193
- '/><path class="cls-2" d="M268.69,292.45c0-27.57,21.47-48,50.76-48s50.76,20.28,50.76,48-21.6,48-50.' \
194
- '76,48S268.69,320,268.69,292.45Zm79.78,0c0-17.63-12.46-29.69-29-29.69s-29,12.06-29,29.69,12.46,29.69' \
195
- ',29,29.69S348.47,310.08,348.47,292.45Z"/><path class="cls-3" d="M377.23,326.91c0-7.69,5.7-12.73,12.' \
196
- '85-12.73s12.86,5,12.86,12.73a12.86,12.86,0,1,1-25.71,0Z"/><path class="cls-3" d="M481.4,298.15v40.' \
197
- '69H462.05V330c-3.84,6.49-11.27,9.94-21.74,9.94-16.7,0-26.64-9.28-26.64-21.61,0-12.59,8.88-21.34,30.' \
198
- '62-21.34h16.43c0-8.87-5.3-14-16.43-14-7.55,0-15.37,2.51-20.54,6.62l-7.43-14.44c7.82-5.57,19.35-8.' \
199
- '62,30.75-8.62C468.81,266.47,481.4,276.54,481.4,298.15Zm-20.68,18.16V309H446.54c-9.67,0-12.72,3.57-' \
200
- '12.72,8.35,0,5.16,4.37,8.61,11.66,8.61C452.37,326,458.34,322.8,460.72,316.31Z"/><path class="cls-3"' \
201
- ' d="M497.56,246.06c0-6.49,5.17-11.53,12.86-11.53s12.86,4.77,12.86,11.13c0,6.89-5.17,11.93-12.86,' \
202
- '11.93S497.56,252.55,497.56,246.06Zm2.52,21.47h20.68v71.31H500.08Z"/></svg>'
203
-
204
-
205
- def get_h2o_title(title, description):
206
- # NOTE: Check full width desktop, smallest width browser desktop, iPhone browsers to ensure no overlap etc.
207
- return f"""<div style="float:left; justify-content:left; height: 80px; width: 195px; margin-top:0px">
208
- {description}
209
- </div>
210
- <div style="display:flex; justify-content:center; margin-bottom:30px; margin-right:330px;">
211
- <div style="height: 60px; width: 60px; margin-right:20px;">{h2o_logo}</div>
212
- <h1 style="line-height:60px">{title}</h1>
213
- </div>
214
- <div style="float:right; height: 80px; width: 80px; margin-top:-100px">
215
- <img src="https://raw.githubusercontent.com/h2oai/h2ogpt/main/docs/h2o-qr.png">
216
- </div>
217
- """
218
-
219
-
220
- def get_simple_title(title, description):
221
- return f"""{description}<h1 align="center"> {title}</h1>"""
222
-
223
-
224
- def get_dark_js():
225
- return """() => {
226
- if (document.querySelectorAll('.dark').length) {
227
- document.querySelectorAll('.dark').forEach(el => el.classList.remove('dark'));
228
- } else {
229
- document.querySelector('body').classList.add('dark');
230
- }
231
- }"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
h2oai_pipeline.py DELETED
@@ -1,201 +0,0 @@
1
- import os
2
-
3
- from transformers import TextGenerationPipeline
4
- from transformers.pipelines.text_generation import ReturnType
5
-
6
- from stopping import get_stopping
7
- from prompter import Prompter, PromptType
8
-
9
-
10
- class H2OTextGenerationPipeline(TextGenerationPipeline):
11
- def __init__(self, *args, debug=False, chat=False, stream_output=False,
12
- sanitize_bot_response=False,
13
- use_prompter=True, prompter=None,
14
- context='', iinput='',
15
- prompt_type=None, prompt_dict=None,
16
- max_input_tokens=2048 - 256, **kwargs):
17
- """
18
- HF-like pipeline, but handle instruction prompting and stopping (for some models)
19
- :param args:
20
- :param debug:
21
- :param chat:
22
- :param stream_output:
23
- :param sanitize_bot_response:
24
- :param use_prompter: Whether to use prompter. If pass prompt_type, will make prompter
25
- :param prompter: prompter, can pass if have already
26
- :param prompt_type: prompt_type, e.g. human_bot. See prompt_type to model mapping in from prompter.py.
27
- If use_prompter, then will make prompter and use it.
28
- :param prompt_dict: dict of get_prompt(, return_dict=True) for prompt_type=custom
29
- :param max_input_tokens:
30
- :param kwargs:
31
- """
32
- super().__init__(*args, **kwargs)
33
- self.prompt_text = None
34
- self.use_prompter = use_prompter
35
- self.prompt_type = prompt_type
36
- self.prompt_dict = prompt_dict
37
- self.prompter = prompter
38
- self.context = context
39
- self.iinput = iinput
40
- if self.use_prompter:
41
- if self.prompter is not None:
42
- assert self.prompter.prompt_type is not None
43
- else:
44
- self.prompter = Prompter(self.prompt_type, self.prompt_dict, debug=debug, chat=chat,
45
- stream_output=stream_output)
46
- self.human = self.prompter.humanstr
47
- self.bot = self.prompter.botstr
48
- self.can_stop = True
49
- else:
50
- self.prompter = None
51
- self.human = None
52
- self.bot = None
53
- self.can_stop = False
54
- self.sanitize_bot_response = sanitize_bot_response
55
- self.max_input_tokens = max_input_tokens # not for generate, so ok that not kwargs
56
-
57
- @staticmethod
58
- def limit_prompt(prompt_text, tokenizer, max_prompt_length=None):
59
- verbose = bool(int(os.getenv('VERBOSE_PIPELINE', '0')))
60
-
61
- if hasattr(tokenizer, 'model_max_length'):
62
- # model_max_length only defined for generate.py, not raw use of h2oai_pipeline.py
63
- model_max_length = tokenizer.model_max_length
64
- if max_prompt_length is not None:
65
- model_max_length = min(model_max_length, max_prompt_length)
66
- # cut at some upper likely limit to avoid excessive tokenization etc
67
- # upper bound of 10 chars/token, e.g. special chars sometimes are long
68
- if len(prompt_text) > model_max_length * 10:
69
- len0 = len(prompt_text)
70
- prompt_text = prompt_text[-model_max_length * 10:]
71
- if verbose:
72
- print("Cut of input: %s -> %s" % (len0, len(prompt_text)), flush=True)
73
- else:
74
- # unknown
75
- model_max_length = None
76
-
77
- num_prompt_tokens = None
78
- if model_max_length is not None:
79
- # can't wait for "hole" if not plain prompt_type, since would lose prefix like <human>:
80
- # For https://github.com/h2oai/h2ogpt/issues/192
81
- for trial in range(0, 3):
82
- prompt_tokens = tokenizer(prompt_text)['input_ids']
83
- num_prompt_tokens = len(prompt_tokens)
84
- if num_prompt_tokens > model_max_length:
85
- # conservative by using int()
86
- chars_per_token = int(len(prompt_text) / num_prompt_tokens)
87
- # keep tail, where question is if using langchain
88
- prompt_text = prompt_text[-model_max_length * chars_per_token:]
89
- if verbose:
90
- print("reducing %s tokens, assuming average of %s chars/token for %s characters" % (
91
- num_prompt_tokens, chars_per_token, len(prompt_text)), flush=True)
92
- else:
93
- if verbose:
94
- print("using %s tokens with %s chars" % (num_prompt_tokens, len(prompt_text)), flush=True)
95
- break
96
-
97
- # Why Below False: don't limit max_new_tokens more, just rely upon stopping to reach limit of model
98
- if False:
99
- # if input prompt is some number of tokens, despite user request, can't have max_new_tokens more
100
- #
101
- assert num_prompt_tokens is not None
102
- if self.prompt_type not in [PromptType.plain.name, PromptType.plain.value]:
103
- # then give room for prompt
104
- fudge = 20
105
- else:
106
- fudge = 0
107
- max_new_tokens = max(0, min(generate_kwargs['max_new_tokens'],
108
- model_max_length - (num_prompt_tokens + fudge)))
109
- if max_new_tokens < generate_kwargs['max_new_tokens']:
110
- if verbose:
111
- print("Reduced max_new_tokens from %s -> %s" % (
112
- generate_kwargs['max_new_tokens'], max_new_tokens))
113
- generate_kwargs['max_new_tokens'] = max_new_tokens
114
- return prompt_text, num_prompt_tokens
115
-
116
- def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs):
117
- prompt_text, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt_text, self.tokenizer)
118
-
119
- data_point = dict(context=self.context, instruction=prompt_text, input=self.iinput)
120
- if self.prompter is not None:
121
- prompt_text = self.prompter.generate_prompt(data_point)
122
- self.prompt_text = prompt_text
123
- if handle_long_generation is None:
124
- # forces truncation of inputs to avoid critical failure
125
- handle_long_generation = None # disable with new approaches
126
- return super().preprocess(prompt_text, prefix=prefix, handle_long_generation=handle_long_generation,
127
- **generate_kwargs)
128
-
129
- def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, clean_up_tokenization_spaces=True):
130
- records = super().postprocess(model_outputs, return_type=return_type,
131
- clean_up_tokenization_spaces=clean_up_tokenization_spaces)
132
- for rec in records:
133
- if self.use_prompter:
134
- outputs = rec['generated_text']
135
- outputs = self.prompter.get_response(outputs, prompt=self.prompt_text,
136
- sanitize_bot_response=self.sanitize_bot_response)
137
- elif self.bot and self.human:
138
- outputs = rec['generated_text'].split(self.bot)[1].split(self.human)[0]
139
- else:
140
- outputs = rec['generated_text']
141
- rec['generated_text'] = outputs
142
- print("prompt: %s\noutputs: %s\n\n" % (self.prompt_text, outputs), flush=True)
143
- return records
144
-
145
- def _forward(self, model_inputs, **generate_kwargs):
146
- if self.can_stop:
147
- stopping_criteria = get_stopping(self.prompt_type, self.prompt_dict,
148
- self.tokenizer, self.device,
149
- human=self.human, bot=self.bot,
150
- model_max_length=self.tokenizer.model_max_length)
151
- generate_kwargs['stopping_criteria'] = stopping_criteria
152
- # return super()._forward(model_inputs, **generate_kwargs)
153
- return self.__forward(model_inputs, **generate_kwargs)
154
-
155
- # FIXME: Copy-paste of original _forward, but removed copy.deepcopy()
156
- # FIXME: https://github.com/h2oai/h2ogpt/issues/172
157
- def __forward(self, model_inputs, **generate_kwargs):
158
- input_ids = model_inputs["input_ids"]
159
- attention_mask = model_inputs.get("attention_mask", None)
160
- # Allow empty prompts
161
- if input_ids.shape[1] == 0:
162
- input_ids = None
163
- attention_mask = None
164
- in_b = 1
165
- else:
166
- in_b = input_ids.shape[0]
167
- prompt_text = model_inputs.pop("prompt_text")
168
-
169
- ## If there is a prefix, we may need to adjust the generation length. Do so without permanently modifying
170
- ## generate_kwargs, as some of the parameterization may come from the initialization of the pipeline.
171
- # generate_kwargs = copy.deepcopy(generate_kwargs)
172
- prefix_length = generate_kwargs.pop("prefix_length", 0)
173
- if prefix_length > 0:
174
- has_max_new_tokens = "max_new_tokens" in generate_kwargs or (
175
- "generation_config" in generate_kwargs
176
- and generate_kwargs["generation_config"].max_new_tokens is not None
177
- )
178
- if not has_max_new_tokens:
179
- generate_kwargs["max_length"] = generate_kwargs.get("max_length") or self.model.config.max_length
180
- generate_kwargs["max_length"] += prefix_length
181
- has_min_new_tokens = "min_new_tokens" in generate_kwargs or (
182
- "generation_config" in generate_kwargs
183
- and generate_kwargs["generation_config"].min_new_tokens is not None
184
- )
185
- if not has_min_new_tokens and "min_length" in generate_kwargs:
186
- generate_kwargs["min_length"] += prefix_length
187
-
188
- # BS x SL
189
- generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
190
- out_b = generated_sequence.shape[0]
191
- if self.framework == "pt":
192
- generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:])
193
- elif self.framework == "tf":
194
- from transformers import is_tf_available
195
- if is_tf_available():
196
- import tensorflow as tf
197
- generated_sequence = tf.reshape(generated_sequence,
198
- (in_b, out_b // in_b, *generated_sequence.shape[1:]))
199
- else:
200
- raise ValueError("TF not avaialble.")
201
- return {"generated_sequence": generated_sequence, "input_ids": input_ids, "prompt_text": prompt_text}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
loaders.py DELETED
@@ -1,61 +0,0 @@
1
- import functools
2
-
3
-
4
- def get_loaders(model_name, reward_type, llama_type=None, load_gptq=''):
5
- # NOTE: Some models need specific new prompt_type
6
- # E.g. t5_xxl_true_nli_mixture has input format: "premise: PREMISE_TEXT hypothesis: HYPOTHESIS_TEXT".)
7
- if load_gptq:
8
- from transformers import AutoTokenizer
9
- from auto_gptq import AutoGPTQForCausalLM
10
- use_triton = False
11
- functools.partial(AutoGPTQForCausalLM.from_quantized, quantize_config=None, use_triton=use_triton)
12
- return AutoGPTQForCausalLM.from_quantized, AutoTokenizer
13
- if llama_type is None:
14
- llama_type = "llama" in model_name.lower()
15
- if llama_type:
16
- from transformers import LlamaForCausalLM, LlamaTokenizer
17
- return LlamaForCausalLM.from_pretrained, LlamaTokenizer
18
- elif 'distilgpt2' in model_name.lower():
19
- from transformers import AutoModelForCausalLM, AutoTokenizer
20
- return AutoModelForCausalLM.from_pretrained, AutoTokenizer
21
- elif 'gpt2' in model_name.lower():
22
- from transformers import GPT2LMHeadModel, GPT2Tokenizer
23
- return GPT2LMHeadModel.from_pretrained, GPT2Tokenizer
24
- elif 'mbart-' in model_name.lower():
25
- from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
26
- return MBartForConditionalGeneration.from_pretrained, MBart50TokenizerFast
27
- elif 't5' == model_name.lower() or \
28
- 't5-' in model_name.lower() or \
29
- 'flan-' in model_name.lower():
30
- from transformers import AutoTokenizer, T5ForConditionalGeneration
31
- return T5ForConditionalGeneration.from_pretrained, AutoTokenizer
32
- elif 'bigbird' in model_name:
33
- from transformers import BigBirdPegasusForConditionalGeneration, AutoTokenizer
34
- return BigBirdPegasusForConditionalGeneration.from_pretrained, AutoTokenizer
35
- elif 'bart-large-cnn-samsum' in model_name or 'flan-t5-base-samsum' in model_name:
36
- from transformers import pipeline
37
- return pipeline, "summarization"
38
- elif reward_type or 'OpenAssistant/reward-model'.lower() in model_name.lower():
39
- from transformers import AutoModelForSequenceClassification, AutoTokenizer
40
- return AutoModelForSequenceClassification.from_pretrained, AutoTokenizer
41
- else:
42
- from transformers import AutoTokenizer, AutoModelForCausalLM
43
- model_loader = AutoModelForCausalLM
44
- tokenizer_loader = AutoTokenizer
45
- return model_loader.from_pretrained, tokenizer_loader
46
-
47
-
48
- def get_tokenizer(tokenizer_loader, tokenizer_base_model, local_files_only, resume_download, use_auth_token):
49
- tokenizer = tokenizer_loader.from_pretrained(tokenizer_base_model,
50
- local_files_only=local_files_only,
51
- resume_download=resume_download,
52
- use_auth_token=use_auth_token,
53
- padding_side='left')
54
-
55
- tokenizer.pad_token_id = 0 # different from the eos token
56
- # when generating, we will use the logits of right-most token to predict the next token
57
- # so the padding should be on the left,
58
- # e.g. see: https://huggingface.co/transformers/v4.11.3/model_doc/t5.html#inference
59
- tokenizer.padding_side = "left" # Allow batched inference
60
-
61
- return tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prompter.py DELETED
@@ -1,871 +0,0 @@
1
- import os
2
- import ast
3
- import time
4
- from enums import PromptType # also supports imports from this file from other files
5
-
6
- non_hf_types = ['gpt4all_llama', 'llama', 'gptj']
7
-
8
- prompt_type_to_model_name = {
9
- 'plain': [
10
- 'EleutherAI/gpt-j-6B',
11
- 'EleutherAI/pythia-6.9b',
12
- 'EleutherAI/pythia-12b',
13
- 'EleutherAI/pythia-12b-deduped',
14
- 'EleutherAI/gpt-neox-20b',
15
- 'openlm-research/open_llama_7b_700bt_preview',
16
- 'decapoda-research/llama-7b-hf',
17
- 'decapoda-research/llama-13b-hf',
18
- 'decapoda-research/llama-30b-hf',
19
- 'decapoda-research/llama-65b-hf',
20
- 'facebook/mbart-large-50-many-to-many-mmt',
21
- 'philschmid/bart-large-cnn-samsum',
22
- 'philschmid/flan-t5-base-samsum',
23
- 'gpt2',
24
- 'distilgpt2',
25
- 'mosaicml/mpt-7b-storywriter',
26
- ],
27
- 'gptj': ['gptj', 'gpt4all_llama'],
28
- 'prompt_answer': [
29
- 'h2oai/h2ogpt-gm-oasst1-en-1024-20b',
30
- 'h2oai/h2ogpt-gm-oasst1-en-1024-12b',
31
- 'h2oai/h2ogpt-gm-oasst1-multilang-1024-20b',
32
- 'h2oai/h2ogpt-gm-oasst1-multilang-2048-falcon-7b',
33
- 'h2oai/h2ogpt-gm-oasst1-multilang-2048-falcon-7b-v2',
34
- 'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3',
35
- 'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b',
36
- 'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v2',
37
- 'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-40b-v1',
38
- 'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-40b-v2',
39
- 'h2oai/h2ogpt-gm-oasst1-en-xgen-7b-8k',
40
- 'h2oai/h2ogpt-gm-oasst1-multilang-xgen-7b-8k',
41
- 'TheBloke/h2ogpt-gm-oasst1-en-2048-falcon-40b-v2-GPTQ',
42
- ],
43
- 'prompt_answer_openllama': [
44
- 'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt',
45
- 'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2',
46
- 'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-700bt',
47
- 'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b',
48
- 'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-13b',
49
- ],
50
- 'instruct': ['TheBloke/llama-30b-supercot-SuperHOT-8K-fp16'], # https://huggingface.co/TheBloke/llama-30b-supercot-SuperHOT-8K-fp16#prompting
51
- 'instruct_with_end': ['databricks/dolly-v2-12b'],
52
- 'quality': [],
53
- 'human_bot': [
54
- 'h2oai/h2ogpt-oasst1-512-12b',
55
- 'h2oai/h2ogpt-oasst1-512-20b',
56
- 'h2oai/h2ogpt-oig-oasst1-256-6_9b',
57
- 'h2oai/h2ogpt-oig-oasst1-512-6_9b',
58
- 'h2oai/h2ogpt-oig-oasst1-256-6.9b', # legacy
59
- 'h2oai/h2ogpt-oig-oasst1-512-6.9b', # legacy
60
- 'h2oai/h2ogpt-research-oasst1-512-30b',
61
- 'h2oai/h2ogpt-research-oasst1-llama-65b',
62
- 'h2oai/h2ogpt-oasst1-falcon-40b',
63
- 'h2oai/h2ogpt-oig-oasst1-falcon-40b',
64
- ],
65
- 'dai_faq': [],
66
- 'summarize': [],
67
- 'simple_instruct': ['t5-small', 't5-large', 'google/flan-t5', 'google/flan-t5-xxl', 'google/flan-ul2'],
68
- 'instruct_vicuna': ['AlekseyKorshuk/vicuna-7b', 'TheBloke/stable-vicuna-13B-HF', 'junelee/wizard-vicuna-13b'],
69
- 'human_bot_orig': ['togethercomputer/GPT-NeoXT-Chat-Base-20B'],
70
- "open_assistant": ['OpenAssistant/oasst-sft-7-llama-30b-xor', 'oasst-sft-7-llama-30b'],
71
- "wizard_lm": ['ehartford/WizardLM-7B-Uncensored', 'ehartford/WizardLM-13B-Uncensored'],
72
- "wizard_mega": ['openaccess-ai-collective/wizard-mega-13b'],
73
- "instruct_simple": ['JosephusCheung/Guanaco'],
74
- "wizard_vicuna": ['ehartford/Wizard-Vicuna-13B-Uncensored'],
75
- "wizard2": ['llama'],
76
- "mptinstruct": ['mosaicml/mpt-30b-instruct', 'mosaicml/mpt-7b-instruct', 'mosaicml/mpt-30b-instruct'],
77
- "mptchat": ['mosaicml/mpt-7b-chat', 'mosaicml/mpt-30b-chat', 'TheBloke/mpt-30B-chat-GGML'],
78
- "vicuna11": ['lmsys/vicuna-33b-v1.3'],
79
- "falcon": ['tiiuae/falcon-40b-instruct', 'tiiuae/falcon-40b', 'tiiuae/falcon-7b-instruct', 'tiiuae/falcon-7b'],
80
- "llama2": [
81
- 'meta-llama/Llama-2-7b-chat-hf',
82
- 'meta-llama/Llama-2-13b-chat-hf',
83
- 'meta-llama/Llama-2-34b-chat-hf',
84
- 'meta-llama/Llama-2-70b-chat-hf',
85
- ],
86
- # could be plain, but default is correct prompt_type for default TheBloke model ggml-wizardLM-7B.q4_2.bin
87
- }
88
- if os.getenv('OPENAI_API_KEY'):
89
- prompt_type_to_model_name.update({
90
- "openai": ["text-davinci-003", "text-curie-001", "text-babbage-001", "text-ada-001"],
91
- "openai_chat": ["gpt-3.5-turbo", "gpt-3.5-turbo-16k"],
92
- })
93
-
94
- inv_prompt_type_to_model_name = {v.strip(): k for k, l in prompt_type_to_model_name.items() for v in l}
95
- inv_prompt_type_to_model_lower = {v.strip().lower(): k for k, l in prompt_type_to_model_name.items() for v in l}
96
-
97
- prompt_types_strings = []
98
- for p in PromptType:
99
- prompt_types_strings.extend([p.name])
100
-
101
- prompt_types = []
102
- for p in PromptType:
103
- prompt_types.extend([p.name, p.value, str(p.value)])
104
-
105
-
106
- def get_prompt(prompt_type, prompt_dict, chat, context, reduced, making_context, return_dict=False):
107
- prompt_dict_error = ''
108
- generates_leading_space = False
109
-
110
- if prompt_type == PromptType.custom.name and not isinstance(prompt_dict, dict):
111
- try:
112
- prompt_dict = ast.literal_eval(prompt_dict)
113
- except BaseException as e:
114
- prompt_dict_error = str(e)
115
- if prompt_dict_error:
116
- promptA = None
117
- promptB = None
118
- PreInstruct = None
119
- PreInput = ''
120
- PreResponse = ''
121
- terminate_response = None
122
- chat_sep = ''
123
- chat_turn_sep = ''
124
- humanstr = ''
125
- botstr = ''
126
- generates_leading_space = False
127
- elif prompt_type in [PromptType.custom.value, str(PromptType.custom.value),
128
- PromptType.custom.name]:
129
- promptA = prompt_dict.get('promptA', '')
130
- promptB = prompt_dict.get('promptB', '')
131
- PreInstruct = prompt_dict.get('PreInstruct', '')
132
- PreInput = prompt_dict.get('PreInput', '')
133
- PreResponse = prompt_dict.get('PreResponse', '')
134
- terminate_response = prompt_dict.get('terminate_response', None)
135
- chat_sep = prompt_dict.get('chat_sep', '\n')
136
- chat_turn_sep = prompt_dict.get('chat_turn_sep', '\n')
137
- humanstr = prompt_dict.get('humanstr', '')
138
- botstr = prompt_dict.get('botstr', '')
139
- elif prompt_type in [PromptType.plain.value, str(PromptType.plain.value),
140
- PromptType.plain.name]:
141
- promptA = promptB = PreInstruct = PreInput = PreResponse = None
142
- terminate_response = []
143
- chat_turn_sep = chat_sep = ''
144
- # plain should have None for human/bot, so nothing truncated out, not '' that would truncate after first token
145
- humanstr = None
146
- botstr = None
147
- elif prompt_type == 'simple_instruct':
148
- promptA = promptB = PreInstruct = PreInput = PreResponse = None
149
- terminate_response = []
150
- chat_turn_sep = chat_sep = '\n'
151
- humanstr = None
152
- botstr = None
153
- elif prompt_type in [PromptType.instruct.value, str(PromptType.instruct.value),
154
- PromptType.instruct.name] + [PromptType.instruct_with_end.value,
155
- str(PromptType.instruct_with_end.value),
156
- PromptType.instruct_with_end.name]:
157
- promptA = 'Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n' if not (
158
- chat and reduced) else ''
159
- promptB = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n' if not (
160
- chat and reduced) else ''
161
-
162
- PreInstruct = """
163
- ### Instruction:
164
- """
165
-
166
- PreInput = """
167
- ### Input:
168
- """
169
-
170
- PreResponse = """
171
- ### Response:
172
- """
173
- if prompt_type in [PromptType.instruct_with_end.value, str(PromptType.instruct_with_end.value),
174
- PromptType.instruct_with_end.name]:
175
- terminate_response = ['### End']
176
- else:
177
- terminate_response = None
178
- chat_turn_sep = chat_sep = '\n'
179
- humanstr = PreInstruct
180
- botstr = PreResponse
181
- elif prompt_type in [PromptType.quality.value, str(PromptType.quality.value),
182
- PromptType.quality.name]:
183
- promptA = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction as applied on the Input.\n' if not (
184
- chat and reduced) else ''
185
- promptB = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction.\n' if not (
186
- chat and reduced) else ''
187
-
188
- PreInstruct = """
189
- ### Instruction:
190
- """
191
-
192
- PreInput = """
193
- ### Input:
194
- """
195
-
196
- PreResponse = """
197
- ### Response:
198
- """
199
- terminate_response = None
200
- chat_turn_sep = chat_sep = '\n'
201
- humanstr = PreInstruct # first thing human says
202
- botstr = PreResponse # first thing bot says
203
- elif prompt_type in [PromptType.human_bot.value, str(PromptType.human_bot.value),
204
- PromptType.human_bot.name] + [PromptType.human_bot_orig.value,
205
- str(PromptType.human_bot_orig.value),
206
- PromptType.human_bot_orig.name]:
207
- human = '<human>:'
208
- bot = "<bot>:"
209
- if reduced or context or prompt_type in [PromptType.human_bot.value, str(PromptType.human_bot.value),
210
- PromptType.human_bot.name]:
211
- preprompt = ''
212
- else:
213
- cur_date = time.strftime('%Y-%m-%d')
214
- cur_time = time.strftime('%H:%M:%S %p %Z')
215
-
216
- PRE_PROMPT = """\
217
- Current Date: {}
218
- Current Time: {}
219
-
220
- """
221
- preprompt = PRE_PROMPT.format(cur_date, cur_time)
222
- start = ''
223
- promptB = promptA = '%s%s' % (preprompt, start)
224
-
225
- PreInstruct = human + ' '
226
-
227
- PreInput = None
228
-
229
- if making_context:
230
- # when making context, want it to appear as-if LLM generated, which starts with space after :
231
- PreResponse = bot + ' '
232
- else:
233
- # normally LLM adds space after this, because was how trained.
234
- # if add space here, non-unique tokenization will often make LLM produce wrong output
235
- PreResponse = bot
236
-
237
- terminate_response = ['\n' + human, '\n' + bot, human, bot, PreResponse]
238
- chat_turn_sep = chat_sep = '\n'
239
- humanstr = human # tag before human talks
240
- botstr = bot # tag before bot talks
241
- generates_leading_space = True
242
- elif prompt_type in [PromptType.dai_faq.value, str(PromptType.dai_faq.value),
243
- PromptType.dai_faq.name]:
244
- promptA = ''
245
- promptB = 'Answer the following Driverless AI question.\n'
246
-
247
- PreInstruct = """
248
- ### Driverless AI frequently asked question:
249
- """
250
-
251
- PreInput = None
252
-
253
- PreResponse = """
254
- ### Driverless AI documentation answer:
255
- """
256
- terminate_response = ['\n\n']
257
- chat_turn_sep = chat_sep = terminate_response
258
- humanstr = PreInstruct
259
- botstr = PreResponse
260
- elif prompt_type in [PromptType.summarize.value, str(PromptType.summarize.value),
261
- PromptType.summarize.name]:
262
- promptA = promptB = PreInput = ''
263
- PreInstruct = '## Main Text\n\n'
264
- PreResponse = '\n\n## Summary\n\n'
265
- terminate_response = None
266
- chat_turn_sep = chat_sep = '\n'
267
- humanstr = PreInstruct
268
- botstr = PreResponse
269
- elif prompt_type in [PromptType.instruct_vicuna.value, str(PromptType.instruct_vicuna.value),
270
- PromptType.instruct_vicuna.name]:
271
- promptA = promptB = "A chat between a curious human and an artificial intelligence assistant. " \
272
- "The assistant gives helpful, detailed, and polite answers to the human's questions." if not (
273
- chat and reduced) else ''
274
-
275
- PreInstruct = """
276
- ### Human:
277
- """
278
-
279
- PreInput = None
280
-
281
- PreResponse = """
282
- ### Assistant:
283
- """
284
- terminate_response = [
285
- '### Human:'] # but only allow terminate after prompt is found correctly, else can't terminate
286
- chat_turn_sep = chat_sep = '\n'
287
- humanstr = PreInstruct
288
- botstr = PreResponse
289
- elif prompt_type in [PromptType.prompt_answer.value, str(PromptType.prompt_answer.value),
290
- PromptType.prompt_answer.name]:
291
- preprompt = ''
292
- prompt_tokens = "<|prompt|>"
293
- answer_tokens = "<|answer|>"
294
- start = ''
295
- promptB = promptA = '%s%s' % (preprompt, start)
296
- PreInstruct = prompt_tokens
297
- PreInput = None
298
- PreResponse = answer_tokens
299
- eos = '<|endoftext|>' # neox eos
300
- humanstr = prompt_tokens
301
- botstr = answer_tokens
302
- terminate_response = [humanstr, PreResponse, eos]
303
- chat_sep = eos
304
- chat_turn_sep = eos
305
- elif prompt_type in [PromptType.prompt_answer_openllama.value, str(PromptType.prompt_answer_openllama.value),
306
- PromptType.prompt_answer_openllama.name]:
307
- preprompt = ''
308
- prompt_tokens = "<|prompt|>"
309
- answer_tokens = "<|answer|>"
310
- start = ''
311
- promptB = promptA = '%s%s' % (preprompt, start)
312
- PreInstruct = prompt_tokens
313
- PreInput = None
314
- PreResponse = answer_tokens
315
- eos = '</s>' # llama eos
316
- humanstr = prompt_tokens
317
- botstr = answer_tokens
318
- terminate_response = [humanstr, PreResponse, eos]
319
- chat_sep = eos
320
- chat_turn_sep = eos
321
- elif prompt_type in [PromptType.open_assistant.value, str(PromptType.open_assistant.value),
322
- PromptType.open_assistant.name]:
323
- # From added_tokens.json
324
- preprompt = ''
325
- prompt_tokens = "<|prompter|>"
326
- answer_tokens = "<|assistant|>"
327
- start = ''
328
- promptB = promptA = '%s%s' % (preprompt, start)
329
- PreInstruct = prompt_tokens
330
- PreInput = None
331
- PreResponse = answer_tokens
332
- pend = "<|prefix_end|>"
333
- eos = "</s>"
334
- humanstr = prompt_tokens
335
- botstr = answer_tokens
336
- terminate_response = [humanstr, PreResponse, pend, eos]
337
- chat_turn_sep = chat_sep = eos
338
- elif prompt_type in [PromptType.wizard_lm.value, str(PromptType.wizard_lm.value),
339
- PromptType.wizard_lm.name]:
340
- # https://github.com/ehartford/WizardLM/blob/main/src/train_freeform.py
341
- preprompt = ''
342
- start = ''
343
- promptB = promptA = '%s%s' % (preprompt, start)
344
- PreInstruct = ""
345
- PreInput = None
346
- PreResponse = "\n\n### Response\n"
347
- eos = "</s>"
348
- terminate_response = [PreResponse, eos]
349
- chat_turn_sep = chat_sep = eos
350
- humanstr = promptA
351
- botstr = PreResponse
352
- elif prompt_type in [PromptType.wizard_mega.value, str(PromptType.wizard_mega.value),
353
- PromptType.wizard_mega.name]:
354
- preprompt = ''
355
- start = ''
356
- promptB = promptA = '%s%s' % (preprompt, start)
357
- PreInstruct = """
358
- ### Instruction:
359
- """
360
- PreInput = None
361
- PreResponse = """
362
- ### Assistant:
363
- """
364
- terminate_response = [PreResponse]
365
- chat_turn_sep = chat_sep = '\n'
366
- humanstr = PreInstruct
367
- botstr = PreResponse
368
- elif prompt_type in [PromptType.instruct_vicuna2.value, str(PromptType.instruct_vicuna2.value),
369
- PromptType.instruct_vicuna2.name]:
370
- promptA = promptB = "" if not (chat and reduced) else ''
371
-
372
- PreInstruct = """
373
- HUMAN:
374
- """
375
-
376
- PreInput = None
377
-
378
- PreResponse = """
379
- ASSISTANT:
380
- """
381
- terminate_response = [
382
- 'HUMAN:'] # but only allow terminate after prompt is found correctly, else can't terminate
383
- chat_turn_sep = chat_sep = '\n'
384
- humanstr = PreInstruct
385
- botstr = PreResponse
386
- elif prompt_type in [PromptType.instruct_vicuna3.value, str(PromptType.instruct_vicuna3.value),
387
- PromptType.instruct_vicuna3.name]:
388
- promptA = promptB = "" if not (chat and reduced) else ''
389
-
390
- PreInstruct = """
391
- ### User:
392
- """
393
-
394
- PreInput = None
395
-
396
- PreResponse = """
397
- ### Assistant:
398
- """
399
- terminate_response = [
400
- '### User:'] # but only allow terminate after prompt is found correctly, else can't terminate
401
- chat_turn_sep = chat_sep = '\n'
402
- humanstr = PreInstruct
403
- botstr = PreResponse
404
- elif prompt_type in [PromptType.wizard2.value, str(PromptType.wizard2.value),
405
- PromptType.wizard2.name]:
406
- # https://huggingface.co/TheBloke/WizardLM-7B-uncensored-GGML
407
- preprompt = """Below is an instruction that describes a task. Write a response that appropriately completes the request.""" if not (
408
- chat and reduced) else ''
409
- start = ''
410
- promptB = promptA = '%s%s' % (preprompt, start)
411
- PreInstruct = """
412
- ### Instruction:
413
- """
414
- PreInput = None
415
- PreResponse = """
416
- ### Response:
417
- """
418
- terminate_response = [PreResponse]
419
- chat_turn_sep = chat_sep = '\n'
420
- humanstr = PreInstruct
421
- botstr = PreResponse
422
- elif prompt_type in [PromptType.wizard3.value, str(PromptType.wizard3.value),
423
- PromptType.wizard3.name]:
424
- # https://huggingface.co/TheBloke/wizardLM-13B-1.0-GGML
425
- preprompt = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.""" if not (
426
- chat and reduced) else ''
427
- start = ''
428
- promptB = promptA = '%s%s' % (preprompt, start)
429
- PreInstruct = """USER: """
430
- PreInput = None
431
- PreResponse = """ASSISTANT: """
432
- terminate_response = [PreResponse]
433
- chat_turn_sep = chat_sep = '\n'
434
- humanstr = PreInstruct
435
- botstr = PreResponse
436
- elif prompt_type in [PromptType.wizard_vicuna.value, str(PromptType.wizard_vicuna.value),
437
- PromptType.wizard_vicuna.name]:
438
- preprompt = ''
439
- start = ''
440
- promptB = promptA = '%s%s' % (preprompt, start)
441
- PreInstruct = """USER: """
442
- PreInput = None
443
- PreResponse = """ASSISTANT: """
444
- terminate_response = [PreResponse]
445
- chat_turn_sep = chat_sep = '\n'
446
- humanstr = PreInstruct
447
- botstr = PreResponse
448
-
449
- elif prompt_type in [PromptType.instruct_simple.value, str(PromptType.instruct_simple.value),
450
- PromptType.instruct_simple.name]:
451
- promptB = promptA = '' if not (chat and reduced) else ''
452
-
453
- PreInstruct = """
454
- ### Instruction:
455
- """
456
-
457
- PreInput = """
458
- ### Input:
459
- """
460
-
461
- PreResponse = """
462
- ### Response:
463
- """
464
- terminate_response = None
465
- chat_turn_sep = chat_sep = '\n'
466
- humanstr = PreInstruct
467
- botstr = PreResponse
468
- elif prompt_type in [PromptType.openai.value, str(PromptType.openai.value),
469
- PromptType.openai.name]:
470
- preprompt = """The following is a conversation with an AI assistant. The assistant is helpful, creative, clever, and very friendly.""" if not (
471
- chat and reduced) else ''
472
- start = ''
473
- promptB = promptA = '%s%s' % (preprompt, start)
474
- PreInstruct = "\nHuman: "
475
- PreInput = None
476
- PreResponse = "\nAI:"
477
- terminate_response = [PreResponse] + [" Human:", " AI:"]
478
- chat_turn_sep = chat_sep = '\n'
479
- humanstr = PreInstruct
480
- botstr = PreResponse
481
- elif prompt_type in [PromptType.gptj.value, str(PromptType.gptj.value),
482
- PromptType.gptj.name]:
483
- preprompt = "### Instruction:\n The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response." if not (
484
- chat and reduced) else ''
485
- start = ''
486
- promptB = promptA = '%s%s' % (preprompt, start)
487
- PreInstruct = "\n### Prompt: "
488
- PreInput = None
489
- PreResponse = "\n### Response: "
490
- terminate_response = [PreResponse] + ["Prompt:", "Response:"]
491
- chat_turn_sep = chat_sep = '\n'
492
- humanstr = PreInstruct
493
- botstr = PreResponse
494
- elif prompt_type in [PromptType.openai_chat.value, str(PromptType.openai_chat.value),
495
- PromptType.openai_chat.name]:
496
- # prompting and termination all handled by endpoint
497
- preprompt = """"""
498
- start = ''
499
- promptB = promptA = '%s%s' % (preprompt, start)
500
- PreInstruct = ""
501
- PreInput = None
502
- PreResponse = ""
503
- terminate_response = []
504
- chat_turn_sep = chat_sep = '\n'
505
- humanstr = None
506
- botstr = None
507
- elif prompt_type in [PromptType.vicuna11.value, str(PromptType.vicuna11.value),
508
- PromptType.vicuna11.name]:
509
- preprompt = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. """ if not (
510
- chat and reduced) else ''
511
- start = ''
512
- promptB = promptA = '%s%s' % (preprompt, start)
513
- eos = '</s>'
514
- PreInstruct = """USER: """
515
- PreInput = None
516
- PreResponse = """ASSISTANT:"""
517
- terminate_response = [PreResponse]
518
- chat_sep = ' '
519
- chat_turn_sep = eos
520
- humanstr = PreInstruct
521
- botstr = PreResponse
522
-
523
- if making_context:
524
- # when making context, want it to appear as-if LLM generated, which starts with space after :
525
- PreResponse = PreResponse + ' '
526
- else:
527
- # normally LLM adds space after this, because was how trained.
528
- # if add space here, non-unique tokenization will often make LLM produce wrong output
529
- PreResponse = PreResponse
530
- elif prompt_type in [PromptType.mptinstruct.value, str(PromptType.mptinstruct.value),
531
- PromptType.mptinstruct.name]:
532
- # https://huggingface.co/mosaicml/mpt-30b-instruct#formatting
533
- promptA = promptB = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n' if not (
534
- chat and reduced) else ''
535
-
536
- PreInstruct = """
537
- ### Instruction
538
- """
539
-
540
- PreInput = """
541
- ### Input
542
- """
543
-
544
- PreResponse = """
545
- ### Response
546
- """
547
- terminate_response = None
548
- chat_turn_sep = chat_sep = '\n'
549
- humanstr = PreInstruct
550
- botstr = PreResponse
551
- elif prompt_type in [PromptType.mptchat.value, str(PromptType.mptchat.value),
552
- PromptType.mptchat.name]:
553
- # https://huggingface.co/TheBloke/mpt-30B-chat-GGML#prompt-template
554
- promptA = promptB = """<|im_start|>system\nA conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.\n<|im_end|>""" if not (
555
- chat and reduced) else ''
556
-
557
- PreInstruct = """<|im_start|>user
558
- """
559
-
560
- PreInput = None
561
-
562
- PreResponse = """<|im_end|><|im_start|>assistant
563
- """
564
- terminate_response = ['<|im_end|>']
565
- chat_sep = ''
566
- chat_turn_sep = '<|im_end|>'
567
- humanstr = PreInstruct
568
- botstr = PreResponse
569
- elif prompt_type in [PromptType.falcon.value, str(PromptType.falcon.value),
570
- PromptType.falcon.name]:
571
- promptA = promptB = "" if not (chat and reduced) else ''
572
-
573
- PreInstruct = """User: """
574
-
575
- PreInput = None
576
-
577
- PreResponse = """Assistant:"""
578
- terminate_response = ['\nUser', "<|endoftext|>"]
579
- chat_sep = '\n\n'
580
- chat_turn_sep = '\n\n'
581
- humanstr = PreInstruct
582
- botstr = PreResponse
583
- if making_context:
584
- # when making context, want it to appear as-if LLM generated, which starts with space after :
585
- PreResponse = 'Assistant: '
586
- else:
587
- # normally LLM adds space after this, because was how trained.
588
- # if add space here, non-unique tokenization will often make LLM produce wrong output
589
- PreResponse = PreResponse
590
- # generates_leading_space = True
591
- elif prompt_type in [PromptType.guanaco.value, str(PromptType.guanaco.value),
592
- PromptType.guanaco.name]:
593
- # https://huggingface.co/TheBloke/guanaco-65B-GPTQ
594
- promptA = promptB = "" if not (chat and reduced) else ''
595
-
596
- PreInstruct = """### Human: """
597
-
598
- PreInput = None
599
-
600
- PreResponse = """### Assistant:"""
601
- terminate_response = ['### Human:'] # but only allow terminate after prompt is found correctly, else can't terminate
602
- chat_turn_sep = chat_sep = '\n'
603
- humanstr = PreInstruct
604
- botstr = PreResponse
605
- elif prompt_type in [PromptType.llama2.value, str(PromptType.llama2.value),
606
- PromptType.llama2.name]:
607
- PreInstruct = ""
608
- llama2_sys = "<<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n"
609
- prompt = "<s>[INST] "
610
- enable_sys = False # too much safety, hurts accuracy
611
- if not (chat and reduced):
612
- if enable_sys:
613
- promptA = promptB = prompt + llama2_sys
614
- else:
615
- promptA = promptB = prompt
616
- else:
617
- promptA = promptB = ''
618
- PreInput = None
619
- PreResponse = ""
620
- terminate_response = ["[INST]", "</s>"]
621
- chat_sep = ' [/INST]'
622
- chat_turn_sep = ' </s><s>[INST] '
623
- humanstr = PreInstruct
624
- botstr = PreResponse
625
- if making_context:
626
- PreResponse += " "
627
- else:
628
- raise RuntimeError("No such prompt_type=%s" % prompt_type)
629
-
630
- if isinstance(terminate_response, (tuple, list)):
631
- assert '' not in terminate_response, "Bad terminate_response"
632
-
633
- ret_dict = dict(promptA=promptA, promptB=promptB, PreInstruct=PreInstruct, PreInput=PreInput,
634
- PreResponse=PreResponse, terminate_response=terminate_response, chat_sep=chat_sep,
635
- chat_turn_sep=chat_turn_sep,
636
- humanstr=humanstr, botstr=botstr,
637
- generates_leading_space=generates_leading_space)
638
-
639
- if return_dict:
640
- return ret_dict, prompt_dict_error
641
- else:
642
- return tuple(list(ret_dict.values()))
643
-
644
-
645
- def generate_prompt(data_point, prompt_type, prompt_dict, chat, reduced, making_context):
646
- context = data_point.get('context')
647
- if context is None:
648
- context = ''
649
- instruction = data_point.get('instruction')
650
- input = data_point.get('input')
651
- output = data_point.get('output')
652
- prompt_type = data_point.get('prompt_type', prompt_type)
653
- prompt_dict = data_point.get('prompt_dict', prompt_dict)
654
- assert prompt_type in prompt_types, "Bad prompt type: %s" % prompt_type
655
- promptA, promptB, PreInstruct, PreInput, PreResponse, \
656
- terminate_response, chat_sep, chat_turn_sep, humanstr, botstr, \
657
- generates_leading_space = get_prompt(prompt_type, prompt_dict, chat,
658
- context, reduced, making_context)
659
-
660
- # could avoid if reduce=True, but too complex for parent functions to handle
661
- prompt = context
662
-
663
- if input and promptA:
664
- prompt += f"""{promptA}"""
665
- elif promptB:
666
- prompt += f"""{promptB}"""
667
-
668
- if instruction and PreInstruct is not None and input and PreInput is not None:
669
- prompt += f"""{PreInstruct}{instruction}{PreInput}{input}"""
670
- prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
671
- elif instruction and input and PreInstruct is None and PreInput is not None:
672
- prompt += f"""{PreInput}{instruction}
673
- {input}"""
674
- prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
675
- elif input and instruction and PreInput is None and PreInstruct is not None:
676
- prompt += f"""{PreInstruct}{instruction}
677
- {input}"""
678
- prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
679
- elif instruction and PreInstruct is not None:
680
- prompt += f"""{PreInstruct}{instruction}"""
681
- prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
682
- elif input and PreInput is not None:
683
- prompt += f"""{PreInput}{input}"""
684
- prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
685
- elif input and instruction and PreInput is not None:
686
- prompt += f"""{PreInput}{instruction}{input}"""
687
- prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
688
- elif input and instruction and PreInstruct is not None:
689
- prompt += f"""{PreInstruct}{instruction}{input}"""
690
- prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
691
- elif input and instruction:
692
- # i.e. for simple_instruct
693
- prompt += f"""{instruction}: {input}"""
694
- prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
695
- elif input:
696
- prompt += f"""{input}"""
697
- prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
698
- elif instruction:
699
- prompt += f"""{instruction}"""
700
- prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
701
-
702
- if PreResponse is not None:
703
- prompt += f"""{PreResponse}"""
704
- pre_response = PreResponse # Don't use strip
705
- else:
706
- pre_response = ''
707
-
708
- if output:
709
- prompt += f"""{output}"""
710
-
711
- return prompt, pre_response, terminate_response, chat_sep, chat_turn_sep
712
-
713
-
714
- def inject_chatsep(prompt_type, prompt, chat_sep=None):
715
- if chat_sep:
716
- # only add new line if structured prompt, while 'plain' is just generation of next tokens from input
717
- prompt += chat_sep
718
- return prompt
719
-
720
-
721
- class Prompter(object):
722
- def __init__(self, prompt_type, prompt_dict, debug=False, chat=False, stream_output=False, repeat_penalty=True,
723
- allowed_repeat_line_length=10):
724
- self.prompt_type = prompt_type
725
- self.prompt_dict = prompt_dict
726
- self.debug = debug
727
- self.chat = chat
728
- self.stream_output = stream_output
729
- self.repeat_penalty = repeat_penalty
730
- self.allowed_repeat_line_length = allowed_repeat_line_length
731
- self.prompt = None
732
- context = "" # not for chat context
733
- reduced = False # not for chat context
734
- making_context = False # not for chat context
735
- self.promptA, self.promptB, self.PreInstruct, self.PreInput, self.PreResponse, \
736
- self.terminate_response, self.chat_sep, self.chat_turn_sep, self.humanstr, self.botstr, \
737
- self.generates_leading_space = \
738
- get_prompt(self.prompt_type, self.prompt_dict, chat, context, reduced, making_context)
739
- self.pre_response = self.PreResponse
740
-
741
- def generate_prompt(self, data_point, reduced=None):
742
- """
743
- data_point['context'] is assumed to be like a system prompt or pre-conversation, not inserted after user prompt
744
- :param data_point:
745
- :param reduced:
746
- :return:
747
- """
748
- reduced = data_point.get('context') not in ['', None] if reduced is None else reduced
749
- making_context = False # whether really making final prompt or just generating context
750
- prompt, _, _, _, _ = generate_prompt(data_point, self.prompt_type, self.prompt_dict, self.chat, reduced,
751
- making_context)
752
- if self.debug:
753
- print("prompt: %s" % prompt, flush=True)
754
- # if have context, should have always reduced and only preappend promptA/B here
755
- if data_point.get('context'):
756
- if data_point.get('input') and self.promptA:
757
- prompt = self.promptA + prompt
758
- elif self.promptB:
759
- prompt = self.promptB + prompt
760
-
761
- self.prompt = prompt
762
- return prompt
763
-
764
- def get_response(self, outputs, prompt=None, sanitize_bot_response=False):
765
- if isinstance(outputs, str):
766
- outputs = [outputs]
767
- if self.debug:
768
- print("output:\n%s" % '\n\n'.join(outputs), flush=True)
769
- if prompt is not None:
770
- self.prompt = prompt
771
-
772
- def clean_response(response):
773
- meaningless_words = ['<pad>', '</s>', '<|endoftext|>']
774
- for word in meaningless_words:
775
- response = response.replace(word, "")
776
- if sanitize_bot_response:
777
- from better_profanity import profanity
778
- response = profanity.censor(response)
779
- if self.generates_leading_space and isinstance(response, str) and len(response) > 0 and response[0] == ' ':
780
- response = response[1:]
781
- return response
782
-
783
- def clean_repeats(response):
784
- lines = response.split('\n')
785
- new_lines = []
786
- [new_lines.append(line) for line in lines if
787
- line not in new_lines or len(line) < self.allowed_repeat_line_length]
788
- if self.debug and len(lines) != len(new_lines):
789
- print("cleaned repeats: %s %s" % (len(lines), len(new_lines)), flush=True)
790
- response = '\n'.join(new_lines)
791
- return response
792
-
793
- multi_output = len(outputs) > 1
794
-
795
- for oi, output in enumerate(outputs):
796
- if self.prompt_type in [PromptType.plain.value, str(PromptType.plain.value), PromptType.plain.name]:
797
- output = clean_response(output)
798
- elif prompt is None:
799
- # then use most basic parsing like pipeline
800
- if not self.botstr:
801
- pass
802
- elif self.botstr in output:
803
- if self.humanstr:
804
- output = clean_response(output.split(self.botstr)[1].split(self.humanstr)[0])
805
- else:
806
- # i.e. use after bot but only up to next bot
807
- output = clean_response(output.split(self.botstr)[1].split(self.botstr)[0])
808
- else:
809
- # output = clean_response(output)
810
- # assume just not printed yet
811
- output = ""
812
- else:
813
- # find first instance of prereponse
814
- # prompt sometimes has odd characters, that mutate length,
815
- # so can't go by length alone
816
- if self.pre_response:
817
- outputi = output.find(prompt)
818
- if outputi >= 0:
819
- output = output[outputi + len(prompt):]
820
- allow_terminate = True
821
- else:
822
- # subtraction is risky due to space offsets sometimes, so only do if necessary
823
- output = output[len(prompt) - len(self.pre_response):]
824
- # [1] to avoid repeated pre_response, just take first (after prompt - pre_response for chat)
825
- if self.pre_response in output:
826
- output = output.split(self.pre_response)[1]
827
- allow_terminate = True
828
- else:
829
- if output:
830
- print("Failure of parsing or not enough output yet: %s" % output, flush=True)
831
- allow_terminate = False
832
- else:
833
- allow_terminate = True
834
- output = output[len(prompt):]
835
- # clean after subtract prompt out, so correct removal of pre_response
836
- output = clean_response(output)
837
- if self.repeat_penalty:
838
- output = clean_repeats(output)
839
- if self.terminate_response and allow_terminate:
840
- finds = []
841
- for term in self.terminate_response:
842
- finds.append(output.find(term))
843
- finds = [x for x in finds if x >= 0]
844
- if len(finds) > 0:
845
- termi = finds[0]
846
- output = output[:termi]
847
- else:
848
- output = output
849
- if multi_output:
850
- # prefix with output counter
851
- output = "\n=========== Output %d\n\n" % (1 + oi) + output
852
- if oi > 0:
853
- # post fix outputs with seperator
854
- output += '\n'
855
- output = self.fix_text(self.prompt_type, output)
856
- outputs[oi] = output
857
- # join all outputs, only one extra new line between outputs
858
- output = '\n'.join(outputs)
859
- if self.debug:
860
- print("outputclean:\n%s" % '\n\n'.join(outputs), flush=True)
861
- return output
862
-
863
- @staticmethod
864
- def fix_text(prompt_type1, text1):
865
- if prompt_type1 == 'human_bot':
866
- # hack bug in vLLM with stopping, stops right, but doesn't return last token
867
- hfix = '<human'
868
- if text1.endswith(hfix):
869
- text1 = text1[:-len(hfix)]
870
- return text1
871
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stopping.py DELETED
@@ -1,78 +0,0 @@
1
- import torch
2
- from transformers import StoppingCriteria, StoppingCriteriaList
3
-
4
- from enums import PromptType
5
-
6
-
7
- class StoppingCriteriaSub(StoppingCriteria):
8
-
9
- def __init__(self, stops=[], encounters=[], device="cuda", model_max_length=None):
10
- super().__init__()
11
- assert len(stops) % len(encounters) == 0, "Number of stops and encounters must match"
12
- self.encounters = encounters
13
- self.stops = [stop.to(device) for stop in stops]
14
- self.num_stops = [0] * len(stops)
15
- self.model_max_length = model_max_length
16
-
17
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
18
- for stopi, stop in enumerate(self.stops):
19
- if torch.all((stop == input_ids[0][-len(stop):])).item():
20
- self.num_stops[stopi] += 1
21
- if self.num_stops[stopi] >= self.encounters[stopi % len(self.encounters)]:
22
- # print("Stopped", flush=True)
23
- return True
24
- if self.model_max_length is not None and input_ids[0].shape[0] >= self.model_max_length:
25
- # critical limit
26
- return True
27
- # print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True)
28
- # print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True)
29
- return False
30
-
31
-
32
- def get_stopping(prompt_type, prompt_dict, tokenizer, device, human='<human>:', bot="<bot>:", model_max_length=None):
33
- # FIXME: prompt_dict unused currently
34
- if prompt_type in [PromptType.human_bot.name, PromptType.instruct_vicuna.name, PromptType.instruct_with_end.name]:
35
- if prompt_type == PromptType.human_bot.name:
36
- # encounters = [prompt.count(human) + 1, prompt.count(bot) + 1]
37
- # stopping only starts once output is beyond prompt
38
- # 1 human is enough to trigger, but need 2 bots, because very first view back will be bot we added
39
- stop_words = [human, bot, '\n' + human, '\n' + bot]
40
- encounters = [1, 2]
41
- elif prompt_type == PromptType.instruct_vicuna.name:
42
- # even below is not enough, generic strings and many ways to encode
43
- stop_words = [
44
- '### Human:',
45
- """
46
- ### Human:""",
47
- """
48
- ### Human:
49
- """,
50
- '### Assistant:',
51
- """
52
- ### Assistant:""",
53
- """
54
- ### Assistant:
55
- """,
56
- ]
57
- encounters = [1, 2]
58
- else:
59
- # some instruct prompts have this as end, doesn't hurt to stop on it since not common otherwise
60
- stop_words = ['### End']
61
- encounters = [1]
62
- stop_words_ids = [
63
- tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words]
64
- # handle single token case
65
- stop_words_ids = [x if len(x.shape) > 0 else torch.tensor([x]) for x in stop_words_ids]
66
- stop_words_ids = [x for x in stop_words_ids if x.shape[0] > 0]
67
- # avoid padding in front of tokens
68
- if tokenizer._pad_token: # use hidden variable to avoid annoying properly logger bug
69
- stop_words_ids = [x[1:] if x[0] == tokenizer.pad_token_id and len(x) > 1 else x for x in stop_words_ids]
70
- # handle fake \n added
71
- stop_words_ids = [x[1:] if y[0] == '\n' else x for x, y in zip(stop_words_ids, stop_words)]
72
- # build stopper
73
- stopping_criteria = StoppingCriteriaList(
74
- [StoppingCriteriaSub(stops=stop_words_ids, encounters=encounters, device=device,
75
- model_max_length=model_max_length)])
76
- else:
77
- stopping_criteria = StoppingCriteriaList()
78
- return stopping_criteria
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils.py DELETED
@@ -1,1080 +0,0 @@
1
- import contextlib
2
- import functools
3
- import hashlib
4
- import inspect
5
- import os
6
- import gc
7
- import pathlib
8
- import pickle
9
- import random
10
- import shutil
11
- import subprocess
12
- import sys
13
- import threading
14
- import time
15
- import traceback
16
- import zipfile
17
- from datetime import datetime
18
-
19
- import filelock
20
- import requests, uuid
21
- from typing import Tuple, Callable, Dict
22
- from tqdm.auto import tqdm
23
- from joblib import Parallel
24
- from concurrent.futures import ProcessPoolExecutor
25
- import numpy as np
26
- import pandas as pd
27
-
28
-
29
- def set_seed(seed: int):
30
- """
31
- Sets the seed of the entire notebook so results are the same every time we run.
32
- This is for REPRODUCIBILITY.
33
- """
34
- import torch
35
- np.random.seed(seed)
36
- random_state = np.random.RandomState(seed)
37
- random.seed(seed)
38
- torch.manual_seed(seed)
39
- torch.cuda.manual_seed(seed)
40
- torch.backends.cudnn.deterministic = True
41
- torch.backends.cudnn.benchmark = False
42
- os.environ['PYTHONHASHSEED'] = str(seed)
43
- return random_state
44
-
45
-
46
- def flatten_list(lis):
47
- """Given a list, possibly nested to any level, return it flattened."""
48
- new_lis = []
49
- for item in lis:
50
- if type(item) == type([]):
51
- new_lis.extend(flatten_list(item))
52
- else:
53
- new_lis.append(item)
54
- return new_lis
55
-
56
-
57
- def clear_torch_cache():
58
- import torch
59
- if torch.cuda.is_available():
60
- torch.cuda.empty_cache()
61
- torch.cuda.ipc_collect()
62
- gc.collect()
63
-
64
-
65
- def ping():
66
- try:
67
- print('Ping: %s' % str(datetime.now()), flush=True)
68
- except AttributeError:
69
- # some programs wrap print and will fail with flush passed
70
- pass
71
-
72
-
73
- def ping_gpu():
74
- try:
75
- print('Ping_GPU: %s %s' % (str(datetime.now()), system_info()), flush=True)
76
- except AttributeError:
77
- # some programs wrap print and will fail with flush passed
78
- pass
79
- try:
80
- ping_gpu_memory()
81
- except Exception as e:
82
- print('Ping_GPU memory failure: %s' % str(e), flush=True)
83
-
84
-
85
- def ping_gpu_memory():
86
- from models.gpu_mem_track import MemTracker
87
- gpu_tracker = MemTracker() # define a GPU tracker
88
- from torch.cuda import memory_summary
89
- gpu_tracker.track()
90
-
91
-
92
- def get_torch_allocated():
93
- import torch
94
- return torch.cuda.memory_allocated()
95
-
96
-
97
- def get_device():
98
- import torch
99
- if torch.cuda.is_available():
100
- device = "cuda"
101
- elif torch.backends.mps.is_built():
102
- device = "mps"
103
- else:
104
- device = "cpu"
105
-
106
- return device
107
-
108
-
109
- def system_info():
110
- import psutil
111
-
112
- system = {}
113
- # https://stackoverflow.com/questions/48951136/plot-multiple-graphs-in-one-plot-using-tensorboard
114
- # https://arshren.medium.com/monitoring-your-devices-in-python-5191d672f749
115
- try:
116
- temps = psutil.sensors_temperatures(fahrenheit=False)
117
- if 'coretemp' in temps:
118
- coretemp = temps['coretemp']
119
- temp_dict = {k.label: k.current for k in coretemp}
120
- for k, v in temp_dict.items():
121
- system['CPU_C/%s' % k] = v
122
- except AttributeError:
123
- pass
124
-
125
- # https://github.com/gpuopenanalytics/pynvml/blob/master/help_query_gpu.txt
126
- try:
127
- from pynvml.smi import nvidia_smi
128
- nvsmi = nvidia_smi.getInstance()
129
-
130
- gpu_power_dict = {'W_gpu%d' % i: x['power_readings']['power_draw'] for i, x in
131
- enumerate(nvsmi.DeviceQuery('power.draw')['gpu'])}
132
- for k, v in gpu_power_dict.items():
133
- system['GPU_W/%s' % k] = v
134
-
135
- gpu_temp_dict = {'C_gpu%d' % i: x['temperature']['gpu_temp'] for i, x in
136
- enumerate(nvsmi.DeviceQuery('temperature.gpu')['gpu'])}
137
- for k, v in gpu_temp_dict.items():
138
- system['GPU_C/%s' % k] = v
139
-
140
- gpu_memory_free_dict = {'MiB_gpu%d' % i: x['fb_memory_usage']['free'] for i, x in
141
- enumerate(nvsmi.DeviceQuery('memory.free')['gpu'])}
142
- gpu_memory_total_dict = {'MiB_gpu%d' % i: x['fb_memory_usage']['total'] for i, x in
143
- enumerate(nvsmi.DeviceQuery('memory.total')['gpu'])}
144
- gpu_memory_frac_dict = {k: gpu_memory_free_dict[k] / gpu_memory_total_dict[k] for k in gpu_memory_total_dict}
145
- for k, v in gpu_memory_frac_dict.items():
146
- system[f'GPU_M/%s' % k] = v
147
- except (KeyError, ModuleNotFoundError):
148
- pass
149
- system['hash'] = get_githash()
150
-
151
- return system
152
-
153
-
154
- def system_info_print():
155
- try:
156
- df = pd.DataFrame.from_dict(system_info(), orient='index')
157
- # avoid slamming GPUs
158
- time.sleep(1)
159
- return df.to_markdown()
160
- except Exception as e:
161
- return "Error: %s" % str(e)
162
-
163
-
164
- def zip_data(root_dirs=None, zip_file=None, base_dir='./', fail_any_exception=False):
165
- try:
166
- return _zip_data(zip_file=zip_file, base_dir=base_dir, root_dirs=root_dirs)
167
- except Exception as e:
168
- traceback.print_exc()
169
- print('Exception in zipping: %s' % str(e))
170
- if not fail_any_exception:
171
- raise
172
-
173
-
174
- def _zip_data(root_dirs=None, zip_file=None, base_dir='./'):
175
- if isinstance(root_dirs, str):
176
- root_dirs = [root_dirs]
177
- if zip_file is None:
178
- datetime_str = str(datetime.now()).replace(" ", "_").replace(":", "_")
179
- host_name = os.getenv('HF_HOSTNAME', 'emptyhost')
180
- zip_file = "data_%s_%s.zip" % (datetime_str, host_name)
181
- assert root_dirs is not None
182
- if not os.path.isdir(os.path.dirname(zip_file)) and os.path.dirname(zip_file):
183
- os.makedirs(os.path.dirname(zip_file), exist_ok=True)
184
- with zipfile.ZipFile(zip_file, "w") as expt_zip:
185
- for root_dir in root_dirs:
186
- if root_dir is None:
187
- continue
188
- for root, d, files in os.walk(root_dir):
189
- for file in files:
190
- file_to_archive = os.path.join(root, file)
191
- assert os.path.exists(file_to_archive)
192
- path_to_archive = os.path.relpath(file_to_archive, base_dir)
193
- expt_zip.write(filename=file_to_archive, arcname=path_to_archive)
194
- return zip_file, zip_file
195
-
196
-
197
- def save_generate_output(prompt=None, output=None, base_model=None, save_dir=None, where_from='unknown where from',
198
- extra_dict={}):
199
- try:
200
- return _save_generate_output(prompt=prompt, output=output, base_model=base_model, save_dir=save_dir,
201
- where_from=where_from, extra_dict=extra_dict)
202
- except Exception as e:
203
- traceback.print_exc()
204
- print('Exception in saving: %s' % str(e))
205
-
206
-
207
- def _save_generate_output(prompt=None, output=None, base_model=None, save_dir=None, where_from='unknown where from',
208
- extra_dict={}):
209
- """
210
- Save conversation to .json, row by row.
211
- json_file_path is path to final JSON file. If not in ., then will attempt to make directories.
212
- Appends if file exists
213
- """
214
- prompt = '<not set>' if prompt is None else prompt
215
- output = '<not set>' if output is None else output
216
- assert save_dir, "save_dir must be provided"
217
- if os.path.exists(save_dir) and not os.path.isdir(save_dir):
218
- raise RuntimeError("save_dir already exists and is not a directory!")
219
- os.makedirs(save_dir, exist_ok=True)
220
- import json
221
- dict_to_save = dict(prompt=prompt, text=output, time=time.ctime(), base_model=base_model, where_from=where_from)
222
- dict_to_save.update(extra_dict)
223
- with filelock.FileLock("save_dir.lock"):
224
- # lock logging in case have concurrency
225
- with open(os.path.join(save_dir, "history.json"), "a") as f:
226
- # just add [ at start, and ] at end, and have proper JSON dataset
227
- f.write(
228
- " " + json.dumps(
229
- dict_to_save
230
- ) + ",\n"
231
- )
232
-
233
-
234
- def s3up(filename):
235
- try:
236
- return _s3up(filename)
237
- except Exception as e:
238
- traceback.print_exc()
239
- print('Exception for file %s in s3up: %s' % (filename, str(e)))
240
- return "Failed to upload %s: Error: %s" % (filename, str(e))
241
-
242
-
243
- def _s3up(filename):
244
- import boto3
245
-
246
- aws_access_key_id = os.getenv('AWS_SERVER_PUBLIC_KEY')
247
- aws_secret_access_key = os.getenv('AWS_SERVER_SECRET_KEY')
248
- bucket = os.getenv('AWS_BUCKET')
249
- assert aws_access_key_id, "Set AWS key"
250
- assert aws_secret_access_key, "Set AWS secret"
251
- assert bucket, "Set AWS Bucket"
252
-
253
- s3 = boto3.client('s3',
254
- aws_access_key_id=os.getenv('AWS_SERVER_PUBLIC_KEY'),
255
- aws_secret_access_key=os.getenv('AWS_SERVER_SECRET_KEY'),
256
- )
257
- ret = s3.upload_file(
258
- Filename=filename,
259
- Bucket=os.getenv('AWS_BUCKET'),
260
- Key=filename,
261
- )
262
- if ret in [None, '']:
263
- return "Successfully uploaded %s" % filename
264
-
265
-
266
- def get_githash():
267
- try:
268
- githash = subprocess.run(['git', 'rev-parse', 'HEAD'], stdout=subprocess.PIPE).stdout.decode('utf-8')[0:-1]
269
- except:
270
- githash = ''
271
- return githash
272
-
273
-
274
- def copy_code(run_id):
275
- """
276
- copy code to track changes
277
- :param run_id:
278
- :return:
279
- """
280
- rnd_num = str(random.randint(0, 2 ** 31))
281
- run_id = 'run_' + str(run_id)
282
- os.makedirs(run_id, exist_ok=True)
283
- me_full = os.path.join(pathlib.Path(__file__).parent.resolve(), __file__)
284
- me_file = os.path.basename(__file__)
285
- new_me = os.path.join(run_id, me_file + '_' + get_githash())
286
- if os.path.isfile(new_me):
287
- new_me = os.path.join(run_id, me_file + '_' + get_githash() + '_' + rnd_num)
288
- shutil.copy(me_full, new_me)
289
- else:
290
- shutil.copy(me_full, new_me)
291
-
292
-
293
- class NullContext(threading.local):
294
- """No-op context manager, executes block without doing any additional processing.
295
-
296
- Used as a stand-in if a particular block of code is only sometimes
297
- used with a normal context manager:
298
- """
299
-
300
- def __init__(self, *args, **kwargs):
301
- pass
302
-
303
- def __enter__(self):
304
- return self
305
-
306
- def __exit__(self, exc_type, exc_value, exc_traceback):
307
- self.finally_act()
308
-
309
- def finally_act(self):
310
- pass
311
-
312
-
313
- def wrapped_partial(func, *args, **kwargs):
314
- """
315
- Give partial properties of normal function, like __name__ attribute etc.
316
- :param func:
317
- :param args:
318
- :param kwargs:
319
- :return:
320
- """
321
- partial_func = functools.partial(func, *args, **kwargs)
322
- functools.update_wrapper(partial_func, func)
323
- return partial_func
324
-
325
-
326
- class ThreadException(Exception):
327
- pass
328
-
329
-
330
- class EThread(threading.Thread):
331
- # Function that raises the custom exception
332
- def __init__(self, group=None, target=None, name=None,
333
- args=(), kwargs=None, *, daemon=None, streamer=None, bucket=None):
334
- self.bucket = bucket
335
- self.streamer = streamer
336
- self.exc = None
337
- self._return = None
338
- super().__init__(group=group, target=target, name=name, args=args, kwargs=kwargs, daemon=daemon)
339
-
340
- def run(self):
341
- # Variable that stores the exception, if raised by someFunction
342
- try:
343
- if self._target is not None:
344
- self._return = self._target(*self._args, **self._kwargs)
345
- except BaseException as e:
346
- print("thread exception: %s" % str(sys.exc_info()))
347
- self.bucket.put(sys.exc_info())
348
- self.exc = e
349
- if self.streamer:
350
- print("make stop: %s" % str(sys.exc_info()), flush=True)
351
- self.streamer.do_stop = True
352
- finally:
353
- # Avoid a refcycle if the thread is running a function with
354
- # an argument that has a member that points to the thread.
355
- del self._target, self._args, self._kwargs
356
-
357
- def join(self, timeout=None):
358
- threading.Thread.join(self)
359
- # Since join() returns in caller thread
360
- # we re-raise the caught exception
361
- # if any was caught
362
- if self.exc:
363
- raise self.exc
364
- return self._return
365
-
366
-
367
- def import_matplotlib():
368
- import matplotlib
369
- matplotlib.use('agg')
370
- # KEEP THESE HERE! START
371
- import matplotlib.pyplot as plt
372
- import pandas as pd
373
- # to avoid dlopen deadlock in fork
374
- import pandas.core.computation.expressions as pd_expressions
375
- import pandas._libs.groupby as pd_libgroupby
376
- import pandas._libs.reduction as pd_libreduction
377
- import pandas.core.algorithms as pd_algorithms
378
- import pandas.core.common as pd_com
379
- import numpy as np
380
- # KEEP THESE HERE! END
381
-
382
-
383
- def get_sha(value):
384
- return hashlib.md5(str(value).encode('utf-8')).hexdigest()
385
-
386
-
387
- def sanitize_filename(name):
388
- """
389
- Sanitize file *base* names.
390
- :param name: name to sanitize
391
- :return:
392
- """
393
- bad_chars = ['[', ']', ',', '/', '\\', '\\w', '\\s', '-', '+', '\"', '\'', '>', '<', ' ', '=', ')', '(', ':', '^']
394
- for char in bad_chars:
395
- name = name.replace(char, "_")
396
-
397
- length = len(name)
398
- file_length_limit = 250 # bit smaller than 256 for safety
399
- sha_length = 32
400
- real_length_limit = file_length_limit - (sha_length + 2)
401
- if length > file_length_limit:
402
- sha = get_sha(name)
403
- half_real_length_limit = max(1, int(real_length_limit / 2))
404
- name = name[0:half_real_length_limit] + "_" + sha + "_" + name[length - half_real_length_limit:length]
405
-
406
- return name
407
-
408
-
409
- def shutil_rmtree(*args, **kwargs):
410
- return shutil.rmtree(*args, **kwargs)
411
-
412
-
413
- def remove(path: str):
414
- try:
415
- if path is not None and os.path.exists(path):
416
- if os.path.isdir(path):
417
- shutil_rmtree(path, ignore_errors=True)
418
- else:
419
- with contextlib.suppress(FileNotFoundError):
420
- os.remove(path)
421
- except:
422
- pass
423
-
424
-
425
- def makedirs(path, exist_ok=True):
426
- """
427
- Avoid some inefficiency in os.makedirs()
428
- :param path:
429
- :param exist_ok:
430
- :return:
431
- """
432
- if os.path.isdir(path) and os.path.exists(path):
433
- assert exist_ok, "Path already exists"
434
- return path
435
- os.makedirs(path, exist_ok=exist_ok)
436
-
437
-
438
- def atomic_move_simple(src, dst):
439
- try:
440
- shutil.move(src, dst)
441
- except (shutil.Error, FileExistsError):
442
- pass
443
- remove(src)
444
-
445
-
446
- def download_simple(url, dest=None, print_func=None):
447
- if print_func is not None:
448
- print_func("BEGIN get url %s" % str(url))
449
- if url.startswith("file://"):
450
- from requests_file import FileAdapter
451
- s = requests.Session()
452
- s.mount('file://', FileAdapter())
453
- url_data = s.get(url, stream=True)
454
- else:
455
- url_data = requests.get(url, stream=True)
456
- if dest is None:
457
- dest = os.path.basename(url)
458
- if url_data.status_code != requests.codes.ok:
459
- msg = "Cannot get url %s, code: %s, reason: %s" % (
460
- str(url),
461
- str(url_data.status_code),
462
- str(url_data.reason),
463
- )
464
- raise requests.exceptions.RequestException(msg)
465
- url_data.raw.decode_content = True
466
- makedirs(os.path.dirname(dest), exist_ok=True)
467
- uuid_tmp = str(uuid.uuid4())[:6]
468
- dest_tmp = dest + "_dl_" + uuid_tmp + ".tmp"
469
- with open(dest_tmp, "wb") as f:
470
- shutil.copyfileobj(url_data.raw, f)
471
- atomic_move_simple(dest_tmp, dest)
472
- if print_func is not None:
473
- print_func("END get url %s" % str(url))
474
-
475
-
476
- def download(url, dest=None, dest_path=None):
477
- if dest_path is not None:
478
- dest = os.path.join(dest_path, os.path.basename(url))
479
- if os.path.isfile(dest):
480
- print("already downloaded %s -> %s" % (url, dest))
481
- return dest
482
- elif dest is not None:
483
- if os.path.exists(dest):
484
- print("already downloaded %s -> %s" % (url, dest))
485
- return dest
486
- else:
487
- uuid_tmp = "dl2_" + str(uuid.uuid4())[:6]
488
- dest = uuid_tmp + os.path.basename(url)
489
-
490
- print("downloading %s to %s" % (url, dest))
491
-
492
- if url.startswith("file://"):
493
- from requests_file import FileAdapter
494
- s = requests.Session()
495
- s.mount('file://', FileAdapter())
496
- url_data = s.get(url, stream=True)
497
- else:
498
- url_data = requests.get(url, stream=True)
499
-
500
- if url_data.status_code != requests.codes.ok:
501
- msg = "Cannot get url %s, code: %s, reason: %s" % (
502
- str(url), str(url_data.status_code), str(url_data.reason))
503
- raise requests.exceptions.RequestException(msg)
504
- url_data.raw.decode_content = True
505
- dirname = os.path.dirname(dest)
506
- if dirname != "" and not os.path.isdir(dirname):
507
- makedirs(os.path.dirname(dest), exist_ok=True)
508
- uuid_tmp = "dl3_" + str(uuid.uuid4())[:6]
509
- dest_tmp = dest + "_" + uuid_tmp + ".tmp"
510
- with open(dest_tmp, 'wb') as f:
511
- shutil.copyfileobj(url_data.raw, f)
512
- try:
513
- shutil.move(dest_tmp, dest)
514
- except FileExistsError:
515
- pass
516
- remove(dest_tmp)
517
- return dest
518
-
519
-
520
- def get_url(x, from_str=False, short_name=False):
521
- if not from_str:
522
- source = x.metadata['source']
523
- else:
524
- source = x
525
- if short_name:
526
- source_name = get_short_name(source)
527
- else:
528
- source_name = source
529
- if source.startswith('http://') or source.startswith('https://'):
530
- return """<a href="%s" target="_blank" rel="noopener noreferrer">%s</a>""" % (
531
- source, source_name)
532
- else:
533
- return """<a href="file/%s" target="_blank" rel="noopener noreferrer">%s</a>""" % (
534
- source, source_name)
535
-
536
-
537
- def get_short_name(name, maxl=50):
538
- if name is None:
539
- return ''
540
- length = len(name)
541
- if length > maxl:
542
- allow_length = maxl - 3
543
- half_allowed = max(1, int(allow_length / 2))
544
- name = name[0:half_allowed] + "..." + name[length - half_allowed:length]
545
- return name
546
-
547
-
548
- def cuda_vis_check(total_gpus):
549
- """Helper function to count GPUs by environment variable
550
- Stolen from Jon's h2o4gpu utils
551
- """
552
- cudavis = os.getenv("CUDA_VISIBLE_DEVICES")
553
- which_gpus = []
554
- if cudavis is not None:
555
- # prune away white-space, non-numerics,
556
- # except commas for simple checking
557
- cudavis = "".join(cudavis.split())
558
- import re
559
- cudavis = re.sub("[^0-9,]", "", cudavis)
560
-
561
- lencudavis = len(cudavis)
562
- if lencudavis == 0:
563
- total_gpus = 0
564
- else:
565
- total_gpus = min(
566
- total_gpus,
567
- os.getenv("CUDA_VISIBLE_DEVICES").count(",") + 1)
568
- which_gpus = os.getenv("CUDA_VISIBLE_DEVICES").split(",")
569
- which_gpus = [int(x) for x in which_gpus]
570
- else:
571
- which_gpus = list(range(0, total_gpus))
572
-
573
- return total_gpus, which_gpus
574
-
575
-
576
- def get_ngpus_vis(raise_if_exception=True):
577
- ngpus_vis1 = 0
578
-
579
- shell = False
580
- if shell:
581
- cmd = "nvidia-smi -L 2> /dev/null"
582
- else:
583
- cmd = ["nvidia-smi", "-L"]
584
-
585
- try:
586
- timeout = 5 * 3
587
- o = subprocess.check_output(cmd, shell=shell, timeout=timeout)
588
- lines = o.decode("utf-8").splitlines()
589
- ngpus_vis1 = 0
590
- for line in lines:
591
- if 'Failed to initialize NVML' not in line:
592
- ngpus_vis1 += 1
593
- except (FileNotFoundError, subprocess.CalledProcessError, OSError):
594
- # GPU systems might not have nvidia-smi, so can't fail
595
- pass
596
- except subprocess.TimeoutExpired as e:
597
- print('Failed get_ngpus_vis: %s' % str(e))
598
- if raise_if_exception:
599
- raise
600
-
601
- ngpus_vis1, which_gpus = cuda_vis_check(ngpus_vis1)
602
- return ngpus_vis1
603
-
604
-
605
- def get_mem_gpus(raise_if_exception=True, ngpus=None):
606
- totalmem_gpus1 = 0
607
- usedmem_gpus1 = 0
608
- freemem_gpus1 = 0
609
-
610
- if ngpus == 0:
611
- return totalmem_gpus1, usedmem_gpus1, freemem_gpus1
612
-
613
- try:
614
- cmd = "nvidia-smi -q 2> /dev/null | grep -A 3 'FB Memory Usage'"
615
- o = subprocess.check_output(cmd, shell=True, timeout=15)
616
- lines = o.decode("utf-8").splitlines()
617
- for line in lines:
618
- if 'Total' in line:
619
- totalmem_gpus1 += int(line.split()[2]) * 1024 ** 2
620
- if 'Used' in line:
621
- usedmem_gpus1 += int(line.split()[2]) * 1024 ** 2
622
- if 'Free' in line:
623
- freemem_gpus1 += int(line.split()[2]) * 1024 ** 2
624
- except (FileNotFoundError, subprocess.CalledProcessError, OSError):
625
- # GPU systems might not have nvidia-smi, so can't fail
626
- pass
627
- except subprocess.TimeoutExpired as e:
628
- print('Failed get_mem_gpus: %s' % str(e))
629
- if raise_if_exception:
630
- raise
631
-
632
- return totalmem_gpus1, usedmem_gpus1, freemem_gpus1
633
-
634
-
635
- class ForkContext(threading.local):
636
- """
637
- Set context for forking
638
- Ensures state is returned once done
639
- """
640
-
641
- def __init__(self, args=None, kwargs=None, forkdata_capable=True):
642
- """
643
- :param args:
644
- :param kwargs:
645
- :param forkdata_capable: whether fork is forkdata capable and will use copy-on-write forking of args/kwargs
646
- """
647
- self.forkdata_capable = forkdata_capable
648
- if self.forkdata_capable:
649
- self.has_args = args is not None
650
- self.has_kwargs = kwargs is not None
651
- forkdatacontext.args = args
652
- forkdatacontext.kwargs = kwargs
653
- else:
654
- self.has_args = False
655
- self.has_kwargs = False
656
-
657
- def __enter__(self):
658
- try:
659
- # flush all outputs so doesn't happen during fork -- don't print/log inside ForkContext contexts!
660
- sys.stdout.flush()
661
- sys.stderr.flush()
662
- except BaseException as e:
663
- # exit not called if exception, and don't want to leave forkdatacontext filled in that case
664
- print("ForkContext failure on enter: %s" % str(e))
665
- self.finally_act()
666
- raise
667
- return self
668
-
669
- def __exit__(self, exc_type, exc_value, exc_traceback):
670
- self.finally_act()
671
-
672
- def finally_act(self):
673
- """
674
- Done when exception hit or exit is reached in context
675
- first reset forkdatacontext as crucial to have reset even if later 2 calls fail
676
- :return: None
677
- """
678
- if self.forkdata_capable and (self.has_args or self.has_kwargs):
679
- forkdatacontext._reset()
680
-
681
-
682
- class _ForkDataContext(threading.local):
683
- def __init__(
684
- self,
685
- args=None,
686
- kwargs=None,
687
- ):
688
- """
689
- Global context for fork to carry data to subprocess instead of relying upon copy/pickle/serialization
690
-
691
- :param args: args
692
- :param kwargs: kwargs
693
- """
694
- assert isinstance(args, (tuple, type(None)))
695
- assert isinstance(kwargs, (dict, type(None)))
696
- self.__args = args
697
- self.__kwargs = kwargs
698
-
699
- @property
700
- def args(self) -> Tuple:
701
- """returns args"""
702
- return self.__args
703
-
704
- @args.setter
705
- def args(self, args):
706
- if self.__args is not None:
707
- raise AttributeError(
708
- "args cannot be overwritten: %s %s" % (str(self.__args), str(self.__kwargs))
709
- )
710
-
711
- self.__args = args
712
-
713
- @property
714
- def kwargs(self) -> Dict:
715
- """returns kwargs"""
716
- return self.__kwargs
717
-
718
- @kwargs.setter
719
- def kwargs(self, kwargs):
720
- if self.__kwargs is not None:
721
- raise AttributeError(
722
- "kwargs cannot be overwritten: %s %s" % (str(self.__args), str(self.__kwargs))
723
- )
724
-
725
- self.__kwargs = kwargs
726
-
727
- def _reset(self):
728
- """Reset fork arg-kwarg context to default values"""
729
- self.__args = None
730
- self.__kwargs = None
731
-
732
- def get_args_kwargs(self, func, args, kwargs) -> Tuple[Callable, Tuple, Dict]:
733
- if self.__args:
734
- args = self.__args[1:]
735
- if not func:
736
- assert len(self.__args) > 0, "if have no func, must have in args"
737
- func = self.__args[0] # should always be there
738
- if self.__kwargs:
739
- kwargs = self.__kwargs
740
- try:
741
- return func, args, kwargs
742
- finally:
743
- forkdatacontext._reset()
744
-
745
- @staticmethod
746
- def get_args_kwargs_for_traced_func(func, args, kwargs):
747
- """
748
- Return args/kwargs out of forkdatacontext when using copy-on-write way of passing args/kwargs
749
- :param func: actual function ran by _traced_func, which itself is directly what mppool treats as function
750
- :param args:
751
- :param kwargs:
752
- :return: func, args, kwargs from forkdatacontext if used, else originals
753
- """
754
- # first 3 lines are debug
755
- func_was_None = func is None
756
- args_was_None_or_empty = args is None or len(args) == 0
757
- kwargs_was_None_or_empty = kwargs is None or len(kwargs) == 0
758
-
759
- forkdatacontext_args_was_None = forkdatacontext.args is None
760
- forkdatacontext_kwargs_was_None = forkdatacontext.kwargs is None
761
- func, args, kwargs = forkdatacontext.get_args_kwargs(func, args, kwargs)
762
- using_forkdatacontext = func_was_None and func is not None # pulled func out of forkdatacontext.__args[0]
763
- assert forkdatacontext.args is None, "forkdatacontext.args should be None after get_args_kwargs"
764
- assert forkdatacontext.kwargs is None, "forkdatacontext.kwargs should be None after get_args_kwargs"
765
-
766
- proc_type = kwargs.get('proc_type', 'SUBPROCESS')
767
- if using_forkdatacontext:
768
- assert proc_type == "SUBPROCESS" or proc_type == "SUBPROCESS"
769
- if proc_type == "NORMAL":
770
- assert forkdatacontext_args_was_None, "if no fork, expect forkdatacontext.args None entering _traced_func"
771
- assert forkdatacontext_kwargs_was_None, "if no fork, expect forkdatacontext.kwargs None entering _traced_func"
772
- assert func is not None, "function should not be None, indicates original args[0] was None or args was None"
773
-
774
- return func, args, kwargs
775
-
776
-
777
- forkdatacontext = _ForkDataContext()
778
-
779
-
780
- def _traced_func(func, *args, **kwargs):
781
- func, args, kwargs = forkdatacontext.get_args_kwargs_for_traced_func(func, args, kwargs)
782
- return func(*args, **kwargs)
783
-
784
-
785
- def call_subprocess_onetask(func, args=None, kwargs=None):
786
- import platform
787
- if platform.system() in ['Darwin', 'Windows']:
788
- return func(*args, **kwargs)
789
- if isinstance(args, list):
790
- args = tuple(args)
791
- if args is None:
792
- args = ()
793
- if kwargs is None:
794
- kwargs = {}
795
- args = list(args)
796
- args = [func] + args
797
- args = tuple(args)
798
- with ForkContext(args=args, kwargs=kwargs):
799
- args = (None,)
800
- kwargs = {}
801
- with ProcessPoolExecutor(max_workers=1) as executor:
802
- future = executor.submit(_traced_func, *args, **kwargs)
803
- return future.result()
804
-
805
-
806
- class ProgressParallel(Parallel):
807
- def __init__(self, use_tqdm=True, total=None, *args, **kwargs):
808
- self._use_tqdm = use_tqdm
809
- self._total = total
810
- super().__init__(*args, **kwargs)
811
-
812
- def __call__(self, *args, **kwargs):
813
- with tqdm(disable=not self._use_tqdm, total=self._total) as self._pbar:
814
- return Parallel.__call__(self, *args, **kwargs)
815
-
816
- def print_progress(self):
817
- if self._total is None:
818
- self._pbar.total = self.n_dispatched_tasks
819
- self._pbar.n = self.n_completed_tasks
820
- self._pbar.refresh()
821
-
822
-
823
- def get_kwargs(func, exclude_names=None, **kwargs):
824
- func_names = list(inspect.signature(func).parameters)
825
- missing_kwargs = [x for x in func_names if x not in kwargs]
826
- if exclude_names:
827
- for k in exclude_names:
828
- if k in missing_kwargs:
829
- missing_kwargs.remove(k)
830
- if k in func_names:
831
- func_names.remove(k)
832
- assert not missing_kwargs, "Missing %s" % missing_kwargs
833
- kwargs = {k: v for k, v in kwargs.items() if k in func_names}
834
- return kwargs
835
-
836
-
837
- import pkg_resources
838
-
839
- have_faiss = False
840
-
841
- try:
842
- assert pkg_resources.get_distribution('faiss') is not None
843
- have_faiss = True
844
- except (pkg_resources.DistributionNotFound, AssertionError):
845
- pass
846
- try:
847
- assert pkg_resources.get_distribution('faiss_gpu') is not None
848
- have_faiss = True
849
- except (pkg_resources.DistributionNotFound, AssertionError):
850
- pass
851
- try:
852
- assert pkg_resources.get_distribution('faiss_cpu') is not None
853
- have_faiss = True
854
- except (pkg_resources.DistributionNotFound, AssertionError):
855
- pass
856
-
857
-
858
- def hash_file(file):
859
- try:
860
- import hashlib
861
-
862
- # BUF_SIZE is totally arbitrary, change for your app!
863
- BUF_SIZE = 65536 # lets read stuff in 64kb chunks!
864
-
865
- md5 = hashlib.md5()
866
- # sha1 = hashlib.sha1()
867
-
868
- with open(file, 'rb') as f:
869
- while True:
870
- data = f.read(BUF_SIZE)
871
- if not data:
872
- break
873
- md5.update(data)
874
- # sha1.update(data)
875
- except BaseException as e:
876
- print("Cannot hash %s due to %s" % (file, str(e)))
877
- traceback.print_exc()
878
- md5 = None
879
- return md5.hexdigest()
880
-
881
-
882
- def start_faulthandler():
883
- # If hit server or any subprocess with signal SIGUSR1, it'll print out all threads stack trace, but wont't quit or coredump
884
- # If more than one fork tries to write at same time, then looks corrupted.
885
- import faulthandler
886
-
887
- # SIGUSR1 in h2oai/__init__.py as well
888
- faulthandler.enable()
889
- if hasattr(faulthandler, 'register'):
890
- # windows/mac
891
- import signal
892
- faulthandler.register(signal.SIGUSR1)
893
-
894
-
895
- def get_hf_server(inference_server):
896
- inf_split = inference_server.split(" ")
897
- assert len(inf_split) == 1 or len(inf_split) == 3
898
- inference_server = inf_split[0]
899
- if len(inf_split) == 3:
900
- headers = {"authorization": "%s %s" % (inf_split[1], inf_split[2])}
901
- else:
902
- headers = None
903
- return inference_server, headers
904
-
905
-
906
- class FakeTokenizer:
907
- """
908
- 1) For keeping track of model_max_length
909
- 2) For when model doesn't directly expose tokenizer but need to count tokens
910
- """
911
-
912
- def __init__(self, model_max_length=2048, encoding_name="cl100k_base"):
913
- # dont' push limit, since if using fake tokenizer, only estimate, and seen underestimates by order 250
914
- self.model_max_length = model_max_length - 250
915
- self.encoding_name = encoding_name
916
- # The first time this runs, it will require an internet connection to download. Later runs won't need an internet connection.
917
- import tiktoken
918
- self.encoding = tiktoken.get_encoding(self.encoding_name)
919
-
920
- def encode(self, x, *args, return_tensors="pt", **kwargs):
921
- input_ids = self.encoding.encode(x, disallowed_special=())
922
- if return_tensors == 'pt' and isinstance(input_ids, list):
923
- import torch
924
- input_ids = torch.tensor(input_ids)
925
- return dict(input_ids=input_ids)
926
-
927
- def decode(self, x, *args, **kwargs):
928
- # input is input_ids[0] form
929
- return self.encoding.decode(x)
930
-
931
- def num_tokens_from_string(self, prompt: str) -> int:
932
- """Returns the number of tokens in a text string."""
933
- num_tokens = len(self.encoding.encode(prompt))
934
- return num_tokens
935
-
936
- def __call__(self, x, *args, **kwargs):
937
- return self.encode(x, *args, **kwargs)
938
-
939
-
940
- def get_local_ip():
941
- import socket
942
- s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
943
- try:
944
- # doesn't even have to be reachable
945
- s.connect(('10.255.255.255', 1))
946
- IP = s.getsockname()[0]
947
- except Exception:
948
- IP = '127.0.0.1'
949
- finally:
950
- s.close()
951
- return IP
952
-
953
-
954
- try:
955
- assert pkg_resources.get_distribution('langchain') is not None
956
- have_langchain = True
957
- except (pkg_resources.DistributionNotFound, AssertionError):
958
- have_langchain = False
959
-
960
- import distutils.spawn
961
-
962
- have_tesseract = distutils.spawn.find_executable("tesseract")
963
- have_libreoffice = distutils.spawn.find_executable("libreoffice")
964
-
965
- import pkg_resources
966
-
967
- try:
968
- assert pkg_resources.get_distribution('arxiv') is not None
969
- assert pkg_resources.get_distribution('pymupdf') is not None
970
- have_arxiv = True
971
- except (pkg_resources.DistributionNotFound, AssertionError):
972
- have_arxiv = False
973
-
974
- try:
975
- assert pkg_resources.get_distribution('pymupdf') is not None
976
- have_pymupdf = True
977
- except (pkg_resources.DistributionNotFound, AssertionError):
978
- have_pymupdf = False
979
-
980
- try:
981
- assert pkg_resources.get_distribution('selenium') is not None
982
- have_selenium = True
983
- except (pkg_resources.DistributionNotFound, AssertionError):
984
- have_selenium = False
985
-
986
- try:
987
- assert pkg_resources.get_distribution('playwright') is not None
988
- have_playwright = True
989
- except (pkg_resources.DistributionNotFound, AssertionError):
990
- have_playwright = False
991
-
992
- # disable, hangs too often
993
- have_playwright = False
994
-
995
-
996
- def set_openai(inference_server):
997
- if inference_server.startswith('vllm'):
998
- import openai_vllm
999
- openai_vllm.api_key = "EMPTY"
1000
- inf_type = inference_server.split(':')[0]
1001
- ip_vllm = inference_server.split(':')[1]
1002
- port_vllm = inference_server.split(':')[2]
1003
- openai_vllm.api_base = f"http://{ip_vllm}:{port_vllm}/v1"
1004
- return openai_vllm, inf_type
1005
- else:
1006
- import openai
1007
- openai.api_key = os.getenv("OPENAI_API_KEY")
1008
- openai.api_base = os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1")
1009
- inf_type = inference_server
1010
- return openai, inf_type
1011
-
1012
-
1013
- visible_langchain_modes_file = 'visible_langchain_modes.pkl'
1014
-
1015
-
1016
- def save_collection_names(langchain_modes, visible_langchain_modes, langchain_mode_paths, LangChainMode, db1s):
1017
- """
1018
- extra controls if UserData type of MyData type
1019
- """
1020
-
1021
- # use first default MyData hash as general user hash to maintain file
1022
- # if user moves MyData from langchain modes, db will still survive, so can still use hash
1023
- scratch_collection_names = list(db1s.keys())
1024
- user_hash = db1s.get(LangChainMode.MY_DATA.value, '')[1]
1025
-
1026
- llms = ['LLM', 'Disabled']
1027
-
1028
- scratch_langchain_modes = [x for x in langchain_modes if x in scratch_collection_names]
1029
- scratch_visible_langchain_modes = [x for x in visible_langchain_modes if x in scratch_collection_names]
1030
- scratch_langchain_mode_paths = {k: v for k, v in langchain_mode_paths.items() if
1031
- k in scratch_collection_names and k not in llms}
1032
-
1033
- user_langchain_modes = [x for x in langchain_modes if x not in scratch_collection_names]
1034
- user_visible_langchain_modes = [x for x in visible_langchain_modes if x not in scratch_collection_names]
1035
- user_langchain_mode_paths = {k: v for k, v in langchain_mode_paths.items() if
1036
- k not in scratch_collection_names and k not in llms}
1037
-
1038
- base_path = 'locks'
1039
- makedirs(base_path)
1040
-
1041
- # user
1042
- extra = ''
1043
- file = "%s%s" % (visible_langchain_modes_file, extra)
1044
- with filelock.FileLock(os.path.join(base_path, "%s.lock" % file)):
1045
- with open(file, 'wb') as f:
1046
- pickle.dump((user_langchain_modes, user_visible_langchain_modes, user_langchain_mode_paths), f)
1047
-
1048
- # scratch
1049
- extra = user_hash
1050
- file = "%s%s" % (visible_langchain_modes_file, extra)
1051
- with filelock.FileLock(os.path.join(base_path, "%s.lock" % file)):
1052
- with open(file, 'wb') as f:
1053
- pickle.dump((scratch_langchain_modes, scratch_visible_langchain_modes, scratch_langchain_mode_paths), f)
1054
-
1055
-
1056
- def load_collection_enum(extra):
1057
- """
1058
- extra controls if UserData type of MyData type
1059
- """
1060
- file = "%s%s" % (visible_langchain_modes_file, extra)
1061
- langchain_modes_from_file = []
1062
- visible_langchain_modes_from_file = []
1063
- langchain_mode_paths_from_file = {}
1064
- if os.path.isfile(visible_langchain_modes_file):
1065
- try:
1066
- with filelock.FileLock("%s.lock" % file):
1067
- with open(file, 'rb') as f:
1068
- langchain_modes_from_file, visible_langchain_modes_from_file, langchain_mode_paths_from_file = pickle.load(
1069
- f)
1070
- except BaseException as e:
1071
- print("Cannot load %s, ignoring error: %s" % (file, str(e)), flush=True)
1072
- for k, v in langchain_mode_paths_from_file.items():
1073
- if v is not None and not os.path.isdir(v) and isinstance(v, str):
1074
- # assume was deleted, but need to make again to avoid extra code elsewhere
1075
- makedirs(v)
1076
- return langchain_modes_from_file, visible_langchain_modes_from_file, langchain_mode_paths_from_file
1077
-
1078
-
1079
- def remove_collection_enum():
1080
- remove(visible_langchain_modes_file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils_langchain.py DELETED
@@ -1,64 +0,0 @@
1
- from typing import Any, Dict, List, Union, Optional
2
- import time
3
- import queue
4
-
5
- from langchain.callbacks.base import BaseCallbackHandler
6
- from langchain.schema import LLMResult
7
-
8
-
9
- class StreamingGradioCallbackHandler(BaseCallbackHandler):
10
- """
11
- Similar to H2OTextIteratorStreamer that is for HF backend, but here LangChain backend
12
- """
13
- def __init__(self, timeout: Optional[float] = None, block=True):
14
- super().__init__()
15
- self.text_queue = queue.SimpleQueue()
16
- self.stop_signal = None
17
- self.do_stop = False
18
- self.timeout = timeout
19
- self.block = block
20
-
21
- def on_llm_start(
22
- self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
23
- ) -> None:
24
- """Run when LLM starts running. Clean the queue."""
25
- while not self.text_queue.empty():
26
- try:
27
- self.text_queue.get(block=False)
28
- except queue.Empty:
29
- continue
30
-
31
- def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
32
- """Run on new LLM token. Only available when streaming is enabled."""
33
- self.text_queue.put(token)
34
-
35
- def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
36
- """Run when LLM ends running."""
37
- self.text_queue.put(self.stop_signal)
38
-
39
- def on_llm_error(
40
- self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
41
- ) -> None:
42
- """Run when LLM errors."""
43
- self.text_queue.put(self.stop_signal)
44
-
45
- def __iter__(self):
46
- return self
47
-
48
- def __next__(self):
49
- while True:
50
- try:
51
- value = self.stop_signal # value looks unused in pycharm, not true
52
- if self.do_stop:
53
- print("hit stop", flush=True)
54
- # could raise or break, maybe best to raise and make parent see if any exception in thread
55
- raise StopIteration()
56
- # break
57
- value = self.text_queue.get(block=self.block, timeout=self.timeout)
58
- break
59
- except queue.Empty:
60
- time.sleep(0.01)
61
- if value == self.stop_signal:
62
- raise StopIteration()
63
- else:
64
- return value