pseudotensor commited on
Commit
454e203
1 Parent(s): eac73aa

Update with h2oGPT hash ad9d685b188cece0b9c69716ea8e320b74f0caf7

Browse files
client_test.py CHANGED
@@ -48,7 +48,7 @@ import markdown # pip install markdown
48
  import pytest
49
  from bs4 import BeautifulSoup # pip install beautifulsoup4
50
 
51
- from enums import DocumentChoices, LangChainAction
52
 
53
  debug = False
54
 
@@ -68,7 +68,9 @@ def get_args(prompt, prompt_type, chat=False, stream_output=False,
68
  max_new_tokens=50,
69
  top_k_docs=3,
70
  langchain_mode='Disabled',
 
71
  langchain_action=LangChainAction.QUERY.value,
 
72
  prompt_dict=None):
73
  from collections import OrderedDict
74
  kwargs = OrderedDict(instruction=prompt if chat else '', # only for chat=True
@@ -94,11 +96,13 @@ def get_args(prompt, prompt_type, chat=False, stream_output=False,
94
  instruction_nochat=prompt if not chat else '',
95
  iinput_nochat='', # only for chat=False
96
  langchain_mode=langchain_mode,
 
97
  langchain_action=langchain_action,
 
98
  top_k_docs=top_k_docs,
99
  chunk=True,
100
  chunk_size=512,
101
- document_subset=DocumentChoices.Relevant.name,
102
  document_choice=[],
103
  )
104
  from evaluate_params import eval_func_param_names
@@ -202,9 +206,11 @@ def run_client_nochat_api_lean_morestuff(prompt, prompt_type='human_bot', max_ne
202
  instruction_nochat=prompt,
203
  iinput_nochat='',
204
  langchain_mode='Disabled',
 
205
  langchain_action=LangChainAction.QUERY.value,
 
206
  top_k_docs=4,
207
- document_subset=DocumentChoices.Relevant.name,
208
  document_choice=[],
209
  )
210
 
@@ -225,23 +231,30 @@ def run_client_nochat_api_lean_morestuff(prompt, prompt_type='human_bot', max_ne
225
  @pytest.mark.skip(reason="For manual use against some server, no server launched")
226
  def test_client_chat(prompt_type='human_bot'):
227
  return run_client_chat(prompt='Who are you?', prompt_type=prompt_type, stream_output=False, max_new_tokens=50,
228
- langchain_mode='Disabled', langchain_action=LangChainAction.QUERY.value)
 
 
229
 
230
 
231
  @pytest.mark.skip(reason="For manual use against some server, no server launched")
232
  def test_client_chat_stream(prompt_type='human_bot'):
233
  return run_client_chat(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type,
234
  stream_output=True, max_new_tokens=512,
235
- langchain_mode='Disabled', langchain_action=LangChainAction.QUERY.value)
 
 
236
 
237
 
238
- def run_client_chat(prompt, prompt_type, stream_output, max_new_tokens, langchain_mode, langchain_action,
 
239
  prompt_dict=None):
240
  client = get_client(serialize=False)
241
 
242
  kwargs, args = get_args(prompt, prompt_type, chat=True, stream_output=stream_output,
243
- max_new_tokens=max_new_tokens, langchain_mode=langchain_mode,
 
244
  langchain_action=langchain_action,
 
245
  prompt_dict=prompt_dict)
246
  return run_client(client, prompt, args, kwargs)
247
 
@@ -285,15 +298,18 @@ def run_client(client, prompt, args, kwargs, do_md_to_text=True, verbose=False):
285
  def test_client_nochat_stream(prompt_type='human_bot'):
286
  return run_client_nochat_gen(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type,
287
  stream_output=True, max_new_tokens=512,
288
- langchain_mode='Disabled', langchain_action=LangChainAction.QUERY.value)
 
 
289
 
290
 
291
- def run_client_nochat_gen(prompt, prompt_type, stream_output, max_new_tokens, langchain_mode, langchain_action):
 
292
  client = get_client(serialize=False)
293
 
294
  kwargs, args = get_args(prompt, prompt_type, chat=False, stream_output=stream_output,
295
  max_new_tokens=max_new_tokens, langchain_mode=langchain_mode,
296
- langchain_action=langchain_action)
297
  return run_client_gen(client, prompt, args, kwargs)
298
 
299
 
 
48
  import pytest
49
  from bs4 import BeautifulSoup # pip install beautifulsoup4
50
 
51
+ from enums import DocumentSubset, LangChainAction
52
 
53
  debug = False
54
 
 
68
  max_new_tokens=50,
69
  top_k_docs=3,
70
  langchain_mode='Disabled',
71
+ add_chat_history_to_context=True,
72
  langchain_action=LangChainAction.QUERY.value,
73
+ langchain_agents=[],
74
  prompt_dict=None):
75
  from collections import OrderedDict
76
  kwargs = OrderedDict(instruction=prompt if chat else '', # only for chat=True
 
96
  instruction_nochat=prompt if not chat else '',
97
  iinput_nochat='', # only for chat=False
98
  langchain_mode=langchain_mode,
99
+ add_chat_history_to_context=add_chat_history_to_context,
100
  langchain_action=langchain_action,
101
+ langchain_agents=langchain_agents,
102
  top_k_docs=top_k_docs,
103
  chunk=True,
104
  chunk_size=512,
105
+ document_subset=DocumentSubset.Relevant.name,
106
  document_choice=[],
107
  )
108
  from evaluate_params import eval_func_param_names
 
206
  instruction_nochat=prompt,
207
  iinput_nochat='',
208
  langchain_mode='Disabled',
209
+ add_chat_history_to_context=True,
210
  langchain_action=LangChainAction.QUERY.value,
211
+ langchain_agents=[],
212
  top_k_docs=4,
213
+ document_subset=DocumentSubset.Relevant.name,
214
  document_choice=[],
215
  )
216
 
 
231
  @pytest.mark.skip(reason="For manual use against some server, no server launched")
232
  def test_client_chat(prompt_type='human_bot'):
233
  return run_client_chat(prompt='Who are you?', prompt_type=prompt_type, stream_output=False, max_new_tokens=50,
234
+ langchain_mode='Disabled',
235
+ langchain_action=LangChainAction.QUERY.value,
236
+ langchain_agents=[])
237
 
238
 
239
  @pytest.mark.skip(reason="For manual use against some server, no server launched")
240
  def test_client_chat_stream(prompt_type='human_bot'):
241
  return run_client_chat(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type,
242
  stream_output=True, max_new_tokens=512,
243
+ langchain_mode='Disabled',
244
+ langchain_action=LangChainAction.QUERY.value,
245
+ langchain_agents=[])
246
 
247
 
248
+ def run_client_chat(prompt, prompt_type, stream_output, max_new_tokens,
249
+ langchain_mode, langchain_action, langchain_agents,
250
  prompt_dict=None):
251
  client = get_client(serialize=False)
252
 
253
  kwargs, args = get_args(prompt, prompt_type, chat=True, stream_output=stream_output,
254
+ max_new_tokens=max_new_tokens,
255
+ langchain_mode=langchain_mode,
256
  langchain_action=langchain_action,
257
+ langchain_agents=langchain_agents,
258
  prompt_dict=prompt_dict)
259
  return run_client(client, prompt, args, kwargs)
260
 
 
298
  def test_client_nochat_stream(prompt_type='human_bot'):
299
  return run_client_nochat_gen(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type,
300
  stream_output=True, max_new_tokens=512,
301
+ langchain_mode='Disabled',
302
+ langchain_action=LangChainAction.QUERY.value,
303
+ langchain_agents=[])
304
 
305
 
306
+ def run_client_nochat_gen(prompt, prompt_type, stream_output, max_new_tokens,
307
+ langchain_mode, langchain_action, langchain_agents):
308
  client = get_client(serialize=False)
309
 
310
  kwargs, args = get_args(prompt, prompt_type, chat=False, stream_output=stream_output,
311
  max_new_tokens=max_new_tokens, langchain_mode=langchain_mode,
312
+ langchain_action=langchain_action, langchain_agents=langchain_agents)
313
  return run_client_gen(client, prompt, args, kwargs)
314
 
315
 
enums.py CHANGED
@@ -31,25 +31,30 @@ class PromptType(Enum):
31
  mptinstruct = 25
32
  mptchat = 26
33
  falcon = 27
 
 
34
 
35
 
36
- class DocumentChoices(Enum):
37
  Relevant = 0
38
- Sources = 1
39
- All = 2
40
 
41
 
42
  non_query_commands = [
43
- DocumentChoices.Sources.name,
44
- DocumentChoices.All.name
45
  ]
46
 
47
 
 
 
 
 
48
  class LangChainMode(Enum):
49
  """LangChain mode"""
50
 
51
  DISABLED = "Disabled"
52
- CHAT_LLM = "ChatLLM"
53
  LLM = "LLM"
54
  ALL = "All"
55
  WIKI = "wiki"
@@ -60,6 +65,12 @@ class LangChainMode(Enum):
60
  H2O_DAI_DOCS = "DriverlessAI docs"
61
 
62
 
 
 
 
 
 
 
63
  class LangChainAction(Enum):
64
  """LangChain action"""
65
 
@@ -71,6 +82,13 @@ class LangChainAction(Enum):
71
  SUMMARIZE_REFINE = "Summarize_refine"
72
 
73
 
 
 
 
 
 
 
 
74
  no_server_str = no_lora_str = no_model_str = '[None/Remove]'
75
 
76
  # from site-packages/langchain/llms/openai.py
 
31
  mptinstruct = 25
32
  mptchat = 26
33
  falcon = 27
34
+ guanaco = 28
35
+ llama2 = 29
36
 
37
 
38
+ class DocumentSubset(Enum):
39
  Relevant = 0
40
+ RelSources = 1
41
+ TopKSources = 2
42
 
43
 
44
  non_query_commands = [
45
+ DocumentSubset.RelSources.name,
46
+ DocumentSubset.TopKSources.name
47
  ]
48
 
49
 
50
+ class DocumentChoice(Enum):
51
+ ALL = 'All'
52
+
53
+
54
  class LangChainMode(Enum):
55
  """LangChain mode"""
56
 
57
  DISABLED = "Disabled"
 
58
  LLM = "LLM"
59
  ALL = "All"
60
  WIKI = "wiki"
 
65
  H2O_DAI_DOCS = "DriverlessAI docs"
66
 
67
 
68
+ # modes should not be removed from visible list or added by name
69
+ langchain_modes_intrinsic = [LangChainMode.DISABLED.value,
70
+ LangChainMode.LLM.value,
71
+ LangChainMode.MY_DATA.value]
72
+
73
+
74
  class LangChainAction(Enum):
75
  """LangChain action"""
76
 
 
82
  SUMMARIZE_REFINE = "Summarize_refine"
83
 
84
 
85
+ class LangChainAgent(Enum):
86
+ """LangChain agents"""
87
+
88
+ SEARCH = "Search"
89
+ # CSV = "csv" # WIP
90
+
91
+
92
  no_server_str = no_lora_str = no_model_str = '[None/Remove]'
93
 
94
  # from site-packages/langchain/llms/openai.py
evaluate_params.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  no_default_param_names = [
2
  'instruction',
3
  'iinput',
@@ -30,7 +33,9 @@ eval_func_param_names = ['instruction',
30
  'instruction_nochat',
31
  'iinput_nochat',
32
  'langchain_mode',
 
33
  'langchain_action',
 
34
  'top_k_docs',
35
  'chunk',
36
  'chunk_size',
 
1
+ input_args_list = ['model_state', 'my_db_state', 'selection_docs_state']
2
+
3
+
4
  no_default_param_names = [
5
  'instruction',
6
  'iinput',
 
33
  'instruction_nochat',
34
  'iinput_nochat',
35
  'langchain_mode',
36
+ 'add_chat_history_to_context',
37
  'langchain_action',
38
+ 'langchain_agents',
39
  'top_k_docs',
40
  'chunk',
41
  'chunk_size',
gen.py CHANGED
@@ -8,7 +8,6 @@ import sys
8
  import os
9
  import time
10
  import traceback
11
- import types
12
  import typing
13
  import warnings
14
  from datetime import datetime
@@ -28,12 +27,12 @@ os.environ['BITSANDBYTES_NOWELCOME'] = '1'
28
  warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
29
 
30
  from evaluate_params import eval_func_param_names, no_default_param_names
31
- from enums import DocumentChoices, LangChainMode, no_lora_str, model_token_mapping, no_model_str, source_prefix, \
32
- source_postfix, LangChainAction
33
  from loaders import get_loaders
34
  from utils import set_seed, clear_torch_cache, save_generate_output, NullContext, wrapped_partial, EThread, get_githash, \
35
  import_matplotlib, get_device, makedirs, get_kwargs, start_faulthandler, get_hf_server, FakeTokenizer, remove, \
36
- have_langchain
37
 
38
  start_faulthandler()
39
  import_matplotlib()
@@ -50,10 +49,10 @@ from transformers import GenerationConfig, AutoModel, TextIteratorStreamer
50
  from prompter import Prompter, inv_prompt_type_to_model_lower, non_hf_types, PromptType, get_prompt, generate_prompt
51
  from stopping import get_stopping
52
 
53
- langchain_modes = [x.value for x in list(LangChainMode)]
54
-
55
  langchain_actions = [x.value for x in list(LangChainAction)]
56
 
 
 
57
  scratch_base_dir = '/tmp/'
58
 
59
 
@@ -114,6 +113,7 @@ def main(
114
  show_examples: bool = None,
115
  verbose: bool = False,
116
  h2ocolors: bool = True,
 
117
  height: int = 600,
118
  show_lora: bool = True,
119
  login_mode_if_model0: bool = False,
@@ -134,7 +134,7 @@ def main(
134
  extra_lora_options: typing.List[str] = [],
135
  extra_server_options: typing.List[str] = [],
136
 
137
- score_model: str = 'OpenAssistant/reward-model-deberta-v3-large-v2',
138
 
139
  eval_filename: str = None,
140
  eval_prompts_only_num: int = 0,
@@ -143,22 +143,30 @@ def main(
143
 
144
  langchain_mode: str = None,
145
  langchain_action: str = LangChainAction.QUERY.value,
 
146
  force_langchain_evaluate: bool = False,
 
147
  visible_langchain_modes: list = ['UserData', 'MyData'],
148
  # WIP:
149
  # visible_langchain_actions: list = langchain_actions.copy(),
150
  visible_langchain_actions: list = [LangChainAction.QUERY.value, LangChainAction.SUMMARIZE_MAP.value],
151
- document_subset: str = DocumentChoices.Relevant.name,
152
- document_choice: list = [],
 
153
  user_path: str = None,
 
154
  detect_user_path_changes_every_query: bool = False,
 
155
  load_db_if_exists: bool = True,
156
  keep_sources_in_context: bool = False,
157
  db_type: str = 'chroma',
158
  use_openai_embedding: bool = False,
159
  use_openai_model: bool = False,
160
  hf_embedding_model: str = None,
 
 
161
  allow_upload_to_user_data: bool = True,
 
162
  allow_upload_to_my_data: bool = True,
163
  enable_url_upload: bool = True,
164
  enable_text_upload: bool = True,
@@ -175,6 +183,7 @@ def main(
175
  pre_load_caption_model: bool = False,
176
  caption_gpu: bool = True,
177
  enable_ocr: bool = False,
 
178
  ):
179
  """
180
 
@@ -196,6 +205,8 @@ def main(
196
  Or Address can be "openai_chat" or "openai" for OpenAI API
197
  e.g. python generate.py --inference_server="openai_chat" --base_model=gpt-3.5-turbo
198
  e.g. python generate.py --inference_server="openai" --base_model=text-davinci-003
 
 
199
  :param prompt_type: type of prompt, usually matched to fine-tuned model or plain for foundational model
200
  :param prompt_dict: If prompt_type=custom, then expects (some) items returned by get_prompt(..., return_dict=True)
201
  :param model_lock: Lock models to specific combinations, for ease of use and extending to many models
@@ -252,6 +263,7 @@ def main(
252
  :param show_examples: whether to show clickable examples in gradio
253
  :param verbose: whether to show verbose prints
254
  :param h2ocolors: whether to use H2O.ai theme
 
255
  :param height: height of chat window
256
  :param show_lora: whether to show LORA options in UI (expert so can be hard to understand)
257
  :param login_mode_if_model0: set to True to load --base_model after client logs in, to be able to free GPU memory when model is swapped
@@ -271,49 +283,73 @@ def main(
271
  :param extra_model_options: extra models to show in list in gradio
272
  :param extra_lora_options: extra LORA to show in list in gradio
273
  :param extra_server_options: extra servers to show in list in gradio
274
- :param score_model: which model to score responses (None means no scoring)
 
 
 
275
  :param eval_filename: json file to use for evaluation, if None is sharegpt
276
  :param eval_prompts_only_num: for no gradio benchmark, if using eval_filename prompts for eval instead of examples
277
  :param eval_prompts_only_seed: for no gradio benchmark, seed for eval_filename sampling
278
  :param eval_as_output: for no gradio benchmark, whether to test eval_filename output itself
279
  :param langchain_mode: Data source to include. Choose "UserData" to only consume files from make_db.py.
 
280
  WARNING: wiki_full requires extra data processing via read_wiki_full.py and requires really good workstation to generate db, unless already present.
281
  :param langchain_action: Mode langchain operations in on documents.
282
  Query: Make query of document(s)
283
  Summarize or Summarize_map_reduce: Summarize document(s) via map_reduce
284
  Summarize_all: Summarize document(s) using entire document at once
285
  Summarize_refine: Summarize document(s) using entire document, and try to refine before returning summary
 
 
286
  :param force_langchain_evaluate: Whether to force langchain LLM use even if not doing langchain, mostly for testing.
287
  :param user_path: user path to glob from to generate db for vector search, for 'UserData' langchain mode.
288
  If already have db, any new/changed files are added automatically if path set, does not have to be same path used for prior db sources
 
 
 
 
289
  :param detect_user_path_changes_every_query: whether to detect if any files changed or added every similarity search (by file hashes).
290
  Expensive for large number of files, so not done by default. By default only detect changes during db loading.
 
291
  :param visible_langchain_modes: dbs to generate at launch to be ready for LLM
292
  Can be up to ['wiki', 'wiki_full', 'UserData', 'MyData', 'github h2oGPT', 'DriverlessAI docs']
293
  But wiki_full is expensive and requires preparation
294
  To allow scratch space only live in session, add 'MyData' to list
295
  Default: If only want to consume local files, e.g. prepared by make_db.py, only include ['UserData']
296
- FIXME: Avoid 'All' for now, not implemented
 
 
 
 
297
  :param visible_langchain_actions: Which actions to allow
 
298
  :param document_subset: Default document choice when taking subset of collection
299
- :param document_choice: Chosen document(s) by internal name
 
300
  :param load_db_if_exists: Whether to load chroma db if exists or re-generate db
301
  :param keep_sources_in_context: Whether to keep url sources in context, not helpful usually
302
  :param db_type: 'faiss' for in-memory or 'chroma' or 'weaviate' for persisted on disk
303
  :param use_openai_embedding: Whether to use OpenAI embeddings for vector db
304
  :param use_openai_model: Whether to use OpenAI model for use with vector db
305
  :param hf_embedding_model: Which HF embedding model to use for vector db
306
- Default is instructor-large with 768 parameters per embedding if have GPUs, else all-MiniLM-L6-v1 if no GPUs
307
  Can also choose simpler model with 384 parameters per embedding: "sentence-transformers/all-MiniLM-L6-v2"
308
  Can also choose even better embedding with 1024 parameters: 'hkunlp/instructor-xl'
309
  We support automatically changing of embeddings for chroma, with a backup of db made if this is done
310
- :param allow_upload_to_user_data: Whether to allow file uploads to update shared vector db
 
 
 
 
 
 
 
311
  :param allow_upload_to_my_data: Whether to allow file uploads to update scratch vector db
312
  :param enable_url_upload: Whether to allow upload from URL
313
  :param enable_text_upload: Whether to allow upload of text
314
  :param enable_sources_list: Whether to allow list (or download for non-shared db) of list of sources for chosen db
315
  :param chunk: Whether to chunk data (True unless know data is already optimally chunked)
316
- :param chunk_size: Size of chunks, with typically top-4 passed to LLM, so neesd to be in context length
317
  :param top_k_docs: number of chunks to give LLM
318
  :param reverse_docs: whether to reverse docs order so most relevant is closest to question.
319
  Best choice for sufficiently smart model, and truncation occurs for oldest context, so best then too.
@@ -327,11 +363,15 @@ def main(
327
  captions_model: str = "Salesforce/blip2-flan-t5-xl", # question/answer capable, 16GB state
328
  captions_model: str = "Salesforce/blip2-flan-t5-xxl", # question/answer capable, 60GB state
329
  Note: opt-based blip2 are not permissive license due to opt and Meta license restrictions
 
330
  :param pre_load_caption_model: Whether to preload caption model, or load after forking parallel doc loader
331
  parallel loading disabled if preload and have images, to prevent deadlocking on cuda context
332
  Recommended if using larger caption model
333
  :param caption_gpu: If support caption, then use GPU if exists
334
  :param enable_ocr: Whether to support OCR on images
 
 
 
335
  :return:
336
  """
337
  if base_model is None:
@@ -393,7 +433,29 @@ def main(
393
  if langchain_mode is not None:
394
  visible_langchain_modes += [langchain_mode]
395
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396
  assert langchain_action in langchain_actions, "Invalid langchain_action %s" % langchain_action
 
 
397
 
398
  # if specifically chose not to show My or User Data, disable upload, so gradio elements are simpler
399
  if LangChainMode.MY_DATA.value not in visible_langchain_modes:
@@ -404,21 +466,22 @@ def main(
404
  # auto-set langchain_mode
405
  if have_langchain and langchain_mode is None:
406
  # start in chat mode, in case just want to chat and don't want to get "No documents to query" by default.
407
- langchain_mode = LangChainMode.CHAT_LLM.value
408
- if allow_upload_to_user_data and not is_public and user_path:
409
  print("Auto set langchain_mode=%s. Could use UserData instead." % langchain_mode, flush=True)
410
  elif allow_upload_to_my_data:
411
  print("Auto set langchain_mode=%s. Could use MyData instead."
412
  " To allow UserData to pull files from disk,"
413
- " set user_path and ensure allow_upload_to_user_data=True" % langchain_mode, flush=True)
 
414
  else:
415
  raise RuntimeError("Please pass --langchain_mode=<chosen mode> out of %s" % langchain_modes)
416
- if not have_langchain and langchain_mode not in [None, LangChainMode.DISABLED.value, LangChainMode.LLM.value, LangChainMode.CHAT_LLM.value]:
417
  raise RuntimeError("Asked for LangChain mode but langchain python package cannot be found.")
418
  if langchain_mode is None:
419
  # if not set yet, disable
420
  langchain_mode = LangChainMode.DISABLED.value
421
- print("Auto set langchain_mode=%s" % langchain_mode, flush=True)
422
 
423
  if is_public:
424
  allow_upload_to_user_data = False
@@ -474,7 +537,7 @@ def main(
474
  # HF accounted for later in get_max_max_new_tokens()
475
  save_dir = os.getenv('SAVE_DIR', save_dir)
476
  score_model = os.getenv('SCORE_MODEL', score_model)
477
- if score_model == 'None' or score_model is None:
478
  score_model = ''
479
  concurrency_count = int(os.getenv('CONCURRENCY_COUNT', concurrency_count))
480
  api_open = bool(int(os.getenv('API_OPEN', str(int(api_open)))))
@@ -482,6 +545,7 @@ def main(
482
 
483
  n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0
484
  if n_gpus == 0:
 
485
  gpu_id = None
486
  load_8bit = False
487
  load_4bit = False
@@ -499,7 +563,11 @@ def main(
499
  if hf_embedding_model is None:
500
  # if no GPUs, use simpler embedding model to avoid cost in time
501
  hf_embedding_model = "sentence-transformers/all-MiniLM-L6-v2"
 
 
502
  else:
 
 
503
  if hf_embedding_model is None:
504
  # if still None, then set default
505
  hf_embedding_model = 'hkunlp/instructor-large'
@@ -524,8 +592,6 @@ def main(
524
 
525
  if offload_folder:
526
  makedirs(offload_folder)
527
- if user_path:
528
- makedirs(user_path)
529
 
530
  placeholder_instruction, placeholder_input, \
531
  stream_output, show_examples, \
@@ -551,7 +617,7 @@ def main(
551
  verbose,
552
  )
553
 
554
- git_hash = get_githash()
555
  locals_dict = locals()
556
  locals_print = '\n'.join(['%s: %s' % (k, v) for k, v in locals_dict.items()])
557
  if verbose:
@@ -565,7 +631,7 @@ def main(
565
  get_some_dbs_from_hf()
566
  dbs = {}
567
  for langchain_mode1 in visible_langchain_modes:
568
- if langchain_mode1 in ['MyData']:
569
  # don't use what is on disk, remove it instead
570
  for gpath1 in glob.glob(os.path.join(scratch_base_dir, 'db_dir_%s*' % langchain_mode1)):
571
  if os.path.isdir(gpath1):
@@ -580,7 +646,7 @@ def main(
580
  db = prep_langchain(persist_directory1,
581
  load_db_if_exists,
582
  db_type, use_openai_embedding,
583
- langchain_mode1, user_path,
584
  hf_embedding_model,
585
  kwargs_make_db=locals())
586
  finally:
@@ -599,6 +665,14 @@ def main(
599
  model_state_none = dict(model=None, tokenizer=None, device=None,
600
  base_model=None, tokenizer_base_model=None, lora_weights=None,
601
  inference_server=None, prompt_type=None, prompt_dict=None)
 
 
 
 
 
 
 
 
602
 
603
  if cli:
604
  from cli import run_cli
@@ -967,11 +1041,13 @@ def get_model(
967
  client = gr_client or hf_client
968
  # Don't return None, None for model, tokenizer so triggers
969
  return client, tokenizer, 'http'
970
- if isinstance(inference_server, str) and inference_server.startswith('openai'):
971
- assert os.getenv('OPENAI_API_KEY'), "Set environment for OPENAI_API_KEY"
972
- # Don't return None, None for model, tokenizer so triggers
973
- # include small token cushion
974
- tokenizer = FakeTokenizer(model_max_length=model_token_mapping[base_model] - 50)
 
 
975
  return inference_server, tokenizer, inference_server
976
  assert not inference_server, "Malformed inference_server=%s" % inference_server
977
  if base_model in non_hf_types:
@@ -1255,6 +1331,7 @@ def get_score_model(score_model: str = None,
1255
  def evaluate(
1256
  model_state,
1257
  my_db_state,
 
1258
  # START NOTE: Examples must have same order of parameters
1259
  instruction,
1260
  iinput,
@@ -1277,7 +1354,9 @@ def evaluate(
1277
  instruction_nochat,
1278
  iinput_nochat,
1279
  langchain_mode,
 
1280
  langchain_action,
 
1281
  top_k_docs,
1282
  chunk,
1283
  chunk_size,
@@ -1291,6 +1370,9 @@ def evaluate(
1291
  save_dir=None,
1292
  sanitize_bot_response=False,
1293
  model_state0=None,
 
 
 
1294
  memory_restriction_level=None,
1295
  max_max_new_tokens=None,
1296
  is_public=None,
@@ -1298,13 +1380,14 @@ def evaluate(
1298
  raise_generate_gpu_exceptions=None,
1299
  chat_context=None,
1300
  lora_weights=None,
 
1301
  load_db_if_exists=True,
1302
  dbs=None,
1303
- user_path=None,
1304
  detect_user_path_changes_every_query=None,
1305
  use_openai_embedding=None,
1306
  use_openai_model=None,
1307
  hf_embedding_model=None,
 
1308
  db_type=None,
1309
  n_jobs=None,
1310
  first_para=None,
@@ -1333,6 +1416,16 @@ def evaluate(
1333
  assert chunk_size is not None and isinstance(chunk_size, int)
1334
  assert n_jobs is not None
1335
  assert first_para is not None
 
 
 
 
 
 
 
 
 
 
1336
 
1337
  if debug:
1338
  locals_dict = locals().copy()
@@ -1452,18 +1545,24 @@ def evaluate(
1452
  # THIRD PLACE where LangChain referenced, but imports only occur if enabled and have db to use
1453
  assert langchain_mode in langchain_modes, "Invalid langchain_mode %s" % langchain_mode
1454
  assert langchain_action in langchain_actions, "Invalid langchain_action %s" % langchain_action
1455
- if langchain_mode in ['MyData'] and my_db_state is not None and len(my_db_state) > 0 and my_db_state[0] is not None:
1456
- db1 = my_db_state[0]
1457
- elif dbs is not None and langchain_mode in dbs:
1458
- db1 = dbs[langchain_mode]
 
 
 
 
 
 
1459
  else:
1460
- db1 = None
1461
- do_langchain_path = langchain_mode not in [False, 'Disabled', 'ChatLLM', 'LLM'] or \
1462
  base_model in non_hf_types or \
1463
  force_langchain_evaluate
1464
  if do_langchain_path:
1465
  outr = ""
1466
- # use smaller cut_distanct for wiki_full since so many matches could be obtained, and often irrelevant unless close
1467
  from gpt_langchain import run_qa_db
1468
  gen_hyper_langchain = dict(do_sample=do_sample,
1469
  temperature=temperature,
@@ -1484,11 +1583,13 @@ def evaluate(
1484
  inference_server=inference_server,
1485
  stream_output=stream_output,
1486
  prompter=prompter,
 
1487
  load_db_if_exists=load_db_if_exists,
1488
- db=db1,
1489
- user_path=user_path,
1490
  detect_user_path_changes_every_query=detect_user_path_changes_every_query,
1491
- cut_distanct=1.1 if langchain_mode in ['wiki_full'] else 1.64, # FIXME, too arbitrary
 
1492
  use_openai_embedding=use_openai_embedding,
1493
  use_openai_model=use_openai_model,
1494
  hf_embedding_model=hf_embedding_model,
@@ -1498,6 +1599,7 @@ def evaluate(
1498
  chunk_size=chunk_size,
1499
  langchain_mode=langchain_mode,
1500
  langchain_action=langchain_action,
 
1501
  document_subset=document_subset,
1502
  document_choice=document_choice,
1503
  db_type=db_type,
@@ -1526,6 +1628,7 @@ def evaluate(
1526
  inference_server=inference_server,
1527
  langchain_mode=langchain_mode,
1528
  langchain_action=langchain_action,
 
1529
  document_subset=document_subset,
1530
  document_choice=document_choice,
1531
  num_prompt_tokens=num_prompt_tokens,
@@ -1549,12 +1652,12 @@ def evaluate(
1549
  clear_torch_cache()
1550
  return
1551
 
1552
- if inference_server.startswith('openai') or inference_server.startswith('http'):
1553
- if inference_server.startswith('openai'):
1554
- import openai
1555
  where_from = "openai_client"
 
1556
 
1557
- openai.api_key = os.getenv("OPENAI_API_KEY")
1558
  terminate_response = prompter.terminate_response or []
1559
  stop_sequences = list(set(terminate_response + [prompter.PreResponse]))
1560
  stop_sequences = [x for x in stop_sequences if x]
@@ -1567,7 +1670,7 @@ def evaluate(
1567
  n=num_return_sequences,
1568
  presence_penalty=1.07 - repetition_penalty + 0.6, # so good default
1569
  )
1570
- if inference_server == 'openai':
1571
  response = openai.Completion.create(
1572
  model=base_model,
1573
  prompt=prompt,
@@ -1590,7 +1693,9 @@ def evaluate(
1590
  yield dict(response=prompter.get_response(prompt + text, prompt=prompt,
1591
  sanitize_bot_response=sanitize_bot_response),
1592
  sources='')
1593
- elif inference_server == 'openai_chat':
 
 
1594
  response = openai.ChatCompletion.create(
1595
  model=base_model,
1596
  messages=[
@@ -1642,7 +1747,9 @@ def evaluate(
1642
  chat_client = False
1643
  where_from = "gr_client"
1644
  client_langchain_mode = 'Disabled'
 
1645
  client_langchain_action = LangChainAction.QUERY.value
 
1646
  gen_server_kwargs = dict(temperature=temperature,
1647
  top_p=top_p,
1648
  top_k=top_k,
@@ -1694,12 +1801,14 @@ def evaluate(
1694
  instruction_nochat=gr_prompt if not chat_client else '',
1695
  iinput_nochat=gr_iinput, # only for chat=False
1696
  langchain_mode=client_langchain_mode,
 
1697
  langchain_action=client_langchain_action,
 
1698
  top_k_docs=top_k_docs,
1699
  chunk=chunk,
1700
  chunk_size=chunk_size,
1701
- document_subset=DocumentChoices.Relevant.name,
1702
- document_choice=[],
1703
  )
1704
  api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
1705
  if not stream_output:
@@ -1993,7 +2102,7 @@ def evaluate(
1993
 
1994
 
1995
  inputs_list_names = list(inspect.signature(evaluate).parameters)
1996
- state_names = ['model_state', 'my_db_state']
1997
  inputs_kwargs_list = [x for x in inputs_list_names if x not in eval_func_param_names + state_names]
1998
 
1999
 
@@ -2276,8 +2385,8 @@ y = np.random.randint(0, 1, 100)
2276
 
2277
  # move to correct position
2278
  for example in examples:
2279
- example += [chat, '', '', LangChainMode.DISABLED.value, LangChainAction.QUERY.value,
2280
- top_k_docs, chunk, chunk_size, [DocumentChoices.Relevant.name], []
2281
  ]
2282
  # adjust examples if non-chat mode
2283
  if not chat:
@@ -2337,7 +2446,7 @@ def score_qa(smodel, stokenizer, max_length_tokenize, question, answer, cutoff_l
2337
  truncation=True,
2338
  max_length=max_length_tokenize).to(smodel.device)
2339
  try:
2340
- score = torch.sigmoid(smodel(**inputs).logits[0]).cpu().detach().numpy()[0]
2341
  except torch.cuda.OutOfMemoryError as e:
2342
  print("GPU OOM 3: question: %s answer: %s exception: %s" % (question, answer, str(e)), flush=True)
2343
  del inputs
@@ -2383,14 +2492,14 @@ def check_locals(**kwargs):
2383
 
2384
 
2385
  def get_model_max_length(model_state):
2386
- if not isinstance(model_state['tokenizer'], (str, types.NoneType)):
2387
  return model_state['tokenizer'].model_max_length
2388
  else:
2389
  return 2048
2390
 
2391
 
2392
  def get_max_max_new_tokens(model_state, **kwargs):
2393
- if not isinstance(model_state['tokenizer'], (str, types.NoneType)):
2394
  max_max_new_tokens = model_state['tokenizer'].model_max_length
2395
  else:
2396
  max_max_new_tokens = None
@@ -2422,12 +2531,15 @@ def get_minmax_top_k_docs(is_public):
2422
  return min_top_k_docs, max_top_k_docs, label_top_k_docs
2423
 
2424
 
2425
- def history_to_context(history, langchain_mode1, prompt_type1, prompt_dict1, chat1, model_max_length1,
 
 
2426
  memory_restriction_level1, keep_sources_in_context1):
2427
  """
2428
  consumes all history up to (but not including) latest history item that is presumed to be an [instruction, None] pair
2429
  :param history:
2430
  :param langchain_mode1:
 
2431
  :param prompt_type1:
2432
  :param prompt_dict1:
2433
  :param chat1:
@@ -2440,7 +2552,7 @@ def history_to_context(history, langchain_mode1, prompt_type1, prompt_dict1, cha
2440
  _, _, _, max_prompt_length = get_cutoffs(memory_restriction_level1,
2441
  for_context=True, model_max_length=model_max_length1)
2442
  context1 = ''
2443
- if max_prompt_length is not None and langchain_mode1 not in ['LLM']:
2444
  context1 = ''
2445
  # - 1 below because current instruction already in history from user()
2446
  for histi in range(0, len(history) - 1):
@@ -2476,6 +2588,22 @@ def history_to_context(history, langchain_mode1, prompt_type1, prompt_dict1, cha
2476
  return context1
2477
 
2478
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2479
  def entrypoint_main():
2480
  """
2481
  Examples:
 
8
  import os
9
  import time
10
  import traceback
 
11
  import typing
12
  import warnings
13
  from datetime import datetime
 
27
  warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
28
 
29
  from evaluate_params import eval_func_param_names, no_default_param_names
30
+ from enums import DocumentSubset, LangChainMode, no_lora_str, model_token_mapping, no_model_str, source_prefix, \
31
+ source_postfix, LangChainAction, LangChainAgent, DocumentChoice
32
  from loaders import get_loaders
33
  from utils import set_seed, clear_torch_cache, save_generate_output, NullContext, wrapped_partial, EThread, get_githash, \
34
  import_matplotlib, get_device, makedirs, get_kwargs, start_faulthandler, get_hf_server, FakeTokenizer, remove, \
35
+ have_langchain, set_openai, load_collection_enum
36
 
37
  start_faulthandler()
38
  import_matplotlib()
 
49
  from prompter import Prompter, inv_prompt_type_to_model_lower, non_hf_types, PromptType, get_prompt, generate_prompt
50
  from stopping import get_stopping
51
 
 
 
52
  langchain_actions = [x.value for x in list(LangChainAction)]
53
 
54
+ langchain_agents_list = [x.value for x in list(LangChainAgent)]
55
+
56
  scratch_base_dir = '/tmp/'
57
 
58
 
 
113
  show_examples: bool = None,
114
  verbose: bool = False,
115
  h2ocolors: bool = True,
116
+ dark: bool = False, # light tends to be best
117
  height: int = 600,
118
  show_lora: bool = True,
119
  login_mode_if_model0: bool = False,
 
134
  extra_lora_options: typing.List[str] = [],
135
  extra_server_options: typing.List[str] = [],
136
 
137
+ score_model: str = 'auto',
138
 
139
  eval_filename: str = None,
140
  eval_prompts_only_num: int = 0,
 
143
 
144
  langchain_mode: str = None,
145
  langchain_action: str = LangChainAction.QUERY.value,
146
+ langchain_agents: list = [],
147
  force_langchain_evaluate: bool = False,
148
+ langchain_modes: list = [x.value for x in list(LangChainMode)],
149
  visible_langchain_modes: list = ['UserData', 'MyData'],
150
  # WIP:
151
  # visible_langchain_actions: list = langchain_actions.copy(),
152
  visible_langchain_actions: list = [LangChainAction.QUERY.value, LangChainAction.SUMMARIZE_MAP.value],
153
+ visible_langchain_agents: list = langchain_agents_list.copy(),
154
+ document_subset: str = DocumentSubset.Relevant.name,
155
+ document_choice: list = [DocumentChoice.ALL.value],
156
  user_path: str = None,
157
+ langchain_mode_paths: dict = {'UserData': None},
158
  detect_user_path_changes_every_query: bool = False,
159
+ use_llm_if_no_docs: bool = False,
160
  load_db_if_exists: bool = True,
161
  keep_sources_in_context: bool = False,
162
  db_type: str = 'chroma',
163
  use_openai_embedding: bool = False,
164
  use_openai_model: bool = False,
165
  hf_embedding_model: str = None,
166
+ cut_distance: float = 1.64,
167
+ add_chat_history_to_context: bool = True,
168
  allow_upload_to_user_data: bool = True,
169
+ reload_langchain_state: bool = True,
170
  allow_upload_to_my_data: bool = True,
171
  enable_url_upload: bool = True,
172
  enable_text_upload: bool = True,
 
183
  pre_load_caption_model: bool = False,
184
  caption_gpu: bool = True,
185
  enable_ocr: bool = False,
186
+ enable_pdf_ocr: str = 'auto',
187
  ):
188
  """
189
 
 
205
  Or Address can be "openai_chat" or "openai" for OpenAI API
206
  e.g. python generate.py --inference_server="openai_chat" --base_model=gpt-3.5-turbo
207
  e.g. python generate.py --inference_server="openai" --base_model=text-davinci-003
208
+ Or Address can be "vllm:IP:port" or "vllm:IP:port" for OpenAI-compliant vLLM endpoint
209
+ Note: vllm_chat not supported by vLLM project.
210
  :param prompt_type: type of prompt, usually matched to fine-tuned model or plain for foundational model
211
  :param prompt_dict: If prompt_type=custom, then expects (some) items returned by get_prompt(..., return_dict=True)
212
  :param model_lock: Lock models to specific combinations, for ease of use and extending to many models
 
263
  :param show_examples: whether to show clickable examples in gradio
264
  :param verbose: whether to show verbose prints
265
  :param h2ocolors: whether to use H2O.ai theme
266
+ :param dark: whether to use dark mode for UI by default (still controlled in UI)
267
  :param height: height of chat window
268
  :param show_lora: whether to show LORA options in UI (expert so can be hard to understand)
269
  :param login_mode_if_model0: set to True to load --base_model after client logs in, to be able to free GPU memory when model is swapped
 
283
  :param extra_model_options: extra models to show in list in gradio
284
  :param extra_lora_options: extra LORA to show in list in gradio
285
  :param extra_server_options: extra servers to show in list in gradio
286
+ :param score_model: which model to score responses
287
+ None: no response scoring
288
+ 'auto': auto mode, '' (no model) for CPU, 'OpenAssistant/reward-model-deberta-v3-large-v2' for GPU,
289
+ because on CPU takes too much compute just for scoring response
290
  :param eval_filename: json file to use for evaluation, if None is sharegpt
291
  :param eval_prompts_only_num: for no gradio benchmark, if using eval_filename prompts for eval instead of examples
292
  :param eval_prompts_only_seed: for no gradio benchmark, seed for eval_filename sampling
293
  :param eval_as_output: for no gradio benchmark, whether to test eval_filename output itself
294
  :param langchain_mode: Data source to include. Choose "UserData" to only consume files from make_db.py.
295
+ None: auto mode, check if langchain package exists, at least do LLM if so, else Disabled
296
  WARNING: wiki_full requires extra data processing via read_wiki_full.py and requires really good workstation to generate db, unless already present.
297
  :param langchain_action: Mode langchain operations in on documents.
298
  Query: Make query of document(s)
299
  Summarize or Summarize_map_reduce: Summarize document(s) via map_reduce
300
  Summarize_all: Summarize document(s) using entire document at once
301
  Summarize_refine: Summarize document(s) using entire document, and try to refine before returning summary
302
+ :param langchain_agents: Which agents to use
303
+ 'search': Use Web Search as context for LLM response, e.g. SERP if have SERPAPI_API_KEY in env
304
  :param force_langchain_evaluate: Whether to force langchain LLM use even if not doing langchain, mostly for testing.
305
  :param user_path: user path to glob from to generate db for vector search, for 'UserData' langchain mode.
306
  If already have db, any new/changed files are added automatically if path set, does not have to be same path used for prior db sources
307
+ :param langchain_mode_paths: dict of langchain_mode keys and disk path values to use for source of documents
308
+ E.g. "{'UserData2': 'userpath2'}"
309
+ Can be None even if existing DB, to avoid new documents being added from that path, source links that are on disk still work.
310
+ If user_path is not None, that path is used for 'UserData' instead of the value in this dict
311
  :param detect_user_path_changes_every_query: whether to detect if any files changed or added every similarity search (by file hashes).
312
  Expensive for large number of files, so not done by default. By default only detect changes during db loading.
313
+ :param langchain_modes: names of collections/dbs to potentially have
314
  :param visible_langchain_modes: dbs to generate at launch to be ready for LLM
315
  Can be up to ['wiki', 'wiki_full', 'UserData', 'MyData', 'github h2oGPT', 'DriverlessAI docs']
316
  But wiki_full is expensive and requires preparation
317
  To allow scratch space only live in session, add 'MyData' to list
318
  Default: If only want to consume local files, e.g. prepared by make_db.py, only include ['UserData']
319
+ If have own user modes, need to add these here or add in UI.
320
+ A state file is stored in visible_langchain_modes.pkl containing last UI-selected values of:
321
+ langchain_modes, visible_langchain_modes, and langchain_mode_paths
322
+ Delete the file if you want to start fresh,
323
+ but in any case the user_path passed in CLI is used for UserData even if was None or different
324
  :param visible_langchain_actions: Which actions to allow
325
+ :param visible_langchain_agents: Which agents to allow
326
  :param document_subset: Default document choice when taking subset of collection
327
+ :param document_choice: Chosen document(s) by internal name, 'All' means use all docs
328
+ :param use_llm_if_no_docs: Whether to use LLM even if no documents, when langchain_mode=UserData or MyData or custom
329
  :param load_db_if_exists: Whether to load chroma db if exists or re-generate db
330
  :param keep_sources_in_context: Whether to keep url sources in context, not helpful usually
331
  :param db_type: 'faiss' for in-memory or 'chroma' or 'weaviate' for persisted on disk
332
  :param use_openai_embedding: Whether to use OpenAI embeddings for vector db
333
  :param use_openai_model: Whether to use OpenAI model for use with vector db
334
  :param hf_embedding_model: Which HF embedding model to use for vector db
335
+ Default is instructor-large with 768 parameters per embedding if have GPUs, else all-MiniLM-L6-v2 if no GPUs
336
  Can also choose simpler model with 384 parameters per embedding: "sentence-transformers/all-MiniLM-L6-v2"
337
  Can also choose even better embedding with 1024 parameters: 'hkunlp/instructor-xl'
338
  We support automatically changing of embeddings for chroma, with a backup of db made if this is done
339
+ :param cut_distance: Distance to cut off references with larger distances when showing references.
340
+ 1.64 is good to avoid dropping references for all-MiniLM-L6-v2, but instructor-large will always show excessive references.
341
+ For all-MiniLM-L6-v2, a value of 1.5 can push out even more references, or a large value of 100 can avoid any loss of references.
342
+ :param add_chat_history_to_context: Include chat context when performing action
343
+ Not supported yet for openai_chat when using document collection instead of LLM
344
+ Also not supported when using CLI mode
345
+ :param allow_upload_to_user_data: Whether to allow file uploads to update shared vector db (UserData or custom user dbs)
346
+ :param reload_langchain_state: Whether to reload visible_langchain_modes.pkl file that contains any new user collections.
347
  :param allow_upload_to_my_data: Whether to allow file uploads to update scratch vector db
348
  :param enable_url_upload: Whether to allow upload from URL
349
  :param enable_text_upload: Whether to allow upload of text
350
  :param enable_sources_list: Whether to allow list (or download for non-shared db) of list of sources for chosen db
351
  :param chunk: Whether to chunk data (True unless know data is already optimally chunked)
352
+ :param chunk_size: Size of chunks, with typically top-4 passed to LLM, so needs to be in context length
353
  :param top_k_docs: number of chunks to give LLM
354
  :param reverse_docs: whether to reverse docs order so most relevant is closest to question.
355
  Best choice for sufficiently smart model, and truncation occurs for oldest context, so best then too.
 
363
  captions_model: str = "Salesforce/blip2-flan-t5-xl", # question/answer capable, 16GB state
364
  captions_model: str = "Salesforce/blip2-flan-t5-xxl", # question/answer capable, 60GB state
365
  Note: opt-based blip2 are not permissive license due to opt and Meta license restrictions
366
+ Disabled for CPU since BLIP requires CUDA
367
  :param pre_load_caption_model: Whether to preload caption model, or load after forking parallel doc loader
368
  parallel loading disabled if preload and have images, to prevent deadlocking on cuda context
369
  Recommended if using larger caption model
370
  :param caption_gpu: If support caption, then use GPU if exists
371
  :param enable_ocr: Whether to support OCR on images
372
+ :param enable_pdf_ocr: 'auto' means only use OCR if normal text extraction fails. Useful for pure image-based PDFs with text
373
+ 'on' means always do OCR as additional parsing of same documents
374
+ 'off' means don't do OCR (e.g. because it's slow even if 'auto' only would trigger if nothing else worked)
375
  :return:
376
  """
377
  if base_model is None:
 
433
  if langchain_mode is not None:
434
  visible_langchain_modes += [langchain_mode]
435
 
436
+ # update
437
+ if isinstance(langchain_mode_paths, str):
438
+ langchain_mode_paths = ast.literal_eval(langchain_mode_paths)
439
+ assert isinstance(langchain_mode_paths, dict)
440
+ if user_path:
441
+ langchain_mode_paths['UserData'] = user_path
442
+ makedirs(user_path)
443
+
444
+ if is_public:
445
+ allow_upload_to_user_data = False
446
+ if LangChainMode.USER_DATA.value in visible_langchain_modes:
447
+ visible_langchain_modes.remove(LangChainMode.USER_DATA.value)
448
+
449
+ # in-place, for non-scratch dbs
450
+ if allow_upload_to_user_data:
451
+ update_langchain(langchain_modes, visible_langchain_modes, langchain_mode_paths, '')
452
+ # always listen to CLI-passed user_path if passed
453
+ if user_path:
454
+ langchain_mode_paths['UserData'] = user_path
455
+
456
  assert langchain_action in langchain_actions, "Invalid langchain_action %s" % langchain_action
457
+ assert len(
458
+ set(langchain_agents).difference(langchain_agents_list)) == 0, "Invalid langchain_agents %s" % langchain_agents
459
 
460
  # if specifically chose not to show My or User Data, disable upload, so gradio elements are simpler
461
  if LangChainMode.MY_DATA.value not in visible_langchain_modes:
 
466
  # auto-set langchain_mode
467
  if have_langchain and langchain_mode is None:
468
  # start in chat mode, in case just want to chat and don't want to get "No documents to query" by default.
469
+ langchain_mode = LangChainMode.LLM.value
470
+ if allow_upload_to_user_data and not is_public and langchain_mode_paths['UserData']:
471
  print("Auto set langchain_mode=%s. Could use UserData instead." % langchain_mode, flush=True)
472
  elif allow_upload_to_my_data:
473
  print("Auto set langchain_mode=%s. Could use MyData instead."
474
  " To allow UserData to pull files from disk,"
475
+ " set user_path or langchain_mode_paths, and ensure allow_upload_to_user_data=True" % langchain_mode,
476
+ flush=True)
477
  else:
478
  raise RuntimeError("Please pass --langchain_mode=<chosen mode> out of %s" % langchain_modes)
479
+ if not have_langchain and langchain_mode not in [None, LangChainMode.DISABLED.value, LangChainMode.LLM.value]:
480
  raise RuntimeError("Asked for LangChain mode but langchain python package cannot be found.")
481
  if langchain_mode is None:
482
  # if not set yet, disable
483
  langchain_mode = LangChainMode.DISABLED.value
484
+ print("Auto set langchain_mode=%s Have langchain package: %s" % (langchain_mode, have_langchain), flush=True)
485
 
486
  if is_public:
487
  allow_upload_to_user_data = False
 
537
  # HF accounted for later in get_max_max_new_tokens()
538
  save_dir = os.getenv('SAVE_DIR', save_dir)
539
  score_model = os.getenv('SCORE_MODEL', score_model)
540
+ if str(score_model) == 'None':
541
  score_model = ''
542
  concurrency_count = int(os.getenv('CONCURRENCY_COUNT', concurrency_count))
543
  api_open = bool(int(os.getenv('API_OPEN', str(int(api_open)))))
 
545
 
546
  n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0
547
  if n_gpus == 0:
548
+ enable_captions = False
549
  gpu_id = None
550
  load_8bit = False
551
  load_4bit = False
 
563
  if hf_embedding_model is None:
564
  # if no GPUs, use simpler embedding model to avoid cost in time
565
  hf_embedding_model = "sentence-transformers/all-MiniLM-L6-v2"
566
+ if score_model == 'auto':
567
+ score_model = ''
568
  else:
569
+ if score_model == 'auto':
570
+ score_model = 'OpenAssistant/reward-model-deberta-v3-large-v2'
571
  if hf_embedding_model is None:
572
  # if still None, then set default
573
  hf_embedding_model = 'hkunlp/instructor-large'
 
592
 
593
  if offload_folder:
594
  makedirs(offload_folder)
 
 
595
 
596
  placeholder_instruction, placeholder_input, \
597
  stream_output, show_examples, \
 
617
  verbose,
618
  )
619
 
620
+ git_hash = get_githash() if is_public or os.getenv('GET_GITHASH') else "GET_GITHASH"
621
  locals_dict = locals()
622
  locals_print = '\n'.join(['%s: %s' % (k, v) for k, v in locals_dict.items()])
623
  if verbose:
 
631
  get_some_dbs_from_hf()
632
  dbs = {}
633
  for langchain_mode1 in visible_langchain_modes:
634
+ if langchain_mode1 in ['MyData']: # FIXME: Remove other custom temp dbs
635
  # don't use what is on disk, remove it instead
636
  for gpath1 in glob.glob(os.path.join(scratch_base_dir, 'db_dir_%s*' % langchain_mode1)):
637
  if os.path.isdir(gpath1):
 
646
  db = prep_langchain(persist_directory1,
647
  load_db_if_exists,
648
  db_type, use_openai_embedding,
649
+ langchain_mode1, langchain_mode_paths,
650
  hf_embedding_model,
651
  kwargs_make_db=locals())
652
  finally:
 
665
  model_state_none = dict(model=None, tokenizer=None, device=None,
666
  base_model=None, tokenizer_base_model=None, lora_weights=None,
667
  inference_server=None, prompt_type=None, prompt_dict=None)
668
+ my_db_state0 = {LangChainMode.MY_DATA.value: [None, None]}
669
+ selection_docs_state0 = dict(visible_langchain_modes=visible_langchain_modes,
670
+ langchain_mode_paths=langchain_mode_paths,
671
+ langchain_modes=langchain_modes)
672
+ selection_docs_state = selection_docs_state0
673
+ langchain_modes0 = langchain_modes
674
+ langchain_mode_paths0 = langchain_mode_paths
675
+ visible_langchain_modes0 = visible_langchain_modes
676
 
677
  if cli:
678
  from cli import run_cli
 
1041
  client = gr_client or hf_client
1042
  # Don't return None, None for model, tokenizer so triggers
1043
  return client, tokenizer, 'http'
1044
+ if isinstance(inference_server, str) and (
1045
+ inference_server.startswith('openai') or inference_server.startswith('vllm')):
1046
+ if inference_server.startswith('openai'):
1047
+ assert os.getenv('OPENAI_API_KEY'), "Set environment for OPENAI_API_KEY"
1048
+ # Don't return None, None for model, tokenizer so triggers
1049
+ # include small token cushion
1050
+ tokenizer = FakeTokenizer(model_max_length=model_token_mapping[base_model] - 50)
1051
  return inference_server, tokenizer, inference_server
1052
  assert not inference_server, "Malformed inference_server=%s" % inference_server
1053
  if base_model in non_hf_types:
 
1331
  def evaluate(
1332
  model_state,
1333
  my_db_state,
1334
+ selection_docs_state,
1335
  # START NOTE: Examples must have same order of parameters
1336
  instruction,
1337
  iinput,
 
1354
  instruction_nochat,
1355
  iinput_nochat,
1356
  langchain_mode,
1357
+ add_chat_history_to_context,
1358
  langchain_action,
1359
+ langchain_agents,
1360
  top_k_docs,
1361
  chunk,
1362
  chunk_size,
 
1370
  save_dir=None,
1371
  sanitize_bot_response=False,
1372
  model_state0=None,
1373
+ langchain_modes0=None,
1374
+ langchain_mode_paths0=None,
1375
+ visible_langchain_modes0=None,
1376
  memory_restriction_level=None,
1377
  max_max_new_tokens=None,
1378
  is_public=None,
 
1380
  raise_generate_gpu_exceptions=None,
1381
  chat_context=None,
1382
  lora_weights=None,
1383
+ use_llm_if_no_docs=False,
1384
  load_db_if_exists=True,
1385
  dbs=None,
 
1386
  detect_user_path_changes_every_query=None,
1387
  use_openai_embedding=None,
1388
  use_openai_model=None,
1389
  hf_embedding_model=None,
1390
+ cut_distance=None,
1391
  db_type=None,
1392
  n_jobs=None,
1393
  first_para=None,
 
1416
  assert chunk_size is not None and isinstance(chunk_size, int)
1417
  assert n_jobs is not None
1418
  assert first_para is not None
1419
+ assert isinstance(add_chat_history_to_context, bool)
1420
+
1421
+ if selection_docs_state is not None:
1422
+ langchain_modes = selection_docs_state.get('langchain_modes', langchain_modes0)
1423
+ langchain_mode_paths = selection_docs_state.get('langchain_mode_paths', langchain_mode_paths0)
1424
+ visible_langchain_modes = selection_docs_state.get('visible_langchain_modes', visible_langchain_modes0)
1425
+ else:
1426
+ langchain_modes = langchain_modes0
1427
+ langchain_mode_paths = langchain_mode_paths0
1428
+ visible_langchain_modes = visible_langchain_modes0
1429
 
1430
  if debug:
1431
  locals_dict = locals().copy()
 
1545
  # THIRD PLACE where LangChain referenced, but imports only occur if enabled and have db to use
1546
  assert langchain_mode in langchain_modes, "Invalid langchain_mode %s" % langchain_mode
1547
  assert langchain_action in langchain_actions, "Invalid langchain_action %s" % langchain_action
1548
+ assert len(
1549
+ set(langchain_agents).difference(langchain_agents_list)) == 0, "Invalid langchain_agents %s" % langchain_agents
1550
+ if dbs is not None and langchain_mode in dbs:
1551
+ db = dbs[langchain_mode]
1552
+ elif my_db_state is not None and langchain_mode in my_db_state:
1553
+ db1 = my_db_state[langchain_mode]
1554
+ if db1 is not None and len(db1) == 2:
1555
+ db = db1[0]
1556
+ else:
1557
+ db = None
1558
  else:
1559
+ db = None
1560
+ do_langchain_path = langchain_mode not in [False, 'Disabled', 'LLM'] or \
1561
  base_model in non_hf_types or \
1562
  force_langchain_evaluate
1563
  if do_langchain_path:
1564
  outr = ""
1565
+ # use smaller cut_distance for wiki_full since so many matches could be obtained, and often irrelevant unless close
1566
  from gpt_langchain import run_qa_db
1567
  gen_hyper_langchain = dict(do_sample=do_sample,
1568
  temperature=temperature,
 
1583
  inference_server=inference_server,
1584
  stream_output=stream_output,
1585
  prompter=prompter,
1586
+ use_llm_if_no_docs=use_llm_if_no_docs,
1587
  load_db_if_exists=load_db_if_exists,
1588
+ db=db,
1589
+ langchain_mode_paths=langchain_mode_paths,
1590
  detect_user_path_changes_every_query=detect_user_path_changes_every_query,
1591
+ cut_distance=1.1 if langchain_mode in ['wiki_full'] else cut_distance,
1592
+ add_chat_history_to_context=add_chat_history_to_context,
1593
  use_openai_embedding=use_openai_embedding,
1594
  use_openai_model=use_openai_model,
1595
  hf_embedding_model=hf_embedding_model,
 
1599
  chunk_size=chunk_size,
1600
  langchain_mode=langchain_mode,
1601
  langchain_action=langchain_action,
1602
+ langchain_agents=langchain_agents,
1603
  document_subset=document_subset,
1604
  document_choice=document_choice,
1605
  db_type=db_type,
 
1628
  inference_server=inference_server,
1629
  langchain_mode=langchain_mode,
1630
  langchain_action=langchain_action,
1631
+ langchain_agents=langchain_agents,
1632
  document_subset=document_subset,
1633
  document_choice=document_choice,
1634
  num_prompt_tokens=num_prompt_tokens,
 
1652
  clear_torch_cache()
1653
  return
1654
 
1655
+ if inference_server.startswith('vllm') or inference_server.startswith('openai') or inference_server.startswith(
1656
+ 'http'):
1657
+ if inference_server.startswith('vllm') or inference_server.startswith('openai'):
1658
  where_from = "openai_client"
1659
+ openai, inf_type = set_openai(inference_server)
1660
 
 
1661
  terminate_response = prompter.terminate_response or []
1662
  stop_sequences = list(set(terminate_response + [prompter.PreResponse]))
1663
  stop_sequences = [x for x in stop_sequences if x]
 
1670
  n=num_return_sequences,
1671
  presence_penalty=1.07 - repetition_penalty + 0.6, # so good default
1672
  )
1673
+ if inf_type == 'vllm' or inference_server == 'openai':
1674
  response = openai.Completion.create(
1675
  model=base_model,
1676
  prompt=prompt,
 
1693
  yield dict(response=prompter.get_response(prompt + text, prompt=prompt,
1694
  sanitize_bot_response=sanitize_bot_response),
1695
  sources='')
1696
+ elif inf_type == 'vllm_chat' or inference_server == 'openai_chat':
1697
+ if inf_type == 'vllm_chat':
1698
+ raise NotImplementedError('%s not supported by vLLM' % inf_type)
1699
  response = openai.ChatCompletion.create(
1700
  model=base_model,
1701
  messages=[
 
1747
  chat_client = False
1748
  where_from = "gr_client"
1749
  client_langchain_mode = 'Disabled'
1750
+ client_add_chat_history_to_context = True
1751
  client_langchain_action = LangChainAction.QUERY.value
1752
+ client_langchain_agents = []
1753
  gen_server_kwargs = dict(temperature=temperature,
1754
  top_p=top_p,
1755
  top_k=top_k,
 
1801
  instruction_nochat=gr_prompt if not chat_client else '',
1802
  iinput_nochat=gr_iinput, # only for chat=False
1803
  langchain_mode=client_langchain_mode,
1804
+ add_chat_history_to_context=client_add_chat_history_to_context,
1805
  langchain_action=client_langchain_action,
1806
+ langchain_agents=client_langchain_agents,
1807
  top_k_docs=top_k_docs,
1808
  chunk=chunk,
1809
  chunk_size=chunk_size,
1810
+ document_subset=DocumentSubset.Relevant.name,
1811
+ document_choice=[DocumentChoice.ALL.value],
1812
  )
1813
  api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
1814
  if not stream_output:
 
2102
 
2103
 
2104
  inputs_list_names = list(inspect.signature(evaluate).parameters)
2105
+ state_names = ['model_state', 'my_db_state', 'selection_docs_state']
2106
  inputs_kwargs_list = [x for x in inputs_list_names if x not in eval_func_param_names + state_names]
2107
 
2108
 
 
2385
 
2386
  # move to correct position
2387
  for example in examples:
2388
+ example += [chat, '', '', LangChainMode.DISABLED.value, True, LangChainAction.QUERY.value, [],
2389
+ top_k_docs, chunk, chunk_size, DocumentSubset.Relevant.name, []
2390
  ]
2391
  # adjust examples if non-chat mode
2392
  if not chat:
 
2446
  truncation=True,
2447
  max_length=max_length_tokenize).to(smodel.device)
2448
  try:
2449
+ score = torch.sigmoid(smodel(**inputs.to(smodel.device)).logits[0].float()).cpu().detach().numpy()[0]
2450
  except torch.cuda.OutOfMemoryError as e:
2451
  print("GPU OOM 3: question: %s answer: %s exception: %s" % (question, answer, str(e)), flush=True)
2452
  del inputs
 
2492
 
2493
 
2494
  def get_model_max_length(model_state):
2495
+ if not isinstance(model_state['tokenizer'], (str, type(None))):
2496
  return model_state['tokenizer'].model_max_length
2497
  else:
2498
  return 2048
2499
 
2500
 
2501
  def get_max_max_new_tokens(model_state, **kwargs):
2502
+ if not isinstance(model_state['tokenizer'], (str, type(None))):
2503
  max_max_new_tokens = model_state['tokenizer'].model_max_length
2504
  else:
2505
  max_max_new_tokens = None
 
2531
  return min_top_k_docs, max_top_k_docs, label_top_k_docs
2532
 
2533
 
2534
+ def history_to_context(history, langchain_mode1,
2535
+ add_chat_history_to_context,
2536
+ prompt_type1, prompt_dict1, chat1, model_max_length1,
2537
  memory_restriction_level1, keep_sources_in_context1):
2538
  """
2539
  consumes all history up to (but not including) latest history item that is presumed to be an [instruction, None] pair
2540
  :param history:
2541
  :param langchain_mode1:
2542
+ :param add_chat_history_to_context:
2543
  :param prompt_type1:
2544
  :param prompt_dict1:
2545
  :param chat1:
 
2552
  _, _, _, max_prompt_length = get_cutoffs(memory_restriction_level1,
2553
  for_context=True, model_max_length=model_max_length1)
2554
  context1 = ''
2555
+ if max_prompt_length is not None and add_chat_history_to_context:
2556
  context1 = ''
2557
  # - 1 below because current instruction already in history from user()
2558
  for histi in range(0, len(history) - 1):
 
2588
  return context1
2589
 
2590
 
2591
+ def update_langchain(langchain_modes, visible_langchain_modes, langchain_mode_paths, extra):
2592
+ # update from saved state on disk
2593
+ langchain_modes_from_file, visible_langchain_modes_from_file, langchain_mode_paths_from_file = \
2594
+ load_collection_enum(extra)
2595
+
2596
+ visible_langchain_modes_temp = visible_langchain_modes.copy() + visible_langchain_modes_from_file
2597
+ visible_langchain_modes.clear() # don't lose original reference
2598
+ [visible_langchain_modes.append(x) for x in visible_langchain_modes_temp if x not in visible_langchain_modes]
2599
+
2600
+ langchain_mode_paths.update(langchain_mode_paths_from_file)
2601
+
2602
+ langchain_modes_temp = langchain_modes.copy() + langchain_modes_from_file
2603
+ langchain_modes.clear() # don't lose original reference
2604
+ [langchain_modes.append(x) for x in langchain_modes_temp if x not in langchain_modes]
2605
+
2606
+
2607
  def entrypoint_main():
2608
  """
2609
  Examples:
gpt4all_llm.py CHANGED
@@ -95,15 +95,17 @@ def get_llm_gpt4all(model_name,
95
  streaming=False,
96
  callbacks=None,
97
  prompter=None,
 
 
98
  verbose=False,
99
  ):
100
  assert prompter is not None
101
  env_gpt4all_file = ".env_gpt4all"
102
  env_kwargs = dotenv_values(env_gpt4all_file)
103
- n_ctx = env_kwargs.pop('n_ctx', 2048 - max_new_tokens)
104
  default_kwargs = dict(context_erase=0.5,
105
  n_batch=1,
106
- n_ctx=n_ctx,
107
  n_predict=max_new_tokens,
108
  repeat_last_n=64 if repetition_penalty != 1.0 else 0,
109
  repeat_penalty=repetition_penalty,
@@ -117,7 +119,8 @@ def get_llm_gpt4all(model_name,
117
  cls = H2OLlamaCpp
118
  model_path = env_kwargs.pop('model_path_llama') if model is None else model
119
  model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=['lc_kwargs'])
120
- model_kwargs.update(dict(model_path=model_path, callbacks=callbacks, streaming=streaming, prompter=prompter))
 
121
  llm = cls(**model_kwargs)
122
  llm.client.verbose = verbose
123
  elif model_name == 'gpt4all_llama':
@@ -125,14 +128,16 @@ def get_llm_gpt4all(model_name,
125
  model_path = env_kwargs.pop('model_path_gpt4all_llama') if model is None else model
126
  model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=['lc_kwargs'])
127
  model_kwargs.update(
128
- dict(model=model_path, backend='llama', callbacks=callbacks, streaming=streaming, prompter=prompter))
 
129
  llm = cls(**model_kwargs)
130
  elif model_name == 'gptj':
131
  cls = H2OGPT4All
132
  model_path = env_kwargs.pop('model_path_gptj') if model is None else model
133
  model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=['lc_kwargs'])
134
  model_kwargs.update(
135
- dict(model=model_path, backend='gptj', callbacks=callbacks, streaming=streaming, prompter=prompter))
 
136
  llm = cls(**model_kwargs)
137
  else:
138
  raise RuntimeError("No such model_name %s" % model_name)
@@ -142,6 +147,8 @@ def get_llm_gpt4all(model_name,
142
  class H2OGPT4All(gpt4all.GPT4All):
143
  model: Any
144
  prompter: Any
 
 
145
  """Path to the pre-trained GPT4All model file."""
146
 
147
  @root_validator()
@@ -187,10 +194,11 @@ class H2OGPT4All(gpt4all.GPT4All):
187
  **kwargs,
188
  ) -> str:
189
  # Roughly 4 chars per token if natural language
190
- prompt = prompt[-self.n_ctx * 4:]
 
191
 
192
  # use instruct prompting
193
- data_point = dict(context='', instruction=prompt, input='')
194
  prompt = self.prompter.generate_prompt(data_point)
195
 
196
  verbose = False
@@ -206,6 +214,8 @@ from langchain.llms import LlamaCpp
206
  class H2OLlamaCpp(LlamaCpp):
207
  model_path: Any
208
  prompter: Any
 
 
209
  """Path to the pre-trained GPT4All model file."""
210
 
211
  @root_validator()
@@ -276,7 +286,7 @@ class H2OLlamaCpp(LlamaCpp):
276
  print("reduced tokens from %d -> %d" % (num_prompt_tokens, num_prompt_tokens2), flush=True)
277
 
278
  # use instruct prompting
279
- data_point = dict(context='', instruction=prompt, input='')
280
  prompt = self.prompter.generate_prompt(data_point)
281
 
282
  if verbose:
 
95
  streaming=False,
96
  callbacks=None,
97
  prompter=None,
98
+ context='',
99
+ iinput='',
100
  verbose=False,
101
  ):
102
  assert prompter is not None
103
  env_gpt4all_file = ".env_gpt4all"
104
  env_kwargs = dotenv_values(env_gpt4all_file)
105
+ max_tokens = env_kwargs.pop('max_tokens', 2048 - max_new_tokens)
106
  default_kwargs = dict(context_erase=0.5,
107
  n_batch=1,
108
+ max_tokens=max_tokens,
109
  n_predict=max_new_tokens,
110
  repeat_last_n=64 if repetition_penalty != 1.0 else 0,
111
  repeat_penalty=repetition_penalty,
 
119
  cls = H2OLlamaCpp
120
  model_path = env_kwargs.pop('model_path_llama') if model is None else model
121
  model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=['lc_kwargs'])
122
+ model_kwargs.update(dict(model_path=model_path, callbacks=callbacks, streaming=streaming,
123
+ prompter=prompter, context=context, iinput=iinput))
124
  llm = cls(**model_kwargs)
125
  llm.client.verbose = verbose
126
  elif model_name == 'gpt4all_llama':
 
128
  model_path = env_kwargs.pop('model_path_gpt4all_llama') if model is None else model
129
  model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=['lc_kwargs'])
130
  model_kwargs.update(
131
+ dict(model=model_path, backend='llama', callbacks=callbacks, streaming=streaming,
132
+ prompter=prompter, context=context, iinput=iinput))
133
  llm = cls(**model_kwargs)
134
  elif model_name == 'gptj':
135
  cls = H2OGPT4All
136
  model_path = env_kwargs.pop('model_path_gptj') if model is None else model
137
  model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=['lc_kwargs'])
138
  model_kwargs.update(
139
+ dict(model=model_path, backend='gptj', callbacks=callbacks, streaming=streaming,
140
+ prompter=prompter, context=context, iinput=iinput))
141
  llm = cls(**model_kwargs)
142
  else:
143
  raise RuntimeError("No such model_name %s" % model_name)
 
147
  class H2OGPT4All(gpt4all.GPT4All):
148
  model: Any
149
  prompter: Any
150
+ context: Any = ''
151
+ iinput: Any = ''
152
  """Path to the pre-trained GPT4All model file."""
153
 
154
  @root_validator()
 
194
  **kwargs,
195
  ) -> str:
196
  # Roughly 4 chars per token if natural language
197
+ n_ctx = 2048
198
+ prompt = prompt[-self.max_tokens * 4:]
199
 
200
  # use instruct prompting
201
+ data_point = dict(context=self.context, instruction=prompt, input=self.iinput)
202
  prompt = self.prompter.generate_prompt(data_point)
203
 
204
  verbose = False
 
214
  class H2OLlamaCpp(LlamaCpp):
215
  model_path: Any
216
  prompter: Any
217
+ context: Any
218
+ iinput: Any
219
  """Path to the pre-trained GPT4All model file."""
220
 
221
  @root_validator()
 
286
  print("reduced tokens from %d -> %d" % (num_prompt_tokens, num_prompt_tokens2), flush=True)
287
 
288
  # use instruct prompting
289
+ data_point = dict(context=self.context, instruction=prompt, input=self.iinput)
290
  prompt = self.prompter.generate_prompt(data_point)
291
 
292
  if verbose:
gpt_langchain.py CHANGED
@@ -21,16 +21,17 @@ import filelock
21
  from joblib import delayed
22
  from langchain.callbacks import streaming_stdout
23
  from langchain.embeddings import HuggingFaceInstructEmbeddings
 
24
  from tqdm import tqdm
25
 
26
- from enums import DocumentChoices, no_lora_str, model_token_mapping, source_prefix, source_postfix, non_query_commands, \
27
- LangChainAction, LangChainMode
28
  from evaluate_params import gen_hyper
29
  from gen import get_model, SEED
30
  from prompter import non_hf_types, PromptType, Prompter
31
  from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \
32
  get_device, ProgressParallel, remove, hash_file, clear_torch_cache, NullContext, get_hf_server, FakeTokenizer, \
33
- have_libreoffice, have_arxiv, have_playwright, have_selenium, have_tesseract, have_pymupdf
34
  from utils_langchain import StreamingGradioCallbackHandler
35
 
36
  import_matplotlib()
@@ -95,11 +96,15 @@ def get_db(sources, use_openai_embedding=False, db_type='faiss',
95
  db = get_existing_db(None, persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode,
96
  hf_embedding_model, verbose=False)
97
  if db is None:
 
 
 
 
98
  db = Chroma.from_documents(documents=sources,
99
  embedding=embedding,
100
  persist_directory=persist_directory,
101
  collection_name=collection_name,
102
- anonymized_telemetry=False)
103
  db.persist()
104
  clear_embedding(db)
105
  save_embed(db, use_openai_embedding, hf_embedding_model)
@@ -276,15 +281,7 @@ from typing import Any, Dict, List, Optional, Set
276
 
277
  from pydantic import Extra, Field, root_validator
278
 
279
- from langchain.callbacks.manager import CallbackManagerForLLMRun
280
-
281
- """Wrapper around Huggingface text generation inference API."""
282
- from functools import partial
283
- from typing import Any, Dict, List, Optional
284
-
285
- from pydantic import Extra, Field, root_validator
286
-
287
- from langchain.callbacks.manager import CallbackManagerForLLMRun
288
  from langchain.llms.base import LLM
289
 
290
 
@@ -312,6 +309,8 @@ class GradioInference(LLM):
312
  sanitize_bot_response: bool = False
313
 
314
  prompter: Any = None
 
 
315
  client: Any = None
316
 
317
  class Config:
@@ -355,13 +354,15 @@ class GradioInference(LLM):
355
  stream_output = self.stream
356
  gr_client = self.client
357
  client_langchain_mode = 'Disabled'
 
358
  client_langchain_action = LangChainAction.QUERY.value
 
359
  top_k_docs = 1
360
  chunk = True
361
  chunk_size = 512
362
  client_kwargs = dict(instruction=prompt if self.chat_client else '', # only for chat=True
363
- iinput='', # only for chat=True
364
- context='',
365
  # streaming output is supported, loops over and outputs each generation in streaming mode
366
  # but leave stream_output=False for simple input/output mode
367
  stream_output=stream_output,
@@ -382,14 +383,16 @@ class GradioInference(LLM):
382
  chat=self.chat_client,
383
 
384
  instruction_nochat=prompt if not self.chat_client else '',
385
- iinput_nochat='', # only for chat=False
386
  langchain_mode=client_langchain_mode,
 
387
  langchain_action=client_langchain_action,
 
388
  top_k_docs=top_k_docs,
389
  chunk=chunk,
390
  chunk_size=chunk_size,
391
- document_subset=DocumentChoices.Relevant.name,
392
- document_choice=[],
393
  )
394
  api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
395
  if not stream_output:
@@ -459,6 +462,8 @@ class H2OHuggingFaceTextGenInference(HuggingFaceTextGenInference):
459
  stream: bool = False
460
  sanitize_bot_response: bool = False
461
  prompter: Any = None
 
 
462
  tokenizer: Any = None
463
  client: Any = None
464
 
@@ -500,7 +505,7 @@ class H2OHuggingFaceTextGenInference(HuggingFaceTextGenInference):
500
  prompt, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt, self.tokenizer)
501
 
502
  # NOTE: TGI server does not add prompting, so must do here
503
- data_point = dict(context='', instruction=prompt, input='')
504
  prompt = self.prompter.generate_prompt(data_point)
505
 
506
  gen_server_kwargs = dict(do_sample=self.do_sample,
@@ -566,6 +571,94 @@ class H2OHuggingFaceTextGenInference(HuggingFaceTextGenInference):
566
 
567
 
568
  from langchain.chat_models import ChatOpenAI
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
569
 
570
 
571
  class H2OChatOpenAI(ChatOpenAI):
@@ -596,17 +689,36 @@ def get_llm(use_openai_model=False,
596
  prompt_type=None,
597
  prompt_dict=None,
598
  prompter=None,
 
 
599
  sanitize_bot_response=False,
600
  verbose=False,
601
  ):
602
- if use_openai_model or inference_server in ['openai', 'openai_chat']:
 
 
603
  if use_openai_model and model_name is None:
604
  model_name = "gpt-3.5-turbo"
605
- if inference_server == 'openai':
606
- from langchain.llms import OpenAI
607
- cls = OpenAI
608
- else:
609
  cls = H2OChatOpenAI
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
610
  callbacks = [StreamingGradioCallbackHandler()]
611
  llm = cls(model_name=model_name,
612
  temperature=temperature if do_sample else 0,
@@ -616,11 +728,18 @@ def get_llm(use_openai_model=False,
616
  frequency_penalty=0,
617
  presence_penalty=1.07 - repetition_penalty + 0.6, # so good default
618
  callbacks=callbacks if stream_output else None,
 
 
 
 
 
 
619
  )
620
  streamer = callbacks[0] if stream_output else None
621
  if inference_server in ['openai', 'openai_chat']:
622
  prompt_type = inference_server
623
  else:
 
624
  prompt_type = prompt_type or 'plain'
625
  elif inference_server:
626
  assert inference_server.startswith(
@@ -669,6 +788,8 @@ def get_llm(use_openai_model=False,
669
  callbacks=callbacks if stream_output else None,
670
  stream=stream_output,
671
  prompter=prompter,
 
 
672
  client=gr_client,
673
  sanitize_bot_response=sanitize_bot_response,
674
  )
@@ -689,6 +810,8 @@ def get_llm(use_openai_model=False,
689
  callbacks=callbacks if stream_output else None,
690
  stream=stream_output,
691
  prompter=prompter,
 
 
692
  tokenizer=tokenizer,
693
  client=hf_client,
694
  timeout=max_time,
@@ -721,6 +844,8 @@ def get_llm(use_openai_model=False,
721
  verbose=verbose,
722
  streaming=stream_output,
723
  prompter=prompter,
 
 
724
  )
725
  else:
726
  if model is None:
@@ -763,6 +888,8 @@ def get_llm(use_openai_model=False,
763
  from h2oai_pipeline import H2OTextGenerationPipeline
764
  pipe = H2OTextGenerationPipeline(model=model, use_prompter=True,
765
  prompter=prompter,
 
 
766
  prompt_type=prompt_type,
767
  prompt_dict=prompt_dict,
768
  sanitize_bot_response=sanitize_bot_response,
@@ -916,7 +1043,6 @@ def get_dai_docs(from_hf=False, get_pickle=True):
916
  return sources
917
 
918
 
919
-
920
  image_types = ["png", "jpg", "jpeg"]
921
  non_image_types = ["pdf", "txt", "csv", "toml", "py", "rst", "rtf",
922
  "md",
@@ -927,7 +1053,8 @@ non_image_types = ["pdf", "txt", "csv", "toml", "py", "rst", "rtf",
927
  ]
928
  # "msg", GPL3
929
 
930
- if have_libreoffice:
 
931
  non_image_types.extend(["docx", "doc", "xls", "xlsx"])
932
 
933
  file_types = non_image_types + image_types
@@ -936,9 +1063,11 @@ file_types = non_image_types + image_types
936
  def add_meta(docs1, file):
937
  file_extension = pathlib.Path(file).suffix
938
  hashid = hash_file(file)
 
939
  if not isinstance(docs1, (list, tuple, types.GeneratorType)):
940
  docs1 = [docs1]
941
- [x.metadata.update(dict(input_type=file_extension, date=str(datetime.now()), hashid=hashid)) for x in docs1]
 
942
 
943
 
944
  def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
@@ -946,7 +1075,7 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
946
  is_url=False, is_txt=False,
947
  enable_captions=True,
948
  captions_model=None,
949
- enable_ocr=False, caption_loader=None,
950
  headsize=50):
951
  if file is None:
952
  if fail_any_exception:
@@ -963,6 +1092,7 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
963
  base_name = sanitize_filename(base_name) + "_" + str(uuid.uuid4())[:10]
964
  base_path = os.path.join(dir_name, base_name)
965
  if is_url:
 
966
  if file.lower().startswith('arxiv:'):
967
  query = file.lower().split('arxiv:')
968
  if len(query) == 2 and have_arxiv:
@@ -1011,11 +1141,11 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
1011
  add_meta(docs1, file)
1012
  docs1 = clean_doc(docs1)
1013
  doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size, language=Language.HTML)
1014
- elif (file.lower().endswith('.docx') or file.lower().endswith('.doc')) and have_libreoffice:
1015
  docs1 = UnstructuredWordDocumentLoader(file_path=file).load()
1016
  add_meta(docs1, file)
1017
  doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
1018
- elif (file.lower().endswith('.xlsx') or file.lower().endswith('.xls')) and have_libreoffice:
1019
  docs1 = UnstructuredExcelLoader(file_path=file).load()
1020
  add_meta(docs1, file)
1021
  doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
@@ -1114,21 +1244,54 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
1114
  from dotenv import dotenv_values
1115
  env_kwargs = dotenv_values(env_gpt4all_file)
1116
  pdf_class_name = env_kwargs.get('PDF_CLASS_NAME', 'PyMuPDFParser')
 
 
1117
  if have_pymupdf and pdf_class_name == 'PyMuPDFParser':
1118
  # GPL, only use if installed
1119
  from langchain.document_loaders import PyMuPDFLoader
1120
  # load() still chunks by pages, but every page has title at start to help
1121
  doc1 = PyMuPDFLoader(file).load()
 
 
 
1122
  doc1 = clean_doc(doc1)
1123
- elif pdf_class_name == 'UnstructuredPDFLoader':
1124
  doc1 = UnstructuredPDFLoader(file).load()
 
 
 
1125
  # seems to not need cleaning in most cases
1126
- else:
1127
  # open-source fallback
1128
  # load() still chunks by pages, but every page has title at start to help
1129
  doc1 = PyPDFLoader(file).load()
 
 
 
 
 
 
 
 
 
 
 
 
1130
  doc1 = clean_doc(doc1)
 
 
 
 
 
 
 
1131
  # Some PDFs return nothing or junk from PDFMinerLoader
 
 
 
 
 
 
1132
  doc1 = chunk_sources(doc1, chunk=chunk, chunk_size=chunk_size)
1133
  add_meta(doc1, file)
1134
  elif file.lower().endswith('.csv'):
@@ -1181,7 +1344,7 @@ def path_to_doc1(file, verbose=False, fail_any_exception=False, return_file=True
1181
  is_url=False, is_txt=False,
1182
  enable_captions=True,
1183
  captions_model=None,
1184
- enable_ocr=False, caption_loader=None):
1185
  if verbose:
1186
  if is_url:
1187
  print("Ingesting URL: %s" % file, flush=True)
@@ -1199,6 +1362,7 @@ def path_to_doc1(file, verbose=False, fail_any_exception=False, return_file=True
1199
  enable_captions=enable_captions,
1200
  captions_model=captions_model,
1201
  enable_ocr=enable_ocr,
 
1202
  caption_loader=caption_loader)
1203
  except BaseException as e:
1204
  print("Failed to ingest %s due to %s" % (file, traceback.format_exc()))
@@ -1207,7 +1371,7 @@ def path_to_doc1(file, verbose=False, fail_any_exception=False, return_file=True
1207
  else:
1208
  exception_doc = Document(
1209
  page_content='',
1210
- metadata={"source": file, "exception": '%s hit %s' % (file, str(e)),
1211
  "traceback": traceback.format_exc()})
1212
  res = [exception_doc]
1213
  if return_file:
@@ -1228,6 +1392,7 @@ def path_to_docs(path_or_paths, verbose=False, fail_any_exception=False, n_jobs=
1228
  captions_model=None,
1229
  caption_loader=None,
1230
  enable_ocr=False,
 
1231
  existing_files=[],
1232
  existing_hash_ids={},
1233
  ):
@@ -1249,11 +1414,15 @@ def path_to_docs(path_or_paths, verbose=False, fail_any_exception=False, n_jobs=
1249
  [globs_non_image_types.extend(glob.glob(os.path.join(path, "./**/*.%s" % ftype), recursive=True))
1250
  for ftype in non_image_types]
1251
  else:
1252
- if isinstance(path_or_paths, str) and (os.path.isfile(path_or_paths) or os.path.isdir(path_or_paths)):
1253
- path_or_paths = [path_or_paths]
 
 
 
 
1254
  # list/tuple of files (consume what can, and exception those that selected but cannot consume so user knows)
1255
- assert isinstance(path_or_paths, (list, tuple, types.GeneratorType)), "Wrong type for path_or_paths: %s" % type(
1256
- path_or_paths)
1257
  # reform out of allowed types
1258
  globs_image_types.extend(flatten_list([[x for x in path_or_paths if x.endswith(y)] for y in image_types]))
1259
  # could do below:
@@ -1305,6 +1474,7 @@ def path_to_docs(path_or_paths, verbose=False, fail_any_exception=False, n_jobs=
1305
  captions_model=captions_model,
1306
  caption_loader=caption_loader,
1307
  enable_ocr=enable_ocr,
 
1308
  )
1309
 
1310
  if n_jobs != 1 and len(globs_non_image_types) > 1:
@@ -1337,7 +1507,7 @@ def path_to_docs(path_or_paths, verbose=False, fail_any_exception=False, n_jobs=
1337
  with open(fil, 'rb') as f:
1338
  documents.extend(pickle.load(f))
1339
  # remove temp pickle
1340
- os.remove(fil)
1341
  else:
1342
  documents = reduce(concat, documents)
1343
  return documents
@@ -1345,7 +1515,7 @@ def path_to_docs(path_or_paths, verbose=False, fail_any_exception=False, n_jobs=
1345
 
1346
  def prep_langchain(persist_directory,
1347
  load_db_if_exists,
1348
- db_type, use_openai_embedding, langchain_mode, user_path,
1349
  hf_embedding_model, n_jobs=-1, kwargs_make_db={}):
1350
  """
1351
  do prep first time, involving downloads
@@ -1355,6 +1525,7 @@ def prep_langchain(persist_directory,
1355
  assert langchain_mode not in ['MyData'], "Should not prep scratch data"
1356
 
1357
  db_dir_exists = os.path.isdir(persist_directory)
 
1358
 
1359
  if db_dir_exists and user_path is None:
1360
  print("Prep: persist_directory=%s exists, using" % persist_directory, flush=True)
@@ -1490,7 +1661,7 @@ def make_db(**langchain_kwargs):
1490
  langchain_kwargs[k] = defaults_db[k]
1491
  # final check for missing
1492
  missing_kwargs = [x for x in func_names if x not in langchain_kwargs]
1493
- assert not missing_kwargs, "Missing kwargs: %s" % missing_kwargs
1494
  # only keep actual used
1495
  langchain_kwargs = {k: v for k, v in langchain_kwargs.items() if k in func_names}
1496
  return _make_db(**langchain_kwargs)
@@ -1524,13 +1695,14 @@ def _make_db(use_openai_embedding=False,
1524
  first_para=False, text_limit=None,
1525
  chunk=True, chunk_size=512,
1526
  langchain_mode=None,
1527
- user_path=None,
1528
  db_type='faiss',
1529
  load_db_if_exists=True,
1530
  db=None,
1531
  n_jobs=-1,
1532
  verbose=False):
1533
  persist_directory = get_persist_directory(langchain_mode)
 
1534
  # see if can get persistent chroma db
1535
  db_trial = get_existing_db(db, persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode,
1536
  hf_embedding_model, verbose=verbose)
@@ -1538,23 +1710,8 @@ def _make_db(use_openai_embedding=False,
1538
  db = db_trial
1539
 
1540
  sources = []
1541
- if not db and langchain_mode not in ['MyData'] or \
1542
- user_path is not None and \
1543
- langchain_mode in ['UserData']:
1544
- # Should not make MyData db this way, why avoided, only upload from UI
1545
- assert langchain_mode not in ['MyData'], "Should not make MyData db this way"
1546
- if verbose:
1547
- if langchain_mode in ['UserData']:
1548
- if user_path is not None:
1549
- print("Checking if changed or new sources in %s, and generating sources them" % user_path,
1550
- flush=True)
1551
- elif db is None:
1552
- print("user_path not passed and no db, no sources", flush=True)
1553
- else:
1554
- print("user_path not passed, using only existing db, no new sources", flush=True)
1555
- else:
1556
- print("Generating %s sources" % langchain_mode, flush=True)
1557
- if langchain_mode in ['wiki_full', 'All', "'All'"]:
1558
  from read_wiki_full import get_all_documents
1559
  small_test = None
1560
  print("Generating new wiki", flush=True)
@@ -1564,55 +1721,48 @@ def _make_db(use_openai_embedding=False,
1564
  sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size)
1565
  print("Chunked new wiki", flush=True)
1566
  sources.extend(sources1)
1567
- if langchain_mode in ['wiki', 'All', "'All'"]:
1568
  sources1 = get_wiki_sources(first_para=first_para, text_limit=text_limit)
1569
  if chunk:
1570
  sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size)
1571
  sources.extend(sources1)
1572
- if langchain_mode in ['github h2oGPT', 'All', "'All'"]:
1573
  # sources = get_github_docs("dagster-io", "dagster")
1574
  sources1 = get_github_docs("h2oai", "h2ogpt")
1575
  # FIXME: always chunk for now
1576
  sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size)
1577
  sources.extend(sources1)
1578
- if langchain_mode in ['DriverlessAI docs', 'All', "'All'"]:
1579
  sources1 = get_dai_docs(from_hf=True)
1580
  if chunk and False: # FIXME: DAI docs are already chunked well, should only chunk more if over limit
1581
  sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size)
1582
  sources.extend(sources1)
1583
- if langchain_mode in ['All', 'UserData']:
1584
- if user_path:
1585
- if db is not None:
1586
- # NOTE: Ignore file names for now, only go by hash ids
1587
- # existing_files = get_existing_files(db)
1588
- existing_files = []
1589
- existing_hash_ids = get_existing_hash_ids(db)
1590
- else:
1591
- # pretend no existing files so won't filter
1592
- existing_files = []
1593
- existing_hash_ids = []
1594
- # chunk internally for speed over multiple docs
1595
- # FIXME: If first had old Hash=None and switch embeddings,
1596
- # then re-embed, and then hit here and reload so have hash, and then re-embed.
1597
- sources1 = path_to_docs(user_path, n_jobs=n_jobs, chunk=chunk, chunk_size=chunk_size,
1598
- existing_files=existing_files, existing_hash_ids=existing_hash_ids)
1599
- new_metadata_sources = set([x.metadata['source'] for x in sources1])
1600
- if new_metadata_sources:
1601
- print("Loaded %s new files as sources to add to UserData" % len(new_metadata_sources), flush=True)
1602
- if verbose:
1603
- print("Files added: %s" % '\n'.join(new_metadata_sources), flush=True)
1604
- sources.extend(sources1)
1605
- print("Loaded %s sources for potentially adding to UserData" % len(sources), flush=True)
1606
- else:
1607
- print("Chose UserData but user_path is empty/None", flush=True)
1608
- if False and langchain_mode in ['urls', 'All', "'All'"]:
1609
- # from langchain.document_loaders import UnstructuredURLLoader
1610
- # loader = UnstructuredURLLoader(urls=urls)
1611
- urls = ["https://www.birdsongsf.com/who-we-are/"]
1612
- from langchain.document_loaders import PlaywrightURLLoader
1613
- loader = PlaywrightURLLoader(urls=urls, remove_selectors=["header", "footer"])
1614
- sources1 = loader.load()
1615
- sources.extend(sources1)
1616
  if not sources:
1617
  if verbose:
1618
  if db is not None:
@@ -1635,7 +1785,7 @@ def _make_db(use_openai_embedding=False,
1635
  else:
1636
  print("Did not generate db since no sources", flush=True)
1637
  new_sources_metadata = [x.metadata for x in sources]
1638
- elif user_path is not None and langchain_mode in ['UserData']:
1639
  print("Existing db, potentially adding %s sources from user_path=%s" % (len(sources), user_path), flush=True)
1640
  db, num_new_sources, new_sources_metadata = add_to_db(db, sources, db_type=db_type,
1641
  use_openai_embedding=use_openai_embedding,
@@ -1733,7 +1883,7 @@ def run_qa_db(**kwargs):
1733
  kwargs['answer_with_sources'] = True
1734
  kwargs['show_rank'] = False
1735
  missing_kwargs = [x for x in func_names if x not in kwargs]
1736
- assert not missing_kwargs, "Missing kwargs: %s" % missing_kwargs
1737
  # only keep actual used
1738
  kwargs = {k: v for k, v in kwargs.items() if k in func_names}
1739
  try:
@@ -1747,7 +1897,7 @@ def _run_qa_db(query=None,
1747
  context=None,
1748
  use_openai_model=False, use_openai_embedding=False,
1749
  first_para=False, text_limit=None, top_k_docs=4, chunk=True, chunk_size=512,
1750
- user_path=None,
1751
  detect_user_path_changes_every_query=False,
1752
  db_type='faiss',
1753
  model_name=None, model=None, tokenizer=None, inference_server=None,
@@ -1757,9 +1907,11 @@ def _run_qa_db(query=None,
1757
  prompt_type=None,
1758
  prompt_dict=None,
1759
  answer_with_sources=True,
1760
- cut_distanct=1.1,
 
1761
  sanitize_bot_response=False,
1762
  show_rank=False,
 
1763
  load_db_if_exists=False,
1764
  db=None,
1765
  do_sample=False,
@@ -1775,8 +1927,9 @@ def _run_qa_db(query=None,
1775
  num_return_sequences=1,
1776
  langchain_mode=None,
1777
  langchain_action=None,
1778
- document_subset=DocumentChoices.Relevant.name,
1779
- document_choice=[],
 
1780
  n_jobs=-1,
1781
  verbose=False,
1782
  cli=False,
@@ -1795,7 +1948,7 @@ def _run_qa_db(query=None,
1795
  :param top_k_docs:
1796
  :param chunk:
1797
  :param chunk_size:
1798
- :param user_path: user path to glob recursively from
1799
  :param db_type: 'faiss' for in-memory db or 'chroma' or 'weaviate' for persistent db
1800
  :param model_name: model name, used to switch behaviors
1801
  :param model: pre-initialized model, else will make new one
@@ -1803,6 +1956,7 @@ def _run_qa_db(query=None,
1803
  :param answer_with_sources
1804
  :return:
1805
  """
 
1806
  if model is not None:
1807
  assert model_name is not None # require so can make decisions
1808
  assert query is not None
@@ -1817,6 +1971,8 @@ def _run_qa_db(query=None,
1817
  else:
1818
  prompt_dict = ''
1819
  assert len(set(gen_hyper).difference(inspect.signature(get_llm).parameters)) == 0
 
 
1820
  llm, model_name, streamer, prompt_type_out = get_llm(use_openai_model=use_openai_model, model_name=model_name,
1821
  model=model,
1822
  tokenizer=tokenizer,
@@ -1836,11 +1992,13 @@ def _run_qa_db(query=None,
1836
  prompt_type=prompt_type,
1837
  prompt_dict=prompt_dict,
1838
  prompter=prompter,
 
 
1839
  sanitize_bot_response=sanitize_bot_response,
1840
  verbose=verbose,
1841
  )
1842
 
1843
- use_context = False
1844
  scores = []
1845
  chain = None
1846
 
@@ -1852,25 +2010,29 @@ def _run_qa_db(query=None,
1852
  sim_kwargs = {k: v for k, v in locals().items() if k in func_names}
1853
  missing_kwargs = [x for x in func_names if x not in sim_kwargs]
1854
  assert not missing_kwargs, "Missing: %s" % missing_kwargs
1855
- docs, chain, scores, use_context, have_any_docs = get_chain(**sim_kwargs)
1856
  if document_subset in non_query_commands:
1857
  formatted_doc_chunks = '\n\n'.join([get_url(x) + '\n\n' + x.page_content for x in docs])
 
 
 
 
1858
  yield formatted_doc_chunks, ''
1859
  return
1860
- if not docs and langchain_action in [LangChainAction.SUMMARIZE_MAP.value,
1861
- LangChainAction.SUMMARIZE_ALL.value,
1862
- LangChainAction.SUMMARIZE_REFINE.value]:
1863
- ret = 'No relevant documents to summarize.' if have_any_docs else 'No documents to summarize.'
1864
- extra = ''
1865
- yield ret, extra
1866
- return
1867
- if not docs and langchain_mode not in [LangChainMode.DISABLED.value,
1868
- LangChainMode.CHAT_LLM.value,
1869
- LangChainMode.LLM.value]:
1870
- ret = 'No relevant documents to query.' if have_any_docs else 'No documents to query.'
1871
- extra = ''
1872
- yield ret, extra
1873
- return
1874
 
1875
  if chain is None and model_name not in non_hf_types:
1876
  # here if no docs at all and not HF type
@@ -1921,7 +2083,7 @@ def _run_qa_db(query=None,
1921
  else:
1922
  answer = chain()
1923
 
1924
- if not use_context:
1925
  ret = answer['output_text']
1926
  extra = ''
1927
  yield ret, extra
@@ -1933,9 +2095,10 @@ def _run_qa_db(query=None,
1933
 
1934
  def get_chain(query=None,
1935
  iinput=None,
 
1936
  use_openai_model=False, use_openai_embedding=False,
1937
  first_para=False, text_limit=None, top_k_docs=4, chunk=True, chunk_size=512,
1938
- user_path=None,
1939
  detect_user_path_changes_every_query=False,
1940
  db_type='faiss',
1941
  model_name=None,
@@ -1943,13 +2106,15 @@ def get_chain(query=None,
1943
  hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
1944
  prompt_type=None,
1945
  prompt_dict=None,
1946
- cut_distanct=1.1,
 
1947
  load_db_if_exists=False,
1948
  db=None,
1949
  langchain_mode=None,
1950
  langchain_action=None,
1951
- document_subset=DocumentChoices.Relevant.name,
1952
- document_choice=[],
 
1953
  n_jobs=-1,
1954
  # beyond run_db_query:
1955
  llm=None,
@@ -1961,14 +2126,15 @@ def get_chain(query=None,
1961
  auto_reduce_chunks=True,
1962
  max_chunks=100,
1963
  ):
 
1964
  # determine whether use of context out of docs is planned
1965
  if not use_openai_model and prompt_type not in ['plain'] or model_name in non_hf_types:
1966
- if langchain_mode in ['Disabled', 'ChatLLM', 'LLM']:
1967
- use_context = False
1968
  else:
1969
- use_context = True
1970
  else:
1971
- use_context = True
1972
 
1973
  # https://github.com/hwchase17/langchain/issues/1946
1974
  # FIXME: Seems to way to get size of chroma db to limit top_k_docs to avoid
@@ -1985,14 +2151,17 @@ def get_chain(query=None,
1985
  # avoid looking at user_path during similarity search db handling,
1986
  # if already have db and not updating from user_path every query
1987
  # but if db is None, no db yet loaded (e.g. from prep), so allow user_path to be whatever it was
1988
- user_path = None
 
 
 
1989
  db, num_new_sources, new_sources_metadata = make_db(use_openai_embedding=use_openai_embedding,
1990
  hf_embedding_model=hf_embedding_model,
1991
  first_para=first_para, text_limit=text_limit,
1992
  chunk=chunk,
1993
  chunk_size=chunk_size,
1994
  langchain_mode=langchain_mode,
1995
- user_path=user_path,
1996
  db_type=db_type,
1997
  load_db_if_exists=load_db_if_exists,
1998
  db=db,
@@ -2012,7 +2181,7 @@ def get_chain(query=None,
2012
  else:
2013
  extra = ""
2014
  prefix = ""
2015
- if langchain_mode in ['Disabled', 'ChatLLM', 'LLM'] or not use_context:
2016
  template_if_no_docs = template = """%s{context}{question}""" % prefix
2017
  else:
2018
  template = """%s
@@ -2053,7 +2222,7 @@ def get_chain(query=None,
2053
  else:
2054
  use_template = False
2055
 
2056
- if db and use_context:
2057
  base_path = 'locks'
2058
  makedirs(base_path)
2059
  if hasattr(db, '_persist_directory'):
@@ -2067,10 +2236,10 @@ def get_chain(query=None,
2067
  filter_kwargs = {}
2068
  else:
2069
  assert document_choice is not None, "Document choice was None"
2070
- if len(document_choice) >= 1 and document_choice[0] == DocumentChoices.All.name:
2071
  filter_kwargs = {}
2072
  elif len(document_choice) >= 2:
2073
- if document_choice[0] == DocumentChoices.All.name:
2074
  # remove 'All'
2075
  document_choice = document_choice[1:]
2076
  or_filter = [{"source": {"$eq": x}} for x in document_choice]
@@ -2082,18 +2251,18 @@ def get_chain(query=None,
2082
  else:
2083
  # shouldn't reach
2084
  filter_kwargs = {}
2085
- if langchain_mode in [LangChainMode.LLM.value, LangChainMode.CHAT_LLM.value]:
2086
  docs = []
2087
  scores = []
2088
- elif document_subset == DocumentChoices.All.name or query in [None, '', '\n']:
2089
  db_documents, db_metadatas = get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs)
2090
  # similar to langchain's chroma's _results_to_docs_and_scores
2091
  docs_with_score = [(Document(page_content=result[0], metadata=result[1] or {}), 0)
2092
  for result in zip(db_documents, db_metadatas)]
2093
 
2094
  # order documents
2095
- doc_hashes = [x['doc_hash'] for x in db_metadatas]
2096
- doc_chunk_ids = [x['chunk_id'] for x in db_metadatas]
2097
  docs_with_score = [x for _, _, x in
2098
  sorted(zip(doc_hashes, doc_chunk_ids, docs_with_score), key=lambda x: (x[0], x[1]))
2099
  ]
@@ -2173,8 +2342,8 @@ def get_chain(query=None,
2173
  docs_with_score.reverse()
2174
  # cut off so no high distance docs/sources considered
2175
  have_any_docs |= len(docs_with_score) > 0 # before cut
2176
- docs = [x[0] for x in docs_with_score if x[1] < cut_distanct]
2177
- scores = [x[1] for x in docs_with_score if x[1] < cut_distanct]
2178
  if len(scores) > 0 and verbose:
2179
  print("Distance: min: %s max: %s mean: %s median: %s" %
2180
  (scores[0], scores[-1], np.mean(scores), np.median(scores)), flush=True)
@@ -2182,7 +2351,7 @@ def get_chain(query=None,
2182
  docs = []
2183
  scores = []
2184
 
2185
- if not docs and use_context and model_name not in non_hf_types:
2186
  # if HF type and have no docs, can bail out
2187
  return docs, None, [], False, have_any_docs
2188
 
@@ -2205,7 +2374,7 @@ def get_chain(query=None,
2205
 
2206
  if len(docs) == 0:
2207
  # avoid context == in prompt then
2208
- use_context = False
2209
  template = template_if_no_docs
2210
 
2211
  if langchain_action == LangChainAction.QUERY.value:
@@ -2221,7 +2390,7 @@ def get_chain(query=None,
2221
  else:
2222
  # only if use_openai_model = True, unused normally except in testing
2223
  chain = load_qa_with_sources_chain(llm)
2224
- if not use_context:
2225
  chain_kwargs = dict(input_documents=[], question=query)
2226
  else:
2227
  chain_kwargs = dict(input_documents=docs, question=query)
@@ -2248,7 +2417,7 @@ def get_chain(query=None,
2248
  else:
2249
  raise RuntimeError("No such langchain_action=%s" % langchain_action)
2250
 
2251
- return docs, target, scores, use_context, have_any_docs
2252
 
2253
 
2254
  def get_sources_answer(query, answer, scores, show_rank, answer_with_sources, verbose=False):
@@ -2302,6 +2471,7 @@ def clean_doc(docs1):
2302
 
2303
  def chunk_sources(sources, chunk=True, chunk_size=512, language=None):
2304
  if not chunk:
 
2305
  return sources
2306
  if not isinstance(sources, (list, tuple, types.GeneratorType)) and not callable(sources):
2307
  # if just one document
@@ -2320,8 +2490,7 @@ def chunk_sources(sources, chunk=True, chunk_size=512, language=None):
2320
  source_chunks = splitter.split_documents(sources)
2321
 
2322
  # currently in order, but when pull from db won't be, so mark order and document by hash
2323
- doc_hash = str(uuid.uuid4())[:10]
2324
- [x.metadata.update(dict(doc_hash=doc_hash, chunk_id=chunk_id)) for chunk_id, x in enumerate(source_chunks)]
2325
 
2326
  return source_chunks
2327
 
 
21
  from joblib import delayed
22
  from langchain.callbacks import streaming_stdout
23
  from langchain.embeddings import HuggingFaceInstructEmbeddings
24
+ from langchain.schema import LLMResult
25
  from tqdm import tqdm
26
 
27
+ from enums import DocumentSubset, no_lora_str, model_token_mapping, source_prefix, source_postfix, non_query_commands, \
28
+ LangChainAction, LangChainMode, DocumentChoice
29
  from evaluate_params import gen_hyper
30
  from gen import get_model, SEED
31
  from prompter import non_hf_types, PromptType, Prompter
32
  from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \
33
  get_device, ProgressParallel, remove, hash_file, clear_torch_cache, NullContext, get_hf_server, FakeTokenizer, \
34
+ have_libreoffice, have_arxiv, have_playwright, have_selenium, have_tesseract, have_pymupdf, set_openai
35
  from utils_langchain import StreamingGradioCallbackHandler
36
 
37
  import_matplotlib()
 
96
  db = get_existing_db(None, persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode,
97
  hf_embedding_model, verbose=False)
98
  if db is None:
99
+ from chromadb.config import Settings
100
+ client_settings = Settings(anonymized_telemetry=False,
101
+ chroma_db_impl="duckdb+parquet",
102
+ persist_directory=persist_directory)
103
  db = Chroma.from_documents(documents=sources,
104
  embedding=embedding,
105
  persist_directory=persist_directory,
106
  collection_name=collection_name,
107
+ client_settings=client_settings)
108
  db.persist()
109
  clear_embedding(db)
110
  save_embed(db, use_openai_embedding, hf_embedding_model)
 
281
 
282
  from pydantic import Extra, Field, root_validator
283
 
284
+ from langchain.callbacks.manager import CallbackManagerForLLMRun, Callbacks
 
 
 
 
 
 
 
 
285
  from langchain.llms.base import LLM
286
 
287
 
 
309
  sanitize_bot_response: bool = False
310
 
311
  prompter: Any = None
312
+ context: Any = ''
313
+ iinput: Any = ''
314
  client: Any = None
315
 
316
  class Config:
 
354
  stream_output = self.stream
355
  gr_client = self.client
356
  client_langchain_mode = 'Disabled'
357
+ client_add_chat_history_to_context = True
358
  client_langchain_action = LangChainAction.QUERY.value
359
+ client_langchain_agents = []
360
  top_k_docs = 1
361
  chunk = True
362
  chunk_size = 512
363
  client_kwargs = dict(instruction=prompt if self.chat_client else '', # only for chat=True
364
+ iinput=self.iinput if self.chat_client else '', # only for chat=True
365
+ context=self.context,
366
  # streaming output is supported, loops over and outputs each generation in streaming mode
367
  # but leave stream_output=False for simple input/output mode
368
  stream_output=stream_output,
 
383
  chat=self.chat_client,
384
 
385
  instruction_nochat=prompt if not self.chat_client else '',
386
+ iinput_nochat=self.iinput if not self.chat_client else '',
387
  langchain_mode=client_langchain_mode,
388
+ add_chat_history_to_context=client_add_chat_history_to_context,
389
  langchain_action=client_langchain_action,
390
+ langchain_agents=client_langchain_agents,
391
  top_k_docs=top_k_docs,
392
  chunk=chunk,
393
  chunk_size=chunk_size,
394
+ document_subset=DocumentSubset.Relevant.name,
395
+ document_choice=[DocumentChoice.ALL.value],
396
  )
397
  api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
398
  if not stream_output:
 
462
  stream: bool = False
463
  sanitize_bot_response: bool = False
464
  prompter: Any = None
465
+ context: Any = ''
466
+ iinput: Any = ''
467
  tokenizer: Any = None
468
  client: Any = None
469
 
 
505
  prompt, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt, self.tokenizer)
506
 
507
  # NOTE: TGI server does not add prompting, so must do here
508
+ data_point = dict(context=self.context, instruction=prompt, input=self.iinput)
509
  prompt = self.prompter.generate_prompt(data_point)
510
 
511
  gen_server_kwargs = dict(do_sample=self.do_sample,
 
571
 
572
 
573
  from langchain.chat_models import ChatOpenAI
574
+ from langchain.llms import OpenAI
575
+ from langchain.llms.openai import _streaming_response_template, completion_with_retry, _update_response, \
576
+ update_token_usage
577
+
578
+
579
+ class H2OOpenAI(OpenAI):
580
+ """
581
+ New class to handle vLLM's use of OpenAI, no vllm_chat supported, so only need here
582
+ Handles prompting that OpenAI doesn't need, stopping as well
583
+ """
584
+ stop_sequences: Any = None
585
+ sanitize_bot_response: bool = False
586
+ prompter: Any = None
587
+ context: Any = ''
588
+ iinput: Any = ''
589
+ tokenizer: Any = None
590
+
591
+ @classmethod
592
+ def all_required_field_names(cls) -> Set:
593
+ all_required_field_names = super(OpenAI, cls).all_required_field_names()
594
+ all_required_field_names.update(
595
+ {'top_p', 'frequency_penalty', 'presence_penalty', 'stop_sequences', 'sanitize_bot_response', 'prompter',
596
+ 'tokenizer'})
597
+ return all_required_field_names
598
+
599
+ def _generate(
600
+ self,
601
+ prompts: List[str],
602
+ stop: Optional[List[str]] = None,
603
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
604
+ **kwargs: Any,
605
+ ) -> LLMResult:
606
+ stop = self.stop_sequences if not stop else self.stop_sequences + stop
607
+
608
+ # HF inference server needs control over input tokens
609
+ assert self.tokenizer is not None
610
+ from h2oai_pipeline import H2OTextGenerationPipeline
611
+ for prompti, prompt in enumerate(prompts):
612
+ prompt, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt, self.tokenizer)
613
+ # NOTE: OpenAI/vLLM server does not add prompting, so must do here
614
+ data_point = dict(context=self.context, instruction=prompt, input=self.iinput)
615
+ prompt = self.prompter.generate_prompt(data_point)
616
+ prompts[prompti] = prompt
617
+
618
+ params = self._invocation_params
619
+ params = {**params, **kwargs}
620
+ sub_prompts = self.get_sub_prompts(params, prompts, stop)
621
+ choices = []
622
+ token_usage: Dict[str, int] = {}
623
+ # Get the token usage from the response.
624
+ # Includes prompt, completion, and total tokens used.
625
+ _keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
626
+ text = ''
627
+ for _prompts in sub_prompts:
628
+ if self.streaming:
629
+ text_with_prompt = ""
630
+ prompt = _prompts[0]
631
+ if len(_prompts) > 1:
632
+ raise ValueError("Cannot stream results with multiple prompts.")
633
+ params["stream"] = True
634
+ response = _streaming_response_template()
635
+ first = True
636
+ for stream_resp in completion_with_retry(
637
+ self, prompt=_prompts, **params
638
+ ):
639
+ if first:
640
+ stream_resp["choices"][0]["text"] = prompt + stream_resp["choices"][0]["text"]
641
+ first = False
642
+ text_chunk = stream_resp["choices"][0]["text"]
643
+ text_with_prompt += text_chunk
644
+ text = self.prompter.get_response(text_with_prompt, prompt=prompt,
645
+ sanitize_bot_response=self.sanitize_bot_response)
646
+ if run_manager:
647
+ run_manager.on_llm_new_token(
648
+ text_chunk,
649
+ verbose=self.verbose,
650
+ logprobs=stream_resp["choices"][0]["logprobs"],
651
+ )
652
+ _update_response(response, stream_resp)
653
+ choices.extend(response["choices"])
654
+ else:
655
+ response = completion_with_retry(self, prompt=_prompts, **params)
656
+ choices.extend(response["choices"])
657
+ if not self.streaming:
658
+ # Can't update token usage if streaming
659
+ update_token_usage(_keys, response, token_usage)
660
+ choices[0]['text'] = text
661
+ return self.create_llm_result(choices, prompts, token_usage)
662
 
663
 
664
  class H2OChatOpenAI(ChatOpenAI):
 
689
  prompt_type=None,
690
  prompt_dict=None,
691
  prompter=None,
692
+ context=None,
693
+ iinput=None,
694
  sanitize_bot_response=False,
695
  verbose=False,
696
  ):
697
+ if inference_server is None:
698
+ inference_server = ''
699
+ if use_openai_model or inference_server.startswith('openai') or inference_server.startswith('vllm'):
700
  if use_openai_model and model_name is None:
701
  model_name = "gpt-3.5-turbo"
702
+ # FIXME: Will later import be ignored? I think so, so should be fine
703
+ openai, inf_type = set_openai(inference_server)
704
+ kwargs_extra = {}
705
+ if inference_server == 'openai_chat' or inf_type == 'vllm_chat':
706
  cls = H2OChatOpenAI
707
+ # FIXME: Support context, iinput
708
+ else:
709
+ cls = H2OOpenAI
710
+ if inf_type == 'vllm':
711
+ terminate_response = prompter.terminate_response or []
712
+ stop_sequences = list(set(terminate_response + [prompter.PreResponse]))
713
+ stop_sequences = [x for x in stop_sequences if x]
714
+ kwargs_extra = dict(stop_sequences=stop_sequences,
715
+ sanitize_bot_response=sanitize_bot_response,
716
+ prompter=prompter,
717
+ context=context,
718
+ iinput=iinput,
719
+ tokenizer=tokenizer,
720
+ client=None)
721
+
722
  callbacks = [StreamingGradioCallbackHandler()]
723
  llm = cls(model_name=model_name,
724
  temperature=temperature if do_sample else 0,
 
728
  frequency_penalty=0,
729
  presence_penalty=1.07 - repetition_penalty + 0.6, # so good default
730
  callbacks=callbacks if stream_output else None,
731
+ openai_api_key=openai.api_key,
732
+ openai_api_base=openai.api_base,
733
+ logit_bias=None if inf_type == 'vllm' else {},
734
+ max_retries=2,
735
+ streaming=stream_output,
736
+ **kwargs_extra
737
  )
738
  streamer = callbacks[0] if stream_output else None
739
  if inference_server in ['openai', 'openai_chat']:
740
  prompt_type = inference_server
741
  else:
742
+ # vllm goes here
743
  prompt_type = prompt_type or 'plain'
744
  elif inference_server:
745
  assert inference_server.startswith(
 
788
  callbacks=callbacks if stream_output else None,
789
  stream=stream_output,
790
  prompter=prompter,
791
+ context=context,
792
+ iinput=iinput,
793
  client=gr_client,
794
  sanitize_bot_response=sanitize_bot_response,
795
  )
 
810
  callbacks=callbacks if stream_output else None,
811
  stream=stream_output,
812
  prompter=prompter,
813
+ context=context,
814
+ iinput=iinput,
815
  tokenizer=tokenizer,
816
  client=hf_client,
817
  timeout=max_time,
 
844
  verbose=verbose,
845
  streaming=stream_output,
846
  prompter=prompter,
847
+ context=context,
848
+ iinput=iinput,
849
  )
850
  else:
851
  if model is None:
 
888
  from h2oai_pipeline import H2OTextGenerationPipeline
889
  pipe = H2OTextGenerationPipeline(model=model, use_prompter=True,
890
  prompter=prompter,
891
+ context=context,
892
+ iinpout=iinput,
893
  prompt_type=prompt_type,
894
  prompt_dict=prompt_dict,
895
  sanitize_bot_response=sanitize_bot_response,
 
1043
  return sources
1044
 
1045
 
 
1046
  image_types = ["png", "jpg", "jpeg"]
1047
  non_image_types = ["pdf", "txt", "csv", "toml", "py", "rst", "rtf",
1048
  "md",
 
1053
  ]
1054
  # "msg", GPL3
1055
 
1056
+ if have_libreoffice or True:
1057
+ # or True so it tries to load, e.g. on MAC/Windows, even if don't have libreoffice since works without that
1058
  non_image_types.extend(["docx", "doc", "xls", "xlsx"])
1059
 
1060
  file_types = non_image_types + image_types
 
1063
  def add_meta(docs1, file):
1064
  file_extension = pathlib.Path(file).suffix
1065
  hashid = hash_file(file)
1066
+ doc_hash = str(uuid.uuid4())[:10]
1067
  if not isinstance(docs1, (list, tuple, types.GeneratorType)):
1068
  docs1 = [docs1]
1069
+ [x.metadata.update(dict(input_type=file_extension, date=str(datetime.now()), hashid=hashid, doc_hash=doc_hash)) for
1070
+ x in docs1]
1071
 
1072
 
1073
  def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
 
1075
  is_url=False, is_txt=False,
1076
  enable_captions=True,
1077
  captions_model=None,
1078
+ enable_ocr=False, enable_pdf_ocr='auto', caption_loader=None,
1079
  headsize=50):
1080
  if file is None:
1081
  if fail_any_exception:
 
1092
  base_name = sanitize_filename(base_name) + "_" + str(uuid.uuid4())[:10]
1093
  base_path = os.path.join(dir_name, base_name)
1094
  if is_url:
1095
+ file = file.strip() # in case accidental spaces in front or at end
1096
  if file.lower().startswith('arxiv:'):
1097
  query = file.lower().split('arxiv:')
1098
  if len(query) == 2 and have_arxiv:
 
1141
  add_meta(docs1, file)
1142
  docs1 = clean_doc(docs1)
1143
  doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size, language=Language.HTML)
1144
+ elif (file.lower().endswith('.docx') or file.lower().endswith('.doc')) and (have_libreoffice or True):
1145
  docs1 = UnstructuredWordDocumentLoader(file_path=file).load()
1146
  add_meta(docs1, file)
1147
  doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
1148
+ elif (file.lower().endswith('.xlsx') or file.lower().endswith('.xls')) and (have_libreoffice or True):
1149
  docs1 = UnstructuredExcelLoader(file_path=file).load()
1150
  add_meta(docs1, file)
1151
  doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
 
1244
  from dotenv import dotenv_values
1245
  env_kwargs = dotenv_values(env_gpt4all_file)
1246
  pdf_class_name = env_kwargs.get('PDF_CLASS_NAME', 'PyMuPDFParser')
1247
+ doc1 = []
1248
+ handled = False
1249
  if have_pymupdf and pdf_class_name == 'PyMuPDFParser':
1250
  # GPL, only use if installed
1251
  from langchain.document_loaders import PyMuPDFLoader
1252
  # load() still chunks by pages, but every page has title at start to help
1253
  doc1 = PyMuPDFLoader(file).load()
1254
+ # remove empty documents
1255
+ handled |= len(doc1) > 0
1256
+ doc1 = [x for x in doc1 if x.page_content]
1257
  doc1 = clean_doc(doc1)
1258
+ if len(doc1) == 0:
1259
  doc1 = UnstructuredPDFLoader(file).load()
1260
+ handled |= len(doc1) > 0
1261
+ # remove empty documents
1262
+ doc1 = [x for x in doc1 if x.page_content]
1263
  # seems to not need cleaning in most cases
1264
+ if len(doc1) == 0:
1265
  # open-source fallback
1266
  # load() still chunks by pages, but every page has title at start to help
1267
  doc1 = PyPDFLoader(file).load()
1268
+ handled |= len(doc1) > 0
1269
+ # remove empty documents
1270
+ doc1 = [x for x in doc1 if x.page_content]
1271
+ doc1 = clean_doc(doc1)
1272
+ if have_pymupdf and len(doc1) == 0:
1273
+ # GPL, only use if installed
1274
+ from langchain.document_loaders import PyMuPDFLoader
1275
+ # load() still chunks by pages, but every page has title at start to help
1276
+ doc1 = PyMuPDFLoader(file).load()
1277
+ handled |= len(doc1) > 0
1278
+ # remove empty documents
1279
+ doc1 = [x for x in doc1 if x.page_content]
1280
  doc1 = clean_doc(doc1)
1281
+ if len(doc1) == 0 and enable_pdf_ocr == 'auto' or enable_pdf_ocr == 'on':
1282
+ # try OCR in end since slowest, but works on pure image pages well
1283
+ doc1 = UnstructuredPDFLoader(file, strategy='ocr_only').load()
1284
+ handled |= len(doc1) > 0
1285
+ # remove empty documents
1286
+ doc1 = [x for x in doc1 if x.page_content]
1287
+ # seems to not need cleaning in most cases
1288
  # Some PDFs return nothing or junk from PDFMinerLoader
1289
+ if len(doc1) == 0:
1290
+ # if literally nothing, show failed to parse so user knows, since unlikely nothing in PDF at all.
1291
+ if handled:
1292
+ raise ValueError("%s had no valid text, but meta data was parsed" % file)
1293
+ else:
1294
+ raise ValueError("%s had no valid text and no meta data was parsed" % file)
1295
  doc1 = chunk_sources(doc1, chunk=chunk, chunk_size=chunk_size)
1296
  add_meta(doc1, file)
1297
  elif file.lower().endswith('.csv'):
 
1344
  is_url=False, is_txt=False,
1345
  enable_captions=True,
1346
  captions_model=None,
1347
+ enable_ocr=False, enable_pdf_ocr='auto', caption_loader=None):
1348
  if verbose:
1349
  if is_url:
1350
  print("Ingesting URL: %s" % file, flush=True)
 
1362
  enable_captions=enable_captions,
1363
  captions_model=captions_model,
1364
  enable_ocr=enable_ocr,
1365
+ enable_pdf_ocr=enable_pdf_ocr,
1366
  caption_loader=caption_loader)
1367
  except BaseException as e:
1368
  print("Failed to ingest %s due to %s" % (file, traceback.format_exc()))
 
1371
  else:
1372
  exception_doc = Document(
1373
  page_content='',
1374
+ metadata={"source": file, "exception": '%s Exception: %s' % (file, str(e)),
1375
  "traceback": traceback.format_exc()})
1376
  res = [exception_doc]
1377
  if return_file:
 
1392
  captions_model=None,
1393
  caption_loader=None,
1394
  enable_ocr=False,
1395
+ enable_pdf_ocr='auto',
1396
  existing_files=[],
1397
  existing_hash_ids={},
1398
  ):
 
1414
  [globs_non_image_types.extend(glob.glob(os.path.join(path, "./**/*.%s" % ftype), recursive=True))
1415
  for ftype in non_image_types]
1416
  else:
1417
+ if isinstance(path_or_paths, str):
1418
+ if os.path.isfile(path_or_paths) or os.path.isdir(path_or_paths):
1419
+ path_or_paths = [path_or_paths]
1420
+ else:
1421
+ # path was deleted etc.
1422
+ return []
1423
  # list/tuple of files (consume what can, and exception those that selected but cannot consume so user knows)
1424
+ assert isinstance(path_or_paths, (list, tuple, types.GeneratorType)), \
1425
+ "Wrong type for path_or_paths: %s %s" % (path_or_paths, type(path_or_paths))
1426
  # reform out of allowed types
1427
  globs_image_types.extend(flatten_list([[x for x in path_or_paths if x.endswith(y)] for y in image_types]))
1428
  # could do below:
 
1474
  captions_model=captions_model,
1475
  caption_loader=caption_loader,
1476
  enable_ocr=enable_ocr,
1477
+ enable_pdf_ocr=enable_pdf_ocr,
1478
  )
1479
 
1480
  if n_jobs != 1 and len(globs_non_image_types) > 1:
 
1507
  with open(fil, 'rb') as f:
1508
  documents.extend(pickle.load(f))
1509
  # remove temp pickle
1510
+ remove(fil)
1511
  else:
1512
  documents = reduce(concat, documents)
1513
  return documents
 
1515
 
1516
  def prep_langchain(persist_directory,
1517
  load_db_if_exists,
1518
+ db_type, use_openai_embedding, langchain_mode, langchain_mode_paths,
1519
  hf_embedding_model, n_jobs=-1, kwargs_make_db={}):
1520
  """
1521
  do prep first time, involving downloads
 
1525
  assert langchain_mode not in ['MyData'], "Should not prep scratch data"
1526
 
1527
  db_dir_exists = os.path.isdir(persist_directory)
1528
+ user_path = langchain_mode_paths.get(langchain_mode)
1529
 
1530
  if db_dir_exists and user_path is None:
1531
  print("Prep: persist_directory=%s exists, using" % persist_directory, flush=True)
 
1661
  langchain_kwargs[k] = defaults_db[k]
1662
  # final check for missing
1663
  missing_kwargs = [x for x in func_names if x not in langchain_kwargs]
1664
+ assert not missing_kwargs, "Missing kwargs for make_db: %s" % missing_kwargs
1665
  # only keep actual used
1666
  langchain_kwargs = {k: v for k, v in langchain_kwargs.items() if k in func_names}
1667
  return _make_db(**langchain_kwargs)
 
1695
  first_para=False, text_limit=None,
1696
  chunk=True, chunk_size=512,
1697
  langchain_mode=None,
1698
+ langchain_mode_paths=None,
1699
  db_type='faiss',
1700
  load_db_if_exists=True,
1701
  db=None,
1702
  n_jobs=-1,
1703
  verbose=False):
1704
  persist_directory = get_persist_directory(langchain_mode)
1705
+ user_path = langchain_mode_paths.get(langchain_mode)
1706
  # see if can get persistent chroma db
1707
  db_trial = get_existing_db(db, persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode,
1708
  hf_embedding_model, verbose=verbose)
 
1710
  db = db_trial
1711
 
1712
  sources = []
1713
+ if not db:
1714
+ if langchain_mode in ['wiki_full']:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1715
  from read_wiki_full import get_all_documents
1716
  small_test = None
1717
  print("Generating new wiki", flush=True)
 
1721
  sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size)
1722
  print("Chunked new wiki", flush=True)
1723
  sources.extend(sources1)
1724
+ elif langchain_mode in ['wiki']:
1725
  sources1 = get_wiki_sources(first_para=first_para, text_limit=text_limit)
1726
  if chunk:
1727
  sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size)
1728
  sources.extend(sources1)
1729
+ elif langchain_mode in ['github h2oGPT']:
1730
  # sources = get_github_docs("dagster-io", "dagster")
1731
  sources1 = get_github_docs("h2oai", "h2ogpt")
1732
  # FIXME: always chunk for now
1733
  sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size)
1734
  sources.extend(sources1)
1735
+ elif langchain_mode in ['DriverlessAI docs']:
1736
  sources1 = get_dai_docs(from_hf=True)
1737
  if chunk and False: # FIXME: DAI docs are already chunked well, should only chunk more if over limit
1738
  sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size)
1739
  sources.extend(sources1)
1740
+ if user_path:
1741
+ # UserData or custom, which has to be from user's disk
1742
+ if db is not None:
1743
+ # NOTE: Ignore file names for now, only go by hash ids
1744
+ # existing_files = get_existing_files(db)
1745
+ existing_files = []
1746
+ existing_hash_ids = get_existing_hash_ids(db)
1747
+ else:
1748
+ # pretend no existing files so won't filter
1749
+ existing_files = []
1750
+ existing_hash_ids = []
1751
+ # chunk internally for speed over multiple docs
1752
+ # FIXME: If first had old Hash=None and switch embeddings,
1753
+ # then re-embed, and then hit here and reload so have hash, and then re-embed.
1754
+ sources1 = path_to_docs(user_path, n_jobs=n_jobs, chunk=chunk, chunk_size=chunk_size,
1755
+ existing_files=existing_files, existing_hash_ids=existing_hash_ids)
1756
+ new_metadata_sources = set([x.metadata['source'] for x in sources1])
1757
+ if new_metadata_sources:
1758
+ print("Loaded %s new files as sources to add to %s" % (len(new_metadata_sources), langchain_mode),
1759
+ flush=True)
1760
+ if verbose:
1761
+ print("Files added: %s" % '\n'.join(new_metadata_sources), flush=True)
1762
+ sources.extend(sources1)
1763
+ print("Loaded %s sources for potentially adding to %s" % (len(sources), langchain_mode), flush=True)
1764
+
1765
+ # see if got sources
 
 
 
 
 
 
 
1766
  if not sources:
1767
  if verbose:
1768
  if db is not None:
 
1785
  else:
1786
  print("Did not generate db since no sources", flush=True)
1787
  new_sources_metadata = [x.metadata for x in sources]
1788
+ elif user_path is not None:
1789
  print("Existing db, potentially adding %s sources from user_path=%s" % (len(sources), user_path), flush=True)
1790
  db, num_new_sources, new_sources_metadata = add_to_db(db, sources, db_type=db_type,
1791
  use_openai_embedding=use_openai_embedding,
 
1883
  kwargs['answer_with_sources'] = True
1884
  kwargs['show_rank'] = False
1885
  missing_kwargs = [x for x in func_names if x not in kwargs]
1886
+ assert not missing_kwargs, "Missing kwargs for run_qa_db: %s" % missing_kwargs
1887
  # only keep actual used
1888
  kwargs = {k: v for k, v in kwargs.items() if k in func_names}
1889
  try:
 
1897
  context=None,
1898
  use_openai_model=False, use_openai_embedding=False,
1899
  first_para=False, text_limit=None, top_k_docs=4, chunk=True, chunk_size=512,
1900
+ langchain_mode_paths={},
1901
  detect_user_path_changes_every_query=False,
1902
  db_type='faiss',
1903
  model_name=None, model=None, tokenizer=None, inference_server=None,
 
1907
  prompt_type=None,
1908
  prompt_dict=None,
1909
  answer_with_sources=True,
1910
+ cut_distance=1.64,
1911
+ add_chat_history_to_context=True,
1912
  sanitize_bot_response=False,
1913
  show_rank=False,
1914
+ use_llm_if_no_docs=False,
1915
  load_db_if_exists=False,
1916
  db=None,
1917
  do_sample=False,
 
1927
  num_return_sequences=1,
1928
  langchain_mode=None,
1929
  langchain_action=None,
1930
+ langchain_agents=None,
1931
+ document_subset=DocumentSubset.Relevant.name,
1932
+ document_choice=[DocumentChoice.ALL.value],
1933
  n_jobs=-1,
1934
  verbose=False,
1935
  cli=False,
 
1948
  :param top_k_docs:
1949
  :param chunk:
1950
  :param chunk_size:
1951
+ :param langchain_mode_paths: dict of langchain_mode -> user path to glob recursively from
1952
  :param db_type: 'faiss' for in-memory db or 'chroma' or 'weaviate' for persistent db
1953
  :param model_name: model name, used to switch behaviors
1954
  :param model: pre-initialized model, else will make new one
 
1956
  :param answer_with_sources
1957
  :return:
1958
  """
1959
+ assert langchain_mode_paths is not None
1960
  if model is not None:
1961
  assert model_name is not None # require so can make decisions
1962
  assert query is not None
 
1971
  else:
1972
  prompt_dict = ''
1973
  assert len(set(gen_hyper).difference(inspect.signature(get_llm).parameters)) == 0
1974
+ # pass in context to LLM directly, since already has prompt_type structure
1975
+ # can't pass through langchain in get_chain() to LLM: https://github.com/hwchase17/langchain/issues/6638
1976
  llm, model_name, streamer, prompt_type_out = get_llm(use_openai_model=use_openai_model, model_name=model_name,
1977
  model=model,
1978
  tokenizer=tokenizer,
 
1992
  prompt_type=prompt_type,
1993
  prompt_dict=prompt_dict,
1994
  prompter=prompter,
1995
+ context=context if add_chat_history_to_context else '',
1996
+ iinput=iinput if add_chat_history_to_context else '',
1997
  sanitize_bot_response=sanitize_bot_response,
1998
  verbose=verbose,
1999
  )
2000
 
2001
+ use_docs_planned = False
2002
  scores = []
2003
  chain = None
2004
 
 
2010
  sim_kwargs = {k: v for k, v in locals().items() if k in func_names}
2011
  missing_kwargs = [x for x in func_names if x not in sim_kwargs]
2012
  assert not missing_kwargs, "Missing: %s" % missing_kwargs
2013
+ docs, chain, scores, use_docs_planned, have_any_docs = get_chain(**sim_kwargs)
2014
  if document_subset in non_query_commands:
2015
  formatted_doc_chunks = '\n\n'.join([get_url(x) + '\n\n' + x.page_content for x in docs])
2016
+ if not formatted_doc_chunks and not use_llm_if_no_docs:
2017
+ yield "No sources", ''
2018
+ return
2019
+ # if no souces, outside gpt_langchain, LLM will be used with '' input
2020
  yield formatted_doc_chunks, ''
2021
  return
2022
+ if not use_llm_if_no_docs:
2023
+ if not docs and langchain_action in [LangChainAction.SUMMARIZE_MAP.value,
2024
+ LangChainAction.SUMMARIZE_ALL.value,
2025
+ LangChainAction.SUMMARIZE_REFINE.value]:
2026
+ ret = 'No relevant documents to summarize.' if have_any_docs else 'No documents to summarize.'
2027
+ extra = ''
2028
+ yield ret, extra
2029
+ return
2030
+ if not docs and langchain_mode not in [LangChainMode.DISABLED.value,
2031
+ LangChainMode.LLM.value]:
2032
+ ret = 'No relevant documents to query.' if have_any_docs else 'No documents to query.'
2033
+ extra = ''
2034
+ yield ret, extra
2035
+ return
2036
 
2037
  if chain is None and model_name not in non_hf_types:
2038
  # here if no docs at all and not HF type
 
2083
  else:
2084
  answer = chain()
2085
 
2086
+ if not use_docs_planned:
2087
  ret = answer['output_text']
2088
  extra = ''
2089
  yield ret, extra
 
2095
 
2096
  def get_chain(query=None,
2097
  iinput=None,
2098
+ context=None, # FIXME: https://github.com/hwchase17/langchain/issues/6638
2099
  use_openai_model=False, use_openai_embedding=False,
2100
  first_para=False, text_limit=None, top_k_docs=4, chunk=True, chunk_size=512,
2101
+ langchain_mode_paths=None,
2102
  detect_user_path_changes_every_query=False,
2103
  db_type='faiss',
2104
  model_name=None,
 
2106
  hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
2107
  prompt_type=None,
2108
  prompt_dict=None,
2109
+ cut_distance=1.1,
2110
+ add_chat_history_to_context=True, # FIXME: https://github.com/hwchase17/langchain/issues/6638
2111
  load_db_if_exists=False,
2112
  db=None,
2113
  langchain_mode=None,
2114
  langchain_action=None,
2115
+ langchain_agents=None,
2116
+ document_subset=DocumentSubset.Relevant.name,
2117
+ document_choice=[DocumentChoice.ALL.value],
2118
  n_jobs=-1,
2119
  # beyond run_db_query:
2120
  llm=None,
 
2126
  auto_reduce_chunks=True,
2127
  max_chunks=100,
2128
  ):
2129
+ assert langchain_agents is not None # should be at least []
2130
  # determine whether use of context out of docs is planned
2131
  if not use_openai_model and prompt_type not in ['plain'] or model_name in non_hf_types:
2132
+ if langchain_mode in ['Disabled', 'LLM']:
2133
+ use_docs_planned = False
2134
  else:
2135
+ use_docs_planned = True
2136
  else:
2137
+ use_docs_planned = True
2138
 
2139
  # https://github.com/hwchase17/langchain/issues/1946
2140
  # FIXME: Seems to way to get size of chroma db to limit top_k_docs to avoid
 
2151
  # avoid looking at user_path during similarity search db handling,
2152
  # if already have db and not updating from user_path every query
2153
  # but if db is None, no db yet loaded (e.g. from prep), so allow user_path to be whatever it was
2154
+ if langchain_mode_paths is None:
2155
+ langchain_mode_paths = {}
2156
+ langchain_mode_paths = langchain_mode_paths.copy()
2157
+ langchain_mode_paths[langchain_mode] = None
2158
  db, num_new_sources, new_sources_metadata = make_db(use_openai_embedding=use_openai_embedding,
2159
  hf_embedding_model=hf_embedding_model,
2160
  first_para=first_para, text_limit=text_limit,
2161
  chunk=chunk,
2162
  chunk_size=chunk_size,
2163
  langchain_mode=langchain_mode,
2164
+ langchain_mode_paths=langchain_mode_paths,
2165
  db_type=db_type,
2166
  load_db_if_exists=load_db_if_exists,
2167
  db=db,
 
2181
  else:
2182
  extra = ""
2183
  prefix = ""
2184
+ if langchain_mode in ['Disabled', 'LLM'] or not use_docs_planned:
2185
  template_if_no_docs = template = """%s{context}{question}""" % prefix
2186
  else:
2187
  template = """%s
 
2222
  else:
2223
  use_template = False
2224
 
2225
+ if db and use_docs_planned:
2226
  base_path = 'locks'
2227
  makedirs(base_path)
2228
  if hasattr(db, '_persist_directory'):
 
2236
  filter_kwargs = {}
2237
  else:
2238
  assert document_choice is not None, "Document choice was None"
2239
+ if len(document_choice) >= 1 and document_choice[0] == DocumentChoice.ALL.value:
2240
  filter_kwargs = {}
2241
  elif len(document_choice) >= 2:
2242
+ if document_choice[0] == DocumentChoice.ALL.value:
2243
  # remove 'All'
2244
  document_choice = document_choice[1:]
2245
  or_filter = [{"source": {"$eq": x}} for x in document_choice]
 
2251
  else:
2252
  # shouldn't reach
2253
  filter_kwargs = {}
2254
+ if langchain_mode in [LangChainMode.LLM.value]:
2255
  docs = []
2256
  scores = []
2257
+ elif document_subset == DocumentSubset.TopKSources.name or query in [None, '', '\n']:
2258
  db_documents, db_metadatas = get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs)
2259
  # similar to langchain's chroma's _results_to_docs_and_scores
2260
  docs_with_score = [(Document(page_content=result[0], metadata=result[1] or {}), 0)
2261
  for result in zip(db_documents, db_metadatas)]
2262
 
2263
  # order documents
2264
+ doc_hashes = [x.get('doc_hash', 'None') for x in db_metadatas]
2265
+ doc_chunk_ids = [x.get('chunk_id', 0) for x in db_metadatas]
2266
  docs_with_score = [x for _, _, x in
2267
  sorted(zip(doc_hashes, doc_chunk_ids, docs_with_score), key=lambda x: (x[0], x[1]))
2268
  ]
 
2342
  docs_with_score.reverse()
2343
  # cut off so no high distance docs/sources considered
2344
  have_any_docs |= len(docs_with_score) > 0 # before cut
2345
+ docs = [x[0] for x in docs_with_score if x[1] < cut_distance]
2346
+ scores = [x[1] for x in docs_with_score if x[1] < cut_distance]
2347
  if len(scores) > 0 and verbose:
2348
  print("Distance: min: %s max: %s mean: %s median: %s" %
2349
  (scores[0], scores[-1], np.mean(scores), np.median(scores)), flush=True)
 
2351
  docs = []
2352
  scores = []
2353
 
2354
+ if not docs and use_docs_planned and model_name not in non_hf_types:
2355
  # if HF type and have no docs, can bail out
2356
  return docs, None, [], False, have_any_docs
2357
 
 
2374
 
2375
  if len(docs) == 0:
2376
  # avoid context == in prompt then
2377
+ use_docs_planned = False
2378
  template = template_if_no_docs
2379
 
2380
  if langchain_action == LangChainAction.QUERY.value:
 
2390
  else:
2391
  # only if use_openai_model = True, unused normally except in testing
2392
  chain = load_qa_with_sources_chain(llm)
2393
+ if not use_docs_planned:
2394
  chain_kwargs = dict(input_documents=[], question=query)
2395
  else:
2396
  chain_kwargs = dict(input_documents=docs, question=query)
 
2417
  else:
2418
  raise RuntimeError("No such langchain_action=%s" % langchain_action)
2419
 
2420
+ return docs, target, scores, use_docs_planned, have_any_docs
2421
 
2422
 
2423
  def get_sources_answer(query, answer, scores, show_rank, answer_with_sources, verbose=False):
 
2471
 
2472
  def chunk_sources(sources, chunk=True, chunk_size=512, language=None):
2473
  if not chunk:
2474
+ [x.metadata.update(dict(chunk_id=chunk_id)) for chunk_id, x in enumerate(sources)]
2475
  return sources
2476
  if not isinstance(sources, (list, tuple, types.GeneratorType)) and not callable(sources):
2477
  # if just one document
 
2490
  source_chunks = splitter.split_documents(sources)
2491
 
2492
  # currently in order, but when pull from db won't be, so mark order and document by hash
2493
+ [x.metadata.update(dict(chunk_id=chunk_id)) for chunk_id, x in enumerate(source_chunks)]
 
2494
 
2495
  return source_chunks
2496
 
gradio_runner.py CHANGED
@@ -50,16 +50,20 @@ def fix_pydantic_duplicate_validators_error():
50
 
51
  fix_pydantic_duplicate_validators_error()
52
 
53
- from enums import DocumentChoices, no_model_str, no_lora_str, no_server_str, LangChainAction, LangChainMode
 
54
  from gradio_themes import H2oTheme, SoftTheme, get_h2o_title, get_simple_title, get_dark_js, spacing_xsm, radius_xsm, \
55
  text_xsm
56
  from prompter import prompt_type_to_model_name, prompt_types_strings, inv_prompt_type_to_model_lower, non_hf_types, \
57
  get_prompt
58
- from utils import get_githash, flatten_list, zip_data, s3up, clear_torch_cache, get_torch_allocated, system_info_print, \
59
- ping, get_short_name, makedirs, get_kwargs, remove, system_info, ping_gpu, get_url, get_local_ip
60
- from gen import get_model, languages_covered, evaluate, score_qa, langchain_modes, inputs_kwargs_list, scratch_base_dir, \
61
- get_max_max_new_tokens, get_minmax_top_k_docs, history_to_context, langchain_actions
62
- from evaluate_params import eval_func_param_names, no_default_param_names, eval_func_param_names_defaults
 
 
 
63
 
64
  from apscheduler.schedulers.background import BackgroundScheduler
65
 
@@ -94,13 +98,11 @@ def go_gradio(**kwargs):
94
  memory_restriction_level = kwargs['memory_restriction_level']
95
  n_gpus = kwargs['n_gpus']
96
  admin_pass = kwargs['admin_pass']
97
- model_state0 = kwargs['model_state0']
98
  model_states = kwargs['model_states']
99
- score_model_state0 = kwargs['score_model_state0']
100
  dbs = kwargs['dbs']
101
  db_type = kwargs['db_type']
102
- visible_langchain_modes = kwargs['visible_langchain_modes']
103
  visible_langchain_actions = kwargs['visible_langchain_actions']
 
104
  allow_upload_to_user_data = kwargs['allow_upload_to_user_data']
105
  allow_upload_to_my_data = kwargs['allow_upload_to_my_data']
106
  enable_sources_list = kwargs['enable_sources_list']
@@ -111,8 +113,19 @@ def go_gradio(**kwargs):
111
  enable_captions = kwargs['enable_captions']
112
  captions_model = kwargs['captions_model']
113
  enable_ocr = kwargs['enable_ocr']
 
114
  caption_loader = kwargs['caption_loader']
115
 
 
 
 
 
 
 
 
 
 
 
116
  # easy update of kwargs needed for evaluate() etc.
117
  queue = True
118
  allow_upload = allow_upload_to_user_data or allow_upload_to_my_data
@@ -132,25 +145,11 @@ def go_gradio(**kwargs):
132
  " use Enter for multiple input lines)"
133
 
134
  title = 'h2oGPT'
135
- more_info = """<iframe src="https://ghbtns.com/github-btn.html?user=h2oai&repo=h2ogpt&type=star&count=true&size=small" frameborder="0" scrolling="0" width="250" height="20" title="GitHub"></iframe><small><a href="https://github.com/h2oai/h2ogpt">h2oGPT</a> <a href="https://github.com/h2oai/h2o-llmstudio">H2O LLM Studio</a><br><a href="https://huggingface.co/h2oai">🤗 Models</a>"""
136
- if kwargs['verbose']:
137
- description = f"""Model {kwargs['base_model']} Instruct dataset.
138
- For more information, visit our GitHub pages: [h2oGPT](https://github.com/h2oai/h2ogpt) and [H2O LLM Studio](https://github.com/h2oai/h2o-llmstudio).
139
- Command: {str(' '.join(sys.argv))}
140
- Hash: {get_githash()}
141
- """
142
- else:
143
- description = more_info
144
- description_bottom = "If this host is busy, try [Multi-Model](https://gpt.h2o.ai), [Falcon 40B](http://falcon.h2o.ai), [HF Spaces1](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot) or [HF Spaces2](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot2)<br>"
145
  if is_hf:
146
  description_bottom += '''<a href="https://huggingface.co/spaces/h2oai/h2ogpt-chatbot?duplicate=true"><img src="https://bit.ly/3gLdBN6" style="white-space: nowrap" alt="Duplicate Space"></a>'''
147
-
148
- if kwargs['verbose']:
149
- task_info_md = f"""
150
- ### Task: {kwargs['task_info']}"""
151
- else:
152
- task_info_md = ''
153
-
154
  css_code = get_css(kwargs)
155
 
156
  if kwargs['gradio_offline_level'] >= 0:
@@ -180,9 +179,9 @@ def go_gradio(**kwargs):
180
  demo = gr.Blocks(theme=theme, css=css_code, title="h2oGPT", analytics_enabled=False)
181
  callback = gr.CSVLogger()
182
 
183
- model_options = flatten_list(list(prompt_type_to_model_name.values())) + kwargs['extra_model_options']
184
- if kwargs['base_model'].strip() not in model_options:
185
- model_options = [kwargs['base_model'].strip()] + model_options
186
  lora_options = kwargs['extra_lora_options']
187
  if kwargs['lora_weights'].strip() not in lora_options:
188
  lora_options = [kwargs['lora_weights'].strip()] + lora_options
@@ -197,7 +196,7 @@ def go_gradio(**kwargs):
197
 
198
  # always add in no lora case
199
  # add fake space so doesn't go away in gradio dropdown
200
- model_options = [no_model_str] + model_options
201
  lora_options = [no_lora_str] + lora_options
202
  server_options = [no_server_str] + server_options
203
  # always add in no model case so can free memory
@@ -251,6 +250,14 @@ def go_gradio(**kwargs):
251
  # else gets input_list at time of submit that is old, and shows up as truncated in chatbot
252
  return x
253
 
 
 
 
 
 
 
 
 
254
  with demo:
255
  # avoid actual model/tokenizer here or anything that would be bad to deepcopy
256
  # https://github.com/gradio-app/gradio/issues/3558
@@ -264,18 +271,32 @@ def go_gradio(**kwargs):
264
  prompt_dict=kwargs['prompt_dict'],
265
  )
266
  )
 
 
 
 
 
 
 
 
 
 
 
267
  model_state2 = gr.State(kwargs['model_state_none'].copy())
268
- model_options_state = gr.State([model_options])
269
  lora_options_state = gr.State([lora_options])
270
  server_options_state = gr.State([server_options])
271
- my_db_state = gr.State([None, None])
272
  chat_state = gr.State({})
273
- docs_state00 = kwargs['document_choice'] + [DocumentChoices.All.name]
274
  docs_state0 = []
275
  [docs_state0.append(x) for x in docs_state00 if x not in docs_state0]
276
  docs_state = gr.State(docs_state0)
277
  viewable_docs_state0 = []
278
  viewable_docs_state = gr.State(viewable_docs_state0)
 
 
 
279
  gr.Markdown(f"""
280
  {get_h2o_title(title, description) if kwargs['h2ocolors'] else get_simple_title(title, description)}
281
  """)
@@ -289,7 +310,7 @@ def go_gradio(**kwargs):
289
  'model_lock'] else "Response Scores: %s" % nas
290
 
291
  if kwargs['langchain_mode'] != LangChainMode.DISABLED.value:
292
- extra_prompt_form = ". For summarization, empty submission uses first top_k_docs documents."
293
  else:
294
  extra_prompt_form = ""
295
  if kwargs['input_lines'] > 1:
@@ -297,6 +318,34 @@ def go_gradio(**kwargs):
297
  else:
298
  instruction_label = "Enter to Submit, Shift-Enter for more lines%s" % extra_prompt_form
299
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
  normal_block = gr.Row(visible=not base_wanted, equal_height=False)
301
  with normal_block:
302
  side_bar = gr.Column(elem_id="col_container", scale=1, min_width=100)
@@ -317,6 +366,7 @@ def go_gradio(**kwargs):
317
  scale=1,
318
  min_width=0,
319
  elem_id="warning", elem_classes="feedback")
 
320
  url_visible = kwargs['langchain_mode'] != 'Disabled' and allow_upload and enable_url_upload
321
  url_label = 'URL/ArXiv' if have_arxiv else 'URL'
322
  url_text = gr.Textbox(label=url_label,
@@ -330,29 +380,20 @@ def go_gradio(**kwargs):
330
  visible=text_visible)
331
  github_textbox = gr.Textbox(label="Github URL", visible=False) # FIXME WIP
332
  database_visible = kwargs['langchain_mode'] != 'Disabled'
333
- with gr.Accordion("Database", open=False, visible=database_visible):
334
- if is_hf:
335
- # don't show 'wiki' since only usually useful for internal testing at moment
336
- no_show_modes = ['Disabled', 'wiki']
337
- else:
338
- no_show_modes = ['Disabled']
339
- allowed_modes = visible_langchain_modes.copy()
340
- allowed_modes = [x for x in allowed_modes if x in dbs]
341
- allowed_modes += ['ChatLLM', 'LLM']
342
- if allow_upload_to_my_data and 'MyData' not in allowed_modes:
343
- allowed_modes += ['MyData']
344
- if allow_upload_to_user_data and 'UserData' not in allowed_modes:
345
- allowed_modes += ['UserData']
346
  langchain_mode = gr.Radio(
347
- [x for x in langchain_modes if x in allowed_modes and x not in no_show_modes],
348
  value=kwargs['langchain_mode'],
349
  label="Collections",
350
  show_label=True,
351
  visible=kwargs['langchain_mode'] != 'Disabled',
352
  min_width=100)
353
- document_subset = gr.Radio([x.name for x in DocumentChoices],
 
 
354
  label="Subset",
355
- value=DocumentChoices.Relevant.name,
356
  interactive=True,
357
  )
358
  allowed_actions = [x for x in langchain_actions if x in visible_langchain_actions]
@@ -361,6 +402,14 @@ def go_gradio(**kwargs):
361
  value=allowed_actions[0] if len(allowed_actions) > 0 else None,
362
  label="Action",
363
  visible=True)
 
 
 
 
 
 
 
 
364
  col_tabs = gr.Column(elem_id="col_container", scale=10)
365
  with (col_tabs, gr.Tabs()):
366
  with gr.TabItem("Chat"):
@@ -408,9 +457,9 @@ def go_gradio(**kwargs):
408
  mw1 = 50
409
  mw2 = 50
410
  with gr.Column(min_width=mw1):
411
- submit = gr.Button(value='Submit', variant='primary', scale=0, size='sm',
412
  min_width=mw1)
413
- stop_btn = gr.Button(value="Stop", variant='secondary', scale=0, size='sm',
414
  min_width=mw1)
415
  save_chat_btn = gr.Button("Save", size='sm', min_width=mw1)
416
  with gr.Column(min_width=mw2):
@@ -431,20 +480,50 @@ def go_gradio(**kwargs):
431
  with gr.TabItem("Document Selection"):
432
  document_choice = gr.Dropdown(docs_state0,
433
  label="Select Subset of Document(s) %s" % file_types_str,
434
- value='All',
435
  interactive=True,
436
  multiselect=True,
437
  visible=kwargs['langchain_mode'] != 'Disabled',
438
  )
439
  sources_visible = kwargs['langchain_mode'] != 'Disabled' and enable_sources_list
440
  with gr.Row():
441
- get_sources_btn = gr.Button(value="Update UI with Document(s) from DB", scale=0, size='sm',
442
- visible=sources_visible)
443
- show_sources_btn = gr.Button(value="Show Sources from DB", scale=0, size='sm',
444
- visible=sources_visible)
445
- refresh_sources_btn = gr.Button(value="Update DB with new/changed files on disk", scale=0,
446
- size='sm',
447
- visible=sources_visible and allow_upload_to_user_data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
448
 
449
  sources_row = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and enable_sources_list,
450
  equal_height=False)
@@ -469,6 +548,7 @@ def go_gradio(**kwargs):
469
  value=None,
470
  interactive=True,
471
  multiselect=False,
 
472
  )
473
  with gr.Column(scale=4):
474
  pass
@@ -713,19 +793,20 @@ def go_gradio(**kwargs):
713
  side_bar_btn = gr.Button("Toggle SideBar", variant="secondary", size="sm")
714
  submit_buttons_btn = gr.Button("Toggle Submit Buttons", variant="secondary", size="sm")
715
  col_tabs_scale = gr.Slider(minimum=1, maximum=20, value=10, step=1, label='Window Size')
716
- text_outputs_height = gr.Slider(minimum=100, maximum=1000, value=kwargs['height'] or 400,
717
- step=100, label='Chat Height')
718
  dark_mode_btn = gr.Button("Dark Mode", variant="secondary", size="sm")
719
  with gr.Column(scale=4):
720
  pass
 
721
  admin_row = gr.Row()
722
  with admin_row:
723
  with gr.Column(scale=1):
724
- admin_pass_textbox = gr.Textbox(label="Admin Password", type='password', visible=is_public)
725
- admin_btn = gr.Button(value="Admin Access", visible=is_public, size='sm')
726
  with gr.Column(scale=4):
727
  pass
728
- system_row = gr.Row(visible=not is_public)
729
  with system_row:
730
  with gr.Column():
731
  with gr.Row():
@@ -789,23 +870,24 @@ def go_gradio(**kwargs):
789
  else:
790
  return tuple([gr.update(interactive=True)] * len(args))
791
 
792
- # Add to UserData
793
  update_db_func = functools.partial(update_user_db,
794
  dbs=dbs,
795
  db_type=db_type,
796
  use_openai_embedding=use_openai_embedding,
797
  hf_embedding_model=hf_embedding_model,
798
- enable_captions=enable_captions,
799
  captions_model=captions_model,
800
- enable_ocr=enable_ocr,
801
  caption_loader=caption_loader,
 
 
802
  verbose=kwargs['verbose'],
803
- user_path=kwargs['user_path'],
804
  n_jobs=kwargs['n_jobs'],
805
  )
806
  add_file_outputs = [fileup_output, langchain_mode]
807
  add_file_kwargs = dict(fn=update_db_func,
808
- inputs=[fileup_output, my_db_state, chunk, chunk_size, langchain_mode],
 
809
  outputs=add_file_outputs + [sources_text, doc_exception_text],
810
  queue=queue,
811
  api_name='add_file' if allow_api and allow_upload_to_user_data else None)
@@ -817,6 +899,15 @@ def go_gradio(**kwargs):
817
  eventdb1b = eventdb1.then(make_interactive, inputs=add_file_outputs, outputs=add_file_outputs,
818
  show_progress='minimal')
819
 
 
 
 
 
 
 
 
 
 
820
  # note for update_user_db_func output is ignored for db
821
 
822
  def clear_textbox():
@@ -826,7 +917,8 @@ def go_gradio(**kwargs):
826
 
827
  add_url_outputs = [url_text, langchain_mode]
828
  add_url_kwargs = dict(fn=update_user_db_url_func,
829
- inputs=[url_text, my_db_state, chunk, chunk_size, langchain_mode],
 
830
  outputs=add_url_outputs + [sources_text, doc_exception_text],
831
  queue=queue,
832
  api_name='add_url' if allow_api and allow_upload_to_user_data else None)
@@ -843,7 +935,8 @@ def go_gradio(**kwargs):
843
  update_user_db_txt_func = functools.partial(update_db_func, is_txt=True)
844
  add_text_outputs = [user_text_text, langchain_mode]
845
  add_text_kwargs = dict(fn=update_user_db_txt_func,
846
- inputs=[user_text_text, my_db_state, chunk, chunk_size, langchain_mode],
 
847
  outputs=add_text_outputs + [sources_text, doc_exception_text],
848
  queue=queue,
849
  api_name='add_text' if allow_api and allow_upload_to_user_data else None
@@ -855,7 +948,7 @@ def go_gradio(**kwargs):
855
  eventdb3 = eventdb3b.then(**add_text_kwargs, show_progress='full')
856
  eventdb3c = eventdb3.then(make_interactive, inputs=add_text_outputs, outputs=add_text_outputs,
857
  show_progress='minimal')
858
- db_events = [eventdb1a, eventdb1, eventdb1b,
859
  eventdb2a, eventdb2, eventdb2b, eventdb2c,
860
  eventdb3a, eventdb3b, eventdb3, eventdb3c]
861
 
@@ -863,14 +956,14 @@ def go_gradio(**kwargs):
863
 
864
  # if change collection source, must clear doc selections from it to avoid inconsistency
865
  def clear_doc_choice():
866
- return gr.Dropdown.update(choices=docs_state0, value=DocumentChoices.All.name)
867
 
868
  langchain_mode.change(clear_doc_choice, inputs=None, outputs=document_choice, queue=False)
869
 
870
  def resize_col_tabs(x):
871
  return gr.Dropdown.update(scale=x)
872
 
873
- col_tabs_scale.change(fn=resize_col_tabs, inputs=col_tabs_scale, outputs=col_tabs)
874
 
875
  def resize_chatbots(x, num_model_lock=0):
876
  if num_model_lock == 0:
@@ -881,7 +974,7 @@ def go_gradio(**kwargs):
881
 
882
  resize_chatbots_func = functools.partial(resize_chatbots, num_model_lock=len(text_outputs))
883
  text_outputs_height.change(fn=resize_chatbots_func, inputs=text_outputs_height,
884
- outputs=[text_output, text_output2] + text_outputs)
885
 
886
  def update_dropdown(x):
887
  return gr.Dropdown.update(choices=x, value=[docs_state0[0]])
@@ -972,7 +1065,8 @@ def go_gradio(**kwargs):
972
  if file.startswith('http') or file.startswith('https'):
973
  # if file is online, then might as well use google(?)
974
  document1 = file
975
- return gr.update(visible=True, value=f"""<iframe width="1000" height="800" src="https://docs.google.com/viewerng/viewer?url={document1}&embedded=true" frameborder="0" height="100%" width="100%">
 
976
  </iframe>
977
  """), dummy1, dummy1, dummy1
978
  else:
@@ -995,9 +1089,11 @@ def go_gradio(**kwargs):
995
 
996
  refresh_sources1 = functools.partial(update_and_get_source_files_given_langchain_mode,
997
  **get_kwargs(update_and_get_source_files_given_langchain_mode,
998
- exclude_names=['db1', 'langchain_mode'],
 
999
  **all_kwargs))
1000
- eventdb9 = refresh_sources_btn.click(fn=refresh_sources1, inputs=[my_db_state, langchain_mode],
 
1001
  outputs=sources_text,
1002
  api_name='refresh_sources' if allow_api else None)
1003
 
@@ -1007,9 +1103,153 @@ def go_gradio(**kwargs):
1007
  def close_admin(x):
1008
  return gr.update(visible=not (x == admin_pass))
1009
 
1010
- admin_btn.click(check_admin_pass, inputs=admin_pass_textbox, outputs=system_row, queue=False) \
1011
  .then(close_admin, inputs=admin_pass_textbox, outputs=admin_row, queue=False)
1012
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1013
  inputs_list, inputs_dict = get_inputs_list(all_kwargs, kwargs['model_lower'], model_id=1)
1014
  inputs_list2, inputs_dict2 = get_inputs_list(all_kwargs, kwargs['model_lower'], model_id=2)
1015
  from functools import partial
@@ -1021,11 +1261,11 @@ def go_gradio(**kwargs):
1021
  def evaluate_nochat(*args1, default_kwargs1=None, str_api=False, **kwargs1):
1022
  args_list = list(args1)
1023
  if str_api:
1024
- user_kwargs = args_list[2]
1025
  assert isinstance(user_kwargs, str)
1026
  user_kwargs = ast.literal_eval(user_kwargs)
1027
  else:
1028
- user_kwargs = {k: v for k, v in zip(eval_func_param_names, args_list[2:])}
1029
  # only used for submit_nochat_api
1030
  user_kwargs['chat'] = False
1031
  if 'stream_output' not in user_kwargs:
@@ -1035,6 +1275,8 @@ def go_gradio(**kwargs):
1035
  user_kwargs['langchain_mode'] = 'Disabled'
1036
  if 'langchain_action' not in user_kwargs:
1037
  user_kwargs['langchain_action'] = LangChainAction.QUERY.value
 
 
1038
 
1039
  set1 = set(list(default_kwargs1.keys()))
1040
  set2 = set(eval_func_param_names)
@@ -1042,10 +1284,11 @@ def go_gradio(**kwargs):
1042
  # correct ordering. Note some things may not be in default_kwargs, so can't be default of user_kwargs.get()
1043
  model_state1 = args_list[0]
1044
  my_db_state1 = args_list[1]
 
1045
  args_list = [user_kwargs[k] if k in user_kwargs and user_kwargs[k] is not None else default_kwargs1[k] for k
1046
  in eval_func_param_names]
1047
  assert len(args_list) == len(eval_func_param_names)
1048
- args_list = [model_state1, my_db_state1] + args_list
1049
 
1050
  try:
1051
  for res_dict in evaluate(*tuple(args_list), **kwargs1):
@@ -1216,6 +1459,7 @@ def go_gradio(**kwargs):
1216
  prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
1217
  langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
1218
  langchain_action1 = args_list[eval_func_param_names.index('langchain_action')]
 
1219
  document_subset1 = args_list[eval_func_param_names.index('document_subset')]
1220
  document_choice1 = args_list[eval_func_param_names.index('document_choice')]
1221
  if not prompt_type1:
@@ -1248,10 +1492,7 @@ def go_gradio(**kwargs):
1248
  history[-1][1] = None
1249
  return history
1250
  if user_message1 in ['', None, '\n']:
1251
- if langchain_action1 in LangChainAction.QUERY.value and \
1252
- DocumentChoices.All.name != document_subset1 \
1253
- or \
1254
- langchain_mode1 in [LangChainMode.CHAT_LLM.value, LangChainMode.LLM.value]:
1255
  # reject non-retry submit/enter
1256
  return history
1257
  user_message1 = fix_text_for_gradio(user_message1)
@@ -1298,10 +1539,12 @@ def go_gradio(**kwargs):
1298
  API only called for which_model=0, default for inputs_list, but rest should ignore inputs_list
1299
  :return: last element is True if should run bot, False if should just yield history
1300
  """
 
1301
  # don't deepcopy, can contain model itself
1302
  args_list = list(args).copy()
1303
- model_state1 = args_list[-3]
1304
- my_db_state1 = args_list[-2]
 
1305
  history = args_list[-1]
1306
  prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
1307
  prompt_dict1 = args_list[eval_func_param_names.index('prompt_dict')]
@@ -1309,9 +1552,11 @@ def go_gradio(**kwargs):
1309
  if model_state1['model'] is None or model_state1['model'] == no_model_str:
1310
  return history, None, None, None
1311
 
1312
- args_list = args_list[:-3] # only keep rest needed for evaluate()
1313
  langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
 
1314
  langchain_action1 = args_list[eval_func_param_names.index('langchain_action')]
 
1315
  document_subset1 = args_list[eval_func_param_names.index('document_subset')]
1316
  document_choice1 = args_list[eval_func_param_names.index('document_choice')]
1317
  if not history:
@@ -1324,10 +1569,7 @@ def go_gradio(**kwargs):
1324
  instruction1 = history[-1][0]
1325
  history[-1][1] = None
1326
  elif not instruction1:
1327
- if langchain_action1 in LangChainAction.QUERY.value and \
1328
- DocumentChoices.All.name != document_choice1 \
1329
- or \
1330
- langchain_mode1 in [LangChainMode.CHAT_LLM.value, LangChainMode.LLM.value]:
1331
  # if not retrying, then reject empty query
1332
  return history, None, None, None
1333
  elif len(history) > 0 and history[-1][1] not in [None, '']:
@@ -1344,7 +1586,9 @@ def go_gradio(**kwargs):
1344
 
1345
  chat1 = args_list[eval_func_param_names.index('chat')]
1346
  model_max_length1 = get_model_max_length(model_state1)
1347
- context1 = history_to_context(history, langchain_mode1, prompt_type1, prompt_dict1, chat1,
 
 
1348
  model_max_length1, memory_restriction_level,
1349
  kwargs['keep_sources_in_context'])
1350
  args_list[0] = instruction1 # override original instruction with history from user
@@ -1353,6 +1597,7 @@ def go_gradio(**kwargs):
1353
  fun1 = partial(evaluate,
1354
  model_state1,
1355
  my_db_state1,
 
1356
  *tuple(args_list),
1357
  **kwargs_evaluate)
1358
 
@@ -1398,24 +1643,26 @@ def go_gradio(**kwargs):
1398
  clear_torch_cache()
1399
  return
1400
 
1401
- def clear_embeddings(langchain_mode1, my_db):
1402
  # clear any use of embedding that sits on GPU, else keeps accumulating GPU usage even if clear torch cache
1403
- if db_type == 'chroma' and langchain_mode1 not in ['ChatLLM', 'LLM', 'Disabled', None, '']:
1404
  from gpt_langchain import clear_embedding
1405
  db = dbs.get('langchain_mode1')
1406
  if db is not None and not isinstance(db, str):
1407
  clear_embedding(db)
1408
- if langchain_mode1 == LangChainMode.MY_DATA.value and my_db is not None:
1409
- clear_embedding(my_db[0])
 
 
1410
 
1411
  def bot(*args, retry=False):
1412
- history, fun1, langchain_mode1, my_db_state1 = prep_bot(*args, retry=retry)
1413
  try:
1414
  for res in get_response(fun1, history):
1415
  yield res
1416
  finally:
1417
  clear_torch_cache()
1418
- clear_embeddings(langchain_mode1, my_db_state1)
1419
 
1420
  def all_bot(*args, retry=False, model_states1=None):
1421
  args_list = list(args).copy()
@@ -1425,12 +1672,14 @@ def go_gradio(**kwargs):
1425
  stream_output1 = args_list[eval_func_param_names.index('stream_output')]
1426
  max_time1 = args_list[eval_func_param_names.index('max_time')]
1427
  langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
1428
- my_db_state1 = None # will be filled below by some bot
 
1429
  try:
1430
  gen_list = []
1431
  for chatboti, (chatbot1, model_state1) in enumerate(zip(chatbots, model_states1)):
1432
  args_list1 = args_list0.copy()
1433
- args_list1.insert(-1, model_state1) # insert at -1 so is at -2
 
1434
  # if at start, have None in response still, replace with '' so client etc. acts like normal
1435
  # assumes other parts of code treat '' and None as if no response yet from bot
1436
  # can't do this later in bot code as racy with threaded generators
@@ -1440,8 +1689,8 @@ def go_gradio(**kwargs):
1440
  # so consistent with prep_bot()
1441
  # with model_state1 at -3, my_db_state1 at -2, and history(chatbot) at -1
1442
  # langchain_mode1 and my_db_state1 should be same for every bot
1443
- history, fun1, langchain_mode1, my_db_state1 = prep_bot(*tuple(args_list1), retry=retry,
1444
- which_model=chatboti)
1445
  gen1 = get_response(fun1, history)
1446
  if stream_output1:
1447
  gen1 = TimeoutIterator(gen1, timeout=0.01, sentinel=None, raise_on_exception=False)
@@ -1487,7 +1736,7 @@ def go_gradio(**kwargs):
1487
  print("Generate exceptions: %s" % exceptions, flush=True)
1488
  finally:
1489
  clear_torch_cache()
1490
- clear_embeddings(langchain_mode1, my_db_state1)
1491
 
1492
  # NORMAL MODEL
1493
  user_args = dict(fn=functools.partial(user, sanitize_user_prompt=kwargs['sanitize_user_prompt']),
@@ -1495,11 +1744,11 @@ def go_gradio(**kwargs):
1495
  outputs=text_output,
1496
  )
1497
  bot_args = dict(fn=bot,
1498
- inputs=inputs_list + [model_state, my_db_state] + [text_output],
1499
  outputs=[text_output, chat_exception_text],
1500
  )
1501
  retry_bot_args = dict(fn=functools.partial(bot, retry=True),
1502
- inputs=inputs_list + [model_state, my_db_state] + [text_output],
1503
  outputs=[text_output, chat_exception_text],
1504
  )
1505
  retry_user_args = dict(fn=functools.partial(user, retry=True),
@@ -1517,11 +1766,11 @@ def go_gradio(**kwargs):
1517
  outputs=text_output2,
1518
  )
1519
  bot_args2 = dict(fn=bot,
1520
- inputs=inputs_list2 + [model_state2, my_db_state] + [text_output2],
1521
  outputs=[text_output2, chat_exception_text],
1522
  )
1523
  retry_bot_args2 = dict(fn=functools.partial(bot, retry=True),
1524
- inputs=inputs_list2 + [model_state2, my_db_state] + [text_output2],
1525
  outputs=[text_output2, chat_exception_text],
1526
  )
1527
  retry_user_args2 = dict(fn=functools.partial(user, retry=True),
@@ -1542,11 +1791,11 @@ def go_gradio(**kwargs):
1542
  outputs=text_outputs,
1543
  )
1544
  all_bot_args = dict(fn=functools.partial(all_bot, model_states1=model_states),
1545
- inputs=inputs_list + [my_db_state] + text_outputs,
1546
  outputs=text_outputs + [chat_exception_text],
1547
  )
1548
  all_retry_bot_args = dict(fn=functools.partial(all_bot, model_states1=model_states, retry=True),
1549
- inputs=inputs_list + [my_db_state] + text_outputs,
1550
  outputs=text_outputs + [chat_exception_text],
1551
  )
1552
  all_retry_user_args = dict(fn=functools.partial(all_user, retry=True,
@@ -1708,6 +1957,11 @@ def go_gradio(**kwargs):
1708
  def get_short_chat(x, short_chats, short_len=20, words=4):
1709
  if x and len(x[0]) == 2 and x[0][0] is not None:
1710
  short_chat = ' '.join(x[0][0][:short_len].split(' ')[:words]).strip()
 
 
 
 
 
1711
  short_chat = dedup(short_chat, short_chats)
1712
  else:
1713
  short_chat = None
@@ -1775,14 +2029,12 @@ def go_gradio(**kwargs):
1775
  already_exists = any([is_chat_same(chat_list, x) for x in old_chat_lists])
1776
  if not already_exists:
1777
  chat_state1[short_chat] = chat_list.copy()
1778
- # clear chat_list so saved and then new conversation starts
1779
- # FIXME: seems less confusing to clear, since have clear button right next
1780
- # chat_list = [[]] * len(chat_list)
1781
- if not chat_is_list:
1782
- ret_list = chat_list + [chat_state1]
1783
- else:
1784
- ret_list = [chat_list] + [chat_state1]
1785
- return tuple(ret_list)
1786
 
1787
  def switch_chat(chat_key, chat_state1, num_model_lock=0):
1788
  chosen_chat = chat_state1[chat_key]
@@ -1813,7 +2065,7 @@ def go_gradio(**kwargs):
1813
 
1814
  remove_chat_event = remove_chat_btn.click(remove_chat,
1815
  inputs=[radio_chats, chat_state], outputs=[radio_chats, chat_state],
1816
- queue=False)
1817
 
1818
  def get_chats1(chat_state1):
1819
  base = 'chats'
@@ -1844,7 +2096,7 @@ def go_gradio(**kwargs):
1844
  new_chats = json.loads(f.read())
1845
  for chat1_k, chat1_v in new_chats.items():
1846
  # ignore chat1_k, regenerate and de-dup to avoid loss
1847
- _, chat_state1 = save_chat(chat1_v, chat_state1, chat_is_list=True)
1848
  except BaseException as e:
1849
  t, v, tb = sys.exc_info()
1850
  ex = ''.join(traceback.format_exception(t, v, tb))
@@ -1870,24 +2122,17 @@ def go_gradio(**kwargs):
1870
  .then(deselect_radio_chats, inputs=None, outputs=radio_chats, queue=False) \
1871
  .then(clear_scores, outputs=[score_text, score_text2, score_text_nochat])
1872
 
1873
- def update_radio_chats(chat_state1):
1874
- # reverse so newest at top
1875
- choices = list(chat_state1.keys()).copy()
1876
- choices.reverse()
1877
- return gr.update(choices=choices, value=None)
1878
-
1879
  clear_event = save_chat_btn.click(save_chat,
1880
  inputs=[text_output, text_output2] + text_outputs + [chat_state],
1881
- outputs=[text_output, text_output2] + text_outputs + [chat_state],
1882
- api_name='save_chat' if allow_api else None) \
1883
- .then(update_radio_chats, inputs=chat_state, outputs=radio_chats,
1884
- api_name='update_chats' if allow_api else None) \
1885
- .then(clear_scores, outputs=[score_text, score_text2, score_text_nochat])
1886
 
1887
  # NOTE: clear of instruction/iinput for nochat has to come after score,
1888
  # because score for nochat consumes actual textbox, while chat consumes chat history filled by user()
1889
  no_chat_args = dict(fn=fun,
1890
- inputs=[model_state, my_db_state] + inputs_list,
1891
  outputs=text_output_nochat,
1892
  queue=queue,
1893
  )
@@ -1906,7 +2151,8 @@ def go_gradio(**kwargs):
1906
  .then(clear_torch_cache)
1907
 
1908
  submit_event_nochat_api = submit_nochat_api.click(fun_with_dict_str,
1909
- inputs=[model_state, my_db_state, inputs_dict_str],
 
1910
  outputs=text_output_nochat_api,
1911
  queue=True, # required for generator
1912
  api_name='submit_nochat_api' if allow_api else None) \
@@ -2156,6 +2402,8 @@ def go_gradio(**kwargs):
2156
  print("Exception: %s" % str(e), flush=True)
2157
  return json.dumps(sys_dict)
2158
 
 
 
2159
  get_system_info_dict_func = functools.partial(get_system_info_dict, **all_kwargs)
2160
 
2161
  system_dict_event = system_btn2.click(get_system_info_dict_func,
@@ -2185,12 +2433,15 @@ def go_gradio(**kwargs):
2185
  else:
2186
  tokenizer = None
2187
  if tokenizer is not None:
2188
- langchain_mode1 = 'ChatLLM'
 
2189
  # fake user message to mimic bot()
2190
  chat1 = copy.deepcopy(chat1)
2191
  chat1 = chat1 + [['user_message1', None]]
2192
  model_max_length1 = tokenizer.model_max_length
2193
- context1 = history_to_context(chat1, langchain_mode1, prompt_type1, prompt_dict1, chat1,
 
 
2194
  model_max_length1,
2195
  memory_restriction_level1, keep_sources_in_context1)
2196
  return str(tokenizer(context1, return_tensors="pt")['input_ids'].shape[1])
@@ -2220,7 +2471,7 @@ def go_gradio(**kwargs):
2220
  ,
2221
  queue=False, api_name='stop' if allow_api else None).then(clear_torch_cache, queue=False)
2222
 
2223
- demo.load(None, None, None, _js=get_dark_js() if kwargs['h2ocolors'] and False else None) # light best
2224
 
2225
  demo.queue(concurrency_count=kwargs['concurrency_count'], api_open=kwargs['api_open'])
2226
  favicon_path = "h2o-logo.svg"
@@ -2235,7 +2486,8 @@ def go_gradio(**kwargs):
2235
  # FIXME: disable for gptj, langchain or gpt4all modify print itself
2236
  # FIXME: and any multi-threaded/async print will enter model output!
2237
  scheduler.add_job(func=ping, trigger="interval", seconds=60)
2238
- scheduler.add_job(func=ping_gpu, trigger="interval", seconds=60 * 10)
 
2239
  scheduler.start()
2240
 
2241
  # import control
@@ -2254,9 +2506,6 @@ def go_gradio(**kwargs):
2254
  demo.block_thread()
2255
 
2256
 
2257
- input_args_list = ['model_state', 'my_db_state']
2258
-
2259
-
2260
  def get_inputs_list(inputs_dict, model_lower, model_id=1):
2261
  """
2262
  map gradio objects in locals() to inputs for evaluate().
@@ -2290,8 +2539,9 @@ def get_inputs_list(inputs_dict, model_lower, model_id=1):
2290
  return inputs_list, inputs_dict_out
2291
 
2292
 
2293
- def get_sources(db1, langchain_mode, dbs=None, docs_state0=None):
2294
- set_userid(db1)
 
2295
 
2296
  if langchain_mode in ['ChatLLM', 'LLM']:
2297
  source_files_added = "NA"
@@ -2300,7 +2550,8 @@ def get_sources(db1, langchain_mode, dbs=None, docs_state0=None):
2300
  source_files_added = "Not showing wiki_full, takes about 20 seconds and makes 4MB file." \
2301
  " Ask jon.mckinney@h2o.ai for file if required."
2302
  source_list = []
2303
- elif langchain_mode == 'MyData' and len(db1) > 0 and db1[0] is not None:
 
2304
  from gpt_langchain import get_metadatas
2305
  metadatas = get_metadatas(db1[0])
2306
  source_list = sorted(set([x['source'] for x in metadatas]))
@@ -2331,14 +2582,13 @@ def set_userid(db1):
2331
  db1[1] = str(uuid.uuid4())
2332
 
2333
 
2334
- def update_user_db(file, db1, chunk, chunk_size, langchain_mode, dbs=None, **kwargs):
2335
- set_userid(db1)
2336
-
2337
  if file is None:
2338
  raise RuntimeError("Don't use change, use input")
2339
 
2340
  try:
2341
- return _update_user_db(file, db1=db1, chunk=chunk, chunk_size=chunk_size,
2342
  langchain_mode=langchain_mode, dbs=dbs,
2343
  **kwargs)
2344
  except BaseException as e:
@@ -2369,25 +2619,30 @@ def get_lock_file(db1, langchain_mode):
2369
  user_id = db1[1]
2370
  base_path = 'locks'
2371
  makedirs(base_path)
2372
- lock_file = "db_%s_%s.lock" % (langchain_mode.replace(' ', '_'), user_id)
2373
  return lock_file
2374
 
2375
 
2376
  def _update_user_db(file,
2377
- db1=None,
2378
  chunk=None, chunk_size=None,
2379
- dbs=None, db_type=None, langchain_mode='UserData',
2380
- user_path=None,
 
 
 
2381
  use_openai_embedding=None,
2382
  hf_embedding_model=None,
2383
  caption_loader=None,
2384
  enable_captions=None,
2385
  captions_model=None,
2386
  enable_ocr=None,
 
2387
  verbose=None,
 
2388
  is_url=None, is_txt=None,
2389
- n_jobs=-1):
2390
- assert db1 is not None
2391
  assert chunk is not None
2392
  assert chunk_size is not None
2393
  assert use_openai_embedding is not None
@@ -2396,10 +2651,9 @@ def _update_user_db(file,
2396
  assert enable_captions is not None
2397
  assert captions_model is not None
2398
  assert enable_ocr is not None
 
2399
  assert verbose is not None
2400
 
2401
- set_userid(db1)
2402
-
2403
  if dbs is None:
2404
  dbs = {}
2405
  assert isinstance(dbs, dict), "Wrong type for dbs: %s" % str(type(dbs))
@@ -2417,17 +2671,22 @@ def _update_user_db(file,
2417
  if langchain_mode == LangChainMode.DISABLED.value:
2418
  return None, langchain_mode, get_source_files(), ""
2419
 
2420
- if langchain_mode in [LangChainMode.CHAT_LLM.value, LangChainMode.CHAT_LLM.value]:
2421
  # then switch to MyData, so langchain_mode also becomes way to select where upload goes
2422
  # but default to mydata if nothing chosen, since safest
2423
- langchain_mode = LangChainMode.MY_DATA.value
2424
-
2425
- if langchain_mode == 'UserData' and user_path is not None:
 
 
 
 
 
2426
  # move temp files from gradio upload to stable location
2427
  for fili, fil in enumerate(file):
2428
- if isinstance(fil, str):
2429
- if fil.startswith('/tmp/gradio/'):
2430
- new_fil = os.path.join(user_path, os.path.basename(fil))
2431
  if os.path.isfile(new_fil):
2432
  remove(new_fil)
2433
  try:
@@ -2447,15 +2706,22 @@ def _update_user_db(file,
2447
  enable_captions=enable_captions,
2448
  captions_model=captions_model,
2449
  enable_ocr=enable_ocr,
 
2450
  caption_loader=caption_loader,
2451
  )
2452
  exceptions = [x for x in sources if x.metadata.get('exception')]
2453
  exceptions_strs = [x.metadata['exception'] for x in exceptions]
2454
  sources = [x for x in sources if 'exception' not in x.metadata]
2455
 
2456
- lock_file = get_lock_file(db1, langchain_mode)
 
 
 
 
 
 
2457
  with filelock.FileLock(lock_file):
2458
- if langchain_mode == 'MyData':
2459
  if db1[0] is not None:
2460
  # then add
2461
  db, num_new_sources, new_sources_metadata = add_to_db(db1[0], sources, db_type=db_type,
@@ -2465,7 +2731,8 @@ def _update_user_db(file,
2465
  # in testing expect:
2466
  # assert len(db1) == 2 and db1[1] is None, "Bad MyData db: %s" % db1
2467
  # for production hit, when user gets clicky:
2468
- assert len(db1) == 2, "Bad MyData db: %s" % db1
 
2469
  # then create
2470
  # if added has to original state and didn't change, then would be shared db for all users
2471
  persist_directory = os.path.join(scratch_base_dir, 'db_dir_%s_%s' % (langchain_mode, db1[1]))
@@ -2487,7 +2754,7 @@ def _update_user_db(file,
2487
  use_openai_embedding=use_openai_embedding,
2488
  hf_embedding_model=hf_embedding_model)
2489
  else:
2490
- # then create
2491
  db = get_db(sources, use_openai_embedding=use_openai_embedding,
2492
  db_type=db_type,
2493
  persist_directory=persist_directory,
@@ -2501,14 +2768,15 @@ def _update_user_db(file,
2501
  return None, langchain_mode, source_files_added, '\n'.join(exceptions_strs)
2502
 
2503
 
2504
- def get_db(db1, langchain_mode, dbs=None):
2505
- lock_file = get_lock_file(db1, langchain_mode)
 
2506
 
2507
  with filelock.FileLock(lock_file):
2508
  if langchain_mode in ['wiki_full']:
2509
  # NOTE: avoid showing full wiki. Takes about 30 seconds over about 90k entries, but not useful for now
2510
  db = None
2511
- elif langchain_mode == 'MyData' and len(db1) > 0 and db1[0] is not None:
2512
  db = db1[0]
2513
  elif dbs is not None and langchain_mode in dbs and dbs[langchain_mode] is not None:
2514
  db = dbs[langchain_mode]
@@ -2517,8 +2785,8 @@ def get_db(db1, langchain_mode, dbs=None):
2517
  return db
2518
 
2519
 
2520
- def get_source_files_given_langchain_mode(db1, langchain_mode='UserData', dbs=None):
2521
- db = get_db(db1, langchain_mode, dbs=dbs)
2522
  if langchain_mode in ['ChatLLM', 'LLM'] or db is None:
2523
  return "Sources: N/A"
2524
  return get_source_files(db=db, exceptions=None)
@@ -2617,11 +2885,19 @@ def get_source_files(db=None, exceptions=None, metadatas=None):
2617
  return source_files_added
2618
 
2619
 
2620
- def update_and_get_source_files_given_langchain_mode(db1, langchain_mode, dbs=None, first_para=None,
2621
- text_limit=None, chunk=None, chunk_size=None,
2622
- user_path=None, db_type=None, load_db_if_exists=None,
 
2623
  n_jobs=None, verbose=None):
2624
- db = get_db(db1, langchain_mode, dbs=dbs)
 
 
 
 
 
 
 
2625
 
2626
  from gpt_langchain import make_db
2627
  db, num_new_sources, new_sources_metadata = make_db(use_openai_embedding=False,
@@ -2630,11 +2906,27 @@ def update_and_get_source_files_given_langchain_mode(db1, langchain_mode, dbs=No
2630
  chunk=chunk,
2631
  chunk_size=chunk_size,
2632
  langchain_mode=langchain_mode,
2633
- user_path=user_path,
2634
  db_type=db_type,
2635
  load_db_if_exists=load_db_if_exists,
2636
  db=db,
2637
  n_jobs=n_jobs,
2638
  verbose=verbose)
 
 
 
 
 
 
 
2639
  # return only new sources with text saying such
2640
  return get_source_files(db=None, exceptions=None, metadatas=new_sources_metadata)
 
 
 
 
 
 
 
 
 
 
50
 
51
  fix_pydantic_duplicate_validators_error()
52
 
53
+ from enums import DocumentSubset, no_model_str, no_lora_str, no_server_str, LangChainAction, LangChainMode, \
54
+ DocumentChoice, langchain_modes_intrinsic
55
  from gradio_themes import H2oTheme, SoftTheme, get_h2o_title, get_simple_title, get_dark_js, spacing_xsm, radius_xsm, \
56
  text_xsm
57
  from prompter import prompt_type_to_model_name, prompt_types_strings, inv_prompt_type_to_model_lower, non_hf_types, \
58
  get_prompt
59
+ from utils import flatten_list, zip_data, s3up, clear_torch_cache, get_torch_allocated, system_info_print, \
60
+ ping, get_short_name, makedirs, get_kwargs, remove, system_info, ping_gpu, get_url, get_local_ip, \
61
+ save_collection_names
62
+ from gen import get_model, languages_covered, evaluate, score_qa, inputs_kwargs_list, scratch_base_dir, \
63
+ get_max_max_new_tokens, get_minmax_top_k_docs, history_to_context, langchain_actions, langchain_agents_list, \
64
+ update_langchain
65
+ from evaluate_params import eval_func_param_names, no_default_param_names, eval_func_param_names_defaults, \
66
+ input_args_list
67
 
68
  from apscheduler.schedulers.background import BackgroundScheduler
69
 
 
98
  memory_restriction_level = kwargs['memory_restriction_level']
99
  n_gpus = kwargs['n_gpus']
100
  admin_pass = kwargs['admin_pass']
 
101
  model_states = kwargs['model_states']
 
102
  dbs = kwargs['dbs']
103
  db_type = kwargs['db_type']
 
104
  visible_langchain_actions = kwargs['visible_langchain_actions']
105
+ visible_langchain_agents = kwargs['visible_langchain_agents']
106
  allow_upload_to_user_data = kwargs['allow_upload_to_user_data']
107
  allow_upload_to_my_data = kwargs['allow_upload_to_my_data']
108
  enable_sources_list = kwargs['enable_sources_list']
 
113
  enable_captions = kwargs['enable_captions']
114
  captions_model = kwargs['captions_model']
115
  enable_ocr = kwargs['enable_ocr']
116
+ enable_pdf_ocr = kwargs['enable_pdf_ocr']
117
  caption_loader = kwargs['caption_loader']
118
 
119
+ # for dynamic state per user session in gradio
120
+ model_state0 = kwargs['model_state0']
121
+ score_model_state0 = kwargs['score_model_state0']
122
+ my_db_state0 = kwargs['my_db_state0']
123
+ selection_docs_state0 = kwargs['selection_docs_state0']
124
+ # for evaluate defaults
125
+ langchain_modes0 = kwargs['langchain_modes']
126
+ visible_langchain_modes0 = kwargs['visible_langchain_modes']
127
+ langchain_mode_paths0 = kwargs['langchain_mode_paths']
128
+
129
  # easy update of kwargs needed for evaluate() etc.
130
  queue = True
131
  allow_upload = allow_upload_to_user_data or allow_upload_to_my_data
 
145
  " use Enter for multiple input lines)"
146
 
147
  title = 'h2oGPT'
148
+ description = """<iframe src="https://ghbtns.com/github-btn.html?user=h2oai&repo=h2ogpt&type=star&count=true&size=small" frameborder="0" scrolling="0" width="250" height="20" title="GitHub"></iframe><small><a href="https://github.com/h2oai/h2ogpt">h2oGPT</a> <a href="https://github.com/h2oai/h2o-llmstudio">H2O LLM Studio</a><br><a href="https://huggingface.co/h2oai">🤗 Models</a>"""
149
+ description_bottom = "If this host is busy, try<br>[Multi-Model](https://gpt.h2o.ai)<br>[Falcon 40B](https://falcon.h2o.ai)<br>[Vicuna 33B](https://wizardvicuna.h2o.ai)<br>[MPT 30B-Chat](https://mpt.h2o.ai)<br>[HF Spaces1](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot)<br>[HF Spaces2](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot2)<br>"
 
 
 
 
 
 
 
 
150
  if is_hf:
151
  description_bottom += '''<a href="https://huggingface.co/spaces/h2oai/h2ogpt-chatbot?duplicate=true"><img src="https://bit.ly/3gLdBN6" style="white-space: nowrap" alt="Duplicate Space"></a>'''
152
+ task_info_md = ''
 
 
 
 
 
 
153
  css_code = get_css(kwargs)
154
 
155
  if kwargs['gradio_offline_level'] >= 0:
 
179
  demo = gr.Blocks(theme=theme, css=css_code, title="h2oGPT", analytics_enabled=False)
180
  callback = gr.CSVLogger()
181
 
182
+ model_options0 = flatten_list(list(prompt_type_to_model_name.values())) + kwargs['extra_model_options']
183
+ if kwargs['base_model'].strip() not in model_options0:
184
+ model_options0 = [kwargs['base_model'].strip()] + model_options0
185
  lora_options = kwargs['extra_lora_options']
186
  if kwargs['lora_weights'].strip() not in lora_options:
187
  lora_options = [kwargs['lora_weights'].strip()] + lora_options
 
196
 
197
  # always add in no lora case
198
  # add fake space so doesn't go away in gradio dropdown
199
+ model_options0 = [no_model_str] + model_options0
200
  lora_options = [no_lora_str] + lora_options
201
  server_options = [no_server_str] + server_options
202
  # always add in no model case so can free memory
 
250
  # else gets input_list at time of submit that is old, and shows up as truncated in chatbot
251
  return x
252
 
253
+ def allow_empty_instruction(langchain_mode1, document_subset1, langchain_action1):
254
+ allow = False
255
+ allow |= langchain_action1 not in LangChainAction.QUERY.value
256
+ allow |= document_subset1 in DocumentSubset.TopKSources.name
257
+ if langchain_mode1 in [LangChainMode.LLM.value]:
258
+ allow = False
259
+ return allow
260
+
261
  with demo:
262
  # avoid actual model/tokenizer here or anything that would be bad to deepcopy
263
  # https://github.com/gradio-app/gradio/issues/3558
 
271
  prompt_dict=kwargs['prompt_dict'],
272
  )
273
  )
274
+
275
+ def update_langchain_mode_paths(db1s, selection_docs_state1):
276
+ if allow_upload_to_my_data:
277
+ selection_docs_state1['langchain_mode_paths'].update({k: None for k in db1s})
278
+ dup = selection_docs_state1['langchain_mode_paths'].copy()
279
+ for k, v in dup.items():
280
+ if k not in selection_docs_state1['visible_langchain_modes']:
281
+ selection_docs_state1['langchain_mode_paths'].pop(k)
282
+ return selection_docs_state1
283
+
284
+ # Setup some gradio states for per-user dynamic state
285
  model_state2 = gr.State(kwargs['model_state_none'].copy())
286
+ model_options_state = gr.State([model_options0])
287
  lora_options_state = gr.State([lora_options])
288
  server_options_state = gr.State([server_options])
289
+ my_db_state = gr.State(my_db_state0)
290
  chat_state = gr.State({})
291
+ docs_state00 = kwargs['document_choice'] + [DocumentChoice.ALL.value]
292
  docs_state0 = []
293
  [docs_state0.append(x) for x in docs_state00 if x not in docs_state0]
294
  docs_state = gr.State(docs_state0)
295
  viewable_docs_state0 = []
296
  viewable_docs_state = gr.State(viewable_docs_state0)
297
+ selection_docs_state0 = update_langchain_mode_paths(my_db_state0, selection_docs_state0)
298
+ selection_docs_state = gr.State(selection_docs_state0)
299
+
300
  gr.Markdown(f"""
301
  {get_h2o_title(title, description) if kwargs['h2ocolors'] else get_simple_title(title, description)}
302
  """)
 
310
  'model_lock'] else "Response Scores: %s" % nas
311
 
312
  if kwargs['langchain_mode'] != LangChainMode.DISABLED.value:
313
+ extra_prompt_form = ". For summarization, no query required, just click submit"
314
  else:
315
  extra_prompt_form = ""
316
  if kwargs['input_lines'] > 1:
 
318
  else:
319
  instruction_label = "Enter to Submit, Shift-Enter for more lines%s" % extra_prompt_form
320
 
321
+ def get_langchain_choices(selection_docs_state1):
322
+ langchain_modes = selection_docs_state1['langchain_modes']
323
+ visible_langchain_modes = selection_docs_state1['visible_langchain_modes']
324
+
325
+ if is_hf:
326
+ # don't show 'wiki' since only usually useful for internal testing at moment
327
+ no_show_modes = ['Disabled', 'wiki']
328
+ else:
329
+ no_show_modes = ['Disabled']
330
+ allowed_modes = visible_langchain_modes.copy()
331
+ # allowed_modes = [x for x in allowed_modes if x in dbs]
332
+ allowed_modes += ['LLM']
333
+ if allow_upload_to_my_data and 'MyData' not in allowed_modes:
334
+ allowed_modes += ['MyData']
335
+ if allow_upload_to_user_data and 'UserData' not in allowed_modes:
336
+ allowed_modes += ['UserData']
337
+ choices = [x for x in langchain_modes if x in allowed_modes and x not in no_show_modes]
338
+ return choices
339
+
340
+ def get_df_langchain_mode_paths(selection_docs_state1):
341
+ langchain_mode_paths = selection_docs_state1['langchain_mode_paths']
342
+ if langchain_mode_paths:
343
+ df = pd.DataFrame.from_dict(langchain_mode_paths.items(), orient='columns')
344
+ df.columns = ['Collection', 'Path']
345
+ else:
346
+ df = pd.DataFrame(None)
347
+ return df
348
+
349
  normal_block = gr.Row(visible=not base_wanted, equal_height=False)
350
  with normal_block:
351
  side_bar = gr.Column(elem_id="col_container", scale=1, min_width=100)
 
366
  scale=1,
367
  min_width=0,
368
  elem_id="warning", elem_classes="feedback")
369
+ fileup_output_text = gr.Textbox(visible=False)
370
  url_visible = kwargs['langchain_mode'] != 'Disabled' and allow_upload and enable_url_upload
371
  url_label = 'URL/ArXiv' if have_arxiv else 'URL'
372
  url_text = gr.Textbox(label=url_label,
 
380
  visible=text_visible)
381
  github_textbox = gr.Textbox(label="Github URL", visible=False) # FIXME WIP
382
  database_visible = kwargs['langchain_mode'] != 'Disabled'
383
+ with gr.Accordion("Resources", open=False, visible=database_visible):
384
+ langchain_choices0 = get_langchain_choices(selection_docs_state0)
 
 
 
 
 
 
 
 
 
 
 
385
  langchain_mode = gr.Radio(
386
+ langchain_choices0,
387
  value=kwargs['langchain_mode'],
388
  label="Collections",
389
  show_label=True,
390
  visible=kwargs['langchain_mode'] != 'Disabled',
391
  min_width=100)
392
+ add_chat_history_to_context = gr.Checkbox(label="Chat History",
393
+ value=kwargs['add_chat_history_to_context'])
394
+ document_subset = gr.Radio([x.name for x in DocumentSubset],
395
  label="Subset",
396
+ value=DocumentSubset.Relevant.name,
397
  interactive=True,
398
  )
399
  allowed_actions = [x for x in langchain_actions if x in visible_langchain_actions]
 
402
  value=allowed_actions[0] if len(allowed_actions) > 0 else None,
403
  label="Action",
404
  visible=True)
405
+ allowed_agents = [x for x in langchain_agents_list if x in visible_langchain_agents]
406
+ langchain_agents = gr.Dropdown(
407
+ langchain_agents_list,
408
+ value=kwargs['langchain_agents'],
409
+ label="Agents",
410
+ multiselect=True,
411
+ interactive=True,
412
+ visible=False) # WIP
413
  col_tabs = gr.Column(elem_id="col_container", scale=10)
414
  with (col_tabs, gr.Tabs()):
415
  with gr.TabItem("Chat"):
 
457
  mw1 = 50
458
  mw2 = 50
459
  with gr.Column(min_width=mw1):
460
+ submit = gr.Button(value='Submit', variant='primary', size='sm',
461
  min_width=mw1)
462
+ stop_btn = gr.Button(value="Stop", variant='secondary', size='sm',
463
  min_width=mw1)
464
  save_chat_btn = gr.Button("Save", size='sm', min_width=mw1)
465
  with gr.Column(min_width=mw2):
 
480
  with gr.TabItem("Document Selection"):
481
  document_choice = gr.Dropdown(docs_state0,
482
  label="Select Subset of Document(s) %s" % file_types_str,
483
+ value=[DocumentChoice.ALL.value],
484
  interactive=True,
485
  multiselect=True,
486
  visible=kwargs['langchain_mode'] != 'Disabled',
487
  )
488
  sources_visible = kwargs['langchain_mode'] != 'Disabled' and enable_sources_list
489
  with gr.Row():
490
+ with gr.Column(scale=1):
491
+ get_sources_btn = gr.Button(value="Update UI with Document(s) from DB", scale=0, size='sm',
492
+ visible=sources_visible)
493
+ show_sources_btn = gr.Button(value="Show Sources from DB", scale=0, size='sm',
494
+ visible=sources_visible)
495
+ refresh_sources_btn = gr.Button(value="Update DB with new/changed files on disk", scale=0,
496
+ size='sm',
497
+ visible=sources_visible and allow_upload_to_user_data)
498
+ with gr.Column(scale=4):
499
+ pass
500
+ with gr.Row():
501
+ with gr.Column(scale=1):
502
+ add_placeholder = "e.g. UserData2, user_path2 (optional)" \
503
+ if not is_public else "e.g. MyData2"
504
+ remove_placeholder = "e.g. UserData2" if not is_public else "e.g. MyData2"
505
+ new_langchain_mode_text = gr.Textbox(value="", visible=allow_upload_to_user_data or
506
+ allow_upload_to_my_data,
507
+ label='Add Collection',
508
+ placeholder=add_placeholder,
509
+ interactive=True)
510
+ remove_langchain_mode_text = gr.Textbox(value="", visible=allow_upload_to_user_data or
511
+ allow_upload_to_my_data,
512
+ label='Remove Collection',
513
+ placeholder=remove_placeholder,
514
+ interactive=True)
515
+ load_langchain = gr.Button(value="Load LangChain State", scale=0, size='sm',
516
+ visible=allow_upload_to_user_data)
517
+ with gr.Column(scale=1):
518
+ df0 = get_df_langchain_mode_paths(selection_docs_state0)
519
+ langchain_mode_path_text = gr.Dataframe(value=df0,
520
+ visible=allow_upload_to_user_data or
521
+ allow_upload_to_my_data,
522
+ label='LangChain Mode-Path',
523
+ show_label=False,
524
+ interactive=False)
525
+ with gr.Column(scale=4):
526
+ pass
527
 
528
  sources_row = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and enable_sources_list,
529
  equal_height=False)
 
548
  value=None,
549
  interactive=True,
550
  multiselect=False,
551
+ visible=True,
552
  )
553
  with gr.Column(scale=4):
554
  pass
 
793
  side_bar_btn = gr.Button("Toggle SideBar", variant="secondary", size="sm")
794
  submit_buttons_btn = gr.Button("Toggle Submit Buttons", variant="secondary", size="sm")
795
  col_tabs_scale = gr.Slider(minimum=1, maximum=20, value=10, step=1, label='Window Size')
796
+ text_outputs_height = gr.Slider(minimum=100, maximum=2000, value=kwargs['height'] or 400,
797
+ step=50, label='Chat Height')
798
  dark_mode_btn = gr.Button("Dark Mode", variant="secondary", size="sm")
799
  with gr.Column(scale=4):
800
  pass
801
+ system_visible0 = not is_public and not admin_pass
802
  admin_row = gr.Row()
803
  with admin_row:
804
  with gr.Column(scale=1):
805
+ admin_pass_textbox = gr.Textbox(label="Admin Password", type='password',
806
+ visible=not system_visible0)
807
  with gr.Column(scale=4):
808
  pass
809
+ system_row = gr.Row(visible=system_visible0)
810
  with system_row:
811
  with gr.Column():
812
  with gr.Row():
 
870
  else:
871
  return tuple([gr.update(interactive=True)] * len(args))
872
 
873
+ # Add to UserData or custom user db
874
  update_db_func = functools.partial(update_user_db,
875
  dbs=dbs,
876
  db_type=db_type,
877
  use_openai_embedding=use_openai_embedding,
878
  hf_embedding_model=hf_embedding_model,
 
879
  captions_model=captions_model,
880
+ enable_captions=enable_captions,
881
  caption_loader=caption_loader,
882
+ enable_ocr=enable_ocr,
883
+ enable_pdf_ocr=enable_pdf_ocr,
884
  verbose=kwargs['verbose'],
 
885
  n_jobs=kwargs['n_jobs'],
886
  )
887
  add_file_outputs = [fileup_output, langchain_mode]
888
  add_file_kwargs = dict(fn=update_db_func,
889
+ inputs=[fileup_output, my_db_state, selection_docs_state, chunk, chunk_size,
890
+ langchain_mode],
891
  outputs=add_file_outputs + [sources_text, doc_exception_text],
892
  queue=queue,
893
  api_name='add_file' if allow_api and allow_upload_to_user_data else None)
 
899
  eventdb1b = eventdb1.then(make_interactive, inputs=add_file_outputs, outputs=add_file_outputs,
900
  show_progress='minimal')
901
 
902
+ # deal with challenge to have fileup_output itself as input
903
+ add_file_kwargs2 = dict(fn=update_db_func,
904
+ inputs=[fileup_output_text, my_db_state, selection_docs_state, chunk, chunk_size,
905
+ langchain_mode],
906
+ outputs=add_file_outputs + [sources_text, doc_exception_text],
907
+ queue=queue,
908
+ api_name='add_file_api' if allow_api and allow_upload_to_user_data else None)
909
+ eventdb1_api = fileup_output_text.submit(**add_file_kwargs2, show_progress='full')
910
+
911
  # note for update_user_db_func output is ignored for db
912
 
913
  def clear_textbox():
 
917
 
918
  add_url_outputs = [url_text, langchain_mode]
919
  add_url_kwargs = dict(fn=update_user_db_url_func,
920
+ inputs=[url_text, my_db_state, selection_docs_state, chunk, chunk_size,
921
+ langchain_mode],
922
  outputs=add_url_outputs + [sources_text, doc_exception_text],
923
  queue=queue,
924
  api_name='add_url' if allow_api and allow_upload_to_user_data else None)
 
935
  update_user_db_txt_func = functools.partial(update_db_func, is_txt=True)
936
  add_text_outputs = [user_text_text, langchain_mode]
937
  add_text_kwargs = dict(fn=update_user_db_txt_func,
938
+ inputs=[user_text_text, my_db_state, selection_docs_state, chunk, chunk_size,
939
+ langchain_mode],
940
  outputs=add_text_outputs + [sources_text, doc_exception_text],
941
  queue=queue,
942
  api_name='add_text' if allow_api and allow_upload_to_user_data else None
 
948
  eventdb3 = eventdb3b.then(**add_text_kwargs, show_progress='full')
949
  eventdb3c = eventdb3.then(make_interactive, inputs=add_text_outputs, outputs=add_text_outputs,
950
  show_progress='minimal')
951
+ db_events = [eventdb1a, eventdb1, eventdb1b, eventdb1_api,
952
  eventdb2a, eventdb2, eventdb2b, eventdb2c,
953
  eventdb3a, eventdb3b, eventdb3, eventdb3c]
954
 
 
956
 
957
  # if change collection source, must clear doc selections from it to avoid inconsistency
958
  def clear_doc_choice():
959
+ return gr.Dropdown.update(choices=docs_state0, value=DocumentChoice.ALL.value)
960
 
961
  langchain_mode.change(clear_doc_choice, inputs=None, outputs=document_choice, queue=False)
962
 
963
  def resize_col_tabs(x):
964
  return gr.Dropdown.update(scale=x)
965
 
966
+ col_tabs_scale.change(fn=resize_col_tabs, inputs=col_tabs_scale, outputs=col_tabs, queue=False)
967
 
968
  def resize_chatbots(x, num_model_lock=0):
969
  if num_model_lock == 0:
 
974
 
975
  resize_chatbots_func = functools.partial(resize_chatbots, num_model_lock=len(text_outputs))
976
  text_outputs_height.change(fn=resize_chatbots_func, inputs=text_outputs_height,
977
+ outputs=[text_output, text_output2] + text_outputs, queue=False)
978
 
979
  def update_dropdown(x):
980
  return gr.Dropdown.update(choices=x, value=[docs_state0[0]])
 
1065
  if file.startswith('http') or file.startswith('https'):
1066
  # if file is online, then might as well use google(?)
1067
  document1 = file
1068
+ return gr.update(visible=True,
1069
+ value=f"""<iframe width="1000" height="800" src="https://docs.google.com/viewerng/viewer?url={document1}&embedded=true" frameborder="0" height="100%" width="100%">
1070
  </iframe>
1071
  """), dummy1, dummy1, dummy1
1072
  else:
 
1089
 
1090
  refresh_sources1 = functools.partial(update_and_get_source_files_given_langchain_mode,
1091
  **get_kwargs(update_and_get_source_files_given_langchain_mode,
1092
+ exclude_names=['db1s', 'langchain_mode', 'chunk',
1093
+ 'chunk_size'],
1094
  **all_kwargs))
1095
+ eventdb9 = refresh_sources_btn.click(fn=refresh_sources1,
1096
+ inputs=[my_db_state, langchain_mode, chunk, chunk_size],
1097
  outputs=sources_text,
1098
  api_name='refresh_sources' if allow_api else None)
1099
 
 
1103
  def close_admin(x):
1104
  return gr.update(visible=not (x == admin_pass))
1105
 
1106
+ admin_pass_textbox.submit(check_admin_pass, inputs=admin_pass_textbox, outputs=system_row, queue=False) \
1107
  .then(close_admin, inputs=admin_pass_textbox, outputs=admin_row, queue=False)
1108
 
1109
+ def add_langchain_mode(db1s, selection_docs_state1, langchain_mode1, y):
1110
+ for k in db1s:
1111
+ set_userid(db1s[k])
1112
+ langchain_modes = selection_docs_state1['langchain_modes']
1113
+ langchain_mode_paths = selection_docs_state1['langchain_mode_paths']
1114
+ visible_langchain_modes = selection_docs_state1['visible_langchain_modes']
1115
+
1116
+ user_path = None
1117
+ valid = True
1118
+ y2 = y.strip().replace(' ', '').split(',')
1119
+ if len(y2) >= 1:
1120
+ langchain_mode2 = y2[0]
1121
+ if len(langchain_mode2) >= 3 and langchain_mode2.isalnum():
1122
+ # real restriction is:
1123
+ # ValueError: Expected collection name that (1) contains 3-63 characters, (2) starts and ends with an alphanumeric character, (3) otherwise contains only alphanumeric characters, underscores or hyphens (-), (4) contains no two consecutive periods (..) and (5) is not a valid IPv4 address, got me
1124
+ # but just make simpler
1125
+ user_path = y2[1] if len(y2) > 1 else None # assume scratch if don't have user_path
1126
+ if user_path in ['', "''"]:
1127
+ # for scratch spaces
1128
+ user_path = None
1129
+ if langchain_mode2 in langchain_modes_intrinsic:
1130
+ user_path = None
1131
+ textbox = "Invalid access to use internal name: %s" % langchain_mode2
1132
+ valid = False
1133
+ langchain_mode2 = langchain_mode1
1134
+ elif user_path and allow_upload_to_user_data or not user_path and allow_upload_to_my_data:
1135
+ langchain_mode_paths.update({langchain_mode2: user_path})
1136
+ if langchain_mode2 not in visible_langchain_modes:
1137
+ visible_langchain_modes.append(langchain_mode2)
1138
+ if langchain_mode2 not in langchain_modes:
1139
+ langchain_modes.append(langchain_mode2)
1140
+ textbox = ''
1141
+ if user_path:
1142
+ makedirs(user_path, exist_ok=True)
1143
+ else:
1144
+ valid = False
1145
+ langchain_mode2 = langchain_mode1
1146
+ textbox = "Invalid access. user allowed: %s " \
1147
+ "scratch allowed: %s" % (allow_upload_to_user_data, allow_upload_to_my_data)
1148
+ else:
1149
+ valid = False
1150
+ langchain_mode2 = langchain_mode1
1151
+ textbox = "Invalid, collection must be >=3 characters and alphanumeric"
1152
+ else:
1153
+ valid = False
1154
+ langchain_mode2 = langchain_mode1
1155
+ textbox = "Invalid, must be like UserData2, user_path2"
1156
+ selection_docs_state1 = update_langchain_mode_paths(db1s, selection_docs_state1)
1157
+ df_langchain_mode_paths1 = get_df_langchain_mode_paths(selection_docs_state1)
1158
+ choices = get_langchain_choices(selection_docs_state1)
1159
+
1160
+ if valid and not user_path:
1161
+ # needs to have key for it to make it known different from userdata case in _update_user_db()
1162
+ db1s[langchain_mode2] = [None, None]
1163
+ if valid:
1164
+ save_collection_names(langchain_modes, visible_langchain_modes, langchain_mode_paths, LangChainMode,
1165
+ db1s)
1166
+
1167
+ return db1s, selection_docs_state1, gr.update(choices=choices,
1168
+ value=langchain_mode2), textbox, df_langchain_mode_paths1
1169
+
1170
+ def remove_langchain_mode(db1s, selection_docs_state1, langchain_mode1, langchain_mode2, dbsu=None):
1171
+ for k in db1s:
1172
+ set_userid(db1s[k])
1173
+ assert dbsu is not None
1174
+ langchain_modes = selection_docs_state1['langchain_modes']
1175
+ langchain_mode_paths = selection_docs_state1['langchain_mode_paths']
1176
+ visible_langchain_modes = selection_docs_state1['visible_langchain_modes']
1177
+
1178
+ if langchain_mode2 in db1s and not allow_upload_to_my_data or \
1179
+ dbsu is not None and langchain_mode2 in dbsu and not allow_upload_to_user_data or \
1180
+ langchain_mode2 in langchain_modes_intrinsic:
1181
+ # NOTE: Doesn't fail if remove MyData, but didn't debug odd behavior seen with upload after gone
1182
+ textbox = "Invalid access, cannot remove %s" % langchain_mode2
1183
+ df_langchain_mode_paths1 = get_df_langchain_mode_paths(selection_docs_state1)
1184
+ else:
1185
+ # change global variables
1186
+ if langchain_mode2 in visible_langchain_modes:
1187
+ visible_langchain_modes.remove(langchain_mode2)
1188
+ textbox = ""
1189
+ else:
1190
+ textbox = "%s was not visible" % langchain_mode2
1191
+ if langchain_mode2 in langchain_modes:
1192
+ langchain_modes.remove(langchain_mode2)
1193
+ if langchain_mode2 in langchain_mode_paths:
1194
+ langchain_mode_paths.pop(langchain_mode2)
1195
+ if langchain_mode2 in db1s:
1196
+ # remove db entirely, so not in list, else need to manage visible list in update_langchain_mode_paths()
1197
+ # FIXME: Remove location?
1198
+ if langchain_mode2 != LangChainMode.MY_DATA.value:
1199
+ # don't remove last MyData, used as user hash
1200
+ db1s.pop(langchain_mode2)
1201
+ # only show
1202
+ selection_docs_state1 = update_langchain_mode_paths(db1s, selection_docs_state1)
1203
+ df_langchain_mode_paths1 = get_df_langchain_mode_paths(selection_docs_state1)
1204
+
1205
+ save_collection_names(langchain_modes, visible_langchain_modes, langchain_mode_paths, LangChainMode,
1206
+ db1s)
1207
+
1208
+ return db1s, selection_docs_state1, \
1209
+ gr.update(choices=get_langchain_choices(selection_docs_state1),
1210
+ value=langchain_mode2), textbox, df_langchain_mode_paths1
1211
+
1212
+ new_langchain_mode_text.submit(fn=add_langchain_mode,
1213
+ inputs=[my_db_state, selection_docs_state, langchain_mode,
1214
+ new_langchain_mode_text],
1215
+ outputs=[my_db_state, selection_docs_state, langchain_mode,
1216
+ new_langchain_mode_text,
1217
+ langchain_mode_path_text],
1218
+ api_name='new_langchain_mode_text' if allow_api and allow_upload_to_user_data else None)
1219
+ remove_langchain_mode_func = functools.partial(remove_langchain_mode, dbsu=dbs)
1220
+ remove_langchain_mode_text.submit(fn=remove_langchain_mode_func,
1221
+ inputs=[my_db_state, selection_docs_state, langchain_mode,
1222
+ remove_langchain_mode_text],
1223
+ outputs=[my_db_state, selection_docs_state, langchain_mode,
1224
+ remove_langchain_mode_text,
1225
+ langchain_mode_path_text],
1226
+ api_name='remove_langchain_mode_text' if allow_api and allow_upload_to_user_data else None)
1227
+
1228
+ def update_langchain_gr(db1s, selection_docs_state1, langchain_mode1):
1229
+ for k in db1s:
1230
+ set_userid(db1s[k])
1231
+ langchain_modes = selection_docs_state1['langchain_modes']
1232
+ langchain_mode_paths = selection_docs_state1['langchain_mode_paths']
1233
+ visible_langchain_modes = selection_docs_state1['visible_langchain_modes']
1234
+ # in-place
1235
+
1236
+ # update user collaborative collections
1237
+ update_langchain(langchain_modes, visible_langchain_modes, langchain_mode_paths, '')
1238
+ # update scratch single-user collections
1239
+ user_hash = db1s.get(LangChainMode.MY_DATA.value, '')[1]
1240
+ update_langchain(langchain_modes, visible_langchain_modes, langchain_mode_paths, user_hash)
1241
+
1242
+ selection_docs_state1 = update_langchain_mode_paths(db1s, selection_docs_state1)
1243
+ df_langchain_mode_paths1 = get_df_langchain_mode_paths(selection_docs_state1)
1244
+ return selection_docs_state1, \
1245
+ gr.update(choices=get_langchain_choices(selection_docs_state1),
1246
+ value=langchain_mode1), df_langchain_mode_paths1
1247
+
1248
+ load_langchain.click(fn=update_langchain_gr,
1249
+ inputs=[my_db_state, selection_docs_state, langchain_mode],
1250
+ outputs=[selection_docs_state, langchain_mode, langchain_mode_path_text],
1251
+ api_name='load_langchain' if allow_api and allow_upload_to_user_data else None)
1252
+
1253
  inputs_list, inputs_dict = get_inputs_list(all_kwargs, kwargs['model_lower'], model_id=1)
1254
  inputs_list2, inputs_dict2 = get_inputs_list(all_kwargs, kwargs['model_lower'], model_id=2)
1255
  from functools import partial
 
1261
  def evaluate_nochat(*args1, default_kwargs1=None, str_api=False, **kwargs1):
1262
  args_list = list(args1)
1263
  if str_api:
1264
+ user_kwargs = args_list[len(input_args_list)]
1265
  assert isinstance(user_kwargs, str)
1266
  user_kwargs = ast.literal_eval(user_kwargs)
1267
  else:
1268
+ user_kwargs = {k: v for k, v in zip(eval_func_param_names, args_list[len(input_args_list):])}
1269
  # only used for submit_nochat_api
1270
  user_kwargs['chat'] = False
1271
  if 'stream_output' not in user_kwargs:
 
1275
  user_kwargs['langchain_mode'] = 'Disabled'
1276
  if 'langchain_action' not in user_kwargs:
1277
  user_kwargs['langchain_action'] = LangChainAction.QUERY.value
1278
+ if 'langchain_agents' not in user_kwargs:
1279
+ user_kwargs['langchain_agents'] = []
1280
 
1281
  set1 = set(list(default_kwargs1.keys()))
1282
  set2 = set(eval_func_param_names)
 
1284
  # correct ordering. Note some things may not be in default_kwargs, so can't be default of user_kwargs.get()
1285
  model_state1 = args_list[0]
1286
  my_db_state1 = args_list[1]
1287
+ selection_docs_state1 = args_list[2]
1288
  args_list = [user_kwargs[k] if k in user_kwargs and user_kwargs[k] is not None else default_kwargs1[k] for k
1289
  in eval_func_param_names]
1290
  assert len(args_list) == len(eval_func_param_names)
1291
+ args_list = [model_state1, my_db_state1, selection_docs_state1] + args_list
1292
 
1293
  try:
1294
  for res_dict in evaluate(*tuple(args_list), **kwargs1):
 
1459
  prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
1460
  langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
1461
  langchain_action1 = args_list[eval_func_param_names.index('langchain_action')]
1462
+ langchain_agents1 = args_list[eval_func_param_names.index('langchain_agents')]
1463
  document_subset1 = args_list[eval_func_param_names.index('document_subset')]
1464
  document_choice1 = args_list[eval_func_param_names.index('document_choice')]
1465
  if not prompt_type1:
 
1492
  history[-1][1] = None
1493
  return history
1494
  if user_message1 in ['', None, '\n']:
1495
+ if not allow_empty_instruction(langchain_mode1, document_subset1, langchain_action1):
 
 
 
1496
  # reject non-retry submit/enter
1497
  return history
1498
  user_message1 = fix_text_for_gradio(user_message1)
 
1539
  API only called for which_model=0, default for inputs_list, but rest should ignore inputs_list
1540
  :return: last element is True if should run bot, False if should just yield history
1541
  """
1542
+ isize = len(input_args_list) + 1 # states + chat history
1543
  # don't deepcopy, can contain model itself
1544
  args_list = list(args).copy()
1545
+ model_state1 = args_list[-isize]
1546
+ my_db_state1 = args_list[-isize + 1]
1547
+ selection_docs_state1 = args_list[-isize + 2]
1548
  history = args_list[-1]
1549
  prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
1550
  prompt_dict1 = args_list[eval_func_param_names.index('prompt_dict')]
 
1552
  if model_state1['model'] is None or model_state1['model'] == no_model_str:
1553
  return history, None, None, None
1554
 
1555
+ args_list = args_list[:-isize] # only keep rest needed for evaluate()
1556
  langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
1557
+ add_chat_history_to_context1 = args_list[eval_func_param_names.index('add_chat_history_to_context')]
1558
  langchain_action1 = args_list[eval_func_param_names.index('langchain_action')]
1559
+ langchain_agents1 = args_list[eval_func_param_names.index('langchain_agents')]
1560
  document_subset1 = args_list[eval_func_param_names.index('document_subset')]
1561
  document_choice1 = args_list[eval_func_param_names.index('document_choice')]
1562
  if not history:
 
1569
  instruction1 = history[-1][0]
1570
  history[-1][1] = None
1571
  elif not instruction1:
1572
+ if not allow_empty_instruction(langchain_mode1, document_subset1, langchain_action1):
 
 
 
1573
  # if not retrying, then reject empty query
1574
  return history, None, None, None
1575
  elif len(history) > 0 and history[-1][1] not in [None, '']:
 
1586
 
1587
  chat1 = args_list[eval_func_param_names.index('chat')]
1588
  model_max_length1 = get_model_max_length(model_state1)
1589
+ context1 = history_to_context(history, langchain_mode1,
1590
+ add_chat_history_to_context1,
1591
+ prompt_type1, prompt_dict1, chat1,
1592
  model_max_length1, memory_restriction_level,
1593
  kwargs['keep_sources_in_context'])
1594
  args_list[0] = instruction1 # override original instruction with history from user
 
1597
  fun1 = partial(evaluate,
1598
  model_state1,
1599
  my_db_state1,
1600
+ selection_docs_state1,
1601
  *tuple(args_list),
1602
  **kwargs_evaluate)
1603
 
 
1643
  clear_torch_cache()
1644
  return
1645
 
1646
+ def clear_embeddings(langchain_mode1, db1s):
1647
  # clear any use of embedding that sits on GPU, else keeps accumulating GPU usage even if clear torch cache
1648
+ if db_type == 'chroma' and langchain_mode1 not in ['LLM', 'Disabled', None, '']:
1649
  from gpt_langchain import clear_embedding
1650
  db = dbs.get('langchain_mode1')
1651
  if db is not None and not isinstance(db, str):
1652
  clear_embedding(db)
1653
+ if db1s is not None and langchain_mode1 in db1s:
1654
+ db1 = db1s[langchain_mode1]
1655
+ if len(db1) == 2:
1656
+ clear_embedding(db1[0])
1657
 
1658
  def bot(*args, retry=False):
1659
+ history, fun1, langchain_mode1, db1 = prep_bot(*args, retry=retry)
1660
  try:
1661
  for res in get_response(fun1, history):
1662
  yield res
1663
  finally:
1664
  clear_torch_cache()
1665
+ clear_embeddings(langchain_mode1, db1)
1666
 
1667
  def all_bot(*args, retry=False, model_states1=None):
1668
  args_list = list(args).copy()
 
1672
  stream_output1 = args_list[eval_func_param_names.index('stream_output')]
1673
  max_time1 = args_list[eval_func_param_names.index('max_time')]
1674
  langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
1675
+ isize = len(input_args_list) + 1 # states + chat history
1676
+ db1s = None
1677
  try:
1678
  gen_list = []
1679
  for chatboti, (chatbot1, model_state1) in enumerate(zip(chatbots, model_states1)):
1680
  args_list1 = args_list0.copy()
1681
+ args_list1.insert(-isize + 2,
1682
+ model_state1) # insert at -2 so is at -3, and after chatbot1 added, at -4
1683
  # if at start, have None in response still, replace with '' so client etc. acts like normal
1684
  # assumes other parts of code treat '' and None as if no response yet from bot
1685
  # can't do this later in bot code as racy with threaded generators
 
1689
  # so consistent with prep_bot()
1690
  # with model_state1 at -3, my_db_state1 at -2, and history(chatbot) at -1
1691
  # langchain_mode1 and my_db_state1 should be same for every bot
1692
+ history, fun1, langchain_mode1, db1s = prep_bot(*tuple(args_list1), retry=retry,
1693
+ which_model=chatboti)
1694
  gen1 = get_response(fun1, history)
1695
  if stream_output1:
1696
  gen1 = TimeoutIterator(gen1, timeout=0.01, sentinel=None, raise_on_exception=False)
 
1736
  print("Generate exceptions: %s" % exceptions, flush=True)
1737
  finally:
1738
  clear_torch_cache()
1739
+ clear_embeddings(langchain_mode1, db1s)
1740
 
1741
  # NORMAL MODEL
1742
  user_args = dict(fn=functools.partial(user, sanitize_user_prompt=kwargs['sanitize_user_prompt']),
 
1744
  outputs=text_output,
1745
  )
1746
  bot_args = dict(fn=bot,
1747
+ inputs=inputs_list + [model_state, my_db_state, selection_docs_state] + [text_output],
1748
  outputs=[text_output, chat_exception_text],
1749
  )
1750
  retry_bot_args = dict(fn=functools.partial(bot, retry=True),
1751
+ inputs=inputs_list + [model_state, my_db_state, selection_docs_state] + [text_output],
1752
  outputs=[text_output, chat_exception_text],
1753
  )
1754
  retry_user_args = dict(fn=functools.partial(user, retry=True),
 
1766
  outputs=text_output2,
1767
  )
1768
  bot_args2 = dict(fn=bot,
1769
+ inputs=inputs_list2 + [model_state2, my_db_state, selection_docs_state] + [text_output2],
1770
  outputs=[text_output2, chat_exception_text],
1771
  )
1772
  retry_bot_args2 = dict(fn=functools.partial(bot, retry=True),
1773
+ inputs=inputs_list2 + [model_state2, my_db_state, selection_docs_state] + [text_output2],
1774
  outputs=[text_output2, chat_exception_text],
1775
  )
1776
  retry_user_args2 = dict(fn=functools.partial(user, retry=True),
 
1791
  outputs=text_outputs,
1792
  )
1793
  all_bot_args = dict(fn=functools.partial(all_bot, model_states1=model_states),
1794
+ inputs=inputs_list + [my_db_state, selection_docs_state] + text_outputs,
1795
  outputs=text_outputs + [chat_exception_text],
1796
  )
1797
  all_retry_bot_args = dict(fn=functools.partial(all_bot, model_states1=model_states, retry=True),
1798
+ inputs=inputs_list + [my_db_state, selection_docs_state] + text_outputs,
1799
  outputs=text_outputs + [chat_exception_text],
1800
  )
1801
  all_retry_user_args = dict(fn=functools.partial(all_user, retry=True,
 
1957
  def get_short_chat(x, short_chats, short_len=20, words=4):
1958
  if x and len(x[0]) == 2 and x[0][0] is not None:
1959
  short_chat = ' '.join(x[0][0][:short_len].split(' ')[:words]).strip()
1960
+ if not short_chat:
1961
+ # e.g.summarization, try using answer
1962
+ short_chat = ' '.join(x[0][1][:short_len].split(' ')[:words]).strip()
1963
+ if not short_chat:
1964
+ short_chat = 'Unk'
1965
  short_chat = dedup(short_chat, short_chats)
1966
  else:
1967
  short_chat = None
 
2029
  already_exists = any([is_chat_same(chat_list, x) for x in old_chat_lists])
2030
  if not already_exists:
2031
  chat_state1[short_chat] = chat_list.copy()
2032
+
2033
+ # reverse so newest at top
2034
+ choices = list(chat_state1.keys()).copy()
2035
+ choices.reverse()
2036
+
2037
+ return chat_state1, gr.update(choices=choices, value=None)
 
 
2038
 
2039
  def switch_chat(chat_key, chat_state1, num_model_lock=0):
2040
  chosen_chat = chat_state1[chat_key]
 
2065
 
2066
  remove_chat_event = remove_chat_btn.click(remove_chat,
2067
  inputs=[radio_chats, chat_state], outputs=[radio_chats, chat_state],
2068
+ queue=False, api_name='remove_chat')
2069
 
2070
  def get_chats1(chat_state1):
2071
  base = 'chats'
 
2096
  new_chats = json.loads(f.read())
2097
  for chat1_k, chat1_v in new_chats.items():
2098
  # ignore chat1_k, regenerate and de-dup to avoid loss
2099
+ chat_state1, _ = save_chat(chat1_v, chat_state1, chat_is_list=True)
2100
  except BaseException as e:
2101
  t, v, tb = sys.exc_info()
2102
  ex = ''.join(traceback.format_exception(t, v, tb))
 
2122
  .then(deselect_radio_chats, inputs=None, outputs=radio_chats, queue=False) \
2123
  .then(clear_scores, outputs=[score_text, score_text2, score_text_nochat])
2124
 
 
 
 
 
 
 
2125
  clear_event = save_chat_btn.click(save_chat,
2126
  inputs=[text_output, text_output2] + text_outputs + [chat_state],
2127
+ outputs=[chat_state, radio_chats],
2128
+ api_name='save_chat' if allow_api else None)
2129
+ if kwargs['score_model']:
2130
+ clear_event2 = clear_event.then(clear_scores, outputs=[score_text, score_text2, score_text_nochat])
 
2131
 
2132
  # NOTE: clear of instruction/iinput for nochat has to come after score,
2133
  # because score for nochat consumes actual textbox, while chat consumes chat history filled by user()
2134
  no_chat_args = dict(fn=fun,
2135
+ inputs=[model_state, my_db_state, selection_docs_state] + inputs_list,
2136
  outputs=text_output_nochat,
2137
  queue=queue,
2138
  )
 
2151
  .then(clear_torch_cache)
2152
 
2153
  submit_event_nochat_api = submit_nochat_api.click(fun_with_dict_str,
2154
+ inputs=[model_state, my_db_state, selection_docs_state,
2155
+ inputs_dict_str],
2156
  outputs=text_output_nochat_api,
2157
  queue=True, # required for generator
2158
  api_name='submit_nochat_api' if allow_api else None) \
 
2402
  print("Exception: %s" % str(e), flush=True)
2403
  return json.dumps(sys_dict)
2404
 
2405
+ system_kwargs = all_kwargs.copy()
2406
+ system_kwargs.update(dict(command=str(' '.join(sys.argv))))
2407
  get_system_info_dict_func = functools.partial(get_system_info_dict, **all_kwargs)
2408
 
2409
  system_dict_event = system_btn2.click(get_system_info_dict_func,
 
2433
  else:
2434
  tokenizer = None
2435
  if tokenizer is not None:
2436
+ langchain_mode1 = 'LLM'
2437
+ add_chat_history_to_context1 = True
2438
  # fake user message to mimic bot()
2439
  chat1 = copy.deepcopy(chat1)
2440
  chat1 = chat1 + [['user_message1', None]]
2441
  model_max_length1 = tokenizer.model_max_length
2442
+ context1 = history_to_context(chat1, langchain_mode1,
2443
+ add_chat_history_to_context1,
2444
+ prompt_type1, prompt_dict1, chat1,
2445
  model_max_length1,
2446
  memory_restriction_level1, keep_sources_in_context1)
2447
  return str(tokenizer(context1, return_tensors="pt")['input_ids'].shape[1])
 
2471
  ,
2472
  queue=False, api_name='stop' if allow_api else None).then(clear_torch_cache, queue=False)
2473
 
2474
+ demo.load(None, None, None, _js=get_dark_js() if kwargs['dark'] else None)
2475
 
2476
  demo.queue(concurrency_count=kwargs['concurrency_count'], api_open=kwargs['api_open'])
2477
  favicon_path = "h2o-logo.svg"
 
2486
  # FIXME: disable for gptj, langchain or gpt4all modify print itself
2487
  # FIXME: and any multi-threaded/async print will enter model output!
2488
  scheduler.add_job(func=ping, trigger="interval", seconds=60)
2489
+ if is_public or os.getenv('PING_GPU'):
2490
+ scheduler.add_job(func=ping_gpu, trigger="interval", seconds=60 * 10)
2491
  scheduler.start()
2492
 
2493
  # import control
 
2506
  demo.block_thread()
2507
 
2508
 
 
 
 
2509
  def get_inputs_list(inputs_dict, model_lower, model_id=1):
2510
  """
2511
  map gradio objects in locals() to inputs for evaluate().
 
2539
  return inputs_list, inputs_dict_out
2540
 
2541
 
2542
+ def get_sources(db1s, langchain_mode, dbs=None, docs_state0=None):
2543
+ for k in db1s:
2544
+ set_userid(db1s[k])
2545
 
2546
  if langchain_mode in ['ChatLLM', 'LLM']:
2547
  source_files_added = "NA"
 
2550
  source_files_added = "Not showing wiki_full, takes about 20 seconds and makes 4MB file." \
2551
  " Ask jon.mckinney@h2o.ai for file if required."
2552
  source_list = []
2553
+ elif langchain_mode in db1s and len(db1s[langchain_mode]) == 2 and db1s[langchain_mode][0] is not None:
2554
+ db1 = db1s[langchain_mode]
2555
  from gpt_langchain import get_metadatas
2556
  metadatas = get_metadatas(db1[0])
2557
  source_list = sorted(set([x['source'] for x in metadatas]))
 
2582
  db1[1] = str(uuid.uuid4())
2583
 
2584
 
2585
+ def update_user_db(file, db1s, selection_docs_state1, chunk, chunk_size, langchain_mode, dbs=None, **kwargs):
2586
+ kwargs.update(selection_docs_state1)
 
2587
  if file is None:
2588
  raise RuntimeError("Don't use change, use input")
2589
 
2590
  try:
2591
+ return _update_user_db(file, db1s=db1s, chunk=chunk, chunk_size=chunk_size,
2592
  langchain_mode=langchain_mode, dbs=dbs,
2593
  **kwargs)
2594
  except BaseException as e:
 
2619
  user_id = db1[1]
2620
  base_path = 'locks'
2621
  makedirs(base_path)
2622
+ lock_file = os.path.join(base_path, "db_%s_%s.lock" % (langchain_mode.replace(' ', '_'), user_id))
2623
  return lock_file
2624
 
2625
 
2626
  def _update_user_db(file,
2627
+ db1s=None,
2628
  chunk=None, chunk_size=None,
2629
+ dbs=None, db_type=None,
2630
+ langchain_mode='UserData',
2631
+ langchain_modes=None, # unused but required as part of selection_docs_state1
2632
+ langchain_mode_paths=None,
2633
+ visible_langchain_modes=None,
2634
  use_openai_embedding=None,
2635
  hf_embedding_model=None,
2636
  caption_loader=None,
2637
  enable_captions=None,
2638
  captions_model=None,
2639
  enable_ocr=None,
2640
+ enable_pdf_ocr=None,
2641
  verbose=None,
2642
+ n_jobs=-1,
2643
  is_url=None, is_txt=None,
2644
+ ):
2645
+ assert db1s is not None
2646
  assert chunk is not None
2647
  assert chunk_size is not None
2648
  assert use_openai_embedding is not None
 
2651
  assert enable_captions is not None
2652
  assert captions_model is not None
2653
  assert enable_ocr is not None
2654
+ assert enable_pdf_ocr is not None
2655
  assert verbose is not None
2656
 
 
 
2657
  if dbs is None:
2658
  dbs = {}
2659
  assert isinstance(dbs, dict), "Wrong type for dbs: %s" % str(type(dbs))
 
2671
  if langchain_mode == LangChainMode.DISABLED.value:
2672
  return None, langchain_mode, get_source_files(), ""
2673
 
2674
+ if langchain_mode in [LangChainMode.LLM.value]:
2675
  # then switch to MyData, so langchain_mode also becomes way to select where upload goes
2676
  # but default to mydata if nothing chosen, since safest
2677
+ if LangChainMode.MY_DATA.value in visible_langchain_modes:
2678
+ langchain_mode = LangChainMode.MY_DATA.value
2679
+
2680
+ if langchain_mode_paths is None:
2681
+ langchain_mode_paths = {}
2682
+ user_path = langchain_mode_paths.get(langchain_mode)
2683
+ # UserData or custom, which has to be from user's disk
2684
+ if user_path is not None:
2685
  # move temp files from gradio upload to stable location
2686
  for fili, fil in enumerate(file):
2687
+ if isinstance(fil, str) and os.path.isfile(fil): # not url, text
2688
+ new_fil = os.path.normpath(os.path.join(user_path, os.path.basename(fil)))
2689
+ if os.path.normpath(os.path.abspath(fil)) != os.path.normpath(os.path.abspath(new_fil)):
2690
  if os.path.isfile(new_fil):
2691
  remove(new_fil)
2692
  try:
 
2706
  enable_captions=enable_captions,
2707
  captions_model=captions_model,
2708
  enable_ocr=enable_ocr,
2709
+ enable_pdf_ocr=enable_pdf_ocr,
2710
  caption_loader=caption_loader,
2711
  )
2712
  exceptions = [x for x in sources if x.metadata.get('exception')]
2713
  exceptions_strs = [x.metadata['exception'] for x in exceptions]
2714
  sources = [x for x in sources if 'exception' not in x.metadata]
2715
 
2716
+ # below must at least come after langchain_mode is modified in case was LLM -> MyData,
2717
+ # so original langchain mode changed
2718
+ for k in db1s:
2719
+ set_userid(db1s[k])
2720
+ db1 = get_db1(db1s, langchain_mode)
2721
+
2722
+ lock_file = get_lock_file(db1s[LangChainMode.MY_DATA.value], langchain_mode) # user-level lock, not db-level lock
2723
  with filelock.FileLock(lock_file):
2724
+ if langchain_mode in db1s:
2725
  if db1[0] is not None:
2726
  # then add
2727
  db, num_new_sources, new_sources_metadata = add_to_db(db1[0], sources, db_type=db_type,
 
2731
  # in testing expect:
2732
  # assert len(db1) == 2 and db1[1] is None, "Bad MyData db: %s" % db1
2733
  # for production hit, when user gets clicky:
2734
+ assert len(db1) == 2, "Bad %s db: %s" % (langchain_mode, db1)
2735
+ assert db1[1] is not None, "db hash was None, not allowed"
2736
  # then create
2737
  # if added has to original state and didn't change, then would be shared db for all users
2738
  persist_directory = os.path.join(scratch_base_dir, 'db_dir_%s_%s' % (langchain_mode, db1[1]))
 
2754
  use_openai_embedding=use_openai_embedding,
2755
  hf_embedding_model=hf_embedding_model)
2756
  else:
2757
+ # then create. Or might just be that dbs is unfilled, then it will fill, then add
2758
  db = get_db(sources, use_openai_embedding=use_openai_embedding,
2759
  db_type=db_type,
2760
  persist_directory=persist_directory,
 
2768
  return None, langchain_mode, source_files_added, '\n'.join(exceptions_strs)
2769
 
2770
 
2771
+ def get_db(db1s, langchain_mode, dbs=None):
2772
+ db1 = get_db1(db1s, langchain_mode)
2773
+ lock_file = get_lock_file(db1s[LangChainMode.MY_DATA.value], langchain_mode)
2774
 
2775
  with filelock.FileLock(lock_file):
2776
  if langchain_mode in ['wiki_full']:
2777
  # NOTE: avoid showing full wiki. Takes about 30 seconds over about 90k entries, but not useful for now
2778
  db = None
2779
+ elif langchain_mode in db1s and len(db1) == 2 and db1[0] is not None:
2780
  db = db1[0]
2781
  elif dbs is not None and langchain_mode in dbs and dbs[langchain_mode] is not None:
2782
  db = dbs[langchain_mode]
 
2785
  return db
2786
 
2787
 
2788
+ def get_source_files_given_langchain_mode(db1s, langchain_mode='UserData', dbs=None):
2789
+ db = get_db(db1s, langchain_mode, dbs=dbs)
2790
  if langchain_mode in ['ChatLLM', 'LLM'] or db is None:
2791
  return "Sources: N/A"
2792
  return get_source_files(db=db, exceptions=None)
 
2885
  return source_files_added
2886
 
2887
 
2888
+ def update_and_get_source_files_given_langchain_mode(db1s, langchain_mode, chunk, chunk_size,
2889
+ dbs=None, first_para=None,
2890
+ text_limit=None,
2891
+ langchain_mode_paths=None, db_type=None, load_db_if_exists=None,
2892
  n_jobs=None, verbose=None):
2893
+ has_path = {k: v for k, v in langchain_mode_paths.items() if v}
2894
+ if langchain_mode in [LangChainMode.LLM.value, LangChainMode.MY_DATA.value]:
2895
+ # then assume user really meant UserData, to avoid extra clicks in UI,
2896
+ # since others can't be on disk, except custom user modes, which they should then select to query it
2897
+ if LangChainMode.USER_DATA.value in has_path:
2898
+ langchain_mode = LangChainMode.USER_DATA.value
2899
+
2900
+ db = get_db(db1s, langchain_mode, dbs=dbs)
2901
 
2902
  from gpt_langchain import make_db
2903
  db, num_new_sources, new_sources_metadata = make_db(use_openai_embedding=False,
 
2906
  chunk=chunk,
2907
  chunk_size=chunk_size,
2908
  langchain_mode=langchain_mode,
2909
+ langchain_mode_paths=langchain_mode_paths,
2910
  db_type=db_type,
2911
  load_db_if_exists=load_db_if_exists,
2912
  db=db,
2913
  n_jobs=n_jobs,
2914
  verbose=verbose)
2915
+ # during refreshing, might have "created" new db since not in dbs[] yet, so insert back just in case
2916
+ # so even if persisted, not kept up-to-date with dbs memory
2917
+ if langchain_mode in db1s:
2918
+ db1s[langchain_mode][0] = db
2919
+ else:
2920
+ dbs[langchain_mode] = db
2921
+
2922
  # return only new sources with text saying such
2923
  return get_source_files(db=None, exceptions=None, metadatas=new_sources_metadata)
2924
+
2925
+
2926
+ def get_db1(db1s, langchain_mode1):
2927
+ if langchain_mode1 in db1s:
2928
+ db1 = db1s[langchain_mode1]
2929
+ else:
2930
+ # indicates to code that not scratch database
2931
+ db1 = [None, None]
2932
+ return db1
gradio_utils/__init__.py ADDED
File without changes
gradio_utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (134 Bytes). View file
 
gradio_utils/__pycache__/css.cpython-310.pyc CHANGED
Binary files a/gradio_utils/__pycache__/css.cpython-310.pyc and b/gradio_utils/__pycache__/css.cpython-310.pyc differ
 
gradio_utils/css.py CHANGED
@@ -53,4 +53,8 @@ def make_css_base() -> str:
53
  margin-bottom: 2.5rem;
54
  }
55
  .chatsmall chatbot {font-size: 10px !important}
 
 
 
 
56
  """
 
53
  margin-bottom: 2.5rem;
54
  }
55
  .chatsmall chatbot {font-size: 10px !important}
56
+
57
+ .gradio-container {
58
+ max-width: none !important;
59
+ }
60
  """
h2oai_pipeline.py CHANGED
@@ -11,6 +11,7 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
11
  def __init__(self, *args, debug=False, chat=False, stream_output=False,
12
  sanitize_bot_response=False,
13
  use_prompter=True, prompter=None,
 
14
  prompt_type=None, prompt_dict=None,
15
  max_input_tokens=2048 - 256, **kwargs):
16
  """
@@ -34,6 +35,8 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
34
  self.prompt_type = prompt_type
35
  self.prompt_dict = prompt_dict
36
  self.prompter = prompter
 
 
37
  if self.use_prompter:
38
  if self.prompter is not None:
39
  assert self.prompter.prompt_type is not None
@@ -113,7 +116,7 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
113
  def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs):
114
  prompt_text, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt_text, self.tokenizer)
115
 
116
- data_point = dict(context='', instruction=prompt_text, input='')
117
  if self.prompter is not None:
118
  prompt_text = self.prompter.generate_prompt(data_point)
119
  self.prompt_text = prompt_text
 
11
  def __init__(self, *args, debug=False, chat=False, stream_output=False,
12
  sanitize_bot_response=False,
13
  use_prompter=True, prompter=None,
14
+ context='', iinput='',
15
  prompt_type=None, prompt_dict=None,
16
  max_input_tokens=2048 - 256, **kwargs):
17
  """
 
35
  self.prompt_type = prompt_type
36
  self.prompt_dict = prompt_dict
37
  self.prompter = prompter
38
+ self.context = context
39
+ self.iinput = iinput
40
  if self.use_prompter:
41
  if self.prompter is not None:
42
  assert self.prompter.prompt_type is not None
 
116
  def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs):
117
  prompt_text, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt_text, self.tokenizer)
118
 
119
+ data_point = dict(context=self.context, instruction=prompt_text, input=self.iinput)
120
  if self.prompter is not None:
121
  prompt_text = self.prompter.generate_prompt(data_point)
122
  self.prompt_text = prompt_text
iterators/__pycache__/timeout_iterator.cpython-310.pyc CHANGED
Binary files a/iterators/__pycache__/timeout_iterator.cpython-310.pyc and b/iterators/__pycache__/timeout_iterator.cpython-310.pyc differ
 
iterators/timeout_iterator.py CHANGED
@@ -48,7 +48,7 @@ class TimeoutIterator:
48
  def interrupt(self):
49
  """
50
  interrupt and stop the underlying thread.
51
- the thread acutally dies only after interrupt has been set and
52
  the underlying iterator yields a value after that.
53
  """
54
  self._interrupt = True
 
48
  def interrupt(self):
49
  """
50
  interrupt and stop the underlying thread.
51
+ the thread actually dies only after interrupt has been set and
52
  the underlying iterator yields a value after that.
53
  """
54
  self._interrupt = True
prompter.py CHANGED
@@ -77,6 +77,12 @@ prompt_type_to_model_name = {
77
  "mptchat": ['mosaicml/mpt-7b-chat', 'mosaicml/mpt-30b-chat', 'TheBloke/mpt-30B-chat-GGML'],
78
  "vicuna11": ['lmsys/vicuna-33b-v1.3'],
79
  "falcon": ['tiiuae/falcon-40b-instruct', 'tiiuae/falcon-40b', 'tiiuae/falcon-7b-instruct', 'tiiuae/falcon-7b'],
 
 
 
 
 
 
80
  # could be plain, but default is correct prompt_type for default TheBloke model ggml-wizardLM-7B.q4_2.bin
81
  }
82
  if os.getenv('OPENAI_API_KEY'):
@@ -582,6 +588,42 @@ ASSISTANT:
582
  # if add space here, non-unique tokenization will often make LLM produce wrong output
583
  PreResponse = PreResponse
584
  # generates_leading_space = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
585
  else:
586
  raise RuntimeError("No such prompt_type=%s" % prompt_type)
587
 
@@ -810,9 +852,20 @@ class Prompter(object):
810
  if oi > 0:
811
  # post fix outputs with seperator
812
  output += '\n'
 
813
  outputs[oi] = output
814
  # join all outputs, only one extra new line between outputs
815
  output = '\n'.join(outputs)
816
  if self.debug:
817
  print("outputclean:\n%s" % '\n\n'.join(outputs), flush=True)
818
  return output
 
 
 
 
 
 
 
 
 
 
 
77
  "mptchat": ['mosaicml/mpt-7b-chat', 'mosaicml/mpt-30b-chat', 'TheBloke/mpt-30B-chat-GGML'],
78
  "vicuna11": ['lmsys/vicuna-33b-v1.3'],
79
  "falcon": ['tiiuae/falcon-40b-instruct', 'tiiuae/falcon-40b', 'tiiuae/falcon-7b-instruct', 'tiiuae/falcon-7b'],
80
+ "llama2": [
81
+ 'meta-llama/Llama-2-7b-chat-hf',
82
+ 'meta-llama/Llama-2-13b-chat-hf',
83
+ 'meta-llama/Llama-2-34b-chat-hf',
84
+ 'meta-llama/Llama-2-70b-chat-hf',
85
+ ],
86
  # could be plain, but default is correct prompt_type for default TheBloke model ggml-wizardLM-7B.q4_2.bin
87
  }
88
  if os.getenv('OPENAI_API_KEY'):
 
588
  # if add space here, non-unique tokenization will often make LLM produce wrong output
589
  PreResponse = PreResponse
590
  # generates_leading_space = True
591
+ elif prompt_type in [PromptType.guanaco.value, str(PromptType.guanaco.value),
592
+ PromptType.guanaco.name]:
593
+ # https://huggingface.co/TheBloke/guanaco-65B-GPTQ
594
+ promptA = promptB = "" if not (chat and reduced) else ''
595
+
596
+ PreInstruct = """### Human: """
597
+
598
+ PreInput = None
599
+
600
+ PreResponse = """### Assistant:"""
601
+ terminate_response = ['### Human:'] # but only allow terminate after prompt is found correctly, else can't terminate
602
+ chat_turn_sep = chat_sep = '\n'
603
+ humanstr = PreInstruct
604
+ botstr = PreResponse
605
+ elif prompt_type in [PromptType.llama2.value, str(PromptType.llama2.value),
606
+ PromptType.llama2.name]:
607
+ PreInstruct = ""
608
+ llama2_sys = "<<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n"
609
+ prompt = "<s>[INST] "
610
+ enable_sys = False # too much safety, hurts accuracy
611
+ if not (chat and reduced):
612
+ if enable_sys:
613
+ promptA = promptB = prompt + llama2_sys
614
+ else:
615
+ promptA = promptB = prompt
616
+ else:
617
+ promptA = promptB = ''
618
+ PreInput = None
619
+ PreResponse = ""
620
+ terminate_response = ["[INST]", "</s>"]
621
+ chat_sep = ' [/INST]'
622
+ chat_turn_sep = ' </s><s>[INST] '
623
+ humanstr = PreInstruct
624
+ botstr = PreResponse
625
+ if making_context:
626
+ PreResponse += " "
627
  else:
628
  raise RuntimeError("No such prompt_type=%s" % prompt_type)
629
 
 
852
  if oi > 0:
853
  # post fix outputs with seperator
854
  output += '\n'
855
+ output = self.fix_text(self.prompt_type, output)
856
  outputs[oi] = output
857
  # join all outputs, only one extra new line between outputs
858
  output = '\n'.join(outputs)
859
  if self.debug:
860
  print("outputclean:\n%s" % '\n\n'.join(outputs), flush=True)
861
  return output
862
+
863
+ @staticmethod
864
+ def fix_text(prompt_type1, text1):
865
+ if prompt_type1 == 'human_bot':
866
+ # hack bug in vLLM with stopping, stops right, but doesn't return last token
867
+ hfix = '<human'
868
+ if text1.endswith(hfix):
869
+ text1 = text1[:-len(hfix)]
870
+ return text1
871
+
requirements.txt CHANGED
@@ -1,8 +1,8 @@
1
  # for generate (gradio server) and finetune
2
  datasets==2.13.0
3
  sentencepiece==0.1.99
4
- gradio==3.35.2
5
- huggingface_hub==0.15.1
6
  appdirs==1.4.4
7
  fire==0.5.0
8
  docutils==0.20.1
@@ -19,7 +19,7 @@ matplotlib==3.7.1
19
  loralib==0.1.1
20
  bitsandbytes==0.39.0
21
  accelerate==0.20.3
22
- git+https://github.com/huggingface/peft.git@06fd06a4d2e8ed8c3a253c67d9c3cb23e0f497ad
23
  transformers==4.30.2
24
  tokenizers==0.13.3
25
  APScheduler==3.10.1
@@ -35,7 +35,7 @@ tensorboard==2.13.0
35
  neptune==1.2.0
36
 
37
  # for gradio client
38
- gradio_client==0.2.7
39
  beautifulsoup4==4.12.2
40
  markdown==3.4.3
41
 
@@ -64,8 +64,8 @@ tiktoken==0.4.0
64
  # optional: for OpenAI endpoint or embeddings (requires key)
65
  openai==0.27.8
66
  # optional for chat with PDF
67
- langchain==0.0.202
68
- pypdf==3.9.1
69
  # avoid textract, requires old six
70
  #textract==1.6.5
71
 
@@ -78,10 +78,10 @@ chromadb==0.3.25
78
  #pymilvus==2.2.8
79
 
80
  # weak url support, if can't install opencv etc. If comment-in this one, then comment-out unstructured[local-inference]==0.6.6
81
- # unstructured==0.6.6
82
 
83
  # strong support for images
84
- # Requires on Ubuntu: sudo apt-get install libmagic-dev poppler-utils tesseract-ocr libreoffice
85
  unstructured[local-inference]==0.7.4
86
  #pdf2image==1.16.3
87
  #pytesseract==0.3.10
@@ -104,10 +104,10 @@ tabulate==0.9.0
104
  pip-licenses==4.3.0
105
 
106
  # weaviate vector db
107
- weaviate-client==3.20.0
108
  # optional for chat with PDF
109
- langchain==0.0.202
110
- pypdf==3.9.1
111
  # avoid textract, requires old six
112
  #textract==1.6.5
113
 
@@ -120,10 +120,10 @@ chromadb==0.3.25
120
  #pymilvus==2.2.8
121
 
122
  # weak url support, if can't install opencv etc. If comment-in this one, then comment-out unstructured[local-inference]==0.6.6
123
- # unstructured==0.6.6
124
 
125
  # strong support for images
126
- # Requires on Ubuntu: sudo apt-get install libmagic-dev poppler-utils tesseract-ocr libreoffice
127
  unstructured[local-inference]==0.7.4
128
  #pdf2image==1.16.3
129
  #pytesseract==0.3.10
@@ -146,8 +146,8 @@ tabulate==0.9.0
146
  pip-licenses==4.3.0
147
 
148
  # weaviate vector db
149
- weaviate-client==3.20.0
150
  faiss-gpu==1.7.2
151
- arxiv==1.4.7
152
- pymupdf==1.22.3 # AGPL license
153
  # extract-msg==0.41.1 # GPL3
 
1
  # for generate (gradio server) and finetune
2
  datasets==2.13.0
3
  sentencepiece==0.1.99
4
+ gradio==3.37.0
5
+ huggingface_hub==0.16.4
6
  appdirs==1.4.4
7
  fire==0.5.0
8
  docutils==0.20.1
 
19
  loralib==0.1.1
20
  bitsandbytes==0.39.0
21
  accelerate==0.20.3
22
+ peft==0.4.0
23
  transformers==4.30.2
24
  tokenizers==0.13.3
25
  APScheduler==3.10.1
 
35
  neptune==1.2.0
36
 
37
  # for gradio client
38
+ gradio_client==0.2.10
39
  beautifulsoup4==4.12.2
40
  markdown==3.4.3
41
 
 
64
  # optional: for OpenAI endpoint or embeddings (requires key)
65
  openai==0.27.8
66
  # optional for chat with PDF
67
+ langchain==0.0.235
68
+ pypdf==3.12.2
69
  # avoid textract, requires old six
70
  #textract==1.6.5
71
 
 
78
  #pymilvus==2.2.8
79
 
80
  # weak url support, if can't install opencv etc. If comment-in this one, then comment-out unstructured[local-inference]==0.6.6
81
+ # unstructured==0.8.1
82
 
83
  # strong support for images
84
+ # Requires on Ubuntu: sudo apt-get install libmagic-dev poppler-utils tesseract-ocr libtesseract-dev libreoffice
85
  unstructured[local-inference]==0.7.4
86
  #pdf2image==1.16.3
87
  #pytesseract==0.3.10
 
104
  pip-licenses==4.3.0
105
 
106
  # weaviate vector db
107
+ weaviate-client==3.22.1
108
  # optional for chat with PDF
109
+ langchain==0.0.235
110
+ pypdf==3.12.2
111
  # avoid textract, requires old six
112
  #textract==1.6.5
113
 
 
120
  #pymilvus==2.2.8
121
 
122
  # weak url support, if can't install opencv etc. If comment-in this one, then comment-out unstructured[local-inference]==0.6.6
123
+ # unstructured==0.8.1
124
 
125
  # strong support for images
126
+ # Requires on Ubuntu: sudo apt-get install libmagic-dev poppler-utils tesseract-ocr libtesseract-dev libreoffice
127
  unstructured[local-inference]==0.7.4
128
  #pdf2image==1.16.3
129
  #pytesseract==0.3.10
 
146
  pip-licenses==4.3.0
147
 
148
  # weaviate vector db
149
+ weaviate-client==3.22.1
150
  faiss-gpu==1.7.2
151
+ arxiv==1.4.8
152
+ pymupdf==1.22.5 # AGPL license
153
  # extract-msg==0.41.1 # GPL3
utils.py CHANGED
@@ -5,6 +5,7 @@ import inspect
5
  import os
6
  import gc
7
  import pathlib
 
8
  import random
9
  import shutil
10
  import subprocess
@@ -111,12 +112,15 @@ def system_info():
111
  system = {}
112
  # https://stackoverflow.com/questions/48951136/plot-multiple-graphs-in-one-plot-using-tensorboard
113
  # https://arshren.medium.com/monitoring-your-devices-in-python-5191d672f749
114
- temps = psutil.sensors_temperatures(fahrenheit=False)
115
- if 'coretemp' in temps:
116
- coretemp = temps['coretemp']
117
- temp_dict = {k.label: k.current for k in coretemp}
118
- for k, v in temp_dict.items():
119
- system['CPU_C/%s' % k] = v
 
 
 
120
 
121
  # https://github.com/gpuopenanalytics/pynvml/blob/master/help_query_gpu.txt
122
  try:
@@ -779,6 +783,9 @@ def _traced_func(func, *args, **kwargs):
779
 
780
 
781
  def call_subprocess_onetask(func, args=None, kwargs=None):
 
 
 
782
  if isinstance(args, list):
783
  args = tuple(args)
784
  if args is None:
@@ -950,7 +957,6 @@ try:
950
  except (pkg_resources.DistributionNotFound, AssertionError):
951
  have_langchain = False
952
 
953
-
954
  import distutils.spawn
955
 
956
  have_tesseract = distutils.spawn.find_executable("tesseract")
@@ -985,3 +991,90 @@ except (pkg_resources.DistributionNotFound, AssertionError):
985
 
986
  # disable, hangs too often
987
  have_playwright = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import os
6
  import gc
7
  import pathlib
8
+ import pickle
9
  import random
10
  import shutil
11
  import subprocess
 
112
  system = {}
113
  # https://stackoverflow.com/questions/48951136/plot-multiple-graphs-in-one-plot-using-tensorboard
114
  # https://arshren.medium.com/monitoring-your-devices-in-python-5191d672f749
115
+ try:
116
+ temps = psutil.sensors_temperatures(fahrenheit=False)
117
+ if 'coretemp' in temps:
118
+ coretemp = temps['coretemp']
119
+ temp_dict = {k.label: k.current for k in coretemp}
120
+ for k, v in temp_dict.items():
121
+ system['CPU_C/%s' % k] = v
122
+ except AttributeError:
123
+ pass
124
 
125
  # https://github.com/gpuopenanalytics/pynvml/blob/master/help_query_gpu.txt
126
  try:
 
783
 
784
 
785
  def call_subprocess_onetask(func, args=None, kwargs=None):
786
+ import platform
787
+ if platform.system() in ['Darwin', 'Windows']:
788
+ return func(*args, **kwargs)
789
  if isinstance(args, list):
790
  args = tuple(args)
791
  if args is None:
 
957
  except (pkg_resources.DistributionNotFound, AssertionError):
958
  have_langchain = False
959
 
 
960
  import distutils.spawn
961
 
962
  have_tesseract = distutils.spawn.find_executable("tesseract")
 
991
 
992
  # disable, hangs too often
993
  have_playwright = False
994
+
995
+
996
+ def set_openai(inference_server):
997
+ if inference_server.startswith('vllm'):
998
+ import openai_vllm
999
+ openai_vllm.api_key = "EMPTY"
1000
+ inf_type = inference_server.split(':')[0]
1001
+ ip_vllm = inference_server.split(':')[1]
1002
+ port_vllm = inference_server.split(':')[2]
1003
+ openai_vllm.api_base = f"http://{ip_vllm}:{port_vllm}/v1"
1004
+ return openai_vllm, inf_type
1005
+ else:
1006
+ import openai
1007
+ openai.api_key = os.getenv("OPENAI_API_KEY")
1008
+ openai.api_base = os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1")
1009
+ inf_type = inference_server
1010
+ return openai, inf_type
1011
+
1012
+
1013
+ visible_langchain_modes_file = 'visible_langchain_modes.pkl'
1014
+
1015
+
1016
+ def save_collection_names(langchain_modes, visible_langchain_modes, langchain_mode_paths, LangChainMode, db1s):
1017
+ """
1018
+ extra controls if UserData type of MyData type
1019
+ """
1020
+
1021
+ # use first default MyData hash as general user hash to maintain file
1022
+ # if user moves MyData from langchain modes, db will still survive, so can still use hash
1023
+ scratch_collection_names = list(db1s.keys())
1024
+ user_hash = db1s.get(LangChainMode.MY_DATA.value, '')[1]
1025
+
1026
+ llms = ['ChatLLM', 'LLM', 'Disabled']
1027
+
1028
+ scratch_langchain_modes = [x for x in langchain_modes if x in scratch_collection_names]
1029
+ scratch_visible_langchain_modes = [x for x in visible_langchain_modes if x in scratch_collection_names]
1030
+ scratch_langchain_mode_paths = {k: v for k, v in langchain_mode_paths.items() if
1031
+ k in scratch_collection_names and k not in llms}
1032
+
1033
+ user_langchain_modes = [x for x in langchain_modes if x not in scratch_collection_names]
1034
+ user_visible_langchain_modes = [x for x in visible_langchain_modes if x not in scratch_collection_names]
1035
+ user_langchain_mode_paths = {k: v for k, v in langchain_mode_paths.items() if
1036
+ k not in scratch_collection_names and k not in llms}
1037
+
1038
+ base_path = 'locks'
1039
+ makedirs(base_path)
1040
+
1041
+ # user
1042
+ extra = ''
1043
+ file = "%s%s" % (visible_langchain_modes_file, extra)
1044
+ with filelock.FileLock(os.path.join(base_path, "%s.lock" % file)):
1045
+ with open(file, 'wb') as f:
1046
+ pickle.dump((user_langchain_modes, user_visible_langchain_modes, user_langchain_mode_paths), f)
1047
+
1048
+ # scratch
1049
+ extra = user_hash
1050
+ file = "%s%s" % (visible_langchain_modes_file, extra)
1051
+ with filelock.FileLock(os.path.join(base_path, "%s.lock" % file)):
1052
+ with open(file, 'wb') as f:
1053
+ pickle.dump((scratch_langchain_modes, scratch_visible_langchain_modes, scratch_langchain_mode_paths), f)
1054
+
1055
+
1056
+ def load_collection_enum(extra):
1057
+ """
1058
+ extra controls if UserData type of MyData type
1059
+ """
1060
+ file = "%s%s" % (visible_langchain_modes_file, extra)
1061
+ langchain_modes_from_file = []
1062
+ visible_langchain_modes_from_file = []
1063
+ langchain_mode_paths_from_file = {}
1064
+ if os.path.isfile(visible_langchain_modes_file):
1065
+ try:
1066
+ with filelock.FileLock("%s.lock" % file):
1067
+ with open(file, 'rb') as f:
1068
+ langchain_modes_from_file, visible_langchain_modes_from_file, langchain_mode_paths_from_file = pickle.load(
1069
+ f)
1070
+ except BaseException as e:
1071
+ print("Cannot load %s, ignoring error: %s" % (file, str(e)), flush=True)
1072
+ for k, v in langchain_mode_paths_from_file.items():
1073
+ if v is not None and not os.path.isdir(v) and isinstance(v, str):
1074
+ # assume was deleted, but need to make again to avoid extra code elsewhere
1075
+ makedirs(v)
1076
+ return langchain_modes_from_file, visible_langchain_modes_from_file, langchain_mode_paths_from_file
1077
+
1078
+
1079
+ def remove_collection_enum():
1080
+ remove(visible_langchain_modes_file)