pseudotensor commited on
Commit
33cecf1
·
1 Parent(s): 292256e

Update with h2oGPT hash 5e48d85682992dae799309dabaf0a4fa501a096f

Browse files
Files changed (4) hide show
  1. src/client_test.py +17 -8
  2. src/gen.py +44 -7
  3. src/gpt_langchain.py +337 -279
  4. src/gradio_runner.py +1 -0
src/client_test.py CHANGED
@@ -80,7 +80,7 @@ def get_args(prompt, prompt_type=None, chat=False, stream_output=False,
80
  version=None,
81
  h2ogpt_key=None,
82
  visible_models=None,
83
- system_prompt='', # default of no system prompt tiggered by empty string
84
  add_search_to_context=False,
85
  chat_conversation=None,
86
  text_context_list=None,
@@ -256,13 +256,18 @@ def run_client_nochat_api(prompt, prompt_type, max_new_tokens, version=None, h2o
256
 
257
 
258
  @pytest.mark.skip(reason="For manual use against some server, no server launched")
259
- def test_client_basic_api_lean(prompt_type='human_bot', version=None, h2ogpt_key=None):
260
- return run_client_nochat_api_lean(prompt='Who are you?', prompt_type=prompt_type, max_new_tokens=50,
261
- version=version, h2ogpt_key=h2ogpt_key)
 
 
 
262
 
263
 
264
- def run_client_nochat_api_lean(prompt, prompt_type, max_new_tokens, version=None, h2ogpt_key=None):
265
- kwargs = dict(instruction_nochat=prompt, h2ogpt_key=h2ogpt_key)
 
 
266
 
267
  api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
268
  client = get_client(serialize=True)
@@ -362,7 +367,9 @@ def run_client_chat(prompt='',
362
  langchain_agents=[],
363
  prompt_type=None, prompt_dict=None,
364
  version=None,
365
- h2ogpt_key=None):
 
 
366
  client = get_client(serialize=False)
367
 
368
  kwargs, args = get_args(prompt, prompt_type, chat=True, stream_output=stream_output,
@@ -372,7 +379,9 @@ def run_client_chat(prompt='',
372
  langchain_agents=langchain_agents,
373
  prompt_dict=prompt_dict,
374
  version=version,
375
- h2ogpt_key=h2ogpt_key)
 
 
376
  return run_client(client, prompt, args, kwargs)
377
 
378
 
 
80
  version=None,
81
  h2ogpt_key=None,
82
  visible_models=None,
83
+ system_prompt='', # default of no system prompt triggered by empty string
84
  add_search_to_context=False,
85
  chat_conversation=None,
86
  text_context_list=None,
 
256
 
257
 
258
  @pytest.mark.skip(reason="For manual use against some server, no server launched")
259
+ def test_client_basic_api_lean(prompt='Who are you?', prompt_type='human_bot', version=None, h2ogpt_key=None,
260
+ chat_conversation=None, system_prompt=''):
261
+ return run_client_nochat_api_lean(prompt=prompt, prompt_type=prompt_type, max_new_tokens=50,
262
+ version=version, h2ogpt_key=h2ogpt_key,
263
+ chat_conversation=chat_conversation,
264
+ system_prompt=system_prompt)
265
 
266
 
267
+ def run_client_nochat_api_lean(prompt, prompt_type, max_new_tokens, version=None, h2ogpt_key=None,
268
+ chat_conversation=None, system_prompt=''):
269
+ kwargs = dict(instruction_nochat=prompt, h2ogpt_key=h2ogpt_key, chat_conversation=chat_conversation,
270
+ system_prompt=system_prompt)
271
 
272
  api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
273
  client = get_client(serialize=True)
 
367
  langchain_agents=[],
368
  prompt_type=None, prompt_dict=None,
369
  version=None,
370
+ h2ogpt_key=None,
371
+ chat_conversation=None,
372
+ system_prompt=''):
373
  client = get_client(serialize=False)
374
 
375
  kwargs, args = get_args(prompt, prompt_type, chat=True, stream_output=stream_output,
 
379
  langchain_agents=langchain_agents,
380
  prompt_dict=prompt_dict,
381
  version=version,
382
+ h2ogpt_key=h2ogpt_key,
383
+ chat_conversation=chat_conversation,
384
+ system_prompt=system_prompt)
385
  return run_client(client, prompt, args, kwargs)
386
 
387
 
src/gen.py CHANGED
@@ -335,7 +335,7 @@ def main(
335
 
336
  Or Address can be for vLLM:
337
  Use: "vllm:IP:port" for OpenAI-compliant vLLM endpoint
338
- Note: vllm_chat not supported by vLLM project.
339
 
340
  Or Address can be replicate:
341
  Use:
@@ -2236,6 +2236,17 @@ def evaluate(
2236
  instruction = instruction_nochat
2237
  iinput = iinput_nochat
2238
 
 
 
 
 
 
 
 
 
 
 
 
2239
  # in some cases, like lean nochat API, don't want to force sending prompt_type, allow default choice
2240
  model_lower = base_model.lower()
2241
  if not prompt_type and model_lower in inv_prompt_type_to_model_lower and prompt_type != 'custom':
@@ -2484,7 +2495,8 @@ def evaluate(
2484
  prompt, \
2485
  instruction, iinput, context, \
2486
  num_prompt_tokens, max_new_tokens, num_prompt_tokens0, num_prompt_tokens_actual, \
2487
- chat_index, top_k_docs_trial, one_doc_size = \
 
2488
  get_limited_prompt(instruction,
2489
  iinput,
2490
  tokenizer,
@@ -2552,8 +2564,6 @@ def evaluate(
2552
  sanitize_bot_response=sanitize_bot_response)
2553
  yield dict(response=response, sources=sources, save_dict=dict())
2554
  elif inf_type == 'vllm_chat' or inference_server == 'openai_chat':
2555
- if inf_type == 'vllm_chat':
2556
- raise NotImplementedError('%s not supported by vLLM' % inf_type)
2557
  if system_prompt in [None, 'None', 'auto']:
2558
  openai_system_prompt = "You are a helpful assistant."
2559
  else:
@@ -2561,7 +2571,16 @@ def evaluate(
2561
  messages0 = []
2562
  if openai_system_prompt:
2563
  messages0.append({"role": "system", "content": openai_system_prompt})
2564
- messages0.append({'role': 'user', 'content': prompt})
 
 
 
 
 
 
 
 
 
2565
  responses = openai.ChatCompletion.create(
2566
  model=base_model,
2567
  messages=messages0,
@@ -3609,13 +3628,27 @@ def get_limited_prompt(instruction,
3609
  stream_output = prompter.stream_output
3610
  system_prompt = prompter.system_prompt
3611
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3612
  # merge handles if chat_conversation is None
3613
  history = []
3614
  history = merge_chat_conversation_history(chat_conversation, history)
3615
  history_to_context_func = functools.partial(history_to_context,
3616
  langchain_mode=langchain_mode,
3617
  add_chat_history_to_context=add_chat_history_to_context,
3618
- prompt_type=prompt_type,
3619
  prompt_dict=prompt_dict,
3620
  chat=chat,
3621
  model_max_length=model_max_length,
@@ -3748,6 +3781,9 @@ def get_limited_prompt(instruction,
3748
  stream_output = False # doesn't matter
3749
  prompter = Prompter(prompt_type, prompt_dict, debug=debug, chat=chat, stream_output=stream_output,
3750
  system_prompt=system_prompt)
 
 
 
3751
 
3752
  data_point = dict(context=context, instruction=instruction, input=iinput)
3753
  # handle promptA/promptB addition if really from history.
@@ -3760,7 +3796,8 @@ def get_limited_prompt(instruction,
3760
  return prompt, \
3761
  instruction, iinput, context, \
3762
  num_prompt_tokens, max_new_tokens, num_prompt_tokens0, num_prompt_tokens_actual, \
3763
- chat_index, top_k_docs, one_doc_size
 
3764
 
3765
 
3766
  def get_docs_tokens(tokenizer, text_context_list=[], max_input_tokens=None):
 
335
 
336
  Or Address can be for vLLM:
337
  Use: "vllm:IP:port" for OpenAI-compliant vLLM endpoint
338
+ Use: "vllm_chat:IP:port" for OpenAI-Chat-compliant vLLM endpoint
339
 
340
  Or Address can be replicate:
341
  Use:
 
2236
  instruction = instruction_nochat
2237
  iinput = iinput_nochat
2238
 
2239
+ # avoid instruction in chat_conversation itself, since always used as additional context to prompt in what follows
2240
+ if isinstance(chat_conversation, list) and \
2241
+ len(chat_conversation) > 0 and \
2242
+ len(chat_conversation[-1]) == 2 and \
2243
+ chat_conversation[-1][0] == instruction:
2244
+ chat_conversation = chat_conversation[:-1]
2245
+ if not add_chat_history_to_context:
2246
+ # make it easy to ignore without needing add_chat_history_to_context
2247
+ # some langchain or unit test may need to then handle more general case
2248
+ chat_conversation = []
2249
+
2250
  # in some cases, like lean nochat API, don't want to force sending prompt_type, allow default choice
2251
  model_lower = base_model.lower()
2252
  if not prompt_type and model_lower in inv_prompt_type_to_model_lower and prompt_type != 'custom':
 
2495
  prompt, \
2496
  instruction, iinput, context, \
2497
  num_prompt_tokens, max_new_tokens, num_prompt_tokens0, num_prompt_tokens_actual, \
2498
+ chat_index, external_handle_chat_conversation, \
2499
+ top_k_docs_trial, one_doc_size = \
2500
  get_limited_prompt(instruction,
2501
  iinput,
2502
  tokenizer,
 
2564
  sanitize_bot_response=sanitize_bot_response)
2565
  yield dict(response=response, sources=sources, save_dict=dict())
2566
  elif inf_type == 'vllm_chat' or inference_server == 'openai_chat':
 
 
2567
  if system_prompt in [None, 'None', 'auto']:
2568
  openai_system_prompt = "You are a helpful assistant."
2569
  else:
 
2571
  messages0 = []
2572
  if openai_system_prompt:
2573
  messages0.append({"role": "system", "content": openai_system_prompt})
2574
+ if chat_conversation and add_chat_history_to_context:
2575
+ assert external_handle_chat_conversation, "Should be handling only externally"
2576
+ # chat_index handles token counting issues
2577
+ for message1 in chat_conversation[chat_index:]:
2578
+ if len(message1) == 2:
2579
+ messages0.append(
2580
+ {'role': 'user', 'content': message1[0] if message1[0] is not None else ''})
2581
+ messages0.append(
2582
+ {'role': 'assistant', 'content': message1[1] if message1[1] is not None else ''})
2583
+ messages0.append({'role': 'user', 'content': prompt if prompt is not None else ''})
2584
  responses = openai.ChatCompletion.create(
2585
  model=base_model,
2586
  messages=messages0,
 
3628
  stream_output = prompter.stream_output
3629
  system_prompt = prompter.system_prompt
3630
 
3631
+ generate_prompt_type = prompt_type
3632
+ external_handle_chat_conversation = False
3633
+ if any(inference_server.startswith(x) for x in ['openai_chat', 'openai_azure_chat', 'vllm_chat']):
3634
+ # Chat APIs do not take prompting
3635
+ # Replicate does not need prompting if no chat history, but in general can take prompting
3636
+ # if using prompter, prompter.system_prompt will already be filled with automatic (e.g. from llama-2),
3637
+ # so if replicate final prompt with system prompt still correct because only access prompter.system_prompt that was already set
3638
+ # below already true for openai,
3639
+ # but not vllm by default as that can be any model and handled by FastChat API inside vLLM itself
3640
+ generate_prompt_type = 'plain'
3641
+ # Chat APIs don't handle chat history via single prompt, but in messages, assumed to be handled outside this function
3642
+ chat_conversation = []
3643
+ external_handle_chat_conversation = True
3644
+
3645
  # merge handles if chat_conversation is None
3646
  history = []
3647
  history = merge_chat_conversation_history(chat_conversation, history)
3648
  history_to_context_func = functools.partial(history_to_context,
3649
  langchain_mode=langchain_mode,
3650
  add_chat_history_to_context=add_chat_history_to_context,
3651
+ prompt_type=generate_prompt_type,
3652
  prompt_dict=prompt_dict,
3653
  chat=chat,
3654
  model_max_length=model_max_length,
 
3781
  stream_output = False # doesn't matter
3782
  prompter = Prompter(prompt_type, prompt_dict, debug=debug, chat=chat, stream_output=stream_output,
3783
  system_prompt=system_prompt)
3784
+ if prompt_type != generate_prompt_type:
3785
+ # override just this attribute, keep system_prompt etc. from original prompt_type
3786
+ prompter.prompt_type = generate_prompt_type
3787
 
3788
  data_point = dict(context=context, instruction=instruction, input=iinput)
3789
  # handle promptA/promptB addition if really from history.
 
3796
  return prompt, \
3797
  instruction, iinput, context, \
3798
  num_prompt_tokens, max_new_tokens, num_prompt_tokens0, num_prompt_tokens_actual, \
3799
+ chat_index, external_handle_chat_conversation, \
3800
+ top_k_docs, one_doc_size
3801
 
3802
 
3803
  def get_docs_tokens(tokenizer, text_context_list=[], max_input_tokens=None):
src/gpt_langchain.py CHANGED
@@ -29,10 +29,11 @@ import yaml
29
 
30
  from joblib import delayed
31
  from langchain.callbacks import streaming_stdout
 
32
  from langchain.embeddings import HuggingFaceInstructEmbeddings
33
  from langchain.llms.huggingface_pipeline import VALID_TASKS
34
  from langchain.llms.utils import enforce_stop_tokens
35
- from langchain.schema import LLMResult, Generation
36
  from langchain.tools import PythonREPLTool
37
  from langchain.tools.json.tool import JsonSpec
38
  from tqdm import tqdm
@@ -944,7 +945,10 @@ class H2OReplicate(Replicate):
944
  assert self.tokenizer is not None
945
  from h2oai_pipeline import H2OTextGenerationPipeline
946
  prompt, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt, self.tokenizer)
947
- # Note Replicate handles the prompting of the specific model
 
 
 
948
  return super()._call(prompt, stop=stop, run_manager=run_manager, **kwargs)
949
 
950
  def get_token_ids(self, text: str) -> List[int]:
@@ -953,21 +957,98 @@ class H2OReplicate(Replicate):
953
  # return _get_token_ids_default_method(text)
954
 
955
 
956
- class H2OChatOpenAI(ChatOpenAI):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
957
  @classmethod
958
  def _all_required_field_names(cls) -> Set:
959
  _all_required_field_names = super(ChatOpenAI, cls)._all_required_field_names()
960
  _all_required_field_names.update({'top_p', 'frequency_penalty', 'presence_penalty', 'logit_bias'})
961
  return _all_required_field_names
962
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
963
 
964
- class H2OAzureChatOpenAI(AzureChatOpenAI):
965
  @classmethod
966
  def _all_required_field_names(cls) -> Set:
967
  _all_required_field_names = super(AzureChatOpenAI, cls)._all_required_field_names()
968
  _all_required_field_names.update({'top_p', 'frequency_penalty', 'presence_penalty', 'logit_bias'})
969
  return _all_required_field_names
970
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
971
 
972
  class H2OAzureOpenAI(AzureOpenAI):
973
  @classmethod
@@ -1052,7 +1133,7 @@ def get_llm(use_openai_model=False,
1052
  if 'meta/llama' in model_string:
1053
  temperature = max(0.01, temperature if do_sample else 0)
1054
  else:
1055
- temperature =temperature if do_sample else 0
1056
  gen_kwargs = dict(temperature=temperature,
1057
  seed=1234,
1058
  max_length=max_new_tokens, # langchain
@@ -1068,8 +1149,7 @@ def get_llm(use_openai_model=False,
1068
  if system_prompt:
1069
  gen_kwargs.update(dict(system_prompt=system_prompt))
1070
 
1071
- # replicate handles prompting, so avoid get_response() filter
1072
- prompter.prompt_type = 'plain'
1073
  if stream_output:
1074
  callbacks = [StreamingGradioCallbackHandler()]
1075
  streamer = callbacks[0] if stream_output else None
@@ -1108,8 +1188,8 @@ def get_llm(use_openai_model=False,
1108
  if inf_type == 'openai_chat' or inf_type == 'vllm_chat':
1109
  cls = H2OChatOpenAI
1110
  # FIXME: Support context, iinput
1111
- # if inf_type == 'vllm_chat':
1112
- # kwargs_extra.update(dict(tokenizer=tokenizer))
1113
  openai_api_key = openai.api_key
1114
  elif inf_type == 'openai_azure_chat':
1115
  cls = H2OAzureChatOpenAI
@@ -1168,6 +1248,8 @@ def get_llm(use_openai_model=False,
1168
  logit_bias=None if inf_type == 'vllm' else {},
1169
  max_retries=6,
1170
  streaming=stream_output,
 
 
1171
  **kwargs_extra
1172
  )
1173
  streamer = callbacks[0] if stream_output else None
@@ -3500,7 +3582,6 @@ Respond to prompt of Final Answer with your final high-quality bullet list answe
3500
  prompter = Prompter(prompt_type, prompt_dict, debug=False, chat=chat, stream_output=stream_output,
3501
  system_prompt=system_prompt)
3502
 
3503
- use_docs_planned = False
3504
  scores = []
3505
  chain = None
3506
 
@@ -3517,8 +3598,8 @@ Respond to prompt of Final Answer with your final high-quality bullet list answe
3517
  missing_kwargs = [x for x in func_names if x not in sim_kwargs]
3518
  assert not missing_kwargs, "Missing: %s" % missing_kwargs
3519
  docs, chain, scores, \
3520
- use_docs_planned, num_docs_before_cut, \
3521
- use_llm_if_no_docs, llm_mode, top_k_docs_max_show = \
3522
  get_chain(**sim_kwargs)
3523
  if document_subset in non_query_commands:
3524
  formatted_doc_chunks = '\n\n'.join([get_url(x) + '\n\n' + x.page_content for x in docs])
@@ -3539,23 +3620,21 @@ Respond to prompt of Final Answer with your final high-quality bullet list answe
3539
  ret, extra = get_sources_answer(*get_answer_args, **get_answer_kwargs)
3540
  yield dict(prompt=prompt_basic, response=formatted_doc_chunks, sources=extra, num_prompt_tokens=0)
3541
  return
3542
- if not use_llm_if_no_docs:
3543
- if not docs and langchain_action in [LangChainAction.SUMMARIZE_MAP.value,
3544
- LangChainAction.SUMMARIZE_ALL.value,
3545
- LangChainAction.SUMMARIZE_REFINE.value]:
3546
- ret = 'No relevant documents to summarize.' if num_docs_before_cut else 'No documents to summarize.'
3547
- extra = ''
3548
- yield dict(prompt=prompt_basic, response=ret, sources=extra, num_prompt_tokens=0)
3549
- return
3550
- if not docs and not llm_mode:
3551
- ret = 'No relevant documents to query (for chatting with LLM, pick Resources->Collections->LLM).' if num_docs_before_cut else 'No documents to query (for chatting with LLM, pick Resources->Collections->LLM).'
3552
  extra = ''
3553
  yield dict(prompt=prompt_basic, response=ret, sources=extra, num_prompt_tokens=0)
3554
  return
3555
 
3556
- if chain is None and not langchain_only_model:
3557
- # here if no docs at all and not HF type
3558
- # can only return if HF type
3559
  return
3560
 
3561
  # context stuff similar to used in evaluate()
@@ -3656,7 +3735,8 @@ Respond to prompt of Final Answer with your final high-quality bullet list answe
3656
  prompt = prompt_basic
3657
  num_prompt_tokens = get_token_count(prompt, tokenizer)
3658
 
3659
- if not use_docs_planned:
 
3660
  ret = answer
3661
  extra = ''
3662
  yield dict(prompt=prompt, response=ret, sources=extra, num_prompt_tokens=num_prompt_tokens)
@@ -3815,8 +3895,7 @@ def get_chain(query=None,
3815
  if text_context_list is None:
3816
  text_context_list = []
3817
 
3818
- # default value:
3819
- llm_mode = langchain_mode in ['Disabled', 'LLM'] and len(text_context_list) == 0
3820
  query_action = langchain_action == LangChainAction.QUERY.value
3821
  summarize_action = langchain_action in [LangChainAction.SUMMARIZE_MAP.value,
3822
  LangChainAction.SUMMARIZE_ALL.value,
@@ -3848,8 +3927,6 @@ def get_chain(query=None,
3848
  add_search_to_context &= len(docs_search) > 0
3849
  top_k_docs_max_show = max(top_k_docs_max_show, len(docs_search))
3850
 
3851
- if len(text_context_list) > 0:
3852
- llm_mode = False
3853
  use_llm_if_no_docs = True
3854
 
3855
  from src.output_parser import H2OMRKLOutputParser
@@ -3877,10 +3954,9 @@ def get_chain(query=None,
3877
 
3878
  docs = []
3879
  scores = []
3880
- use_docs_planned = False
3881
  num_docs_before_cut = 0
3882
  use_llm_if_no_docs = True
3883
- return docs, target, scores, use_docs_planned, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show
3884
 
3885
  if LangChainAgent.COLLECTION.value in langchain_agents:
3886
  output_parser = H2OMRKLOutputParser()
@@ -3899,10 +3975,9 @@ def get_chain(query=None,
3899
 
3900
  docs = []
3901
  scores = []
3902
- use_docs_planned = False
3903
  num_docs_before_cut = 0
3904
  use_llm_if_no_docs = True
3905
- return docs, target, scores, use_docs_planned, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show
3906
 
3907
  if LangChainAgent.PYTHON.value in langchain_agents and inference_server.startswith('openai'):
3908
  chain = create_python_agent(
@@ -3918,10 +3993,9 @@ def get_chain(query=None,
3918
 
3919
  docs = []
3920
  scores = []
3921
- use_docs_planned = False
3922
  num_docs_before_cut = 0
3923
  use_llm_if_no_docs = True
3924
- return docs, target, scores, use_docs_planned, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show
3925
 
3926
  if LangChainAgent.PANDAS.value in langchain_agents and inference_server.startswith('openai_chat'):
3927
  # FIXME: DATA
@@ -3938,10 +4012,9 @@ def get_chain(query=None,
3938
 
3939
  docs = []
3940
  scores = []
3941
- use_docs_planned = False
3942
  num_docs_before_cut = 0
3943
  use_llm_if_no_docs = True
3944
- return docs, target, scores, use_docs_planned, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show
3945
 
3946
  if isinstance(document_choice, str):
3947
  document_choice = [document_choice]
@@ -3971,10 +4044,9 @@ def get_chain(query=None,
3971
 
3972
  docs = []
3973
  scores = []
3974
- use_docs_planned = False
3975
  num_docs_before_cut = 0
3976
  use_llm_if_no_docs = True
3977
- return docs, target, scores, use_docs_planned, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show
3978
 
3979
  if isinstance(document_choice, str):
3980
  document_choice = [document_choice]
@@ -3985,7 +4057,7 @@ def get_chain(query=None,
3985
  document_choice_agent = [x for x in document_choice_agent if x.endswith('.csv')]
3986
  if LangChainAgent.CSV.value in langchain_agents and len(document_choice_agent) == 1 and document_choice_agent[
3987
  0].endswith(
3988
- '.csv'):
3989
  data_file = document_choice[0]
3990
  if inference_server.startswith('openai_chat'):
3991
  chain = create_csv_agent(
@@ -4006,19 +4078,9 @@ def get_chain(query=None,
4006
 
4007
  docs = []
4008
  scores = []
4009
- use_docs_planned = False
4010
  num_docs_before_cut = 0
4011
  use_llm_if_no_docs = True
4012
- return docs, target, scores, use_docs_planned, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show
4013
-
4014
- # determine whether use of context out of docs is planned
4015
- if not use_openai_model and prompt_type not in ['plain'] or langchain_only_model:
4016
- if llm_mode:
4017
- use_docs_planned = False
4018
- else:
4019
- use_docs_planned = True
4020
- else:
4021
- use_docs_planned = True
4022
 
4023
  # https://github.com/hwchase17/langchain/issues/1946
4024
  # FIXME: Seems to way to get size of chroma db to limit top_k_docs to avoid
@@ -4090,8 +4152,7 @@ def get_chain(query=None,
4090
  pre_prompt_query, prompt_query,
4091
  pre_prompt_summary, prompt_summary,
4092
  langchain_action,
4093
- llm_mode,
4094
- use_docs_planned,
4095
  auto_reduce_chunks,
4096
  got_db_docs,
4097
  add_search_to_context)
@@ -4099,239 +4160,242 @@ def get_chain(query=None,
4099
  max_input_tokens = get_max_input_tokens(llm=llm, tokenizer=tokenizer, inference_server=inference_server,
4100
  model_name=model_name, max_new_tokens=max_new_tokens)
4101
 
4102
- if (db or text_context_list) and use_docs_planned:
4103
- if hasattr(db, '_persist_directory'):
4104
- lock_file = get_db_lock_file(db, lock_type='sim')
4105
- else:
4106
- base_path = 'locks'
4107
- base_path = makedirs(base_path, exist_ok=True, tmp_ok=True, use_base=True)
4108
- name_path = "sim.lock"
4109
- lock_file = os.path.join(base_path, name_path)
4110
-
4111
- if not (isinstance(db, Chroma) or isinstance(db, ChromaMig) or ChromaMig.__name__ in str(db)):
4112
- # only chroma supports filtering
4113
- filter_kwargs = {}
4114
- filter_kwargs_backup = {}
4115
- else:
4116
- import logging
4117
- logging.getLogger("chromadb").setLevel(logging.ERROR)
4118
- assert document_choice is not None, "Document choice was None"
4119
- if isinstance(db, Chroma):
4120
- filter_kwargs_backup = {} # shouldn't ever need backup
4121
- # chroma >= 0.4
4122
- if len(document_choice) == 0 or len(document_choice) >= 1 and document_choice[
4123
- 0] == DocumentChoice.ALL.value:
4124
- filter_kwargs = {"filter": {"chunk_id": {"$gte": 0}}} if query_action else \
4125
- {"filter": {"chunk_id": {"$eq": -1}}}
4126
- else:
4127
- if document_choice[0] == DocumentChoice.ALL.value:
4128
- document_choice = document_choice[1:]
4129
- if len(document_choice) == 0:
4130
- filter_kwargs = {}
4131
- elif len(document_choice) > 1:
4132
- or_filter = [
4133
- {"$and": [dict(source={"$eq": x}), dict(chunk_id={"$gte": 0})]} if query_action else {
4134
- "$and": [dict(source={"$eq": x}), dict(chunk_id={"$eq": -1})]}
4135
- for x in document_choice]
4136
- filter_kwargs = dict(filter={"$or": or_filter})
4137
- else:
4138
- # still chromadb UX bug, have to do different thing for 1 vs. 2+ docs when doing filter
4139
- one_filter = \
4140
- [{"source": {"$eq": x}, "chunk_id": {"$gte": 0}} if query_action else {
4141
- "source": {"$eq": x},
4142
- "chunk_id": {
4143
- "$eq": -1}}
4144
- for x in document_choice][0]
4145
-
4146
- filter_kwargs = dict(filter={"$and": [dict(source=one_filter['source']),
4147
- dict(chunk_id=one_filter['chunk_id'])]})
4148
  else:
4149
- # migration for chroma < 0.4
4150
- if len(document_choice) == 0 or len(document_choice) >= 1 and document_choice[
4151
- 0] == DocumentChoice.ALL.value:
4152
- filter_kwargs = {"filter": {"chunk_id": {"$gte": 0}}} if query_action else \
4153
- {"filter": {"chunk_id": {"$eq": -1}}}
4154
- filter_kwargs_backup = {"filter": {"chunk_id": {"$gte": 0}}}
4155
- elif len(document_choice) >= 2:
4156
- if document_choice[0] == DocumentChoice.ALL.value:
4157
- document_choice = document_choice[1:]
4158
  or_filter = [
4159
- {"source": {"$eq": x}, "chunk_id": {"$gte": 0}} if query_action else {"source": {"$eq": x},
4160
- "chunk_id": {
4161
- "$eq": -1}}
4162
  for x in document_choice]
4163
  filter_kwargs = dict(filter={"$or": or_filter})
4164
- or_filter_backup = [
4165
- {"source": {"$eq": x}} if query_action else {"source": {"$eq": x}}
4166
- for x in document_choice]
4167
- filter_kwargs_backup = dict(filter={"$or": or_filter_backup})
4168
- elif len(document_choice) == 1:
4169
- # degenerate UX bug in chroma
4170
  one_filter = \
4171
- [{"source": {"$eq": x}, "chunk_id": {"$gte": 0}} if query_action else {"source": {"$eq": x},
4172
- "chunk_id": {
4173
- "$eq": -1}}
4174
- for x in document_choice][0]
4175
- filter_kwargs = dict(filter=one_filter)
4176
- one_filter_backup = \
4177
- [{"source": {"$eq": x}} if query_action else {"source": {"$eq": x}}
4178
  for x in document_choice][0]
4179
- filter_kwargs_backup = dict(filter=one_filter_backup)
4180
- else:
4181
- # shouldn't reach
4182
- filter_kwargs = {}
4183
- filter_kwargs_backup = {}
4184
 
4185
- if llm_mode:
4186
- docs = []
4187
- scores = []
4188
- elif document_subset == DocumentSubset.TopKSources.name or query in [None, '', '\n']:
4189
- db_documents, db_metadatas = get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs,
4190
- text_context_list=text_context_list)
4191
- if len(db_documents) == 0 and filter_kwargs_backup:
4192
- db_documents, db_metadatas = get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs_backup,
4193
- text_context_list=text_context_list)
4194
-
4195
- if top_k_docs == -1:
4196
- top_k_docs = len(db_documents)
4197
- # similar to langchain's chroma's _results_to_docs_and_scores
4198
- docs_with_score = [(Document(page_content=result[0], metadata=result[1] or {}), 0)
4199
- for result in zip(db_documents, db_metadatas)]
4200
- # set in metadata original order of docs
4201
- [x[0].metadata.update(orig_index=ii) for ii, x in enumerate(docs_with_score)]
4202
-
4203
- # order documents
4204
- doc_hashes = [x.get('doc_hash', 'None') for x in db_metadatas]
4205
- if query_action:
4206
- doc_chunk_ids = [x.get('chunk_id', 0) for x in db_metadatas]
4207
- docs_with_score2 = [x for hx, cx, x in
4208
- sorted(zip(doc_hashes, doc_chunk_ids, docs_with_score), key=lambda x: (x[0], x[1]))
4209
- if cx >= 0]
 
 
 
 
 
 
 
 
 
4210
  else:
4211
- assert summarize_action
4212
- doc_chunk_ids = [x.get('chunk_id', -1) for x in db_metadatas]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4213
  docs_with_score2 = [x for hx, cx, x in
4214
- sorted(zip(doc_hashes, doc_chunk_ids, docs_with_score), key=lambda x: (x[0], x[1]))
4215
- if cx == -1
4216
  ]
4217
- if len(docs_with_score2) == 0 and len(docs_with_score) > 0:
4218
- # old database without chunk_id, migration added 0 but didn't make -1 as that would be expensive
4219
- # just do again and relax filter, let summarize operate on actual chunks if nothing else
4220
- docs_with_score2 = [x for hx, cx, x in
4221
- sorted(zip(doc_hashes, doc_chunk_ids, docs_with_score),
4222
- key=lambda x: (x[0], x[1]))
4223
- ]
4224
- docs_with_score = docs_with_score2
4225
 
4226
- docs_with_score = docs_with_score[:top_k_docs]
4227
- docs = [x[0] for x in docs_with_score]
4228
- scores = [x[1] for x in docs_with_score]
4229
- num_docs_before_cut = len(docs)
4230
- else:
4231
- with filelock.FileLock(lock_file):
4232
- docs_with_score, got_db_docs = get_docs_with_score(query, k_db, filter_kwargs, db, db_type,
 
 
 
 
 
 
4233
  text_context_list=text_context_list,
4234
  verbose=verbose)
4235
- if len(docs_with_score) == 0 and filter_kwargs_backup:
4236
- docs_with_score, got_db_docs = get_docs_with_score(query, k_db, filter_kwargs_backup, db,
4237
- db_type,
4238
- text_context_list=text_context_list,
4239
- verbose=verbose)
4240
-
4241
- tokenizer = get_tokenizer(db=db, llm=llm, tokenizer=tokenizer, inference_server=inference_server,
4242
- use_openai_model=use_openai_model,
4243
- db_type=db_type)
4244
- # NOTE: if map_reduce, then no need to auto reduce chunks
4245
- if query_action and (top_k_docs == -1 or auto_reduce_chunks):
4246
- top_k_docs_tokenize = 100
4247
- docs_with_score = docs_with_score[:top_k_docs_tokenize]
4248
-
4249
- prompt_no_docs = template.format(context='', question=query)
4250
-
4251
- model_max_length = tokenizer.model_max_length
4252
- chat = True # FIXME?
4253
-
4254
- # first docs_with_score are most important with highest score
4255
- full_prompt, \
4256
- instruction, iinput, context, \
4257
- num_prompt_tokens, max_new_tokens, \
4258
- num_prompt_tokens0, num_prompt_tokens_actual, \
4259
- chat_index, top_k_docs_trial, one_doc_size = \
4260
- get_limited_prompt(prompt_no_docs,
4261
- iinput,
4262
- tokenizer,
4263
- prompter=prompter,
4264
- inference_server=inference_server,
4265
- prompt_type=prompt_type,
4266
- prompt_dict=prompt_dict,
4267
- chat=chat,
4268
- max_new_tokens=max_new_tokens,
4269
- system_prompt=system_prompt,
4270
- context=context,
4271
- chat_conversation=chat_conversation,
4272
- text_context_list=[x[0].page_content for x in docs_with_score],
4273
- keep_sources_in_context=keep_sources_in_context,
4274
- model_max_length=model_max_length,
4275
- memory_restriction_level=memory_restriction_level,
4276
- langchain_mode=langchain_mode,
4277
- add_chat_history_to_context=add_chat_history_to_context,
4278
- min_max_new_tokens=min_max_new_tokens,
4279
- )
 
 
 
 
 
 
 
 
4280
  # avoid craziness
4281
- if 0 < top_k_docs_trial < max_chunks:
4282
- # avoid craziness
4283
- if top_k_docs == -1:
4284
- top_k_docs = top_k_docs_trial
4285
- else:
4286
- top_k_docs = min(top_k_docs, top_k_docs_trial)
4287
- elif top_k_docs_trial >= max_chunks:
4288
- top_k_docs = max_chunks
4289
- if top_k_docs > 0:
4290
- docs_with_score = docs_with_score[:top_k_docs]
4291
- elif one_doc_size is not None:
4292
- docs_with_score = [docs_with_score[0][:one_doc_size]]
4293
  else:
4294
- docs_with_score = []
 
 
 
 
 
 
4295
  else:
4296
- if total_tokens_for_docs is not None:
4297
- # used to limit tokens for summarization, e.g. public instance
4298
- top_k_docs, one_doc_size, num_doc_tokens = \
4299
- get_docs_tokens(tokenizer,
4300
- text_context_list=[x[0].page_content for x in docs_with_score],
4301
- max_input_tokens=total_tokens_for_docs)
 
 
4302
 
4303
- docs_with_score = docs_with_score[:top_k_docs]
4304
 
4305
- # put most relevant chunks closest to question,
4306
- # esp. if truncation occurs will be "oldest" or "farthest from response" text that is truncated
4307
- # BUT: for small models, e.g. 6_9 pythia, if sees some stuff related to h2oGPT first, it can connect that and not listen to rest
4308
- if docs_ordering_type in ['best_first']:
4309
- pass
4310
- elif docs_ordering_type in ['best_near_prompt', 'reverse_sort']:
4311
- docs_with_score.reverse()
4312
- elif docs_ordering_type in ['', None, 'reverse_ucurve_sort']:
4313
- docs_with_score = reverse_ucurve_list(docs_with_score)
4314
- else:
4315
- raise ValueError("No such docs_ordering_type=%s" % docs_ordering_type)
4316
-
4317
- # cut off so no high distance docs/sources considered
4318
- num_docs_before_cut = len(docs_with_score)
4319
- docs = [x[0] for x in docs_with_score if x[1] < cut_distance]
4320
- scores = [x[1] for x in docs_with_score if x[1] < cut_distance]
4321
- if len(scores) > 0 and verbose:
4322
- print("Distance: min: %s max: %s mean: %s median: %s" %
4323
- (scores[0], scores[-1], np.mean(scores), np.median(scores)), flush=True)
4324
- else:
4325
- docs = []
4326
- scores = []
4327
 
4328
- if not docs and use_docs_planned and not langchain_only_model:
4329
- # if HF type and have no docs, can bail out
4330
- return docs, None, [], False, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show
4331
 
4332
  if document_subset in non_query_commands:
4333
- # no LLM use
4334
- return docs, None, [], False, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show
4335
 
4336
  # FIXME: WIP
4337
  common_words_file = "data/NGSL_1.2_stats.csv.zip"
@@ -4349,7 +4413,6 @@ def get_chain(query=None,
4349
 
4350
  if len(docs) == 0:
4351
  # avoid context == in prompt then
4352
- use_docs_planned = False
4353
  template = template_if_no_docs
4354
 
4355
  got_db_docs = got_db_docs and len(text_context_list) < len(docs)
@@ -4361,8 +4424,7 @@ def get_chain(query=None,
4361
  pre_prompt_query, prompt_query,
4362
  pre_prompt_summary, prompt_summary,
4363
  langchain_action,
4364
- llm_mode,
4365
- use_docs_planned,
4366
  auto_reduce_chunks,
4367
  got_db_docs,
4368
  add_search_to_context)
@@ -4380,10 +4442,7 @@ def get_chain(query=None,
4380
  else:
4381
  # only if use_openai_model = True, unused normally except in testing
4382
  chain = load_qa_with_sources_chain(llm)
4383
- if not use_docs_planned:
4384
- chain_kwargs = dict(input_documents=[], question=query)
4385
- else:
4386
- chain_kwargs = dict(input_documents=docs, question=query)
4387
  target = wrapped_partial(chain, chain_kwargs)
4388
  elif langchain_action in [LangChainAction.SUMMARIZE_MAP.value,
4389
  LangChainAction.SUMMARIZE_REFINE,
@@ -4427,7 +4486,7 @@ def get_chain(query=None,
4427
  else:
4428
  raise RuntimeError("No such langchain_action=%s" % langchain_action)
4429
 
4430
- return docs, target, scores, use_docs_planned, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show
4431
 
4432
 
4433
  def get_max_model_length(llm=None, tokenizer=None, inference_server=None, model_name=None):
@@ -4473,11 +4532,11 @@ def get_tokenizer(db=None, llm=None, tokenizer=None, inference_server=None, use_
4473
  if hasattr(llm, 'pipeline') and hasattr(llm.pipeline, 'tokenizer'):
4474
  # more accurate
4475
  return llm.pipeline.tokenizer
4476
- elif hasattr(llm, 'tokenizer'):
4477
  # e.g. TGI client mode etc.
4478
  return llm.tokenizer
4479
  elif inference_server in ['openai', 'openai_chat', 'openai_azure',
4480
- 'openai_azure_chat']:
4481
  return tokenizer
4482
  elif isinstance(tokenizer, FakeTokenizer):
4483
  return tokenizer
@@ -4500,8 +4559,7 @@ def get_template(query, iinput,
4500
  pre_prompt_query, prompt_query,
4501
  pre_prompt_summary, prompt_summary,
4502
  langchain_action,
4503
- llm_mode,
4504
- use_docs_planned,
4505
  auto_reduce_chunks,
4506
  got_db_docs,
4507
  add_search_to_context):
@@ -4523,7 +4581,7 @@ def get_template(query, iinput,
4523
  if langchain_action == LangChainAction.QUERY.value:
4524
  if iinput:
4525
  query = "%s\n%s" % (query, iinput)
4526
- if llm_mode or not use_docs_planned:
4527
  template_if_no_docs = template = """{context}{question}"""
4528
  else:
4529
  template = """%s
 
29
 
30
  from joblib import delayed
31
  from langchain.callbacks import streaming_stdout
32
+ from langchain.callbacks.base import Callbacks
33
  from langchain.embeddings import HuggingFaceInstructEmbeddings
34
  from langchain.llms.huggingface_pipeline import VALID_TASKS
35
  from langchain.llms.utils import enforce_stop_tokens
36
+ from langchain.schema import LLMResult, Generation, PromptValue
37
  from langchain.tools import PythonREPLTool
38
  from langchain.tools.json.tool import JsonSpec
39
  from tqdm import tqdm
 
945
  assert self.tokenizer is not None
946
  from h2oai_pipeline import H2OTextGenerationPipeline
947
  prompt, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt, self.tokenizer)
948
+ # Note Replicate handles the prompting of the specific model, but not if history, so just do it all on our side
949
+ data_point = dict(context=self.context, instruction=prompt, input=self.iinput)
950
+ prompt = self.prompter.generate_prompt(data_point)
951
+
952
  return super()._call(prompt, stop=stop, run_manager=run_manager, **kwargs)
953
 
954
  def get_token_ids(self, text: str) -> List[int]:
 
957
  # return _get_token_ids_default_method(text)
958
 
959
 
960
+ class ExtraChat:
961
+ def get_messages(self, prompts):
962
+ from langchain.schema import AIMessage, SystemMessage, HumanMessage
963
+ messages = []
964
+ if self.system_prompt:
965
+ messages.append(SystemMessage(content=self.system_prompt))
966
+ if self.chat_conversation:
967
+ for messages1 in self.chat_conversation:
968
+ messages.append(HumanMessage(content=messages1[0] if messages1[0] is not None else ''))
969
+ messages.append(AIMessage(content=messages1[1] if messages1[1] is not None else ''))
970
+ assert len(prompts) == 1, "Not implemented"
971
+ messages.append(HumanMessage(content=prompts[0].text if prompts[0].text is not None else ''))
972
+ return [messages]
973
+
974
+
975
+ class H2OChatOpenAI(ChatOpenAI, ExtraChat):
976
+ tokenizer: Any = None # for vllm_chat
977
+ system_prompt: Any = None
978
+ chat_conversation: Any = []
979
+
980
  @classmethod
981
  def _all_required_field_names(cls) -> Set:
982
  _all_required_field_names = super(ChatOpenAI, cls)._all_required_field_names()
983
  _all_required_field_names.update({'top_p', 'frequency_penalty', 'presence_penalty', 'logit_bias'})
984
  return _all_required_field_names
985
 
986
+ def get_token_ids(self, text: str) -> List[int]:
987
+ if self.tokenizer is not None:
988
+ return self.tokenizer.encode(text)
989
+ else:
990
+ # OpenAI uses tiktoken
991
+ return super().get_token_ids(text)
992
+
993
+ def generate_prompt(
994
+ self,
995
+ prompts: List[PromptValue],
996
+ stop: Optional[List[str]] = None,
997
+ callbacks: Callbacks = None,
998
+ **kwargs: Any,
999
+ ) -> LLMResult:
1000
+ prompt_messages = self.get_messages(prompts)
1001
+ # prompt_messages = [p.to_messages() for p in prompts]
1002
+ return self.generate(prompt_messages, stop=stop, callbacks=callbacks, **kwargs)
1003
+
1004
+ async def agenerate_prompt(
1005
+ self,
1006
+ prompts: List[PromptValue],
1007
+ stop: Optional[List[str]] = None,
1008
+ callbacks: Callbacks = None,
1009
+ **kwargs: Any,
1010
+ ) -> LLMResult:
1011
+ prompt_messages = self.get_messages(prompts)
1012
+ # prompt_messages = [p.to_messages() for p in prompts]
1013
+ return await self.agenerate(
1014
+ prompt_messages, stop=stop, callbacks=callbacks, **kwargs
1015
+ )
1016
+
1017
+
1018
+ class H2OAzureChatOpenAI(AzureChatOpenAI, ExtraChat):
1019
+ system_prompt: Any = None
1020
+ chat_conversation: Any = []
1021
 
 
1022
  @classmethod
1023
  def _all_required_field_names(cls) -> Set:
1024
  _all_required_field_names = super(AzureChatOpenAI, cls)._all_required_field_names()
1025
  _all_required_field_names.update({'top_p', 'frequency_penalty', 'presence_penalty', 'logit_bias'})
1026
  return _all_required_field_names
1027
 
1028
+ def generate_prompt(
1029
+ self,
1030
+ prompts: List[PromptValue],
1031
+ stop: Optional[List[str]] = None,
1032
+ callbacks: Callbacks = None,
1033
+ **kwargs: Any,
1034
+ ) -> LLMResult:
1035
+ prompt_messages = self.get_messages(prompts)
1036
+ # prompt_messages = [p.to_messages() for p in prompts]
1037
+ return self.generate(prompt_messages, stop=stop, callbacks=callbacks, **kwargs)
1038
+
1039
+ async def agenerate_prompt(
1040
+ self,
1041
+ prompts: List[PromptValue],
1042
+ stop: Optional[List[str]] = None,
1043
+ callbacks: Callbacks = None,
1044
+ **kwargs: Any,
1045
+ ) -> LLMResult:
1046
+ prompt_messages = self.get_messages(prompts)
1047
+ # prompt_messages = [p.to_messages() for p in prompts]
1048
+ return await self.agenerate(
1049
+ prompt_messages, stop=stop, callbacks=callbacks, **kwargs
1050
+ )
1051
+
1052
 
1053
  class H2OAzureOpenAI(AzureOpenAI):
1054
  @classmethod
 
1133
  if 'meta/llama' in model_string:
1134
  temperature = max(0.01, temperature if do_sample else 0)
1135
  else:
1136
+ temperature = temperature if do_sample else 0
1137
  gen_kwargs = dict(temperature=temperature,
1138
  seed=1234,
1139
  max_length=max_new_tokens, # langchain
 
1149
  if system_prompt:
1150
  gen_kwargs.update(dict(system_prompt=system_prompt))
1151
 
1152
+ # replicate handles prompting if no conversation, but in general has no chat API, so do all handling of prompting in h2oGPT
 
1153
  if stream_output:
1154
  callbacks = [StreamingGradioCallbackHandler()]
1155
  streamer = callbacks[0] if stream_output else None
 
1188
  if inf_type == 'openai_chat' or inf_type == 'vllm_chat':
1189
  cls = H2OChatOpenAI
1190
  # FIXME: Support context, iinput
1191
+ if inf_type == 'vllm_chat':
1192
+ kwargs_extra.update(dict(tokenizer=tokenizer))
1193
  openai_api_key = openai.api_key
1194
  elif inf_type == 'openai_azure_chat':
1195
  cls = H2OAzureChatOpenAI
 
1248
  logit_bias=None if inf_type == 'vllm' else {},
1249
  max_retries=6,
1250
  streaming=stream_output,
1251
+ system_prompt=system_prompt,
1252
+ # chat_conversation=chat_conversation, # don't do here, not token aware
1253
  **kwargs_extra
1254
  )
1255
  streamer = callbacks[0] if stream_output else None
 
3582
  prompter = Prompter(prompt_type, prompt_dict, debug=False, chat=chat, stream_output=stream_output,
3583
  system_prompt=system_prompt)
3584
 
 
3585
  scores = []
3586
  chain = None
3587
 
 
3598
  missing_kwargs = [x for x in func_names if x not in sim_kwargs]
3599
  assert not missing_kwargs, "Missing: %s" % missing_kwargs
3600
  docs, chain, scores, \
3601
+ num_docs_before_cut, \
3602
+ use_llm_if_no_docs, top_k_docs_max_show = \
3603
  get_chain(**sim_kwargs)
3604
  if document_subset in non_query_commands:
3605
  formatted_doc_chunks = '\n\n'.join([get_url(x) + '\n\n' + x.page_content for x in docs])
 
3620
  ret, extra = get_sources_answer(*get_answer_args, **get_answer_kwargs)
3621
  yield dict(prompt=prompt_basic, response=formatted_doc_chunks, sources=extra, num_prompt_tokens=0)
3622
  return
3623
+ if langchain_mode not in langchain_modes_intrinsic and not use_llm_if_no_docs:
3624
+ if not docs:
3625
+ if langchain_action in [LangChainAction.SUMMARIZE_MAP.value,
3626
+ LangChainAction.SUMMARIZE_ALL.value,
3627
+ LangChainAction.SUMMARIZE_REFINE.value]:
3628
+ ret = 'No relevant documents to summarize.' if num_docs_before_cut else 'No documents to summarize.'
3629
+ else:
3630
+ ret = 'No relevant documents to query (for chatting with LLM, pick Resources->Collections->LLM).' if num_docs_before_cut else 'No documents to query (for chatting with LLM, pick Resources->Collections->LLM).'
 
 
3631
  extra = ''
3632
  yield dict(prompt=prompt_basic, response=ret, sources=extra, num_prompt_tokens=0)
3633
  return
3634
 
3635
+ # NOTE: If chain=None, could return if HF type (i.e. not langchain_only_model), but makes code too complex
3636
+ # only return now if no chain at all, e.g. when only returning sources
3637
+ if chain is None:
3638
  return
3639
 
3640
  # context stuff similar to used in evaluate()
 
3735
  prompt = prompt_basic
3736
  num_prompt_tokens = get_token_count(prompt, tokenizer)
3737
 
3738
+ if len(docs) == 0:
3739
+ # if no docs, then no sources to cite
3740
  ret = answer
3741
  extra = ''
3742
  yield dict(prompt=prompt, response=ret, sources=extra, num_prompt_tokens=num_prompt_tokens)
 
3895
  if text_context_list is None:
3896
  text_context_list = []
3897
 
3898
+ # NOTE: Could try to establish if pure llm mode or not, but makes code too complex
 
3899
  query_action = langchain_action == LangChainAction.QUERY.value
3900
  summarize_action = langchain_action in [LangChainAction.SUMMARIZE_MAP.value,
3901
  LangChainAction.SUMMARIZE_ALL.value,
 
3927
  add_search_to_context &= len(docs_search) > 0
3928
  top_k_docs_max_show = max(top_k_docs_max_show, len(docs_search))
3929
 
 
 
3930
  use_llm_if_no_docs = True
3931
 
3932
  from src.output_parser import H2OMRKLOutputParser
 
3954
 
3955
  docs = []
3956
  scores = []
 
3957
  num_docs_before_cut = 0
3958
  use_llm_if_no_docs = True
3959
+ return docs, target, scores, num_docs_before_cut, use_llm_if_no_docs, top_k_docs_max_show
3960
 
3961
  if LangChainAgent.COLLECTION.value in langchain_agents:
3962
  output_parser = H2OMRKLOutputParser()
 
3975
 
3976
  docs = []
3977
  scores = []
 
3978
  num_docs_before_cut = 0
3979
  use_llm_if_no_docs = True
3980
+ return docs, target, scores, num_docs_before_cut, use_llm_if_no_docs, top_k_docs_max_show
3981
 
3982
  if LangChainAgent.PYTHON.value in langchain_agents and inference_server.startswith('openai'):
3983
  chain = create_python_agent(
 
3993
 
3994
  docs = []
3995
  scores = []
 
3996
  num_docs_before_cut = 0
3997
  use_llm_if_no_docs = True
3998
+ return docs, target, scores, num_docs_before_cut, use_llm_if_no_docs, top_k_docs_max_show
3999
 
4000
  if LangChainAgent.PANDAS.value in langchain_agents and inference_server.startswith('openai_chat'):
4001
  # FIXME: DATA
 
4012
 
4013
  docs = []
4014
  scores = []
 
4015
  num_docs_before_cut = 0
4016
  use_llm_if_no_docs = True
4017
+ return docs, target, scores, num_docs_before_cut, use_llm_if_no_docs, top_k_docs_max_show
4018
 
4019
  if isinstance(document_choice, str):
4020
  document_choice = [document_choice]
 
4044
 
4045
  docs = []
4046
  scores = []
 
4047
  num_docs_before_cut = 0
4048
  use_llm_if_no_docs = True
4049
+ return docs, target, scores, num_docs_before_cut, use_llm_if_no_docs, top_k_docs_max_show
4050
 
4051
  if isinstance(document_choice, str):
4052
  document_choice = [document_choice]
 
4057
  document_choice_agent = [x for x in document_choice_agent if x.endswith('.csv')]
4058
  if LangChainAgent.CSV.value in langchain_agents and len(document_choice_agent) == 1 and document_choice_agent[
4059
  0].endswith(
4060
+ '.csv'):
4061
  data_file = document_choice[0]
4062
  if inference_server.startswith('openai_chat'):
4063
  chain = create_csv_agent(
 
4078
 
4079
  docs = []
4080
  scores = []
 
4081
  num_docs_before_cut = 0
4082
  use_llm_if_no_docs = True
4083
+ return docs, target, scores, num_docs_before_cut, use_llm_if_no_docs, top_k_docs_max_show
 
 
 
 
 
 
 
 
 
4084
 
4085
  # https://github.com/hwchase17/langchain/issues/1946
4086
  # FIXME: Seems to way to get size of chroma db to limit top_k_docs to avoid
 
4152
  pre_prompt_query, prompt_query,
4153
  pre_prompt_summary, prompt_summary,
4154
  langchain_action,
4155
+ True, # just to overestimate prompting
 
4156
  auto_reduce_chunks,
4157
  got_db_docs,
4158
  add_search_to_context)
 
4160
  max_input_tokens = get_max_input_tokens(llm=llm, tokenizer=tokenizer, inference_server=inference_server,
4161
  model_name=model_name, max_new_tokens=max_new_tokens)
4162
 
4163
+ if hasattr(db, '_persist_directory'):
4164
+ lock_file = get_db_lock_file(db, lock_type='sim')
4165
+ else:
4166
+ base_path = 'locks'
4167
+ base_path = makedirs(base_path, exist_ok=True, tmp_ok=True, use_base=True)
4168
+ name_path = "sim.lock"
4169
+ lock_file = os.path.join(base_path, name_path)
4170
+
4171
+ if not (isinstance(db, Chroma) or isinstance(db, ChromaMig) or ChromaMig.__name__ in str(db)):
4172
+ # only chroma supports filtering
4173
+ filter_kwargs = {}
4174
+ filter_kwargs_backup = {}
4175
+ else:
4176
+ import logging
4177
+ logging.getLogger("chromadb").setLevel(logging.ERROR)
4178
+ assert document_choice is not None, "Document choice was None"
4179
+ if isinstance(db, Chroma):
4180
+ filter_kwargs_backup = {} # shouldn't ever need backup
4181
+ # chroma >= 0.4
4182
+ if len(document_choice) == 0 or len(document_choice) >= 1 and document_choice[
4183
+ 0] == DocumentChoice.ALL.value:
4184
+ filter_kwargs = {"filter": {"chunk_id": {"$gte": 0}}} if query_action else \
4185
+ {"filter": {"chunk_id": {"$eq": -1}}}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4186
  else:
4187
+ if document_choice[0] == DocumentChoice.ALL.value:
4188
+ document_choice = document_choice[1:]
4189
+ if len(document_choice) == 0:
4190
+ filter_kwargs = {}
4191
+ elif len(document_choice) > 1:
 
 
 
 
4192
  or_filter = [
4193
+ {"$and": [dict(source={"$eq": x}), dict(chunk_id={"$gte": 0})]} if query_action else {
4194
+ "$and": [dict(source={"$eq": x}), dict(chunk_id={"$eq": -1})]}
 
4195
  for x in document_choice]
4196
  filter_kwargs = dict(filter={"$or": or_filter})
4197
+ else:
4198
+ # still chromadb UX bug, have to do different thing for 1 vs. 2+ docs when doing filter
 
 
 
 
4199
  one_filter = \
4200
+ [{"source": {"$eq": x}, "chunk_id": {"$gte": 0}} if query_action else {
4201
+ "source": {"$eq": x},
4202
+ "chunk_id": {
4203
+ "$eq": -1}}
 
 
 
4204
  for x in document_choice][0]
 
 
 
 
 
4205
 
4206
+ filter_kwargs = dict(filter={"$and": [dict(source=one_filter['source']),
4207
+ dict(chunk_id=one_filter['chunk_id'])]})
4208
+ else:
4209
+ # migration for chroma < 0.4
4210
+ if len(document_choice) == 0 or len(document_choice) >= 1 and document_choice[
4211
+ 0] == DocumentChoice.ALL.value:
4212
+ filter_kwargs = {"filter": {"chunk_id": {"$gte": 0}}} if query_action else \
4213
+ {"filter": {"chunk_id": {"$eq": -1}}}
4214
+ filter_kwargs_backup = {"filter": {"chunk_id": {"$gte": 0}}}
4215
+ elif len(document_choice) >= 2:
4216
+ if document_choice[0] == DocumentChoice.ALL.value:
4217
+ document_choice = document_choice[1:]
4218
+ or_filter = [
4219
+ {"source": {"$eq": x}, "chunk_id": {"$gte": 0}} if query_action else {"source": {"$eq": x},
4220
+ "chunk_id": {
4221
+ "$eq": -1}}
4222
+ for x in document_choice]
4223
+ filter_kwargs = dict(filter={"$or": or_filter})
4224
+ or_filter_backup = [
4225
+ {"source": {"$eq": x}} if query_action else {"source": {"$eq": x}}
4226
+ for x in document_choice]
4227
+ filter_kwargs_backup = dict(filter={"$or": or_filter_backup})
4228
+ elif len(document_choice) == 1:
4229
+ # degenerate UX bug in chroma
4230
+ one_filter = \
4231
+ [{"source": {"$eq": x}, "chunk_id": {"$gte": 0}} if query_action else {"source": {"$eq": x},
4232
+ "chunk_id": {
4233
+ "$eq": -1}}
4234
+ for x in document_choice][0]
4235
+ filter_kwargs = dict(filter=one_filter)
4236
+ one_filter_backup = \
4237
+ [{"source": {"$eq": x}} if query_action else {"source": {"$eq": x}}
4238
+ for x in document_choice][0]
4239
+ filter_kwargs_backup = dict(filter=one_filter_backup)
4240
  else:
4241
+ # shouldn't reach
4242
+ filter_kwargs = {}
4243
+ filter_kwargs_backup = {}
4244
+
4245
+ if document_subset == DocumentSubset.TopKSources.name or query in [None, '', '\n']:
4246
+ db_documents, db_metadatas = get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs,
4247
+ text_context_list=text_context_list)
4248
+ if len(db_documents) == 0 and filter_kwargs_backup:
4249
+ db_documents, db_metadatas = get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs_backup,
4250
+ text_context_list=text_context_list)
4251
+
4252
+ if top_k_docs == -1:
4253
+ top_k_docs = len(db_documents)
4254
+ # similar to langchain's chroma's _results_to_docs_and_scores
4255
+ docs_with_score = [(Document(page_content=result[0], metadata=result[1] or {}), 0)
4256
+ for result in zip(db_documents, db_metadatas)]
4257
+ # set in metadata original order of docs
4258
+ [x[0].metadata.update(orig_index=ii) for ii, x in enumerate(docs_with_score)]
4259
+
4260
+ # order documents
4261
+ doc_hashes = [x.get('doc_hash', 'None') for x in db_metadatas]
4262
+ if query_action:
4263
+ doc_chunk_ids = [x.get('chunk_id', 0) for x in db_metadatas]
4264
+ docs_with_score2 = [x for hx, cx, x in
4265
+ sorted(zip(doc_hashes, doc_chunk_ids, docs_with_score), key=lambda x: (x[0], x[1]))
4266
+ if cx >= 0]
4267
+ else:
4268
+ assert summarize_action
4269
+ doc_chunk_ids = [x.get('chunk_id', -1) for x in db_metadatas]
4270
+ docs_with_score2 = [x for hx, cx, x in
4271
+ sorted(zip(doc_hashes, doc_chunk_ids, docs_with_score), key=lambda x: (x[0], x[1]))
4272
+ if cx == -1
4273
+ ]
4274
+ if len(docs_with_score2) == 0 and len(docs_with_score) > 0:
4275
+ # old database without chunk_id, migration added 0 but didn't make -1 as that would be expensive
4276
+ # just do again and relax filter, let summarize operate on actual chunks if nothing else
4277
  docs_with_score2 = [x for hx, cx, x in
4278
+ sorted(zip(doc_hashes, doc_chunk_ids, docs_with_score),
4279
+ key=lambda x: (x[0], x[1]))
4280
  ]
4281
+ docs_with_score = docs_with_score2
 
 
 
 
 
 
 
4282
 
4283
+ docs_with_score = docs_with_score[:top_k_docs]
4284
+ docs = [x[0] for x in docs_with_score]
4285
+ scores = [x[1] for x in docs_with_score]
4286
+ num_docs_before_cut = len(docs)
4287
+ else:
4288
+ # for db=None too
4289
+ with filelock.FileLock(lock_file):
4290
+ docs_with_score, got_db_docs = get_docs_with_score(query, k_db, filter_kwargs, db, db_type,
4291
+ text_context_list=text_context_list,
4292
+ verbose=verbose)
4293
+ if len(docs_with_score) == 0 and filter_kwargs_backup:
4294
+ docs_with_score, got_db_docs = get_docs_with_score(query, k_db, filter_kwargs_backup, db,
4295
+ db_type,
4296
  text_context_list=text_context_list,
4297
  verbose=verbose)
4298
+
4299
+ tokenizer = get_tokenizer(db=db, llm=llm, tokenizer=tokenizer, inference_server=inference_server,
4300
+ use_openai_model=use_openai_model,
4301
+ db_type=db_type)
4302
+ # NOTE: if map_reduce, then no need to auto reduce chunks
4303
+ if query_action and (top_k_docs == -1 or auto_reduce_chunks):
4304
+ top_k_docs_tokenize = 100
4305
+ docs_with_score = docs_with_score[:top_k_docs_tokenize]
4306
+ if docs_with_score:
4307
+ estimated_prompt_no_docs = template.format(context='', question=query)
4308
+ else:
4309
+ estimated_prompt_no_docs = template_if_no_docs.format(context='', question=query)
4310
+
4311
+ model_max_length = tokenizer.model_max_length
4312
+ chat = True # FIXME?
4313
+
4314
+ # first docs_with_score are most important with highest score
4315
+ estimated_full_prompt, \
4316
+ instruction, iinput, context, \
4317
+ num_prompt_tokens, max_new_tokens, \
4318
+ num_prompt_tokens0, num_prompt_tokens_actual, \
4319
+ chat_index, external_handle_chat_conversation, \
4320
+ top_k_docs_trial, one_doc_size = \
4321
+ get_limited_prompt(estimated_prompt_no_docs,
4322
+ iinput,
4323
+ tokenizer,
4324
+ prompter=prompter,
4325
+ inference_server=inference_server,
4326
+ prompt_type=prompt_type,
4327
+ prompt_dict=prompt_dict,
4328
+ chat=chat,
4329
+ max_new_tokens=max_new_tokens,
4330
+ system_prompt=system_prompt,
4331
+ context=context,
4332
+ chat_conversation=chat_conversation,
4333
+ text_context_list=[x[0].page_content for x in docs_with_score],
4334
+ keep_sources_in_context=keep_sources_in_context,
4335
+ model_max_length=model_max_length,
4336
+ memory_restriction_level=memory_restriction_level,
4337
+ langchain_mode=langchain_mode,
4338
+ add_chat_history_to_context=add_chat_history_to_context,
4339
+ min_max_new_tokens=min_max_new_tokens,
4340
+ )
4341
+ if hasattr(llm, 'chat_conversation'):
4342
+ # means LLM will handle
4343
+ assert external_handle_chat_conversation, "Should be handling only externally"
4344
+ llm.chat_conversation = chat_conversation[chat_index:]
4345
+ if hasattr(llm, 'context'):
4346
+ llm.context = context
4347
+ if hasattr(llm, 'iinput'):
4348
+ llm.iinput = iinput
4349
+ # avoid craziness
4350
+ if 0 < top_k_docs_trial < max_chunks:
4351
  # avoid craziness
4352
+ if top_k_docs == -1:
4353
+ top_k_docs = top_k_docs_trial
 
 
 
 
 
 
 
 
 
 
4354
  else:
4355
+ top_k_docs = min(top_k_docs, top_k_docs_trial)
4356
+ elif top_k_docs_trial >= max_chunks:
4357
+ top_k_docs = max_chunks
4358
+ if top_k_docs > 0:
4359
+ docs_with_score = docs_with_score[:top_k_docs]
4360
+ elif one_doc_size is not None:
4361
+ docs_with_score = [docs_with_score[0][:one_doc_size]]
4362
  else:
4363
+ docs_with_score = []
4364
+ else:
4365
+ if total_tokens_for_docs is not None:
4366
+ # used to limit tokens for summarization, e.g. public instance
4367
+ top_k_docs, one_doc_size, num_doc_tokens = \
4368
+ get_docs_tokens(tokenizer,
4369
+ text_context_list=[x[0].page_content for x in docs_with_score],
4370
+ max_input_tokens=total_tokens_for_docs)
4371
 
4372
+ docs_with_score = docs_with_score[:top_k_docs]
4373
 
4374
+ # put most relevant chunks closest to question,
4375
+ # esp. if truncation occurs will be "oldest" or "farthest from response" text that is truncated
4376
+ # BUT: for small models, e.g. 6_9 pythia, if sees some stuff related to h2oGPT first, it can connect that and not listen to rest
4377
+ if docs_ordering_type in ['best_first']:
4378
+ pass
4379
+ elif docs_ordering_type in ['best_near_prompt', 'reverse_sort']:
4380
+ docs_with_score.reverse()
4381
+ elif docs_ordering_type in ['', None, 'reverse_ucurve_sort']:
4382
+ docs_with_score = reverse_ucurve_list(docs_with_score)
4383
+ else:
4384
+ raise ValueError("No such docs_ordering_type=%s" % docs_ordering_type)
4385
+
4386
+ # cut off so no high distance docs/sources considered
4387
+ num_docs_before_cut = len(docs_with_score)
4388
+ docs = [x[0] for x in docs_with_score if x[1] < cut_distance]
4389
+ scores = [x[1] for x in docs_with_score if x[1] < cut_distance]
4390
+ if len(scores) > 0 and verbose:
4391
+ print("Distance: min: %s max: %s mean: %s median: %s" %
4392
+ (scores[0], scores[-1], np.mean(scores), np.median(scores)), flush=True)
 
 
 
4393
 
4394
+ # if HF type and have no docs, could bail out, but makes code too complex
 
 
4395
 
4396
  if document_subset in non_query_commands:
4397
+ # no LLM use at all, just sources
4398
+ return docs, None, [], num_docs_before_cut, use_llm_if_no_docs, top_k_docs_max_show
4399
 
4400
  # FIXME: WIP
4401
  common_words_file = "data/NGSL_1.2_stats.csv.zip"
 
4413
 
4414
  if len(docs) == 0:
4415
  # avoid context == in prompt then
 
4416
  template = template_if_no_docs
4417
 
4418
  got_db_docs = got_db_docs and len(text_context_list) < len(docs)
 
4424
  pre_prompt_query, prompt_query,
4425
  pre_prompt_summary, prompt_summary,
4426
  langchain_action,
4427
+ got_db_docs,
 
4428
  auto_reduce_chunks,
4429
  got_db_docs,
4430
  add_search_to_context)
 
4442
  else:
4443
  # only if use_openai_model = True, unused normally except in testing
4444
  chain = load_qa_with_sources_chain(llm)
4445
+ chain_kwargs = dict(input_documents=docs, question=query)
 
 
 
4446
  target = wrapped_partial(chain, chain_kwargs)
4447
  elif langchain_action in [LangChainAction.SUMMARIZE_MAP.value,
4448
  LangChainAction.SUMMARIZE_REFINE,
 
4486
  else:
4487
  raise RuntimeError("No such langchain_action=%s" % langchain_action)
4488
 
4489
+ return docs, target, scores, num_docs_before_cut, use_llm_if_no_docs, top_k_docs_max_show
4490
 
4491
 
4492
  def get_max_model_length(llm=None, tokenizer=None, inference_server=None, model_name=None):
 
4532
  if hasattr(llm, 'pipeline') and hasattr(llm.pipeline, 'tokenizer'):
4533
  # more accurate
4534
  return llm.pipeline.tokenizer
4535
+ elif hasattr(llm, 'tokenizer') and llm.tokenizer is not None:
4536
  # e.g. TGI client mode etc.
4537
  return llm.tokenizer
4538
  elif inference_server in ['openai', 'openai_chat', 'openai_azure',
4539
+ 'openai_azure_chat'] and tokenizer is not None:
4540
  return tokenizer
4541
  elif isinstance(tokenizer, FakeTokenizer):
4542
  return tokenizer
 
4559
  pre_prompt_query, prompt_query,
4560
  pre_prompt_summary, prompt_summary,
4561
  langchain_action,
4562
+ got_docs,
 
4563
  auto_reduce_chunks,
4564
  got_db_docs,
4565
  add_search_to_context):
 
4581
  if langchain_action == LangChainAction.QUERY.value:
4582
  if iinput:
4583
  query = "%s\n%s" % (query, iinput)
4584
+ if not got_docs:
4585
  template_if_no_docs = template = """{context}{question}"""
4586
  else:
4587
  template = """%s
src/gradio_runner.py CHANGED
@@ -2881,6 +2881,7 @@ def go_gradio(**kwargs):
2881
  history = args_list[-1]
2882
  if not history:
2883
  history = []
 
2884
  prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
2885
  prompt_dict1 = args_list[eval_func_param_names.index('prompt_dict')]
2886
  langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
 
2881
  history = args_list[-1]
2882
  if not history:
2883
  history = []
2884
+ # NOTE: For these, could check if None, then automatically use CLI values, but too complex behavior
2885
  prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
2886
  prompt_dict1 = args_list[eval_func_param_names.index('prompt_dict')]
2887
  langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]