pseudotensor commited on
Commit
54f4f91
·
1 Parent(s): 3f42f2e

Update with h2oGPT hash e4482a4c59016517cd0d5513bc15b78b46f4598a

Browse files
LICENSE DELETED
@@ -1,201 +0,0 @@
1
- Apache License
2
- Version 2.0, January 2004
3
- http://www.apache.org/licenses/
4
-
5
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
-
7
- 1. Definitions.
8
-
9
- "License" shall mean the terms and conditions for use, reproduction,
10
- and distribution as defined by Sections 1 through 9 of this document.
11
-
12
- "Licensor" shall mean the copyright owner or entity authorized by
13
- the copyright owner that is granting the License.
14
-
15
- "Legal Entity" shall mean the union of the acting entity and all
16
- other entities that control, are controlled by, or are under common
17
- control with that entity. For the purposes of this definition,
18
- "control" means (i) the power, direct or indirect, to cause the
19
- direction or management of such entity, whether by contract or
20
- otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
- outstanding shares, or (iii) beneficial ownership of such entity.
22
-
23
- "You" (or "Your") shall mean an individual or Legal Entity
24
- exercising permissions granted by this License.
25
-
26
- "Source" form shall mean the preferred form for making modifications,
27
- including but not limited to software source code, documentation
28
- source, and configuration files.
29
-
30
- "Object" form shall mean any form resulting from mechanical
31
- transformation or translation of a Source form, including but
32
- not limited to compiled object code, generated documentation,
33
- and conversions to other media types.
34
-
35
- "Work" shall mean the work of authorship, whether in Source or
36
- Object form, made available under the License, as indicated by a
37
- copyright notice that is included in or attached to the work
38
- (an example is provided in the Appendix below).
39
-
40
- "Derivative Works" shall mean any work, whether in Source or Object
41
- form, that is based on (or derived from) the Work and for which the
42
- editorial revisions, annotations, elaborations, or other modifications
43
- represent, as a whole, an original work of authorship. For the purposes
44
- of this License, Derivative Works shall not include works that remain
45
- separable from, or merely link (or bind by name) to the interfaces of,
46
- the Work and Derivative Works thereof.
47
-
48
- "Contribution" shall mean any work of authorship, including
49
- the original version of the Work and any modifications or additions
50
- to that Work or Derivative Works thereof, that is intentionally
51
- submitted to Licensor for inclusion in the Work by the copyright owner
52
- or by an individual or Legal Entity authorized to submit on behalf of
53
- the copyright owner. For the purposes of this definition, "submitted"
54
- means any form of electronic, verbal, or written communication sent
55
- to the Licensor or its representatives, including but not limited to
56
- communication on electronic mailing lists, source code control systems,
57
- and issue tracking systems that are managed by, or on behalf of, the
58
- Licensor for the purpose of discussing and improving the Work, but
59
- excluding communication that is conspicuously marked or otherwise
60
- designated in writing by the copyright owner as "Not a Contribution."
61
-
62
- "Contributor" shall mean Licensor and any individual or Legal Entity
63
- on behalf of whom a Contribution has been received by Licensor and
64
- subsequently incorporated within the Work.
65
-
66
- 2. Grant of Copyright License. Subject to the terms and conditions of
67
- this License, each Contributor hereby grants to You a perpetual,
68
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
- copyright license to reproduce, prepare Derivative Works of,
70
- publicly display, publicly perform, sublicense, and distribute the
71
- Work and such Derivative Works in Source or Object form.
72
-
73
- 3. Grant of Patent License. Subject to the terms and conditions of
74
- this License, each Contributor hereby grants to You a perpetual,
75
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
- (except as stated in this section) patent license to make, have made,
77
- use, offer to sell, sell, import, and otherwise transfer the Work,
78
- where such license applies only to those patent claims licensable
79
- by such Contributor that are necessarily infringed by their
80
- Contribution(s) alone or by combination of their Contribution(s)
81
- with the Work to which such Contribution(s) was submitted. If You
82
- institute patent litigation against any entity (including a
83
- cross-claim or counterclaim in a lawsuit) alleging that the Work
84
- or a Contribution incorporated within the Work constitutes direct
85
- or contributory patent infringement, then any patent licenses
86
- granted to You under this License for that Work shall terminate
87
- as of the date such litigation is filed.
88
-
89
- 4. Redistribution. You may reproduce and distribute copies of the
90
- Work or Derivative Works thereof in any medium, with or without
91
- modifications, and in Source or Object form, provided that You
92
- meet the following conditions:
93
-
94
- (a) You must give any other recipients of the Work or
95
- Derivative Works a copy of this License; and
96
-
97
- (b) You must cause any modified files to carry prominent notices
98
- stating that You changed the files; and
99
-
100
- (c) You must retain, in the Source form of any Derivative Works
101
- that You distribute, all copyright, patent, trademark, and
102
- attribution notices from the Source form of the Work,
103
- excluding those notices that do not pertain to any part of
104
- the Derivative Works; and
105
-
106
- (d) If the Work includes a "NOTICE" text file as part of its
107
- distribution, then any Derivative Works that You distribute must
108
- include a readable copy of the attribution notices contained
109
- within such NOTICE file, excluding those notices that do not
110
- pertain to any part of the Derivative Works, in at least one
111
- of the following places: within a NOTICE text file distributed
112
- as part of the Derivative Works; within the Source form or
113
- documentation, if provided along with the Derivative Works; or,
114
- within a display generated by the Derivative Works, if and
115
- wherever such third-party notices normally appear. The contents
116
- of the NOTICE file are for informational purposes only and
117
- do not modify the License. You may add Your own attribution
118
- notices within Derivative Works that You distribute, alongside
119
- or as an addendum to the NOTICE text from the Work, provided
120
- that such additional attribution notices cannot be construed
121
- as modifying the License.
122
-
123
- You may add Your own copyright statement to Your modifications and
124
- may provide additional or different license terms and conditions
125
- for use, reproduction, or distribution of Your modifications, or
126
- for any such Derivative Works as a whole, provided Your use,
127
- reproduction, and distribution of the Work otherwise complies with
128
- the conditions stated in this License.
129
-
130
- 5. Submission of Contributions. Unless You explicitly state otherwise,
131
- any Contribution intentionally submitted for inclusion in the Work
132
- by You to the Licensor shall be under the terms and conditions of
133
- this License, without any additional terms or conditions.
134
- Notwithstanding the above, nothing herein shall supersede or modify
135
- the terms of any separate license agreement you may have executed
136
- with Licensor regarding such Contributions.
137
-
138
- 6. Trademarks. This License does not grant permission to use the trade
139
- names, trademarks, service marks, or product names of the Licensor,
140
- except as required for reasonable and customary use in describing the
141
- origin of the Work and reproducing the content of the NOTICE file.
142
-
143
- 7. Disclaimer of Warranty. Unless required by applicable law or
144
- agreed to in writing, Licensor provides the Work (and each
145
- Contributor provides its Contributions) on an "AS IS" BASIS,
146
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
- implied, including, without limitation, any warranties or conditions
148
- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
- PARTICULAR PURPOSE. You are solely responsible for determining the
150
- appropriateness of using or redistributing the Work and assume any
151
- risks associated with Your exercise of permissions under this License.
152
-
153
- 8. Limitation of Liability. In no event and under no legal theory,
154
- whether in tort (including negligence), contract, or otherwise,
155
- unless required by applicable law (such as deliberate and grossly
156
- negligent acts) or agreed to in writing, shall any Contributor be
157
- liable to You for damages, including any direct, indirect, special,
158
- incidental, or consequential damages of any character arising as a
159
- result of this License or out of the use or inability to use the
160
- Work (including but not limited to damages for loss of goodwill,
161
- work stoppage, computer failure or malfunction, or any and all
162
- other commercial damages or losses), even if such Contributor
163
- has been advised of the possibility of such damages.
164
-
165
- 9. Accepting Warranty or Additional Liability. While redistributing
166
- the Work or Derivative Works thereof, You may choose to offer,
167
- and charge a fee for, acceptance of support, warranty, indemnity,
168
- or other liability obligations and/or rights consistent with this
169
- License. However, in accepting such obligations, You may act only
170
- on Your own behalf and on Your sole responsibility, not on behalf
171
- of any other Contributor, and only if You agree to indemnify,
172
- defend, and hold each Contributor harmless for any liability
173
- incurred by, or claims asserted against, such Contributor by reason
174
- of your accepting any such warranty or additional liability.
175
-
176
- END OF TERMS AND CONDITIONS
177
-
178
- APPENDIX: How to apply the Apache License to your work.
179
-
180
- To apply the Apache License to your work, attach the following
181
- boilerplate notice, with the fields enclosed by brackets "[]"
182
- replaced with your own identifying information. (Don't include
183
- the brackets!) The text should be enclosed in the appropriate
184
- comment syntax for the file format. We also recommend that a
185
- file or class name and description of purpose be included on the
186
- same "printed page" as the copyright notice for easier
187
- identification within third-party archives.
188
-
189
- Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
190
-
191
- Licensed under the Apache License, Version 2.0 (the "License");
192
- you may not use this file except in compliance with the License.
193
- You may obtain a copy of the License at
194
-
195
- http://www.apache.org/licenses/LICENSE-2.0
196
-
197
- Unless required by applicable law or agreed to in writing, software
198
- distributed under the License is distributed on an "AS IS" BASIS,
199
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
- See the License for the specific language governing permissions and
201
- limitations under the License.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
client_test.py CHANGED
@@ -12,13 +12,13 @@ Currently, this will force model to be on a single GPU.
12
 
13
  Then run this client as:
14
 
15
- python client_test.py
16
 
17
 
18
 
19
  For HF spaces:
20
 
21
- HOST="https://h2oai-h2ogpt-chatbot.hf.space" python client_test.py
22
 
23
  Result:
24
 
@@ -28,7 +28,7 @@ Loaded as API: https://h2oai-h2ogpt-chatbot.hf.space ✔
28
 
29
  For demo:
30
 
31
- HOST="https://gpt.h2o.ai" python client_test.py
32
 
33
  Result:
34
 
@@ -48,7 +48,7 @@ import markdown # pip install markdown
48
  import pytest
49
  from bs4 import BeautifulSoup # pip install beautifulsoup4
50
 
51
- from enums import DocumentChoices
52
 
53
  debug = False
54
 
@@ -67,7 +67,9 @@ def get_client(serialize=True):
67
  def get_args(prompt, prompt_type, chat=False, stream_output=False,
68
  max_new_tokens=50,
69
  top_k_docs=3,
70
- langchain_mode='Disabled'):
 
 
71
  from collections import OrderedDict
72
  kwargs = OrderedDict(instruction=prompt if chat else '', # only for chat=True
73
  iinput='', # only for chat=True
@@ -76,7 +78,7 @@ def get_args(prompt, prompt_type, chat=False, stream_output=False,
76
  # but leave stream_output=False for simple input/output mode
77
  stream_output=stream_output,
78
  prompt_type=prompt_type,
79
- prompt_dict='',
80
  temperature=0.1,
81
  top_p=0.75,
82
  top_k=40,
@@ -92,12 +94,13 @@ def get_args(prompt, prompt_type, chat=False, stream_output=False,
92
  instruction_nochat=prompt if not chat else '',
93
  iinput_nochat='', # only for chat=False
94
  langchain_mode=langchain_mode,
 
95
  top_k_docs=top_k_docs,
96
  chunk=True,
97
  chunk_size=512,
98
  document_choice=[DocumentChoices.All_Relevant.name],
99
  )
100
- from generate import eval_func_param_names
101
  assert len(set(eval_func_param_names).difference(set(list(kwargs.keys())))) == 0
102
  if chat:
103
  # add chatbot output on end. Assumes serialize=False
@@ -198,6 +201,7 @@ def run_client_nochat_api_lean_morestuff(prompt, prompt_type='human_bot', max_ne
198
  instruction_nochat=prompt,
199
  iinput_nochat='',
200
  langchain_mode='Disabled',
 
201
  top_k_docs=4,
202
  document_choice=['All'],
203
  )
@@ -219,21 +223,24 @@ def run_client_nochat_api_lean_morestuff(prompt, prompt_type='human_bot', max_ne
219
  @pytest.mark.skip(reason="For manual use against some server, no server launched")
220
  def test_client_chat(prompt_type='human_bot'):
221
  return run_client_chat(prompt='Who are you?', prompt_type=prompt_type, stream_output=False, max_new_tokens=50,
222
- langchain_mode='Disabled')
223
 
224
 
225
  @pytest.mark.skip(reason="For manual use against some server, no server launched")
226
  def test_client_chat_stream(prompt_type='human_bot'):
227
  return run_client_chat(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type,
228
  stream_output=True, max_new_tokens=512,
229
- langchain_mode='Disabled')
230
 
231
 
232
- def run_client_chat(prompt, prompt_type, stream_output, max_new_tokens, langchain_mode):
 
233
  client = get_client(serialize=False)
234
 
235
  kwargs, args = get_args(prompt, prompt_type, chat=True, stream_output=stream_output,
236
- max_new_tokens=max_new_tokens, langchain_mode=langchain_mode)
 
 
237
  return run_client(client, prompt, args, kwargs)
238
 
239
 
@@ -276,14 +283,15 @@ def run_client(client, prompt, args, kwargs, do_md_to_text=True, verbose=False):
276
  def test_client_nochat_stream(prompt_type='human_bot'):
277
  return run_client_nochat_gen(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type,
278
  stream_output=True, max_new_tokens=512,
279
- langchain_mode='Disabled')
280
 
281
 
282
- def run_client_nochat_gen(prompt, prompt_type, stream_output, max_new_tokens, langchain_mode):
283
  client = get_client(serialize=False)
284
 
285
  kwargs, args = get_args(prompt, prompt_type, chat=False, stream_output=stream_output,
286
- max_new_tokens=max_new_tokens, langchain_mode=langchain_mode)
 
287
  return run_client_gen(client, prompt, args, kwargs)
288
 
289
 
 
12
 
13
  Then run this client as:
14
 
15
+ python src/client_test.py
16
 
17
 
18
 
19
  For HF spaces:
20
 
21
+ HOST="https://h2oai-h2ogpt-chatbot.hf.space" python src/client_test.py
22
 
23
  Result:
24
 
 
28
 
29
  For demo:
30
 
31
+ HOST="https://gpt.h2o.ai" python src/client_test.py
32
 
33
  Result:
34
 
 
48
  import pytest
49
  from bs4 import BeautifulSoup # pip install beautifulsoup4
50
 
51
+ from enums import DocumentChoices, LangChainAction
52
 
53
  debug = False
54
 
 
67
  def get_args(prompt, prompt_type, chat=False, stream_output=False,
68
  max_new_tokens=50,
69
  top_k_docs=3,
70
+ langchain_mode='Disabled',
71
+ langchain_action=LangChainAction.QUERY.value,
72
+ prompt_dict=None):
73
  from collections import OrderedDict
74
  kwargs = OrderedDict(instruction=prompt if chat else '', # only for chat=True
75
  iinput='', # only for chat=True
 
78
  # but leave stream_output=False for simple input/output mode
79
  stream_output=stream_output,
80
  prompt_type=prompt_type,
81
+ prompt_dict=prompt_dict,
82
  temperature=0.1,
83
  top_p=0.75,
84
  top_k=40,
 
94
  instruction_nochat=prompt if not chat else '',
95
  iinput_nochat='', # only for chat=False
96
  langchain_mode=langchain_mode,
97
+ langchain_action=langchain_action,
98
  top_k_docs=top_k_docs,
99
  chunk=True,
100
  chunk_size=512,
101
  document_choice=[DocumentChoices.All_Relevant.name],
102
  )
103
+ from src.gen import eval_func_param_names
104
  assert len(set(eval_func_param_names).difference(set(list(kwargs.keys())))) == 0
105
  if chat:
106
  # add chatbot output on end. Assumes serialize=False
 
201
  instruction_nochat=prompt,
202
  iinput_nochat='',
203
  langchain_mode='Disabled',
204
+ langchain_action=LangChainAction.QUERY.value,
205
  top_k_docs=4,
206
  document_choice=['All'],
207
  )
 
223
  @pytest.mark.skip(reason="For manual use against some server, no server launched")
224
  def test_client_chat(prompt_type='human_bot'):
225
  return run_client_chat(prompt='Who are you?', prompt_type=prompt_type, stream_output=False, max_new_tokens=50,
226
+ langchain_mode='Disabled', langchain_action=LangChainAction.QUERY.value)
227
 
228
 
229
  @pytest.mark.skip(reason="For manual use against some server, no server launched")
230
  def test_client_chat_stream(prompt_type='human_bot'):
231
  return run_client_chat(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type,
232
  stream_output=True, max_new_tokens=512,
233
+ langchain_mode='Disabled', langchain_action=LangChainAction.QUERY.value)
234
 
235
 
236
+ def run_client_chat(prompt, prompt_type, stream_output, max_new_tokens, langchain_mode, langchain_action,
237
+ prompt_dict=None):
238
  client = get_client(serialize=False)
239
 
240
  kwargs, args = get_args(prompt, prompt_type, chat=True, stream_output=stream_output,
241
+ max_new_tokens=max_new_tokens, langchain_mode=langchain_mode,
242
+ langchain_action=langchain_action,
243
+ prompt_dict=prompt_dict)
244
  return run_client(client, prompt, args, kwargs)
245
 
246
 
 
283
  def test_client_nochat_stream(prompt_type='human_bot'):
284
  return run_client_nochat_gen(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type,
285
  stream_output=True, max_new_tokens=512,
286
+ langchain_mode='Disabled', langchain_action=LangChainAction.QUERY.value)
287
 
288
 
289
+ def run_client_nochat_gen(prompt, prompt_type, stream_output, max_new_tokens, langchain_mode, langchain_action):
290
  client = get_client(serialize=False)
291
 
292
  kwargs, args = get_args(prompt, prompt_type, chat=False, stream_output=stream_output,
293
+ max_new_tokens=max_new_tokens, langchain_mode=langchain_mode,
294
+ langchain_action=langchain_action)
295
  return run_client_gen(client, prompt, args, kwargs)
296
 
297
 
enums.py CHANGED
@@ -37,6 +37,9 @@ class DocumentChoices(Enum):
37
  Just_LLM = 3
38
 
39
 
 
 
 
40
  class LangChainMode(Enum):
41
  """LangChain mode"""
42
 
@@ -52,10 +55,22 @@ class LangChainMode(Enum):
52
  H2O_DAI_DOCS = "DriverlessAI docs"
53
 
54
 
 
 
 
 
 
 
 
 
 
 
 
55
  no_server_str = no_lora_str = no_model_str = '[None/Remove]'
56
 
57
 
58
- # from site-packages/langchain/llms/openai.py, but needed since ChatOpenAI doesn't have this information
 
59
  model_token_mapping = {
60
  "gpt-4": 8192,
61
  "gpt-4-0314": 8192,
 
37
  Just_LLM = 3
38
 
39
 
40
+ non_query_commands = [DocumentChoices.All_Relevant_Only_Sources.name, DocumentChoices.Only_All_Sources.name]
41
+
42
+
43
  class LangChainMode(Enum):
44
  """LangChain mode"""
45
 
 
55
  H2O_DAI_DOCS = "DriverlessAI docs"
56
 
57
 
58
+ class LangChainAction(Enum):
59
+ """LangChain action"""
60
+
61
+ QUERY = "Query"
62
+ # WIP:
63
+ #SUMMARIZE_MAP = "Summarize_map_reduce"
64
+ SUMMARIZE_MAP = "Summarize"
65
+ SUMMARIZE_ALL = "Summarize_all"
66
+ SUMMARIZE_REFINE = "Summarize_refine"
67
+
68
+
69
  no_server_str = no_lora_str = no_model_str = '[None/Remove]'
70
 
71
 
72
+ # from site-packages/langchain/llms/openai.py
73
+ # but needed since ChatOpenAI doesn't have this information
74
  model_token_mapping = {
75
  "gpt-4": 8192,
76
  "gpt-4-0314": 8192,
finetune.py DELETED
@@ -1,676 +0,0 @@
1
- import os
2
- import sys
3
- from functools import partial
4
- from typing import List, Union
5
- import fire
6
- import numpy as np
7
-
8
- if os.path.dirname(os.path.abspath(__file__)) not in sys.path:
9
- sys.path.append(os.path.dirname(os.path.abspath(__file__)))
10
-
11
- from loaders import get_loaders, get_tokenizer
12
- from prompter import generate_prompt, prompt_types, PromptType
13
- from utils import get_githash, copy_code
14
- import torch
15
-
16
-
17
- def log(*args, **kwargs):
18
- if int(os.environ.get("LOCAL_RANK", 0)) == 0:
19
- if 'flush' not in kwargs:
20
- kwargs['flush'] = True
21
- print(*args, **kwargs)
22
-
23
-
24
- # supported by huggingface evaluate
25
- supported_metrics = ['bleu', 'rouge', 'sacrebleu', 'meteor']
26
-
27
-
28
- def train(
29
- save_code: bool = False,
30
- run_id: int = None,
31
-
32
- base_model: str = 'h2oai/h2ogpt-oig-oasst1-512-6_9b',
33
- # base_model: str = 'h2oai/h2ogpt-oasst1-512-12b',
34
- # base_model: str = 'h2oai/h2ogpt-oasst1-512-20b',
35
- # base_model: str = 'EleutherAI/gpt-neox-20b',
36
- # base_model: str = 'EleutherAI/pythia-12b-deduped',
37
- # base_model: str = 'togethercomputer/GPT-NeoXT-Chat-Base-20B',
38
- # base_model: str = 'decapoda-research/llama-7b-hf',
39
- # base_model: str = 'decapoda-research/llama-13b-hf',
40
- # base_model: str = 'decapoda-research/llama-30b-hf',
41
- # base_model: str = 'EleutherAI/gpt-j-6B',
42
-
43
- # only needed if base_model is self-exported HF state without tokenizer
44
- tokenizer_base_model: str = None,
45
- # tokenizer_base_model: str = 'EleutherAI/gpt-neox-20b',
46
-
47
- data_path: str = "h2oai/openassistant_oasst1_h2ogpt",
48
- data_col_dict: dict = None,
49
- # data_path: str = "./dai_docs.train.json",
50
- prompt_type: Union[str, int] = "plain", # "plain", "instruct", "quality", "human_bot", "dai_faq"
51
-
52
- valid_path: str = None,
53
- # valid_path: str = "./dai_docs.valid.json",
54
-
55
- # data_mix_in_path: str = "laion/OIG", # way too big, medium quality
56
- data_mix_in_path: str = "0-hero/OIG-small-chip2", # high quality, 50 MB, good enough for now
57
- data_mix_in_factor: float = 0.0, # >1: more mix-in data, <1: more of data_path data
58
- data_mix_in_col_dict: dict = {'user': 'instruction', 'chip2': 'output'},
59
- data_mix_in_prompt_type: str = "instruct", # just instruction->output, same as instruct
60
-
61
- output_dir: str = None,
62
-
63
- # LoRA checkpoint continuation
64
- lora_weights: str = "",
65
-
66
- # batching training hyperparams
67
- batch_size: int = 128,
68
- micro_batch_size: int = 4,
69
- gradient_checkpointing=False, # unnecessary with gradient accumulation enabled
70
- fp16=True,
71
- train_8bit=False,
72
- train_4bit=False,
73
-
74
- # general training hyperparams
75
- num_epochs: float = 1,
76
- learning_rate: float = 3e-4,
77
-
78
- # validation settings
79
- val_set_size: int = None,
80
- val_metrics: List[str] = [],
81
- eval_steps: int = None, # to control eval steps via steps
82
- eval_epochs: float = None, # to control eval steps via epochs
83
-
84
- # lora hyperparams
85
- lora_r: int = 8,
86
- lora_alpha: int = 16,
87
- lora_dropout: float = 0.05,
88
- lora_target_modules: List[str] = None,
89
- llama_type: bool = None,
90
- llama_flash_attn: bool = False,
91
-
92
- # llm hyperparams
93
- train_on_inputs: bool = True, # if False, masks out inputs in loss
94
- group_by_length: bool = False, # if True, faster, but produces an odd training loss curve
95
- resume_from_checkpoint: str = None, # either training checkpoint or final adapter
96
- cutoff_len: int = 512, # larger values use more memory
97
- drop_truncations: bool = False, # if True, drop any truncated long sequences
98
-
99
- # torch training params
100
- ddp: bool = True, # set to False if OOM with True, for multi-GPU model parallelism
101
- local_files_only: bool = False, # else will download new versions, normally unwanted
102
- resume_download: bool = True,
103
- use_auth_token: Union[str, bool] = False, # True requires CLI did huggingface-cli login before running
104
- warmup_steps: int = 100,
105
- logging_steps: int = 1,
106
- save_steps: int = None, # must be round multiple of eval_steps
107
- save_total_limit: int = 3,
108
- add_eos_token: bool = False,
109
- ):
110
- if llama_flash_attn:
111
- # Need to call this before importing transformers.
112
- from llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
113
- replace_llama_attn_with_flash_attn()
114
-
115
- # allow set token directly
116
- use_auth_token = os.environ.get("HUGGINGFACE_API_TOKEN", use_auth_token)
117
-
118
- prompt_type = str(prompt_type) # migration from integers
119
- assert prompt_type in prompt_types
120
-
121
- world_size = int(os.getenv("WORLD_SIZE", 1))
122
- local_rank = int(os.getenv("LOCAL_RANK", 0))
123
- rank = int(os.getenv("RANK", 0))
124
- print(f"local_rank: {local_rank}")
125
- print(f"global rank: {rank}")
126
-
127
- gpus = max(world_size, torch.cuda.device_count())
128
- run_id = run_id or 0
129
- if not data_path:
130
- raise ValueError("No data_path provided")
131
- if not output_dir:
132
- output_dir = f"{base_model.split('/')[-1]}.{data_path.replace('/', '')}.{num_epochs}_epochs.{get_githash() or 'nogit'}.{run_id}"
133
- if os.path.exists(output_dir) and not resume_from_checkpoint:
134
- raise FileExistsError(
135
- f"output_dir {output_dir} based on run_id {run_id} already exists. Please pick a different run_id.")
136
- else:
137
- if os.path.exists(output_dir) and not resume_from_checkpoint:
138
- raise FileExistsError(
139
- f"output_dir {output_dir} already exists. Please pick a different output_dir, or specify a run_id instead.")
140
- device_map = "auto"
141
-
142
- if save_code:
143
- copy_code(run_id)
144
- if tokenizer_base_model is None:
145
- tokenizer_base_model = base_model
146
- if llama_type is None:
147
- llama_type = "llama" in base_model.lower()
148
- if llama_type and llama_flash_attn:
149
- import pkg_resources
150
- try:
151
- pkg_resources.get_distribution('flash_attn')
152
- can_do_flash_attn = True
153
- except (pkg_resources.DistributionNotFound, pkg_resources.ContextualVersionConflict):
154
- can_do_flash_attn = False
155
-
156
- if not can_do_flash_attn:
157
- raise RuntimeError("""Flash attention not installed.
158
- NOTE: for current pytorch 2.0, flash attention requires installing cuda 11.7 via https://developer.nvidia.com/cuda-11-7-0-download-archive?target_os=Linux&target_arch=x86_64&Distribution=Ubuntu&target_version=20.04&target_type=runfile_local and then when running, to avoid installing driver, docs, samples, just install toolkit. Then when pip installing flash attention do:
159
-
160
- CUDA_HOME=/usr/local/cuda-11.7 pip install flash-attn""")
161
- assert (
162
- base_model
163
- ), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
164
- gradient_accumulation_steps = batch_size // micro_batch_size
165
- assert gradient_accumulation_steps >= world_size, "must increase batch_size for multi-GPU"
166
-
167
- device_map = "auto"
168
-
169
- locals_dict = locals()
170
- locals_print = '\n'.join(['%s: %s' % (k, v) for k, v in locals_dict.items()])
171
- log(f"Training model with params:\n{locals_print}")
172
- log("Command: %s\nHash: %s" % (str(' '.join(sys.argv)), get_githash()))
173
-
174
- max_memory = None
175
- if gpus > 1:
176
- if ddp:
177
- log("Distributed: data parallel")
178
- device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
179
- gradient_accumulation_steps = gradient_accumulation_steps // world_size
180
- else:
181
- free_in_GB = int(min(torch.cuda.mem_get_info()) / 1024 ** 3)
182
- max_memory = f"{free_in_GB - 2}GB"
183
- max_memory = {i: max_memory for i in range(gpus)}
184
- log("world_size: %d" % world_size)
185
- log("num_gpus: %d" % gpus)
186
- log("max mem: %s" % max_memory)
187
-
188
- model_loader, tokenizer_loader = get_loaders(model_name=base_model, reward_type=False, llama_type=llama_type)
189
-
190
- model = model_loader.from_pretrained(
191
- base_model,
192
- load_in_8bit=train_8bit,
193
- load_in_4bit=train_4bit,
194
- device_map=device_map,
195
- torch_dtype=torch.float16,
196
- max_memory=max_memory,
197
- local_files_only=local_files_only,
198
- trust_remote_code=True,
199
- resume_download=resume_download,
200
- use_auth_token=use_auth_token,
201
- )
202
- if gpus > 1:
203
- if not ddp:
204
- log("model parallel")
205
- model.is_parallelizable = True
206
- model.model_parallel = True
207
-
208
- tokenizer = get_tokenizer(tokenizer_loader, tokenizer_base_model, local_files_only, resume_download, use_auth_token)
209
-
210
- if train_8bit or train_4bit:
211
- from peft import (
212
- prepare_model_for_kbit_training,
213
- )
214
-
215
- model = prepare_model_for_kbit_training(model)
216
-
217
- from peft import LoraConfig, get_peft_model, set_peft_model_state_dict
218
- try:
219
- from peft import utils
220
- lora_mappings = utils.TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.copy()
221
- except AttributeError:
222
- from peft import mapping
223
- lora_mappings = mapping.TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.copy()
224
- lora_mappings['distilgpt2'] = ["c_attn"]
225
-
226
- if lora_weights:
227
-
228
- from peft import PeftModel
229
- model = PeftModel.from_pretrained(
230
- model,
231
- lora_weights,
232
- torch_dtype=torch.float16,
233
- device_map=device_map,
234
- local_files_only=local_files_only,
235
- resume_download=resume_download,
236
- use_auth_token=use_auth_token,
237
- )
238
- elif lora_r > 0:
239
- if lora_target_modules is None:
240
- base_model_lower = base_model.lower()
241
- if base_model_lower in lora_mappings:
242
- lora_target_modules_cand = [lora_mappings[base_model_lower]]
243
- else:
244
- lora_target_modules_cand = [["query_key_value"], ["q_proj", "v_proj"]]
245
- else:
246
- lora_target_modules_cand = [lora_target_modules]
247
-
248
- for lora_target_modules in lora_target_modules_cand:
249
- try:
250
- config = LoraConfig(
251
- r=lora_r,
252
- lora_alpha=lora_alpha,
253
- target_modules=lora_target_modules,
254
- lora_dropout=lora_dropout,
255
- bias="none",
256
- task_type="CAUSAL_LM",
257
- )
258
- model = get_peft_model(model, config)
259
- break
260
- except ValueError as e:
261
- if "Target modules" in str(e) and "not found" in str(e):
262
- continue
263
- else:
264
- raise
265
- from peft import PeftModel
266
- assert isinstance(model, PeftModel), "LoRA failed. Please provide --lora_target_modules explicitly."
267
- if resume_from_checkpoint:
268
- # Check the available weights and load them
269
- checkpoint_name = os.path.join(
270
- resume_from_checkpoint, "pytorch_model.bin"
271
- ) # Full checkpoint
272
- if not os.path.exists(checkpoint_name):
273
- checkpoint_name = os.path.join(
274
- resume_from_checkpoint, "adapter_model.bin"
275
- ) # only LoRA model - LoRA config above has to fit
276
- resume_from_checkpoint = False # So the trainer won't try loading its state
277
- # The two files above have a different name depending on how they were saved, but are actually the same.
278
- if os.path.exists(checkpoint_name):
279
- log(f"Restarting from {checkpoint_name}")
280
- adapters_weights = torch.load(checkpoint_name)
281
- set_peft_model_state_dict(model, adapters_weights)
282
- else:
283
- log(f"Checkpoint {checkpoint_name} not found")
284
-
285
- print(model)
286
- try:
287
- # only for PeftModel
288
- model.print_trainable_parameters() # Be more transparent about the % of trainable params.
289
- except:
290
- pass
291
-
292
- metrics = {}
293
- for name in supported_metrics:
294
- if name in val_metrics:
295
- import evaluate # Causes hang for 'python generate.py' on dual 4090 if imported early, 100% reproducible
296
- metrics[name] = evaluate.load(name)
297
- log("Using Validation Metrics: %s" % str(list(metrics.keys())))
298
- log("Supported Metrics: %s" % supported_metrics)
299
-
300
- if val_set_size is None:
301
- if len(metrics) == 0:
302
- val_set_size = 1000
303
- else:
304
- val_set_size = 100
305
- log("Auto set val_set_size %s" % val_set_size)
306
- elif val_set_size < 1.0 and val_set_size != 0:
307
- raise RuntimeError("Fractional validation size not supported.")
308
-
309
- from datasets import load_dataset, concatenate_datasets
310
- if valid_path:
311
- data = load_dataset("json", data_files={"train": data_path, "valid": valid_path})
312
- else:
313
- if "json" in data_path:
314
- data = load_dataset("json", data_files={"train": data_path})
315
- else:
316
- data = load_dataset(data_path)
317
- data = data.rename_columns(data_col_dict or {})
318
-
319
- valid_data = None
320
- train_data_mix_in = None
321
- valid_data_mix_in = None
322
-
323
- if data_mix_in_path and data_mix_in_factor > 0:
324
- # get mix-in training/validation data - to keep model "sane"
325
- num_rows = data["train"].num_rows
326
- log("Loading mix-in dataset: %s" % data_mix_in_path)
327
- if "json" in data_mix_in_path:
328
- data_mix_in = load_dataset("json", data_files={"train": data_mix_in_path})["train"]
329
- else:
330
- data_mix_in = load_dataset(data_mix_in_path)["train"] # can be large
331
- data_mix_in = data_mix_in.rename_columns(data_mix_in_col_dict or {})
332
- mix_in_rows = int(num_rows * data_mix_in_factor)
333
-
334
- if mix_in_rows > data_mix_in.num_rows:
335
- # duplicate rows if mix-in is smaller than required
336
- log("Duplicating mixin to compensate for its size for training size and mixin fraction")
337
- data_mix_in = concatenate_datasets([data_mix_in] * int(np.ceil(mix_in_rows / data_mix_in.num_rows)))
338
-
339
- # only get as much as we need to balance
340
- valid_size = min(data_mix_in.num_rows // 2, val_set_size or 0)
341
- train_size = max(1, min(data_mix_in.num_rows - valid_size, mix_in_rows))
342
- mixin_small = data_mix_in.train_test_split(
343
- test_size=train_size + valid_size,
344
- shuffle=True, seed=np.random.randint(10000),
345
- )["test"]
346
- if valid_size:
347
- mixin_train_test = mixin_small.train_test_split(
348
- test_size=valid_size, shuffle=False,
349
- )
350
- train_data_mix_in = mixin_train_test["train"]
351
- valid_data_mix_in = mixin_train_test["test"]
352
- else:
353
- train_data_mix_in = mixin_small
354
-
355
- if "prompt_type" not in train_data_mix_in.column_names:
356
- train_data_mix_in = train_data_mix_in.add_column(
357
- "prompt_type",
358
- [data_mix_in_prompt_type] * train_data_mix_in.num_rows,
359
- )
360
- log("Added prompt type %s to mix-in training data" % data_mix_in_prompt_type)
361
- if valid_data_mix_in and "prompt_type" not in valid_data_mix_in.column_names:
362
- valid_data_mix_in = valid_data_mix_in.add_column(
363
- "prompt_type",
364
- [data_mix_in_prompt_type] * valid_data_mix_in.num_rows,
365
- )
366
- log("Added prompt type %s to mix-in validation data" % data_mix_in_prompt_type)
367
- log("Created mix-in data:\nTrain %s\nValid %s" % (train_data_mix_in, valid_data_mix_in))
368
-
369
- # get our own training/validation data - for fine-tuning
370
- if val_set_size > 0 and not valid_path and not data_mix_in_path:
371
- # create valid split from train
372
- train_val = data["train"].train_test_split(
373
- test_size=val_set_size, shuffle=True, seed=42
374
- )
375
- train_data = train_val["train"]
376
- valid_data = train_val["test"]
377
- else:
378
- train_data = data["train"]
379
- if valid_path:
380
- # use given valid split, has priority over data_mix_in_path
381
- valid_data = data["valid"]
382
- if "prompt_type" not in train_data.column_names:
383
- train_data = train_data.add_column(
384
- "prompt_type",
385
- [prompt_type] * train_data.num_rows,
386
- )
387
- log("Added prompt type %s to training data" % prompt_type)
388
- if valid_data and "prompt_type" not in valid_data.column_names:
389
- valid_data = valid_data.add_column(
390
- "prompt_type",
391
- [prompt_type] * valid_data.num_rows,
392
- )
393
- log("Added prompt type %s to validation data" % prompt_type)
394
-
395
- assert train_data is not None
396
-
397
- generate_and_tokenize_prompt_fun = partial(generate_and_tokenize_prompt, prompt_type=prompt_type,
398
- train_on_inputs=train_on_inputs, add_eos_token=add_eos_token,
399
- cutoff_len=cutoff_len, tokenizer=tokenizer)
400
-
401
- # shuffle and tokenize data
402
- if train_data_mix_in:
403
- train_data = concatenate_datasets([train_data, train_data_mix_in])
404
- log("Tokenizing %s training rows" % train_data.num_rows)
405
- train_data = train_data.shuffle().map(generate_and_tokenize_prompt_fun,
406
- num_proc=os.cpu_count() // torch.cuda.device_count())
407
- if drop_truncations:
408
- log("avoid keeping truncated cases to avoid contaminating model with truncation cases. Original size: %s" % train_data.num_rows)
409
- prune_long_sequences_func = partial(prune_long_sequences, cutoff_len=cutoff_len)
410
- train_data = train_data.filter(prune_long_sequences_func, num_proc=os.cpu_count() // torch.cuda.device_count())
411
- log("avoid keeping truncated cases to avoid contaminating model with truncation cases. New size: %s" % train_data.num_rows)
412
- train_set_size = len(train_data)
413
-
414
- if valid_data and valid_data_mix_in:
415
- valid_data = concatenate_datasets([valid_data, valid_data_mix_in])
416
- elif valid_data_mix_in:
417
- valid_data = valid_data_mix_in
418
-
419
- if valid_data:
420
- log("Tokenizing %s validation rows" % valid_data.num_rows)
421
- valid_data = valid_data.shuffle().map(generate_and_tokenize_prompt_fun,
422
- num_proc=os.cpu_count() // torch.cuda.device_count())
423
- val_set_size = len(valid_data)
424
- else:
425
- val_set_size = 0
426
- log("Final fine-tuning data:\nTrain %s\nValid %s" % (train_data, valid_data))
427
- sample_row_dict = train_data[:1]
428
- del sample_row_dict['input_ids']
429
- del sample_row_dict['attention_mask']
430
- del sample_row_dict['labels']
431
- log("Sample input: %s" % sample_row_dict)
432
-
433
- try:
434
- import neptune
435
- from transformers.integrations import NeptuneCallback
436
-
437
- neptune_run = neptune.init_run(
438
- source_files=[],
439
- )
440
- log("Connected to Neptune.")
441
- except ImportError:
442
- neptune_run = None
443
- log("Please pip install neptune for tracking.")
444
- except neptune.exceptions.NeptuneMissingApiTokenException:
445
- neptune_run = None
446
- os.environ["NEPTUNE_MODE"] = 'debug'
447
- log("No neptune configured, set NEPTUNE_API_TOKEN env var.")
448
-
449
- if neptune_run:
450
- neptune_callback = NeptuneCallback(run=neptune_run)
451
- callbacks = [neptune_callback]
452
- else:
453
- from transformers.integrations import TensorBoardCallback, is_tensorboard_available
454
- if is_tensorboard_available:
455
- # tensorboard --logdir=runs/
456
- from torch.utils.tensorboard import SummaryWriter
457
- tb_writer = SummaryWriter()
458
- callbacks = [TensorBoardCallback(tb_writer=tb_writer)]
459
- else:
460
- callbacks = []
461
-
462
- expected_steps = (train_set_size * num_epochs) // batch_size
463
- if eval_steps is None and eval_epochs is None:
464
- # 20 evaluations for a run
465
- eval_steps = max(1, int(expected_steps / 20))
466
- log("Auto set eval_steps to %s out of %s total training steps" % (eval_steps, expected_steps))
467
- elif eval_steps is None and eval_epochs is not None:
468
- eval_steps = max(1, int(expected_steps * eval_epochs / num_epochs))
469
- log("Auto converted eval_epochs=%s to eval_steps %s"
470
- " out of %s total training steps" % (eval_epochs, eval_steps, expected_steps))
471
- if save_steps is None:
472
- save_steps = eval_steps
473
- log("Auto step save_steps to %s" % save_steps)
474
- elif save_steps > eval_steps:
475
- # save steps must be round multiple of eval_steps
476
- save_steps0 = save_steps
477
- save_steps = max(1, (save_steps // eval_steps)) * eval_steps
478
- if save_steps0 != save_steps:
479
- log("Auto converted save_steps from %s to %s" % (save_steps0, save_steps))
480
-
481
- def compute_metrics(eval_preds):
482
- # e.g. see: https://huggingface.co/docs/transformers/v4.25.1/en/tasks/translation#evaluate
483
- inputs = eval_preds.inputs
484
- label_ids = eval_preds.label_ids
485
- predictions = eval_preds.predictions
486
-
487
- # inputs = np.where(inputs != -100, inputs, tokenizer.pad_token_id)
488
- # decoded_inputs = tokenizer.batch_decode(inputs, skip_special_tokens=True)
489
- # decoded_inputs = [pred.strip() for pred in decoded_inputs]
490
-
491
- label_ids = np.where(label_ids != -100, label_ids, tokenizer.pad_token_id)
492
- # tokenizer behavior like generate time
493
- decoded_labels = tokenizer.batch_decode(label_ids, skip_special_tokens=True,
494
- clean_up_tokenization_spaces=True)
495
- decoded_labels = [pred.strip() for pred in decoded_labels]
496
-
497
- predictions = np.argmax(predictions, -1)
498
- predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
499
- # tokenizer behavior like generate time
500
- decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True,
501
- clean_up_tokenization_spaces=True)
502
- decoded_predictions = [pred.strip() for pred in decoded_predictions]
503
-
504
- result = {}
505
- for metric in metrics.values():
506
- result1 = metric.compute(predictions=decoded_predictions, references=decoded_labels)
507
- # get rid of lists, for precision etc., for now
508
- numeric_results = {k: v for k, v in result1.items() if isinstance(v, (int, float))}
509
- result.update(numeric_results)
510
- return result
511
-
512
- # the callback that computes metrics of interest
513
- if val_metrics:
514
- trainer_kwargs = dict(compute_metrics=compute_metrics)
515
- else:
516
- trainer_kwargs = dict()
517
-
518
- import transformers
519
- trainer = transformers.Trainer(
520
- model=model,
521
- tokenizer=tokenizer,
522
- train_dataset=train_data,
523
- eval_dataset=valid_data,
524
- # FIXME: might need Seq2SeqTrainingArguments for some models
525
- args=transformers.TrainingArguments(
526
- per_device_train_batch_size=micro_batch_size,
527
- per_device_eval_batch_size=1,
528
- eval_accumulation_steps=10,
529
- # predict_with_generate=True, # SEQ2SEQ only
530
- include_inputs_for_metrics=True,
531
- gradient_accumulation_steps=gradient_accumulation_steps,
532
- warmup_steps=warmup_steps,
533
- num_train_epochs=num_epochs,
534
- learning_rate=learning_rate,
535
- gradient_checkpointing=gradient_checkpointing,
536
- fp16=fp16,
537
- # cosnider 8-bit adam: https://huggingface.co/docs/transformers/v4.18.0/en/performance#8bit-adam
538
- optim="adamw_torch", # consider "adafactor" to save memory
539
- logging_steps=logging_steps,
540
- logging_strategy="steps",
541
- evaluation_strategy="steps" if val_set_size > 0 else "no",
542
- save_strategy="steps",
543
- eval_steps=eval_steps if val_set_size > 0 else None,
544
- save_steps=save_steps,
545
- output_dir=output_dir,
546
- save_total_limit=save_total_limit,
547
- load_best_model_at_end=True if val_set_size > 0 else False,
548
- ddp_find_unused_parameters=False if ddp else None,
549
- group_by_length=group_by_length,
550
- # fsdp="shard_grad_op auto_wrap" if gpus > 1 and not ddp else None,
551
- # fsdp_min_num_params=20000 if gpus > 1 and not ddp else None,
552
- report_to='tensorboard' if not neptune_run else 'neptune',
553
- ),
554
- data_collator=transformers.DataCollatorForSeq2Seq(
555
- tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
556
- ),
557
- callbacks=callbacks,
558
- **trainer_kwargs,
559
- )
560
- model.config.use_cache = False
561
-
562
- if torch.__version__ >= "2" and sys.platform != "win32":
563
- model = torch.compile(model)
564
- # WIP (not generally replacing layers until pytorch 2.1)
565
- if not llama_flash_attn:
566
- torch.backends.cuda.enable_flash_sdp(True)
567
-
568
- if gpus > 1 and not ddp:
569
- assert trainer.is_model_parallel
570
- else:
571
- assert not trainer.is_model_parallel
572
- trainer.train(resume_from_checkpoint=resume_from_checkpoint)
573
-
574
- model.save_pretrained(output_dir)
575
-
576
- log("\n If there's a warning about missing keys above, please disregard :)")
577
-
578
-
579
- def tokenize(prompt, tokenizer, cutoff_len, add_eos_token=False):
580
- # there's probably a way to do this with the tokenizer settings
581
- # but again, gotta move fast
582
- result = tokenizer(
583
- prompt,
584
- truncation=True,
585
- max_length=cutoff_len,
586
- padding=False,
587
- return_tensors=None,
588
- )
589
- if (
590
- result["input_ids"][-1] != tokenizer.eos_token_id
591
- and len(result["input_ids"]) < cutoff_len
592
- and add_eos_token
593
- ):
594
- result["input_ids"].append(tokenizer.eos_token_id)
595
- result["attention_mask"].append(1)
596
-
597
- result["labels"] = result["input_ids"].copy()
598
-
599
- return result
600
-
601
-
602
- def prune_long_sequences(data_point, cutoff_len=None):
603
- """
604
- Prune if too long for tokenizer, so truncation doesn't lead training to learn from truncated language
605
- :param data_point:
606
- :param cutoff_len:
607
- :return:
608
- """
609
- assert cutoff_len is not None
610
- return len(data_point['input_ids']) < cutoff_len
611
-
612
-
613
- def generate_and_tokenize_prompt(data_point, prompt_type=None, train_on_inputs=False, add_eos_token=False,
614
- cutoff_len=None, tokenizer=None):
615
- assert prompt_type is not None
616
- assert cutoff_len is not None
617
- assert tokenizer is not None
618
- prompt_dict = '' # only for custom prompt_type
619
- assert prompt_type != PromptType.custom.name, "custom not setup for finetune"
620
- full_prompt, _, _, _, _ = generate_prompt(data_point, prompt_type, prompt_dict, False, False, False)
621
- tokenized_full_prompt = tokenize(full_prompt, tokenizer, cutoff_len, add_eos_token=add_eos_token)
622
- if not train_on_inputs:
623
- user_prompt, _, _, _, _ = generate_prompt({**data_point, "output": ""}, prompt_type, prompt_dict, False, False, False)
624
- tokenized_user_prompt = tokenize(user_prompt, tokenizer, cutoff_len, add_eos_token=add_eos_token)
625
- user_prompt_len = len(tokenized_user_prompt["input_ids"])
626
- if add_eos_token:
627
- user_prompt_len -= 1
628
-
629
- # ignore_index=-100 ensures torch/tf don't include padding token id in CrossEntropyLoss
630
- tokenized_full_prompt["labels"] = [
631
- -100
632
- ] * user_prompt_len + tokenized_full_prompt["labels"][
633
- user_prompt_len:
634
- ] # could be sped up, probably
635
- return tokenized_full_prompt
636
-
637
-
638
- def test_debug():
639
- fire.Fire(train)
640
-
641
-
642
- def entrypoint_main():
643
- CONFIG = "NCCL_P2P_LEVEL=LOC WORLD_SIZE=5 torchrun --nnodes=5 --master_addr=10.10.10.2 --master_port=1111 --nproc_per_node=1"
644
- CMD = "finetune.py --data_path=config.json --num_epochs=1 --base_model=decapoda-research/llama-13b-hf"
645
- log(f"""
646
- Example runs on 4 GPUs:
647
- WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='decapoda-research/llama-7b-hf' --data_path=data/config.json --run_id=0 &> 0.log
648
- WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='decapoda-research/llama-30b-hf' --data_path=data/config.json --batch_size=16 --micro_batch_size=1 --run_id=1 --save_code=True &> 1.log
649
- WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='EleutherAI/gpt-j-6B' --data_path=data/config.json --run_id=2 &> 2.log
650
- WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='EleutherAI/gpt-neox-20b' --data_path=data/config.json --run_id=8 --batch_size=16 --micro_batch_size=4 &> 8.log
651
- WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --data_path=data/config.json --prompt_type='dai_faq' --run_id=13 --batch_size=16 --micro_batch_size=4 --num_epochs=100 --val_set_size=0 data_mix_in_path='' &> 13.log
652
- WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --data_path=data/config.json --run_id=28 --batch_size=16 --micro_batch_size=4 --num_epochs=8 --val_set_size=0 --data_mix_in_factor=0.1 --data_mix_in_prompt_type='human_bot' --save_code=True --cutoff_len=512 &> 28.log
653
-
654
- All metrics:
655
- CUDA_VISIBLE_DEVICES= finetune.py --data_mix_in_factor=0 --eval_steps=100 --warmup_steps=2 --val_set_size=100 --val_metrics="['bleu', 'rouge', 'sacrebleu', 'meteor']"
656
-
657
- # Fine-tune 20B on 24GB GPUs across 3 nodes with 3+2+2 GPUs
658
- rippa>
659
- NCCL_P2P_LEVEL=LOC WORLD_SIZE=7 CUDA_VISIBLE_DEVICES="0,1,2" torchrun --node_rank 0 --nproc_per_node=3 --master_port=1234 --nnodes=3 --master_addr=10.10.10.2 finetune.py --data_path=merged_shuffled_OIG_87f6a1e788.json --micro_batch_size=1 --batch_size=7 --cutoff_len=512 --run_id=17 &>log.17.rank0
660
- ova>
661
- NCCL_P2P_LEVEL=LOC WORLD_SIZE=7 CUDA_VISIBLE_DEVICES="0,1" torchrun --node_rank 1 --nproc_per_node=2 --master_port=1234 --nnodes=3 --master_addr=10.10.10.2 finetune.py --data_path=merged_shuffled_OIG_87f6a1e788.json --micro_batch_size=1 --batch_size=7 --cutoff_len=512 --run_id=17 &>log.17.rank1
662
- timemachine>
663
- NCCL_P2P_LEVEL=LOC WORLD_SIZE=7 CUDA_VISIBLE_DEVICES="0,1" torchrun --node_rank 2 --nproc_per_node=2 --master_port=1234 --nnodes=3 --master_addr=10.10.10.2 finetune.py --data_path=merged_shuffled_OIG_87f6a1e788.json --micro_batch_size=1 --batch_size=7 --cutoff_len=512 --run_id=17 &>log.17.rank2
664
-
665
- """, flush=True)
666
-
667
- if os.environ.get("LOCAL_RANK") is None:
668
- # then not using torchrun, so can't do distributed, ensure CVD set
669
- assert os.environ.get(
670
- "CUDA_VISIBLE_DEVICES") is not None, "Run python script using: torchrun finetune.py OR set CUDA_VISIBLE_DEVICES to single GPU"
671
-
672
- fire.Fire(train)
673
-
674
-
675
- if __name__ == "__main__":
676
- entrypoint_main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
generate.py DELETED
The diff for this file is too large to render. See raw diff
 
gpt4all_llm.py CHANGED
@@ -19,6 +19,15 @@ def get_model_tokenizer_gpt4all(base_model, **kwargs):
19
  n_ctx=2048 - 256)
20
  env_gpt4all_file = ".env_gpt4all"
21
  model_kwargs.update(dotenv_values(env_gpt4all_file))
 
 
 
 
 
 
 
 
 
22
 
23
  if base_model == "llama":
24
  if 'model_path_llama' not in model_kwargs:
 
19
  n_ctx=2048 - 256)
20
  env_gpt4all_file = ".env_gpt4all"
21
  model_kwargs.update(dotenv_values(env_gpt4all_file))
22
+ # make int or float if can to satisfy types for class
23
+ for k, v in model_kwargs.items():
24
+ try:
25
+ if float(v) == int(v):
26
+ model_kwargs[k] = int(v)
27
+ else:
28
+ model_kwargs[k] = float(v)
29
+ except:
30
+ pass
31
 
32
  if base_model == "llama":
33
  if 'model_path_llama' not in model_kwargs:
gpt_langchain.py CHANGED
@@ -23,8 +23,9 @@ from langchain.callbacks import streaming_stdout
23
  from langchain.embeddings import HuggingFaceInstructEmbeddings
24
  from tqdm import tqdm
25
 
26
- from enums import DocumentChoices, no_lora_str, model_token_mapping, source_prefix, source_postfix
27
- from generate import gen_hyper, get_model, SEED
 
28
  from prompter import non_hf_types, PromptType, Prompter
29
  from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \
30
  get_device, ProgressParallel, remove, hash_file, clear_torch_cache, NullContext, get_hf_server, FakeTokenizer
@@ -43,7 +44,8 @@ from langchain.chains.qa_with_sources import load_qa_with_sources_chain
43
  from langchain.document_loaders import PyPDFLoader, TextLoader, CSVLoader, PythonLoader, TomlLoader, \
44
  UnstructuredURLLoader, UnstructuredHTMLLoader, UnstructuredWordDocumentLoader, UnstructuredMarkdownLoader, \
45
  EverNoteLoader, UnstructuredEmailLoader, UnstructuredODTLoader, UnstructuredPowerPointLoader, \
46
- UnstructuredEPubLoader, UnstructuredImageLoader, UnstructuredRTFLoader, ArxivLoader, UnstructuredPDFLoader
 
47
  from langchain.text_splitter import RecursiveCharacterTextSplitter, Language
48
  from langchain.chains.question_answering import load_qa_chain
49
  from langchain.docstore.document import Document
@@ -351,6 +353,7 @@ class GradioInference(LLM):
351
  stream_output = self.stream
352
  gr_client = self.client
353
  client_langchain_mode = 'Disabled'
 
354
  top_k_docs = 1
355
  chunk = True
356
  chunk_size = 512
@@ -379,6 +382,7 @@ class GradioInference(LLM):
379
  instruction_nochat=prompt if not self.chat_client else '',
380
  iinput_nochat='', # only for chat=False
381
  langchain_mode=client_langchain_mode,
 
382
  top_k_docs=top_k_docs,
383
  chunk=chunk,
384
  chunk_size=chunk_size,
@@ -637,6 +641,7 @@ def get_llm(use_openai_model=False,
637
  callbacks = [StreamingGradioCallbackHandler()]
638
  assert prompter is not None
639
  stop_sequences = list(set(prompter.terminate_response + [prompter.PreResponse]))
 
640
 
641
  if gr_client:
642
  chat_client = False
@@ -744,7 +749,7 @@ def get_llm(use_openai_model=False,
744
 
745
  if stream_output:
746
  skip_prompt = False
747
- from generate import H2OTextIteratorStreamer
748
  decoder_kwargs = {}
749
  streamer = H2OTextIteratorStreamer(tokenizer, skip_prompt=skip_prompt, block=False, **decoder_kwargs)
750
  gen_kwargs.update(dict(streamer=streamer))
@@ -944,14 +949,16 @@ have_playwright = False
944
 
945
  image_types = ["png", "jpg", "jpeg"]
946
  non_image_types = ["pdf", "txt", "csv", "toml", "py", "rst", "rtf",
947
- "md", "html",
 
948
  "enex", "eml", "epub", "odt", "pptx", "ppt",
949
  "zip", "urls",
 
950
  ]
951
  # "msg", GPL3
952
 
953
  if have_libreoffice:
954
- non_image_types.extend(["docx", "doc"])
955
 
956
  file_types = non_image_types + image_types
957
 
@@ -961,7 +968,7 @@ def add_meta(docs1, file):
961
  hashid = hash_file(file)
962
  if not isinstance(docs1, (list, tuple, types.GeneratorType)):
963
  docs1 = [docs1]
964
- [x.metadata.update(dict(input_type=file_extension, date=str(datetime.now), hashid=hashid)) for x in docs1]
965
 
966
 
967
  def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
@@ -1038,6 +1045,10 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
1038
  docs1 = UnstructuredWordDocumentLoader(file_path=file).load()
1039
  add_meta(docs1, file)
1040
  doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
 
 
 
 
1041
  elif file.lower().endswith('.odt'):
1042
  docs1 = UnstructuredODTLoader(file_path=file).load()
1043
  add_meta(docs1, file)
@@ -1758,6 +1769,8 @@ def run_qa_db(**kwargs):
1758
 
1759
 
1760
  def _run_qa_db(query=None,
 
 
1761
  use_openai_model=False, use_openai_embedding=False,
1762
  first_para=False, text_limit=None, top_k_docs=4, chunk=True, chunk_size=512,
1763
  user_path=None,
@@ -1787,6 +1800,7 @@ def _run_qa_db(query=None,
1787
  repetition_penalty=1.0,
1788
  num_return_sequences=1,
1789
  langchain_mode=None,
 
1790
  document_choice=[DocumentChoices.All_Relevant.name],
1791
  n_jobs=-1,
1792
  verbose=False,
@@ -1803,7 +1817,7 @@ def _run_qa_db(query=None,
1803
  :param use_openai_embedding:
1804
  :param first_para:
1805
  :param text_limit:
1806
- :param k:
1807
  :param chunk:
1808
  :param chunk_size:
1809
  :param user_path: user path to glob recursively from
@@ -1869,12 +1883,28 @@ def _run_qa_db(query=None,
1869
  sim_kwargs = {k: v for k, v in locals().items() if k in func_names}
1870
  missing_kwargs = [x for x in func_names if x not in sim_kwargs]
1871
  assert not missing_kwargs, "Missing: %s" % missing_kwargs
1872
- docs, chain, scores, use_context = get_similarity_chain(**sim_kwargs)
1873
- if cmd in [DocumentChoices.All_Relevant_Only_Sources.name, DocumentChoices.Only_All_Sources.name]:
1874
  formatted_doc_chunks = '\n\n'.join([get_url(x) + '\n\n' + x.page_content for x in docs])
1875
  yield formatted_doc_chunks, ''
1876
  return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1877
  if chain is None and model_name not in non_hf_types:
 
1878
  # can only return if HF type
1879
  return
1880
 
@@ -1933,6 +1963,7 @@ def _run_qa_db(query=None,
1933
 
1934
 
1935
  def get_similarity_chain(query=None,
 
1936
  use_openai_model=False, use_openai_embedding=False,
1937
  first_para=False, text_limit=None, top_k_docs=4, chunk=True, chunk_size=512,
1938
  user_path=None,
@@ -1947,6 +1978,7 @@ def get_similarity_chain(query=None,
1947
  load_db_if_exists=False,
1948
  db=None,
1949
  langchain_mode=None,
 
1950
  document_choice=[DocumentChoices.All_Relevant.name],
1951
  n_jobs=-1,
1952
  # beyond run_db_query:
@@ -1997,25 +2029,56 @@ def get_similarity_chain(query=None,
1997
  db=db,
1998
  n_jobs=n_jobs,
1999
  verbose=verbose)
2000
-
2001
- if 'falcon' in model_name:
2002
- extra = "According to only the information in the document sources provided within the context above, "
2003
- prefix = "Pay attention and remember information below, which will help to answer the question or imperative after the context ends."
2004
- elif inference_server in ['openai', 'openai_chat']:
2005
- extra = "According to (primarily) the information in the document sources provided within context above, "
2006
- prefix = "Pay attention and remember information below, which will help to answer the question or imperative after the context ends. If the answer cannot be primarily obtained from information within the context, then respond that the answer does not appear in the context of the documents."
2007
- else:
2008
- extra = ""
2009
- prefix = ""
2010
- if langchain_mode in ['Disabled', 'ChatLLM', 'LLM'] or not use_context:
2011
- template_if_no_docs = template = """%s{context}{question}""" % prefix
2012
- else:
2013
- template = """%s
2014
- \"\"\"
2015
- {context}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2016
  \"\"\"
2017
- %s{question}""" % (prefix, extra)
2018
- template_if_no_docs = """%s{context}%s{question}""" % (prefix, extra)
 
 
 
 
 
 
 
2019
  if not use_openai_model and prompt_type not in ['plain'] or model_name in non_hf_types:
2020
  use_template = True
2021
  else:
@@ -2040,14 +2103,26 @@ def get_similarity_chain(query=None,
2040
  if cmd == DocumentChoices.Just_LLM.name:
2041
  docs = []
2042
  scores = []
2043
- elif cmd == DocumentChoices.Only_All_Sources.name:
2044
  db_documents, db_metadatas = get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs)
2045
  # similar to langchain's chroma's _results_to_docs_and_scores
2046
  docs_with_score = [(Document(page_content=result[0], metadata=result[1] or {}), 0)
2047
- for result in zip(db_documents, db_metadatas)][:top_k_docs]
 
 
 
 
 
 
 
 
 
2048
  docs = [x[0] for x in docs_with_score]
2049
  scores = [x[1] for x in docs_with_score]
 
2050
  else:
 
 
2051
  if top_k_docs == -1 or auto_reduce_chunks:
2052
  # docs_with_score = db.similarity_search_with_score(query, k=k_db, **filter_kwargs)[:top_k_docs]
2053
  top_k_docs_tokenize = 100
@@ -2120,6 +2195,7 @@ def get_similarity_chain(query=None,
2120
  if reverse_docs:
2121
  docs_with_score.reverse()
2122
  # cut off so no high distance docs/sources considered
 
2123
  docs = [x[0] for x in docs_with_score if x[1] < cut_distanct]
2124
  scores = [x[1] for x in docs_with_score if x[1] < cut_distanct]
2125
  if len(scores) > 0 and verbose:
@@ -2131,14 +2207,14 @@ def get_similarity_chain(query=None,
2131
 
2132
  if not docs and use_context and model_name not in non_hf_types:
2133
  # if HF type and have no docs, can bail out
2134
- return docs, None, [], False
2135
 
2136
- if cmd in [DocumentChoices.All_Relevant_Only_Sources.name, DocumentChoices.Only_All_Sources.name]:
2137
  # no LLM use
2138
- return docs, None, [], False
2139
 
2140
  common_words_file = "data/NGSL_1.2_stats.csv.zip"
2141
- if os.path.isfile(common_words_file):
2142
  df = pd.read_csv("data/NGSL_1.2_stats.csv.zip")
2143
  import string
2144
  reduced_query = query.translate(str.maketrans(string.punctuation, ' ' * len(string.punctuation))).strip()
@@ -2155,25 +2231,47 @@ def get_similarity_chain(query=None,
2155
  use_context = False
2156
  template = template_if_no_docs
2157
 
2158
- if use_template:
2159
- # instruct-like, rather than few-shot prompt_type='plain' as default
2160
- # but then sources confuse the model with how inserted among rest of text, so avoid
2161
- prompt = PromptTemplate(
2162
- # input_variables=["summaries", "question"],
2163
- input_variables=["context", "question"],
2164
- template=template,
2165
- )
2166
- chain = load_qa_chain(llm, prompt=prompt)
2167
- else:
2168
- chain = load_qa_with_sources_chain(llm)
2169
-
2170
- if not use_context:
2171
- chain_kwargs = dict(input_documents=[], question=query)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2172
  else:
2173
- chain_kwargs = dict(input_documents=docs, question=query)
2174
 
2175
- target = wrapped_partial(chain, chain_kwargs)
2176
- return docs, target, scores, use_context
2177
 
2178
 
2179
  def get_sources_answer(query, answer, scores, show_rank, answer_with_sources, verbose=False):
@@ -2243,6 +2341,11 @@ def chunk_sources(sources, chunk=True, chunk_size=512, language=None):
2243
  splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=0, keep_separator=keep_separator,
2244
  separators=separators)
2245
  source_chunks = splitter.split_documents(sources)
 
 
 
 
 
2246
  return source_chunks
2247
 
2248
 
 
23
  from langchain.embeddings import HuggingFaceInstructEmbeddings
24
  from tqdm import tqdm
25
 
26
+ from enums import DocumentChoices, no_lora_str, model_token_mapping, source_prefix, source_postfix, non_query_commands, \
27
+ LangChainAction, LangChainMode
28
+ from src.gen import gen_hyper, get_model, SEED
29
  from prompter import non_hf_types, PromptType, Prompter
30
  from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \
31
  get_device, ProgressParallel, remove, hash_file, clear_torch_cache, NullContext, get_hf_server, FakeTokenizer
 
44
  from langchain.document_loaders import PyPDFLoader, TextLoader, CSVLoader, PythonLoader, TomlLoader, \
45
  UnstructuredURLLoader, UnstructuredHTMLLoader, UnstructuredWordDocumentLoader, UnstructuredMarkdownLoader, \
46
  EverNoteLoader, UnstructuredEmailLoader, UnstructuredODTLoader, UnstructuredPowerPointLoader, \
47
+ UnstructuredEPubLoader, UnstructuredImageLoader, UnstructuredRTFLoader, ArxivLoader, UnstructuredPDFLoader, \
48
+ UnstructuredExcelLoader
49
  from langchain.text_splitter import RecursiveCharacterTextSplitter, Language
50
  from langchain.chains.question_answering import load_qa_chain
51
  from langchain.docstore.document import Document
 
353
  stream_output = self.stream
354
  gr_client = self.client
355
  client_langchain_mode = 'Disabled'
356
+ client_langchain_action = LangChainAction.QUERY.value
357
  top_k_docs = 1
358
  chunk = True
359
  chunk_size = 512
 
382
  instruction_nochat=prompt if not self.chat_client else '',
383
  iinput_nochat='', # only for chat=False
384
  langchain_mode=client_langchain_mode,
385
+ langchain_action=client_langchain_action,
386
  top_k_docs=top_k_docs,
387
  chunk=chunk,
388
  chunk_size=chunk_size,
 
641
  callbacks = [StreamingGradioCallbackHandler()]
642
  assert prompter is not None
643
  stop_sequences = list(set(prompter.terminate_response + [prompter.PreResponse]))
644
+ stop_sequences = [x for x in stop_sequences if x]
645
 
646
  if gr_client:
647
  chat_client = False
 
749
 
750
  if stream_output:
751
  skip_prompt = False
752
+ from src.gen import H2OTextIteratorStreamer
753
  decoder_kwargs = {}
754
  streamer = H2OTextIteratorStreamer(tokenizer, skip_prompt=skip_prompt, block=False, **decoder_kwargs)
755
  gen_kwargs.update(dict(streamer=streamer))
 
949
 
950
  image_types = ["png", "jpg", "jpeg"]
951
  non_image_types = ["pdf", "txt", "csv", "toml", "py", "rst", "rtf",
952
+ "md",
953
+ "html", "mhtml",
954
  "enex", "eml", "epub", "odt", "pptx", "ppt",
955
  "zip", "urls",
956
+
957
  ]
958
  # "msg", GPL3
959
 
960
  if have_libreoffice:
961
+ non_image_types.extend(["docx", "doc", "xls", "xlsx"])
962
 
963
  file_types = non_image_types + image_types
964
 
 
968
  hashid = hash_file(file)
969
  if not isinstance(docs1, (list, tuple, types.GeneratorType)):
970
  docs1 = [docs1]
971
+ [x.metadata.update(dict(input_type=file_extension, date=str(datetime.now()), hashid=hashid)) for x in docs1]
972
 
973
 
974
  def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
 
1045
  docs1 = UnstructuredWordDocumentLoader(file_path=file).load()
1046
  add_meta(docs1, file)
1047
  doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
1048
+ elif (file.lower().endswith('.xlsx') or file.lower().endswith('.xls')) and have_libreoffice:
1049
+ docs1 = UnstructuredExcelLoader(file_path=file).load()
1050
+ add_meta(docs1, file)
1051
+ doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
1052
  elif file.lower().endswith('.odt'):
1053
  docs1 = UnstructuredODTLoader(file_path=file).load()
1054
  add_meta(docs1, file)
 
1769
 
1770
 
1771
  def _run_qa_db(query=None,
1772
+ iinput=None,
1773
+ context=None,
1774
  use_openai_model=False, use_openai_embedding=False,
1775
  first_para=False, text_limit=None, top_k_docs=4, chunk=True, chunk_size=512,
1776
  user_path=None,
 
1800
  repetition_penalty=1.0,
1801
  num_return_sequences=1,
1802
  langchain_mode=None,
1803
+ langchain_action=None,
1804
  document_choice=[DocumentChoices.All_Relevant.name],
1805
  n_jobs=-1,
1806
  verbose=False,
 
1817
  :param use_openai_embedding:
1818
  :param first_para:
1819
  :param text_limit:
1820
+ :param top_k_docs:
1821
  :param chunk:
1822
  :param chunk_size:
1823
  :param user_path: user path to glob recursively from
 
1883
  sim_kwargs = {k: v for k, v in locals().items() if k in func_names}
1884
  missing_kwargs = [x for x in func_names if x not in sim_kwargs]
1885
  assert not missing_kwargs, "Missing: %s" % missing_kwargs
1886
+ docs, chain, scores, use_context, have_any_docs = get_similarity_chain(**sim_kwargs)
1887
+ if cmd in non_query_commands:
1888
  formatted_doc_chunks = '\n\n'.join([get_url(x) + '\n\n' + x.page_content for x in docs])
1889
  yield formatted_doc_chunks, ''
1890
  return
1891
+ if not docs and langchain_action in [LangChainAction.SUMMARIZE_MAP.value,
1892
+ LangChainAction.SUMMARIZE_ALL.value,
1893
+ LangChainAction.SUMMARIZE_REFINE.value]:
1894
+ ret = 'No relevant documents to summarize.' if have_any_docs else 'No documents to summarize.'
1895
+ extra = ''
1896
+ yield ret, extra
1897
+ return
1898
+ if not docs and langchain_mode not in [LangChainMode.DISABLED.value,
1899
+ LangChainMode.CHAT_LLM.value,
1900
+ LangChainMode.LLM.value]:
1901
+ ret = 'No relevant documents to query.' if have_any_docs else 'No documents to query.'
1902
+ extra = ''
1903
+ yield ret, extra
1904
+ return
1905
+
1906
  if chain is None and model_name not in non_hf_types:
1907
+ # here if no docs at all and not HF type
1908
  # can only return if HF type
1909
  return
1910
 
 
1963
 
1964
 
1965
  def get_similarity_chain(query=None,
1966
+ iinput=None,
1967
  use_openai_model=False, use_openai_embedding=False,
1968
  first_para=False, text_limit=None, top_k_docs=4, chunk=True, chunk_size=512,
1969
  user_path=None,
 
1978
  load_db_if_exists=False,
1979
  db=None,
1980
  langchain_mode=None,
1981
+ langchain_action=None,
1982
  document_choice=[DocumentChoices.All_Relevant.name],
1983
  n_jobs=-1,
1984
  # beyond run_db_query:
 
2029
  db=db,
2030
  n_jobs=n_jobs,
2031
  verbose=verbose)
2032
+ have_any_docs = db is not None
2033
+ if langchain_action == LangChainAction.QUERY.value:
2034
+ if iinput:
2035
+ query = "%s\n%s" % (query, iinput)
2036
+
2037
+ if 'falcon' in model_name:
2038
+ extra = "According to only the information in the document sources provided within the context above, "
2039
+ prefix = "Pay attention and remember information below, which will help to answer the question or imperative after the context ends."
2040
+ elif inference_server in ['openai', 'openai_chat']:
2041
+ extra = "According to (primarily) the information in the document sources provided within context above, "
2042
+ prefix = "Pay attention and remember information below, which will help to answer the question or imperative after the context ends. If the answer cannot be primarily obtained from information within the context, then respond that the answer does not appear in the context of the documents."
2043
+ else:
2044
+ extra = ""
2045
+ prefix = ""
2046
+ if langchain_mode in ['Disabled', 'ChatLLM', 'LLM'] or not use_context:
2047
+ template_if_no_docs = template = """%s{context}{question}""" % prefix
2048
+ else:
2049
+ template = """%s
2050
+ \"\"\"
2051
+ {context}
2052
+ \"\"\"
2053
+ %s{question}""" % (prefix, extra)
2054
+ template_if_no_docs = """%s{context}%s{question}""" % (prefix, extra)
2055
+ elif langchain_action in [LangChainAction.SUMMARIZE_ALL.value, LangChainAction.SUMMARIZE_MAP.value]:
2056
+ none = ['', '\n', None]
2057
+ if query in none and iinput in none:
2058
+ prompt_summary = "Using only the text above, write a condensed and concise summary:\n"
2059
+ elif query not in none:
2060
+ prompt_summary = "Focusing on %s, write a condensed and concise Summary:\n" % query
2061
+ elif iinput not in None:
2062
+ prompt_summary = iinput
2063
+ else:
2064
+ prompt_summary = "Focusing on %s, %s:\n" % (query, iinput)
2065
+ # don't auto reduce
2066
+ auto_reduce_chunks = False
2067
+ if langchain_action == LangChainAction.SUMMARIZE_MAP.value:
2068
+ fstring = '{text}'
2069
+ else:
2070
+ fstring = '{input_documents}'
2071
+ template = """In order to write a concise single-paragraph or bulleted list summary, pay attention to the following text:
2072
  \"\"\"
2073
+ %s
2074
+ \"\"\"\n%s""" % (fstring, prompt_summary)
2075
+ template_if_no_docs = "Exactly only say: There are no documents to summarize."
2076
+ elif langchain_action in [LangChainAction.SUMMARIZE_REFINE]:
2077
+ template = '' # unused
2078
+ template_if_no_docs = '' # unused
2079
+ else:
2080
+ raise RuntimeError("No such langchain_action=%s" % langchain_action)
2081
+
2082
  if not use_openai_model and prompt_type not in ['plain'] or model_name in non_hf_types:
2083
  use_template = True
2084
  else:
 
2103
  if cmd == DocumentChoices.Just_LLM.name:
2104
  docs = []
2105
  scores = []
2106
+ elif cmd == DocumentChoices.Only_All_Sources.name or query in [None, '', '\n']:
2107
  db_documents, db_metadatas = get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs)
2108
  # similar to langchain's chroma's _results_to_docs_and_scores
2109
  docs_with_score = [(Document(page_content=result[0], metadata=result[1] or {}), 0)
2110
+ for result in zip(db_documents, db_metadatas)]
2111
+
2112
+ # order documents
2113
+ doc_hashes = [x['doc_hash'] for x in db_metadatas]
2114
+ doc_chunk_ids = [x['chunk_id'] for x in db_metadatas]
2115
+ docs_with_score = [x for _, _, x in
2116
+ sorted(zip(doc_hashes, doc_chunk_ids, docs_with_score), key=lambda x: (x[0], x[1]))
2117
+ ]
2118
+
2119
+ docs_with_score = docs_with_score[:top_k_docs]
2120
  docs = [x[0] for x in docs_with_score]
2121
  scores = [x[1] for x in docs_with_score]
2122
+ have_any_docs |= len(docs) > 0
2123
  else:
2124
+ # FIXME: if langchain_action == LangChainAction.SUMMARIZE_MAP.value
2125
+ # if map_reduce, then no need to auto reduce chunks
2126
  if top_k_docs == -1 or auto_reduce_chunks:
2127
  # docs_with_score = db.similarity_search_with_score(query, k=k_db, **filter_kwargs)[:top_k_docs]
2128
  top_k_docs_tokenize = 100
 
2195
  if reverse_docs:
2196
  docs_with_score.reverse()
2197
  # cut off so no high distance docs/sources considered
2198
+ have_any_docs |= len(docs_with_score) > 0 # before cut
2199
  docs = [x[0] for x in docs_with_score if x[1] < cut_distanct]
2200
  scores = [x[1] for x in docs_with_score if x[1] < cut_distanct]
2201
  if len(scores) > 0 and verbose:
 
2207
 
2208
  if not docs and use_context and model_name not in non_hf_types:
2209
  # if HF type and have no docs, can bail out
2210
+ return docs, None, [], False, have_any_docs
2211
 
2212
+ if cmd in non_query_commands:
2213
  # no LLM use
2214
+ return docs, None, [], False, have_any_docs
2215
 
2216
  common_words_file = "data/NGSL_1.2_stats.csv.zip"
2217
+ if os.path.isfile(common_words_file) and langchain_mode == LangChainAction.QUERY.value:
2218
  df = pd.read_csv("data/NGSL_1.2_stats.csv.zip")
2219
  import string
2220
  reduced_query = query.translate(str.maketrans(string.punctuation, ' ' * len(string.punctuation))).strip()
 
2231
  use_context = False
2232
  template = template_if_no_docs
2233
 
2234
+ if langchain_action == LangChainAction.QUERY.value:
2235
+ if use_template:
2236
+ # instruct-like, rather than few-shot prompt_type='plain' as default
2237
+ # but then sources confuse the model with how inserted among rest of text, so avoid
2238
+ prompt = PromptTemplate(
2239
+ # input_variables=["summaries", "question"],
2240
+ input_variables=["context", "question"],
2241
+ template=template,
2242
+ )
2243
+ chain = load_qa_chain(llm, prompt=prompt)
2244
+ else:
2245
+ # only if use_openai_model = True, unused normally except in testing
2246
+ chain = load_qa_with_sources_chain(llm)
2247
+ if not use_context:
2248
+ chain_kwargs = dict(input_documents=[], question=query)
2249
+ else:
2250
+ chain_kwargs = dict(input_documents=docs, question=query)
2251
+ target = wrapped_partial(chain, chain_kwargs)
2252
+ elif langchain_action in [LangChainAction.SUMMARIZE_MAP.value,
2253
+ LangChainAction.SUMMARIZE_REFINE,
2254
+ LangChainAction.SUMMARIZE_ALL.value]:
2255
+ from langchain.chains.summarize import load_summarize_chain
2256
+ if langchain_action == LangChainAction.SUMMARIZE_MAP.value:
2257
+ prompt = PromptTemplate(input_variables=["text"], template=template)
2258
+ chain = load_summarize_chain(llm, chain_type="map_reduce",
2259
+ map_prompt=prompt, combine_prompt=prompt, return_intermediate_steps=True)
2260
+ target = wrapped_partial(chain, {"input_documents": docs}) # , return_only_outputs=True)
2261
+ elif langchain_action == LangChainAction.SUMMARIZE_ALL.value:
2262
+ assert use_template
2263
+ prompt = PromptTemplate(input_variables=["text"], template=template)
2264
+ chain = load_summarize_chain(llm, chain_type="stuff", prompt=prompt, return_intermediate_steps=True)
2265
+ target = wrapped_partial(chain)
2266
+ elif langchain_action == LangChainAction.SUMMARIZE_REFINE.value:
2267
+ chain = load_summarize_chain(llm, chain_type="refine", return_intermediate_steps=True)
2268
+ target = wrapped_partial(chain)
2269
+ else:
2270
+ raise RuntimeError("No such langchain_action=%s" % langchain_action)
2271
  else:
2272
+ raise RuntimeError("No such langchain_action=%s" % langchain_action)
2273
 
2274
+ return docs, target, scores, use_context, have_any_docs
 
2275
 
2276
 
2277
  def get_sources_answer(query, answer, scores, show_rank, answer_with_sources, verbose=False):
 
2341
  splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=0, keep_separator=keep_separator,
2342
  separators=separators)
2343
  source_chunks = splitter.split_documents(sources)
2344
+
2345
+ # currently in order, but when pull from db won't be, so mark order and document by hash
2346
+ doc_hash = str(uuid.uuid4())[:10]
2347
+ [x.metadata.update(dict(doc_hash=doc_hash, chunk_id=chunk_id)) for chunk_id, x in enumerate(source_chunks)]
2348
+
2349
  return source_chunks
2350
 
2351
 
gradio_runner.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import copy
2
  import functools
3
  import inspect
@@ -49,16 +50,16 @@ def fix_pydantic_duplicate_validators_error():
49
 
50
  fix_pydantic_duplicate_validators_error()
51
 
52
- from enums import DocumentChoices, no_model_str, no_lora_str, no_server_str, LangChainMode
53
  from gradio_themes import H2oTheme, SoftTheme, get_h2o_title, get_simple_title, get_dark_js, spacing_xsm, radius_xsm, \
54
  text_xsm
55
  from prompter import prompt_type_to_model_name, prompt_types_strings, inv_prompt_type_to_model_lower, non_hf_types, \
56
  get_prompt
57
  from utils import get_githash, flatten_list, zip_data, s3up, clear_torch_cache, get_torch_allocated, system_info_print, \
58
  ping, get_short_name, get_url, makedirs, get_kwargs, remove, system_info, ping_gpu
59
- from generate import get_model, languages_covered, evaluate, eval_func_param_names, score_qa, langchain_modes, \
60
- inputs_kwargs_list, scratch_base_dir, evaluate_from_str, no_default_param_names, \
61
- eval_func_param_names_defaults, get_max_max_new_tokens, get_minmax_top_k_docs, history_to_context
62
 
63
  from apscheduler.schedulers.background import BackgroundScheduler
64
 
@@ -99,6 +100,7 @@ def go_gradio(**kwargs):
99
  dbs = kwargs['dbs']
100
  db_type = kwargs['db_type']
101
  visible_langchain_modes = kwargs['visible_langchain_modes']
 
102
  allow_upload_to_user_data = kwargs['allow_upload_to_user_data']
103
  allow_upload_to_my_data = kwargs['allow_upload_to_my_data']
104
  enable_sources_list = kwargs['enable_sources_list']
@@ -213,7 +215,28 @@ def go_gradio(**kwargs):
213
  'base_model') else no_model_msg
214
  output_label0_model2 = no_model_msg
215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  default_kwargs = {k: kwargs[k] for k in eval_func_param_names_defaults}
 
 
 
 
217
  for k in no_default_param_names:
218
  default_kwargs[k] = ''
219
 
@@ -239,7 +262,8 @@ def go_gradio(**kwargs):
239
  model_options_state = gr.State([model_options])
240
  lora_options_state = gr.State([lora_options])
241
  server_options_state = gr.State([server_options])
242
- my_db_state = gr.State([None, None])
 
243
  chat_state = gr.State({})
244
  # make user default first and default choice, dedup
245
  docs_state00 = kwargs['document_choice'] + [x.name for x in list(DocumentChoices)]
@@ -332,6 +356,12 @@ def go_gradio(**kwargs):
332
  value=kwargs['langchain_mode'],
333
  label="Data Collection of Sources",
334
  visible=kwargs['langchain_mode'] != 'Disabled')
 
 
 
 
 
 
335
  data_row2 = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled')
336
  with data_row2:
337
  with gr.Column(scale=50):
@@ -920,19 +950,59 @@ def go_gradio(**kwargs):
920
  for k in inputs_kwargs_list:
921
  assert k in kwargs_evaluate, "Missing %s" % k
922
 
923
- def evaluate_gradio(*args1, **kwargs1):
924
- for res_dict in evaluate(*args1, **kwargs1):
925
- if kwargs['langchain_mode'] == 'Disabled':
926
- yield fix_text_for_gradio(res_dict['response'])
927
- else:
928
- yield '<br>' + fix_text_for_gradio(res_dict['response'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
929
 
930
- fun = partial(evaluate_gradio,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
931
  **kwargs_evaluate)
932
- fun2 = partial(evaluate_gradio,
 
 
933
  **kwargs_evaluate)
934
- fun_with_dict_str = partial(evaluate_from_str,
935
- default_kwargs=default_kwargs,
 
936
  **kwargs_evaluate
937
  )
938
 
@@ -1072,14 +1142,17 @@ def go_gradio(**kwargs):
1072
  User that fills history for bot
1073
  :param args:
1074
  :param undo:
 
1075
  :param sanitize_user_prompt:
1076
- :param model2:
1077
  :return:
1078
  """
1079
  args_list = list(args)
1080
  user_message = args_list[eval_func_param_names.index('instruction')] # chat only
1081
  input1 = args_list[eval_func_param_names.index('iinput')] # chat only
1082
  prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
 
 
 
1083
  if not prompt_type1:
1084
  # shouldn't have to specify if CLI launched model
1085
  prompt_type1 = kwargs['prompt_type']
@@ -1110,8 +1183,12 @@ def go_gradio(**kwargs):
1110
  history[-1][1] = None
1111
  return history
1112
  if user_message1 in ['', None, '\n']:
1113
- # reject non-retry submit/enter
1114
- return history
 
 
 
 
1115
  user_message1 = fix_text_for_gradio(user_message1)
1116
  return history + [[user_message1, None]]
1117
 
@@ -1147,11 +1224,13 @@ def go_gradio(**kwargs):
1147
  else:
1148
  return 2000
1149
 
1150
- def prep_bot(*args, retry=False):
1151
  """
1152
 
1153
  :param args:
1154
  :param retry:
 
 
1155
  :return: last element is True if should run bot, False if should just yield history
1156
  """
1157
  # don't deepcopy, can contain model itself
@@ -1159,12 +1238,16 @@ def go_gradio(**kwargs):
1159
  model_state1 = args_list[-3]
1160
  my_db_state1 = args_list[-2]
1161
  history = args_list[-1]
1162
- langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
 
1163
 
1164
  if model_state1['model'] is None or model_state1['model'] == no_model_str:
1165
  return history, None, None, None
1166
 
1167
  args_list = args_list[:-3] # only keep rest needed for evaluate()
 
 
 
1168
  if not history:
1169
  print("No history", flush=True)
1170
  history = []
@@ -1175,22 +1258,23 @@ def go_gradio(**kwargs):
1175
  instruction1 = history[-1][0]
1176
  history[-1][1] = None
1177
  elif not instruction1:
1178
- # if not retrying, then reject empty query
1179
- return history, None, None, None
 
 
 
 
1180
  elif len(history) > 0 and history[-1][1] not in [None, '']:
1181
  # reject submit button if already filled and not retrying
1182
  # None when not filling with '' to keep client happy
1183
  return history, None, None, None
1184
 
1185
  # shouldn't have to specify in API prompt_type if CLI launched model, so prefer global CLI one if have it
1186
- prompt_type1 = kwargs.get('prompt_type', args_list[eval_func_param_names.index('prompt_type')])
1187
- # prefer model specific prompt type instead of global one, and apply back to args_list for evaluate()
1188
- args_list[eval_func_param_names.index('prompt_type')] = prompt_type1 = \
1189
- model_state1.get('prompt_type', prompt_type1)
1190
-
1191
- prompt_dict1 = kwargs.get('prompt_dict', args_list[eval_func_param_names.index('prompt_dict')])
1192
- args_list[eval_func_param_names.index('prompt_dict')] = prompt_dict1 = \
1193
- model_state1.get('prompt_dict', prompt_dict1)
1194
 
1195
  chat1 = args_list[eval_func_param_names.index('chat')]
1196
  model_max_length1 = get_model_max_length(model_state1)
@@ -1264,6 +1348,7 @@ def go_gradio(**kwargs):
1264
  for res in get_response(fun1, history):
1265
  yield res
1266
  finally:
 
1267
  clear_embeddings(langchain_mode1, my_db_state1)
1268
 
1269
  def all_bot(*args, retry=False, model_states1=None):
@@ -1277,7 +1362,7 @@ def go_gradio(**kwargs):
1277
  my_db_state1 = None # will be filled below by some bot
1278
  try:
1279
  gen_list = []
1280
- for chatbot1, model_state1 in zip(chatbots, model_states1):
1281
  args_list1 = args_list0.copy()
1282
  args_list1.insert(-1, model_state1) # insert at -1 so is at -2
1283
  # if at start, have None in response still, replace with '' so client etc. acts like normal
@@ -1289,7 +1374,8 @@ def go_gradio(**kwargs):
1289
  # so consistent with prep_bot()
1290
  # with model_state1 at -3, my_db_state1 at -2, and history(chatbot) at -1
1291
  # langchain_mode1 and my_db_state1 should be same for every bot
1292
- history, fun1, langchain_mode1, my_db_state1 = prep_bot(*tuple(args_list1), retry=retry)
 
1293
  gen1 = get_response(fun1, history)
1294
  if stream_output1:
1295
  gen1 = TimeoutIterator(gen1, timeout=0.01, sentinel=None, raise_on_exception=False)
@@ -1301,6 +1387,7 @@ def go_gradio(**kwargs):
1301
  tgen0 = time.time()
1302
  for res1 in itertools.zip_longest(*gen_list):
1303
  if time.time() - tgen0 > max_time1:
 
1304
  break
1305
 
1306
  bots = [x[0] if x is not None and not isinstance(x, BaseException) else y for x, y in
@@ -1735,6 +1822,9 @@ def go_gradio(**kwargs):
1735
 
1736
  def load_model(model_name, lora_weights, server_name, model_state_old, prompt_type_old, load_8bit,
1737
  infer_devices, gpu_id):
 
 
 
1738
  # ensure old model removed from GPU memory
1739
  if kwargs['debug']:
1740
  print("Pre-switch pre-del GPU memory: %s" % get_torch_allocated(), flush=True)
@@ -2161,6 +2251,15 @@ def update_user_db(file, db1, x, y, *args, dbs=None, langchain_mode='UserData',
2161
  clear_torch_cache()
2162
 
2163
 
 
 
 
 
 
 
 
 
 
2164
  def _update_user_db(file, db1, x, y, chunk, chunk_size, dbs=None, db_type=None, langchain_mode='UserData',
2165
  user_path=None,
2166
  use_openai_embedding=None,
@@ -2222,7 +2321,8 @@ def _update_user_db(file, db1, x, y, chunk, chunk_size, dbs=None, db_type=None,
2222
  exceptions = [x for x in sources if x.metadata.get('exception')]
2223
  sources = [x for x in sources if 'exception' not in x.metadata]
2224
 
2225
- with filelock.FileLock("db_%s.lock" % langchain_mode.replace(' ', '_')):
 
2226
  if langchain_mode == 'MyData':
2227
  if db1[0] is not None:
2228
  # then add
@@ -2235,18 +2335,14 @@ def _update_user_db(file, db1, x, y, chunk, chunk_size, dbs=None, db_type=None,
2235
  # for production hit, when user gets clicky:
2236
  assert len(db1) == 2, "Bad MyData db: %s" % db1
2237
  # then create
2238
- # assign fresh hash for this user session, so not shared
2239
  # if added has to original state and didn't change, then would be shared db for all users
2240
- db1[1] = str(uuid.uuid4())
2241
  persist_directory = os.path.join(scratch_base_dir, 'db_dir_%s_%s' % (langchain_mode, db1[1]))
2242
  db = get_db(sources, use_openai_embedding=use_openai_embedding,
2243
  db_type=db_type,
2244
  persist_directory=persist_directory,
2245
  langchain_mode=langchain_mode,
2246
  hf_embedding_model=hf_embedding_model)
2247
- if db is None:
2248
- db1[1] = None
2249
- else:
2250
  db1[0] = db
2251
  source_files_added = get_source_files(db=db1[0], exceptions=exceptions)
2252
  return None, langchain_mode, db1, x, y, source_files_added
@@ -2274,7 +2370,9 @@ def _update_user_db(file, db1, x, y, chunk, chunk_size, dbs=None, db_type=None,
2274
 
2275
 
2276
  def get_db(db1, langchain_mode, dbs=None):
2277
- with filelock.FileLock("db_%s.lock" % langchain_mode.replace(' ', '_')):
 
 
2278
  if langchain_mode in ['wiki_full']:
2279
  # NOTE: avoid showing full wiki. Takes about 30 seconds over about 90k entries, but not useful for now
2280
  db = None
 
1
+ import ast
2
  import copy
3
  import functools
4
  import inspect
 
50
 
51
  fix_pydantic_duplicate_validators_error()
52
 
53
+ from enums import DocumentChoices, no_model_str, no_lora_str, no_server_str, LangChainAction, LangChainMode
54
  from gradio_themes import H2oTheme, SoftTheme, get_h2o_title, get_simple_title, get_dark_js, spacing_xsm, radius_xsm, \
55
  text_xsm
56
  from prompter import prompt_type_to_model_name, prompt_types_strings, inv_prompt_type_to_model_lower, non_hf_types, \
57
  get_prompt
58
  from utils import get_githash, flatten_list, zip_data, s3up, clear_torch_cache, get_torch_allocated, system_info_print, \
59
  ping, get_short_name, get_url, makedirs, get_kwargs, remove, system_info, ping_gpu
60
+ from src.gen import get_model, languages_covered, evaluate, eval_func_param_names, score_qa, langchain_modes, \
61
+ inputs_kwargs_list, scratch_base_dir, no_default_param_names, \
62
+ eval_func_param_names_defaults, get_max_max_new_tokens, get_minmax_top_k_docs, history_to_context, langchain_actions
63
 
64
  from apscheduler.schedulers.background import BackgroundScheduler
65
 
 
100
  dbs = kwargs['dbs']
101
  db_type = kwargs['db_type']
102
  visible_langchain_modes = kwargs['visible_langchain_modes']
103
+ visible_langchain_actions = kwargs['visible_langchain_actions']
104
  allow_upload_to_user_data = kwargs['allow_upload_to_user_data']
105
  allow_upload_to_my_data = kwargs['allow_upload_to_my_data']
106
  enable_sources_list = kwargs['enable_sources_list']
 
215
  'base_model') else no_model_msg
216
  output_label0_model2 = no_model_msg
217
 
218
+ def update_prompt(prompt_type1, prompt_dict1, model_state1, which_model=0):
219
+ if not prompt_type1 or which_model != 0:
220
+ # keep prompt_type and prompt_dict in sync if possible
221
+ prompt_type1 = kwargs.get('prompt_type', prompt_type1)
222
+ prompt_dict1 = kwargs.get('prompt_dict', prompt_dict1)
223
+ # prefer model specific prompt type instead of global one
224
+ if not prompt_type1 or which_model != 0:
225
+ prompt_type1 = model_state1.get('prompt_type', prompt_type1)
226
+ prompt_dict1 = model_state1.get('prompt_dict', prompt_dict1)
227
+
228
+ if not prompt_dict1 or which_model != 0:
229
+ # if still not defined, try to get
230
+ prompt_dict1 = kwargs.get('prompt_dict', prompt_dict1)
231
+ if not prompt_dict1 or which_model != 0:
232
+ prompt_dict1 = model_state1.get('prompt_dict', prompt_dict1)
233
+ return prompt_type1, prompt_dict1
234
+
235
  default_kwargs = {k: kwargs[k] for k in eval_func_param_names_defaults}
236
+ # ensure prompt_type consistent with prep_bot(), so nochat API works same way
237
+ default_kwargs['prompt_type'], default_kwargs['prompt_dict'] = \
238
+ update_prompt(default_kwargs['prompt_type'], default_kwargs['prompt_dict'],
239
+ model_state1=model_state0, which_model=0)
240
  for k in no_default_param_names:
241
  default_kwargs[k] = ''
242
 
 
262
  model_options_state = gr.State([model_options])
263
  lora_options_state = gr.State([lora_options])
264
  server_options_state = gr.State([server_options])
265
+ # uuid in db is used as user ID
266
+ my_db_state = gr.State([None, str(uuid.uuid4())])
267
  chat_state = gr.State({})
268
  # make user default first and default choice, dedup
269
  docs_state00 = kwargs['document_choice'] + [x.name for x in list(DocumentChoices)]
 
356
  value=kwargs['langchain_mode'],
357
  label="Data Collection of Sources",
358
  visible=kwargs['langchain_mode'] != 'Disabled')
359
+ allowed_actions = [x for x in langchain_actions if x in visible_langchain_actions]
360
+ langchain_action = gr.Radio(
361
+ allowed_actions,
362
+ value=allowed_actions[0] if len(allowed_actions) > 0 else None,
363
+ label="Data Action",
364
+ visible=True)
365
  data_row2 = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled')
366
  with data_row2:
367
  with gr.Column(scale=50):
 
950
  for k in inputs_kwargs_list:
951
  assert k in kwargs_evaluate, "Missing %s" % k
952
 
953
+ def evaluate_nochat(*args1, default_kwargs1=None, str_api=False, **kwargs1):
954
+ args_list = list(args1)
955
+ if str_api:
956
+ user_kwargs = args_list[2]
957
+ assert isinstance(user_kwargs, str)
958
+ user_kwargs = ast.literal_eval(user_kwargs)
959
+ else:
960
+ user_kwargs = {k: v for k, v in zip(eval_func_param_names, args_list[2:])}
961
+ # only used for submit_nochat_api
962
+ user_kwargs['chat'] = False
963
+ if 'stream_output' not in user_kwargs:
964
+ user_kwargs['stream_output'] = False
965
+ if 'langchain_mode' not in user_kwargs:
966
+ # if user doesn't specify, then assume disabled, not use default
967
+ user_kwargs['langchain_mode'] = 'Disabled'
968
+ if 'langchain_action' not in user_kwargs:
969
+ user_kwargs['langchain_action'] = LangChainAction.QUERY.value
970
+
971
+ set1 = set(list(default_kwargs1.keys()))
972
+ set2 = set(eval_func_param_names)
973
+ assert set1 == set2, "Set diff: %s %s: %s" % (set1, set2, set1.symmetric_difference(set2))
974
+ # correct ordering. Note some things may not be in default_kwargs, so can't be default of user_kwargs.get()
975
+ model_state1 = args_list[0]
976
+ my_db_state1 = args_list[1]
977
+ args_list = [user_kwargs[k] if k in user_kwargs and user_kwargs[k] is not None else default_kwargs1[k] for k
978
+ in eval_func_param_names]
979
+ assert len(args_list) == len(eval_func_param_names)
980
+ args_list = [model_state1, my_db_state1] + args_list
981
 
982
+ try:
983
+ for res_dict in evaluate(*tuple(args_list), **kwargs1):
984
+ if str_api:
985
+ # full return of dict
986
+ yield res_dict
987
+ elif kwargs['langchain_mode'] == 'Disabled':
988
+ yield fix_text_for_gradio(res_dict['response'])
989
+ else:
990
+ yield '<br>' + fix_text_for_gradio(res_dict['response'])
991
+ finally:
992
+ clear_torch_cache()
993
+ clear_embeddings(user_kwargs['langchain_mode'], my_db_state1)
994
+
995
+ fun = partial(evaluate_nochat,
996
+ default_kwargs1=default_kwargs,
997
+ str_api=False,
998
  **kwargs_evaluate)
999
+ fun2 = partial(evaluate_nochat,
1000
+ default_kwargs1=default_kwargs,
1001
+ str_api=False,
1002
  **kwargs_evaluate)
1003
+ fun_with_dict_str = partial(evaluate_nochat,
1004
+ default_kwargs1=default_kwargs,
1005
+ str_api=True,
1006
  **kwargs_evaluate
1007
  )
1008
 
 
1142
  User that fills history for bot
1143
  :param args:
1144
  :param undo:
1145
+ :param retry:
1146
  :param sanitize_user_prompt:
 
1147
  :return:
1148
  """
1149
  args_list = list(args)
1150
  user_message = args_list[eval_func_param_names.index('instruction')] # chat only
1151
  input1 = args_list[eval_func_param_names.index('iinput')] # chat only
1152
  prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
1153
+ langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
1154
+ langchain_action1 = args_list[eval_func_param_names.index('langchain_action')]
1155
+ document_choice1 = args_list[eval_func_param_names.index('document_choice')]
1156
  if not prompt_type1:
1157
  # shouldn't have to specify if CLI launched model
1158
  prompt_type1 = kwargs['prompt_type']
 
1183
  history[-1][1] = None
1184
  return history
1185
  if user_message1 in ['', None, '\n']:
1186
+ if langchain_action1 in LangChainAction.QUERY.value and \
1187
+ DocumentChoices.Only_All_Sources.name not in document_choice1 \
1188
+ or \
1189
+ langchain_mode1 in [LangChainMode.CHAT_LLM.value, LangChainMode.LLM.value]:
1190
+ # reject non-retry submit/enter
1191
+ return history
1192
  user_message1 = fix_text_for_gradio(user_message1)
1193
  return history + [[user_message1, None]]
1194
 
 
1224
  else:
1225
  return 2000
1226
 
1227
+ def prep_bot(*args, retry=False, which_model=0):
1228
  """
1229
 
1230
  :param args:
1231
  :param retry:
1232
+ :param which_model: identifies which model if doing model_lock
1233
+ API only called for which_model=0, default for inputs_list, but rest should ignore inputs_list
1234
  :return: last element is True if should run bot, False if should just yield history
1235
  """
1236
  # don't deepcopy, can contain model itself
 
1238
  model_state1 = args_list[-3]
1239
  my_db_state1 = args_list[-2]
1240
  history = args_list[-1]
1241
+ prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
1242
+ prompt_dict1 = args_list[eval_func_param_names.index('prompt_dict')]
1243
 
1244
  if model_state1['model'] is None or model_state1['model'] == no_model_str:
1245
  return history, None, None, None
1246
 
1247
  args_list = args_list[:-3] # only keep rest needed for evaluate()
1248
+ langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
1249
+ langchain_action1 = args_list[eval_func_param_names.index('langchain_action')]
1250
+ document_choice1 = args_list[eval_func_param_names.index('document_choice')]
1251
  if not history:
1252
  print("No history", flush=True)
1253
  history = []
 
1258
  instruction1 = history[-1][0]
1259
  history[-1][1] = None
1260
  elif not instruction1:
1261
+ if langchain_action1 in LangChainAction.QUERY.value and \
1262
+ DocumentChoices.Only_All_Sources.name not in document_choice1 \
1263
+ or \
1264
+ langchain_mode1 in [LangChainMode.CHAT_LLM.value, LangChainMode.LLM.value]:
1265
+ # if not retrying, then reject empty query
1266
+ return history, None, None, None
1267
  elif len(history) > 0 and history[-1][1] not in [None, '']:
1268
  # reject submit button if already filled and not retrying
1269
  # None when not filling with '' to keep client happy
1270
  return history, None, None, None
1271
 
1272
  # shouldn't have to specify in API prompt_type if CLI launched model, so prefer global CLI one if have it
1273
+ prompt_type1, prompt_dict1 = update_prompt(prompt_type1, prompt_dict1, model_state1,
1274
+ which_model=which_model)
1275
+ # apply back to args_list for evaluate()
1276
+ args_list[eval_func_param_names.index('prompt_type')] = prompt_type1
1277
+ args_list[eval_func_param_names.index('prompt_dict')] = prompt_dict1
 
 
 
1278
 
1279
  chat1 = args_list[eval_func_param_names.index('chat')]
1280
  model_max_length1 = get_model_max_length(model_state1)
 
1348
  for res in get_response(fun1, history):
1349
  yield res
1350
  finally:
1351
+ clear_torch_cache()
1352
  clear_embeddings(langchain_mode1, my_db_state1)
1353
 
1354
  def all_bot(*args, retry=False, model_states1=None):
 
1362
  my_db_state1 = None # will be filled below by some bot
1363
  try:
1364
  gen_list = []
1365
+ for chatboti, (chatbot1, model_state1) in enumerate(zip(chatbots, model_states1)):
1366
  args_list1 = args_list0.copy()
1367
  args_list1.insert(-1, model_state1) # insert at -1 so is at -2
1368
  # if at start, have None in response still, replace with '' so client etc. acts like normal
 
1374
  # so consistent with prep_bot()
1375
  # with model_state1 at -3, my_db_state1 at -2, and history(chatbot) at -1
1376
  # langchain_mode1 and my_db_state1 should be same for every bot
1377
+ history, fun1, langchain_mode1, my_db_state1 = prep_bot(*tuple(args_list1), retry=retry,
1378
+ which_model=chatboti)
1379
  gen1 = get_response(fun1, history)
1380
  if stream_output1:
1381
  gen1 = TimeoutIterator(gen1, timeout=0.01, sentinel=None, raise_on_exception=False)
 
1387
  tgen0 = time.time()
1388
  for res1 in itertools.zip_longest(*gen_list):
1389
  if time.time() - tgen0 > max_time1:
1390
+ print("Took too long: %s" % max_time1, flush=True)
1391
  break
1392
 
1393
  bots = [x[0] if x is not None and not isinstance(x, BaseException) else y for x, y in
 
1822
 
1823
  def load_model(model_name, lora_weights, server_name, model_state_old, prompt_type_old, load_8bit,
1824
  infer_devices, gpu_id):
1825
+ # ensure no API calls reach here
1826
+ if is_public:
1827
+ raise RuntimeError("Illegal access for %s" % model_name)
1828
  # ensure old model removed from GPU memory
1829
  if kwargs['debug']:
1830
  print("Pre-switch pre-del GPU memory: %s" % get_torch_allocated(), flush=True)
 
2251
  clear_torch_cache()
2252
 
2253
 
2254
+ def get_lock_file(db1, langchain_mode):
2255
+ assert len(db1) == 2 and db1[1] is not None and isinstance(db1[1], str)
2256
+ user_id = db1[1]
2257
+ base_path = 'locks'
2258
+ makedirs(base_path)
2259
+ lock_file = "db_%s_%s.lock" % (langchain_mode.replace(' ', '_'), user_id)
2260
+ return lock_file
2261
+
2262
+
2263
  def _update_user_db(file, db1, x, y, chunk, chunk_size, dbs=None, db_type=None, langchain_mode='UserData',
2264
  user_path=None,
2265
  use_openai_embedding=None,
 
2321
  exceptions = [x for x in sources if x.metadata.get('exception')]
2322
  sources = [x for x in sources if 'exception' not in x.metadata]
2323
 
2324
+ lock_file = get_lock_file(db1, langchain_mode)
2325
+ with filelock.FileLock(lock_file):
2326
  if langchain_mode == 'MyData':
2327
  if db1[0] is not None:
2328
  # then add
 
2335
  # for production hit, when user gets clicky:
2336
  assert len(db1) == 2, "Bad MyData db: %s" % db1
2337
  # then create
 
2338
  # if added has to original state and didn't change, then would be shared db for all users
 
2339
  persist_directory = os.path.join(scratch_base_dir, 'db_dir_%s_%s' % (langchain_mode, db1[1]))
2340
  db = get_db(sources, use_openai_embedding=use_openai_embedding,
2341
  db_type=db_type,
2342
  persist_directory=persist_directory,
2343
  langchain_mode=langchain_mode,
2344
  hf_embedding_model=hf_embedding_model)
2345
+ if db is not None:
 
 
2346
  db1[0] = db
2347
  source_files_added = get_source_files(db=db1[0], exceptions=exceptions)
2348
  return None, langchain_mode, db1, x, y, source_files_added
 
2370
 
2371
 
2372
  def get_db(db1, langchain_mode, dbs=None):
2373
+ lock_file = get_lock_file(db1, langchain_mode)
2374
+
2375
+ with filelock.FileLock(lock_file):
2376
  if langchain_mode in ['wiki_full']:
2377
  # NOTE: avoid showing full wiki. Takes about 30 seconds over about 90k entries, but not useful for now
2378
  db = None
gradio_utils/__pycache__/css.cpython-310.pyc DELETED
Binary file (1.53 kB)
 
gradio_utils/__pycache__/grclient.cpython-310.pyc DELETED
Binary file (2.69 kB)
 
gradio_utils/__pycache__/prompt_form.cpython-310.pyc DELETED
Binary file (3.59 kB)
 
gradio_utils/css.py DELETED
@@ -1,53 +0,0 @@
1
- def get_css(kwargs) -> str:
2
- if kwargs['h2ocolors']:
3
- css_code = """footer {visibility: hidden;}
4
- body{background:linear-gradient(#f5f5f5,#e5e5e5);}
5
- body.dark{background:linear-gradient(#000000,#0d0d0d);}
6
- """
7
- else:
8
- css_code = """footer {visibility: hidden}"""
9
-
10
- css_code += make_css_base()
11
- return css_code
12
-
13
-
14
- def make_css_base() -> str:
15
- return """
16
- @import url('https://fonts.googleapis.com/css2?family=Source+Sans+Pro:wght@400;600&display=swap');
17
-
18
- body.dark{#warning {background-color: #555555};}
19
-
20
- #small_btn {
21
- margin: 0.6em 0em 0.55em 0;
22
- max-width: 20em;
23
- min-width: 5em !important;
24
- height: 5em;
25
- font-size: 14px !important;
26
- }
27
-
28
- #prompt-form {
29
- border: 1px solid var(--primary-500) !important;
30
- }
31
-
32
- #prompt-form.block {
33
- border-radius: var(--block-radius) !important;
34
- }
35
-
36
- #prompt-form textarea {
37
- border: 1px solid rgb(209, 213, 219);
38
- }
39
-
40
- #prompt-form label > div {
41
- margin-top: 4px;
42
- }
43
-
44
- button.primary:hover {
45
- background-color: var(--primary-600) !important;
46
- transition: .2s;
47
- }
48
-
49
- #prompt-form-area {
50
- margin-bottom: 2.5rem;
51
- }
52
- .chatsmall chatbot {font-size: 10px !important}
53
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_utils/grclient.py DELETED
@@ -1,82 +0,0 @@
1
- import traceback
2
- from typing import Callable
3
- import os
4
-
5
- from gradio_client.client import Job
6
-
7
- os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
8
-
9
- from gradio_client import Client
10
-
11
-
12
- class GradioClient(Client):
13
- """
14
- Parent class of gradio client
15
- To handle automatically refreshing client if detect gradio server changed
16
- """
17
-
18
- def __init__(self, *args, **kwargs):
19
- self.args = args
20
- self.kwargs = kwargs
21
- super().__init__(*args, **kwargs)
22
- self.server_hash = self.get_server_hash()
23
-
24
- def get_server_hash(self):
25
- """
26
- Get server hash using super without any refresh action triggered
27
- Returns: git hash of gradio server
28
- """
29
- return super().submit(api_name='/system_hash').result()
30
-
31
- def refresh_client_if_should(self):
32
- # get current hash in order to update api_name -> fn_index map in case gradio server changed
33
- # FIXME: Could add cli api as hash
34
- server_hash = self.get_server_hash()
35
- if self.server_hash != server_hash:
36
- self.refresh_client()
37
- self.server_hash = server_hash
38
- else:
39
- self.reset_session()
40
-
41
- def refresh_client(self):
42
- """
43
- Ensure every client call is independent
44
- Also ensure map between api_name and fn_index is updated in case server changed (e.g. restarted with new code)
45
- Returns:
46
- """
47
- # need session hash to be new every time, to avoid "generator already executing"
48
- self.reset_session()
49
-
50
- client = Client(*self.args, **self.kwargs)
51
- for k, v in client.__dict__.items():
52
- setattr(self, k, v)
53
-
54
- def submit(
55
- self,
56
- *args,
57
- api_name: str | None = None,
58
- fn_index: int | None = None,
59
- result_callbacks: Callable | list[Callable] | None = None,
60
- ) -> Job:
61
- # Note predict calls submit
62
- try:
63
- self.refresh_client_if_should()
64
- job = super().submit(*args, api_name=api_name, fn_index=fn_index)
65
- except Exception as e:
66
- print("Hit e=%s" % str(e), flush=True)
67
- # force reconfig in case only that
68
- self.refresh_client()
69
- job = super().submit(*args, api_name=api_name, fn_index=fn_index)
70
-
71
- # see if immediately failed
72
- e = job.future._exception
73
- if e is not None:
74
- print("GR job failed: %s %s" % (str(e), ''.join(traceback.format_tb(e.__traceback__))), flush=True)
75
- # force reconfig in case only that
76
- self.refresh_client()
77
- job = super().submit(*args, api_name=api_name, fn_index=fn_index)
78
- e2 = job.future._exception
79
- if e2 is not None:
80
- print("GR job failed again: %s\n%s" % (str(e2), ''.join(traceback.format_tb(e2.__traceback__))), flush=True)
81
-
82
- return job
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_utils/prompt_form.py DELETED
@@ -1,118 +0,0 @@
1
- import os
2
- import math
3
-
4
- import gradio as gr
5
-
6
-
7
- def make_chatbots(output_label0, output_label0_model2, **kwargs):
8
- text_outputs = []
9
- chat_kwargs = []
10
- for model_state_lock in kwargs['model_states']:
11
- if os.environ.get('DEBUG_MODEL_LOCK'):
12
- model_name = model_state_lock["base_model"] + " : " + model_state_lock["inference_server"]
13
- else:
14
- model_name = model_state_lock["base_model"]
15
- output_label = f'h2oGPT [{model_name}]'
16
- min_width = 250 if kwargs['gradio_size'] in ['small', 'large', 'medium'] else 160
17
- chat_kwargs.append(dict(label=output_label, visible=kwargs['model_lock'], elem_classes='chatsmall',
18
- height=kwargs['height'] or 400, min_width=min_width))
19
-
20
- if kwargs['model_lock_columns'] == -1:
21
- kwargs['model_lock_columns'] = len(kwargs['model_states'])
22
- if kwargs['model_lock_columns'] is None:
23
- kwargs['model_lock_columns'] = 3
24
-
25
- ncols = kwargs['model_lock_columns']
26
- if kwargs['model_states'] == 0:
27
- nrows = 0
28
- else:
29
- nrows = math.ceil(len(kwargs['model_states']) / kwargs['model_lock_columns'])
30
-
31
- if kwargs['model_lock_columns'] == 0:
32
- # not using model_lock
33
- pass
34
- elif nrows <= 1:
35
- with gr.Row():
36
- for chat_kwargs1, model_state_lock in zip(chat_kwargs, kwargs['model_states']):
37
- text_outputs.append(gr.Chatbot(**chat_kwargs1))
38
- elif nrows == kwargs['model_states']:
39
- with gr.Row():
40
- for chat_kwargs1, model_state_lock in zip(chat_kwargs, kwargs['model_states']):
41
- text_outputs.append(gr.Chatbot(**chat_kwargs1))
42
- elif nrows == 2:
43
- with gr.Row():
44
- for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
45
- if mii >= len(kwargs['model_states']) / 2:
46
- continue
47
- text_outputs.append(gr.Chatbot(**chat_kwargs1))
48
- with gr.Row():
49
- for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
50
- if mii < len(kwargs['model_states']) / 2:
51
- continue
52
- text_outputs.append(gr.Chatbot(**chat_kwargs1))
53
- elif nrows == 3:
54
- with gr.Row():
55
- for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
56
- if mii >= 1 * len(kwargs['model_states']) / 3:
57
- continue
58
- text_outputs.append(gr.Chatbot(**chat_kwargs1))
59
- with gr.Row():
60
- for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
61
- if mii < 1 * len(kwargs['model_states']) / 3 or mii >= 2 * len(kwargs['model_states']) / 3:
62
- continue
63
- text_outputs.append(gr.Chatbot(**chat_kwargs1))
64
- with gr.Row():
65
- for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
66
- if mii < 2 * len(kwargs['model_states']) / 3:
67
- continue
68
- text_outputs.append(gr.Chatbot(**chat_kwargs1))
69
- elif nrows >= 4:
70
- with gr.Row():
71
- for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
72
- if mii >= 1 * len(kwargs['model_states']) / 4:
73
- continue
74
- text_outputs.append(gr.Chatbot(**chat_kwargs1))
75
- with gr.Row():
76
- for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
77
- if mii < 1 * len(kwargs['model_states']) / 4 or mii >= 2 * len(kwargs['model_states']) / 4:
78
- continue
79
- text_outputs.append(gr.Chatbot(**chat_kwargs1))
80
- with gr.Row():
81
- for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
82
- if mii < 2 * len(kwargs['model_states']) / 4 or mii >= 3 * len(kwargs['model_states']) / 4:
83
- continue
84
- text_outputs.append(gr.Chatbot(**chat_kwargs1))
85
- with gr.Row():
86
- for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
87
- if mii < 3 * len(kwargs['model_states']) / 4:
88
- continue
89
- text_outputs.append(gr.Chatbot(**chat_kwargs1))
90
-
91
- with gr.Row():
92
- text_output = gr.Chatbot(label=output_label0, visible=not kwargs['model_lock'], height=kwargs['height'] or 400)
93
- text_output2 = gr.Chatbot(label=output_label0_model2,
94
- visible=False and not kwargs['model_lock'], height=kwargs['height'] or 400)
95
- return text_output, text_output2, text_outputs
96
-
97
-
98
- def make_prompt_form(kwargs):
99
- if kwargs['input_lines'] > 1:
100
- instruction_label = "Shift-Enter to Submit, Enter for more lines"
101
- else:
102
- instruction_label = "Enter to Submit, Shift-Enter for more lines"
103
-
104
- with gr.Row():#elem_id='prompt-form-area'):
105
- with gr.Column(scale=50):
106
- instruction = gr.Textbox(
107
- lines=kwargs['input_lines'],
108
- label='Ask anything',
109
- placeholder=instruction_label,
110
- info=None,
111
- elem_id='prompt-form',
112
- container=True,
113
- )
114
- with gr.Row():
115
- submit = gr.Button(value='Submit', variant='primary', scale=0, size='sm')
116
- stop_btn = gr.Button(value="Stop", variant='secondary', scale=0, size='sm')
117
-
118
- return instruction, submit, stop_btn
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
h2o-logo.svg DELETED
h2oai_pipeline.py CHANGED
@@ -136,6 +136,7 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
136
  else:
137
  outputs = rec['generated_text']
138
  rec['generated_text'] = outputs
 
139
  return records
140
 
141
  def _forward(self, model_inputs, **generate_kwargs):
 
136
  else:
137
  outputs = rec['generated_text']
138
  rec['generated_text'] = outputs
139
+ print("prompt: %s\noutputs: %s\n\n" % (self.prompt_text, outputs), flush=True)
140
  return records
141
 
142
  def _forward(self, model_inputs, **generate_kwargs):
iterators/__init__.py DELETED
@@ -1,4 +0,0 @@
1
- from .timeout_iterator import TimeoutIterator, AsyncTimeoutIterator
2
- from .iterator_pipe import IteratorPipe, AsyncIteratorPipe
3
-
4
- __all__ = ["TimeoutIterator", "AsyncTimeoutIterator", "IteratorPipe", "AsyncIteratorPipe"]
 
 
 
 
 
iterators/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (337 Bytes)
 
iterators/__pycache__/iterator_pipe.cpython-310.pyc DELETED
Binary file (2.71 kB)
 
iterators/__pycache__/timeout_iterator.cpython-310.pyc DELETED
Binary file (5.63 kB)
 
iterators/iterator_pipe.py DELETED
@@ -1,93 +0,0 @@
1
- import queue
2
- import asyncio
3
-
4
-
5
- class IteratorPipe:
6
- """
7
- Iterator Pipe creates an iterator that can be fed in data from another block of code or thread of execution
8
- """
9
-
10
- def __init__(self, sentinel=object()):
11
- self._q = queue.Queue()
12
- self._sentinel = sentinel
13
- self._sentinel_pushed = False
14
- self._closed = False
15
-
16
- def __iter__(self):
17
- return self
18
-
19
- def __next__(self):
20
- if self._closed:
21
- raise StopIteration
22
-
23
- data = self._q.get(block=True)
24
- if data is self._sentinel:
25
- self._closed = True
26
- raise StopIteration
27
-
28
- return data
29
-
30
- def put(self, data) -> bool:
31
- """
32
- Pushes next item to Iterator and returns True
33
- If iterator has been closed via close(), doesn't push anything and returns False
34
- """
35
- if self._sentinel_pushed:
36
- return False
37
-
38
- self._q.put(data)
39
- return True
40
-
41
- def close(self):
42
- """
43
- Close is idempotent. Calling close multiple times is safe
44
- Iterator will raise StopIteration only after all elements pushed before close have been iterated
45
- """
46
- # make close idempotent
47
- if not self._sentinel_pushed:
48
- self._sentinel_pushed = True
49
- self._q.put(self._sentinel)
50
-
51
-
52
- class AsyncIteratorPipe:
53
-
54
- def __init__(self, sentinel=object()):
55
- self._q = asyncio.Queue()
56
- self._sentinel = sentinel
57
- self._sentinel_pushed = False
58
- self._closed = False
59
-
60
- def __aiter__(self):
61
- return self
62
-
63
- async def __anext__(self):
64
- if self._closed:
65
- raise StopAsyncIteration
66
-
67
- data = await self._q.get()
68
- if data is self._sentinel:
69
- self._closed = True
70
- raise StopAsyncIteration
71
-
72
- return data
73
-
74
- async def put(self, data) -> bool:
75
- """
76
- Pushes next item to Iterator and returns True
77
- If iterator has been closed via close(), doesn't push anything and returns False
78
- """
79
- if self._sentinel_pushed:
80
- return False
81
-
82
- await self._q.put(data)
83
- return True
84
-
85
- async def close(self):
86
- """
87
- Close is idempotent. Calling close multiple times is safe
88
- Iterator will raise StopIteration only after all elements pushed before close have been iterated
89
- """
90
- # make close idempotent
91
- if not self._sentinel_pushed:
92
- self._sentinel_pushed = True
93
- await self._q.put(self._sentinel)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
iterators/timeout_iterator.py DELETED
@@ -1,170 +0,0 @@
1
- import queue
2
- import asyncio
3
- import threading
4
- import traceback
5
-
6
-
7
- class TimeoutIterator:
8
- """
9
- Wrapper class to add timeout feature to synchronous iterators
10
- - timeout: timeout for next(). Default=ZERO_TIMEOUT i.e. no timeout or blocking calls to next. Updated using set_timeout()
11
- - sentinel: the object returned by iterator when timeout happens
12
- - reset_on_next: if set to True, timeout is reset to the value of ZERO_TIMEOUT on each iteration
13
-
14
- TimeoutIterator uses a thread internally.
15
- The thread stops once the iterator exhausts or raises an exception during iteration.
16
-
17
- Any exceptions raised within the wrapped iterator are propagated as it is.
18
- Exception is raised when all elements generated by the actual iterator before exception have been consumed
19
- Timeout can be set dynamically before going for iteration
20
- """
21
- ZERO_TIMEOUT = 0.0
22
-
23
- def __init__(self, iterator, timeout=0.0, sentinel=object(), reset_on_next=False, raise_on_exception=True):
24
- self._iterator = iterator
25
- self._timeout = timeout
26
- self._sentinel = sentinel
27
- self._reset_on_next = reset_on_next
28
- self._raise_on_exception = raise_on_exception
29
-
30
- self._interrupt = False
31
- self._done = False
32
- self._buffer = queue.Queue()
33
- self._thread = threading.Thread(target=self.__lookahead)
34
- self._thread.start()
35
-
36
- def get_sentinel(self):
37
- return self._sentinel
38
-
39
- def set_reset_on_next(self, reset_on_next):
40
- self._reset_on_next = reset_on_next
41
-
42
- def set_timeout(self, timeout: float):
43
- """
44
- Set timeout for next iteration
45
- """
46
- self._timeout = timeout
47
-
48
- def interrupt(self):
49
- """
50
- interrupt and stop the underlying thread.
51
- the thread acutally dies only after interrupt has been set and
52
- the underlying iterator yields a value after that.
53
- """
54
- self._interrupt = True
55
-
56
- def __iter__(self):
57
- return self
58
-
59
- def __next__(self):
60
- """
61
- yield the result from iterator
62
- if timeout > 0:
63
- yield data if available.
64
- otherwise yield sentinal
65
- """
66
- if self._done:
67
- raise StopIteration
68
-
69
- data = self._sentinel
70
- try:
71
- if self._timeout > self.ZERO_TIMEOUT:
72
- data = self._buffer.get(timeout=self._timeout)
73
- else:
74
- data = self._buffer.get()
75
- except queue.Empty:
76
- pass
77
- finally:
78
- # see if timeout needs to be reset
79
- if self._reset_on_next:
80
- self._timeout = self.ZERO_TIMEOUT
81
-
82
- # propagate any exceptions including StopIteration
83
- if isinstance(data, BaseException):
84
- self._done = True
85
- if isinstance(data, StopIteration):
86
- raise data
87
- ex = ''.join(traceback.format_tb(data.__traceback__))
88
- print("Generation Failed: %s %s" % (str(data), str(ex)), flush=True)
89
- if self._raise_on_exception:
90
- raise data
91
- else:
92
- return data
93
-
94
- return data
95
-
96
- def __lookahead(self):
97
- try:
98
- while True:
99
- self._buffer.put(next(self._iterator))
100
- if self._interrupt:
101
- raise StopIteration()
102
- except BaseException as e:
103
- self._buffer.put(e)
104
-
105
-
106
- class AsyncTimeoutIterator:
107
- """
108
- Async version of TimeoutIterator. See method documentation of TimeoutIterator
109
- """
110
- ZERO_TIMEOUT = 0.0
111
-
112
- def __init__(self, iterator, timeout=0.0, sentinel=object(), reset_on_next=False):
113
- self._iterator = iterator
114
- self._timeout = timeout
115
- self._sentinel = sentinel
116
- self._reset_on_next = reset_on_next
117
-
118
- self._interrupt = False
119
- self._done = False
120
- self._buffer = asyncio.Queue()
121
- self._task = asyncio.get_event_loop().create_task(self.__lookahead())
122
-
123
- def get_sentinel(self):
124
- return self._sentinel
125
-
126
- def set_reset_on_next(self, reset_on_next):
127
- self._reset_on_next = reset_on_next
128
-
129
- def set_timeout(self, timeout: float):
130
- self._timeout = timeout
131
-
132
- def interrupt(self):
133
- self._interrupt = True
134
-
135
- def __aiter__(self):
136
- return self
137
-
138
- async def __anext__(self):
139
- if self._done:
140
- raise StopAsyncIteration
141
-
142
- data = self._sentinel
143
- try:
144
- if self._timeout > self.ZERO_TIMEOUT:
145
- data = await asyncio.wait_for(self._buffer.get(), self._timeout)
146
- else:
147
- data = await self._buffer.get()
148
- except asyncio.TimeoutError:
149
- pass
150
- finally:
151
- # see if timeout needs to be reset
152
- if self._reset_on_next:
153
- self._timeout = self.ZERO_TIMEOUT
154
-
155
- # propagate any exceptions including StopIteration
156
- if isinstance(data, BaseException):
157
- self._done = True
158
- raise data
159
-
160
- return data
161
-
162
- async def __lookahead(self):
163
- try:
164
- while True:
165
- data = await self._iterator.__anext__()
166
- await self._buffer.put(data)
167
- if self._interrupt:
168
- raise StopAsyncIteration()
169
- except BaseException as e:
170
- await self._buffer.put(e)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prompter.py CHANGED
@@ -120,7 +120,7 @@ def get_prompt(prompt_type, prompt_dict, chat, context, reduced, making_context,
120
  elif prompt_type in [PromptType.custom.value, str(PromptType.custom.value),
121
  PromptType.custom.name]:
122
  promptA = prompt_dict.get('promptA', '')
123
- promptB = prompt_dict('promptB', '')
124
  PreInstruct = prompt_dict.get('PreInstruct', '')
125
  PreInput = prompt_dict.get('PreInput', '')
126
  PreResponse = prompt_dict.get('PreResponse', '')
@@ -693,7 +693,9 @@ class Prompter(object):
693
  output = clean_response(output)
694
  elif prompt is None:
695
  # then use most basic parsing like pipeline
696
- if self.botstr in output:
 
 
697
  if self.humanstr:
698
  output = clean_response(output.split(self.botstr)[1].split(self.humanstr)[0])
699
  else:
 
120
  elif prompt_type in [PromptType.custom.value, str(PromptType.custom.value),
121
  PromptType.custom.name]:
122
  promptA = prompt_dict.get('promptA', '')
123
+ promptB = prompt_dict.get('promptB', '')
124
  PreInstruct = prompt_dict.get('PreInstruct', '')
125
  PreInput = prompt_dict.get('PreInput', '')
126
  PreResponse = prompt_dict.get('PreResponse', '')
 
693
  output = clean_response(output)
694
  elif prompt is None:
695
  # then use most basic parsing like pipeline
696
+ if not self.botstr:
697
+ pass
698
+ elif self.botstr in output:
699
  if self.humanstr:
700
  output = clean_response(output.split(self.botstr)[1].split(self.humanstr)[0])
701
  else:
requirements.txt CHANGED
@@ -1,153 +0,0 @@
1
- # for generate (gradio server) and finetune
2
- datasets==2.13.0
3
- sentencepiece==0.1.99
4
- gradio==3.35.2
5
- huggingface_hub==0.15.1
6
- appdirs==1.4.4
7
- fire==0.5.0
8
- docutils==0.20.1
9
- torch==2.0.1
10
- evaluate==0.4.0
11
- rouge_score==0.1.2
12
- sacrebleu==2.3.1
13
- scikit-learn==1.2.2
14
- alt-profanity-check==1.2.2
15
- better-profanity==0.7.0
16
- numpy==1.24.3
17
- pandas==2.0.2
18
- matplotlib==3.7.1
19
- loralib==0.1.1
20
- bitsandbytes==0.39.0
21
- accelerate==0.20.3
22
- git+https://github.com/huggingface/peft.git@0b62b4378b4ce9367932c73540349da9a41bdea8
23
- transformers==4.30.2
24
- tokenizers==0.13.3
25
- APScheduler==3.10.1
26
-
27
- # optional for generate
28
- pynvml==11.5.0
29
- psutil==5.9.5
30
- boto3==1.26.101
31
- botocore==1.29.101
32
-
33
- # optional for finetune
34
- tensorboard==2.13.0
35
- neptune==1.2.0
36
-
37
- # for gradio client
38
- gradio_client==0.2.7
39
- beautifulsoup4==4.12.2
40
- markdown==3.4.3
41
-
42
- # data and testing
43
- pytest==7.2.2
44
- pytest-xdist==3.2.1
45
- nltk==3.8.1
46
- textstat==0.7.3
47
- # pandoc==2.3
48
- #pypandoc==1.11
49
- pypandoc_binary==1.11
50
- openpyxl==3.1.2
51
- lm_dataformat==0.0.20
52
- bioc==2.0
53
-
54
- # falcon
55
- einops==0.6.1
56
- instructorembedding==1.0.1
57
-
58
- # for gpt4all .env file, but avoid worrying about imports
59
- python-dotenv==1.0.0
60
-
61
- text-generation==0.6.0
62
- # for tokenization when don't have HF tokenizer
63
- tiktoken==0.4.0
64
- # optional: for OpenAI endpoint or embeddings (requires key)
65
- openai==0.27.8
66
- # optional for chat with PDF
67
- langchain==0.0.202
68
- pypdf==3.9.1
69
- # avoid textract, requires old six
70
- #textract==1.6.5
71
-
72
- # for HF embeddings
73
- sentence_transformers==2.2.2
74
-
75
- # local vector db
76
- chromadb==0.3.25
77
- # server vector db
78
- #pymilvus==2.2.8
79
-
80
- # weak url support, if can't install opencv etc. If comment-in this one, then comment-out unstructured[local-inference]==0.6.6
81
- # unstructured==0.6.6
82
-
83
- # strong support for images
84
- # Requires on Ubuntu: sudo apt-get install libmagic-dev poppler-utils tesseract-ocr libreoffice
85
- unstructured[local-inference]==0.7.4
86
- #pdf2image==1.16.3
87
- #pytesseract==0.3.10
88
- pillow
89
-
90
- pdfminer.six==20221105
91
- urllib3
92
- requests_file
93
-
94
- #pdf2image==1.16.3
95
- #pytesseract==0.3.10
96
- tabulate==0.9.0
97
- # FYI pandoc already part of requirements.txt
98
-
99
- # JSONLoader, but makes some trouble for some users
100
- # jq==1.4.1
101
-
102
- # to check licenses
103
- # Run: pip-licenses|grep -v 'BSD\|Apache\|MIT'
104
- pip-licenses==4.3.0
105
-
106
- # weaviate vector db
107
- weaviate-client==3.20.0
108
- # optional for chat with PDF
109
- langchain==0.0.202
110
- pypdf==3.9.1
111
- # avoid textract, requires old six
112
- #textract==1.6.5
113
-
114
- # for HF embeddings
115
- sentence_transformers==2.2.2
116
-
117
- # local vector db
118
- chromadb==0.3.25
119
- # server vector db
120
- #pymilvus==2.2.8
121
-
122
- # weak url support, if can't install opencv etc. If comment-in this one, then comment-out unstructured[local-inference]==0.6.6
123
- # unstructured==0.6.6
124
-
125
- # strong support for images
126
- # Requires on Ubuntu: sudo apt-get install libmagic-dev poppler-utils tesseract-ocr libreoffice
127
- unstructured[local-inference]==0.7.4
128
- #pdf2image==1.16.3
129
- #pytesseract==0.3.10
130
- pillow
131
-
132
- pdfminer.six==20221105
133
- urllib3
134
- requests_file
135
-
136
- #pdf2image==1.16.3
137
- #pytesseract==0.3.10
138
- tabulate==0.9.0
139
- # FYI pandoc already part of requirements.txt
140
-
141
- # JSONLoader, but makes some trouble for some users
142
- # jq==1.4.1
143
-
144
- # to check licenses
145
- # Run: pip-licenses|grep -v 'BSD\|Apache\|MIT'
146
- pip-licenses==4.3.0
147
-
148
- # weaviate vector db
149
- weaviate-client==3.20.0
150
- faiss-gpu==1.7.2
151
- arxiv==1.4.7
152
- pymupdf==1.22.3 # AGPL license
153
- # extract-msg==0.41.1 # GPL3