pseudotensor commited on
Commit
edf6dca
·
1 Parent(s): ce8ae40

Update with h2oGPT hash a9971663accc92add02bde0be7622726ef2db350

Browse files
client_test.py CHANGED
@@ -7,7 +7,7 @@ python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6_9b
7
 
8
  NOTE: For private models, add --use-auth_token=True
9
 
10
- NOTE: --infer_devices=True (default) must be used for multi-GPU in case see failures with cuda:x cuda:y mismatches.
11
  Currently, this will force model to be on a single GPU.
12
 
13
  Then run this client as:
@@ -98,7 +98,8 @@ def get_args(prompt, prompt_type, chat=False, stream_output=False,
98
  top_k_docs=top_k_docs,
99
  chunk=True,
100
  chunk_size=512,
101
- document_choice=[DocumentChoices.All_Relevant.name],
 
102
  )
103
  from evaluate_params import eval_func_param_names
104
  assert len(set(eval_func_param_names).difference(set(list(kwargs.keys())))) == 0
@@ -203,7 +204,8 @@ def run_client_nochat_api_lean_morestuff(prompt, prompt_type='human_bot', max_ne
203
  langchain_mode='Disabled',
204
  langchain_action=LangChainAction.QUERY.value,
205
  top_k_docs=4,
206
- document_choice=['All'],
 
207
  )
208
 
209
  api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
 
7
 
8
  NOTE: For private models, add --use-auth_token=True
9
 
10
+ NOTE: --use_gpu_id=True (default) must be used for multi-GPU in case see failures with cuda:x cuda:y mismatches.
11
  Currently, this will force model to be on a single GPU.
12
 
13
  Then run this client as:
 
98
  top_k_docs=top_k_docs,
99
  chunk=True,
100
  chunk_size=512,
101
+ document_subset=DocumentChoices.Relevant.name,
102
+ document_choice=[],
103
  )
104
  from evaluate_params import eval_func_param_names
105
  assert len(set(eval_func_param_names).difference(set(list(kwargs.keys())))) == 0
 
204
  langchain_mode='Disabled',
205
  langchain_action=LangChainAction.QUERY.value,
206
  top_k_docs=4,
207
+ document_subset=DocumentChoices.Relevant.name,
208
+ document_choice=[],
209
  )
210
 
211
  api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
enums.py CHANGED
@@ -28,16 +28,21 @@ class PromptType(Enum):
28
  gptj = 22
29
  prompt_answer_openllama = 23
30
  vicuna11 = 24
 
 
 
31
 
32
 
33
  class DocumentChoices(Enum):
34
- All_Relevant = 0
35
- All_Relevant_Only_Sources = 1
36
- Only_All_Sources = 2
37
- Just_LLM = 3
38
 
39
 
40
- non_query_commands = [DocumentChoices.All_Relevant_Only_Sources.name, DocumentChoices.Only_All_Sources.name]
 
 
 
41
 
42
 
43
  class LangChainMode(Enum):
@@ -60,7 +65,7 @@ class LangChainAction(Enum):
60
 
61
  QUERY = "Query"
62
  # WIP:
63
- #SUMMARIZE_MAP = "Summarize_map_reduce"
64
  SUMMARIZE_MAP = "Summarize"
65
  SUMMARIZE_ALL = "Summarize_all"
66
  SUMMARIZE_REFINE = "Summarize_refine"
@@ -68,7 +73,6 @@ class LangChainAction(Enum):
68
 
69
  no_server_str = no_lora_str = no_model_str = '[None/Remove]'
70
 
71
-
72
  # from site-packages/langchain/llms/openai.py
73
  # but needed since ChatOpenAI doesn't have this information
74
  model_token_mapping = {
@@ -77,7 +81,7 @@ model_token_mapping = {
77
  "gpt-4-32k": 32768,
78
  "gpt-4-32k-0314": 32768,
79
  "gpt-3.5-turbo": 4096,
80
- "gpt-3.5-turbo-16k": 16*1024,
81
  "gpt-3.5-turbo-0301": 4096,
82
  "text-ada-001": 2049,
83
  "ada": 2049,
@@ -94,6 +98,5 @@ model_token_mapping = {
94
  "code-cushman-001": 2048,
95
  }
96
 
97
-
98
  source_prefix = "Sources [Score | Link]:"
99
  source_postfix = "End Sources<p>"
 
28
  gptj = 22
29
  prompt_answer_openllama = 23
30
  vicuna11 = 24
31
+ mptinstruct = 25
32
+ mptchat = 26
33
+ falcon = 27
34
 
35
 
36
  class DocumentChoices(Enum):
37
+ Relevant = 0
38
+ Sources = 1
39
+ All = 2
 
40
 
41
 
42
+ non_query_commands = [
43
+ DocumentChoices.Sources.name,
44
+ DocumentChoices.All.name
45
+ ]
46
 
47
 
48
  class LangChainMode(Enum):
 
65
 
66
  QUERY = "Query"
67
  # WIP:
68
+ # SUMMARIZE_MAP = "Summarize_map_reduce"
69
  SUMMARIZE_MAP = "Summarize"
70
  SUMMARIZE_ALL = "Summarize_all"
71
  SUMMARIZE_REFINE = "Summarize_refine"
 
73
 
74
  no_server_str = no_lora_str = no_model_str = '[None/Remove]'
75
 
 
76
  # from site-packages/langchain/llms/openai.py
77
  # but needed since ChatOpenAI doesn't have this information
78
  model_token_mapping = {
 
81
  "gpt-4-32k": 32768,
82
  "gpt-4-32k-0314": 32768,
83
  "gpt-3.5-turbo": 4096,
84
+ "gpt-3.5-turbo-16k": 16 * 1024,
85
  "gpt-3.5-turbo-0301": 4096,
86
  "text-ada-001": 2049,
87
  "ada": 2049,
 
98
  "code-cushman-001": 2048,
99
  }
100
 
 
101
  source_prefix = "Sources [Score | Link]:"
102
  source_postfix = "End Sources<p>"
evaluate_params.py CHANGED
@@ -34,6 +34,7 @@ eval_func_param_names = ['instruction',
34
  'top_k_docs',
35
  'chunk',
36
  'chunk_size',
 
37
  'document_choice',
38
  ]
39
 
@@ -43,5 +44,4 @@ for k in no_default_param_names:
43
  if k in eval_func_param_names_defaults:
44
  eval_func_param_names_defaults.remove(k)
45
 
46
-
47
  eval_extra_columns = ['prompt', 'response', 'score']
 
34
  'top_k_docs',
35
  'chunk',
36
  'chunk_size',
37
+ 'document_subset',
38
  'document_choice',
39
  ]
40
 
 
44
  if k in eval_func_param_names_defaults:
45
  eval_func_param_names_defaults.remove(k)
46
 
 
47
  eval_extra_columns = ['prompt', 'response', 'score']
gen.py CHANGED
@@ -32,7 +32,8 @@ from enums import DocumentChoices, LangChainMode, no_lora_str, model_token_mappi
32
  source_postfix, LangChainAction
33
  from loaders import get_loaders
34
  from utils import set_seed, clear_torch_cache, save_generate_output, NullContext, wrapped_partial, EThread, get_githash, \
35
- import_matplotlib, get_device, makedirs, get_kwargs, start_faulthandler, get_hf_server, FakeTokenizer, remove
 
36
 
37
  start_faulthandler()
38
  import_matplotlib()
@@ -60,7 +61,9 @@ def main(
60
  load_8bit: bool = False,
61
  load_4bit: bool = False,
62
  load_half: bool = True,
63
- infer_devices: bool = True,
 
 
64
  base_model: str = '',
65
  tokenizer_base_model: str = '',
66
  lora_weights: str = "",
@@ -91,7 +94,7 @@ def main(
91
  memory_restriction_level: int = None,
92
  debug: bool = False,
93
  save_dir: str = None,
94
- share: bool = True,
95
  local_files_only: bool = False,
96
  resume_download: bool = True,
97
  use_auth_token: Union[str, bool] = False,
@@ -138,14 +141,15 @@ def main(
138
  eval_prompts_only_seed: int = 1234,
139
  eval_as_output: bool = False,
140
 
141
- langchain_mode: str = 'Disabled',
142
  langchain_action: str = LangChainAction.QUERY.value,
143
  force_langchain_evaluate: bool = False,
144
  visible_langchain_modes: list = ['UserData', 'MyData'],
145
  # WIP:
146
  # visible_langchain_actions: list = langchain_actions.copy(),
147
  visible_langchain_actions: list = [LangChainAction.QUERY.value, LangChainAction.SUMMARIZE_MAP.value],
148
- document_choice: list = [DocumentChoices.All_Relevant.name],
 
149
  user_path: str = None,
150
  detect_user_path_changes_every_query: bool = False,
151
  load_db_if_exists: bool = True,
@@ -177,11 +181,13 @@ def main(
177
  :param load_8bit: load model in 8-bit using bitsandbytes
178
  :param load_4bit: load model in 4-bit using bitsandbytes
179
  :param load_half: load model in float16
180
- :param infer_devices: whether to control devices with gpu_id. If False, then spread across GPUs
 
 
181
  :param base_model: model HF-type name. If use --base_model to preload model, cannot unload in gradio in models tab
182
  :param tokenizer_base_model: tokenizer HF-type name. Usually not required, inferred from base_model.
183
  :param lora_weights: LORA weights path/HF link
184
- :param gpu_id: if infer_devices, then use gpu_id for cuda device ID, or auto mode if gpu_id != -1
185
  :param compile_model Whether to compile the model
186
  :param use_cache: Whether to use caching in model (some models fail when multiple threads use)
187
  :param inference_server: Consume base_model as type of model at this address
@@ -289,7 +295,8 @@ def main(
289
  Default: If only want to consume local files, e.g. prepared by make_db.py, only include ['UserData']
290
  FIXME: Avoid 'All' for now, not implemented
291
  :param visible_langchain_actions: Which actions to allow
292
- :param document_choice: Default document choice when taking subset of collection
 
293
  :param load_db_if_exists: Whether to load chroma db if exists or re-generate db
294
  :param keep_sources_in_context: Whether to keep url sources in context, not helpful usually
295
  :param db_type: 'faiss' for in-memory or 'chroma' or 'weaviate' for persisted on disk
@@ -379,10 +386,12 @@ def main(
379
  # allow enabling langchain via ENV
380
  # FIRST PLACE where LangChain referenced, but no imports related to it
381
  langchain_mode = os.environ.get("LANGCHAIN_MODE", langchain_mode)
382
- assert langchain_mode in langchain_modes, "Invalid langchain_mode %s" % langchain_mode
 
383
  visible_langchain_modes = ast.literal_eval(os.environ.get("visible_langchain_modes", str(visible_langchain_modes)))
384
  if langchain_mode not in visible_langchain_modes and langchain_mode in langchain_modes:
385
- visible_langchain_modes += [langchain_mode]
 
386
 
387
  assert langchain_action in langchain_actions, "Invalid langchain_action %s" % langchain_action
388
 
@@ -392,6 +401,25 @@ def main(
392
  if LangChainMode.USER_DATA.value not in visible_langchain_modes:
393
  allow_upload_to_user_data = False
394
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395
  if is_public:
396
  allow_upload_to_user_data = False
397
  input_lines = 1 # ensure set, for ease of use
@@ -458,7 +486,9 @@ def main(
458
  load_8bit = False
459
  load_4bit = False
460
  load_half = False
461
- infer_devices = False
 
 
462
  torch.backends.cudnn.benchmark = True
463
  torch.backends.cudnn.enabled = False
464
  torch.set_default_dtype(torch.float32)
@@ -714,7 +744,9 @@ def get_config(base_model,
714
  return config, model
715
 
716
 
717
- def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward_type,
 
 
718
  config, model,
719
  gpu_id=0,
720
  ):
@@ -761,16 +793,25 @@ def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward
761
  load_in_8bit = model_kwargs.get('load_in_8bit', False)
762
  load_in_4bit = model_kwargs.get('load_in_4bit', False)
763
  model_kwargs['device_map'] = device_map
 
764
  pop_unused_model_kwargs(model_kwargs)
765
 
766
- if load_in_8bit or load_in_4bit or not load_half:
767
- model = model_loader.from_pretrained(
 
 
 
 
 
 
 
 
768
  base_model,
769
  config=config,
770
  **model_kwargs,
771
  )
772
  else:
773
- model = model_loader.from_pretrained(
774
  base_model,
775
  config=config,
776
  **model_kwargs,
@@ -778,7 +819,7 @@ def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward
778
  return model
779
 
780
 
781
- def get_client_from_inference_server(inference_server, raise_connection_exception=False):
782
  inference_server, headers = get_hf_server(inference_server)
783
  # preload client since slow for gradio case especially
784
  from gradio_utils.grclient import GradioClient
@@ -786,7 +827,7 @@ def get_client_from_inference_server(inference_server, raise_connection_exceptio
786
  hf_client = None
787
  if headers is None:
788
  try:
789
- print("GR Client Begin: %s" % inference_server, flush=True)
790
  # first do sanity check if alive, else gradio client takes too long by default
791
  requests.get(inference_server, timeout=int(os.getenv('REQUEST_TIMEOUT', '30')))
792
  gr_client = GradioClient(inference_server)
@@ -794,19 +835,19 @@ def get_client_from_inference_server(inference_server, raise_connection_exceptio
794
  except (OSError, ValueError) as e:
795
  # Occurs when wrong endpoint and should have been HF client, so don't hard raise, just move to HF
796
  gr_client = None
797
- print("GR Client Failed %s: %s" % (inference_server, str(e)), flush=True)
798
  except (ConnectTimeoutError, ConnectTimeout, MaxRetryError, ConnectionError, ConnectionError2,
799
  JSONDecodeError, ReadTimeout2, KeyError) as e:
800
  t, v, tb = sys.exc_info()
801
  ex = ''.join(traceback.format_exception(t, v, tb))
802
- print("GR Client Failed %s: %s" % (inference_server, str(ex)), flush=True)
803
  if raise_connection_exception:
804
  raise
805
 
806
  if gr_client is None:
807
  res = None
808
  from text_generation import Client as HFClient
809
- print("HF Client Begin: %s" % inference_server)
810
  try:
811
  hf_client = HFClient(inference_server, headers=headers, timeout=int(os.getenv('REQUEST_TIMEOUT', '30')))
812
  # quick check valid TGI endpoint
@@ -817,10 +858,10 @@ def get_client_from_inference_server(inference_server, raise_connection_exceptio
817
  hf_client = None
818
  t, v, tb = sys.exc_info()
819
  ex = ''.join(traceback.format_exception(t, v, tb))
820
- print("HF Client Failed %s: %s" % (inference_server, str(ex)))
821
  if raise_connection_exception:
822
  raise
823
- print("HF Client End: %s %s" % (inference_server, res))
824
  return inference_server, gr_client, hf_client
825
 
826
 
@@ -828,7 +869,9 @@ def get_model(
828
  load_8bit: bool = False,
829
  load_4bit: bool = False,
830
  load_half: bool = True,
831
- infer_devices: bool = True,
 
 
832
  base_model: str = '',
833
  inference_server: str = "",
834
  tokenizer_base_model: str = '',
@@ -850,7 +893,9 @@ def get_model(
850
  :param load_8bit: load model in 8-bit, not supported by all models
851
  :param load_4bit: load model in 4-bit, not supported by all models
852
  :param load_half: load model in 16-bit
853
- :param infer_devices: Use torch infer of optimal placement of layers on devices (for non-lora case)
 
 
854
  For non-LORA case, False will spread shards across multiple GPUs, but this can lead to cuda:x cuda:y mismatches
855
  So it is not the default
856
  :param base_model: name/path of base model
@@ -868,8 +913,7 @@ def get_model(
868
  :param verbose:
869
  :return:
870
  """
871
- if verbose:
872
- print("Get %s model" % base_model, flush=True)
873
 
874
  triton_attn = False
875
  long_sequence = True
@@ -893,7 +937,8 @@ def get_model(
893
  print("Detected as llama type from"
894
  " config (%s) or name (%s)" % (llama_type_from_config, llama_type_from_name), flush=True)
895
 
896
- model_loader, tokenizer_loader = get_loaders(model_name=base_model, reward_type=reward_type, llama_type=llama_type)
 
897
 
898
  tokenizer_kwargs = dict(local_files_only=local_files_only,
899
  resume_download=resume_download,
@@ -917,7 +962,8 @@ def get_model(
917
  tokenizer = FakeTokenizer()
918
 
919
  if isinstance(inference_server, str) and inference_server.startswith("http"):
920
- inference_server, gr_client, hf_client = get_client_from_inference_server(inference_server)
 
921
  client = gr_client or hf_client
922
  # Don't return None, None for model, tokenizer so triggers
923
  return client, tokenizer, 'http'
@@ -937,7 +983,9 @@ def get_model(
937
  return get_hf_model(load_8bit=load_8bit,
938
  load_4bit=load_4bit,
939
  load_half=load_half,
940
- infer_devices=infer_devices,
 
 
941
  base_model=base_model,
942
  tokenizer_base_model=tokenizer_base_model,
943
  lora_weights=lora_weights,
@@ -961,7 +1009,9 @@ def get_model(
961
  def get_hf_model(load_8bit: bool = False,
962
  load_4bit: bool = False,
963
  load_half: bool = True,
964
- infer_devices: bool = True,
 
 
965
  base_model: str = '',
966
  tokenizer_base_model: str = '',
967
  lora_weights: str = "",
@@ -998,7 +1048,8 @@ def get_hf_model(load_8bit: bool = False,
998
  "Please choose a base model with --base_model (CLI) or load one from Models Tab (gradio)"
999
  )
1000
 
1001
- model_loader, tokenizer_loader = get_loaders(model_name=base_model, reward_type=reward_type, llama_type=llama_type)
 
1002
 
1003
  config, _ = get_config(base_model, return_model=False, raise_exception=True, **config_kwargs)
1004
 
@@ -1015,7 +1066,7 @@ def get_hf_model(load_8bit: bool = False,
1015
  device=0 if device == "cuda" else -1,
1016
  torch_dtype=torch.float16 if device == 'cuda' else torch.float32)
1017
  else:
1018
- assert device in ["cuda", "cpu"], "Unsupported device %s" % device
1019
  model_kwargs = dict(local_files_only=local_files_only,
1020
  torch_dtype=torch.float16 if device == 'cuda' else torch.float32,
1021
  resume_download=resume_download,
@@ -1024,11 +1075,16 @@ def get_hf_model(load_8bit: bool = False,
1024
  offload_folder=offload_folder,
1025
  )
1026
  if 'mbart-' not in base_model.lower() and 'mpt-' not in base_model.lower():
 
 
 
 
1027
  model_kwargs.update(dict(load_in_8bit=load_8bit,
1028
  load_in_4bit=load_4bit,
1029
- device_map={"": 0} if (load_8bit or load_4bit) and device == 'cuda' else "auto",
1030
  ))
1031
  if 'mpt-' in base_model.lower() and gpu_id is not None and gpu_id >= 0:
 
1032
  model_kwargs.update(dict(device_map={"": gpu_id} if device == 'cuda' else "cpu"))
1033
 
1034
  if 'OpenAssistant/reward-model'.lower() in base_model.lower():
@@ -1038,29 +1094,32 @@ def get_hf_model(load_8bit: bool = False,
1038
  pop_unused_model_kwargs(model_kwargs)
1039
 
1040
  if not lora_weights:
1041
- with torch.device(device):
 
 
1042
 
1043
- if infer_devices:
1044
  config, model = get_config(base_model, return_model=True, raise_exception=True, **config_kwargs)
1045
- model = get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward_type,
 
1046
  config, model,
1047
  gpu_id=gpu_id,
1048
  )
1049
  else:
1050
  config, _ = get_config(base_model, **config_kwargs)
1051
- if load_half and not (load_8bit or load_4bit):
1052
- model = model_loader.from_pretrained(
1053
  base_model,
1054
  config=config,
1055
  **model_kwargs).half()
1056
  else:
1057
- model = model_loader.from_pretrained(
1058
  base_model,
1059
  config=config,
1060
  **model_kwargs)
1061
  elif load_8bit or load_4bit:
1062
  config, _ = get_config(base_model, **config_kwargs)
1063
- model = model_loader.from_pretrained(
1064
  base_model,
1065
  config=config,
1066
  **model_kwargs
@@ -1080,7 +1139,7 @@ def get_hf_model(load_8bit: bool = False,
1080
  else:
1081
  with torch.device(device):
1082
  config, _ = get_config(base_model, raise_exception=True, **config_kwargs)
1083
- model = model_loader.from_pretrained(
1084
  base_model,
1085
  config=config,
1086
  **model_kwargs
@@ -1097,7 +1156,7 @@ def get_hf_model(load_8bit: bool = False,
1097
  offload_folder=offload_folder,
1098
  device_map="auto",
1099
  )
1100
- if load_half:
1101
  model.half()
1102
 
1103
  # unwind broken decapoda-research config
@@ -1156,7 +1215,8 @@ def get_score_model(score_model: str = None,
1156
  load_8bit: bool = False,
1157
  load_4bit: bool = False,
1158
  load_half: bool = True,
1159
- infer_devices: bool = True,
 
1160
  base_model: str = '',
1161
  inference_server: str = '',
1162
  tokenizer_base_model: str = '',
@@ -1177,6 +1237,8 @@ def get_score_model(score_model: str = None,
1177
  load_8bit = False
1178
  load_4bit = False
1179
  load_half = False
 
 
1180
  base_model = score_model.strip()
1181
  tokenizer_base_model = ''
1182
  lora_weights = ''
@@ -1219,6 +1281,7 @@ def evaluate(
1219
  top_k_docs,
1220
  chunk,
1221
  chunk_size,
 
1222
  document_choice,
1223
  # END NOTE: Examples must have same order of parameters
1224
  src_lang=None,
@@ -1435,6 +1498,7 @@ def evaluate(
1435
  chunk_size=chunk_size,
1436
  langchain_mode=langchain_mode,
1437
  langchain_action=langchain_action,
 
1438
  document_choice=document_choice,
1439
  db_type=db_type,
1440
  top_k_docs=top_k_docs,
@@ -1462,6 +1526,7 @@ def evaluate(
1462
  inference_server=inference_server,
1463
  langchain_mode=langchain_mode,
1464
  langchain_action=langchain_action,
 
1465
  document_choice=document_choice,
1466
  num_prompt_tokens=num_prompt_tokens,
1467
  instruction=instruction,
@@ -1563,7 +1628,8 @@ def evaluate(
1563
  gr_client = None
1564
  hf_client = model
1565
  else:
1566
- inference_server, gr_client, hf_client = get_client_from_inference_server(inference_server)
 
1567
 
1568
  # quick sanity check to avoid long timeouts, just see if can reach server
1569
  requests.get(inference_server, timeout=int(os.getenv('REQUEST_TIMEOUT_FAST', '10')))
@@ -1631,7 +1697,8 @@ def evaluate(
1631
  top_k_docs=top_k_docs,
1632
  chunk=chunk,
1633
  chunk_size=chunk_size,
1634
- document_choice=[DocumentChoices.All_Relevant.name],
 
1635
  )
1636
  api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
1637
  if not stream_output:
@@ -1830,7 +1897,7 @@ def evaluate(
1830
 
1831
  with torch.no_grad():
1832
  have_lora_weights = lora_weights not in [no_lora_str, '', None]
1833
- context_class_cast = NullContext if device == 'cpu' or have_lora_weights else torch.autocast
1834
  with context_class_cast(device):
1835
  # protection for gradio not keeping track of closed users,
1836
  # else hit bitsandbytes lack of thread safety:
@@ -2207,8 +2274,8 @@ y = np.random.randint(0, 1, 100)
2207
 
2208
  # move to correct position
2209
  for example in examples:
2210
- example += [chat, '', '', 'Disabled', LangChainAction.QUERY.value,
2211
- top_k_docs, chunk, chunk_size, [DocumentChoices.All_Relevant.name]
2212
  ]
2213
  # adjust examples if non-chat mode
2214
  if not chat:
@@ -2431,9 +2498,9 @@ def entrypoint_main():
2431
 
2432
  python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='human_bot' --lora_weights='GPT-NeoXT-Chat-Base-20B.merged.json.8_epochs.57b2892c53df5b8cefac45f84d019cace803ef26.28'
2433
 
2434
- must have 4*48GB GPU and run without 8bit in order for sharding to work with infer_devices=False
2435
  can also pass --prompt_type='human_bot' and model can somewhat handle instructions without being instruct tuned
2436
- python generate.py --base_model=decapoda-research/llama-65b-hf --load_8bit=False --infer_devices=False --prompt_type='human_bot'
2437
 
2438
  python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6_9b
2439
  """
 
32
  source_postfix, LangChainAction
33
  from loaders import get_loaders
34
  from utils import set_seed, clear_torch_cache, save_generate_output, NullContext, wrapped_partial, EThread, get_githash, \
35
+ import_matplotlib, get_device, makedirs, get_kwargs, start_faulthandler, get_hf_server, FakeTokenizer, remove, \
36
+ have_langchain
37
 
38
  start_faulthandler()
39
  import_matplotlib()
 
61
  load_8bit: bool = False,
62
  load_4bit: bool = False,
63
  load_half: bool = True,
64
+ load_gptq: str = '',
65
+ use_safetensors: bool = False,
66
+ use_gpu_id: bool = True,
67
  base_model: str = '',
68
  tokenizer_base_model: str = '',
69
  lora_weights: str = "",
 
94
  memory_restriction_level: int = None,
95
  debug: bool = False,
96
  save_dir: str = None,
97
+ share: bool = False,
98
  local_files_only: bool = False,
99
  resume_download: bool = True,
100
  use_auth_token: Union[str, bool] = False,
 
141
  eval_prompts_only_seed: int = 1234,
142
  eval_as_output: bool = False,
143
 
144
+ langchain_mode: str = None,
145
  langchain_action: str = LangChainAction.QUERY.value,
146
  force_langchain_evaluate: bool = False,
147
  visible_langchain_modes: list = ['UserData', 'MyData'],
148
  # WIP:
149
  # visible_langchain_actions: list = langchain_actions.copy(),
150
  visible_langchain_actions: list = [LangChainAction.QUERY.value, LangChainAction.SUMMARIZE_MAP.value],
151
+ document_subset: str = DocumentChoices.Relevant.name,
152
+ document_choice: list = [],
153
  user_path: str = None,
154
  detect_user_path_changes_every_query: bool = False,
155
  load_db_if_exists: bool = True,
 
181
  :param load_8bit: load model in 8-bit using bitsandbytes
182
  :param load_4bit: load model in 4-bit using bitsandbytes
183
  :param load_half: load model in float16
184
+ :param load_gptq: to load model with GPTQ, put model_basename here, e.g. gptq_model-4bit--1g
185
+ :param use_safetensors: to use safetensors version (assumes file/HF points to safe tensors version)
186
+ :param use_gpu_id: whether to control devices with gpu_id. If False, then spread across GPUs
187
  :param base_model: model HF-type name. If use --base_model to preload model, cannot unload in gradio in models tab
188
  :param tokenizer_base_model: tokenizer HF-type name. Usually not required, inferred from base_model.
189
  :param lora_weights: LORA weights path/HF link
190
+ :param gpu_id: if use_gpu_id, then use gpu_id for cuda device ID, or auto mode if gpu_id != -1
191
  :param compile_model Whether to compile the model
192
  :param use_cache: Whether to use caching in model (some models fail when multiple threads use)
193
  :param inference_server: Consume base_model as type of model at this address
 
295
  Default: If only want to consume local files, e.g. prepared by make_db.py, only include ['UserData']
296
  FIXME: Avoid 'All' for now, not implemented
297
  :param visible_langchain_actions: Which actions to allow
298
+ :param document_subset: Default document choice when taking subset of collection
299
+ :param document_choice: Chosen document(s) by internal name
300
  :param load_db_if_exists: Whether to load chroma db if exists or re-generate db
301
  :param keep_sources_in_context: Whether to keep url sources in context, not helpful usually
302
  :param db_type: 'faiss' for in-memory or 'chroma' or 'weaviate' for persisted on disk
 
386
  # allow enabling langchain via ENV
387
  # FIRST PLACE where LangChain referenced, but no imports related to it
388
  langchain_mode = os.environ.get("LANGCHAIN_MODE", langchain_mode)
389
+ if langchain_mode is not None:
390
+ assert langchain_mode in langchain_modes, "Invalid langchain_mode %s" % langchain_mode
391
  visible_langchain_modes = ast.literal_eval(os.environ.get("visible_langchain_modes", str(visible_langchain_modes)))
392
  if langchain_mode not in visible_langchain_modes and langchain_mode in langchain_modes:
393
+ if langchain_mode is not None:
394
+ visible_langchain_modes += [langchain_mode]
395
 
396
  assert langchain_action in langchain_actions, "Invalid langchain_action %s" % langchain_action
397
 
 
401
  if LangChainMode.USER_DATA.value not in visible_langchain_modes:
402
  allow_upload_to_user_data = False
403
 
404
+ # auto-set langchain_mode
405
+ if have_langchain and langchain_mode is None:
406
+ if allow_upload_to_user_data and not is_public and user_path:
407
+ langchain_mode = 'UserData'
408
+ print("Auto set langchain_mode=%s" % langchain_mode, flush=True)
409
+ elif allow_upload_to_my_data:
410
+ langchain_mode = 'MyData'
411
+ print("Auto set langchain_mode=%s."
412
+ " To use UserData to pull files from disk,"
413
+ " set user_path and ensure allow_upload_to_user_data=True" % langchain_mode, flush=True)
414
+ else:
415
+ raise RuntimeError("Please pass --langchain_mode=<chosen mode> out of %s" % langchain_modes)
416
+ if not have_langchain and langchain_mode not in [None, LangChainMode.DISABLED.value, LangChainMode.LLM.value, LangChainMode.CHAT_LLM.value]:
417
+ raise RuntimeError("Asked for LangChain mode but langchain python package cannot be found.")
418
+ if langchain_mode is None:
419
+ # if not set yet, disable
420
+ langchain_mode = LangChainMode.DISABLED.value
421
+ print("Auto set langchain_mode=%s" % langchain_mode, flush=True)
422
+
423
  if is_public:
424
  allow_upload_to_user_data = False
425
  input_lines = 1 # ensure set, for ease of use
 
486
  load_8bit = False
487
  load_4bit = False
488
  load_half = False
489
+ load_gptq = ''
490
+ use_safetensors = False
491
+ use_gpu_id = False
492
  torch.backends.cudnn.benchmark = True
493
  torch.backends.cudnn.enabled = False
494
  torch.set_default_dtype(torch.float32)
 
744
  return config, model
745
 
746
 
747
+ def get_non_lora_model(base_model, model_loader, load_half,
748
+ load_gptq, use_safetensors,
749
+ model_kwargs, reward_type,
750
  config, model,
751
  gpu_id=0,
752
  ):
 
793
  load_in_8bit = model_kwargs.get('load_in_8bit', False)
794
  load_in_4bit = model_kwargs.get('load_in_4bit', False)
795
  model_kwargs['device_map'] = device_map
796
+ model_kwargs['use_safetensors'] = use_safetensors
797
  pop_unused_model_kwargs(model_kwargs)
798
 
799
+ if load_gptq:
800
+ model_kwargs.pop('torch_dtype', None)
801
+ model_kwargs.pop('device_map')
802
+ model = model_loader(
803
+ model_name_or_path=base_model,
804
+ model_basename=load_gptq,
805
+ **model_kwargs,
806
+ )
807
+ elif load_in_8bit or load_in_4bit or not load_half:
808
+ model = model_loader(
809
  base_model,
810
  config=config,
811
  **model_kwargs,
812
  )
813
  else:
814
+ model = model_loader(
815
  base_model,
816
  config=config,
817
  **model_kwargs,
 
819
  return model
820
 
821
 
822
+ def get_client_from_inference_server(inference_server, base_model=None, raise_connection_exception=False):
823
  inference_server, headers = get_hf_server(inference_server)
824
  # preload client since slow for gradio case especially
825
  from gradio_utils.grclient import GradioClient
 
827
  hf_client = None
828
  if headers is None:
829
  try:
830
+ print("GR Client Begin: %s %s" % (inference_server, base_model), flush=True)
831
  # first do sanity check if alive, else gradio client takes too long by default
832
  requests.get(inference_server, timeout=int(os.getenv('REQUEST_TIMEOUT', '30')))
833
  gr_client = GradioClient(inference_server)
 
835
  except (OSError, ValueError) as e:
836
  # Occurs when wrong endpoint and should have been HF client, so don't hard raise, just move to HF
837
  gr_client = None
838
+ print("GR Client Failed %s %s: %s" % (inference_server, base_model, str(e)), flush=True)
839
  except (ConnectTimeoutError, ConnectTimeout, MaxRetryError, ConnectionError, ConnectionError2,
840
  JSONDecodeError, ReadTimeout2, KeyError) as e:
841
  t, v, tb = sys.exc_info()
842
  ex = ''.join(traceback.format_exception(t, v, tb))
843
+ print("GR Client Failed %s %s: %s" % (inference_server, base_model, str(ex)), flush=True)
844
  if raise_connection_exception:
845
  raise
846
 
847
  if gr_client is None:
848
  res = None
849
  from text_generation import Client as HFClient
850
+ print("HF Client Begin: %s %s" % (inference_server, base_model))
851
  try:
852
  hf_client = HFClient(inference_server, headers=headers, timeout=int(os.getenv('REQUEST_TIMEOUT', '30')))
853
  # quick check valid TGI endpoint
 
858
  hf_client = None
859
  t, v, tb = sys.exc_info()
860
  ex = ''.join(traceback.format_exception(t, v, tb))
861
+ print("HF Client Failed %s %s: %s" % (inference_server, base_model, str(ex)))
862
  if raise_connection_exception:
863
  raise
864
+ print("HF Client End: %s %s : %s" % (inference_server, base_model, res))
865
  return inference_server, gr_client, hf_client
866
 
867
 
 
869
  load_8bit: bool = False,
870
  load_4bit: bool = False,
871
  load_half: bool = True,
872
+ load_gptq: str = '',
873
+ use_safetensors: bool = False,
874
+ use_gpu_id: bool = True,
875
  base_model: str = '',
876
  inference_server: str = "",
877
  tokenizer_base_model: str = '',
 
893
  :param load_8bit: load model in 8-bit, not supported by all models
894
  :param load_4bit: load model in 4-bit, not supported by all models
895
  :param load_half: load model in 16-bit
896
+ :param load_gptq: GPTQ model_basename
897
+ :param use_safetensors: use safetensors file
898
+ :param use_gpu_id: Use torch infer of optimal placement of layers on devices (for non-lora case)
899
  For non-LORA case, False will spread shards across multiple GPUs, but this can lead to cuda:x cuda:y mismatches
900
  So it is not the default
901
  :param base_model: name/path of base model
 
913
  :param verbose:
914
  :return:
915
  """
916
+ print("Starting get_model: %s %s" % (base_model, inference_server), flush=True)
 
917
 
918
  triton_attn = False
919
  long_sequence = True
 
937
  print("Detected as llama type from"
938
  " config (%s) or name (%s)" % (llama_type_from_config, llama_type_from_name), flush=True)
939
 
940
+ model_loader, tokenizer_loader = get_loaders(model_name=base_model, reward_type=reward_type, llama_type=llama_type,
941
+ load_gptq=load_gptq)
942
 
943
  tokenizer_kwargs = dict(local_files_only=local_files_only,
944
  resume_download=resume_download,
 
962
  tokenizer = FakeTokenizer()
963
 
964
  if isinstance(inference_server, str) and inference_server.startswith("http"):
965
+ inference_server, gr_client, hf_client = get_client_from_inference_server(inference_server,
966
+ base_model=base_model)
967
  client = gr_client or hf_client
968
  # Don't return None, None for model, tokenizer so triggers
969
  return client, tokenizer, 'http'
 
983
  return get_hf_model(load_8bit=load_8bit,
984
  load_4bit=load_4bit,
985
  load_half=load_half,
986
+ load_gptq=load_gptq,
987
+ use_safetensors=use_safetensors,
988
+ use_gpu_id=use_gpu_id,
989
  base_model=base_model,
990
  tokenizer_base_model=tokenizer_base_model,
991
  lora_weights=lora_weights,
 
1009
  def get_hf_model(load_8bit: bool = False,
1010
  load_4bit: bool = False,
1011
  load_half: bool = True,
1012
+ load_gptq: str = '',
1013
+ use_safetensors: bool = False,
1014
+ use_gpu_id: bool = True,
1015
  base_model: str = '',
1016
  tokenizer_base_model: str = '',
1017
  lora_weights: str = "",
 
1048
  "Please choose a base model with --base_model (CLI) or load one from Models Tab (gradio)"
1049
  )
1050
 
1051
+ model_loader, tokenizer_loader = get_loaders(model_name=base_model, reward_type=reward_type, llama_type=llama_type,
1052
+ load_gptq=load_gptq)
1053
 
1054
  config, _ = get_config(base_model, return_model=False, raise_exception=True, **config_kwargs)
1055
 
 
1066
  device=0 if device == "cuda" else -1,
1067
  torch_dtype=torch.float16 if device == 'cuda' else torch.float32)
1068
  else:
1069
+ assert device in ["cuda", "cpu", "mps"], "Unsupported device %s" % device
1070
  model_kwargs = dict(local_files_only=local_files_only,
1071
  torch_dtype=torch.float16 if device == 'cuda' else torch.float32,
1072
  resume_download=resume_download,
 
1075
  offload_folder=offload_folder,
1076
  )
1077
  if 'mbart-' not in base_model.lower() and 'mpt-' not in base_model.lower():
1078
+ if use_gpu_id and gpu_id is not None and gpu_id >= 0 and device == 'cuda':
1079
+ device_map = {"": gpu_id}
1080
+ else:
1081
+ device_map = "auto"
1082
  model_kwargs.update(dict(load_in_8bit=load_8bit,
1083
  load_in_4bit=load_4bit,
1084
+ device_map=device_map,
1085
  ))
1086
  if 'mpt-' in base_model.lower() and gpu_id is not None and gpu_id >= 0:
1087
+ # MPT doesn't support spreading over GPUs
1088
  model_kwargs.update(dict(device_map={"": gpu_id} if device == 'cuda' else "cpu"))
1089
 
1090
  if 'OpenAssistant/reward-model'.lower() in base_model.lower():
 
1094
  pop_unused_model_kwargs(model_kwargs)
1095
 
1096
  if not lora_weights:
1097
+ # torch.device context uses twice memory for AutoGPTQ
1098
+ context = NullContext if load_gptq else torch.device
1099
+ with context(device):
1100
 
1101
+ if use_gpu_id:
1102
  config, model = get_config(base_model, return_model=True, raise_exception=True, **config_kwargs)
1103
+ model = get_non_lora_model(base_model, model_loader, load_half, load_gptq, use_safetensors,
1104
+ model_kwargs, reward_type,
1105
  config, model,
1106
  gpu_id=gpu_id,
1107
  )
1108
  else:
1109
  config, _ = get_config(base_model, **config_kwargs)
1110
+ if load_half and not (load_8bit or load_4bit or load_gptq):
1111
+ model = model_loader(
1112
  base_model,
1113
  config=config,
1114
  **model_kwargs).half()
1115
  else:
1116
+ model = model_loader(
1117
  base_model,
1118
  config=config,
1119
  **model_kwargs)
1120
  elif load_8bit or load_4bit:
1121
  config, _ = get_config(base_model, **config_kwargs)
1122
+ model = model_loader(
1123
  base_model,
1124
  config=config,
1125
  **model_kwargs
 
1139
  else:
1140
  with torch.device(device):
1141
  config, _ = get_config(base_model, raise_exception=True, **config_kwargs)
1142
+ model = model_loader(
1143
  base_model,
1144
  config=config,
1145
  **model_kwargs
 
1156
  offload_folder=offload_folder,
1157
  device_map="auto",
1158
  )
1159
+ if load_half and not load_gptq:
1160
  model.half()
1161
 
1162
  # unwind broken decapoda-research config
 
1215
  load_8bit: bool = False,
1216
  load_4bit: bool = False,
1217
  load_half: bool = True,
1218
+ load_gptq: str = '',
1219
+ use_gpu_id: bool = True,
1220
  base_model: str = '',
1221
  inference_server: str = '',
1222
  tokenizer_base_model: str = '',
 
1237
  load_8bit = False
1238
  load_4bit = False
1239
  load_half = False
1240
+ load_gptq = ''
1241
+ use_safetensors = False
1242
  base_model = score_model.strip()
1243
  tokenizer_base_model = ''
1244
  lora_weights = ''
 
1281
  top_k_docs,
1282
  chunk,
1283
  chunk_size,
1284
+ document_subset,
1285
  document_choice,
1286
  # END NOTE: Examples must have same order of parameters
1287
  src_lang=None,
 
1498
  chunk_size=chunk_size,
1499
  langchain_mode=langchain_mode,
1500
  langchain_action=langchain_action,
1501
+ document_subset=document_subset,
1502
  document_choice=document_choice,
1503
  db_type=db_type,
1504
  top_k_docs=top_k_docs,
 
1526
  inference_server=inference_server,
1527
  langchain_mode=langchain_mode,
1528
  langchain_action=langchain_action,
1529
+ document_subset=document_subset,
1530
  document_choice=document_choice,
1531
  num_prompt_tokens=num_prompt_tokens,
1532
  instruction=instruction,
 
1628
  gr_client = None
1629
  hf_client = model
1630
  else:
1631
+ inference_server, gr_client, hf_client = get_client_from_inference_server(inference_server,
1632
+ base_model=base_model)
1633
 
1634
  # quick sanity check to avoid long timeouts, just see if can reach server
1635
  requests.get(inference_server, timeout=int(os.getenv('REQUEST_TIMEOUT_FAST', '10')))
 
1697
  top_k_docs=top_k_docs,
1698
  chunk=chunk,
1699
  chunk_size=chunk_size,
1700
+ document_subset=DocumentChoices.Relevant.name,
1701
+ document_choice=[],
1702
  )
1703
  api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
1704
  if not stream_output:
 
1897
 
1898
  with torch.no_grad():
1899
  have_lora_weights = lora_weights not in [no_lora_str, '', None]
1900
+ context_class_cast = NullContext if device == 'cpu' or have_lora_weights or device == 'mps' else torch.autocast
1901
  with context_class_cast(device):
1902
  # protection for gradio not keeping track of closed users,
1903
  # else hit bitsandbytes lack of thread safety:
 
2274
 
2275
  # move to correct position
2276
  for example in examples:
2277
+ example += [chat, '', '', LangChainMode.DISABLED.value, LangChainAction.QUERY.value,
2278
+ top_k_docs, chunk, chunk_size, [DocumentChoices.Relevant.name], []
2279
  ]
2280
  # adjust examples if non-chat mode
2281
  if not chat:
 
2498
 
2499
  python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='human_bot' --lora_weights='GPT-NeoXT-Chat-Base-20B.merged.json.8_epochs.57b2892c53df5b8cefac45f84d019cace803ef26.28'
2500
 
2501
+ must have 4*48GB GPU and run without 8bit in order for sharding to work with use_gpu_id=False
2502
  can also pass --prompt_type='human_bot' and model can somewhat handle instructions without being instruct tuned
2503
+ python generate.py --base_model=decapoda-research/llama-65b-hf --load_8bit=False --use_gpu_id=False --prompt_type='human_bot'
2504
 
2505
  python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6_9b
2506
  """
gpt_langchain.py CHANGED
@@ -29,7 +29,8 @@ from evaluate_params import gen_hyper
29
  from gen import get_model, SEED
30
  from prompter import non_hf_types, PromptType, Prompter
31
  from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \
32
- get_device, ProgressParallel, remove, hash_file, clear_torch_cache, NullContext, get_hf_server, FakeTokenizer
 
33
  from utils_langchain import StreamingGradioCallbackHandler
34
 
35
  import_matplotlib()
@@ -387,7 +388,8 @@ class GradioInference(LLM):
387
  top_k_docs=top_k_docs,
388
  chunk=chunk,
389
  chunk_size=chunk_size,
390
- document_choice=[DocumentChoices.All_Relevant.name],
 
391
  )
392
  api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
393
  if not stream_output:
@@ -913,40 +915,6 @@ def get_dai_docs(from_hf=False, get_pickle=True):
913
  return sources
914
 
915
 
916
- import distutils.spawn
917
-
918
- have_tesseract = distutils.spawn.find_executable("tesseract")
919
- have_libreoffice = distutils.spawn.find_executable("libreoffice")
920
-
921
- import pkg_resources
922
-
923
- try:
924
- assert pkg_resources.get_distribution('arxiv') is not None
925
- assert pkg_resources.get_distribution('pymupdf') is not None
926
- have_arxiv = True
927
- except (pkg_resources.DistributionNotFound, AssertionError):
928
- have_arxiv = False
929
-
930
- try:
931
- assert pkg_resources.get_distribution('pymupdf') is not None
932
- have_pymupdf = True
933
- except (pkg_resources.DistributionNotFound, AssertionError):
934
- have_pymupdf = False
935
-
936
- try:
937
- assert pkg_resources.get_distribution('selenium') is not None
938
- have_selenium = True
939
- except (pkg_resources.DistributionNotFound, AssertionError):
940
- have_selenium = False
941
-
942
- try:
943
- assert pkg_resources.get_distribution('playwright') is not None
944
- have_playwright = True
945
- except (pkg_resources.DistributionNotFound, AssertionError):
946
- have_playwright = False
947
-
948
- # disable, hangs too often
949
- have_playwright = False
950
 
951
  image_types = ["png", "jpg", "jpeg"]
952
  non_image_types = ["pdf", "txt", "csv", "toml", "py", "rst", "rtf",
@@ -973,7 +941,7 @@ def add_meta(docs1, file):
973
 
974
 
975
  def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
976
- chunk=True, chunk_size=512,
977
  is_url=False, is_txt=False,
978
  enable_captions=True,
979
  captions_model=None,
@@ -1208,6 +1176,7 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
1208
 
1209
  def path_to_doc1(file, verbose=False, fail_any_exception=False, return_file=True,
1210
  chunk=True, chunk_size=512,
 
1211
  is_url=False, is_txt=False,
1212
  enable_captions=True,
1213
  captions_model=None,
@@ -1224,6 +1193,7 @@ def path_to_doc1(file, verbose=False, fail_any_exception=False, return_file=True
1224
  # don't pass base_path=path, would infinitely recurse
1225
  res = file_to_doc(file, base_path=None, verbose=verbose, fail_any_exception=fail_any_exception,
1226
  chunk=chunk, chunk_size=chunk_size,
 
1227
  is_url=is_url, is_txt=is_txt,
1228
  enable_captions=enable_captions,
1229
  captions_model=captions_model,
@@ -1236,7 +1206,8 @@ def path_to_doc1(file, verbose=False, fail_any_exception=False, return_file=True
1236
  else:
1237
  exception_doc = Document(
1238
  page_content='',
1239
- metadata={"source": file, "exception": str(e), "traceback": traceback.format_exc()})
 
1240
  res = [exception_doc]
1241
  if return_file:
1242
  base_tmp = "temp_path_to_doc1"
@@ -1326,6 +1297,7 @@ def path_to_docs(path_or_paths, verbose=False, fail_any_exception=False, n_jobs=
1326
  kwargs = dict(verbose=verbose, fail_any_exception=fail_any_exception,
1327
  return_file=return_file,
1328
  chunk=chunk, chunk_size=chunk_size,
 
1329
  is_url=is_url,
1330
  is_txt=is_txt,
1331
  enable_captions=enable_captions,
@@ -1802,7 +1774,8 @@ def _run_qa_db(query=None,
1802
  num_return_sequences=1,
1803
  langchain_mode=None,
1804
  langchain_action=None,
1805
- document_choice=[DocumentChoices.All_Relevant.name],
 
1806
  n_jobs=-1,
1807
  verbose=False,
1808
  cli=False,
@@ -1873,19 +1846,13 @@ def _run_qa_db(query=None,
1873
  if isinstance(document_choice, str):
1874
  # support string as well
1875
  document_choice = [document_choice]
1876
- # get first DocumentChoices as command to use, ignore others
1877
- doc_choices_set = set([x.name for x in list(DocumentChoices)])
1878
- cmd = [x for x in document_choice if x in doc_choices_set]
1879
- cmd = None if len(cmd) == 0 else cmd[0]
1880
- # now have cmd, filter out for only docs
1881
- document_choice = [x for x in document_choice if x not in doc_choices_set]
1882
-
1883
- func_names = list(inspect.signature(get_similarity_chain).parameters)
1884
  sim_kwargs = {k: v for k, v in locals().items() if k in func_names}
1885
  missing_kwargs = [x for x in func_names if x not in sim_kwargs]
1886
  assert not missing_kwargs, "Missing: %s" % missing_kwargs
1887
- docs, chain, scores, use_context, have_any_docs = get_similarity_chain(**sim_kwargs)
1888
- if cmd in non_query_commands:
1889
  formatted_doc_chunks = '\n\n'.join([get_url(x) + '\n\n' + x.page_content for x in docs])
1890
  yield formatted_doc_chunks, ''
1891
  return
@@ -1963,36 +1930,36 @@ def _run_qa_db(query=None,
1963
  return
1964
 
1965
 
1966
- def get_similarity_chain(query=None,
1967
- iinput=None,
1968
- use_openai_model=False, use_openai_embedding=False,
1969
- first_para=False, text_limit=None, top_k_docs=4, chunk=True, chunk_size=512,
1970
- user_path=None,
1971
- detect_user_path_changes_every_query=False,
1972
- db_type='faiss',
1973
- model_name=None,
1974
- inference_server='',
1975
- hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
1976
- prompt_type=None,
1977
- prompt_dict=None,
1978
- cut_distanct=1.1,
1979
- load_db_if_exists=False,
1980
- db=None,
1981
- langchain_mode=None,
1982
- langchain_action=None,
1983
- document_choice=[DocumentChoices.All_Relevant.name],
1984
- n_jobs=-1,
1985
- # beyond run_db_query:
1986
- llm=None,
1987
- tokenizer=None,
1988
- verbose=False,
1989
- cmd=None,
1990
- reverse_docs=True,
1991
-
1992
- # local
1993
- auto_reduce_chunks=True,
1994
- max_chunks=100,
1995
- ):
1996
  # determine whether use of context out of docs is planned
1997
  if not use_openai_model and prompt_type not in ['plain'] or model_name in non_hf_types:
1998
  if langchain_mode in ['Disabled', 'ChatLLM', 'LLM']:
@@ -2086,12 +2053,25 @@ def get_similarity_chain(query=None,
2086
  use_template = False
2087
 
2088
  if db and use_context:
 
 
 
 
 
 
 
 
2089
  if not isinstance(db, Chroma):
2090
  # only chroma supports filtering
2091
  filter_kwargs = {}
2092
  else:
2093
- # if here then some cmd + documents selected or just documents selected
2094
- if len(document_choice) >= 2:
 
 
 
 
 
2095
  or_filter = [{"source": {"$eq": x}} for x in document_choice]
2096
  filter_kwargs = dict(filter={"$or": or_filter})
2097
  elif len(document_choice) == 1:
@@ -2101,10 +2081,10 @@ def get_similarity_chain(query=None,
2101
  else:
2102
  # shouldn't reach
2103
  filter_kwargs = {}
2104
- if cmd == DocumentChoices.Just_LLM.name:
2105
  docs = []
2106
  scores = []
2107
- elif cmd == DocumentChoices.Only_All_Sources.name or query in [None, '', '\n']:
2108
  db_documents, db_metadatas = get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs)
2109
  # similar to langchain's chroma's _results_to_docs_and_scores
2110
  docs_with_score = [(Document(page_content=result[0], metadata=result[1] or {}), 0)
@@ -2127,13 +2107,7 @@ def get_similarity_chain(query=None,
2127
  if top_k_docs == -1 or auto_reduce_chunks:
2128
  # docs_with_score = db.similarity_search_with_score(query, k=k_db, **filter_kwargs)[:top_k_docs]
2129
  top_k_docs_tokenize = 100
2130
- base_path = 'locks'
2131
- makedirs(base_path)
2132
- if hasattr(db, '_persist_directory'):
2133
- name_path = "sim_%s.lock" % os.path.basename(db._persist_directory)
2134
- else:
2135
- name_path = "sim.lock"
2136
- with filelock.FileLock(os.path.join(base_path, name_path)):
2137
  docs_with_score = db.similarity_search_with_score(query, k=k_db, **filter_kwargs)[
2138
  :top_k_docs_tokenize]
2139
  if hasattr(llm, 'pipeline') and hasattr(llm.pipeline, 'tokenizer'):
@@ -2189,7 +2163,8 @@ def get_similarity_chain(query=None,
2189
  top_k_docs = 1
2190
  docs_with_score = docs_with_score[:top_k_docs]
2191
  else:
2192
- docs_with_score = db.similarity_search_with_score(query, k=k_db, **filter_kwargs)[:top_k_docs]
 
2193
  # put most relevant chunks closest to question,
2194
  # esp. if truncation occurs will be "oldest" or "farthest from response" text that is truncated
2195
  # BUT: for small models, e.g. 6_9 pythia, if sees some stuff related to h2oGPT first, it can connect that and not listen to rest
@@ -2210,7 +2185,7 @@ def get_similarity_chain(query=None,
2210
  # if HF type and have no docs, can bail out
2211
  return docs, None, [], False, have_any_docs
2212
 
2213
- if cmd in non_query_commands:
2214
  # no LLM use
2215
  return docs, None, [], False, have_any_docs
2216
 
 
29
  from gen import get_model, SEED
30
  from prompter import non_hf_types, PromptType, Prompter
31
  from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \
32
+ get_device, ProgressParallel, remove, hash_file, clear_torch_cache, NullContext, get_hf_server, FakeTokenizer, \
33
+ have_libreoffice, have_arxiv, have_playwright, have_selenium, have_tesseract, have_pymupdf
34
  from utils_langchain import StreamingGradioCallbackHandler
35
 
36
  import_matplotlib()
 
388
  top_k_docs=top_k_docs,
389
  chunk=chunk,
390
  chunk_size=chunk_size,
391
+ document_subset=DocumentChoices.Relevant.name,
392
+ document_choice=[],
393
  )
394
  api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
395
  if not stream_output:
 
915
  return sources
916
 
917
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
918
 
919
  image_types = ["png", "jpg", "jpeg"]
920
  non_image_types = ["pdf", "txt", "csv", "toml", "py", "rst", "rtf",
 
941
 
942
 
943
  def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
944
+ chunk=True, chunk_size=512, n_jobs=-1,
945
  is_url=False, is_txt=False,
946
  enable_captions=True,
947
  captions_model=None,
 
1176
 
1177
  def path_to_doc1(file, verbose=False, fail_any_exception=False, return_file=True,
1178
  chunk=True, chunk_size=512,
1179
+ n_jobs=-1,
1180
  is_url=False, is_txt=False,
1181
  enable_captions=True,
1182
  captions_model=None,
 
1193
  # don't pass base_path=path, would infinitely recurse
1194
  res = file_to_doc(file, base_path=None, verbose=verbose, fail_any_exception=fail_any_exception,
1195
  chunk=chunk, chunk_size=chunk_size,
1196
+ n_jobs=n_jobs,
1197
  is_url=is_url, is_txt=is_txt,
1198
  enable_captions=enable_captions,
1199
  captions_model=captions_model,
 
1206
  else:
1207
  exception_doc = Document(
1208
  page_content='',
1209
+ metadata={"source": file, "exception": '%s hit %s' % (file, str(e)),
1210
+ "traceback": traceback.format_exc()})
1211
  res = [exception_doc]
1212
  if return_file:
1213
  base_tmp = "temp_path_to_doc1"
 
1297
  kwargs = dict(verbose=verbose, fail_any_exception=fail_any_exception,
1298
  return_file=return_file,
1299
  chunk=chunk, chunk_size=chunk_size,
1300
+ n_jobs=n_jobs,
1301
  is_url=is_url,
1302
  is_txt=is_txt,
1303
  enable_captions=enable_captions,
 
1774
  num_return_sequences=1,
1775
  langchain_mode=None,
1776
  langchain_action=None,
1777
+ document_subset=DocumentChoices.Relevant.name,
1778
+ document_choice=[],
1779
  n_jobs=-1,
1780
  verbose=False,
1781
  cli=False,
 
1846
  if isinstance(document_choice, str):
1847
  # support string as well
1848
  document_choice = [document_choice]
1849
+
1850
+ func_names = list(inspect.signature(get_chain).parameters)
 
 
 
 
 
 
1851
  sim_kwargs = {k: v for k, v in locals().items() if k in func_names}
1852
  missing_kwargs = [x for x in func_names if x not in sim_kwargs]
1853
  assert not missing_kwargs, "Missing: %s" % missing_kwargs
1854
+ docs, chain, scores, use_context, have_any_docs = get_chain(**sim_kwargs)
1855
+ if document_subset in non_query_commands:
1856
  formatted_doc_chunks = '\n\n'.join([get_url(x) + '\n\n' + x.page_content for x in docs])
1857
  yield formatted_doc_chunks, ''
1858
  return
 
1930
  return
1931
 
1932
 
1933
+ def get_chain(query=None,
1934
+ iinput=None,
1935
+ use_openai_model=False, use_openai_embedding=False,
1936
+ first_para=False, text_limit=None, top_k_docs=4, chunk=True, chunk_size=512,
1937
+ user_path=None,
1938
+ detect_user_path_changes_every_query=False,
1939
+ db_type='faiss',
1940
+ model_name=None,
1941
+ inference_server='',
1942
+ hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
1943
+ prompt_type=None,
1944
+ prompt_dict=None,
1945
+ cut_distanct=1.1,
1946
+ load_db_if_exists=False,
1947
+ db=None,
1948
+ langchain_mode=None,
1949
+ langchain_action=None,
1950
+ document_subset=DocumentChoices.Relevant.name,
1951
+ document_choice=[],
1952
+ n_jobs=-1,
1953
+ # beyond run_db_query:
1954
+ llm=None,
1955
+ tokenizer=None,
1956
+ verbose=False,
1957
+ reverse_docs=True,
1958
+
1959
+ # local
1960
+ auto_reduce_chunks=True,
1961
+ max_chunks=100,
1962
+ ):
1963
  # determine whether use of context out of docs is planned
1964
  if not use_openai_model and prompt_type not in ['plain'] or model_name in non_hf_types:
1965
  if langchain_mode in ['Disabled', 'ChatLLM', 'LLM']:
 
2053
  use_template = False
2054
 
2055
  if db and use_context:
2056
+ base_path = 'locks'
2057
+ makedirs(base_path)
2058
+ if hasattr(db, '_persist_directory'):
2059
+ name_path = "sim_%s.lock" % os.path.basename(db._persist_directory)
2060
+ else:
2061
+ name_path = "sim.lock"
2062
+ lock_file = os.path.join(base_path, name_path)
2063
+
2064
  if not isinstance(db, Chroma):
2065
  # only chroma supports filtering
2066
  filter_kwargs = {}
2067
  else:
2068
+ assert document_choice is not None, "Document choice was None"
2069
+ if len(document_choice) >= 1 and document_choice[0] == DocumentChoices.All.name:
2070
+ filter_kwargs = {}
2071
+ elif len(document_choice) >= 2:
2072
+ if document_choice[0] == DocumentChoices.All.name:
2073
+ # remove 'All'
2074
+ document_choice = document_choice[1:]
2075
  or_filter = [{"source": {"$eq": x}} for x in document_choice]
2076
  filter_kwargs = dict(filter={"$or": or_filter})
2077
  elif len(document_choice) == 1:
 
2081
  else:
2082
  # shouldn't reach
2083
  filter_kwargs = {}
2084
+ if langchain_mode in [LangChainMode.LLM.value, LangChainMode.CHAT_LLM.value]:
2085
  docs = []
2086
  scores = []
2087
+ elif document_subset == DocumentChoices.All.name or query in [None, '', '\n']:
2088
  db_documents, db_metadatas = get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs)
2089
  # similar to langchain's chroma's _results_to_docs_and_scores
2090
  docs_with_score = [(Document(page_content=result[0], metadata=result[1] or {}), 0)
 
2107
  if top_k_docs == -1 or auto_reduce_chunks:
2108
  # docs_with_score = db.similarity_search_with_score(query, k=k_db, **filter_kwargs)[:top_k_docs]
2109
  top_k_docs_tokenize = 100
2110
+ with filelock.FileLock(lock_file):
 
 
 
 
 
 
2111
  docs_with_score = db.similarity_search_with_score(query, k=k_db, **filter_kwargs)[
2112
  :top_k_docs_tokenize]
2113
  if hasattr(llm, 'pipeline') and hasattr(llm.pipeline, 'tokenizer'):
 
2163
  top_k_docs = 1
2164
  docs_with_score = docs_with_score[:top_k_docs]
2165
  else:
2166
+ with filelock.FileLock(lock_file):
2167
+ docs_with_score = db.similarity_search_with_score(query, k=k_db, **filter_kwargs)[:top_k_docs]
2168
  # put most relevant chunks closest to question,
2169
  # esp. if truncation occurs will be "oldest" or "farthest from response" text that is truncated
2170
  # BUT: for small models, e.g. 6_9 pythia, if sees some stuff related to h2oGPT first, it can connect that and not listen to rest
 
2185
  # if HF type and have no docs, can bail out
2186
  return docs, None, [], False, have_any_docs
2187
 
2188
+ if document_subset in non_query_commands:
2189
  # no LLM use
2190
  return docs, None, [], False, have_any_docs
2191
 
gradio_runner.py CHANGED
@@ -20,7 +20,7 @@ import tabulate
20
  from iterators import TimeoutIterator
21
 
22
  from gradio_utils.css import get_css
23
- from gradio_utils.prompt_form import make_prompt_form, make_chatbots
24
 
25
  # This is a hack to prevent Gradio from phoning home when it gets imported
26
  os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
@@ -56,7 +56,7 @@ from gradio_themes import H2oTheme, SoftTheme, get_h2o_title, get_simple_title,
56
  from prompter import prompt_type_to_model_name, prompt_types_strings, inv_prompt_type_to_model_lower, non_hf_types, \
57
  get_prompt
58
  from utils import get_githash, flatten_list, zip_data, s3up, clear_torch_cache, get_torch_allocated, system_info_print, \
59
- ping, get_short_name, get_url, makedirs, get_kwargs, remove, system_info, ping_gpu
60
  from gen import get_model, languages_covered, evaluate, score_qa, langchain_modes, inputs_kwargs_list, scratch_base_dir, \
61
  get_max_max_new_tokens, get_minmax_top_k_docs, history_to_context, langchain_actions
62
  from evaluate_params import eval_func_param_names, no_default_param_names, eval_func_param_names_defaults
@@ -118,6 +118,13 @@ def go_gradio(**kwargs):
118
  allow_upload = allow_upload_to_user_data or allow_upload_to_my_data
119
  kwargs.update(locals())
120
 
 
 
 
 
 
 
 
121
  if 'mbart-' in kwargs['model_lower']:
122
  instruction_label_nochat = "Text to translate"
123
  else:
@@ -134,8 +141,7 @@ def go_gradio(**kwargs):
134
  """
135
  else:
136
  description = more_info
137
- description_bottom = "If this host is busy, try [LLaMa 65B](https://llama.h2o.ai), [Falcon 40B](https://gpt.h2o.ai), [Falcon 40B](http://falcon.h2o.ai), [HF Spaces1 12B](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot) or [HF Spaces2 12B](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot2)<br>"
138
- description_bottom += """<p>By using h2oGPT, you accept our [Terms of Service](https://github.com/h2oai/h2ogpt/blob/main/docs/tos.md)</p>"""
139
  if is_hf:
140
  description_bottom += '''<a href="https://huggingface.co/spaces/h2oai/h2ogpt-chatbot?duplicate=true"><img src="https://bit.ly/3gLdBN6" style="white-space: nowrap" alt="Duplicate Space"></a>'''
141
 
@@ -160,7 +166,7 @@ def go_gradio(**kwargs):
160
  theme_kwargs = dict()
161
  if kwargs['gradio_size'] == 'xsmall':
162
  theme_kwargs.update(dict(spacing_size=spacing_xsm, text_size=text_xsm, radius_size=radius_xsm))
163
- elif kwargs['gradio_size'] == 'small':
164
  theme_kwargs.update(dict(spacing_size=gr.themes.sizes.spacing_sm, text_size=gr.themes.sizes.text_sm,
165
  radius_size=gr.themes.sizes.spacing_sm))
166
  elif kwargs['gradio_size'] == 'large':
@@ -262,14 +268,14 @@ def go_gradio(**kwargs):
262
  model_options_state = gr.State([model_options])
263
  lora_options_state = gr.State([lora_options])
264
  server_options_state = gr.State([server_options])
265
- # uuid in db is used as user ID
266
- my_db_state = gr.State([None, str(uuid.uuid4())])
267
  chat_state = gr.State({})
268
- # make user default first and default choice, dedup
269
- docs_state00 = kwargs['document_choice'] + [x.name for x in list(DocumentChoices)]
270
  docs_state0 = []
271
  [docs_state0.append(x) for x in docs_state00 if x not in docs_state0]
272
- docs_state = gr.State(docs_state0) # first is chosen as default
 
 
273
  gr.Markdown(f"""
274
  {get_h2o_title(title, description) if kwargs['h2ocolors'] else get_simple_title(title, description)}
275
  """)
@@ -282,179 +288,208 @@ def go_gradio(**kwargs):
282
  res_value = "Response Score: NA" if not kwargs[
283
  'model_lock'] else "Response Scores: %s" % nas
284
 
285
- normal_block = gr.Row(visible=not base_wanted)
 
 
 
 
 
 
 
 
 
286
  with normal_block:
287
- with gr.Tabs():
288
- with gr.Row():
289
- col_nochat = gr.Column(visible=not kwargs['chat'])
290
- with col_nochat: # FIXME: for model comparison, and check rest
291
- if kwargs['langchain_mode'] == 'Disabled':
292
- text_output_nochat = gr.Textbox(lines=5, label=output_label0, show_copy_button=True)
293
- else:
294
- # text looks a bit worse, but HTML links work
295
- text_output_nochat = gr.HTML(label=output_label0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
  instruction_nochat = gr.Textbox(
297
  lines=kwargs['input_lines'],
298
  label=instruction_label_nochat,
299
  placeholder=kwargs['placeholder_instruction'],
 
300
  )
301
  iinput_nochat = gr.Textbox(lines=4, label="Input context for Instruction",
302
- placeholder=kwargs['placeholder_input'])
303
- submit_nochat = gr.Button("Submit")
304
- flag_btn_nochat = gr.Button("Flag")
305
- with gr.Column(visible=kwargs['score_model']):
306
- score_text_nochat = gr.Textbox("Response Score: NA", show_label=False)
307
-
308
- col_chat = gr.Column(visible=kwargs['chat'])
309
- with col_chat:
310
- instruction, submit, stop_btn = make_prompt_form(kwargs, LangChainMode)
311
- text_output, text_output2, text_outputs = make_chatbots(output_label0, output_label0_model2,
312
- **kwargs)
313
-
314
- with gr.Row():
315
- clear = gr.Button("Save Chat / New Chat")
316
- flag_btn = gr.Button("Flag")
317
- with gr.Column(visible=kwargs['score_model']):
318
- score_text = gr.Textbox(res_value,
319
- show_label=False,
320
- visible=True)
321
- score_text2 = gr.Textbox("Response Score2: NA", show_label=False,
322
- visible=False and not kwargs['model_lock'])
323
- retry_btn = gr.Button("Regenerate")
324
- undo = gr.Button("Undo")
325
- submit_nochat_api = gr.Button("Submit nochat API", visible=False)
326
- inputs_dict_str = gr.Textbox(label='API input for nochat', show_label=False, visible=False)
327
- text_output_nochat_api = gr.Textbox(lines=5, label='API nochat output', visible=False,
328
- show_copy_button=True)
329
- with gr.TabItem("Documents"):
330
- langchain_readme = get_url('https://github.com/h2oai/h2ogpt/blob/main/docs/README_LangChain.md',
331
- from_str=True)
332
- gr.HTML(value=f"""LangChain Support Disabled<p>
333
- Run:<p>
334
- <code>
335
- python generate.py --langchain_mode=MyData
336
- </code>
337
- <p>
338
- For more options see: {langchain_readme}""",
339
- visible=kwargs['langchain_mode'] == 'Disabled', interactive=False)
340
- data_row1 = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled')
341
- with data_row1:
342
- if is_hf:
343
- # don't show 'wiki' since only usually useful for internal testing at moment
344
- no_show_modes = ['Disabled', 'wiki']
345
- else:
346
- no_show_modes = ['Disabled']
347
- allowed_modes = visible_langchain_modes.copy()
348
- allowed_modes = [x for x in allowed_modes if x in dbs]
349
- allowed_modes += ['ChatLLM', 'LLM']
350
- if allow_upload_to_my_data and 'MyData' not in allowed_modes:
351
- allowed_modes += ['MyData']
352
- if allow_upload_to_user_data and 'UserData' not in allowed_modes:
353
- allowed_modes += ['UserData']
354
- langchain_mode = gr.Radio(
355
- [x for x in langchain_modes if x in allowed_modes and x not in no_show_modes],
356
- value=kwargs['langchain_mode'],
357
- label="Data Collection of Sources",
358
- visible=kwargs['langchain_mode'] != 'Disabled')
359
- allowed_actions = [x for x in langchain_actions if x in visible_langchain_actions]
360
- langchain_action = gr.Radio(
361
- allowed_actions,
362
- value=allowed_actions[0] if len(allowed_actions) > 0 else None,
363
- label="Data Action",
364
- visible=True)
365
- data_row2 = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled')
366
- with data_row2:
367
- with gr.Column(scale=50):
368
- document_choice = gr.Dropdown(docs_state.value,
369
- label="Choose Subset of Doc(s) in Collection [click get sources to update]",
370
- value=docs_state.value[0],
371
- interactive=True,
372
- multiselect=True,
373
- )
374
- with gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and enable_sources_list):
375
- get_sources_btn = gr.Button(value="Get Sources", scale=0, size='sm')
376
- show_sources_btn = gr.Button(value="Show Sources", scale=0, size='sm')
377
- refresh_sources_btn = gr.Button(value="Refresh Sources", scale=0, size='sm')
378
-
379
- # import control
380
- if kwargs['langchain_mode'] != 'Disabled':
381
- from gpt_langchain import file_types, have_arxiv
382
- else:
383
- have_arxiv = False
384
- file_types = []
385
 
386
- upload_row = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and allow_upload,
387
- equal_height=False)
388
- with upload_row:
389
- with gr.Column():
390
- file_types_str = '[' + ' '.join(file_types) + ']'
391
- fileup_output = gr.File(label=f'Upload {file_types_str}',
392
- file_types=file_types,
393
- file_count="multiple",
394
- elem_id="warning", elem_classes="feedback")
395
  with gr.Row():
396
- add_to_shared_db_btn = gr.Button("Add File(s) to UserData",
397
- visible=allow_upload_to_user_data,
398
- elem_id='small_btn')
399
- add_to_my_db_btn = gr.Button("Add File(s) to Scratch MyData",
400
- visible=allow_upload_to_my_data and
401
- allow_upload_to_user_data,
402
- elem_id='small_btn' if allow_upload_to_user_data else None,
403
- size='sm' if not allow_upload_to_user_data else None)
404
- with gr.Column(
405
- visible=kwargs['langchain_mode'] != 'Disabled' and allow_upload and enable_url_upload):
406
- url_label = 'URL (http/https) or ArXiv:' if have_arxiv else 'URL (http/https)'
407
- url_text = gr.Textbox(label=url_label,
408
- placeholder="Click Add to Submit" if
409
- allow_upload_to_my_data and
410
- allow_upload_to_user_data else
411
- "Enter to Submit",
412
- max_lines=1,
413
- interactive=True)
414
- with gr.Row():
415
- url_user_btn = gr.Button(value='Add URL content to Shared UserData',
416
- visible=allow_upload_to_user_data and allow_upload_to_my_data,
417
- elem_id='small_btn')
418
- url_my_btn = gr.Button(value='Add URL content to Scratch MyData',
419
- visible=allow_upload_to_my_data and allow_upload_to_user_data,
420
- elem_id='small_btn' if allow_upload_to_user_data else None,
421
- size='sm' if not allow_upload_to_user_data else None)
422
- with gr.Column(
423
- visible=kwargs['langchain_mode'] != 'Disabled' and allow_upload and enable_text_upload):
424
- user_text_text = gr.Textbox(label='Paste Text [Shift-Enter more lines]',
425
- placeholder="Click Add to Submit" if
426
- allow_upload_to_my_data and
427
- allow_upload_to_user_data else
428
- "Enter to Submit, Shift-Enter for more lines",
429
- interactive=True)
430
- with gr.Row():
431
- user_text_user_btn = gr.Button(value='Add Text to Shared UserData',
432
- visible=allow_upload_to_user_data and allow_upload_to_my_data,
433
- elem_id='small_btn')
434
- user_text_my_btn = gr.Button(value='Add Text to Scratch MyData',
435
- visible=allow_upload_to_my_data and allow_upload_to_user_data,
436
- elem_id='small_btn' if allow_upload_to_user_data else None,
437
- size='sm' if not allow_upload_to_user_data else None)
438
- with gr.Column(visible=False):
439
- # WIP:
440
- with gr.Row(visible=False, equal_height=False):
441
- github_textbox = gr.Textbox(label="Github URL")
442
- with gr.Row(visible=True):
443
- github_shared_btn = gr.Button(value="Add Github to Shared UserData",
444
- visible=allow_upload_to_user_data,
445
- elem_id='small_btn')
446
- github_my_btn = gr.Button(value="Add Github to Scratch MyData",
447
- visible=allow_upload_to_my_data, elem_id='small_btn')
448
  sources_row = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and enable_sources_list,
449
  equal_height=False)
450
  with sources_row:
451
  with gr.Column(scale=1):
452
  file_source = gr.File(interactive=False,
453
- label="Download File w/Sources [click get sources to make file]")
454
  with gr.Column(scale=2):
455
  sources_text = gr.HTML(label='Sources Added', interactive=False)
456
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
457
  with gr.TabItem("Chat History"):
 
 
 
 
 
 
 
 
 
 
 
 
 
458
  with gr.Row():
459
  if 'mbart-' in kwargs['model_lower']:
460
  src_lang = gr.Dropdown(list(languages_covered().keys()),
@@ -463,20 +498,9 @@ def go_gradio(**kwargs):
463
  tgt_lang = gr.Dropdown(list(languages_covered().keys()),
464
  value=kwargs['tgt_lang'],
465
  label="Output Language")
466
- radio_chats = gr.Radio(value=None, label="Saved Chats", visible=True, interactive=True,
467
- type='value')
468
- with gr.Row():
469
- clear_chat_btn = gr.Button(value="Clear Chat", visible=True, size='sm')
470
- export_chats_btn = gr.Button(value="Export Chats to Download", size='sm')
471
- remove_chat_btn = gr.Button(value="Remove Selected Chat", visible=True, size='sm')
472
- add_to_chats_btn = gr.Button("Import Chats from Upload", size='sm')
473
- with gr.Row():
474
- chats_file = gr.File(interactive=False, label="Download Exported Chats")
475
- chatsup_output = gr.File(label="Upload Chat File(s)",
476
- file_types=['.json'],
477
- file_count='multiple',
478
- elem_id="warning", elem_classes="feedback")
479
 
 
 
480
  with gr.TabItem("Expert"):
481
  with gr.Row():
482
  with gr.Column():
@@ -555,7 +579,7 @@ def go_gradio(**kwargs):
555
  info="Directly pre-appended without prompt processing",
556
  interactive=not is_public)
557
  chat = gr.components.Checkbox(label="Chat mode", value=kwargs['chat'],
558
- visible=not kwargs['model_lock'],
559
  interactive=not is_public,
560
  )
561
  count_chat_tokens_btn = gr.Button(value="Count Chat Tokens",
@@ -614,9 +638,9 @@ def go_gradio(**kwargs):
614
  model_load8bit_checkbox = gr.components.Checkbox(
615
  label="Load 8-bit [requires support]",
616
  value=kwargs['load_8bit'], interactive=not is_public)
617
- model_infer_devices_checkbox = gr.components.Checkbox(
618
  label="Choose Devices [If not Checked, use all GPUs]",
619
- value=kwargs['infer_devices'], interactive=not is_public)
620
  model_gpu = gr.Dropdown(n_gpus_list,
621
  label="GPU ID [-1 = all GPUs, if Choose is enabled]",
622
  value=kwargs['gpu_id'], interactive=not is_public)
@@ -649,10 +673,10 @@ def go_gradio(**kwargs):
649
  model_load8bit_checkbox2 = gr.components.Checkbox(
650
  label="Load 8-bit 2 [requires support]",
651
  value=kwargs['load_8bit'], interactive=not is_public)
652
- model_infer_devices_checkbox2 = gr.components.Checkbox(
653
  label="Choose Devices 2 [If not Checked, use all GPUs]",
654
  value=kwargs[
655
- 'infer_devices'], interactive=not is_public)
656
  model_gpu2 = gr.Dropdown(n_gpus_list,
657
  label="GPU ID 2 [-1 = all GPUs, if choose is enabled]",
658
  value=kwargs['gpu_id'], interactive=not is_public)
@@ -679,35 +703,52 @@ def go_gradio(**kwargs):
679
  add_model_lora_server_button = gr.Button("Add new Model, Lora, Server url:port", scale=0,
680
  size='sm', interactive=not is_public)
681
  with gr.TabItem("System"):
 
 
 
 
 
 
 
 
 
 
 
 
 
682
  admin_row = gr.Row()
683
  with admin_row:
684
- admin_pass_textbox = gr.Textbox(label="Admin Password", type='password', visible=is_public)
685
- admin_btn = gr.Button(value="Admin Access", visible=is_public)
 
 
 
686
  system_row = gr.Row(visible=not is_public)
687
  with system_row:
688
  with gr.Column():
689
  with gr.Row():
690
- system_btn = gr.Button(value='Get System Info')
691
  system_text = gr.Textbox(label='System Info', interactive=False, show_copy_button=True)
692
  with gr.Row():
693
  system_input = gr.Textbox(label='System Info Dict Password', interactive=True,
694
  visible=not is_public)
695
- system_btn2 = gr.Button(value='Get System Info Dict', visible=not is_public)
696
  system_text2 = gr.Textbox(label='System Info Dict', interactive=False,
697
  visible=not is_public, show_copy_button=True)
698
  with gr.Row():
699
- system_btn3 = gr.Button(value='Get Hash', visible=not is_public)
700
  system_text3 = gr.Textbox(label='Hash', interactive=False,
701
  visible=not is_public, show_copy_button=True)
702
 
703
  with gr.Row():
704
- zip_btn = gr.Button("Zip")
705
  zip_text = gr.Textbox(label="Zip file name", interactive=False)
706
  file_output = gr.File(interactive=False, label="Zip file to Download")
707
  with gr.Row():
708
- s3up_btn = gr.Button("S3UP")
709
  s3up_text = gr.Textbox(label='S3UP result', interactive=False)
710
- with gr.TabItem("Disclaimers"):
 
711
  description = ""
712
  description += """<p><b> DISCLAIMERS: </b><ul><i><li>The model was trained on The Pile and other data, which may contain objectionable content. Use at own risk.</i></li>"""
713
  if kwargs['load_8bit']:
@@ -718,17 +759,18 @@ def go_gradio(**kwargs):
718
  description += """<i><li>By using h2oGPT, you accept our <a href="https://github.com/h2oai/h2ogpt/blob/main/docs/tos.md">Terms of Service</a></i></li></ul></p>"""
719
  gr.Markdown(value=description, show_label=False, interactive=False)
720
 
721
- gr.Markdown(f"""
722
- {description_bottom}
723
- {task_info_md}
724
- """)
 
725
 
726
  # Get flagged data
727
  zip_data1 = functools.partial(zip_data, root_dirs=['flagged_data_points', kwargs['save_dir']])
728
- zip_btn.click(zip_data1, inputs=None, outputs=[file_output, zip_text], queue=False,
729
- api_name='zip_data' if allow_api else None)
730
- s3up_btn.click(s3up, inputs=zip_text, outputs=s3up_text, queue=False,
731
- api_name='s3up_data' if allow_api else None)
732
 
733
  def clear_file_list():
734
  return None
@@ -746,182 +788,204 @@ def go_gradio(**kwargs):
746
  return tuple([gr.update(interactive=True)] * len(args))
747
 
748
  # Add to UserData
749
- update_user_db_func = functools.partial(update_user_db,
750
- dbs=dbs, db_type=db_type, langchain_mode='UserData',
751
- use_openai_embedding=use_openai_embedding,
752
- hf_embedding_model=hf_embedding_model,
753
- enable_captions=enable_captions,
754
- captions_model=captions_model,
755
- enable_ocr=enable_ocr,
756
- caption_loader=caption_loader,
757
- verbose=kwargs['verbose'],
758
- user_path=kwargs['user_path'],
759
- n_jobs=kwargs['n_jobs'],
760
- )
761
- add_file_outputs = [fileup_output, langchain_mode, add_to_shared_db_btn, add_to_my_db_btn]
762
- add_file_kwargs = dict(fn=update_user_db_func,
763
- inputs=[fileup_output, my_db_state, add_to_shared_db_btn,
764
- add_to_my_db_btn,
765
- chunk, chunk_size],
766
- outputs=add_file_outputs + [sources_text],
767
  queue=queue,
768
- api_name='add_to_shared' if allow_api and allow_upload_to_user_data else None)
769
 
770
- if allow_upload_to_user_data and not allow_upload_to_my_data:
771
- func1 = fileup_output.change
772
- else:
773
- func1 = add_to_shared_db_btn.click
774
  # then no need for add buttons, only single changeable db
775
- eventdb1a = func1(make_non_interactive, inputs=add_file_outputs, outputs=add_file_outputs,
776
- show_progress='minimal')
777
- eventdb1 = eventdb1a.then(**add_file_kwargs, show_progress='minimal')
778
- eventdb1.then(make_interactive, inputs=add_file_outputs, outputs=add_file_outputs, show_progress='minimal')
 
779
 
780
  # note for update_user_db_func output is ignored for db
781
 
782
  def clear_textbox():
783
  return gr.Textbox.update(value='')
784
 
785
- update_user_db_url_func = functools.partial(update_user_db_func, is_url=True)
786
 
787
- add_url_outputs = [url_text, langchain_mode, url_user_btn, url_my_btn]
788
  add_url_kwargs = dict(fn=update_user_db_url_func,
789
- inputs=[url_text, my_db_state, url_user_btn, url_my_btn,
790
- chunk, chunk_size],
791
- outputs=add_url_outputs + [sources_text],
792
  queue=queue,
793
- api_name='add_url_to_shared' if allow_api and allow_upload_to_user_data else None)
794
 
795
- if allow_upload_to_user_data and not allow_upload_to_my_data:
796
- func2 = url_text.submit
797
- else:
798
- func2 = url_user_btn.click
799
- eventdb2a = func2(fn=dummy_fun, inputs=url_text, outputs=url_text, queue=queue,
800
- show_progress='minimal')
801
  # work around https://github.com/gradio-app/gradio/issues/4733
802
  eventdb2b = eventdb2a.then(make_non_interactive, inputs=add_url_outputs, outputs=add_url_outputs,
803
  show_progress='minimal')
804
- eventdb2 = eventdb2b.then(**add_url_kwargs, show_progress='minimal')
805
- eventdb2.then(make_interactive, inputs=add_url_outputs, outputs=add_url_outputs, show_progress='minimal')
 
806
 
807
- update_user_db_txt_func = functools.partial(update_user_db_func, is_txt=True)
808
- add_text_outputs = [user_text_text, langchain_mode, user_text_user_btn, user_text_my_btn]
809
  add_text_kwargs = dict(fn=update_user_db_txt_func,
810
- inputs=[user_text_text, my_db_state, user_text_user_btn, user_text_my_btn,
811
- chunk, chunk_size],
812
- outputs=add_text_outputs + [sources_text],
813
  queue=queue,
814
- api_name='add_text_to_shared' if allow_api and allow_upload_to_user_data else None
815
  )
816
- if allow_upload_to_user_data and not allow_upload_to_my_data:
817
- func3 = user_text_text.submit
818
- else:
819
- func3 = user_text_user_btn.click
820
-
821
- eventdb3a = func3(fn=dummy_fun, inputs=user_text_text, outputs=user_text_text, queue=queue,
822
- show_progress='minimal')
823
  eventdb3b = eventdb3a.then(make_non_interactive, inputs=add_text_outputs, outputs=add_text_outputs,
824
  show_progress='minimal')
825
- eventdb3 = eventdb3b.then(**add_text_kwargs, show_progress='minimal')
826
- eventdb3.then(make_interactive, inputs=add_text_outputs, outputs=add_text_outputs,
827
- show_progress='minimal')
828
-
829
- update_my_db_func = functools.partial(update_user_db, dbs=dbs, db_type=db_type, langchain_mode='MyData',
830
- use_openai_embedding=use_openai_embedding,
831
- hf_embedding_model=hf_embedding_model,
832
- enable_captions=enable_captions,
833
- captions_model=captions_model,
834
- enable_ocr=enable_ocr,
835
- caption_loader=caption_loader,
836
- verbose=kwargs['verbose'],
837
- user_path=kwargs['user_path'],
838
- n_jobs=kwargs['n_jobs'],
839
- )
840
-
841
- add_my_file_outputs = [fileup_output, langchain_mode, my_db_state, add_to_shared_db_btn, add_to_my_db_btn]
842
- add_my_file_kwargs = dict(fn=update_my_db_func,
843
- inputs=[fileup_output, my_db_state, add_to_shared_db_btn, add_to_my_db_btn,
844
- chunk, chunk_size],
845
- outputs=add_my_file_outputs + [sources_text],
846
- queue=queue,
847
- api_name='add_to_my' if allow_api and allow_upload_to_my_data else None)
848
-
849
- if not allow_upload_to_user_data and allow_upload_to_my_data:
850
- func4 = fileup_output.change
851
- else:
852
- func4 = add_to_my_db_btn.click
853
-
854
- eventdb4a = func4(make_non_interactive, inputs=add_my_file_outputs,
855
- outputs=add_my_file_outputs,
856
- show_progress='minimal')
857
- eventdb4 = eventdb4a.then(**add_my_file_kwargs, show_progress='minimal')
858
- eventdb4.then(make_interactive, inputs=add_my_file_outputs, outputs=add_my_file_outputs,
859
- show_progress='minimal')
860
-
861
- update_my_db_url_func = functools.partial(update_my_db_func, is_url=True)
862
- add_my_url_outputs = [url_text, langchain_mode, my_db_state, url_user_btn, url_my_btn]
863
- add_my_url_kwargs = dict(fn=update_my_db_url_func,
864
- inputs=[url_text, my_db_state, url_user_btn, url_my_btn,
865
- chunk, chunk_size],
866
- outputs=add_my_url_outputs + [sources_text],
867
- queue=queue,
868
- api_name='add_url_to_my' if allow_api and allow_upload_to_my_data else None)
869
- if not allow_upload_to_user_data and allow_upload_to_my_data:
870
- func5 = url_text.submit
871
- else:
872
- func5 = url_my_btn.click
873
- eventdb5a = func5(fn=dummy_fun, inputs=url_text, outputs=url_text, queue=queue,
874
- show_progress='minimal')
875
- eventdb5b = eventdb5a.then(make_non_interactive, inputs=add_my_url_outputs, outputs=add_my_url_outputs,
876
- show_progress='minimal')
877
- eventdb5 = eventdb5b.then(**add_my_url_kwargs, show_progress='minimal')
878
- eventdb5.then(make_interactive, inputs=add_my_url_outputs, outputs=add_my_url_outputs,
879
- show_progress='minimal')
880
-
881
- update_my_db_txt_func = functools.partial(update_my_db_func, is_txt=True)
882
-
883
- add_my_text_outputs = [user_text_text, langchain_mode, my_db_state, user_text_user_btn,
884
- user_text_my_btn]
885
- add_my_text_kwargs = dict(fn=update_my_db_txt_func,
886
- inputs=[user_text_text, my_db_state, user_text_user_btn, user_text_my_btn,
887
- chunk, chunk_size],
888
- outputs=add_my_text_outputs + [sources_text],
889
- queue=queue,
890
- api_name='add_txt_to_my' if allow_api and allow_upload_to_my_data else None)
891
- if not allow_upload_to_user_data and allow_upload_to_my_data:
892
- func6 = user_text_text.submit
893
- else:
894
- func6 = user_text_my_btn.click
895
-
896
- eventdb6a = func6(fn=dummy_fun, inputs=user_text_text, outputs=user_text_text, queue=queue,
897
- show_progress='minimal')
898
- eventdb6b = eventdb6a.then(make_non_interactive, inputs=add_my_text_outputs, outputs=add_my_text_outputs,
899
- show_progress='minimal')
900
- eventdb6 = eventdb6b.then(**add_my_text_kwargs, show_progress='minimal')
901
- eventdb6.then(make_interactive, inputs=add_my_text_outputs, outputs=add_my_text_outputs,
902
- show_progress='minimal')
903
 
904
  get_sources1 = functools.partial(get_sources, dbs=dbs, docs_state0=docs_state0)
905
 
906
  # if change collection source, must clear doc selections from it to avoid inconsistency
907
  def clear_doc_choice():
908
- return gr.Dropdown.update(choices=docs_state0, value=[docs_state0[0]])
 
 
909
 
910
- langchain_mode.change(clear_doc_choice, inputs=None, outputs=document_choice)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
911
 
912
  def update_dropdown(x):
913
  return gr.Dropdown.update(choices=x, value=[docs_state0[0]])
914
 
915
- eventdb7 = get_sources_btn.click(get_sources1, inputs=[my_db_state, langchain_mode],
916
- outputs=[file_source, docs_state],
917
- queue=queue,
918
- api_name='get_sources' if allow_api else None) \
 
 
919
  .then(fn=update_dropdown, inputs=docs_state, outputs=document_choice)
920
  # show button, else only show when add. Could add to above get_sources for download/dropdown, but bit much maybe
921
  show_sources1 = functools.partial(get_source_files_given_langchain_mode, dbs=dbs)
922
  eventdb8 = show_sources_btn.click(fn=show_sources1, inputs=[my_db_state, langchain_mode], outputs=sources_text,
923
  api_name='show_sources' if allow_api else None)
924
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
925
  # Get inputs to evaluate() and make_db()
926
  # don't deepcopy, can contain model itself
927
  all_kwargs = kwargs.copy()
@@ -1008,9 +1072,6 @@ def go_gradio(**kwargs):
1008
  **kwargs_evaluate
1009
  )
1010
 
1011
- dark_mode_btn = gr.Button("Dark Mode", variant="primary", size="sm")
1012
- # FIXME: Could add exceptions for non-chat but still streaming
1013
- exception_text = gr.Textbox(value="", visible=kwargs['chat'], label='Chat Exceptions', interactive=False)
1014
  dark_mode_btn.click(
1015
  None,
1016
  None,
@@ -1020,20 +1081,19 @@ def go_gradio(**kwargs):
1020
  queue=False,
1021
  )
1022
 
1023
- # Control chat and non-chat blocks, which can be independently used by chat checkbox swap
1024
- def col_nochat_fun(x):
1025
- return gr.Column.update(visible=not x)
1026
 
1027
- def col_chat_fun(x):
1028
- return gr.Column.update(visible=bool(x))
 
 
1029
 
1030
- def context_fun(x):
1031
- return gr.Textbox.update(visible=not x)
1032
-
1033
- chat.select(col_nochat_fun, chat, col_nochat, api_name="chat_checkbox" if allow_api else None) \
1034
- .then(col_chat_fun, chat, col_chat) \
1035
- .then(context_fun, chat, context) \
1036
- .then(col_chat_fun, chat, exception_text)
1037
 
1038
  # examples after submit or any other buttons for chat or no chat
1039
  if kwargs['examples'] is not None and kwargs['show_examples']:
@@ -1154,6 +1214,7 @@ def go_gradio(**kwargs):
1154
  prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
1155
  langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
1156
  langchain_action1 = args_list[eval_func_param_names.index('langchain_action')]
 
1157
  document_choice1 = args_list[eval_func_param_names.index('document_choice')]
1158
  if not prompt_type1:
1159
  # shouldn't have to specify if CLI launched model
@@ -1186,7 +1247,7 @@ def go_gradio(**kwargs):
1186
  return history
1187
  if user_message1 in ['', None, '\n']:
1188
  if langchain_action1 in LangChainAction.QUERY.value and \
1189
- DocumentChoices.Only_All_Sources.name not in document_choice1 \
1190
  or \
1191
  langchain_mode1 in [LangChainMode.CHAT_LLM.value, LangChainMode.LLM.value]:
1192
  # reject non-retry submit/enter
@@ -1249,6 +1310,7 @@ def go_gradio(**kwargs):
1249
  args_list = args_list[:-3] # only keep rest needed for evaluate()
1250
  langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
1251
  langchain_action1 = args_list[eval_func_param_names.index('langchain_action')]
 
1252
  document_choice1 = args_list[eval_func_param_names.index('document_choice')]
1253
  if not history:
1254
  print("No history", flush=True)
@@ -1261,7 +1323,7 @@ def go_gradio(**kwargs):
1261
  history[-1][1] = None
1262
  elif not instruction1:
1263
  if langchain_action1 in LangChainAction.QUERY.value and \
1264
- DocumentChoices.Only_All_Sources.name not in document_choice1 \
1265
  or \
1266
  langchain_mode1 in [LangChainMode.CHAT_LLM.value, LangChainMode.LLM.value]:
1267
  # if not retrying, then reject empty query
@@ -1432,11 +1494,11 @@ def go_gradio(**kwargs):
1432
  )
1433
  bot_args = dict(fn=bot,
1434
  inputs=inputs_list + [model_state, my_db_state] + [text_output],
1435
- outputs=[text_output, exception_text],
1436
  )
1437
  retry_bot_args = dict(fn=functools.partial(bot, retry=True),
1438
  inputs=inputs_list + [model_state, my_db_state] + [text_output],
1439
- outputs=[text_output, exception_text],
1440
  )
1441
  retry_user_args = dict(fn=functools.partial(user, retry=True),
1442
  inputs=inputs_list + [text_output],
@@ -1454,11 +1516,11 @@ def go_gradio(**kwargs):
1454
  )
1455
  bot_args2 = dict(fn=bot,
1456
  inputs=inputs_list2 + [model_state2, my_db_state] + [text_output2],
1457
- outputs=[text_output2, exception_text],
1458
  )
1459
  retry_bot_args2 = dict(fn=functools.partial(bot, retry=True),
1460
  inputs=inputs_list2 + [model_state2, my_db_state] + [text_output2],
1461
- outputs=[text_output2, exception_text],
1462
  )
1463
  retry_user_args2 = dict(fn=functools.partial(user, retry=True),
1464
  inputs=inputs_list2 + [text_output2],
@@ -1479,11 +1541,11 @@ def go_gradio(**kwargs):
1479
  )
1480
  all_bot_args = dict(fn=functools.partial(all_bot, model_states1=model_states),
1481
  inputs=inputs_list + [my_db_state] + text_outputs,
1482
- outputs=text_outputs + [exception_text],
1483
  )
1484
  all_retry_bot_args = dict(fn=functools.partial(all_bot, model_states1=model_states, retry=True),
1485
  inputs=inputs_list + [my_db_state] + text_outputs,
1486
- outputs=text_outputs + [exception_text],
1487
  )
1488
  all_retry_user_args = dict(fn=functools.partial(all_user, retry=True,
1489
  sanitize_user_prompt=kwargs['sanitize_user_prompt'],
@@ -1681,13 +1743,26 @@ def go_gradio(**kwargs):
1681
  return False
1682
  return is_same
1683
 
1684
- def save_chat(*args):
1685
  args_list = list(args)
1686
- chat_list = args_list[:-1] # list of chatbot histories
 
 
 
 
 
 
 
 
 
 
1687
  # remove None histories
1688
  chat_list_not_none = [x for x in chat_list if x and len(x) > 0 and len(x[0]) == 2 and x[0][1] is not None]
1689
- chat_state1 = args_list[
1690
- -1] # dict with keys of short chat names, values of list of list of chatbot histories
 
 
 
1691
  short_chats = list(chat_state1.keys())
1692
  if len(chat_list_not_none) > 0:
1693
  # make short_chat key from only first history, based upon question that is same anyways
@@ -1699,13 +1774,14 @@ def go_gradio(**kwargs):
1699
  if not already_exists:
1700
  chat_state1[short_chat] = chat_list.copy()
1701
  # clear chat_list so saved and then new conversation starts
1702
- chat_list = [[]] * len(chat_list)
1703
- ret_list = chat_list + [chat_state1]
 
 
 
 
1704
  return tuple(ret_list)
1705
 
1706
- def update_radio_chats(chat_state1):
1707
- return gr.update(choices=list(chat_state1.keys()), value=None)
1708
-
1709
  def switch_chat(chat_key, chat_state1, num_model_lock=0):
1710
  chosen_chat = chat_state1[chat_key]
1711
  # deal with possible different size of chat list vs. current list
@@ -1729,11 +1805,13 @@ def go_gradio(**kwargs):
1729
  .then(clear_scores, outputs=[score_text, score_text2, score_text_nochat])
1730
 
1731
  def remove_chat(chat_key, chat_state1):
1732
- chat_state1.pop(chat_key, None)
1733
- return chat_state1
 
1734
 
1735
- remove_chat_btn.click(remove_chat, inputs=[radio_chats, chat_state], outputs=chat_state) \
1736
- .then(update_radio_chats, inputs=chat_state, outputs=radio_chats)
 
1737
 
1738
  def get_chats1(chat_state1):
1739
  base = 'chats'
@@ -1743,18 +1821,19 @@ def go_gradio(**kwargs):
1743
  f.write(json.dumps(chat_state1, indent=2))
1744
  return filename
1745
 
1746
- export_chats_btn.click(get_chats1, inputs=chat_state, outputs=chats_file, queue=False,
1747
- api_name='export_chats' if allow_api else None)
1748
 
1749
- def add_chats_from_file(file, chat_state1, add_btn):
1750
  if not file:
1751
- return chat_state1, add_btn
1752
  if isinstance(file, str):
1753
  files = [file]
1754
  else:
1755
  files = file
1756
  if not files:
1757
- return chat_state1, add_btn
 
1758
  for file1 in files:
1759
  try:
1760
  if hasattr(file1, 'name'):
@@ -1763,33 +1842,42 @@ def go_gradio(**kwargs):
1763
  new_chats = json.loads(f.read())
1764
  for chat1_k, chat1_v in new_chats.items():
1765
  # ignore chat1_k, regenerate and de-dup to avoid loss
1766
- _, chat_state1 = save_chat(chat1_v, chat_state1)
1767
  except BaseException as e:
1768
  t, v, tb = sys.exc_info()
1769
  ex = ''.join(traceback.format_exception(t, v, tb))
1770
- print("Add chats exception: %s" % str(ex), flush=True)
1771
- return chat_state1, add_btn
 
 
 
1772
 
1773
  # note for update_user_db_func output is ignored for db
1774
- add_to_chats_btn.click(add_chats_from_file,
1775
- inputs=[chatsup_output, chat_state, add_to_chats_btn],
1776
- outputs=[chat_state, add_to_my_db_btn], queue=False,
1777
- api_name='add_to_chats' if allow_api else None) \
1778
- .then(clear_file_list, outputs=chatsup_output, queue=False) \
1779
- .then(update_radio_chats, inputs=chat_state, outputs=radio_chats, queue=False)
1780
-
1781
- clear_chat_btn.click(fn=clear_texts,
1782
- inputs=[text_output, text_output2] + text_outputs,
1783
- outputs=[text_output, text_output2] + text_outputs,
1784
- queue=False, api_name='clear' if allow_api else None) \
 
1785
  .then(deselect_radio_chats, inputs=None, outputs=radio_chats, queue=False) \
1786
  .then(clear_scores, outputs=[score_text, score_text2, score_text_nochat])
1787
 
1788
- # does both models
1789
- clear.click(save_chat,
1790
- inputs=[text_output, text_output2] + text_outputs + [chat_state],
1791
- outputs=[text_output, text_output2] + text_outputs + [chat_state],
1792
- api_name='save_chat' if allow_api else None) \
 
 
 
 
 
1793
  .then(update_radio_chats, inputs=chat_state, outputs=radio_chats,
1794
  api_name='update_chats' if allow_api else None) \
1795
  .then(clear_scores, outputs=[score_text, score_text2, score_text_nochat])
@@ -1823,7 +1911,7 @@ def go_gradio(**kwargs):
1823
  .then(clear_torch_cache)
1824
 
1825
  def load_model(model_name, lora_weights, server_name, model_state_old, prompt_type_old, load_8bit,
1826
- infer_devices, gpu_id):
1827
  # ensure no API calls reach here
1828
  if is_public:
1829
  raise RuntimeError("Illegal access for %s" % model_name)
@@ -1867,7 +1955,7 @@ def go_gradio(**kwargs):
1867
  all_kwargs1 = all_kwargs.copy()
1868
  all_kwargs1['base_model'] = model_name.strip()
1869
  all_kwargs1['load_8bit'] = load_8bit
1870
- all_kwargs1['infer_devices'] = infer_devices
1871
  all_kwargs1['gpu_id'] = int(gpu_id) # detranscribe
1872
  model_lower = model_name.strip().lower()
1873
  if model_lower in inv_prompt_type_to_model_lower:
@@ -1920,8 +2008,9 @@ def go_gradio(**kwargs):
1920
 
1921
  get_prompt_str_func1 = functools.partial(get_prompt_str, which=1)
1922
  get_prompt_str_func2 = functools.partial(get_prompt_str, which=2)
1923
- prompt_type.change(fn=get_prompt_str_func1, inputs=[prompt_type, prompt_dict], outputs=prompt_dict)
1924
- prompt_type2.change(fn=get_prompt_str_func2, inputs=[prompt_type2, prompt_dict2], outputs=prompt_dict2)
 
1925
 
1926
  def dropdown_prompt_type_list(x):
1927
  return gr.Dropdown.update(value=x)
@@ -1931,7 +2020,7 @@ def go_gradio(**kwargs):
1931
 
1932
  load_model_args = dict(fn=load_model,
1933
  inputs=[model_choice, lora_choice, server_choice, model_state, prompt_type,
1934
- model_load8bit_checkbox, model_infer_devices_checkbox, model_gpu],
1935
  outputs=[model_state, model_used, lora_used, server_used,
1936
  # if prompt_type changes, prompt_dict will change via change rule
1937
  prompt_type, max_new_tokens, min_new_tokens,
@@ -1939,28 +2028,27 @@ def go_gradio(**kwargs):
1939
  prompt_update_args = dict(fn=dropdown_prompt_type_list, inputs=prompt_type, outputs=prompt_type)
1940
  chatbot_update_args = dict(fn=chatbot_list, inputs=[text_output, model_used], outputs=text_output)
1941
  nochat_update_args = dict(fn=chatbot_list, inputs=[text_output_nochat, model_used], outputs=text_output_nochat)
1942
- if not is_public:
1943
- load_model_event = load_model_button.click(**load_model_args, api_name='load_model' if allow_api else None) \
1944
- .then(**prompt_update_args) \
1945
- .then(**chatbot_update_args) \
1946
- .then(**nochat_update_args) \
1947
- .then(clear_torch_cache)
1948
 
1949
  load_model_args2 = dict(fn=load_model,
1950
  inputs=[model_choice2, lora_choice2, server_choice2, model_state2, prompt_type2,
1951
- model_load8bit_checkbox2, model_infer_devices_checkbox2, model_gpu2],
1952
  outputs=[model_state2, model_used2, lora_used2, server_used2,
1953
  # if prompt_type2 changes, prompt_dict2 will change via change rule
1954
  prompt_type2, max_new_tokens2, min_new_tokens2
1955
  ])
1956
  prompt_update_args2 = dict(fn=dropdown_prompt_type_list, inputs=prompt_type2, outputs=prompt_type2)
1957
  chatbot_update_args2 = dict(fn=chatbot_list, inputs=[text_output2, model_used2], outputs=text_output2)
1958
- if not is_public:
1959
- load_model_event2 = load_model_button2.click(**load_model_args2,
1960
- api_name='load_model2' if allow_api else None) \
1961
- .then(**prompt_update_args2) \
1962
- .then(**chatbot_update_args2) \
1963
- .then(clear_torch_cache)
1964
 
1965
  def dropdown_model_lora_server_list(model_list0, model_x,
1966
  lora_list0, lora_x,
@@ -2009,7 +2097,8 @@ def go_gradio(**kwargs):
2009
  server_options_state],
2010
  queue=False)
2011
 
2012
- go_btn.click(lambda: gr.update(visible=False), None, go_btn, api_name="go" if allow_api else None, queue=False) \
 
2013
  .then(lambda: gr.update(visible=True), None, normal_block, queue=False) \
2014
  .then(**load_model_args, queue=False).then(**prompt_update_args, queue=False)
2015
 
@@ -2077,23 +2166,11 @@ def go_gradio(**kwargs):
2077
  def get_hash():
2078
  return kwargs['git_hash']
2079
 
2080
- system_btn3.click(get_hash,
2081
- outputs=system_text3,
2082
- api_name='system_hash' if allow_api else None,
2083
- queue=False,
2084
- )
2085
-
2086
- # don't pass text_output, don't want to clear output, just stop it
2087
- # cancel only stops outer generation, not inner generation or non-generation
2088
- stop_btn.click(lambda: None, None, None,
2089
- cancels=submits1 + submits2 + submits3 +
2090
- submits4 +
2091
- [submit_event_nochat, submit_event_nochat2] +
2092
- [eventdb1, eventdb2, eventdb3,
2093
- eventdb4, eventdb5, eventdb6] +
2094
- [eventdb7, eventdb8, eventdb9]
2095
- ,
2096
- queue=False, api_name='stop' if allow_api else None).then(clear_torch_cache, queue=False)
2097
 
2098
  def count_chat_tokens(model_state1, chat1, prompt_type1, prompt_dict1,
2099
  memory_restriction_level1=0,
@@ -2121,9 +2198,25 @@ def go_gradio(**kwargs):
2121
  count_chat_tokens_func = functools.partial(count_chat_tokens,
2122
  memory_restriction_level1=memory_restriction_level,
2123
  keep_sources_in_context1=kwargs['keep_sources_in_context'])
2124
- count_chat_tokens_btn.click(fn=count_chat_tokens,
2125
- inputs=[model_state, text_output, prompt_type, prompt_dict],
2126
- outputs=chat_token_count, api_name='count_tokens' if allow_api else None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2127
 
2128
  demo.load(None, None, None, _js=get_dark_js() if kwargs['h2ocolors'] and False else None) # light best
2129
 
@@ -2196,6 +2289,8 @@ def get_inputs_list(inputs_dict, model_lower, model_id=1):
2196
 
2197
 
2198
  def get_sources(db1, langchain_mode, dbs=None, docs_state0=None):
 
 
2199
  if langchain_mode in ['ChatLLM', 'LLM']:
2200
  source_files_added = "NA"
2201
  source_list = []
@@ -2226,9 +2321,24 @@ def get_sources(db1, langchain_mode, dbs=None, docs_state0=None):
2226
  return sources_file, source_list
2227
 
2228
 
2229
- def update_user_db(file, db1, x, y, *args, dbs=None, langchain_mode='UserData', **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
2230
  try:
2231
- return _update_user_db(file, db1, x, y, *args, dbs=dbs, langchain_mode=langchain_mode, **kwargs)
 
 
2232
  except BaseException as e:
2233
  print(traceback.format_exc(), flush=True)
2234
  # gradio has issues if except, so fail semi-gracefully, else would hang forever in processing textbox
@@ -2245,15 +2355,14 @@ def update_user_db(file, db1, x, y, *args, dbs=None, langchain_mode='UserData',
2245
  </body>
2246
  </html>
2247
  """.format(ex_str)
2248
- if langchain_mode == 'MyData':
2249
- return None, langchain_mode, db1, x, y, source_files_added
2250
- else:
2251
- return None, langchain_mode, x, y, source_files_added
2252
  finally:
2253
  clear_torch_cache()
2254
 
2255
 
2256
  def get_lock_file(db1, langchain_mode):
 
2257
  assert len(db1) == 2 and db1[1] is not None and isinstance(db1[1], str)
2258
  user_id = db1[1]
2259
  base_path = 'locks'
@@ -2262,7 +2371,10 @@ def get_lock_file(db1, langchain_mode):
2262
  return lock_file
2263
 
2264
 
2265
- def _update_user_db(file, db1, x, y, chunk, chunk_size, dbs=None, db_type=None, langchain_mode='UserData',
 
 
 
2266
  user_path=None,
2267
  use_openai_embedding=None,
2268
  hf_embedding_model=None,
@@ -2273,6 +2385,9 @@ def _update_user_db(file, db1, x, y, chunk, chunk_size, dbs=None, db_type=None,
2273
  verbose=None,
2274
  is_url=None, is_txt=None,
2275
  n_jobs=-1):
 
 
 
2276
  assert use_openai_embedding is not None
2277
  assert hf_embedding_model is not None
2278
  assert caption_loader is not None
@@ -2281,6 +2396,8 @@ def _update_user_db(file, db1, x, y, chunk, chunk_size, dbs=None, db_type=None,
2281
  assert enable_ocr is not None
2282
  assert verbose is not None
2283
 
 
 
2284
  if dbs is None:
2285
  dbs = {}
2286
  assert isinstance(dbs, dict), "Wrong type for dbs: %s" % str(type(dbs))
@@ -2295,6 +2412,14 @@ def _update_user_db(file, db1, x, y, chunk, chunk_size, dbs=None, db_type=None,
2295
  if not isinstance(file, (list, tuple, typing.Generator)) and isinstance(file, str):
2296
  file = [file]
2297
 
 
 
 
 
 
 
 
 
2298
  if langchain_mode == 'UserData' and user_path is not None:
2299
  # move temp files from gradio upload to stable location
2300
  for fili, fil in enumerate(file):
@@ -2323,6 +2448,7 @@ def _update_user_db(file, db1, x, y, chunk, chunk_size, dbs=None, db_type=None,
2323
  caption_loader=caption_loader,
2324
  )
2325
  exceptions = [x for x in sources if x.metadata.get('exception')]
 
2326
  sources = [x for x in sources if 'exception' not in x.metadata]
2327
 
2328
  lock_file = get_lock_file(db1, langchain_mode)
@@ -2349,7 +2475,7 @@ def _update_user_db(file, db1, x, y, chunk, chunk_size, dbs=None, db_type=None,
2349
  if db is not None:
2350
  db1[0] = db
2351
  source_files_added = get_source_files(db=db1[0], exceptions=exceptions)
2352
- return None, langchain_mode, db1, x, y, source_files_added
2353
  else:
2354
  from gpt_langchain import get_persist_directory
2355
  persist_directory = get_persist_directory(langchain_mode)
@@ -2367,10 +2493,10 @@ def _update_user_db(file, db1, x, y, chunk, chunk_size, dbs=None, db_type=None,
2367
  hf_embedding_model=hf_embedding_model)
2368
  dbs[langchain_mode] = db
2369
  # NOTE we do not return db, because function call always same code path
2370
- # return dbs[langchain_mode], x, y
2371
  # db in this code path is updated in place
2372
  source_files_added = get_source_files(db=dbs[langchain_mode], exceptions=exceptions)
2373
- return None, langchain_mode, x, y, source_files_added
2374
 
2375
 
2376
  def get_db(db1, langchain_mode, dbs=None):
 
20
  from iterators import TimeoutIterator
21
 
22
  from gradio_utils.css import get_css
23
+ from gradio_utils.prompt_form import make_chatbots
24
 
25
  # This is a hack to prevent Gradio from phoning home when it gets imported
26
  os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
 
56
  from prompter import prompt_type_to_model_name, prompt_types_strings, inv_prompt_type_to_model_lower, non_hf_types, \
57
  get_prompt
58
  from utils import get_githash, flatten_list, zip_data, s3up, clear_torch_cache, get_torch_allocated, system_info_print, \
59
+ ping, get_short_name, makedirs, get_kwargs, remove, system_info, ping_gpu, get_url, get_local_ip
60
  from gen import get_model, languages_covered, evaluate, score_qa, langchain_modes, inputs_kwargs_list, scratch_base_dir, \
61
  get_max_max_new_tokens, get_minmax_top_k_docs, history_to_context, langchain_actions
62
  from evaluate_params import eval_func_param_names, no_default_param_names, eval_func_param_names_defaults
 
118
  allow_upload = allow_upload_to_user_data or allow_upload_to_my_data
119
  kwargs.update(locals())
120
 
121
+ # import control
122
+ if kwargs['langchain_mode'] != 'Disabled':
123
+ from gpt_langchain import file_types, have_arxiv
124
+ else:
125
+ have_arxiv = False
126
+ file_types = []
127
+
128
  if 'mbart-' in kwargs['model_lower']:
129
  instruction_label_nochat = "Text to translate"
130
  else:
 
141
  """
142
  else:
143
  description = more_info
144
+ description_bottom = "If this host is busy, try [Multi-Model](https://gpt.h2o.ai), [Falcon 40B](http://falcon.h2o.ai), [HF Spaces1](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot) or [HF Spaces2](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot2)<br>"
 
145
  if is_hf:
146
  description_bottom += '''<a href="https://huggingface.co/spaces/h2oai/h2ogpt-chatbot?duplicate=true"><img src="https://bit.ly/3gLdBN6" style="white-space: nowrap" alt="Duplicate Space"></a>'''
147
 
 
166
  theme_kwargs = dict()
167
  if kwargs['gradio_size'] == 'xsmall':
168
  theme_kwargs.update(dict(spacing_size=spacing_xsm, text_size=text_xsm, radius_size=radius_xsm))
169
+ elif kwargs['gradio_size'] in [None, 'small']:
170
  theme_kwargs.update(dict(spacing_size=gr.themes.sizes.spacing_sm, text_size=gr.themes.sizes.text_sm,
171
  radius_size=gr.themes.sizes.spacing_sm))
172
  elif kwargs['gradio_size'] == 'large':
 
268
  model_options_state = gr.State([model_options])
269
  lora_options_state = gr.State([lora_options])
270
  server_options_state = gr.State([server_options])
271
+ my_db_state = gr.State([None, None])
 
272
  chat_state = gr.State({})
273
+ docs_state00 = kwargs['document_choice'] + [DocumentChoices.All.name]
 
274
  docs_state0 = []
275
  [docs_state0.append(x) for x in docs_state00 if x not in docs_state0]
276
+ docs_state = gr.State(docs_state0)
277
+ viewable_docs_state0 = []
278
+ viewable_docs_state = gr.State(viewable_docs_state0)
279
  gr.Markdown(f"""
280
  {get_h2o_title(title, description) if kwargs['h2ocolors'] else get_simple_title(title, description)}
281
  """)
 
288
  res_value = "Response Score: NA" if not kwargs[
289
  'model_lock'] else "Response Scores: %s" % nas
290
 
291
+ if kwargs['langchain_mode'] != LangChainMode.DISABLED.value:
292
+ extra_prompt_form = ". For summarization, empty submission uses first top_k_docs documents."
293
+ else:
294
+ extra_prompt_form = ""
295
+ if kwargs['input_lines'] > 1:
296
+ instruction_label = "Shift-Enter to Submit, Enter for more lines%s" % extra_prompt_form
297
+ else:
298
+ instruction_label = "Enter to Submit, Shift-Enter for more lines%s" % extra_prompt_form
299
+
300
+ normal_block = gr.Row(visible=not base_wanted, equal_height=False)
301
  with normal_block:
302
+ side_bar = gr.Column(elem_id="col_container", scale=1, min_width=100)
303
+ with side_bar:
304
+ with gr.Accordion("Chats", open=False, visible=True):
305
+ radio_chats = gr.Radio(value=None, label="Saved Chats", show_label=False,
306
+ visible=True, interactive=True,
307
+ type='value')
308
+ upload_visible = kwargs['langchain_mode'] != 'Disabled' and allow_upload
309
+ with gr.Accordion("Upload", open=False, visible=upload_visible):
310
+ with gr.Column():
311
+ with gr.Row(equal_height=False):
312
+ file_types_str = '[' + ' '.join(file_types) + ' URL ArXiv TEXT' + ']'
313
+ fileup_output = gr.File(label=f'Upload {file_types_str}',
314
+ show_label=False,
315
+ file_types=file_types,
316
+ file_count="multiple",
317
+ scale=1,
318
+ min_width=0,
319
+ elem_id="warning", elem_classes="feedback")
320
+ url_visible = kwargs['langchain_mode'] != 'Disabled' and allow_upload and enable_url_upload
321
+ url_label = 'URL/ArXiv' if have_arxiv else 'URL'
322
+ url_text = gr.Textbox(label=url_label,
323
+ # placeholder="Enter Submits",
324
+ max_lines=1,
325
+ interactive=True)
326
+ text_visible = kwargs['langchain_mode'] != 'Disabled' and allow_upload and enable_text_upload
327
+ user_text_text = gr.Textbox(label='Paste Text',
328
+ # placeholder="Enter Submits",
329
+ interactive=True,
330
+ visible=text_visible)
331
+ github_textbox = gr.Textbox(label="Github URL", visible=False) # FIXME WIP
332
+ database_visible = kwargs['langchain_mode'] != 'Disabled'
333
+ with gr.Accordion("Database", open=False, visible=database_visible):
334
+ if is_hf:
335
+ # don't show 'wiki' since only usually useful for internal testing at moment
336
+ no_show_modes = ['Disabled', 'wiki']
337
+ else:
338
+ no_show_modes = ['Disabled']
339
+ allowed_modes = visible_langchain_modes.copy()
340
+ allowed_modes = [x for x in allowed_modes if x in dbs]
341
+ allowed_modes += ['ChatLLM', 'LLM']
342
+ if allow_upload_to_my_data and 'MyData' not in allowed_modes:
343
+ allowed_modes += ['MyData']
344
+ if allow_upload_to_user_data and 'UserData' not in allowed_modes:
345
+ allowed_modes += ['UserData']
346
+ langchain_mode = gr.Radio(
347
+ [x for x in langchain_modes if x in allowed_modes and x not in no_show_modes],
348
+ value=kwargs['langchain_mode'],
349
+ label="Collections",
350
+ show_label=True,
351
+ visible=kwargs['langchain_mode'] != 'Disabled',
352
+ min_width=100)
353
+ document_subset = gr.Radio([x.name for x in DocumentChoices],
354
+ label="Subset",
355
+ value=DocumentChoices.Relevant.name,
356
+ interactive=True,
357
+ )
358
+ allowed_actions = [x for x in langchain_actions if x in visible_langchain_actions]
359
+ langchain_action = gr.Radio(
360
+ allowed_actions,
361
+ value=allowed_actions[0] if len(allowed_actions) > 0 else None,
362
+ label="Action",
363
+ visible=True)
364
+ col_tabs = gr.Column(elem_id="col_container", scale=10)
365
+ with (col_tabs, gr.Tabs()):
366
+ with gr.TabItem("Chat"):
367
+ if kwargs['langchain_mode'] == 'Disabled':
368
+ text_output_nochat = gr.Textbox(lines=5, label=output_label0, show_copy_button=True,
369
+ visible=not kwargs['chat'])
370
+ else:
371
+ # text looks a bit worse, but HTML links work
372
+ text_output_nochat = gr.HTML(label=output_label0, visible=not kwargs['chat'])
373
+ with gr.Row():
374
+ # NOCHAT
375
  instruction_nochat = gr.Textbox(
376
  lines=kwargs['input_lines'],
377
  label=instruction_label_nochat,
378
  placeholder=kwargs['placeholder_instruction'],
379
+ visible=not kwargs['chat'],
380
  )
381
  iinput_nochat = gr.Textbox(lines=4, label="Input context for Instruction",
382
+ placeholder=kwargs['placeholder_input'],
383
+ visible=not kwargs['chat'])
384
+ submit_nochat = gr.Button("Submit", size='sm', visible=not kwargs['chat'])
385
+ flag_btn_nochat = gr.Button("Flag", size='sm', visible=not kwargs['chat'])
386
+ score_text_nochat = gr.Textbox("Response Score: NA", show_label=False,
387
+ visible=not kwargs['chat'])
388
+ submit_nochat_api = gr.Button("Submit nochat API", visible=False)
389
+ inputs_dict_str = gr.Textbox(label='API input for nochat', show_label=False, visible=False)
390
+ text_output_nochat_api = gr.Textbox(lines=5, label='API nochat output', visible=False,
391
+ show_copy_button=True)
392
+
393
+ # CHAT
394
+ col_chat = gr.Column(visible=kwargs['chat'])
395
+ with col_chat:
396
+ with gr.Row(): # elem_id='prompt-form-area'):
397
+ with gr.Column(scale=50):
398
+ instruction = gr.Textbox(
399
+ lines=kwargs['input_lines'],
400
+ label='Ask anything',
401
+ placeholder=instruction_label,
402
+ info=None,
403
+ elem_id='prompt-form',
404
+ container=True,
405
+ )
406
+ submit_buttons = gr.Row(equal_height=False)
407
+ with submit_buttons:
408
+ mw1 = 50
409
+ mw2 = 50
410
+ with gr.Column(min_width=mw1):
411
+ submit = gr.Button(value='Submit', variant='primary', scale=0, size='sm',
412
+ min_width=mw1)
413
+ stop_btn = gr.Button(value="Stop", variant='secondary', scale=0, size='sm',
414
+ min_width=mw1)
415
+ save_chat_btn = gr.Button("Save", size='sm', min_width=mw1)
416
+ with gr.Column(min_width=mw2):
417
+ retry_btn = gr.Button("Redo", size='sm', min_width=mw2)
418
+ undo = gr.Button("Undo", size='sm', min_width=mw2)
419
+ clear_chat_btn = gr.Button(value="Clear", size='sm', min_width=mw2)
420
+ text_output, text_output2, text_outputs = make_chatbots(output_label0, output_label0_model2,
421
+ **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
422
 
 
 
 
 
 
 
 
 
 
423
  with gr.Row():
424
+ with gr.Column(visible=kwargs['score_model']):
425
+ score_text = gr.Textbox(res_value,
426
+ show_label=False,
427
+ visible=True)
428
+ score_text2 = gr.Textbox("Response Score2: NA", show_label=False,
429
+ visible=False and not kwargs['model_lock'])
430
+
431
+ with gr.TabItem("Document Selection"):
432
+ document_choice = gr.Dropdown(docs_state0,
433
+ label="Select Subset of Document(s) %s" % file_types_str,
434
+ value='All',
435
+ interactive=True,
436
+ multiselect=True,
437
+ )
438
+ sources_visible = kwargs['langchain_mode'] != 'Disabled' and enable_sources_list
439
+ with gr.Row():
440
+ get_sources_btn = gr.Button(value="Update UI with Document(s) from DB", scale=0, size='sm',
441
+ visible=sources_visible)
442
+ show_sources_btn = gr.Button(value="Show Sources from DB", scale=0, size='sm',
443
+ visible=sources_visible)
444
+ refresh_sources_btn = gr.Button(value="Update DB with new/changed files on disk", scale=0,
445
+ size='sm',
446
+ visible=sources_visible and allow_upload_to_user_data)
447
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
448
  sources_row = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and enable_sources_list,
449
  equal_height=False)
450
  with sources_row:
451
  with gr.Column(scale=1):
452
  file_source = gr.File(interactive=False,
453
+ label="Download File w/Sources")
454
  with gr.Column(scale=2):
455
  sources_text = gr.HTML(label='Sources Added', interactive=False)
456
 
457
+ doc_exception_text = gr.Textbox(value="", visible=True, label='Document Exceptions',
458
+ interactive=False)
459
+ with gr.TabItem("Document Viewer"):
460
+ with gr.Row():
461
+ with gr.Column(scale=2):
462
+ get_viewable_sources_btn = gr.Button(value="Update UI with Document(s) from DB", scale=0,
463
+ size='sm',
464
+ visible=sources_visible)
465
+ view_document_choice = gr.Dropdown(viewable_docs_state0,
466
+ label="Select Single Document",
467
+ value=None,
468
+ interactive=True,
469
+ multiselect=False,
470
+ )
471
+ with gr.Column(scale=4):
472
+ pass
473
+ document = 'http://infolab.stanford.edu/pub/papers/google.pdf'
474
+ doc_view = gr.HTML(visible=False)
475
+ doc_view2 = gr.Dataframe(visible=False)
476
+ doc_view3 = gr.JSON(visible=False)
477
+ doc_view4 = gr.Markdown(visible=False)
478
+
479
  with gr.TabItem("Chat History"):
480
+ with gr.Row():
481
+ with gr.Column(scale=1):
482
+ remove_chat_btn = gr.Button(value="Remove Selected Saved Chats", visible=True, size='sm')
483
+ flag_btn = gr.Button("Flag Current Chat", size='sm')
484
+ export_chats_btn = gr.Button(value="Export Chats to Download", size='sm')
485
+ with gr.Column(scale=4):
486
+ pass
487
+ with gr.Row():
488
+ chats_file = gr.File(interactive=False, label="Download Exported Chats")
489
+ chatsup_output = gr.File(label="Upload Chat File(s)",
490
+ file_types=['.json'],
491
+ file_count='multiple',
492
+ elem_id="warning", elem_classes="feedback")
493
  with gr.Row():
494
  if 'mbart-' in kwargs['model_lower']:
495
  src_lang = gr.Dropdown(list(languages_covered().keys()),
 
498
  tgt_lang = gr.Dropdown(list(languages_covered().keys()),
499
  value=kwargs['tgt_lang'],
500
  label="Output Language")
 
 
 
 
 
 
 
 
 
 
 
 
 
501
 
502
+ chat_exception_text = gr.Textbox(value="", visible=True, label='Chat Exceptions',
503
+ interactive=False)
504
  with gr.TabItem("Expert"):
505
  with gr.Row():
506
  with gr.Column():
 
579
  info="Directly pre-appended without prompt processing",
580
  interactive=not is_public)
581
  chat = gr.components.Checkbox(label="Chat mode", value=kwargs['chat'],
582
+ visible=False, # no longer support nochat in UI
583
  interactive=not is_public,
584
  )
585
  count_chat_tokens_btn = gr.Button(value="Count Chat Tokens",
 
638
  model_load8bit_checkbox = gr.components.Checkbox(
639
  label="Load 8-bit [requires support]",
640
  value=kwargs['load_8bit'], interactive=not is_public)
641
+ model_use_gpu_id_checkbox = gr.components.Checkbox(
642
  label="Choose Devices [If not Checked, use all GPUs]",
643
+ value=kwargs['use_gpu_id'], interactive=not is_public)
644
  model_gpu = gr.Dropdown(n_gpus_list,
645
  label="GPU ID [-1 = all GPUs, if Choose is enabled]",
646
  value=kwargs['gpu_id'], interactive=not is_public)
 
673
  model_load8bit_checkbox2 = gr.components.Checkbox(
674
  label="Load 8-bit 2 [requires support]",
675
  value=kwargs['load_8bit'], interactive=not is_public)
676
+ model_use_gpu_id_checkbox2 = gr.components.Checkbox(
677
  label="Choose Devices 2 [If not Checked, use all GPUs]",
678
  value=kwargs[
679
+ 'use_gpu_id'], interactive=not is_public)
680
  model_gpu2 = gr.Dropdown(n_gpus_list,
681
  label="GPU ID 2 [-1 = all GPUs, if choose is enabled]",
682
  value=kwargs['gpu_id'], interactive=not is_public)
 
703
  add_model_lora_server_button = gr.Button("Add new Model, Lora, Server url:port", scale=0,
704
  size='sm', interactive=not is_public)
705
  with gr.TabItem("System"):
706
+ with gr.Row():
707
+ with gr.Column(scale=1):
708
+ side_bar_text = gr.Textbox('on', visible=False, interactive=False)
709
+ submit_buttons_text = gr.Textbox('on', visible=False, interactive=False)
710
+
711
+ side_bar_btn = gr.Button("Toggle SideBar", variant="secondary", size="sm")
712
+ submit_buttons_btn = gr.Button("Toggle Submit Buttons", variant="secondary", size="sm")
713
+ col_tabs_scale = gr.Slider(minimum=1, maximum=20, value=10, step=1, label='Window Size')
714
+ text_outputs_height = gr.Slider(minimum=100, maximum=1000, value=kwargs['height'] or 400,
715
+ step=100, label='Chat Height')
716
+ dark_mode_btn = gr.Button("Dark Mode", variant="secondary", size="sm")
717
+ with gr.Column(scale=4):
718
+ pass
719
  admin_row = gr.Row()
720
  with admin_row:
721
+ with gr.Column(scale=1):
722
+ admin_pass_textbox = gr.Textbox(label="Admin Password", type='password', visible=is_public)
723
+ admin_btn = gr.Button(value="Admin Access", visible=is_public, size='sm')
724
+ with gr.Column(scale=4):
725
+ pass
726
  system_row = gr.Row(visible=not is_public)
727
  with system_row:
728
  with gr.Column():
729
  with gr.Row():
730
+ system_btn = gr.Button(value='Get System Info', size='sm')
731
  system_text = gr.Textbox(label='System Info', interactive=False, show_copy_button=True)
732
  with gr.Row():
733
  system_input = gr.Textbox(label='System Info Dict Password', interactive=True,
734
  visible=not is_public)
735
+ system_btn2 = gr.Button(value='Get System Info Dict', visible=not is_public, size='sm')
736
  system_text2 = gr.Textbox(label='System Info Dict', interactive=False,
737
  visible=not is_public, show_copy_button=True)
738
  with gr.Row():
739
+ system_btn3 = gr.Button(value='Get Hash', visible=not is_public, size='sm')
740
  system_text3 = gr.Textbox(label='Hash', interactive=False,
741
  visible=not is_public, show_copy_button=True)
742
 
743
  with gr.Row():
744
+ zip_btn = gr.Button("Zip", size='sm')
745
  zip_text = gr.Textbox(label="Zip file name", interactive=False)
746
  file_output = gr.File(interactive=False, label="Zip file to Download")
747
  with gr.Row():
748
+ s3up_btn = gr.Button("S3UP", size='sm')
749
  s3up_text = gr.Textbox(label='S3UP result', interactive=False)
750
+
751
+ with gr.TabItem("Terms of Service"):
752
  description = ""
753
  description += """<p><b> DISCLAIMERS: </b><ul><i><li>The model was trained on The Pile and other data, which may contain objectionable content. Use at own risk.</i></li>"""
754
  if kwargs['load_8bit']:
 
759
  description += """<i><li>By using h2oGPT, you accept our <a href="https://github.com/h2oai/h2ogpt/blob/main/docs/tos.md">Terms of Service</a></i></li></ul></p>"""
760
  gr.Markdown(value=description, show_label=False, interactive=False)
761
 
762
+ with gr.TabItem("Hosts"):
763
+ gr.Markdown(f"""
764
+ {description_bottom}
765
+ {task_info_md}
766
+ """)
767
 
768
  # Get flagged data
769
  zip_data1 = functools.partial(zip_data, root_dirs=['flagged_data_points', kwargs['save_dir']])
770
+ zip_event = zip_btn.click(zip_data1, inputs=None, outputs=[file_output, zip_text], queue=False,
771
+ api_name='zip_data' if allow_api else None)
772
+ s3up_event = s3up_btn.click(s3up, inputs=zip_text, outputs=s3up_text, queue=False,
773
+ api_name='s3up_data' if allow_api else None)
774
 
775
  def clear_file_list():
776
  return None
 
788
  return tuple([gr.update(interactive=True)] * len(args))
789
 
790
  # Add to UserData
791
+ update_db_func = functools.partial(update_user_db,
792
+ dbs=dbs,
793
+ db_type=db_type,
794
+ use_openai_embedding=use_openai_embedding,
795
+ hf_embedding_model=hf_embedding_model,
796
+ enable_captions=enable_captions,
797
+ captions_model=captions_model,
798
+ enable_ocr=enable_ocr,
799
+ caption_loader=caption_loader,
800
+ verbose=kwargs['verbose'],
801
+ user_path=kwargs['user_path'],
802
+ n_jobs=kwargs['n_jobs'],
803
+ )
804
+ add_file_outputs = [fileup_output, langchain_mode]
805
+ add_file_kwargs = dict(fn=update_db_func,
806
+ inputs=[fileup_output, my_db_state, chunk, chunk_size, langchain_mode],
807
+ outputs=add_file_outputs + [sources_text, doc_exception_text],
 
808
  queue=queue,
809
+ api_name='add_file' if allow_api and allow_upload_to_user_data else None)
810
 
 
 
 
 
811
  # then no need for add buttons, only single changeable db
812
+ eventdb1a = fileup_output.upload(make_non_interactive, inputs=add_file_outputs, outputs=add_file_outputs,
813
+ show_progress='minimal')
814
+ eventdb1 = eventdb1a.then(**add_file_kwargs, show_progress='full')
815
+ eventdb1b = eventdb1.then(make_interactive, inputs=add_file_outputs, outputs=add_file_outputs,
816
+ show_progress='minimal')
817
 
818
  # note for update_user_db_func output is ignored for db
819
 
820
  def clear_textbox():
821
  return gr.Textbox.update(value='')
822
 
823
+ update_user_db_url_func = functools.partial(update_db_func, is_url=True)
824
 
825
+ add_url_outputs = [url_text, langchain_mode]
826
  add_url_kwargs = dict(fn=update_user_db_url_func,
827
+ inputs=[url_text, my_db_state, chunk, chunk_size, langchain_mode],
828
+ outputs=add_url_outputs + [sources_text, doc_exception_text],
 
829
  queue=queue,
830
+ api_name='add_url' if allow_api and allow_upload_to_user_data else None)
831
 
832
+ eventdb2a = url_text.submit(fn=dummy_fun, inputs=url_text, outputs=url_text, queue=queue,
833
+ show_progress='minimal')
 
 
 
 
834
  # work around https://github.com/gradio-app/gradio/issues/4733
835
  eventdb2b = eventdb2a.then(make_non_interactive, inputs=add_url_outputs, outputs=add_url_outputs,
836
  show_progress='minimal')
837
+ eventdb2 = eventdb2b.then(**add_url_kwargs, show_progress='full')
838
+ eventdb2c = eventdb2.then(make_interactive, inputs=add_url_outputs, outputs=add_url_outputs,
839
+ show_progress='minimal')
840
 
841
+ update_user_db_txt_func = functools.partial(update_db_func, is_txt=True)
842
+ add_text_outputs = [user_text_text, langchain_mode]
843
  add_text_kwargs = dict(fn=update_user_db_txt_func,
844
+ inputs=[user_text_text, my_db_state, chunk, chunk_size, langchain_mode],
845
+ outputs=add_text_outputs + [sources_text, doc_exception_text],
 
846
  queue=queue,
847
+ api_name='add_text' if allow_api and allow_upload_to_user_data else None
848
  )
849
+ eventdb3a = user_text_text.submit(fn=dummy_fun, inputs=user_text_text, outputs=user_text_text, queue=queue,
850
+ show_progress='minimal')
 
 
 
 
 
851
  eventdb3b = eventdb3a.then(make_non_interactive, inputs=add_text_outputs, outputs=add_text_outputs,
852
  show_progress='minimal')
853
+ eventdb3 = eventdb3b.then(**add_text_kwargs, show_progress='full')
854
+ eventdb3c = eventdb3.then(make_interactive, inputs=add_text_outputs, outputs=add_text_outputs,
855
+ show_progress='minimal')
856
+ db_events = [eventdb1a, eventdb1, eventdb1b,
857
+ eventdb2a, eventdb2, eventdb2b, eventdb2c,
858
+ eventdb3a, eventdb3b, eventdb3, eventdb3c]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
859
 
860
  get_sources1 = functools.partial(get_sources, dbs=dbs, docs_state0=docs_state0)
861
 
862
  # if change collection source, must clear doc selections from it to avoid inconsistency
863
  def clear_doc_choice():
864
+ return gr.Dropdown.update(choices=docs_state0, value=DocumentChoices.All.name)
865
+
866
+ langchain_mode.change(clear_doc_choice, inputs=None, outputs=document_choice, queue=False)
867
 
868
+ def resize_col_tabs(x):
869
+ return gr.Dropdown.update(scale=x)
870
+
871
+ col_tabs_scale.change(fn=resize_col_tabs, inputs=col_tabs_scale, outputs=col_tabs)
872
+
873
+ def resize_chatbots(x, num_model_lock=0):
874
+ if num_model_lock == 0:
875
+ num_model_lock = 3 # 2 + 1 (which is dup of first)
876
+ else:
877
+ num_model_lock = 2 + num_model_lock
878
+ return tuple([gr.update(height=x)] * num_model_lock)
879
+
880
+ resize_chatbots_func = functools.partial(resize_chatbots, num_model_lock=len(text_outputs))
881
+ text_outputs_height.change(fn=resize_chatbots_func, inputs=text_outputs_height,
882
+ outputs=[text_output, text_output2] + text_outputs)
883
 
884
  def update_dropdown(x):
885
  return gr.Dropdown.update(choices=x, value=[docs_state0[0]])
886
 
887
+ get_sources_args = dict(fn=get_sources1, inputs=[my_db_state, langchain_mode],
888
+ outputs=[file_source, docs_state],
889
+ queue=queue,
890
+ api_name='get_sources' if allow_api else None)
891
+
892
+ eventdb7 = get_sources_btn.click(**get_sources_args) \
893
  .then(fn=update_dropdown, inputs=docs_state, outputs=document_choice)
894
  # show button, else only show when add. Could add to above get_sources for download/dropdown, but bit much maybe
895
  show_sources1 = functools.partial(get_source_files_given_langchain_mode, dbs=dbs)
896
  eventdb8 = show_sources_btn.click(fn=show_sources1, inputs=[my_db_state, langchain_mode], outputs=sources_text,
897
  api_name='show_sources' if allow_api else None)
898
 
899
+ def update_viewable_dropdown(x):
900
+ return gr.Dropdown.update(choices=x,
901
+ value=viewable_docs_state0[0] if len(viewable_docs_state0) > 0 else None)
902
+
903
+ get_viewable_sources1 = functools.partial(get_sources, dbs=dbs, docs_state0=viewable_docs_state0)
904
+ get_viewable_sources_args = dict(fn=get_viewable_sources1, inputs=[my_db_state, langchain_mode],
905
+ outputs=[file_source, viewable_docs_state],
906
+ queue=queue,
907
+ api_name='get_viewable_sources' if allow_api else None)
908
+ eventdb12 = get_viewable_sources_btn.click(**get_viewable_sources_args) \
909
+ .then(fn=update_viewable_dropdown, inputs=viewable_docs_state,
910
+ outputs=view_document_choice)
911
+
912
+ def show_doc(file):
913
+ dummy1 = gr.update(visible=False, value=None)
914
+ dummy_ret = dummy1, dummy1, dummy1, dummy1
915
+ if not isinstance(file, str):
916
+ return dummy_ret
917
+
918
+ if file.endswith('.md'):
919
+ try:
920
+ with open(file, 'rt') as f:
921
+ content = f.read()
922
+ return dummy1, dummy1, dummy1, gr.update(visible=True, value=content)
923
+ except:
924
+ return dummy_ret
925
+
926
+ if file.endswith('.py'):
927
+ try:
928
+ with open(file, 'rt') as f:
929
+ content = f.read()
930
+ content = f"```python\n{content}\n```"
931
+ return dummy1, dummy1, dummy1, gr.update(visible=True, value=content)
932
+ except:
933
+ return dummy_ret
934
+
935
+ if file.endswith('.txt') or file.endswith('.rst') or file.endswith('.rtf') or file.endswith('.toml'):
936
+ try:
937
+ with open(file, 'rt') as f:
938
+ content = f.read()
939
+ content = f"```text\n{content}\n```"
940
+ return dummy1, dummy1, dummy1, gr.update(visible=True, value=content)
941
+ except:
942
+ return dummy_ret
943
+
944
+ func = None
945
+ if file.endswith(".csv"):
946
+ func = pd.read_csv
947
+ elif file.endswith(".pickle"):
948
+ func = pd.read_pickle
949
+ elif file.endswith(".xls") or file.endswith("xlsx"):
950
+ func = pd.read_excel
951
+ elif file.endswith('.json'):
952
+ func = pd.read_json
953
+ elif file.endswith('.xml'):
954
+ func = pd.read_xml
955
+ if func is not None:
956
+ try:
957
+ df = func(file).head(100)
958
+ except:
959
+ return dummy_ret
960
+ return dummy1, gr.update(visible=True, value=df), dummy1, dummy1
961
+ port = int(os.getenv('GRADIO_SERVER_PORT', '7860'))
962
+ import pathlib
963
+ absolute_path_string = os.path.abspath(file)
964
+ url_path = pathlib.Path(absolute_path_string).as_uri()
965
+ url = get_url(absolute_path_string, from_str=True)
966
+ img_url = url.replace("""<a href=""", """<img src=""")
967
+ if file.endswith('.png') or file.endswith('.jpg') or file.endswith('.jpeg'):
968
+ return gr.update(visible=True, value=img_url), dummy1, dummy1, dummy1
969
+ elif file.endswith('.pdf') or 'arxiv.org/pdf' in file:
970
+ if file.startswith('http') or file.startswith('https'):
971
+ # if file is online, then might as well use google(?)
972
+ document1 = file
973
+ return gr.update(visible=True, value=f"""<iframe width="1000" height="800" src="https://docs.google.com/viewerng/viewer?url={document1}&embedded=true" frameborder="0" height="100%" width="100%">
974
+ </iframe>
975
+ """), dummy1, dummy1, dummy1
976
+ else:
977
+ ip = get_local_ip()
978
+ document1 = url_path.replace('file://', f'http://{ip}:{port}/')
979
+ # document1 = url
980
+ return gr.update(visible=True, value=f"""<object data="{document1}" type="application/pdf">
981
+ <iframe src="https://docs.google.com/viewer?url={document1}&embedded=true"></iframe>
982
+ </object>"""), dummy1, dummy1, dummy1
983
+ else:
984
+ return dummy_ret
985
+
986
+ view_document_choice.select(fn=show_doc, inputs=view_document_choice,
987
+ outputs=[doc_view, doc_view2, doc_view3, doc_view4])
988
+
989
  # Get inputs to evaluate() and make_db()
990
  # don't deepcopy, can contain model itself
991
  all_kwargs = kwargs.copy()
 
1072
  **kwargs_evaluate
1073
  )
1074
 
 
 
 
1075
  dark_mode_btn.click(
1076
  None,
1077
  None,
 
1081
  queue=False,
1082
  )
1083
 
1084
+ def visible_toggle(x):
1085
+ x = 'off' if x == 'on' else 'on'
1086
+ return x, gr.Column.update(visible=True if x == 'on' else False)
1087
 
1088
+ side_bar_btn.click(fn=visible_toggle,
1089
+ inputs=side_bar_text,
1090
+ outputs=[side_bar_text, side_bar],
1091
+ queue=False)
1092
 
1093
+ submit_buttons_btn.click(fn=visible_toggle,
1094
+ inputs=submit_buttons_text,
1095
+ outputs=[submit_buttons_text, submit_buttons],
1096
+ queue=False)
 
 
 
1097
 
1098
  # examples after submit or any other buttons for chat or no chat
1099
  if kwargs['examples'] is not None and kwargs['show_examples']:
 
1214
  prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
1215
  langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
1216
  langchain_action1 = args_list[eval_func_param_names.index('langchain_action')]
1217
+ document_subset1 = args_list[eval_func_param_names.index('document_subset')]
1218
  document_choice1 = args_list[eval_func_param_names.index('document_choice')]
1219
  if not prompt_type1:
1220
  # shouldn't have to specify if CLI launched model
 
1247
  return history
1248
  if user_message1 in ['', None, '\n']:
1249
  if langchain_action1 in LangChainAction.QUERY.value and \
1250
+ DocumentChoices.All.name != document_subset1 \
1251
  or \
1252
  langchain_mode1 in [LangChainMode.CHAT_LLM.value, LangChainMode.LLM.value]:
1253
  # reject non-retry submit/enter
 
1310
  args_list = args_list[:-3] # only keep rest needed for evaluate()
1311
  langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
1312
  langchain_action1 = args_list[eval_func_param_names.index('langchain_action')]
1313
+ document_subset1 = args_list[eval_func_param_names.index('document_subset')]
1314
  document_choice1 = args_list[eval_func_param_names.index('document_choice')]
1315
  if not history:
1316
  print("No history", flush=True)
 
1323
  history[-1][1] = None
1324
  elif not instruction1:
1325
  if langchain_action1 in LangChainAction.QUERY.value and \
1326
+ DocumentChoices.All.name != document_choice1 \
1327
  or \
1328
  langchain_mode1 in [LangChainMode.CHAT_LLM.value, LangChainMode.LLM.value]:
1329
  # if not retrying, then reject empty query
 
1494
  )
1495
  bot_args = dict(fn=bot,
1496
  inputs=inputs_list + [model_state, my_db_state] + [text_output],
1497
+ outputs=[text_output, chat_exception_text],
1498
  )
1499
  retry_bot_args = dict(fn=functools.partial(bot, retry=True),
1500
  inputs=inputs_list + [model_state, my_db_state] + [text_output],
1501
+ outputs=[text_output, chat_exception_text],
1502
  )
1503
  retry_user_args = dict(fn=functools.partial(user, retry=True),
1504
  inputs=inputs_list + [text_output],
 
1516
  )
1517
  bot_args2 = dict(fn=bot,
1518
  inputs=inputs_list2 + [model_state2, my_db_state] + [text_output2],
1519
+ outputs=[text_output2, chat_exception_text],
1520
  )
1521
  retry_bot_args2 = dict(fn=functools.partial(bot, retry=True),
1522
  inputs=inputs_list2 + [model_state2, my_db_state] + [text_output2],
1523
+ outputs=[text_output2, chat_exception_text],
1524
  )
1525
  retry_user_args2 = dict(fn=functools.partial(user, retry=True),
1526
  inputs=inputs_list2 + [text_output2],
 
1541
  )
1542
  all_bot_args = dict(fn=functools.partial(all_bot, model_states1=model_states),
1543
  inputs=inputs_list + [my_db_state] + text_outputs,
1544
+ outputs=text_outputs + [chat_exception_text],
1545
  )
1546
  all_retry_bot_args = dict(fn=functools.partial(all_bot, model_states1=model_states, retry=True),
1547
  inputs=inputs_list + [my_db_state] + text_outputs,
1548
+ outputs=text_outputs + [chat_exception_text],
1549
  )
1550
  all_retry_user_args = dict(fn=functools.partial(all_user, retry=True,
1551
  sanitize_user_prompt=kwargs['sanitize_user_prompt'],
 
1743
  return False
1744
  return is_same
1745
 
1746
+ def save_chat(*args, chat_is_list=False):
1747
  args_list = list(args)
1748
+ if not chat_is_list:
1749
+ # list of chatbot histories,
1750
+ # can't pass in list with list of chatbot histories and state due to gradio limits
1751
+ chat_list = args_list[:-1]
1752
+ else:
1753
+ assert len(args_list) == 2
1754
+ chat_list = args_list[0]
1755
+ # if old chat file with single chatbot, get into shape
1756
+ if isinstance(chat_list, list) and len(chat_list) > 0 and isinstance(chat_list[0], list) and len(
1757
+ chat_list[0]) == 2 and isinstance(chat_list[0][0], str) and isinstance(chat_list[0][1], str):
1758
+ chat_list = [chat_list]
1759
  # remove None histories
1760
  chat_list_not_none = [x for x in chat_list if x and len(x) > 0 and len(x[0]) == 2 and x[0][1] is not None]
1761
+ chat_list_none = [x for x in chat_list if x not in chat_list_not_none]
1762
+ if len(chat_list_none) > 0 and len(chat_list_not_none) == 0:
1763
+ raise ValueError("Invalid chat file")
1764
+ # dict with keys of short chat names, values of list of list of chatbot histories
1765
+ chat_state1 = args_list[-1]
1766
  short_chats = list(chat_state1.keys())
1767
  if len(chat_list_not_none) > 0:
1768
  # make short_chat key from only first history, based upon question that is same anyways
 
1774
  if not already_exists:
1775
  chat_state1[short_chat] = chat_list.copy()
1776
  # clear chat_list so saved and then new conversation starts
1777
+ # FIXME: seems less confusing to clear, since have clear button right next
1778
+ # chat_list = [[]] * len(chat_list)
1779
+ if not chat_is_list:
1780
+ ret_list = chat_list + [chat_state1]
1781
+ else:
1782
+ ret_list = [chat_list] + [chat_state1]
1783
  return tuple(ret_list)
1784
 
 
 
 
1785
  def switch_chat(chat_key, chat_state1, num_model_lock=0):
1786
  chosen_chat = chat_state1[chat_key]
1787
  # deal with possible different size of chat list vs. current list
 
1805
  .then(clear_scores, outputs=[score_text, score_text2, score_text_nochat])
1806
 
1807
  def remove_chat(chat_key, chat_state1):
1808
+ if isinstance(chat_key, str):
1809
+ chat_state1.pop(chat_key, None)
1810
+ return gr.update(choices=list(chat_state1.keys()), value=None), chat_state1
1811
 
1812
+ remove_chat_event = remove_chat_btn.click(remove_chat,
1813
+ inputs=[radio_chats, chat_state], outputs=[radio_chats, chat_state],
1814
+ queue=False)
1815
 
1816
  def get_chats1(chat_state1):
1817
  base = 'chats'
 
1821
  f.write(json.dumps(chat_state1, indent=2))
1822
  return filename
1823
 
1824
+ export_chat_event = export_chats_btn.click(get_chats1, inputs=chat_state, outputs=chats_file, queue=False,
1825
+ api_name='export_chats' if allow_api else None)
1826
 
1827
+ def add_chats_from_file(file, chat_state1, radio_chats1, chat_exception_text1):
1828
  if not file:
1829
+ return None, chat_state1, gr.update(choices=list(chat_state1.keys()), value=None), chat_exception_text1
1830
  if isinstance(file, str):
1831
  files = [file]
1832
  else:
1833
  files = file
1834
  if not files:
1835
+ return None, chat_state1, gr.update(choices=list(chat_state1.keys()), value=None), chat_exception_text1
1836
+ chat_exception_list = []
1837
  for file1 in files:
1838
  try:
1839
  if hasattr(file1, 'name'):
 
1842
  new_chats = json.loads(f.read())
1843
  for chat1_k, chat1_v in new_chats.items():
1844
  # ignore chat1_k, regenerate and de-dup to avoid loss
1845
+ _, chat_state1 = save_chat(chat1_v, chat_state1, chat_is_list=True)
1846
  except BaseException as e:
1847
  t, v, tb = sys.exc_info()
1848
  ex = ''.join(traceback.format_exception(t, v, tb))
1849
+ ex_str = "File %s exception: %s" % (file1, str(e))
1850
+ print(ex_str, flush=True)
1851
+ chat_exception_list.append(ex_str)
1852
+ chat_exception_text1 = '\n'.join(chat_exception_list)
1853
+ return None, chat_state1, gr.update(choices=list(chat_state1.keys()), value=None), chat_exception_text1
1854
 
1855
  # note for update_user_db_func output is ignored for db
1856
+ chatup_change_event = chatsup_output.change(add_chats_from_file,
1857
+ inputs=[chatsup_output, chat_state, radio_chats,
1858
+ chat_exception_text],
1859
+ outputs=[chatsup_output, chat_state, radio_chats,
1860
+ chat_exception_text],
1861
+ queue=False,
1862
+ api_name='add_to_chats' if allow_api else None)
1863
+
1864
+ clear_chat_event = clear_chat_btn.click(fn=clear_texts,
1865
+ inputs=[text_output, text_output2] + text_outputs,
1866
+ outputs=[text_output, text_output2] + text_outputs,
1867
+ queue=False, api_name='clear' if allow_api else None) \
1868
  .then(deselect_radio_chats, inputs=None, outputs=radio_chats, queue=False) \
1869
  .then(clear_scores, outputs=[score_text, score_text2, score_text_nochat])
1870
 
1871
+ def update_radio_chats(chat_state1):
1872
+ # reverse so newest at top
1873
+ choices = list(chat_state1.keys()).copy()
1874
+ choices.reverse()
1875
+ return gr.update(choices=choices, value=None)
1876
+
1877
+ clear_event = save_chat_btn.click(save_chat,
1878
+ inputs=[text_output, text_output2] + text_outputs + [chat_state],
1879
+ outputs=[text_output, text_output2] + text_outputs + [chat_state],
1880
+ api_name='save_chat' if allow_api else None) \
1881
  .then(update_radio_chats, inputs=chat_state, outputs=radio_chats,
1882
  api_name='update_chats' if allow_api else None) \
1883
  .then(clear_scores, outputs=[score_text, score_text2, score_text_nochat])
 
1911
  .then(clear_torch_cache)
1912
 
1913
  def load_model(model_name, lora_weights, server_name, model_state_old, prompt_type_old, load_8bit,
1914
+ use_gpu_id, gpu_id):
1915
  # ensure no API calls reach here
1916
  if is_public:
1917
  raise RuntimeError("Illegal access for %s" % model_name)
 
1955
  all_kwargs1 = all_kwargs.copy()
1956
  all_kwargs1['base_model'] = model_name.strip()
1957
  all_kwargs1['load_8bit'] = load_8bit
1958
+ all_kwargs1['use_gpu_id'] = use_gpu_id
1959
  all_kwargs1['gpu_id'] = int(gpu_id) # detranscribe
1960
  model_lower = model_name.strip().lower()
1961
  if model_lower in inv_prompt_type_to_model_lower:
 
2008
 
2009
  get_prompt_str_func1 = functools.partial(get_prompt_str, which=1)
2010
  get_prompt_str_func2 = functools.partial(get_prompt_str, which=2)
2011
+ prompt_type.change(fn=get_prompt_str_func1, inputs=[prompt_type, prompt_dict], outputs=prompt_dict, queue=False)
2012
+ prompt_type2.change(fn=get_prompt_str_func2, inputs=[prompt_type2, prompt_dict2], outputs=prompt_dict2,
2013
+ queue=False)
2014
 
2015
  def dropdown_prompt_type_list(x):
2016
  return gr.Dropdown.update(value=x)
 
2020
 
2021
  load_model_args = dict(fn=load_model,
2022
  inputs=[model_choice, lora_choice, server_choice, model_state, prompt_type,
2023
+ model_load8bit_checkbox, model_use_gpu_id_checkbox, model_gpu],
2024
  outputs=[model_state, model_used, lora_used, server_used,
2025
  # if prompt_type changes, prompt_dict will change via change rule
2026
  prompt_type, max_new_tokens, min_new_tokens,
 
2028
  prompt_update_args = dict(fn=dropdown_prompt_type_list, inputs=prompt_type, outputs=prompt_type)
2029
  chatbot_update_args = dict(fn=chatbot_list, inputs=[text_output, model_used], outputs=text_output)
2030
  nochat_update_args = dict(fn=chatbot_list, inputs=[text_output_nochat, model_used], outputs=text_output_nochat)
2031
+ load_model_event = load_model_button.click(**load_model_args,
2032
+ api_name='load_model' if allow_api and is_public else None) \
2033
+ .then(**prompt_update_args) \
2034
+ .then(**chatbot_update_args) \
2035
+ .then(**nochat_update_args) \
2036
+ .then(clear_torch_cache)
2037
 
2038
  load_model_args2 = dict(fn=load_model,
2039
  inputs=[model_choice2, lora_choice2, server_choice2, model_state2, prompt_type2,
2040
+ model_load8bit_checkbox2, model_use_gpu_id_checkbox2, model_gpu2],
2041
  outputs=[model_state2, model_used2, lora_used2, server_used2,
2042
  # if prompt_type2 changes, prompt_dict2 will change via change rule
2043
  prompt_type2, max_new_tokens2, min_new_tokens2
2044
  ])
2045
  prompt_update_args2 = dict(fn=dropdown_prompt_type_list, inputs=prompt_type2, outputs=prompt_type2)
2046
  chatbot_update_args2 = dict(fn=chatbot_list, inputs=[text_output2, model_used2], outputs=text_output2)
2047
+ load_model_event2 = load_model_button2.click(**load_model_args2,
2048
+ api_name='load_model2' if allow_api and is_public else None) \
2049
+ .then(**prompt_update_args2) \
2050
+ .then(**chatbot_update_args2) \
2051
+ .then(clear_torch_cache)
 
2052
 
2053
  def dropdown_model_lora_server_list(model_list0, model_x,
2054
  lora_list0, lora_x,
 
2097
  server_options_state],
2098
  queue=False)
2099
 
2100
+ go_event = go_btn.click(lambda: gr.update(visible=False), None, go_btn, api_name="go" if allow_api else None,
2101
+ queue=False) \
2102
  .then(lambda: gr.update(visible=True), None, normal_block, queue=False) \
2103
  .then(**load_model_args, queue=False).then(**prompt_update_args, queue=False)
2104
 
 
2166
  def get_hash():
2167
  return kwargs['git_hash']
2168
 
2169
+ system_event = system_btn3.click(get_hash,
2170
+ outputs=system_text3,
2171
+ api_name='system_hash' if allow_api else None,
2172
+ queue=False,
2173
+ )
 
 
 
 
 
 
 
 
 
 
 
 
2174
 
2175
  def count_chat_tokens(model_state1, chat1, prompt_type1, prompt_dict1,
2176
  memory_restriction_level1=0,
 
2198
  count_chat_tokens_func = functools.partial(count_chat_tokens,
2199
  memory_restriction_level1=memory_restriction_level,
2200
  keep_sources_in_context1=kwargs['keep_sources_in_context'])
2201
+ count_tokens_event = count_chat_tokens_btn.click(fn=count_chat_tokens,
2202
+ inputs=[model_state, text_output, prompt_type, prompt_dict],
2203
+ outputs=chat_token_count,
2204
+ api_name='count_tokens' if allow_api else None)
2205
+
2206
+ # don't pass text_output, don't want to clear output, just stop it
2207
+ # cancel only stops outer generation, not inner generation or non-generation
2208
+ stop_btn.click(lambda: None, None, None,
2209
+ cancels=submits1 + submits2 + submits3 + submits4 +
2210
+ [submit_event_nochat, submit_event_nochat2] +
2211
+ [eventdb1, eventdb2, eventdb3] +
2212
+ [eventdb7, eventdb8, eventdb9, eventdb12] +
2213
+ db_events +
2214
+ [clear_event] +
2215
+ [submit_event_nochat_api, submit_event_nochat] +
2216
+ [load_model_event, load_model_event2] +
2217
+ [count_tokens_event]
2218
+ ,
2219
+ queue=False, api_name='stop' if allow_api else None).then(clear_torch_cache, queue=False)
2220
 
2221
  demo.load(None, None, None, _js=get_dark_js() if kwargs['h2ocolors'] and False else None) # light best
2222
 
 
2289
 
2290
 
2291
  def get_sources(db1, langchain_mode, dbs=None, docs_state0=None):
2292
+ set_userid(db1)
2293
+
2294
  if langchain_mode in ['ChatLLM', 'LLM']:
2295
  source_files_added = "NA"
2296
  source_list = []
 
2321
  return sources_file, source_list
2322
 
2323
 
2324
+ def set_userid(db1):
2325
+ # can only call this after function called so for specific userr, not in gr.State() that occurs during app init
2326
+ assert db1 is not None and len(db1) == 2
2327
+ if db1[1] is None:
2328
+ # uuid in db is used as user ID
2329
+ db1[1] = str(uuid.uuid4())
2330
+
2331
+
2332
+ def update_user_db(file, db1, chunk, chunk_size, langchain_mode, dbs=None, **kwargs):
2333
+ set_userid(db1)
2334
+
2335
+ if file is None:
2336
+ raise RuntimeError("Don't use change, use input")
2337
+
2338
  try:
2339
+ return _update_user_db(file, db1=db1, chunk=chunk, chunk_size=chunk_size,
2340
+ langchain_mode=langchain_mode, dbs=dbs,
2341
+ **kwargs)
2342
  except BaseException as e:
2343
  print(traceback.format_exc(), flush=True)
2344
  # gradio has issues if except, so fail semi-gracefully, else would hang forever in processing textbox
 
2355
  </body>
2356
  </html>
2357
  """.format(ex_str)
2358
+ doc_exception_text = str(e)
2359
+ return None, langchain_mode, source_files_added, doc_exception_text
 
 
2360
  finally:
2361
  clear_torch_cache()
2362
 
2363
 
2364
  def get_lock_file(db1, langchain_mode):
2365
+ set_userid(db1)
2366
  assert len(db1) == 2 and db1[1] is not None and isinstance(db1[1], str)
2367
  user_id = db1[1]
2368
  base_path = 'locks'
 
2371
  return lock_file
2372
 
2373
 
2374
+ def _update_user_db(file,
2375
+ db1=None,
2376
+ chunk=None, chunk_size=None,
2377
+ dbs=None, db_type=None, langchain_mode='UserData',
2378
  user_path=None,
2379
  use_openai_embedding=None,
2380
  hf_embedding_model=None,
 
2385
  verbose=None,
2386
  is_url=None, is_txt=None,
2387
  n_jobs=-1):
2388
+ assert db1 is not None
2389
+ assert chunk is not None
2390
+ assert chunk_size is not None
2391
  assert use_openai_embedding is not None
2392
  assert hf_embedding_model is not None
2393
  assert caption_loader is not None
 
2396
  assert enable_ocr is not None
2397
  assert verbose is not None
2398
 
2399
+ set_userid(db1)
2400
+
2401
  if dbs is None:
2402
  dbs = {}
2403
  assert isinstance(dbs, dict), "Wrong type for dbs: %s" % str(type(dbs))
 
2412
  if not isinstance(file, (list, tuple, typing.Generator)) and isinstance(file, str):
2413
  file = [file]
2414
 
2415
+ if langchain_mode == LangChainMode.DISABLED.value:
2416
+ return None, langchain_mode, get_source_files(), ""
2417
+
2418
+ if langchain_mode in [LangChainMode.CHAT_LLM.value, LangChainMode.CHAT_LLM.value]:
2419
+ # then switch to MyData, so langchain_mode also becomes way to select where upload goes
2420
+ # but default to mydata if nothing chosen, since safest
2421
+ langchain_mode = LangChainMode.MY_DATA.value
2422
+
2423
  if langchain_mode == 'UserData' and user_path is not None:
2424
  # move temp files from gradio upload to stable location
2425
  for fili, fil in enumerate(file):
 
2448
  caption_loader=caption_loader,
2449
  )
2450
  exceptions = [x for x in sources if x.metadata.get('exception')]
2451
+ exceptions_strs = [x.metadata['exception'] for x in exceptions]
2452
  sources = [x for x in sources if 'exception' not in x.metadata]
2453
 
2454
  lock_file = get_lock_file(db1, langchain_mode)
 
2475
  if db is not None:
2476
  db1[0] = db
2477
  source_files_added = get_source_files(db=db1[0], exceptions=exceptions)
2478
+ return None, langchain_mode, source_files_added, '\n'.join(exceptions_strs)
2479
  else:
2480
  from gpt_langchain import get_persist_directory
2481
  persist_directory = get_persist_directory(langchain_mode)
 
2493
  hf_embedding_model=hf_embedding_model)
2494
  dbs[langchain_mode] = db
2495
  # NOTE we do not return db, because function call always same code path
2496
+ # return dbs[langchain_mode]
2497
  # db in this code path is updated in place
2498
  source_files_added = get_source_files(db=dbs[langchain_mode], exceptions=exceptions)
2499
+ return None, langchain_mode, source_files_added, '\n'.join(exceptions_strs)
2500
 
2501
 
2502
  def get_db(db1, langchain_mode, dbs=None):
gradio_themes.py CHANGED
@@ -133,6 +133,11 @@ class H2oTheme(Soft):
133
  background_fill_primary_dark="*block_background_fill",
134
  block_radius="0 0 8px 8px",
135
  checkbox_label_text_color_selected_dark='#000000',
 
 
 
 
 
136
  )
137
 
138
 
@@ -173,6 +178,9 @@ class SoftTheme(Soft):
173
  font=font,
174
  font_mono=font_mono,
175
  )
 
 
 
176
 
177
 
178
  h2o_logo = '<svg id="Layer_1" data-name="Layer 1" xmlns="http://www.w3.org/2000/svg" width="100%" height="100%"' \
 
133
  background_fill_primary_dark="*block_background_fill",
134
  block_radius="0 0 8px 8px",
135
  checkbox_label_text_color_selected_dark='#000000',
136
+ #checkbox_label_text_size="*text_xs", # too small for iPhone etc. but good if full large screen zoomed to fit
137
+ checkbox_label_text_size="*text_sm",
138
+ #radio_circle="""url("data:image/svg+xml,%3csvg viewBox='0 0 32 32' fill='white' xmlns='http://www.w3.org/2000/svg'%3e%3ccircle cx='32' cy='32' r='1'/%3e%3c/svg%3e")""",
139
+ #checkbox_border_width=1,
140
+ #heckbox_border_width_dark=1,
141
  )
142
 
143
 
 
178
  font=font,
179
  font_mono=font_mono,
180
  )
181
+ super().set(
182
+ checkbox_label_text_size="*text_sm",
183
+ )
184
 
185
 
186
  h2o_logo = '<svg id="Layer_1" data-name="Layer 1" xmlns="http://www.w3.org/2000/svg" width="100%" height="100%"' \
gradio_utils/__pycache__/css.cpython-310.pyc CHANGED
Binary files a/gradio_utils/__pycache__/css.cpython-310.pyc and b/gradio_utils/__pycache__/css.cpython-310.pyc differ
 
gradio_utils/__pycache__/prompt_form.cpython-310.pyc CHANGED
Binary files a/gradio_utils/__pycache__/prompt_form.cpython-310.pyc and b/gradio_utils/__pycache__/prompt_form.cpython-310.pyc differ
 
gradio_utils/css.py CHANGED
@@ -12,7 +12,10 @@ def get_css(kwargs) -> str:
12
 
13
 
14
  def make_css_base() -> str:
15
- return """
 
 
 
16
  @import url('https://fonts.googleapis.com/css2?family=Source+Sans+Pro:wght@400;600&display=swap');
17
 
18
  body.dark{#warning {background-color: #555555};}
 
12
 
13
 
14
  def make_css_base() -> str:
15
+ css1 = """
16
+ #col_container {margin-left: auto; margin-right: auto; text-align: left;}
17
+ """
18
+ return css1 + """
19
  @import url('https://fonts.googleapis.com/css2?family=Source+Sans+Pro:wght@400;600&display=swap');
20
 
21
  body.dark{#warning {background-color: #555555};}
gradio_utils/prompt_form.py CHANGED
@@ -93,30 +93,3 @@ def make_chatbots(output_label0, output_label0_model2, **kwargs):
93
  text_output2 = gr.Chatbot(label=output_label0_model2,
94
  visible=False and not kwargs['model_lock'], height=kwargs['height'] or 400)
95
  return text_output, text_output2, text_outputs
96
-
97
-
98
- def make_prompt_form(kwargs, LangChainMode):
99
- if kwargs['langchain_mode'] != LangChainMode.DISABLED.value:
100
- extra_prompt_form = ". For summarization, empty submission uses first top_k_docs documents."
101
- else:
102
- extra_prompt_form = ""
103
- if kwargs['input_lines'] > 1:
104
- instruction_label = "Shift-Enter to Submit, Enter for more lines%s" % extra_prompt_form
105
- else:
106
- instruction_label = "Enter to Submit, Shift-Enter for more lines%s" % extra_prompt_form
107
-
108
- with gr.Row():#elem_id='prompt-form-area'):
109
- with gr.Column(scale=50):
110
- instruction = gr.Textbox(
111
- lines=kwargs['input_lines'],
112
- label='Ask anything',
113
- placeholder=instruction_label,
114
- info=None,
115
- elem_id='prompt-form',
116
- container=True,
117
- )
118
- with gr.Row():
119
- submit = gr.Button(value='Submit', variant='primary', scale=0, size='sm')
120
- stop_btn = gr.Button(value="Stop", variant='secondary', scale=0, size='sm')
121
-
122
- return instruction, submit, stop_btn
 
93
  text_output2 = gr.Chatbot(label=output_label0_model2,
94
  visible=False and not kwargs['model_lock'], height=kwargs['height'] or 400)
95
  return text_output, text_output2, text_outputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
loaders.py CHANGED
@@ -1,40 +1,48 @@
1
- def get_loaders(model_name, reward_type, llama_type=None):
 
 
 
2
  # NOTE: Some models need specific new prompt_type
3
  # E.g. t5_xxl_true_nli_mixture has input format: "premise: PREMISE_TEXT hypothesis: HYPOTHESIS_TEXT".)
 
 
 
 
 
 
4
  if llama_type is None:
5
  llama_type = "llama" in model_name.lower()
6
  if llama_type:
7
  from transformers import LlamaForCausalLM, LlamaTokenizer
8
- model_loader = LlamaForCausalLM
9
- tokenizer_loader = LlamaTokenizer
10
  elif 'distilgpt2' in model_name.lower():
11
  from transformers import AutoModelForCausalLM, AutoTokenizer
12
- return AutoModelForCausalLM, AutoTokenizer
13
  elif 'gpt2' in model_name.lower():
14
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
15
- return GPT2LMHeadModel, GPT2Tokenizer
16
  elif 'mbart-' in model_name.lower():
17
  from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
18
- return MBartForConditionalGeneration, MBart50TokenizerFast
19
  elif 't5' == model_name.lower() or \
20
  't5-' in model_name.lower() or \
21
  'flan-' in model_name.lower():
22
  from transformers import AutoTokenizer, T5ForConditionalGeneration
23
- return T5ForConditionalGeneration, AutoTokenizer
24
  elif 'bigbird' in model_name:
25
  from transformers import BigBirdPegasusForConditionalGeneration, AutoTokenizer
26
- return BigBirdPegasusForConditionalGeneration, AutoTokenizer
27
  elif 'bart-large-cnn-samsum' in model_name or 'flan-t5-base-samsum' in model_name:
28
  from transformers import pipeline
29
  return pipeline, "summarization"
30
  elif reward_type or 'OpenAssistant/reward-model'.lower() in model_name.lower():
31
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
32
- return AutoModelForSequenceClassification, AutoTokenizer
33
  else:
34
  from transformers import AutoTokenizer, AutoModelForCausalLM
35
  model_loader = AutoModelForCausalLM
36
  tokenizer_loader = AutoTokenizer
37
- return model_loader, tokenizer_loader
38
 
39
 
40
  def get_tokenizer(tokenizer_loader, tokenizer_base_model, local_files_only, resume_download, use_auth_token):
 
1
+ import functools
2
+
3
+
4
+ def get_loaders(model_name, reward_type, llama_type=None, load_gptq=''):
5
  # NOTE: Some models need specific new prompt_type
6
  # E.g. t5_xxl_true_nli_mixture has input format: "premise: PREMISE_TEXT hypothesis: HYPOTHESIS_TEXT".)
7
+ if load_gptq:
8
+ from transformers import AutoTokenizer
9
+ from auto_gptq import AutoGPTQForCausalLM
10
+ use_triton = False
11
+ functools.partial(AutoGPTQForCausalLM.from_quantized, quantize_config=None, use_triton=use_triton)
12
+ return AutoGPTQForCausalLM.from_quantized, AutoTokenizer
13
  if llama_type is None:
14
  llama_type = "llama" in model_name.lower()
15
  if llama_type:
16
  from transformers import LlamaForCausalLM, LlamaTokenizer
17
+ return LlamaForCausalLM.from_pretrained, LlamaTokenizer
 
18
  elif 'distilgpt2' in model_name.lower():
19
  from transformers import AutoModelForCausalLM, AutoTokenizer
20
+ return AutoModelForCausalLM.from_pretrained, AutoTokenizer
21
  elif 'gpt2' in model_name.lower():
22
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
23
+ return GPT2LMHeadModel.from_pretrained, GPT2Tokenizer
24
  elif 'mbart-' in model_name.lower():
25
  from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
26
+ return MBartForConditionalGeneration.from_pretrained, MBart50TokenizerFast
27
  elif 't5' == model_name.lower() or \
28
  't5-' in model_name.lower() or \
29
  'flan-' in model_name.lower():
30
  from transformers import AutoTokenizer, T5ForConditionalGeneration
31
+ return T5ForConditionalGeneration.from_pretrained, AutoTokenizer
32
  elif 'bigbird' in model_name:
33
  from transformers import BigBirdPegasusForConditionalGeneration, AutoTokenizer
34
+ return BigBirdPegasusForConditionalGeneration.from_pretrained, AutoTokenizer
35
  elif 'bart-large-cnn-samsum' in model_name or 'flan-t5-base-samsum' in model_name:
36
  from transformers import pipeline
37
  return pipeline, "summarization"
38
  elif reward_type or 'OpenAssistant/reward-model'.lower() in model_name.lower():
39
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
40
+ return AutoModelForSequenceClassification.from_pretrained, AutoTokenizer
41
  else:
42
  from transformers import AutoTokenizer, AutoModelForCausalLM
43
  model_loader = AutoModelForCausalLM
44
  tokenizer_loader = AutoTokenizer
45
+ return model_loader.from_pretrained, tokenizer_loader
46
 
47
 
48
  def get_tokenizer(tokenizer_loader, tokenizer_base_model, local_files_only, resume_download, use_auth_token):
prompter.py CHANGED
@@ -23,9 +23,6 @@ prompt_type_to_model_name = {
23
  'gpt2',
24
  'distilgpt2',
25
  'mosaicml/mpt-7b-storywriter',
26
- 'mosaicml/mpt-7b-instruct', # internal code handles instruct
27
- 'mosaicml/mpt-7b-chat', # NC, internal code handles instruct
28
- 'mosaicml/mpt-30b-instruct', # internal code handles instruct
29
  ],
30
  'gptj': ['gptj', 'gpt4all_llama'],
31
  'prompt_answer': [
@@ -41,6 +38,7 @@ prompt_type_to_model_name = {
41
  'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-40b-v2',
42
  'h2oai/h2ogpt-gm-oasst1-en-xgen-7b-8k',
43
  'h2oai/h2ogpt-gm-oasst1-multilang-xgen-7b-8k',
 
44
  ],
45
  'prompt_answer_openllama': [
46
  'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt',
@@ -49,7 +47,7 @@ prompt_type_to_model_name = {
49
  'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b',
50
  'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-13b',
51
  ],
52
- 'instruct': [],
53
  'instruct_with_end': ['databricks/dolly-v2-12b'],
54
  'quality': [],
55
  'human_bot': [
@@ -74,8 +72,11 @@ prompt_type_to_model_name = {
74
  "wizard_mega": ['openaccess-ai-collective/wizard-mega-13b'],
75
  "instruct_simple": ['JosephusCheung/Guanaco'],
76
  "wizard_vicuna": ['ehartford/Wizard-Vicuna-13B-Uncensored'],
77
- "wizard2": ['llama', 'mosaicml/mpt-30b-instruct'],
 
 
78
  "vicuna11": ['lmsys/vicuna-33b-v1.3'],
 
79
  # could be plain, but default is correct prompt_type for default TheBloke model ggml-wizardLM-7B.q4_2.bin
80
  }
81
  if os.getenv('OPENAI_API_KEY'):
@@ -293,7 +294,7 @@ Current Time: {}
293
  humanstr = prompt_tokens
294
  botstr = answer_tokens
295
  terminate_response = [humanstr, PreResponse, eos]
296
- chat_sep = ''
297
  chat_turn_sep = eos
298
  elif prompt_type in [PromptType.prompt_answer_openllama.value, str(PromptType.prompt_answer_openllama.value),
299
  PromptType.prompt_answer_openllama.name]:
@@ -309,7 +310,7 @@ Current Time: {}
309
  humanstr = prompt_tokens
310
  botstr = answer_tokens
311
  terminate_response = [humanstr, PreResponse, eos]
312
- chat_sep = ''
313
  chat_turn_sep = eos
314
  elif prompt_type in [PromptType.open_assistant.value, str(PromptType.open_assistant.value),
315
  PromptType.open_assistant.name]:
@@ -520,6 +521,67 @@ ASSISTANT:
520
  # normally LLM adds space after this, because was how trained.
521
  # if add space here, non-unique tokenization will often make LLM produce wrong output
522
  PreResponse = PreResponse
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
523
  else:
524
  raise RuntimeError("No such prompt_type=%s" % prompt_type)
525
 
 
23
  'gpt2',
24
  'distilgpt2',
25
  'mosaicml/mpt-7b-storywriter',
 
 
 
26
  ],
27
  'gptj': ['gptj', 'gpt4all_llama'],
28
  'prompt_answer': [
 
38
  'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-40b-v2',
39
  'h2oai/h2ogpt-gm-oasst1-en-xgen-7b-8k',
40
  'h2oai/h2ogpt-gm-oasst1-multilang-xgen-7b-8k',
41
+ 'TheBloke/h2ogpt-gm-oasst1-en-2048-falcon-40b-v2-GPTQ',
42
  ],
43
  'prompt_answer_openllama': [
44
  'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt',
 
47
  'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b',
48
  'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-13b',
49
  ],
50
+ 'instruct': ['TheBloke/llama-30b-supercot-SuperHOT-8K-fp16'], # https://huggingface.co/TheBloke/llama-30b-supercot-SuperHOT-8K-fp16#prompting
51
  'instruct_with_end': ['databricks/dolly-v2-12b'],
52
  'quality': [],
53
  'human_bot': [
 
72
  "wizard_mega": ['openaccess-ai-collective/wizard-mega-13b'],
73
  "instruct_simple": ['JosephusCheung/Guanaco'],
74
  "wizard_vicuna": ['ehartford/Wizard-Vicuna-13B-Uncensored'],
75
+ "wizard2": ['llama'],
76
+ "mptinstruct": ['mosaicml/mpt-30b-instruct', 'mosaicml/mpt-7b-instruct', 'mosaicml/mpt-30b-instruct'],
77
+ "mptchat": ['mosaicml/mpt-7b-chat', 'mosaicml/mpt-30b-chat', 'TheBloke/mpt-30B-chat-GGML'],
78
  "vicuna11": ['lmsys/vicuna-33b-v1.3'],
79
+ "falcon": ['tiiuae/falcon-40b-instruct', 'tiiuae/falcon-40b', 'tiiuae/falcon-7b-instruct', 'tiiuae/falcon-7b'],
80
  # could be plain, but default is correct prompt_type for default TheBloke model ggml-wizardLM-7B.q4_2.bin
81
  }
82
  if os.getenv('OPENAI_API_KEY'):
 
294
  humanstr = prompt_tokens
295
  botstr = answer_tokens
296
  terminate_response = [humanstr, PreResponse, eos]
297
+ chat_sep = eos
298
  chat_turn_sep = eos
299
  elif prompt_type in [PromptType.prompt_answer_openllama.value, str(PromptType.prompt_answer_openllama.value),
300
  PromptType.prompt_answer_openllama.name]:
 
310
  humanstr = prompt_tokens
311
  botstr = answer_tokens
312
  terminate_response = [humanstr, PreResponse, eos]
313
+ chat_sep = eos
314
  chat_turn_sep = eos
315
  elif prompt_type in [PromptType.open_assistant.value, str(PromptType.open_assistant.value),
316
  PromptType.open_assistant.name]:
 
521
  # normally LLM adds space after this, because was how trained.
522
  # if add space here, non-unique tokenization will often make LLM produce wrong output
523
  PreResponse = PreResponse
524
+ elif prompt_type in [PromptType.mptinstruct.value, str(PromptType.mptinstruct.value),
525
+ PromptType.mptinstruct.name]:
526
+ # https://huggingface.co/mosaicml/mpt-30b-instruct#formatting
527
+ promptA = promptB = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n' if not (
528
+ chat and reduced) else ''
529
+
530
+ PreInstruct = """
531
+ ### Instruction
532
+ """
533
+
534
+ PreInput = """
535
+ ### Input
536
+ """
537
+
538
+ PreResponse = """
539
+ ### Response
540
+ """
541
+ terminate_response = None
542
+ chat_turn_sep = chat_sep = '\n'
543
+ humanstr = PreInstruct
544
+ botstr = PreResponse
545
+ elif prompt_type in [PromptType.mptchat.value, str(PromptType.mptchat.value),
546
+ PromptType.mptchat.name]:
547
+ # https://huggingface.co/TheBloke/mpt-30B-chat-GGML#prompt-template
548
+ promptA = promptB = """<|im_start|>system\nA conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.\n<|im_end|>""" if not (
549
+ chat and reduced) else ''
550
+
551
+ PreInstruct = """<|im_start|>user
552
+ """
553
+
554
+ PreInput = None
555
+
556
+ PreResponse = """<|im_end|><|im_start|>assistant
557
+ """
558
+ terminate_response = ['<|im_end|>']
559
+ chat_sep = ''
560
+ chat_turn_sep = '<|im_end|>'
561
+ humanstr = PreInstruct
562
+ botstr = PreResponse
563
+ elif prompt_type in [PromptType.falcon.value, str(PromptType.falcon.value),
564
+ PromptType.falcon.name]:
565
+ promptA = promptB = "" if not (chat and reduced) else ''
566
+
567
+ PreInstruct = """User: """
568
+
569
+ PreInput = None
570
+
571
+ PreResponse = """Assistant:"""
572
+ terminate_response = ['\nUser', "<|endoftext|>"]
573
+ chat_sep = '\n\n'
574
+ chat_turn_sep = '\n\n'
575
+ humanstr = PreInstruct
576
+ botstr = PreResponse
577
+ if making_context:
578
+ # when making context, want it to appear as-if LLM generated, which starts with space after :
579
+ PreResponse = 'Assistant: '
580
+ else:
581
+ # normally LLM adds space after this, because was how trained.
582
+ # if add space here, non-unique tokenization will often make LLM produce wrong output
583
+ PreResponse = PreResponse
584
+ # generates_leading_space = True
585
  else:
586
  raise RuntimeError("No such prompt_type=%s" % prompt_type)
587
 
requirements.txt CHANGED
@@ -6,7 +6,7 @@ huggingface_hub==0.15.1
6
  appdirs==1.4.4
7
  fire==0.5.0
8
  docutils==0.20.1
9
- torch==2.0.1
10
  evaluate==0.4.0
11
  rouge_score==0.1.2
12
  sacrebleu==2.3.1
@@ -19,7 +19,7 @@ matplotlib==3.7.1
19
  loralib==0.1.1
20
  bitsandbytes==0.39.0
21
  accelerate==0.20.3
22
- git+https://github.com/huggingface/peft.git@0b62b4378b4ce9367932c73540349da9a41bdea8
23
  transformers==4.30.2
24
  tokenizers==0.13.3
25
  APScheduler==3.10.1
@@ -45,8 +45,8 @@ pytest-xdist==3.2.1
45
  nltk==3.8.1
46
  textstat==0.7.3
47
  # pandoc==2.3
48
- #pypandoc==1.11
49
- pypandoc_binary==1.11
50
  openpyxl==3.1.2
51
  lm_dataformat==0.0.20
52
  bioc==2.0
 
6
  appdirs==1.4.4
7
  fire==0.5.0
8
  docutils==0.20.1
9
+ torch==2.0.1; sys_platform != "darwin" and platform_machine != "arm64"
10
  evaluate==0.4.0
11
  rouge_score==0.1.2
12
  sacrebleu==2.3.1
 
19
  loralib==0.1.1
20
  bitsandbytes==0.39.0
21
  accelerate==0.20.3
22
+ git+https://github.com/huggingface/peft.git@06fd06a4d2e8ed8c3a253c67d9c3cb23e0f497ad
23
  transformers==4.30.2
24
  tokenizers==0.13.3
25
  APScheduler==3.10.1
 
45
  nltk==3.8.1
46
  textstat==0.7.3
47
  # pandoc==2.3
48
+ pypandoc==1.11; sys_platform == "darwin" and platform_machine == "arm64"
49
+ pypandoc_binary==1.11; platform_machine == "x86_64"
50
  openpyxl==3.1.2
51
  lm_dataformat==0.0.20
52
  bioc==2.0
utils.py CHANGED
@@ -97,6 +97,8 @@ def get_device():
97
  import torch
98
  if torch.cuda.is_available():
99
  device = "cuda"
 
 
100
  else:
101
  device = "cpu"
102
 
@@ -138,7 +140,7 @@ def system_info():
138
  gpu_memory_frac_dict = {k: gpu_memory_free_dict[k] / gpu_memory_total_dict[k] for k in gpu_memory_total_dict}
139
  for k, v in gpu_memory_frac_dict.items():
140
  system[f'GPU_M/%s' % k] = v
141
- except ModuleNotFoundError:
142
  pass
143
  system['hash'] = get_githash()
144
 
@@ -926,3 +928,60 @@ class FakeTokenizer:
926
 
927
  def __call__(self, x, *args, **kwargs):
928
  return self.encode(x, *args, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  import torch
98
  if torch.cuda.is_available():
99
  device = "cuda"
100
+ elif torch.backends.mps.is_built():
101
+ device = "mps"
102
  else:
103
  device = "cpu"
104
 
 
140
  gpu_memory_frac_dict = {k: gpu_memory_free_dict[k] / gpu_memory_total_dict[k] for k in gpu_memory_total_dict}
141
  for k, v in gpu_memory_frac_dict.items():
142
  system[f'GPU_M/%s' % k] = v
143
+ except (KeyError, ModuleNotFoundError):
144
  pass
145
  system['hash'] = get_githash()
146
 
 
928
 
929
  def __call__(self, x, *args, **kwargs):
930
  return self.encode(x, *args, **kwargs)
931
+
932
+
933
+ def get_local_ip():
934
+ import socket
935
+ s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
936
+ try:
937
+ # doesn't even have to be reachable
938
+ s.connect(('10.255.255.255', 1))
939
+ IP = s.getsockname()[0]
940
+ except Exception:
941
+ IP = '127.0.0.1'
942
+ finally:
943
+ s.close()
944
+ return IP
945
+
946
+
947
+ try:
948
+ assert pkg_resources.get_distribution('langchain') is not None
949
+ have_langchain = True
950
+ except (pkg_resources.DistributionNotFound, AssertionError):
951
+ have_langchain = False
952
+
953
+
954
+ import distutils.spawn
955
+
956
+ have_tesseract = distutils.spawn.find_executable("tesseract")
957
+ have_libreoffice = distutils.spawn.find_executable("libreoffice")
958
+
959
+ import pkg_resources
960
+
961
+ try:
962
+ assert pkg_resources.get_distribution('arxiv') is not None
963
+ assert pkg_resources.get_distribution('pymupdf') is not None
964
+ have_arxiv = True
965
+ except (pkg_resources.DistributionNotFound, AssertionError):
966
+ have_arxiv = False
967
+
968
+ try:
969
+ assert pkg_resources.get_distribution('pymupdf') is not None
970
+ have_pymupdf = True
971
+ except (pkg_resources.DistributionNotFound, AssertionError):
972
+ have_pymupdf = False
973
+
974
+ try:
975
+ assert pkg_resources.get_distribution('selenium') is not None
976
+ have_selenium = True
977
+ except (pkg_resources.DistributionNotFound, AssertionError):
978
+ have_selenium = False
979
+
980
+ try:
981
+ assert pkg_resources.get_distribution('playwright') is not None
982
+ have_playwright = True
983
+ except (pkg_resources.DistributionNotFound, AssertionError):
984
+ have_playwright = False
985
+
986
+ # disable, hangs too often
987
+ have_playwright = False