Spaces:
Runtime error
Runtime error
pseudotensor
commited on
Commit
·
edf6dca
1
Parent(s):
ce8ae40
Update with h2oGPT hash a9971663accc92add02bde0be7622726ef2db350
Browse files- client_test.py +5 -3
- enums.py +12 -9
- evaluate_params.py +1 -1
- gen.py +117 -50
- gpt_langchain.py +67 -92
- gradio_runner.py +572 -446
- gradio_themes.py +8 -0
- gradio_utils/__pycache__/css.cpython-310.pyc +0 -0
- gradio_utils/__pycache__/prompt_form.cpython-310.pyc +0 -0
- gradio_utils/css.py +4 -1
- gradio_utils/prompt_form.py +0 -27
- loaders.py +18 -10
- prompter.py +69 -7
- requirements.txt +4 -4
- utils.py +60 -1
client_test.py
CHANGED
@@ -7,7 +7,7 @@ python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6_9b
|
|
7 |
|
8 |
NOTE: For private models, add --use-auth_token=True
|
9 |
|
10 |
-
NOTE: --
|
11 |
Currently, this will force model to be on a single GPU.
|
12 |
|
13 |
Then run this client as:
|
@@ -98,7 +98,8 @@ def get_args(prompt, prompt_type, chat=False, stream_output=False,
|
|
98 |
top_k_docs=top_k_docs,
|
99 |
chunk=True,
|
100 |
chunk_size=512,
|
101 |
-
|
|
|
102 |
)
|
103 |
from evaluate_params import eval_func_param_names
|
104 |
assert len(set(eval_func_param_names).difference(set(list(kwargs.keys())))) == 0
|
@@ -203,7 +204,8 @@ def run_client_nochat_api_lean_morestuff(prompt, prompt_type='human_bot', max_ne
|
|
203 |
langchain_mode='Disabled',
|
204 |
langchain_action=LangChainAction.QUERY.value,
|
205 |
top_k_docs=4,
|
206 |
-
|
|
|
207 |
)
|
208 |
|
209 |
api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
|
|
|
7 |
|
8 |
NOTE: For private models, add --use-auth_token=True
|
9 |
|
10 |
+
NOTE: --use_gpu_id=True (default) must be used for multi-GPU in case see failures with cuda:x cuda:y mismatches.
|
11 |
Currently, this will force model to be on a single GPU.
|
12 |
|
13 |
Then run this client as:
|
|
|
98 |
top_k_docs=top_k_docs,
|
99 |
chunk=True,
|
100 |
chunk_size=512,
|
101 |
+
document_subset=DocumentChoices.Relevant.name,
|
102 |
+
document_choice=[],
|
103 |
)
|
104 |
from evaluate_params import eval_func_param_names
|
105 |
assert len(set(eval_func_param_names).difference(set(list(kwargs.keys())))) == 0
|
|
|
204 |
langchain_mode='Disabled',
|
205 |
langchain_action=LangChainAction.QUERY.value,
|
206 |
top_k_docs=4,
|
207 |
+
document_subset=DocumentChoices.Relevant.name,
|
208 |
+
document_choice=[],
|
209 |
)
|
210 |
|
211 |
api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
|
enums.py
CHANGED
@@ -28,16 +28,21 @@ class PromptType(Enum):
|
|
28 |
gptj = 22
|
29 |
prompt_answer_openllama = 23
|
30 |
vicuna11 = 24
|
|
|
|
|
|
|
31 |
|
32 |
|
33 |
class DocumentChoices(Enum):
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
Just_LLM = 3
|
38 |
|
39 |
|
40 |
-
non_query_commands = [
|
|
|
|
|
|
|
41 |
|
42 |
|
43 |
class LangChainMode(Enum):
|
@@ -60,7 +65,7 @@ class LangChainAction(Enum):
|
|
60 |
|
61 |
QUERY = "Query"
|
62 |
# WIP:
|
63 |
-
#SUMMARIZE_MAP = "Summarize_map_reduce"
|
64 |
SUMMARIZE_MAP = "Summarize"
|
65 |
SUMMARIZE_ALL = "Summarize_all"
|
66 |
SUMMARIZE_REFINE = "Summarize_refine"
|
@@ -68,7 +73,6 @@ class LangChainAction(Enum):
|
|
68 |
|
69 |
no_server_str = no_lora_str = no_model_str = '[None/Remove]'
|
70 |
|
71 |
-
|
72 |
# from site-packages/langchain/llms/openai.py
|
73 |
# but needed since ChatOpenAI doesn't have this information
|
74 |
model_token_mapping = {
|
@@ -77,7 +81,7 @@ model_token_mapping = {
|
|
77 |
"gpt-4-32k": 32768,
|
78 |
"gpt-4-32k-0314": 32768,
|
79 |
"gpt-3.5-turbo": 4096,
|
80 |
-
"gpt-3.5-turbo-16k": 16*1024,
|
81 |
"gpt-3.5-turbo-0301": 4096,
|
82 |
"text-ada-001": 2049,
|
83 |
"ada": 2049,
|
@@ -94,6 +98,5 @@ model_token_mapping = {
|
|
94 |
"code-cushman-001": 2048,
|
95 |
}
|
96 |
|
97 |
-
|
98 |
source_prefix = "Sources [Score | Link]:"
|
99 |
source_postfix = "End Sources<p>"
|
|
|
28 |
gptj = 22
|
29 |
prompt_answer_openllama = 23
|
30 |
vicuna11 = 24
|
31 |
+
mptinstruct = 25
|
32 |
+
mptchat = 26
|
33 |
+
falcon = 27
|
34 |
|
35 |
|
36 |
class DocumentChoices(Enum):
|
37 |
+
Relevant = 0
|
38 |
+
Sources = 1
|
39 |
+
All = 2
|
|
|
40 |
|
41 |
|
42 |
+
non_query_commands = [
|
43 |
+
DocumentChoices.Sources.name,
|
44 |
+
DocumentChoices.All.name
|
45 |
+
]
|
46 |
|
47 |
|
48 |
class LangChainMode(Enum):
|
|
|
65 |
|
66 |
QUERY = "Query"
|
67 |
# WIP:
|
68 |
+
# SUMMARIZE_MAP = "Summarize_map_reduce"
|
69 |
SUMMARIZE_MAP = "Summarize"
|
70 |
SUMMARIZE_ALL = "Summarize_all"
|
71 |
SUMMARIZE_REFINE = "Summarize_refine"
|
|
|
73 |
|
74 |
no_server_str = no_lora_str = no_model_str = '[None/Remove]'
|
75 |
|
|
|
76 |
# from site-packages/langchain/llms/openai.py
|
77 |
# but needed since ChatOpenAI doesn't have this information
|
78 |
model_token_mapping = {
|
|
|
81 |
"gpt-4-32k": 32768,
|
82 |
"gpt-4-32k-0314": 32768,
|
83 |
"gpt-3.5-turbo": 4096,
|
84 |
+
"gpt-3.5-turbo-16k": 16 * 1024,
|
85 |
"gpt-3.5-turbo-0301": 4096,
|
86 |
"text-ada-001": 2049,
|
87 |
"ada": 2049,
|
|
|
98 |
"code-cushman-001": 2048,
|
99 |
}
|
100 |
|
|
|
101 |
source_prefix = "Sources [Score | Link]:"
|
102 |
source_postfix = "End Sources<p>"
|
evaluate_params.py
CHANGED
@@ -34,6 +34,7 @@ eval_func_param_names = ['instruction',
|
|
34 |
'top_k_docs',
|
35 |
'chunk',
|
36 |
'chunk_size',
|
|
|
37 |
'document_choice',
|
38 |
]
|
39 |
|
@@ -43,5 +44,4 @@ for k in no_default_param_names:
|
|
43 |
if k in eval_func_param_names_defaults:
|
44 |
eval_func_param_names_defaults.remove(k)
|
45 |
|
46 |
-
|
47 |
eval_extra_columns = ['prompt', 'response', 'score']
|
|
|
34 |
'top_k_docs',
|
35 |
'chunk',
|
36 |
'chunk_size',
|
37 |
+
'document_subset',
|
38 |
'document_choice',
|
39 |
]
|
40 |
|
|
|
44 |
if k in eval_func_param_names_defaults:
|
45 |
eval_func_param_names_defaults.remove(k)
|
46 |
|
|
|
47 |
eval_extra_columns = ['prompt', 'response', 'score']
|
gen.py
CHANGED
@@ -32,7 +32,8 @@ from enums import DocumentChoices, LangChainMode, no_lora_str, model_token_mappi
|
|
32 |
source_postfix, LangChainAction
|
33 |
from loaders import get_loaders
|
34 |
from utils import set_seed, clear_torch_cache, save_generate_output, NullContext, wrapped_partial, EThread, get_githash, \
|
35 |
-
import_matplotlib, get_device, makedirs, get_kwargs, start_faulthandler, get_hf_server, FakeTokenizer, remove
|
|
|
36 |
|
37 |
start_faulthandler()
|
38 |
import_matplotlib()
|
@@ -60,7 +61,9 @@ def main(
|
|
60 |
load_8bit: bool = False,
|
61 |
load_4bit: bool = False,
|
62 |
load_half: bool = True,
|
63 |
-
|
|
|
|
|
64 |
base_model: str = '',
|
65 |
tokenizer_base_model: str = '',
|
66 |
lora_weights: str = "",
|
@@ -91,7 +94,7 @@ def main(
|
|
91 |
memory_restriction_level: int = None,
|
92 |
debug: bool = False,
|
93 |
save_dir: str = None,
|
94 |
-
share: bool =
|
95 |
local_files_only: bool = False,
|
96 |
resume_download: bool = True,
|
97 |
use_auth_token: Union[str, bool] = False,
|
@@ -138,14 +141,15 @@ def main(
|
|
138 |
eval_prompts_only_seed: int = 1234,
|
139 |
eval_as_output: bool = False,
|
140 |
|
141 |
-
langchain_mode: str =
|
142 |
langchain_action: str = LangChainAction.QUERY.value,
|
143 |
force_langchain_evaluate: bool = False,
|
144 |
visible_langchain_modes: list = ['UserData', 'MyData'],
|
145 |
# WIP:
|
146 |
# visible_langchain_actions: list = langchain_actions.copy(),
|
147 |
visible_langchain_actions: list = [LangChainAction.QUERY.value, LangChainAction.SUMMARIZE_MAP.value],
|
148 |
-
|
|
|
149 |
user_path: str = None,
|
150 |
detect_user_path_changes_every_query: bool = False,
|
151 |
load_db_if_exists: bool = True,
|
@@ -177,11 +181,13 @@ def main(
|
|
177 |
:param load_8bit: load model in 8-bit using bitsandbytes
|
178 |
:param load_4bit: load model in 4-bit using bitsandbytes
|
179 |
:param load_half: load model in float16
|
180 |
-
:param
|
|
|
|
|
181 |
:param base_model: model HF-type name. If use --base_model to preload model, cannot unload in gradio in models tab
|
182 |
:param tokenizer_base_model: tokenizer HF-type name. Usually not required, inferred from base_model.
|
183 |
:param lora_weights: LORA weights path/HF link
|
184 |
-
:param gpu_id: if
|
185 |
:param compile_model Whether to compile the model
|
186 |
:param use_cache: Whether to use caching in model (some models fail when multiple threads use)
|
187 |
:param inference_server: Consume base_model as type of model at this address
|
@@ -289,7 +295,8 @@ def main(
|
|
289 |
Default: If only want to consume local files, e.g. prepared by make_db.py, only include ['UserData']
|
290 |
FIXME: Avoid 'All' for now, not implemented
|
291 |
:param visible_langchain_actions: Which actions to allow
|
292 |
-
:param
|
|
|
293 |
:param load_db_if_exists: Whether to load chroma db if exists or re-generate db
|
294 |
:param keep_sources_in_context: Whether to keep url sources in context, not helpful usually
|
295 |
:param db_type: 'faiss' for in-memory or 'chroma' or 'weaviate' for persisted on disk
|
@@ -379,10 +386,12 @@ def main(
|
|
379 |
# allow enabling langchain via ENV
|
380 |
# FIRST PLACE where LangChain referenced, but no imports related to it
|
381 |
langchain_mode = os.environ.get("LANGCHAIN_MODE", langchain_mode)
|
382 |
-
|
|
|
383 |
visible_langchain_modes = ast.literal_eval(os.environ.get("visible_langchain_modes", str(visible_langchain_modes)))
|
384 |
if langchain_mode not in visible_langchain_modes and langchain_mode in langchain_modes:
|
385 |
-
|
|
|
386 |
|
387 |
assert langchain_action in langchain_actions, "Invalid langchain_action %s" % langchain_action
|
388 |
|
@@ -392,6 +401,25 @@ def main(
|
|
392 |
if LangChainMode.USER_DATA.value not in visible_langchain_modes:
|
393 |
allow_upload_to_user_data = False
|
394 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
395 |
if is_public:
|
396 |
allow_upload_to_user_data = False
|
397 |
input_lines = 1 # ensure set, for ease of use
|
@@ -458,7 +486,9 @@ def main(
|
|
458 |
load_8bit = False
|
459 |
load_4bit = False
|
460 |
load_half = False
|
461 |
-
|
|
|
|
|
462 |
torch.backends.cudnn.benchmark = True
|
463 |
torch.backends.cudnn.enabled = False
|
464 |
torch.set_default_dtype(torch.float32)
|
@@ -714,7 +744,9 @@ def get_config(base_model,
|
|
714 |
return config, model
|
715 |
|
716 |
|
717 |
-
def get_non_lora_model(base_model, model_loader, load_half,
|
|
|
|
|
718 |
config, model,
|
719 |
gpu_id=0,
|
720 |
):
|
@@ -761,16 +793,25 @@ def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward
|
|
761 |
load_in_8bit = model_kwargs.get('load_in_8bit', False)
|
762 |
load_in_4bit = model_kwargs.get('load_in_4bit', False)
|
763 |
model_kwargs['device_map'] = device_map
|
|
|
764 |
pop_unused_model_kwargs(model_kwargs)
|
765 |
|
766 |
-
if
|
767 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
768 |
base_model,
|
769 |
config=config,
|
770 |
**model_kwargs,
|
771 |
)
|
772 |
else:
|
773 |
-
model = model_loader
|
774 |
base_model,
|
775 |
config=config,
|
776 |
**model_kwargs,
|
@@ -778,7 +819,7 @@ def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward
|
|
778 |
return model
|
779 |
|
780 |
|
781 |
-
def get_client_from_inference_server(inference_server, raise_connection_exception=False):
|
782 |
inference_server, headers = get_hf_server(inference_server)
|
783 |
# preload client since slow for gradio case especially
|
784 |
from gradio_utils.grclient import GradioClient
|
@@ -786,7 +827,7 @@ def get_client_from_inference_server(inference_server, raise_connection_exceptio
|
|
786 |
hf_client = None
|
787 |
if headers is None:
|
788 |
try:
|
789 |
-
print("GR Client Begin: %s" % inference_server, flush=True)
|
790 |
# first do sanity check if alive, else gradio client takes too long by default
|
791 |
requests.get(inference_server, timeout=int(os.getenv('REQUEST_TIMEOUT', '30')))
|
792 |
gr_client = GradioClient(inference_server)
|
@@ -794,19 +835,19 @@ def get_client_from_inference_server(inference_server, raise_connection_exceptio
|
|
794 |
except (OSError, ValueError) as e:
|
795 |
# Occurs when wrong endpoint and should have been HF client, so don't hard raise, just move to HF
|
796 |
gr_client = None
|
797 |
-
print("GR Client Failed %s: %s" % (inference_server, str(e)), flush=True)
|
798 |
except (ConnectTimeoutError, ConnectTimeout, MaxRetryError, ConnectionError, ConnectionError2,
|
799 |
JSONDecodeError, ReadTimeout2, KeyError) as e:
|
800 |
t, v, tb = sys.exc_info()
|
801 |
ex = ''.join(traceback.format_exception(t, v, tb))
|
802 |
-
print("GR Client Failed %s: %s" % (inference_server, str(ex)), flush=True)
|
803 |
if raise_connection_exception:
|
804 |
raise
|
805 |
|
806 |
if gr_client is None:
|
807 |
res = None
|
808 |
from text_generation import Client as HFClient
|
809 |
-
print("HF Client Begin: %s" % inference_server)
|
810 |
try:
|
811 |
hf_client = HFClient(inference_server, headers=headers, timeout=int(os.getenv('REQUEST_TIMEOUT', '30')))
|
812 |
# quick check valid TGI endpoint
|
@@ -817,10 +858,10 @@ def get_client_from_inference_server(inference_server, raise_connection_exceptio
|
|
817 |
hf_client = None
|
818 |
t, v, tb = sys.exc_info()
|
819 |
ex = ''.join(traceback.format_exception(t, v, tb))
|
820 |
-
print("HF Client Failed %s: %s" % (inference_server, str(ex)))
|
821 |
if raise_connection_exception:
|
822 |
raise
|
823 |
-
print("HF Client End: %s %s" % (inference_server, res))
|
824 |
return inference_server, gr_client, hf_client
|
825 |
|
826 |
|
@@ -828,7 +869,9 @@ def get_model(
|
|
828 |
load_8bit: bool = False,
|
829 |
load_4bit: bool = False,
|
830 |
load_half: bool = True,
|
831 |
-
|
|
|
|
|
832 |
base_model: str = '',
|
833 |
inference_server: str = "",
|
834 |
tokenizer_base_model: str = '',
|
@@ -850,7 +893,9 @@ def get_model(
|
|
850 |
:param load_8bit: load model in 8-bit, not supported by all models
|
851 |
:param load_4bit: load model in 4-bit, not supported by all models
|
852 |
:param load_half: load model in 16-bit
|
853 |
-
:param
|
|
|
|
|
854 |
For non-LORA case, False will spread shards across multiple GPUs, but this can lead to cuda:x cuda:y mismatches
|
855 |
So it is not the default
|
856 |
:param base_model: name/path of base model
|
@@ -868,8 +913,7 @@ def get_model(
|
|
868 |
:param verbose:
|
869 |
:return:
|
870 |
"""
|
871 |
-
|
872 |
-
print("Get %s model" % base_model, flush=True)
|
873 |
|
874 |
triton_attn = False
|
875 |
long_sequence = True
|
@@ -893,7 +937,8 @@ def get_model(
|
|
893 |
print("Detected as llama type from"
|
894 |
" config (%s) or name (%s)" % (llama_type_from_config, llama_type_from_name), flush=True)
|
895 |
|
896 |
-
model_loader, tokenizer_loader = get_loaders(model_name=base_model, reward_type=reward_type, llama_type=llama_type
|
|
|
897 |
|
898 |
tokenizer_kwargs = dict(local_files_only=local_files_only,
|
899 |
resume_download=resume_download,
|
@@ -917,7 +962,8 @@ def get_model(
|
|
917 |
tokenizer = FakeTokenizer()
|
918 |
|
919 |
if isinstance(inference_server, str) and inference_server.startswith("http"):
|
920 |
-
inference_server, gr_client, hf_client = get_client_from_inference_server(inference_server
|
|
|
921 |
client = gr_client or hf_client
|
922 |
# Don't return None, None for model, tokenizer so triggers
|
923 |
return client, tokenizer, 'http'
|
@@ -937,7 +983,9 @@ def get_model(
|
|
937 |
return get_hf_model(load_8bit=load_8bit,
|
938 |
load_4bit=load_4bit,
|
939 |
load_half=load_half,
|
940 |
-
|
|
|
|
|
941 |
base_model=base_model,
|
942 |
tokenizer_base_model=tokenizer_base_model,
|
943 |
lora_weights=lora_weights,
|
@@ -961,7 +1009,9 @@ def get_model(
|
|
961 |
def get_hf_model(load_8bit: bool = False,
|
962 |
load_4bit: bool = False,
|
963 |
load_half: bool = True,
|
964 |
-
|
|
|
|
|
965 |
base_model: str = '',
|
966 |
tokenizer_base_model: str = '',
|
967 |
lora_weights: str = "",
|
@@ -998,7 +1048,8 @@ def get_hf_model(load_8bit: bool = False,
|
|
998 |
"Please choose a base model with --base_model (CLI) or load one from Models Tab (gradio)"
|
999 |
)
|
1000 |
|
1001 |
-
model_loader, tokenizer_loader = get_loaders(model_name=base_model, reward_type=reward_type, llama_type=llama_type
|
|
|
1002 |
|
1003 |
config, _ = get_config(base_model, return_model=False, raise_exception=True, **config_kwargs)
|
1004 |
|
@@ -1015,7 +1066,7 @@ def get_hf_model(load_8bit: bool = False,
|
|
1015 |
device=0 if device == "cuda" else -1,
|
1016 |
torch_dtype=torch.float16 if device == 'cuda' else torch.float32)
|
1017 |
else:
|
1018 |
-
assert device in ["cuda", "cpu"], "Unsupported device %s" % device
|
1019 |
model_kwargs = dict(local_files_only=local_files_only,
|
1020 |
torch_dtype=torch.float16 if device == 'cuda' else torch.float32,
|
1021 |
resume_download=resume_download,
|
@@ -1024,11 +1075,16 @@ def get_hf_model(load_8bit: bool = False,
|
|
1024 |
offload_folder=offload_folder,
|
1025 |
)
|
1026 |
if 'mbart-' not in base_model.lower() and 'mpt-' not in base_model.lower():
|
|
|
|
|
|
|
|
|
1027 |
model_kwargs.update(dict(load_in_8bit=load_8bit,
|
1028 |
load_in_4bit=load_4bit,
|
1029 |
-
device_map=
|
1030 |
))
|
1031 |
if 'mpt-' in base_model.lower() and gpu_id is not None and gpu_id >= 0:
|
|
|
1032 |
model_kwargs.update(dict(device_map={"": gpu_id} if device == 'cuda' else "cpu"))
|
1033 |
|
1034 |
if 'OpenAssistant/reward-model'.lower() in base_model.lower():
|
@@ -1038,29 +1094,32 @@ def get_hf_model(load_8bit: bool = False,
|
|
1038 |
pop_unused_model_kwargs(model_kwargs)
|
1039 |
|
1040 |
if not lora_weights:
|
1041 |
-
|
|
|
|
|
1042 |
|
1043 |
-
if
|
1044 |
config, model = get_config(base_model, return_model=True, raise_exception=True, **config_kwargs)
|
1045 |
-
model = get_non_lora_model(base_model, model_loader, load_half,
|
|
|
1046 |
config, model,
|
1047 |
gpu_id=gpu_id,
|
1048 |
)
|
1049 |
else:
|
1050 |
config, _ = get_config(base_model, **config_kwargs)
|
1051 |
-
if load_half and not (load_8bit or load_4bit):
|
1052 |
-
model = model_loader
|
1053 |
base_model,
|
1054 |
config=config,
|
1055 |
**model_kwargs).half()
|
1056 |
else:
|
1057 |
-
model = model_loader
|
1058 |
base_model,
|
1059 |
config=config,
|
1060 |
**model_kwargs)
|
1061 |
elif load_8bit or load_4bit:
|
1062 |
config, _ = get_config(base_model, **config_kwargs)
|
1063 |
-
model = model_loader
|
1064 |
base_model,
|
1065 |
config=config,
|
1066 |
**model_kwargs
|
@@ -1080,7 +1139,7 @@ def get_hf_model(load_8bit: bool = False,
|
|
1080 |
else:
|
1081 |
with torch.device(device):
|
1082 |
config, _ = get_config(base_model, raise_exception=True, **config_kwargs)
|
1083 |
-
model = model_loader
|
1084 |
base_model,
|
1085 |
config=config,
|
1086 |
**model_kwargs
|
@@ -1097,7 +1156,7 @@ def get_hf_model(load_8bit: bool = False,
|
|
1097 |
offload_folder=offload_folder,
|
1098 |
device_map="auto",
|
1099 |
)
|
1100 |
-
if load_half:
|
1101 |
model.half()
|
1102 |
|
1103 |
# unwind broken decapoda-research config
|
@@ -1156,7 +1215,8 @@ def get_score_model(score_model: str = None,
|
|
1156 |
load_8bit: bool = False,
|
1157 |
load_4bit: bool = False,
|
1158 |
load_half: bool = True,
|
1159 |
-
|
|
|
1160 |
base_model: str = '',
|
1161 |
inference_server: str = '',
|
1162 |
tokenizer_base_model: str = '',
|
@@ -1177,6 +1237,8 @@ def get_score_model(score_model: str = None,
|
|
1177 |
load_8bit = False
|
1178 |
load_4bit = False
|
1179 |
load_half = False
|
|
|
|
|
1180 |
base_model = score_model.strip()
|
1181 |
tokenizer_base_model = ''
|
1182 |
lora_weights = ''
|
@@ -1219,6 +1281,7 @@ def evaluate(
|
|
1219 |
top_k_docs,
|
1220 |
chunk,
|
1221 |
chunk_size,
|
|
|
1222 |
document_choice,
|
1223 |
# END NOTE: Examples must have same order of parameters
|
1224 |
src_lang=None,
|
@@ -1435,6 +1498,7 @@ def evaluate(
|
|
1435 |
chunk_size=chunk_size,
|
1436 |
langchain_mode=langchain_mode,
|
1437 |
langchain_action=langchain_action,
|
|
|
1438 |
document_choice=document_choice,
|
1439 |
db_type=db_type,
|
1440 |
top_k_docs=top_k_docs,
|
@@ -1462,6 +1526,7 @@ def evaluate(
|
|
1462 |
inference_server=inference_server,
|
1463 |
langchain_mode=langchain_mode,
|
1464 |
langchain_action=langchain_action,
|
|
|
1465 |
document_choice=document_choice,
|
1466 |
num_prompt_tokens=num_prompt_tokens,
|
1467 |
instruction=instruction,
|
@@ -1563,7 +1628,8 @@ def evaluate(
|
|
1563 |
gr_client = None
|
1564 |
hf_client = model
|
1565 |
else:
|
1566 |
-
inference_server, gr_client, hf_client = get_client_from_inference_server(inference_server
|
|
|
1567 |
|
1568 |
# quick sanity check to avoid long timeouts, just see if can reach server
|
1569 |
requests.get(inference_server, timeout=int(os.getenv('REQUEST_TIMEOUT_FAST', '10')))
|
@@ -1631,7 +1697,8 @@ def evaluate(
|
|
1631 |
top_k_docs=top_k_docs,
|
1632 |
chunk=chunk,
|
1633 |
chunk_size=chunk_size,
|
1634 |
-
|
|
|
1635 |
)
|
1636 |
api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
|
1637 |
if not stream_output:
|
@@ -1830,7 +1897,7 @@ def evaluate(
|
|
1830 |
|
1831 |
with torch.no_grad():
|
1832 |
have_lora_weights = lora_weights not in [no_lora_str, '', None]
|
1833 |
-
context_class_cast = NullContext if device == 'cpu' or have_lora_weights else torch.autocast
|
1834 |
with context_class_cast(device):
|
1835 |
# protection for gradio not keeping track of closed users,
|
1836 |
# else hit bitsandbytes lack of thread safety:
|
@@ -2207,8 +2274,8 @@ y = np.random.randint(0, 1, 100)
|
|
2207 |
|
2208 |
# move to correct position
|
2209 |
for example in examples:
|
2210 |
-
example += [chat, '', '',
|
2211 |
-
top_k_docs, chunk, chunk_size, [DocumentChoices.
|
2212 |
]
|
2213 |
# adjust examples if non-chat mode
|
2214 |
if not chat:
|
@@ -2431,9 +2498,9 @@ def entrypoint_main():
|
|
2431 |
|
2432 |
python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='human_bot' --lora_weights='GPT-NeoXT-Chat-Base-20B.merged.json.8_epochs.57b2892c53df5b8cefac45f84d019cace803ef26.28'
|
2433 |
|
2434 |
-
must have 4*48GB GPU and run without 8bit in order for sharding to work with
|
2435 |
can also pass --prompt_type='human_bot' and model can somewhat handle instructions without being instruct tuned
|
2436 |
-
python generate.py --base_model=decapoda-research/llama-65b-hf --load_8bit=False --
|
2437 |
|
2438 |
python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6_9b
|
2439 |
"""
|
|
|
32 |
source_postfix, LangChainAction
|
33 |
from loaders import get_loaders
|
34 |
from utils import set_seed, clear_torch_cache, save_generate_output, NullContext, wrapped_partial, EThread, get_githash, \
|
35 |
+
import_matplotlib, get_device, makedirs, get_kwargs, start_faulthandler, get_hf_server, FakeTokenizer, remove, \
|
36 |
+
have_langchain
|
37 |
|
38 |
start_faulthandler()
|
39 |
import_matplotlib()
|
|
|
61 |
load_8bit: bool = False,
|
62 |
load_4bit: bool = False,
|
63 |
load_half: bool = True,
|
64 |
+
load_gptq: str = '',
|
65 |
+
use_safetensors: bool = False,
|
66 |
+
use_gpu_id: bool = True,
|
67 |
base_model: str = '',
|
68 |
tokenizer_base_model: str = '',
|
69 |
lora_weights: str = "",
|
|
|
94 |
memory_restriction_level: int = None,
|
95 |
debug: bool = False,
|
96 |
save_dir: str = None,
|
97 |
+
share: bool = False,
|
98 |
local_files_only: bool = False,
|
99 |
resume_download: bool = True,
|
100 |
use_auth_token: Union[str, bool] = False,
|
|
|
141 |
eval_prompts_only_seed: int = 1234,
|
142 |
eval_as_output: bool = False,
|
143 |
|
144 |
+
langchain_mode: str = None,
|
145 |
langchain_action: str = LangChainAction.QUERY.value,
|
146 |
force_langchain_evaluate: bool = False,
|
147 |
visible_langchain_modes: list = ['UserData', 'MyData'],
|
148 |
# WIP:
|
149 |
# visible_langchain_actions: list = langchain_actions.copy(),
|
150 |
visible_langchain_actions: list = [LangChainAction.QUERY.value, LangChainAction.SUMMARIZE_MAP.value],
|
151 |
+
document_subset: str = DocumentChoices.Relevant.name,
|
152 |
+
document_choice: list = [],
|
153 |
user_path: str = None,
|
154 |
detect_user_path_changes_every_query: bool = False,
|
155 |
load_db_if_exists: bool = True,
|
|
|
181 |
:param load_8bit: load model in 8-bit using bitsandbytes
|
182 |
:param load_4bit: load model in 4-bit using bitsandbytes
|
183 |
:param load_half: load model in float16
|
184 |
+
:param load_gptq: to load model with GPTQ, put model_basename here, e.g. gptq_model-4bit--1g
|
185 |
+
:param use_safetensors: to use safetensors version (assumes file/HF points to safe tensors version)
|
186 |
+
:param use_gpu_id: whether to control devices with gpu_id. If False, then spread across GPUs
|
187 |
:param base_model: model HF-type name. If use --base_model to preload model, cannot unload in gradio in models tab
|
188 |
:param tokenizer_base_model: tokenizer HF-type name. Usually not required, inferred from base_model.
|
189 |
:param lora_weights: LORA weights path/HF link
|
190 |
+
:param gpu_id: if use_gpu_id, then use gpu_id for cuda device ID, or auto mode if gpu_id != -1
|
191 |
:param compile_model Whether to compile the model
|
192 |
:param use_cache: Whether to use caching in model (some models fail when multiple threads use)
|
193 |
:param inference_server: Consume base_model as type of model at this address
|
|
|
295 |
Default: If only want to consume local files, e.g. prepared by make_db.py, only include ['UserData']
|
296 |
FIXME: Avoid 'All' for now, not implemented
|
297 |
:param visible_langchain_actions: Which actions to allow
|
298 |
+
:param document_subset: Default document choice when taking subset of collection
|
299 |
+
:param document_choice: Chosen document(s) by internal name
|
300 |
:param load_db_if_exists: Whether to load chroma db if exists or re-generate db
|
301 |
:param keep_sources_in_context: Whether to keep url sources in context, not helpful usually
|
302 |
:param db_type: 'faiss' for in-memory or 'chroma' or 'weaviate' for persisted on disk
|
|
|
386 |
# allow enabling langchain via ENV
|
387 |
# FIRST PLACE where LangChain referenced, but no imports related to it
|
388 |
langchain_mode = os.environ.get("LANGCHAIN_MODE", langchain_mode)
|
389 |
+
if langchain_mode is not None:
|
390 |
+
assert langchain_mode in langchain_modes, "Invalid langchain_mode %s" % langchain_mode
|
391 |
visible_langchain_modes = ast.literal_eval(os.environ.get("visible_langchain_modes", str(visible_langchain_modes)))
|
392 |
if langchain_mode not in visible_langchain_modes and langchain_mode in langchain_modes:
|
393 |
+
if langchain_mode is not None:
|
394 |
+
visible_langchain_modes += [langchain_mode]
|
395 |
|
396 |
assert langchain_action in langchain_actions, "Invalid langchain_action %s" % langchain_action
|
397 |
|
|
|
401 |
if LangChainMode.USER_DATA.value not in visible_langchain_modes:
|
402 |
allow_upload_to_user_data = False
|
403 |
|
404 |
+
# auto-set langchain_mode
|
405 |
+
if have_langchain and langchain_mode is None:
|
406 |
+
if allow_upload_to_user_data and not is_public and user_path:
|
407 |
+
langchain_mode = 'UserData'
|
408 |
+
print("Auto set langchain_mode=%s" % langchain_mode, flush=True)
|
409 |
+
elif allow_upload_to_my_data:
|
410 |
+
langchain_mode = 'MyData'
|
411 |
+
print("Auto set langchain_mode=%s."
|
412 |
+
" To use UserData to pull files from disk,"
|
413 |
+
" set user_path and ensure allow_upload_to_user_data=True" % langchain_mode, flush=True)
|
414 |
+
else:
|
415 |
+
raise RuntimeError("Please pass --langchain_mode=<chosen mode> out of %s" % langchain_modes)
|
416 |
+
if not have_langchain and langchain_mode not in [None, LangChainMode.DISABLED.value, LangChainMode.LLM.value, LangChainMode.CHAT_LLM.value]:
|
417 |
+
raise RuntimeError("Asked for LangChain mode but langchain python package cannot be found.")
|
418 |
+
if langchain_mode is None:
|
419 |
+
# if not set yet, disable
|
420 |
+
langchain_mode = LangChainMode.DISABLED.value
|
421 |
+
print("Auto set langchain_mode=%s" % langchain_mode, flush=True)
|
422 |
+
|
423 |
if is_public:
|
424 |
allow_upload_to_user_data = False
|
425 |
input_lines = 1 # ensure set, for ease of use
|
|
|
486 |
load_8bit = False
|
487 |
load_4bit = False
|
488 |
load_half = False
|
489 |
+
load_gptq = ''
|
490 |
+
use_safetensors = False
|
491 |
+
use_gpu_id = False
|
492 |
torch.backends.cudnn.benchmark = True
|
493 |
torch.backends.cudnn.enabled = False
|
494 |
torch.set_default_dtype(torch.float32)
|
|
|
744 |
return config, model
|
745 |
|
746 |
|
747 |
+
def get_non_lora_model(base_model, model_loader, load_half,
|
748 |
+
load_gptq, use_safetensors,
|
749 |
+
model_kwargs, reward_type,
|
750 |
config, model,
|
751 |
gpu_id=0,
|
752 |
):
|
|
|
793 |
load_in_8bit = model_kwargs.get('load_in_8bit', False)
|
794 |
load_in_4bit = model_kwargs.get('load_in_4bit', False)
|
795 |
model_kwargs['device_map'] = device_map
|
796 |
+
model_kwargs['use_safetensors'] = use_safetensors
|
797 |
pop_unused_model_kwargs(model_kwargs)
|
798 |
|
799 |
+
if load_gptq:
|
800 |
+
model_kwargs.pop('torch_dtype', None)
|
801 |
+
model_kwargs.pop('device_map')
|
802 |
+
model = model_loader(
|
803 |
+
model_name_or_path=base_model,
|
804 |
+
model_basename=load_gptq,
|
805 |
+
**model_kwargs,
|
806 |
+
)
|
807 |
+
elif load_in_8bit or load_in_4bit or not load_half:
|
808 |
+
model = model_loader(
|
809 |
base_model,
|
810 |
config=config,
|
811 |
**model_kwargs,
|
812 |
)
|
813 |
else:
|
814 |
+
model = model_loader(
|
815 |
base_model,
|
816 |
config=config,
|
817 |
**model_kwargs,
|
|
|
819 |
return model
|
820 |
|
821 |
|
822 |
+
def get_client_from_inference_server(inference_server, base_model=None, raise_connection_exception=False):
|
823 |
inference_server, headers = get_hf_server(inference_server)
|
824 |
# preload client since slow for gradio case especially
|
825 |
from gradio_utils.grclient import GradioClient
|
|
|
827 |
hf_client = None
|
828 |
if headers is None:
|
829 |
try:
|
830 |
+
print("GR Client Begin: %s %s" % (inference_server, base_model), flush=True)
|
831 |
# first do sanity check if alive, else gradio client takes too long by default
|
832 |
requests.get(inference_server, timeout=int(os.getenv('REQUEST_TIMEOUT', '30')))
|
833 |
gr_client = GradioClient(inference_server)
|
|
|
835 |
except (OSError, ValueError) as e:
|
836 |
# Occurs when wrong endpoint and should have been HF client, so don't hard raise, just move to HF
|
837 |
gr_client = None
|
838 |
+
print("GR Client Failed %s %s: %s" % (inference_server, base_model, str(e)), flush=True)
|
839 |
except (ConnectTimeoutError, ConnectTimeout, MaxRetryError, ConnectionError, ConnectionError2,
|
840 |
JSONDecodeError, ReadTimeout2, KeyError) as e:
|
841 |
t, v, tb = sys.exc_info()
|
842 |
ex = ''.join(traceback.format_exception(t, v, tb))
|
843 |
+
print("GR Client Failed %s %s: %s" % (inference_server, base_model, str(ex)), flush=True)
|
844 |
if raise_connection_exception:
|
845 |
raise
|
846 |
|
847 |
if gr_client is None:
|
848 |
res = None
|
849 |
from text_generation import Client as HFClient
|
850 |
+
print("HF Client Begin: %s %s" % (inference_server, base_model))
|
851 |
try:
|
852 |
hf_client = HFClient(inference_server, headers=headers, timeout=int(os.getenv('REQUEST_TIMEOUT', '30')))
|
853 |
# quick check valid TGI endpoint
|
|
|
858 |
hf_client = None
|
859 |
t, v, tb = sys.exc_info()
|
860 |
ex = ''.join(traceback.format_exception(t, v, tb))
|
861 |
+
print("HF Client Failed %s %s: %s" % (inference_server, base_model, str(ex)))
|
862 |
if raise_connection_exception:
|
863 |
raise
|
864 |
+
print("HF Client End: %s %s : %s" % (inference_server, base_model, res))
|
865 |
return inference_server, gr_client, hf_client
|
866 |
|
867 |
|
|
|
869 |
load_8bit: bool = False,
|
870 |
load_4bit: bool = False,
|
871 |
load_half: bool = True,
|
872 |
+
load_gptq: str = '',
|
873 |
+
use_safetensors: bool = False,
|
874 |
+
use_gpu_id: bool = True,
|
875 |
base_model: str = '',
|
876 |
inference_server: str = "",
|
877 |
tokenizer_base_model: str = '',
|
|
|
893 |
:param load_8bit: load model in 8-bit, not supported by all models
|
894 |
:param load_4bit: load model in 4-bit, not supported by all models
|
895 |
:param load_half: load model in 16-bit
|
896 |
+
:param load_gptq: GPTQ model_basename
|
897 |
+
:param use_safetensors: use safetensors file
|
898 |
+
:param use_gpu_id: Use torch infer of optimal placement of layers on devices (for non-lora case)
|
899 |
For non-LORA case, False will spread shards across multiple GPUs, but this can lead to cuda:x cuda:y mismatches
|
900 |
So it is not the default
|
901 |
:param base_model: name/path of base model
|
|
|
913 |
:param verbose:
|
914 |
:return:
|
915 |
"""
|
916 |
+
print("Starting get_model: %s %s" % (base_model, inference_server), flush=True)
|
|
|
917 |
|
918 |
triton_attn = False
|
919 |
long_sequence = True
|
|
|
937 |
print("Detected as llama type from"
|
938 |
" config (%s) or name (%s)" % (llama_type_from_config, llama_type_from_name), flush=True)
|
939 |
|
940 |
+
model_loader, tokenizer_loader = get_loaders(model_name=base_model, reward_type=reward_type, llama_type=llama_type,
|
941 |
+
load_gptq=load_gptq)
|
942 |
|
943 |
tokenizer_kwargs = dict(local_files_only=local_files_only,
|
944 |
resume_download=resume_download,
|
|
|
962 |
tokenizer = FakeTokenizer()
|
963 |
|
964 |
if isinstance(inference_server, str) and inference_server.startswith("http"):
|
965 |
+
inference_server, gr_client, hf_client = get_client_from_inference_server(inference_server,
|
966 |
+
base_model=base_model)
|
967 |
client = gr_client or hf_client
|
968 |
# Don't return None, None for model, tokenizer so triggers
|
969 |
return client, tokenizer, 'http'
|
|
|
983 |
return get_hf_model(load_8bit=load_8bit,
|
984 |
load_4bit=load_4bit,
|
985 |
load_half=load_half,
|
986 |
+
load_gptq=load_gptq,
|
987 |
+
use_safetensors=use_safetensors,
|
988 |
+
use_gpu_id=use_gpu_id,
|
989 |
base_model=base_model,
|
990 |
tokenizer_base_model=tokenizer_base_model,
|
991 |
lora_weights=lora_weights,
|
|
|
1009 |
def get_hf_model(load_8bit: bool = False,
|
1010 |
load_4bit: bool = False,
|
1011 |
load_half: bool = True,
|
1012 |
+
load_gptq: str = '',
|
1013 |
+
use_safetensors: bool = False,
|
1014 |
+
use_gpu_id: bool = True,
|
1015 |
base_model: str = '',
|
1016 |
tokenizer_base_model: str = '',
|
1017 |
lora_weights: str = "",
|
|
|
1048 |
"Please choose a base model with --base_model (CLI) or load one from Models Tab (gradio)"
|
1049 |
)
|
1050 |
|
1051 |
+
model_loader, tokenizer_loader = get_loaders(model_name=base_model, reward_type=reward_type, llama_type=llama_type,
|
1052 |
+
load_gptq=load_gptq)
|
1053 |
|
1054 |
config, _ = get_config(base_model, return_model=False, raise_exception=True, **config_kwargs)
|
1055 |
|
|
|
1066 |
device=0 if device == "cuda" else -1,
|
1067 |
torch_dtype=torch.float16 if device == 'cuda' else torch.float32)
|
1068 |
else:
|
1069 |
+
assert device in ["cuda", "cpu", "mps"], "Unsupported device %s" % device
|
1070 |
model_kwargs = dict(local_files_only=local_files_only,
|
1071 |
torch_dtype=torch.float16 if device == 'cuda' else torch.float32,
|
1072 |
resume_download=resume_download,
|
|
|
1075 |
offload_folder=offload_folder,
|
1076 |
)
|
1077 |
if 'mbart-' not in base_model.lower() and 'mpt-' not in base_model.lower():
|
1078 |
+
if use_gpu_id and gpu_id is not None and gpu_id >= 0 and device == 'cuda':
|
1079 |
+
device_map = {"": gpu_id}
|
1080 |
+
else:
|
1081 |
+
device_map = "auto"
|
1082 |
model_kwargs.update(dict(load_in_8bit=load_8bit,
|
1083 |
load_in_4bit=load_4bit,
|
1084 |
+
device_map=device_map,
|
1085 |
))
|
1086 |
if 'mpt-' in base_model.lower() and gpu_id is not None and gpu_id >= 0:
|
1087 |
+
# MPT doesn't support spreading over GPUs
|
1088 |
model_kwargs.update(dict(device_map={"": gpu_id} if device == 'cuda' else "cpu"))
|
1089 |
|
1090 |
if 'OpenAssistant/reward-model'.lower() in base_model.lower():
|
|
|
1094 |
pop_unused_model_kwargs(model_kwargs)
|
1095 |
|
1096 |
if not lora_weights:
|
1097 |
+
# torch.device context uses twice memory for AutoGPTQ
|
1098 |
+
context = NullContext if load_gptq else torch.device
|
1099 |
+
with context(device):
|
1100 |
|
1101 |
+
if use_gpu_id:
|
1102 |
config, model = get_config(base_model, return_model=True, raise_exception=True, **config_kwargs)
|
1103 |
+
model = get_non_lora_model(base_model, model_loader, load_half, load_gptq, use_safetensors,
|
1104 |
+
model_kwargs, reward_type,
|
1105 |
config, model,
|
1106 |
gpu_id=gpu_id,
|
1107 |
)
|
1108 |
else:
|
1109 |
config, _ = get_config(base_model, **config_kwargs)
|
1110 |
+
if load_half and not (load_8bit or load_4bit or load_gptq):
|
1111 |
+
model = model_loader(
|
1112 |
base_model,
|
1113 |
config=config,
|
1114 |
**model_kwargs).half()
|
1115 |
else:
|
1116 |
+
model = model_loader(
|
1117 |
base_model,
|
1118 |
config=config,
|
1119 |
**model_kwargs)
|
1120 |
elif load_8bit or load_4bit:
|
1121 |
config, _ = get_config(base_model, **config_kwargs)
|
1122 |
+
model = model_loader(
|
1123 |
base_model,
|
1124 |
config=config,
|
1125 |
**model_kwargs
|
|
|
1139 |
else:
|
1140 |
with torch.device(device):
|
1141 |
config, _ = get_config(base_model, raise_exception=True, **config_kwargs)
|
1142 |
+
model = model_loader(
|
1143 |
base_model,
|
1144 |
config=config,
|
1145 |
**model_kwargs
|
|
|
1156 |
offload_folder=offload_folder,
|
1157 |
device_map="auto",
|
1158 |
)
|
1159 |
+
if load_half and not load_gptq:
|
1160 |
model.half()
|
1161 |
|
1162 |
# unwind broken decapoda-research config
|
|
|
1215 |
load_8bit: bool = False,
|
1216 |
load_4bit: bool = False,
|
1217 |
load_half: bool = True,
|
1218 |
+
load_gptq: str = '',
|
1219 |
+
use_gpu_id: bool = True,
|
1220 |
base_model: str = '',
|
1221 |
inference_server: str = '',
|
1222 |
tokenizer_base_model: str = '',
|
|
|
1237 |
load_8bit = False
|
1238 |
load_4bit = False
|
1239 |
load_half = False
|
1240 |
+
load_gptq = ''
|
1241 |
+
use_safetensors = False
|
1242 |
base_model = score_model.strip()
|
1243 |
tokenizer_base_model = ''
|
1244 |
lora_weights = ''
|
|
|
1281 |
top_k_docs,
|
1282 |
chunk,
|
1283 |
chunk_size,
|
1284 |
+
document_subset,
|
1285 |
document_choice,
|
1286 |
# END NOTE: Examples must have same order of parameters
|
1287 |
src_lang=None,
|
|
|
1498 |
chunk_size=chunk_size,
|
1499 |
langchain_mode=langchain_mode,
|
1500 |
langchain_action=langchain_action,
|
1501 |
+
document_subset=document_subset,
|
1502 |
document_choice=document_choice,
|
1503 |
db_type=db_type,
|
1504 |
top_k_docs=top_k_docs,
|
|
|
1526 |
inference_server=inference_server,
|
1527 |
langchain_mode=langchain_mode,
|
1528 |
langchain_action=langchain_action,
|
1529 |
+
document_subset=document_subset,
|
1530 |
document_choice=document_choice,
|
1531 |
num_prompt_tokens=num_prompt_tokens,
|
1532 |
instruction=instruction,
|
|
|
1628 |
gr_client = None
|
1629 |
hf_client = model
|
1630 |
else:
|
1631 |
+
inference_server, gr_client, hf_client = get_client_from_inference_server(inference_server,
|
1632 |
+
base_model=base_model)
|
1633 |
|
1634 |
# quick sanity check to avoid long timeouts, just see if can reach server
|
1635 |
requests.get(inference_server, timeout=int(os.getenv('REQUEST_TIMEOUT_FAST', '10')))
|
|
|
1697 |
top_k_docs=top_k_docs,
|
1698 |
chunk=chunk,
|
1699 |
chunk_size=chunk_size,
|
1700 |
+
document_subset=DocumentChoices.Relevant.name,
|
1701 |
+
document_choice=[],
|
1702 |
)
|
1703 |
api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
|
1704 |
if not stream_output:
|
|
|
1897 |
|
1898 |
with torch.no_grad():
|
1899 |
have_lora_weights = lora_weights not in [no_lora_str, '', None]
|
1900 |
+
context_class_cast = NullContext if device == 'cpu' or have_lora_weights or device == 'mps' else torch.autocast
|
1901 |
with context_class_cast(device):
|
1902 |
# protection for gradio not keeping track of closed users,
|
1903 |
# else hit bitsandbytes lack of thread safety:
|
|
|
2274 |
|
2275 |
# move to correct position
|
2276 |
for example in examples:
|
2277 |
+
example += [chat, '', '', LangChainMode.DISABLED.value, LangChainAction.QUERY.value,
|
2278 |
+
top_k_docs, chunk, chunk_size, [DocumentChoices.Relevant.name], []
|
2279 |
]
|
2280 |
# adjust examples if non-chat mode
|
2281 |
if not chat:
|
|
|
2498 |
|
2499 |
python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='human_bot' --lora_weights='GPT-NeoXT-Chat-Base-20B.merged.json.8_epochs.57b2892c53df5b8cefac45f84d019cace803ef26.28'
|
2500 |
|
2501 |
+
must have 4*48GB GPU and run without 8bit in order for sharding to work with use_gpu_id=False
|
2502 |
can also pass --prompt_type='human_bot' and model can somewhat handle instructions without being instruct tuned
|
2503 |
+
python generate.py --base_model=decapoda-research/llama-65b-hf --load_8bit=False --use_gpu_id=False --prompt_type='human_bot'
|
2504 |
|
2505 |
python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6_9b
|
2506 |
"""
|
gpt_langchain.py
CHANGED
@@ -29,7 +29,8 @@ from evaluate_params import gen_hyper
|
|
29 |
from gen import get_model, SEED
|
30 |
from prompter import non_hf_types, PromptType, Prompter
|
31 |
from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \
|
32 |
-
get_device, ProgressParallel, remove, hash_file, clear_torch_cache, NullContext, get_hf_server, FakeTokenizer
|
|
|
33 |
from utils_langchain import StreamingGradioCallbackHandler
|
34 |
|
35 |
import_matplotlib()
|
@@ -387,7 +388,8 @@ class GradioInference(LLM):
|
|
387 |
top_k_docs=top_k_docs,
|
388 |
chunk=chunk,
|
389 |
chunk_size=chunk_size,
|
390 |
-
|
|
|
391 |
)
|
392 |
api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
|
393 |
if not stream_output:
|
@@ -913,40 +915,6 @@ def get_dai_docs(from_hf=False, get_pickle=True):
|
|
913 |
return sources
|
914 |
|
915 |
|
916 |
-
import distutils.spawn
|
917 |
-
|
918 |
-
have_tesseract = distutils.spawn.find_executable("tesseract")
|
919 |
-
have_libreoffice = distutils.spawn.find_executable("libreoffice")
|
920 |
-
|
921 |
-
import pkg_resources
|
922 |
-
|
923 |
-
try:
|
924 |
-
assert pkg_resources.get_distribution('arxiv') is not None
|
925 |
-
assert pkg_resources.get_distribution('pymupdf') is not None
|
926 |
-
have_arxiv = True
|
927 |
-
except (pkg_resources.DistributionNotFound, AssertionError):
|
928 |
-
have_arxiv = False
|
929 |
-
|
930 |
-
try:
|
931 |
-
assert pkg_resources.get_distribution('pymupdf') is not None
|
932 |
-
have_pymupdf = True
|
933 |
-
except (pkg_resources.DistributionNotFound, AssertionError):
|
934 |
-
have_pymupdf = False
|
935 |
-
|
936 |
-
try:
|
937 |
-
assert pkg_resources.get_distribution('selenium') is not None
|
938 |
-
have_selenium = True
|
939 |
-
except (pkg_resources.DistributionNotFound, AssertionError):
|
940 |
-
have_selenium = False
|
941 |
-
|
942 |
-
try:
|
943 |
-
assert pkg_resources.get_distribution('playwright') is not None
|
944 |
-
have_playwright = True
|
945 |
-
except (pkg_resources.DistributionNotFound, AssertionError):
|
946 |
-
have_playwright = False
|
947 |
-
|
948 |
-
# disable, hangs too often
|
949 |
-
have_playwright = False
|
950 |
|
951 |
image_types = ["png", "jpg", "jpeg"]
|
952 |
non_image_types = ["pdf", "txt", "csv", "toml", "py", "rst", "rtf",
|
@@ -973,7 +941,7 @@ def add_meta(docs1, file):
|
|
973 |
|
974 |
|
975 |
def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
|
976 |
-
chunk=True, chunk_size=512,
|
977 |
is_url=False, is_txt=False,
|
978 |
enable_captions=True,
|
979 |
captions_model=None,
|
@@ -1208,6 +1176,7 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
|
|
1208 |
|
1209 |
def path_to_doc1(file, verbose=False, fail_any_exception=False, return_file=True,
|
1210 |
chunk=True, chunk_size=512,
|
|
|
1211 |
is_url=False, is_txt=False,
|
1212 |
enable_captions=True,
|
1213 |
captions_model=None,
|
@@ -1224,6 +1193,7 @@ def path_to_doc1(file, verbose=False, fail_any_exception=False, return_file=True
|
|
1224 |
# don't pass base_path=path, would infinitely recurse
|
1225 |
res = file_to_doc(file, base_path=None, verbose=verbose, fail_any_exception=fail_any_exception,
|
1226 |
chunk=chunk, chunk_size=chunk_size,
|
|
|
1227 |
is_url=is_url, is_txt=is_txt,
|
1228 |
enable_captions=enable_captions,
|
1229 |
captions_model=captions_model,
|
@@ -1236,7 +1206,8 @@ def path_to_doc1(file, verbose=False, fail_any_exception=False, return_file=True
|
|
1236 |
else:
|
1237 |
exception_doc = Document(
|
1238 |
page_content='',
|
1239 |
-
metadata={"source": file, "exception":
|
|
|
1240 |
res = [exception_doc]
|
1241 |
if return_file:
|
1242 |
base_tmp = "temp_path_to_doc1"
|
@@ -1326,6 +1297,7 @@ def path_to_docs(path_or_paths, verbose=False, fail_any_exception=False, n_jobs=
|
|
1326 |
kwargs = dict(verbose=verbose, fail_any_exception=fail_any_exception,
|
1327 |
return_file=return_file,
|
1328 |
chunk=chunk, chunk_size=chunk_size,
|
|
|
1329 |
is_url=is_url,
|
1330 |
is_txt=is_txt,
|
1331 |
enable_captions=enable_captions,
|
@@ -1802,7 +1774,8 @@ def _run_qa_db(query=None,
|
|
1802 |
num_return_sequences=1,
|
1803 |
langchain_mode=None,
|
1804 |
langchain_action=None,
|
1805 |
-
|
|
|
1806 |
n_jobs=-1,
|
1807 |
verbose=False,
|
1808 |
cli=False,
|
@@ -1873,19 +1846,13 @@ def _run_qa_db(query=None,
|
|
1873 |
if isinstance(document_choice, str):
|
1874 |
# support string as well
|
1875 |
document_choice = [document_choice]
|
1876 |
-
|
1877 |
-
|
1878 |
-
cmd = [x for x in document_choice if x in doc_choices_set]
|
1879 |
-
cmd = None if len(cmd) == 0 else cmd[0]
|
1880 |
-
# now have cmd, filter out for only docs
|
1881 |
-
document_choice = [x for x in document_choice if x not in doc_choices_set]
|
1882 |
-
|
1883 |
-
func_names = list(inspect.signature(get_similarity_chain).parameters)
|
1884 |
sim_kwargs = {k: v for k, v in locals().items() if k in func_names}
|
1885 |
missing_kwargs = [x for x in func_names if x not in sim_kwargs]
|
1886 |
assert not missing_kwargs, "Missing: %s" % missing_kwargs
|
1887 |
-
docs, chain, scores, use_context, have_any_docs =
|
1888 |
-
if
|
1889 |
formatted_doc_chunks = '\n\n'.join([get_url(x) + '\n\n' + x.page_content for x in docs])
|
1890 |
yield formatted_doc_chunks, ''
|
1891 |
return
|
@@ -1963,36 +1930,36 @@ def _run_qa_db(query=None,
|
|
1963 |
return
|
1964 |
|
1965 |
|
1966 |
-
def
|
1967 |
-
|
1968 |
-
|
1969 |
-
|
1970 |
-
|
1971 |
-
|
1972 |
-
|
1973 |
-
|
1974 |
-
|
1975 |
-
|
1976 |
-
|
1977 |
-
|
1978 |
-
|
1979 |
-
|
1980 |
-
|
1981 |
-
|
1982 |
-
|
1983 |
-
|
1984 |
-
|
1985 |
-
|
1986 |
-
|
1987 |
-
|
1988 |
-
|
1989 |
-
|
1990 |
-
|
1991 |
-
|
1992 |
-
|
1993 |
-
|
1994 |
-
|
1995 |
-
|
1996 |
# determine whether use of context out of docs is planned
|
1997 |
if not use_openai_model and prompt_type not in ['plain'] or model_name in non_hf_types:
|
1998 |
if langchain_mode in ['Disabled', 'ChatLLM', 'LLM']:
|
@@ -2086,12 +2053,25 @@ def get_similarity_chain(query=None,
|
|
2086 |
use_template = False
|
2087 |
|
2088 |
if db and use_context:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2089 |
if not isinstance(db, Chroma):
|
2090 |
# only chroma supports filtering
|
2091 |
filter_kwargs = {}
|
2092 |
else:
|
2093 |
-
|
2094 |
-
if len(document_choice) >=
|
|
|
|
|
|
|
|
|
|
|
2095 |
or_filter = [{"source": {"$eq": x}} for x in document_choice]
|
2096 |
filter_kwargs = dict(filter={"$or": or_filter})
|
2097 |
elif len(document_choice) == 1:
|
@@ -2101,10 +2081,10 @@ def get_similarity_chain(query=None,
|
|
2101 |
else:
|
2102 |
# shouldn't reach
|
2103 |
filter_kwargs = {}
|
2104 |
-
if
|
2105 |
docs = []
|
2106 |
scores = []
|
2107 |
-
elif
|
2108 |
db_documents, db_metadatas = get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs)
|
2109 |
# similar to langchain's chroma's _results_to_docs_and_scores
|
2110 |
docs_with_score = [(Document(page_content=result[0], metadata=result[1] or {}), 0)
|
@@ -2127,13 +2107,7 @@ def get_similarity_chain(query=None,
|
|
2127 |
if top_k_docs == -1 or auto_reduce_chunks:
|
2128 |
# docs_with_score = db.similarity_search_with_score(query, k=k_db, **filter_kwargs)[:top_k_docs]
|
2129 |
top_k_docs_tokenize = 100
|
2130 |
-
|
2131 |
-
makedirs(base_path)
|
2132 |
-
if hasattr(db, '_persist_directory'):
|
2133 |
-
name_path = "sim_%s.lock" % os.path.basename(db._persist_directory)
|
2134 |
-
else:
|
2135 |
-
name_path = "sim.lock"
|
2136 |
-
with filelock.FileLock(os.path.join(base_path, name_path)):
|
2137 |
docs_with_score = db.similarity_search_with_score(query, k=k_db, **filter_kwargs)[
|
2138 |
:top_k_docs_tokenize]
|
2139 |
if hasattr(llm, 'pipeline') and hasattr(llm.pipeline, 'tokenizer'):
|
@@ -2189,7 +2163,8 @@ def get_similarity_chain(query=None,
|
|
2189 |
top_k_docs = 1
|
2190 |
docs_with_score = docs_with_score[:top_k_docs]
|
2191 |
else:
|
2192 |
-
|
|
|
2193 |
# put most relevant chunks closest to question,
|
2194 |
# esp. if truncation occurs will be "oldest" or "farthest from response" text that is truncated
|
2195 |
# BUT: for small models, e.g. 6_9 pythia, if sees some stuff related to h2oGPT first, it can connect that and not listen to rest
|
@@ -2210,7 +2185,7 @@ def get_similarity_chain(query=None,
|
|
2210 |
# if HF type and have no docs, can bail out
|
2211 |
return docs, None, [], False, have_any_docs
|
2212 |
|
2213 |
-
if
|
2214 |
# no LLM use
|
2215 |
return docs, None, [], False, have_any_docs
|
2216 |
|
|
|
29 |
from gen import get_model, SEED
|
30 |
from prompter import non_hf_types, PromptType, Prompter
|
31 |
from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \
|
32 |
+
get_device, ProgressParallel, remove, hash_file, clear_torch_cache, NullContext, get_hf_server, FakeTokenizer, \
|
33 |
+
have_libreoffice, have_arxiv, have_playwright, have_selenium, have_tesseract, have_pymupdf
|
34 |
from utils_langchain import StreamingGradioCallbackHandler
|
35 |
|
36 |
import_matplotlib()
|
|
|
388 |
top_k_docs=top_k_docs,
|
389 |
chunk=chunk,
|
390 |
chunk_size=chunk_size,
|
391 |
+
document_subset=DocumentChoices.Relevant.name,
|
392 |
+
document_choice=[],
|
393 |
)
|
394 |
api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
|
395 |
if not stream_output:
|
|
|
915 |
return sources
|
916 |
|
917 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
918 |
|
919 |
image_types = ["png", "jpg", "jpeg"]
|
920 |
non_image_types = ["pdf", "txt", "csv", "toml", "py", "rst", "rtf",
|
|
|
941 |
|
942 |
|
943 |
def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
|
944 |
+
chunk=True, chunk_size=512, n_jobs=-1,
|
945 |
is_url=False, is_txt=False,
|
946 |
enable_captions=True,
|
947 |
captions_model=None,
|
|
|
1176 |
|
1177 |
def path_to_doc1(file, verbose=False, fail_any_exception=False, return_file=True,
|
1178 |
chunk=True, chunk_size=512,
|
1179 |
+
n_jobs=-1,
|
1180 |
is_url=False, is_txt=False,
|
1181 |
enable_captions=True,
|
1182 |
captions_model=None,
|
|
|
1193 |
# don't pass base_path=path, would infinitely recurse
|
1194 |
res = file_to_doc(file, base_path=None, verbose=verbose, fail_any_exception=fail_any_exception,
|
1195 |
chunk=chunk, chunk_size=chunk_size,
|
1196 |
+
n_jobs=n_jobs,
|
1197 |
is_url=is_url, is_txt=is_txt,
|
1198 |
enable_captions=enable_captions,
|
1199 |
captions_model=captions_model,
|
|
|
1206 |
else:
|
1207 |
exception_doc = Document(
|
1208 |
page_content='',
|
1209 |
+
metadata={"source": file, "exception": '%s hit %s' % (file, str(e)),
|
1210 |
+
"traceback": traceback.format_exc()})
|
1211 |
res = [exception_doc]
|
1212 |
if return_file:
|
1213 |
base_tmp = "temp_path_to_doc1"
|
|
|
1297 |
kwargs = dict(verbose=verbose, fail_any_exception=fail_any_exception,
|
1298 |
return_file=return_file,
|
1299 |
chunk=chunk, chunk_size=chunk_size,
|
1300 |
+
n_jobs=n_jobs,
|
1301 |
is_url=is_url,
|
1302 |
is_txt=is_txt,
|
1303 |
enable_captions=enable_captions,
|
|
|
1774 |
num_return_sequences=1,
|
1775 |
langchain_mode=None,
|
1776 |
langchain_action=None,
|
1777 |
+
document_subset=DocumentChoices.Relevant.name,
|
1778 |
+
document_choice=[],
|
1779 |
n_jobs=-1,
|
1780 |
verbose=False,
|
1781 |
cli=False,
|
|
|
1846 |
if isinstance(document_choice, str):
|
1847 |
# support string as well
|
1848 |
document_choice = [document_choice]
|
1849 |
+
|
1850 |
+
func_names = list(inspect.signature(get_chain).parameters)
|
|
|
|
|
|
|
|
|
|
|
|
|
1851 |
sim_kwargs = {k: v for k, v in locals().items() if k in func_names}
|
1852 |
missing_kwargs = [x for x in func_names if x not in sim_kwargs]
|
1853 |
assert not missing_kwargs, "Missing: %s" % missing_kwargs
|
1854 |
+
docs, chain, scores, use_context, have_any_docs = get_chain(**sim_kwargs)
|
1855 |
+
if document_subset in non_query_commands:
|
1856 |
formatted_doc_chunks = '\n\n'.join([get_url(x) + '\n\n' + x.page_content for x in docs])
|
1857 |
yield formatted_doc_chunks, ''
|
1858 |
return
|
|
|
1930 |
return
|
1931 |
|
1932 |
|
1933 |
+
def get_chain(query=None,
|
1934 |
+
iinput=None,
|
1935 |
+
use_openai_model=False, use_openai_embedding=False,
|
1936 |
+
first_para=False, text_limit=None, top_k_docs=4, chunk=True, chunk_size=512,
|
1937 |
+
user_path=None,
|
1938 |
+
detect_user_path_changes_every_query=False,
|
1939 |
+
db_type='faiss',
|
1940 |
+
model_name=None,
|
1941 |
+
inference_server='',
|
1942 |
+
hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
|
1943 |
+
prompt_type=None,
|
1944 |
+
prompt_dict=None,
|
1945 |
+
cut_distanct=1.1,
|
1946 |
+
load_db_if_exists=False,
|
1947 |
+
db=None,
|
1948 |
+
langchain_mode=None,
|
1949 |
+
langchain_action=None,
|
1950 |
+
document_subset=DocumentChoices.Relevant.name,
|
1951 |
+
document_choice=[],
|
1952 |
+
n_jobs=-1,
|
1953 |
+
# beyond run_db_query:
|
1954 |
+
llm=None,
|
1955 |
+
tokenizer=None,
|
1956 |
+
verbose=False,
|
1957 |
+
reverse_docs=True,
|
1958 |
+
|
1959 |
+
# local
|
1960 |
+
auto_reduce_chunks=True,
|
1961 |
+
max_chunks=100,
|
1962 |
+
):
|
1963 |
# determine whether use of context out of docs is planned
|
1964 |
if not use_openai_model and prompt_type not in ['plain'] or model_name in non_hf_types:
|
1965 |
if langchain_mode in ['Disabled', 'ChatLLM', 'LLM']:
|
|
|
2053 |
use_template = False
|
2054 |
|
2055 |
if db and use_context:
|
2056 |
+
base_path = 'locks'
|
2057 |
+
makedirs(base_path)
|
2058 |
+
if hasattr(db, '_persist_directory'):
|
2059 |
+
name_path = "sim_%s.lock" % os.path.basename(db._persist_directory)
|
2060 |
+
else:
|
2061 |
+
name_path = "sim.lock"
|
2062 |
+
lock_file = os.path.join(base_path, name_path)
|
2063 |
+
|
2064 |
if not isinstance(db, Chroma):
|
2065 |
# only chroma supports filtering
|
2066 |
filter_kwargs = {}
|
2067 |
else:
|
2068 |
+
assert document_choice is not None, "Document choice was None"
|
2069 |
+
if len(document_choice) >= 1 and document_choice[0] == DocumentChoices.All.name:
|
2070 |
+
filter_kwargs = {}
|
2071 |
+
elif len(document_choice) >= 2:
|
2072 |
+
if document_choice[0] == DocumentChoices.All.name:
|
2073 |
+
# remove 'All'
|
2074 |
+
document_choice = document_choice[1:]
|
2075 |
or_filter = [{"source": {"$eq": x}} for x in document_choice]
|
2076 |
filter_kwargs = dict(filter={"$or": or_filter})
|
2077 |
elif len(document_choice) == 1:
|
|
|
2081 |
else:
|
2082 |
# shouldn't reach
|
2083 |
filter_kwargs = {}
|
2084 |
+
if langchain_mode in [LangChainMode.LLM.value, LangChainMode.CHAT_LLM.value]:
|
2085 |
docs = []
|
2086 |
scores = []
|
2087 |
+
elif document_subset == DocumentChoices.All.name or query in [None, '', '\n']:
|
2088 |
db_documents, db_metadatas = get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs)
|
2089 |
# similar to langchain's chroma's _results_to_docs_and_scores
|
2090 |
docs_with_score = [(Document(page_content=result[0], metadata=result[1] or {}), 0)
|
|
|
2107 |
if top_k_docs == -1 or auto_reduce_chunks:
|
2108 |
# docs_with_score = db.similarity_search_with_score(query, k=k_db, **filter_kwargs)[:top_k_docs]
|
2109 |
top_k_docs_tokenize = 100
|
2110 |
+
with filelock.FileLock(lock_file):
|
|
|
|
|
|
|
|
|
|
|
|
|
2111 |
docs_with_score = db.similarity_search_with_score(query, k=k_db, **filter_kwargs)[
|
2112 |
:top_k_docs_tokenize]
|
2113 |
if hasattr(llm, 'pipeline') and hasattr(llm.pipeline, 'tokenizer'):
|
|
|
2163 |
top_k_docs = 1
|
2164 |
docs_with_score = docs_with_score[:top_k_docs]
|
2165 |
else:
|
2166 |
+
with filelock.FileLock(lock_file):
|
2167 |
+
docs_with_score = db.similarity_search_with_score(query, k=k_db, **filter_kwargs)[:top_k_docs]
|
2168 |
# put most relevant chunks closest to question,
|
2169 |
# esp. if truncation occurs will be "oldest" or "farthest from response" text that is truncated
|
2170 |
# BUT: for small models, e.g. 6_9 pythia, if sees some stuff related to h2oGPT first, it can connect that and not listen to rest
|
|
|
2185 |
# if HF type and have no docs, can bail out
|
2186 |
return docs, None, [], False, have_any_docs
|
2187 |
|
2188 |
+
if document_subset in non_query_commands:
|
2189 |
# no LLM use
|
2190 |
return docs, None, [], False, have_any_docs
|
2191 |
|
gradio_runner.py
CHANGED
@@ -20,7 +20,7 @@ import tabulate
|
|
20 |
from iterators import TimeoutIterator
|
21 |
|
22 |
from gradio_utils.css import get_css
|
23 |
-
from gradio_utils.prompt_form import
|
24 |
|
25 |
# This is a hack to prevent Gradio from phoning home when it gets imported
|
26 |
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
|
@@ -56,7 +56,7 @@ from gradio_themes import H2oTheme, SoftTheme, get_h2o_title, get_simple_title,
|
|
56 |
from prompter import prompt_type_to_model_name, prompt_types_strings, inv_prompt_type_to_model_lower, non_hf_types, \
|
57 |
get_prompt
|
58 |
from utils import get_githash, flatten_list, zip_data, s3up, clear_torch_cache, get_torch_allocated, system_info_print, \
|
59 |
-
ping, get_short_name,
|
60 |
from gen import get_model, languages_covered, evaluate, score_qa, langchain_modes, inputs_kwargs_list, scratch_base_dir, \
|
61 |
get_max_max_new_tokens, get_minmax_top_k_docs, history_to_context, langchain_actions
|
62 |
from evaluate_params import eval_func_param_names, no_default_param_names, eval_func_param_names_defaults
|
@@ -118,6 +118,13 @@ def go_gradio(**kwargs):
|
|
118 |
allow_upload = allow_upload_to_user_data or allow_upload_to_my_data
|
119 |
kwargs.update(locals())
|
120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
if 'mbart-' in kwargs['model_lower']:
|
122 |
instruction_label_nochat = "Text to translate"
|
123 |
else:
|
@@ -134,8 +141,7 @@ def go_gradio(**kwargs):
|
|
134 |
"""
|
135 |
else:
|
136 |
description = more_info
|
137 |
-
description_bottom = "If this host is busy, try [
|
138 |
-
description_bottom += """<p>By using h2oGPT, you accept our [Terms of Service](https://github.com/h2oai/h2ogpt/blob/main/docs/tos.md)</p>"""
|
139 |
if is_hf:
|
140 |
description_bottom += '''<a href="https://huggingface.co/spaces/h2oai/h2ogpt-chatbot?duplicate=true"><img src="https://bit.ly/3gLdBN6" style="white-space: nowrap" alt="Duplicate Space"></a>'''
|
141 |
|
@@ -160,7 +166,7 @@ def go_gradio(**kwargs):
|
|
160 |
theme_kwargs = dict()
|
161 |
if kwargs['gradio_size'] == 'xsmall':
|
162 |
theme_kwargs.update(dict(spacing_size=spacing_xsm, text_size=text_xsm, radius_size=radius_xsm))
|
163 |
-
elif kwargs['gradio_size']
|
164 |
theme_kwargs.update(dict(spacing_size=gr.themes.sizes.spacing_sm, text_size=gr.themes.sizes.text_sm,
|
165 |
radius_size=gr.themes.sizes.spacing_sm))
|
166 |
elif kwargs['gradio_size'] == 'large':
|
@@ -262,14 +268,14 @@ def go_gradio(**kwargs):
|
|
262 |
model_options_state = gr.State([model_options])
|
263 |
lora_options_state = gr.State([lora_options])
|
264 |
server_options_state = gr.State([server_options])
|
265 |
-
|
266 |
-
my_db_state = gr.State([None, str(uuid.uuid4())])
|
267 |
chat_state = gr.State({})
|
268 |
-
|
269 |
-
docs_state00 = kwargs['document_choice'] + [x.name for x in list(DocumentChoices)]
|
270 |
docs_state0 = []
|
271 |
[docs_state0.append(x) for x in docs_state00 if x not in docs_state0]
|
272 |
-
docs_state = gr.State(docs_state0)
|
|
|
|
|
273 |
gr.Markdown(f"""
|
274 |
{get_h2o_title(title, description) if kwargs['h2ocolors'] else get_simple_title(title, description)}
|
275 |
""")
|
@@ -282,179 +288,208 @@ def go_gradio(**kwargs):
|
|
282 |
res_value = "Response Score: NA" if not kwargs[
|
283 |
'model_lock'] else "Response Scores: %s" % nas
|
284 |
|
285 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
286 |
with normal_block:
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
296 |
instruction_nochat = gr.Textbox(
|
297 |
lines=kwargs['input_lines'],
|
298 |
label=instruction_label_nochat,
|
299 |
placeholder=kwargs['placeholder_instruction'],
|
|
|
300 |
)
|
301 |
iinput_nochat = gr.Textbox(lines=4, label="Input context for Instruction",
|
302 |
-
placeholder=kwargs['placeholder_input']
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
if is_hf:
|
343 |
-
# don't show 'wiki' since only usually useful for internal testing at moment
|
344 |
-
no_show_modes = ['Disabled', 'wiki']
|
345 |
-
else:
|
346 |
-
no_show_modes = ['Disabled']
|
347 |
-
allowed_modes = visible_langchain_modes.copy()
|
348 |
-
allowed_modes = [x for x in allowed_modes if x in dbs]
|
349 |
-
allowed_modes += ['ChatLLM', 'LLM']
|
350 |
-
if allow_upload_to_my_data and 'MyData' not in allowed_modes:
|
351 |
-
allowed_modes += ['MyData']
|
352 |
-
if allow_upload_to_user_data and 'UserData' not in allowed_modes:
|
353 |
-
allowed_modes += ['UserData']
|
354 |
-
langchain_mode = gr.Radio(
|
355 |
-
[x for x in langchain_modes if x in allowed_modes and x not in no_show_modes],
|
356 |
-
value=kwargs['langchain_mode'],
|
357 |
-
label="Data Collection of Sources",
|
358 |
-
visible=kwargs['langchain_mode'] != 'Disabled')
|
359 |
-
allowed_actions = [x for x in langchain_actions if x in visible_langchain_actions]
|
360 |
-
langchain_action = gr.Radio(
|
361 |
-
allowed_actions,
|
362 |
-
value=allowed_actions[0] if len(allowed_actions) > 0 else None,
|
363 |
-
label="Data Action",
|
364 |
-
visible=True)
|
365 |
-
data_row2 = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled')
|
366 |
-
with data_row2:
|
367 |
-
with gr.Column(scale=50):
|
368 |
-
document_choice = gr.Dropdown(docs_state.value,
|
369 |
-
label="Choose Subset of Doc(s) in Collection [click get sources to update]",
|
370 |
-
value=docs_state.value[0],
|
371 |
-
interactive=True,
|
372 |
-
multiselect=True,
|
373 |
-
)
|
374 |
-
with gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and enable_sources_list):
|
375 |
-
get_sources_btn = gr.Button(value="Get Sources", scale=0, size='sm')
|
376 |
-
show_sources_btn = gr.Button(value="Show Sources", scale=0, size='sm')
|
377 |
-
refresh_sources_btn = gr.Button(value="Refresh Sources", scale=0, size='sm')
|
378 |
-
|
379 |
-
# import control
|
380 |
-
if kwargs['langchain_mode'] != 'Disabled':
|
381 |
-
from gpt_langchain import file_types, have_arxiv
|
382 |
-
else:
|
383 |
-
have_arxiv = False
|
384 |
-
file_types = []
|
385 |
|
386 |
-
upload_row = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and allow_upload,
|
387 |
-
equal_height=False)
|
388 |
-
with upload_row:
|
389 |
-
with gr.Column():
|
390 |
-
file_types_str = '[' + ' '.join(file_types) + ']'
|
391 |
-
fileup_output = gr.File(label=f'Upload {file_types_str}',
|
392 |
-
file_types=file_types,
|
393 |
-
file_count="multiple",
|
394 |
-
elem_id="warning", elem_classes="feedback")
|
395 |
with gr.Row():
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
elem_id='small_btn' if allow_upload_to_user_data else None,
|
421 |
-
size='sm' if not allow_upload_to_user_data else None)
|
422 |
-
with gr.Column(
|
423 |
-
visible=kwargs['langchain_mode'] != 'Disabled' and allow_upload and enable_text_upload):
|
424 |
-
user_text_text = gr.Textbox(label='Paste Text [Shift-Enter more lines]',
|
425 |
-
placeholder="Click Add to Submit" if
|
426 |
-
allow_upload_to_my_data and
|
427 |
-
allow_upload_to_user_data else
|
428 |
-
"Enter to Submit, Shift-Enter for more lines",
|
429 |
-
interactive=True)
|
430 |
-
with gr.Row():
|
431 |
-
user_text_user_btn = gr.Button(value='Add Text to Shared UserData',
|
432 |
-
visible=allow_upload_to_user_data and allow_upload_to_my_data,
|
433 |
-
elem_id='small_btn')
|
434 |
-
user_text_my_btn = gr.Button(value='Add Text to Scratch MyData',
|
435 |
-
visible=allow_upload_to_my_data and allow_upload_to_user_data,
|
436 |
-
elem_id='small_btn' if allow_upload_to_user_data else None,
|
437 |
-
size='sm' if not allow_upload_to_user_data else None)
|
438 |
-
with gr.Column(visible=False):
|
439 |
-
# WIP:
|
440 |
-
with gr.Row(visible=False, equal_height=False):
|
441 |
-
github_textbox = gr.Textbox(label="Github URL")
|
442 |
-
with gr.Row(visible=True):
|
443 |
-
github_shared_btn = gr.Button(value="Add Github to Shared UserData",
|
444 |
-
visible=allow_upload_to_user_data,
|
445 |
-
elem_id='small_btn')
|
446 |
-
github_my_btn = gr.Button(value="Add Github to Scratch MyData",
|
447 |
-
visible=allow_upload_to_my_data, elem_id='small_btn')
|
448 |
sources_row = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and enable_sources_list,
|
449 |
equal_height=False)
|
450 |
with sources_row:
|
451 |
with gr.Column(scale=1):
|
452 |
file_source = gr.File(interactive=False,
|
453 |
-
label="Download File w/Sources
|
454 |
with gr.Column(scale=2):
|
455 |
sources_text = gr.HTML(label='Sources Added', interactive=False)
|
456 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
457 |
with gr.TabItem("Chat History"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
458 |
with gr.Row():
|
459 |
if 'mbart-' in kwargs['model_lower']:
|
460 |
src_lang = gr.Dropdown(list(languages_covered().keys()),
|
@@ -463,20 +498,9 @@ def go_gradio(**kwargs):
|
|
463 |
tgt_lang = gr.Dropdown(list(languages_covered().keys()),
|
464 |
value=kwargs['tgt_lang'],
|
465 |
label="Output Language")
|
466 |
-
radio_chats = gr.Radio(value=None, label="Saved Chats", visible=True, interactive=True,
|
467 |
-
type='value')
|
468 |
-
with gr.Row():
|
469 |
-
clear_chat_btn = gr.Button(value="Clear Chat", visible=True, size='sm')
|
470 |
-
export_chats_btn = gr.Button(value="Export Chats to Download", size='sm')
|
471 |
-
remove_chat_btn = gr.Button(value="Remove Selected Chat", visible=True, size='sm')
|
472 |
-
add_to_chats_btn = gr.Button("Import Chats from Upload", size='sm')
|
473 |
-
with gr.Row():
|
474 |
-
chats_file = gr.File(interactive=False, label="Download Exported Chats")
|
475 |
-
chatsup_output = gr.File(label="Upload Chat File(s)",
|
476 |
-
file_types=['.json'],
|
477 |
-
file_count='multiple',
|
478 |
-
elem_id="warning", elem_classes="feedback")
|
479 |
|
|
|
|
|
480 |
with gr.TabItem("Expert"):
|
481 |
with gr.Row():
|
482 |
with gr.Column():
|
@@ -555,7 +579,7 @@ def go_gradio(**kwargs):
|
|
555 |
info="Directly pre-appended without prompt processing",
|
556 |
interactive=not is_public)
|
557 |
chat = gr.components.Checkbox(label="Chat mode", value=kwargs['chat'],
|
558 |
-
visible=
|
559 |
interactive=not is_public,
|
560 |
)
|
561 |
count_chat_tokens_btn = gr.Button(value="Count Chat Tokens",
|
@@ -614,9 +638,9 @@ def go_gradio(**kwargs):
|
|
614 |
model_load8bit_checkbox = gr.components.Checkbox(
|
615 |
label="Load 8-bit [requires support]",
|
616 |
value=kwargs['load_8bit'], interactive=not is_public)
|
617 |
-
|
618 |
label="Choose Devices [If not Checked, use all GPUs]",
|
619 |
-
value=kwargs['
|
620 |
model_gpu = gr.Dropdown(n_gpus_list,
|
621 |
label="GPU ID [-1 = all GPUs, if Choose is enabled]",
|
622 |
value=kwargs['gpu_id'], interactive=not is_public)
|
@@ -649,10 +673,10 @@ def go_gradio(**kwargs):
|
|
649 |
model_load8bit_checkbox2 = gr.components.Checkbox(
|
650 |
label="Load 8-bit 2 [requires support]",
|
651 |
value=kwargs['load_8bit'], interactive=not is_public)
|
652 |
-
|
653 |
label="Choose Devices 2 [If not Checked, use all GPUs]",
|
654 |
value=kwargs[
|
655 |
-
'
|
656 |
model_gpu2 = gr.Dropdown(n_gpus_list,
|
657 |
label="GPU ID 2 [-1 = all GPUs, if choose is enabled]",
|
658 |
value=kwargs['gpu_id'], interactive=not is_public)
|
@@ -679,35 +703,52 @@ def go_gradio(**kwargs):
|
|
679 |
add_model_lora_server_button = gr.Button("Add new Model, Lora, Server url:port", scale=0,
|
680 |
size='sm', interactive=not is_public)
|
681 |
with gr.TabItem("System"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
682 |
admin_row = gr.Row()
|
683 |
with admin_row:
|
684 |
-
|
685 |
-
|
|
|
|
|
|
|
686 |
system_row = gr.Row(visible=not is_public)
|
687 |
with system_row:
|
688 |
with gr.Column():
|
689 |
with gr.Row():
|
690 |
-
system_btn = gr.Button(value='Get System Info')
|
691 |
system_text = gr.Textbox(label='System Info', interactive=False, show_copy_button=True)
|
692 |
with gr.Row():
|
693 |
system_input = gr.Textbox(label='System Info Dict Password', interactive=True,
|
694 |
visible=not is_public)
|
695 |
-
system_btn2 = gr.Button(value='Get System Info Dict', visible=not is_public)
|
696 |
system_text2 = gr.Textbox(label='System Info Dict', interactive=False,
|
697 |
visible=not is_public, show_copy_button=True)
|
698 |
with gr.Row():
|
699 |
-
system_btn3 = gr.Button(value='Get Hash', visible=not is_public)
|
700 |
system_text3 = gr.Textbox(label='Hash', interactive=False,
|
701 |
visible=not is_public, show_copy_button=True)
|
702 |
|
703 |
with gr.Row():
|
704 |
-
zip_btn = gr.Button("Zip")
|
705 |
zip_text = gr.Textbox(label="Zip file name", interactive=False)
|
706 |
file_output = gr.File(interactive=False, label="Zip file to Download")
|
707 |
with gr.Row():
|
708 |
-
s3up_btn = gr.Button("S3UP")
|
709 |
s3up_text = gr.Textbox(label='S3UP result', interactive=False)
|
710 |
-
|
|
|
711 |
description = ""
|
712 |
description += """<p><b> DISCLAIMERS: </b><ul><i><li>The model was trained on The Pile and other data, which may contain objectionable content. Use at own risk.</i></li>"""
|
713 |
if kwargs['load_8bit']:
|
@@ -718,17 +759,18 @@ def go_gradio(**kwargs):
|
|
718 |
description += """<i><li>By using h2oGPT, you accept our <a href="https://github.com/h2oai/h2ogpt/blob/main/docs/tos.md">Terms of Service</a></i></li></ul></p>"""
|
719 |
gr.Markdown(value=description, show_label=False, interactive=False)
|
720 |
|
721 |
-
|
722 |
-
|
723 |
-
|
724 |
-
|
|
|
725 |
|
726 |
# Get flagged data
|
727 |
zip_data1 = functools.partial(zip_data, root_dirs=['flagged_data_points', kwargs['save_dir']])
|
728 |
-
zip_btn.click(zip_data1, inputs=None, outputs=[file_output, zip_text], queue=False,
|
729 |
-
|
730 |
-
s3up_btn.click(s3up, inputs=zip_text, outputs=s3up_text, queue=False,
|
731 |
-
|
732 |
|
733 |
def clear_file_list():
|
734 |
return None
|
@@ -746,182 +788,204 @@ def go_gradio(**kwargs):
|
|
746 |
return tuple([gr.update(interactive=True)] * len(args))
|
747 |
|
748 |
# Add to UserData
|
749 |
-
|
750 |
-
|
751 |
-
|
752 |
-
|
753 |
-
|
754 |
-
|
755 |
-
|
756 |
-
|
757 |
-
|
758 |
-
|
759 |
-
|
760 |
-
|
761 |
-
|
762 |
-
|
763 |
-
|
764 |
-
|
765 |
-
|
766 |
-
outputs=add_file_outputs + [sources_text],
|
767 |
queue=queue,
|
768 |
-
api_name='
|
769 |
|
770 |
-
if allow_upload_to_user_data and not allow_upload_to_my_data:
|
771 |
-
func1 = fileup_output.change
|
772 |
-
else:
|
773 |
-
func1 = add_to_shared_db_btn.click
|
774 |
# then no need for add buttons, only single changeable db
|
775 |
-
eventdb1a =
|
776 |
-
|
777 |
-
eventdb1 = eventdb1a.then(**add_file_kwargs, show_progress='
|
778 |
-
eventdb1.then(make_interactive, inputs=add_file_outputs, outputs=add_file_outputs,
|
|
|
779 |
|
780 |
# note for update_user_db_func output is ignored for db
|
781 |
|
782 |
def clear_textbox():
|
783 |
return gr.Textbox.update(value='')
|
784 |
|
785 |
-
update_user_db_url_func = functools.partial(
|
786 |
|
787 |
-
add_url_outputs = [url_text, langchain_mode
|
788 |
add_url_kwargs = dict(fn=update_user_db_url_func,
|
789 |
-
inputs=[url_text, my_db_state,
|
790 |
-
|
791 |
-
outputs=add_url_outputs + [sources_text],
|
792 |
queue=queue,
|
793 |
-
api_name='
|
794 |
|
795 |
-
|
796 |
-
|
797 |
-
else:
|
798 |
-
func2 = url_user_btn.click
|
799 |
-
eventdb2a = func2(fn=dummy_fun, inputs=url_text, outputs=url_text, queue=queue,
|
800 |
-
show_progress='minimal')
|
801 |
# work around https://github.com/gradio-app/gradio/issues/4733
|
802 |
eventdb2b = eventdb2a.then(make_non_interactive, inputs=add_url_outputs, outputs=add_url_outputs,
|
803 |
show_progress='minimal')
|
804 |
-
eventdb2 = eventdb2b.then(**add_url_kwargs, show_progress='
|
805 |
-
eventdb2.then(make_interactive, inputs=add_url_outputs, outputs=add_url_outputs,
|
|
|
806 |
|
807 |
-
update_user_db_txt_func = functools.partial(
|
808 |
-
add_text_outputs = [user_text_text, langchain_mode
|
809 |
add_text_kwargs = dict(fn=update_user_db_txt_func,
|
810 |
-
inputs=[user_text_text, my_db_state,
|
811 |
-
|
812 |
-
outputs=add_text_outputs + [sources_text],
|
813 |
queue=queue,
|
814 |
-
api_name='
|
815 |
)
|
816 |
-
|
817 |
-
|
818 |
-
else:
|
819 |
-
func3 = user_text_user_btn.click
|
820 |
-
|
821 |
-
eventdb3a = func3(fn=dummy_fun, inputs=user_text_text, outputs=user_text_text, queue=queue,
|
822 |
-
show_progress='minimal')
|
823 |
eventdb3b = eventdb3a.then(make_non_interactive, inputs=add_text_outputs, outputs=add_text_outputs,
|
824 |
show_progress='minimal')
|
825 |
-
eventdb3 = eventdb3b.then(**add_text_kwargs, show_progress='
|
826 |
-
eventdb3.then(make_interactive, inputs=add_text_outputs, outputs=add_text_outputs,
|
827 |
-
|
828 |
-
|
829 |
-
|
830 |
-
|
831 |
-
hf_embedding_model=hf_embedding_model,
|
832 |
-
enable_captions=enable_captions,
|
833 |
-
captions_model=captions_model,
|
834 |
-
enable_ocr=enable_ocr,
|
835 |
-
caption_loader=caption_loader,
|
836 |
-
verbose=kwargs['verbose'],
|
837 |
-
user_path=kwargs['user_path'],
|
838 |
-
n_jobs=kwargs['n_jobs'],
|
839 |
-
)
|
840 |
-
|
841 |
-
add_my_file_outputs = [fileup_output, langchain_mode, my_db_state, add_to_shared_db_btn, add_to_my_db_btn]
|
842 |
-
add_my_file_kwargs = dict(fn=update_my_db_func,
|
843 |
-
inputs=[fileup_output, my_db_state, add_to_shared_db_btn, add_to_my_db_btn,
|
844 |
-
chunk, chunk_size],
|
845 |
-
outputs=add_my_file_outputs + [sources_text],
|
846 |
-
queue=queue,
|
847 |
-
api_name='add_to_my' if allow_api and allow_upload_to_my_data else None)
|
848 |
-
|
849 |
-
if not allow_upload_to_user_data and allow_upload_to_my_data:
|
850 |
-
func4 = fileup_output.change
|
851 |
-
else:
|
852 |
-
func4 = add_to_my_db_btn.click
|
853 |
-
|
854 |
-
eventdb4a = func4(make_non_interactive, inputs=add_my_file_outputs,
|
855 |
-
outputs=add_my_file_outputs,
|
856 |
-
show_progress='minimal')
|
857 |
-
eventdb4 = eventdb4a.then(**add_my_file_kwargs, show_progress='minimal')
|
858 |
-
eventdb4.then(make_interactive, inputs=add_my_file_outputs, outputs=add_my_file_outputs,
|
859 |
-
show_progress='minimal')
|
860 |
-
|
861 |
-
update_my_db_url_func = functools.partial(update_my_db_func, is_url=True)
|
862 |
-
add_my_url_outputs = [url_text, langchain_mode, my_db_state, url_user_btn, url_my_btn]
|
863 |
-
add_my_url_kwargs = dict(fn=update_my_db_url_func,
|
864 |
-
inputs=[url_text, my_db_state, url_user_btn, url_my_btn,
|
865 |
-
chunk, chunk_size],
|
866 |
-
outputs=add_my_url_outputs + [sources_text],
|
867 |
-
queue=queue,
|
868 |
-
api_name='add_url_to_my' if allow_api and allow_upload_to_my_data else None)
|
869 |
-
if not allow_upload_to_user_data and allow_upload_to_my_data:
|
870 |
-
func5 = url_text.submit
|
871 |
-
else:
|
872 |
-
func5 = url_my_btn.click
|
873 |
-
eventdb5a = func5(fn=dummy_fun, inputs=url_text, outputs=url_text, queue=queue,
|
874 |
-
show_progress='minimal')
|
875 |
-
eventdb5b = eventdb5a.then(make_non_interactive, inputs=add_my_url_outputs, outputs=add_my_url_outputs,
|
876 |
-
show_progress='minimal')
|
877 |
-
eventdb5 = eventdb5b.then(**add_my_url_kwargs, show_progress='minimal')
|
878 |
-
eventdb5.then(make_interactive, inputs=add_my_url_outputs, outputs=add_my_url_outputs,
|
879 |
-
show_progress='minimal')
|
880 |
-
|
881 |
-
update_my_db_txt_func = functools.partial(update_my_db_func, is_txt=True)
|
882 |
-
|
883 |
-
add_my_text_outputs = [user_text_text, langchain_mode, my_db_state, user_text_user_btn,
|
884 |
-
user_text_my_btn]
|
885 |
-
add_my_text_kwargs = dict(fn=update_my_db_txt_func,
|
886 |
-
inputs=[user_text_text, my_db_state, user_text_user_btn, user_text_my_btn,
|
887 |
-
chunk, chunk_size],
|
888 |
-
outputs=add_my_text_outputs + [sources_text],
|
889 |
-
queue=queue,
|
890 |
-
api_name='add_txt_to_my' if allow_api and allow_upload_to_my_data else None)
|
891 |
-
if not allow_upload_to_user_data and allow_upload_to_my_data:
|
892 |
-
func6 = user_text_text.submit
|
893 |
-
else:
|
894 |
-
func6 = user_text_my_btn.click
|
895 |
-
|
896 |
-
eventdb6a = func6(fn=dummy_fun, inputs=user_text_text, outputs=user_text_text, queue=queue,
|
897 |
-
show_progress='minimal')
|
898 |
-
eventdb6b = eventdb6a.then(make_non_interactive, inputs=add_my_text_outputs, outputs=add_my_text_outputs,
|
899 |
-
show_progress='minimal')
|
900 |
-
eventdb6 = eventdb6b.then(**add_my_text_kwargs, show_progress='minimal')
|
901 |
-
eventdb6.then(make_interactive, inputs=add_my_text_outputs, outputs=add_my_text_outputs,
|
902 |
-
show_progress='minimal')
|
903 |
|
904 |
get_sources1 = functools.partial(get_sources, dbs=dbs, docs_state0=docs_state0)
|
905 |
|
906 |
# if change collection source, must clear doc selections from it to avoid inconsistency
|
907 |
def clear_doc_choice():
|
908 |
-
return gr.Dropdown.update(choices=docs_state0, value=
|
|
|
|
|
909 |
|
910 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
911 |
|
912 |
def update_dropdown(x):
|
913 |
return gr.Dropdown.update(choices=x, value=[docs_state0[0]])
|
914 |
|
915 |
-
|
916 |
-
|
917 |
-
|
918 |
-
|
|
|
|
|
919 |
.then(fn=update_dropdown, inputs=docs_state, outputs=document_choice)
|
920 |
# show button, else only show when add. Could add to above get_sources for download/dropdown, but bit much maybe
|
921 |
show_sources1 = functools.partial(get_source_files_given_langchain_mode, dbs=dbs)
|
922 |
eventdb8 = show_sources_btn.click(fn=show_sources1, inputs=[my_db_state, langchain_mode], outputs=sources_text,
|
923 |
api_name='show_sources' if allow_api else None)
|
924 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
925 |
# Get inputs to evaluate() and make_db()
|
926 |
# don't deepcopy, can contain model itself
|
927 |
all_kwargs = kwargs.copy()
|
@@ -1008,9 +1072,6 @@ def go_gradio(**kwargs):
|
|
1008 |
**kwargs_evaluate
|
1009 |
)
|
1010 |
|
1011 |
-
dark_mode_btn = gr.Button("Dark Mode", variant="primary", size="sm")
|
1012 |
-
# FIXME: Could add exceptions for non-chat but still streaming
|
1013 |
-
exception_text = gr.Textbox(value="", visible=kwargs['chat'], label='Chat Exceptions', interactive=False)
|
1014 |
dark_mode_btn.click(
|
1015 |
None,
|
1016 |
None,
|
@@ -1020,20 +1081,19 @@ def go_gradio(**kwargs):
|
|
1020 |
queue=False,
|
1021 |
)
|
1022 |
|
1023 |
-
|
1024 |
-
|
1025 |
-
return gr.Column.update(visible=
|
1026 |
|
1027 |
-
|
1028 |
-
|
|
|
|
|
1029 |
|
1030 |
-
|
1031 |
-
|
1032 |
-
|
1033 |
-
|
1034 |
-
.then(col_chat_fun, chat, col_chat) \
|
1035 |
-
.then(context_fun, chat, context) \
|
1036 |
-
.then(col_chat_fun, chat, exception_text)
|
1037 |
|
1038 |
# examples after submit or any other buttons for chat or no chat
|
1039 |
if kwargs['examples'] is not None and kwargs['show_examples']:
|
@@ -1154,6 +1214,7 @@ def go_gradio(**kwargs):
|
|
1154 |
prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
|
1155 |
langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
|
1156 |
langchain_action1 = args_list[eval_func_param_names.index('langchain_action')]
|
|
|
1157 |
document_choice1 = args_list[eval_func_param_names.index('document_choice')]
|
1158 |
if not prompt_type1:
|
1159 |
# shouldn't have to specify if CLI launched model
|
@@ -1186,7 +1247,7 @@ def go_gradio(**kwargs):
|
|
1186 |
return history
|
1187 |
if user_message1 in ['', None, '\n']:
|
1188 |
if langchain_action1 in LangChainAction.QUERY.value and \
|
1189 |
-
DocumentChoices.
|
1190 |
or \
|
1191 |
langchain_mode1 in [LangChainMode.CHAT_LLM.value, LangChainMode.LLM.value]:
|
1192 |
# reject non-retry submit/enter
|
@@ -1249,6 +1310,7 @@ def go_gradio(**kwargs):
|
|
1249 |
args_list = args_list[:-3] # only keep rest needed for evaluate()
|
1250 |
langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
|
1251 |
langchain_action1 = args_list[eval_func_param_names.index('langchain_action')]
|
|
|
1252 |
document_choice1 = args_list[eval_func_param_names.index('document_choice')]
|
1253 |
if not history:
|
1254 |
print("No history", flush=True)
|
@@ -1261,7 +1323,7 @@ def go_gradio(**kwargs):
|
|
1261 |
history[-1][1] = None
|
1262 |
elif not instruction1:
|
1263 |
if langchain_action1 in LangChainAction.QUERY.value and \
|
1264 |
-
DocumentChoices.
|
1265 |
or \
|
1266 |
langchain_mode1 in [LangChainMode.CHAT_LLM.value, LangChainMode.LLM.value]:
|
1267 |
# if not retrying, then reject empty query
|
@@ -1432,11 +1494,11 @@ def go_gradio(**kwargs):
|
|
1432 |
)
|
1433 |
bot_args = dict(fn=bot,
|
1434 |
inputs=inputs_list + [model_state, my_db_state] + [text_output],
|
1435 |
-
outputs=[text_output,
|
1436 |
)
|
1437 |
retry_bot_args = dict(fn=functools.partial(bot, retry=True),
|
1438 |
inputs=inputs_list + [model_state, my_db_state] + [text_output],
|
1439 |
-
outputs=[text_output,
|
1440 |
)
|
1441 |
retry_user_args = dict(fn=functools.partial(user, retry=True),
|
1442 |
inputs=inputs_list + [text_output],
|
@@ -1454,11 +1516,11 @@ def go_gradio(**kwargs):
|
|
1454 |
)
|
1455 |
bot_args2 = dict(fn=bot,
|
1456 |
inputs=inputs_list2 + [model_state2, my_db_state] + [text_output2],
|
1457 |
-
outputs=[text_output2,
|
1458 |
)
|
1459 |
retry_bot_args2 = dict(fn=functools.partial(bot, retry=True),
|
1460 |
inputs=inputs_list2 + [model_state2, my_db_state] + [text_output2],
|
1461 |
-
outputs=[text_output2,
|
1462 |
)
|
1463 |
retry_user_args2 = dict(fn=functools.partial(user, retry=True),
|
1464 |
inputs=inputs_list2 + [text_output2],
|
@@ -1479,11 +1541,11 @@ def go_gradio(**kwargs):
|
|
1479 |
)
|
1480 |
all_bot_args = dict(fn=functools.partial(all_bot, model_states1=model_states),
|
1481 |
inputs=inputs_list + [my_db_state] + text_outputs,
|
1482 |
-
outputs=text_outputs + [
|
1483 |
)
|
1484 |
all_retry_bot_args = dict(fn=functools.partial(all_bot, model_states1=model_states, retry=True),
|
1485 |
inputs=inputs_list + [my_db_state] + text_outputs,
|
1486 |
-
outputs=text_outputs + [
|
1487 |
)
|
1488 |
all_retry_user_args = dict(fn=functools.partial(all_user, retry=True,
|
1489 |
sanitize_user_prompt=kwargs['sanitize_user_prompt'],
|
@@ -1681,13 +1743,26 @@ def go_gradio(**kwargs):
|
|
1681 |
return False
|
1682 |
return is_same
|
1683 |
|
1684 |
-
def save_chat(*args):
|
1685 |
args_list = list(args)
|
1686 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1687 |
# remove None histories
|
1688 |
chat_list_not_none = [x for x in chat_list if x and len(x) > 0 and len(x[0]) == 2 and x[0][1] is not None]
|
1689 |
-
|
1690 |
-
|
|
|
|
|
|
|
1691 |
short_chats = list(chat_state1.keys())
|
1692 |
if len(chat_list_not_none) > 0:
|
1693 |
# make short_chat key from only first history, based upon question that is same anyways
|
@@ -1699,13 +1774,14 @@ def go_gradio(**kwargs):
|
|
1699 |
if not already_exists:
|
1700 |
chat_state1[short_chat] = chat_list.copy()
|
1701 |
# clear chat_list so saved and then new conversation starts
|
1702 |
-
|
1703 |
-
|
|
|
|
|
|
|
|
|
1704 |
return tuple(ret_list)
|
1705 |
|
1706 |
-
def update_radio_chats(chat_state1):
|
1707 |
-
return gr.update(choices=list(chat_state1.keys()), value=None)
|
1708 |
-
|
1709 |
def switch_chat(chat_key, chat_state1, num_model_lock=0):
|
1710 |
chosen_chat = chat_state1[chat_key]
|
1711 |
# deal with possible different size of chat list vs. current list
|
@@ -1729,11 +1805,13 @@ def go_gradio(**kwargs):
|
|
1729 |
.then(clear_scores, outputs=[score_text, score_text2, score_text_nochat])
|
1730 |
|
1731 |
def remove_chat(chat_key, chat_state1):
|
1732 |
-
|
1733 |
-
|
|
|
1734 |
|
1735 |
-
remove_chat_btn.click(remove_chat,
|
1736 |
-
|
|
|
1737 |
|
1738 |
def get_chats1(chat_state1):
|
1739 |
base = 'chats'
|
@@ -1743,18 +1821,19 @@ def go_gradio(**kwargs):
|
|
1743 |
f.write(json.dumps(chat_state1, indent=2))
|
1744 |
return filename
|
1745 |
|
1746 |
-
export_chats_btn.click(get_chats1, inputs=chat_state, outputs=chats_file, queue=False,
|
1747 |
-
|
1748 |
|
1749 |
-
def add_chats_from_file(file, chat_state1,
|
1750 |
if not file:
|
1751 |
-
return chat_state1,
|
1752 |
if isinstance(file, str):
|
1753 |
files = [file]
|
1754 |
else:
|
1755 |
files = file
|
1756 |
if not files:
|
1757 |
-
return chat_state1,
|
|
|
1758 |
for file1 in files:
|
1759 |
try:
|
1760 |
if hasattr(file1, 'name'):
|
@@ -1763,33 +1842,42 @@ def go_gradio(**kwargs):
|
|
1763 |
new_chats = json.loads(f.read())
|
1764 |
for chat1_k, chat1_v in new_chats.items():
|
1765 |
# ignore chat1_k, regenerate and de-dup to avoid loss
|
1766 |
-
_, chat_state1 = save_chat(chat1_v, chat_state1)
|
1767 |
except BaseException as e:
|
1768 |
t, v, tb = sys.exc_info()
|
1769 |
ex = ''.join(traceback.format_exception(t, v, tb))
|
1770 |
-
|
1771 |
-
|
|
|
|
|
|
|
1772 |
|
1773 |
# note for update_user_db_func output is ignored for db
|
1774 |
-
|
1775 |
-
|
1776 |
-
|
1777 |
-
|
1778 |
-
|
1779 |
-
|
1780 |
-
|
1781 |
-
|
1782 |
-
|
1783 |
-
|
1784 |
-
|
|
|
1785 |
.then(deselect_radio_chats, inputs=None, outputs=radio_chats, queue=False) \
|
1786 |
.then(clear_scores, outputs=[score_text, score_text2, score_text_nochat])
|
1787 |
|
1788 |
-
|
1789 |
-
|
1790 |
-
|
1791 |
-
|
1792 |
-
|
|
|
|
|
|
|
|
|
|
|
1793 |
.then(update_radio_chats, inputs=chat_state, outputs=radio_chats,
|
1794 |
api_name='update_chats' if allow_api else None) \
|
1795 |
.then(clear_scores, outputs=[score_text, score_text2, score_text_nochat])
|
@@ -1823,7 +1911,7 @@ def go_gradio(**kwargs):
|
|
1823 |
.then(clear_torch_cache)
|
1824 |
|
1825 |
def load_model(model_name, lora_weights, server_name, model_state_old, prompt_type_old, load_8bit,
|
1826 |
-
|
1827 |
# ensure no API calls reach here
|
1828 |
if is_public:
|
1829 |
raise RuntimeError("Illegal access for %s" % model_name)
|
@@ -1867,7 +1955,7 @@ def go_gradio(**kwargs):
|
|
1867 |
all_kwargs1 = all_kwargs.copy()
|
1868 |
all_kwargs1['base_model'] = model_name.strip()
|
1869 |
all_kwargs1['load_8bit'] = load_8bit
|
1870 |
-
all_kwargs1['
|
1871 |
all_kwargs1['gpu_id'] = int(gpu_id) # detranscribe
|
1872 |
model_lower = model_name.strip().lower()
|
1873 |
if model_lower in inv_prompt_type_to_model_lower:
|
@@ -1920,8 +2008,9 @@ def go_gradio(**kwargs):
|
|
1920 |
|
1921 |
get_prompt_str_func1 = functools.partial(get_prompt_str, which=1)
|
1922 |
get_prompt_str_func2 = functools.partial(get_prompt_str, which=2)
|
1923 |
-
prompt_type.change(fn=get_prompt_str_func1, inputs=[prompt_type, prompt_dict], outputs=prompt_dict)
|
1924 |
-
prompt_type2.change(fn=get_prompt_str_func2, inputs=[prompt_type2, prompt_dict2], outputs=prompt_dict2
|
|
|
1925 |
|
1926 |
def dropdown_prompt_type_list(x):
|
1927 |
return gr.Dropdown.update(value=x)
|
@@ -1931,7 +2020,7 @@ def go_gradio(**kwargs):
|
|
1931 |
|
1932 |
load_model_args = dict(fn=load_model,
|
1933 |
inputs=[model_choice, lora_choice, server_choice, model_state, prompt_type,
|
1934 |
-
model_load8bit_checkbox,
|
1935 |
outputs=[model_state, model_used, lora_used, server_used,
|
1936 |
# if prompt_type changes, prompt_dict will change via change rule
|
1937 |
prompt_type, max_new_tokens, min_new_tokens,
|
@@ -1939,28 +2028,27 @@ def go_gradio(**kwargs):
|
|
1939 |
prompt_update_args = dict(fn=dropdown_prompt_type_list, inputs=prompt_type, outputs=prompt_type)
|
1940 |
chatbot_update_args = dict(fn=chatbot_list, inputs=[text_output, model_used], outputs=text_output)
|
1941 |
nochat_update_args = dict(fn=chatbot_list, inputs=[text_output_nochat, model_used], outputs=text_output_nochat)
|
1942 |
-
|
1943 |
-
|
1944 |
-
|
1945 |
-
|
1946 |
-
|
1947 |
-
|
1948 |
|
1949 |
load_model_args2 = dict(fn=load_model,
|
1950 |
inputs=[model_choice2, lora_choice2, server_choice2, model_state2, prompt_type2,
|
1951 |
-
model_load8bit_checkbox2,
|
1952 |
outputs=[model_state2, model_used2, lora_used2, server_used2,
|
1953 |
# if prompt_type2 changes, prompt_dict2 will change via change rule
|
1954 |
prompt_type2, max_new_tokens2, min_new_tokens2
|
1955 |
])
|
1956 |
prompt_update_args2 = dict(fn=dropdown_prompt_type_list, inputs=prompt_type2, outputs=prompt_type2)
|
1957 |
chatbot_update_args2 = dict(fn=chatbot_list, inputs=[text_output2, model_used2], outputs=text_output2)
|
1958 |
-
|
1959 |
-
|
1960 |
-
|
1961 |
-
|
1962 |
-
|
1963 |
-
.then(clear_torch_cache)
|
1964 |
|
1965 |
def dropdown_model_lora_server_list(model_list0, model_x,
|
1966 |
lora_list0, lora_x,
|
@@ -2009,7 +2097,8 @@ def go_gradio(**kwargs):
|
|
2009 |
server_options_state],
|
2010 |
queue=False)
|
2011 |
|
2012 |
-
go_btn.click(lambda: gr.update(visible=False), None, go_btn, api_name="go" if allow_api else None,
|
|
|
2013 |
.then(lambda: gr.update(visible=True), None, normal_block, queue=False) \
|
2014 |
.then(**load_model_args, queue=False).then(**prompt_update_args, queue=False)
|
2015 |
|
@@ -2077,23 +2166,11 @@ def go_gradio(**kwargs):
|
|
2077 |
def get_hash():
|
2078 |
return kwargs['git_hash']
|
2079 |
|
2080 |
-
system_btn3.click(get_hash,
|
2081 |
-
|
2082 |
-
|
2083 |
-
|
2084 |
-
|
2085 |
-
|
2086 |
-
# don't pass text_output, don't want to clear output, just stop it
|
2087 |
-
# cancel only stops outer generation, not inner generation or non-generation
|
2088 |
-
stop_btn.click(lambda: None, None, None,
|
2089 |
-
cancels=submits1 + submits2 + submits3 +
|
2090 |
-
submits4 +
|
2091 |
-
[submit_event_nochat, submit_event_nochat2] +
|
2092 |
-
[eventdb1, eventdb2, eventdb3,
|
2093 |
-
eventdb4, eventdb5, eventdb6] +
|
2094 |
-
[eventdb7, eventdb8, eventdb9]
|
2095 |
-
,
|
2096 |
-
queue=False, api_name='stop' if allow_api else None).then(clear_torch_cache, queue=False)
|
2097 |
|
2098 |
def count_chat_tokens(model_state1, chat1, prompt_type1, prompt_dict1,
|
2099 |
memory_restriction_level1=0,
|
@@ -2121,9 +2198,25 @@ def go_gradio(**kwargs):
|
|
2121 |
count_chat_tokens_func = functools.partial(count_chat_tokens,
|
2122 |
memory_restriction_level1=memory_restriction_level,
|
2123 |
keep_sources_in_context1=kwargs['keep_sources_in_context'])
|
2124 |
-
count_chat_tokens_btn.click(fn=count_chat_tokens,
|
2125 |
-
|
2126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2127 |
|
2128 |
demo.load(None, None, None, _js=get_dark_js() if kwargs['h2ocolors'] and False else None) # light best
|
2129 |
|
@@ -2196,6 +2289,8 @@ def get_inputs_list(inputs_dict, model_lower, model_id=1):
|
|
2196 |
|
2197 |
|
2198 |
def get_sources(db1, langchain_mode, dbs=None, docs_state0=None):
|
|
|
|
|
2199 |
if langchain_mode in ['ChatLLM', 'LLM']:
|
2200 |
source_files_added = "NA"
|
2201 |
source_list = []
|
@@ -2226,9 +2321,24 @@ def get_sources(db1, langchain_mode, dbs=None, docs_state0=None):
|
|
2226 |
return sources_file, source_list
|
2227 |
|
2228 |
|
2229 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2230 |
try:
|
2231 |
-
return _update_user_db(file, db1,
|
|
|
|
|
2232 |
except BaseException as e:
|
2233 |
print(traceback.format_exc(), flush=True)
|
2234 |
# gradio has issues if except, so fail semi-gracefully, else would hang forever in processing textbox
|
@@ -2245,15 +2355,14 @@ def update_user_db(file, db1, x, y, *args, dbs=None, langchain_mode='UserData',
|
|
2245 |
</body>
|
2246 |
</html>
|
2247 |
""".format(ex_str)
|
2248 |
-
|
2249 |
-
|
2250 |
-
else:
|
2251 |
-
return None, langchain_mode, x, y, source_files_added
|
2252 |
finally:
|
2253 |
clear_torch_cache()
|
2254 |
|
2255 |
|
2256 |
def get_lock_file(db1, langchain_mode):
|
|
|
2257 |
assert len(db1) == 2 and db1[1] is not None and isinstance(db1[1], str)
|
2258 |
user_id = db1[1]
|
2259 |
base_path = 'locks'
|
@@ -2262,7 +2371,10 @@ def get_lock_file(db1, langchain_mode):
|
|
2262 |
return lock_file
|
2263 |
|
2264 |
|
2265 |
-
def _update_user_db(file,
|
|
|
|
|
|
|
2266 |
user_path=None,
|
2267 |
use_openai_embedding=None,
|
2268 |
hf_embedding_model=None,
|
@@ -2273,6 +2385,9 @@ def _update_user_db(file, db1, x, y, chunk, chunk_size, dbs=None, db_type=None,
|
|
2273 |
verbose=None,
|
2274 |
is_url=None, is_txt=None,
|
2275 |
n_jobs=-1):
|
|
|
|
|
|
|
2276 |
assert use_openai_embedding is not None
|
2277 |
assert hf_embedding_model is not None
|
2278 |
assert caption_loader is not None
|
@@ -2281,6 +2396,8 @@ def _update_user_db(file, db1, x, y, chunk, chunk_size, dbs=None, db_type=None,
|
|
2281 |
assert enable_ocr is not None
|
2282 |
assert verbose is not None
|
2283 |
|
|
|
|
|
2284 |
if dbs is None:
|
2285 |
dbs = {}
|
2286 |
assert isinstance(dbs, dict), "Wrong type for dbs: %s" % str(type(dbs))
|
@@ -2295,6 +2412,14 @@ def _update_user_db(file, db1, x, y, chunk, chunk_size, dbs=None, db_type=None,
|
|
2295 |
if not isinstance(file, (list, tuple, typing.Generator)) and isinstance(file, str):
|
2296 |
file = [file]
|
2297 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2298 |
if langchain_mode == 'UserData' and user_path is not None:
|
2299 |
# move temp files from gradio upload to stable location
|
2300 |
for fili, fil in enumerate(file):
|
@@ -2323,6 +2448,7 @@ def _update_user_db(file, db1, x, y, chunk, chunk_size, dbs=None, db_type=None,
|
|
2323 |
caption_loader=caption_loader,
|
2324 |
)
|
2325 |
exceptions = [x for x in sources if x.metadata.get('exception')]
|
|
|
2326 |
sources = [x for x in sources if 'exception' not in x.metadata]
|
2327 |
|
2328 |
lock_file = get_lock_file(db1, langchain_mode)
|
@@ -2349,7 +2475,7 @@ def _update_user_db(file, db1, x, y, chunk, chunk_size, dbs=None, db_type=None,
|
|
2349 |
if db is not None:
|
2350 |
db1[0] = db
|
2351 |
source_files_added = get_source_files(db=db1[0], exceptions=exceptions)
|
2352 |
-
return None, langchain_mode,
|
2353 |
else:
|
2354 |
from gpt_langchain import get_persist_directory
|
2355 |
persist_directory = get_persist_directory(langchain_mode)
|
@@ -2367,10 +2493,10 @@ def _update_user_db(file, db1, x, y, chunk, chunk_size, dbs=None, db_type=None,
|
|
2367 |
hf_embedding_model=hf_embedding_model)
|
2368 |
dbs[langchain_mode] = db
|
2369 |
# NOTE we do not return db, because function call always same code path
|
2370 |
-
# return dbs[langchain_mode]
|
2371 |
# db in this code path is updated in place
|
2372 |
source_files_added = get_source_files(db=dbs[langchain_mode], exceptions=exceptions)
|
2373 |
-
return None, langchain_mode,
|
2374 |
|
2375 |
|
2376 |
def get_db(db1, langchain_mode, dbs=None):
|
|
|
20 |
from iterators import TimeoutIterator
|
21 |
|
22 |
from gradio_utils.css import get_css
|
23 |
+
from gradio_utils.prompt_form import make_chatbots
|
24 |
|
25 |
# This is a hack to prevent Gradio from phoning home when it gets imported
|
26 |
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
|
|
|
56 |
from prompter import prompt_type_to_model_name, prompt_types_strings, inv_prompt_type_to_model_lower, non_hf_types, \
|
57 |
get_prompt
|
58 |
from utils import get_githash, flatten_list, zip_data, s3up, clear_torch_cache, get_torch_allocated, system_info_print, \
|
59 |
+
ping, get_short_name, makedirs, get_kwargs, remove, system_info, ping_gpu, get_url, get_local_ip
|
60 |
from gen import get_model, languages_covered, evaluate, score_qa, langchain_modes, inputs_kwargs_list, scratch_base_dir, \
|
61 |
get_max_max_new_tokens, get_minmax_top_k_docs, history_to_context, langchain_actions
|
62 |
from evaluate_params import eval_func_param_names, no_default_param_names, eval_func_param_names_defaults
|
|
|
118 |
allow_upload = allow_upload_to_user_data or allow_upload_to_my_data
|
119 |
kwargs.update(locals())
|
120 |
|
121 |
+
# import control
|
122 |
+
if kwargs['langchain_mode'] != 'Disabled':
|
123 |
+
from gpt_langchain import file_types, have_arxiv
|
124 |
+
else:
|
125 |
+
have_arxiv = False
|
126 |
+
file_types = []
|
127 |
+
|
128 |
if 'mbart-' in kwargs['model_lower']:
|
129 |
instruction_label_nochat = "Text to translate"
|
130 |
else:
|
|
|
141 |
"""
|
142 |
else:
|
143 |
description = more_info
|
144 |
+
description_bottom = "If this host is busy, try [Multi-Model](https://gpt.h2o.ai), [Falcon 40B](http://falcon.h2o.ai), [HF Spaces1](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot) or [HF Spaces2](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot2)<br>"
|
|
|
145 |
if is_hf:
|
146 |
description_bottom += '''<a href="https://huggingface.co/spaces/h2oai/h2ogpt-chatbot?duplicate=true"><img src="https://bit.ly/3gLdBN6" style="white-space: nowrap" alt="Duplicate Space"></a>'''
|
147 |
|
|
|
166 |
theme_kwargs = dict()
|
167 |
if kwargs['gradio_size'] == 'xsmall':
|
168 |
theme_kwargs.update(dict(spacing_size=spacing_xsm, text_size=text_xsm, radius_size=radius_xsm))
|
169 |
+
elif kwargs['gradio_size'] in [None, 'small']:
|
170 |
theme_kwargs.update(dict(spacing_size=gr.themes.sizes.spacing_sm, text_size=gr.themes.sizes.text_sm,
|
171 |
radius_size=gr.themes.sizes.spacing_sm))
|
172 |
elif kwargs['gradio_size'] == 'large':
|
|
|
268 |
model_options_state = gr.State([model_options])
|
269 |
lora_options_state = gr.State([lora_options])
|
270 |
server_options_state = gr.State([server_options])
|
271 |
+
my_db_state = gr.State([None, None])
|
|
|
272 |
chat_state = gr.State({})
|
273 |
+
docs_state00 = kwargs['document_choice'] + [DocumentChoices.All.name]
|
|
|
274 |
docs_state0 = []
|
275 |
[docs_state0.append(x) for x in docs_state00 if x not in docs_state0]
|
276 |
+
docs_state = gr.State(docs_state0)
|
277 |
+
viewable_docs_state0 = []
|
278 |
+
viewable_docs_state = gr.State(viewable_docs_state0)
|
279 |
gr.Markdown(f"""
|
280 |
{get_h2o_title(title, description) if kwargs['h2ocolors'] else get_simple_title(title, description)}
|
281 |
""")
|
|
|
288 |
res_value = "Response Score: NA" if not kwargs[
|
289 |
'model_lock'] else "Response Scores: %s" % nas
|
290 |
|
291 |
+
if kwargs['langchain_mode'] != LangChainMode.DISABLED.value:
|
292 |
+
extra_prompt_form = ". For summarization, empty submission uses first top_k_docs documents."
|
293 |
+
else:
|
294 |
+
extra_prompt_form = ""
|
295 |
+
if kwargs['input_lines'] > 1:
|
296 |
+
instruction_label = "Shift-Enter to Submit, Enter for more lines%s" % extra_prompt_form
|
297 |
+
else:
|
298 |
+
instruction_label = "Enter to Submit, Shift-Enter for more lines%s" % extra_prompt_form
|
299 |
+
|
300 |
+
normal_block = gr.Row(visible=not base_wanted, equal_height=False)
|
301 |
with normal_block:
|
302 |
+
side_bar = gr.Column(elem_id="col_container", scale=1, min_width=100)
|
303 |
+
with side_bar:
|
304 |
+
with gr.Accordion("Chats", open=False, visible=True):
|
305 |
+
radio_chats = gr.Radio(value=None, label="Saved Chats", show_label=False,
|
306 |
+
visible=True, interactive=True,
|
307 |
+
type='value')
|
308 |
+
upload_visible = kwargs['langchain_mode'] != 'Disabled' and allow_upload
|
309 |
+
with gr.Accordion("Upload", open=False, visible=upload_visible):
|
310 |
+
with gr.Column():
|
311 |
+
with gr.Row(equal_height=False):
|
312 |
+
file_types_str = '[' + ' '.join(file_types) + ' URL ArXiv TEXT' + ']'
|
313 |
+
fileup_output = gr.File(label=f'Upload {file_types_str}',
|
314 |
+
show_label=False,
|
315 |
+
file_types=file_types,
|
316 |
+
file_count="multiple",
|
317 |
+
scale=1,
|
318 |
+
min_width=0,
|
319 |
+
elem_id="warning", elem_classes="feedback")
|
320 |
+
url_visible = kwargs['langchain_mode'] != 'Disabled' and allow_upload and enable_url_upload
|
321 |
+
url_label = 'URL/ArXiv' if have_arxiv else 'URL'
|
322 |
+
url_text = gr.Textbox(label=url_label,
|
323 |
+
# placeholder="Enter Submits",
|
324 |
+
max_lines=1,
|
325 |
+
interactive=True)
|
326 |
+
text_visible = kwargs['langchain_mode'] != 'Disabled' and allow_upload and enable_text_upload
|
327 |
+
user_text_text = gr.Textbox(label='Paste Text',
|
328 |
+
# placeholder="Enter Submits",
|
329 |
+
interactive=True,
|
330 |
+
visible=text_visible)
|
331 |
+
github_textbox = gr.Textbox(label="Github URL", visible=False) # FIXME WIP
|
332 |
+
database_visible = kwargs['langchain_mode'] != 'Disabled'
|
333 |
+
with gr.Accordion("Database", open=False, visible=database_visible):
|
334 |
+
if is_hf:
|
335 |
+
# don't show 'wiki' since only usually useful for internal testing at moment
|
336 |
+
no_show_modes = ['Disabled', 'wiki']
|
337 |
+
else:
|
338 |
+
no_show_modes = ['Disabled']
|
339 |
+
allowed_modes = visible_langchain_modes.copy()
|
340 |
+
allowed_modes = [x for x in allowed_modes if x in dbs]
|
341 |
+
allowed_modes += ['ChatLLM', 'LLM']
|
342 |
+
if allow_upload_to_my_data and 'MyData' not in allowed_modes:
|
343 |
+
allowed_modes += ['MyData']
|
344 |
+
if allow_upload_to_user_data and 'UserData' not in allowed_modes:
|
345 |
+
allowed_modes += ['UserData']
|
346 |
+
langchain_mode = gr.Radio(
|
347 |
+
[x for x in langchain_modes if x in allowed_modes and x not in no_show_modes],
|
348 |
+
value=kwargs['langchain_mode'],
|
349 |
+
label="Collections",
|
350 |
+
show_label=True,
|
351 |
+
visible=kwargs['langchain_mode'] != 'Disabled',
|
352 |
+
min_width=100)
|
353 |
+
document_subset = gr.Radio([x.name for x in DocumentChoices],
|
354 |
+
label="Subset",
|
355 |
+
value=DocumentChoices.Relevant.name,
|
356 |
+
interactive=True,
|
357 |
+
)
|
358 |
+
allowed_actions = [x for x in langchain_actions if x in visible_langchain_actions]
|
359 |
+
langchain_action = gr.Radio(
|
360 |
+
allowed_actions,
|
361 |
+
value=allowed_actions[0] if len(allowed_actions) > 0 else None,
|
362 |
+
label="Action",
|
363 |
+
visible=True)
|
364 |
+
col_tabs = gr.Column(elem_id="col_container", scale=10)
|
365 |
+
with (col_tabs, gr.Tabs()):
|
366 |
+
with gr.TabItem("Chat"):
|
367 |
+
if kwargs['langchain_mode'] == 'Disabled':
|
368 |
+
text_output_nochat = gr.Textbox(lines=5, label=output_label0, show_copy_button=True,
|
369 |
+
visible=not kwargs['chat'])
|
370 |
+
else:
|
371 |
+
# text looks a bit worse, but HTML links work
|
372 |
+
text_output_nochat = gr.HTML(label=output_label0, visible=not kwargs['chat'])
|
373 |
+
with gr.Row():
|
374 |
+
# NOCHAT
|
375 |
instruction_nochat = gr.Textbox(
|
376 |
lines=kwargs['input_lines'],
|
377 |
label=instruction_label_nochat,
|
378 |
placeholder=kwargs['placeholder_instruction'],
|
379 |
+
visible=not kwargs['chat'],
|
380 |
)
|
381 |
iinput_nochat = gr.Textbox(lines=4, label="Input context for Instruction",
|
382 |
+
placeholder=kwargs['placeholder_input'],
|
383 |
+
visible=not kwargs['chat'])
|
384 |
+
submit_nochat = gr.Button("Submit", size='sm', visible=not kwargs['chat'])
|
385 |
+
flag_btn_nochat = gr.Button("Flag", size='sm', visible=not kwargs['chat'])
|
386 |
+
score_text_nochat = gr.Textbox("Response Score: NA", show_label=False,
|
387 |
+
visible=not kwargs['chat'])
|
388 |
+
submit_nochat_api = gr.Button("Submit nochat API", visible=False)
|
389 |
+
inputs_dict_str = gr.Textbox(label='API input for nochat', show_label=False, visible=False)
|
390 |
+
text_output_nochat_api = gr.Textbox(lines=5, label='API nochat output', visible=False,
|
391 |
+
show_copy_button=True)
|
392 |
+
|
393 |
+
# CHAT
|
394 |
+
col_chat = gr.Column(visible=kwargs['chat'])
|
395 |
+
with col_chat:
|
396 |
+
with gr.Row(): # elem_id='prompt-form-area'):
|
397 |
+
with gr.Column(scale=50):
|
398 |
+
instruction = gr.Textbox(
|
399 |
+
lines=kwargs['input_lines'],
|
400 |
+
label='Ask anything',
|
401 |
+
placeholder=instruction_label,
|
402 |
+
info=None,
|
403 |
+
elem_id='prompt-form',
|
404 |
+
container=True,
|
405 |
+
)
|
406 |
+
submit_buttons = gr.Row(equal_height=False)
|
407 |
+
with submit_buttons:
|
408 |
+
mw1 = 50
|
409 |
+
mw2 = 50
|
410 |
+
with gr.Column(min_width=mw1):
|
411 |
+
submit = gr.Button(value='Submit', variant='primary', scale=0, size='sm',
|
412 |
+
min_width=mw1)
|
413 |
+
stop_btn = gr.Button(value="Stop", variant='secondary', scale=0, size='sm',
|
414 |
+
min_width=mw1)
|
415 |
+
save_chat_btn = gr.Button("Save", size='sm', min_width=mw1)
|
416 |
+
with gr.Column(min_width=mw2):
|
417 |
+
retry_btn = gr.Button("Redo", size='sm', min_width=mw2)
|
418 |
+
undo = gr.Button("Undo", size='sm', min_width=mw2)
|
419 |
+
clear_chat_btn = gr.Button(value="Clear", size='sm', min_width=mw2)
|
420 |
+
text_output, text_output2, text_outputs = make_chatbots(output_label0, output_label0_model2,
|
421 |
+
**kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
422 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
423 |
with gr.Row():
|
424 |
+
with gr.Column(visible=kwargs['score_model']):
|
425 |
+
score_text = gr.Textbox(res_value,
|
426 |
+
show_label=False,
|
427 |
+
visible=True)
|
428 |
+
score_text2 = gr.Textbox("Response Score2: NA", show_label=False,
|
429 |
+
visible=False and not kwargs['model_lock'])
|
430 |
+
|
431 |
+
with gr.TabItem("Document Selection"):
|
432 |
+
document_choice = gr.Dropdown(docs_state0,
|
433 |
+
label="Select Subset of Document(s) %s" % file_types_str,
|
434 |
+
value='All',
|
435 |
+
interactive=True,
|
436 |
+
multiselect=True,
|
437 |
+
)
|
438 |
+
sources_visible = kwargs['langchain_mode'] != 'Disabled' and enable_sources_list
|
439 |
+
with gr.Row():
|
440 |
+
get_sources_btn = gr.Button(value="Update UI with Document(s) from DB", scale=0, size='sm',
|
441 |
+
visible=sources_visible)
|
442 |
+
show_sources_btn = gr.Button(value="Show Sources from DB", scale=0, size='sm',
|
443 |
+
visible=sources_visible)
|
444 |
+
refresh_sources_btn = gr.Button(value="Update DB with new/changed files on disk", scale=0,
|
445 |
+
size='sm',
|
446 |
+
visible=sources_visible and allow_upload_to_user_data)
|
447 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
448 |
sources_row = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and enable_sources_list,
|
449 |
equal_height=False)
|
450 |
with sources_row:
|
451 |
with gr.Column(scale=1):
|
452 |
file_source = gr.File(interactive=False,
|
453 |
+
label="Download File w/Sources")
|
454 |
with gr.Column(scale=2):
|
455 |
sources_text = gr.HTML(label='Sources Added', interactive=False)
|
456 |
|
457 |
+
doc_exception_text = gr.Textbox(value="", visible=True, label='Document Exceptions',
|
458 |
+
interactive=False)
|
459 |
+
with gr.TabItem("Document Viewer"):
|
460 |
+
with gr.Row():
|
461 |
+
with gr.Column(scale=2):
|
462 |
+
get_viewable_sources_btn = gr.Button(value="Update UI with Document(s) from DB", scale=0,
|
463 |
+
size='sm',
|
464 |
+
visible=sources_visible)
|
465 |
+
view_document_choice = gr.Dropdown(viewable_docs_state0,
|
466 |
+
label="Select Single Document",
|
467 |
+
value=None,
|
468 |
+
interactive=True,
|
469 |
+
multiselect=False,
|
470 |
+
)
|
471 |
+
with gr.Column(scale=4):
|
472 |
+
pass
|
473 |
+
document = 'http://infolab.stanford.edu/pub/papers/google.pdf'
|
474 |
+
doc_view = gr.HTML(visible=False)
|
475 |
+
doc_view2 = gr.Dataframe(visible=False)
|
476 |
+
doc_view3 = gr.JSON(visible=False)
|
477 |
+
doc_view4 = gr.Markdown(visible=False)
|
478 |
+
|
479 |
with gr.TabItem("Chat History"):
|
480 |
+
with gr.Row():
|
481 |
+
with gr.Column(scale=1):
|
482 |
+
remove_chat_btn = gr.Button(value="Remove Selected Saved Chats", visible=True, size='sm')
|
483 |
+
flag_btn = gr.Button("Flag Current Chat", size='sm')
|
484 |
+
export_chats_btn = gr.Button(value="Export Chats to Download", size='sm')
|
485 |
+
with gr.Column(scale=4):
|
486 |
+
pass
|
487 |
+
with gr.Row():
|
488 |
+
chats_file = gr.File(interactive=False, label="Download Exported Chats")
|
489 |
+
chatsup_output = gr.File(label="Upload Chat File(s)",
|
490 |
+
file_types=['.json'],
|
491 |
+
file_count='multiple',
|
492 |
+
elem_id="warning", elem_classes="feedback")
|
493 |
with gr.Row():
|
494 |
if 'mbart-' in kwargs['model_lower']:
|
495 |
src_lang = gr.Dropdown(list(languages_covered().keys()),
|
|
|
498 |
tgt_lang = gr.Dropdown(list(languages_covered().keys()),
|
499 |
value=kwargs['tgt_lang'],
|
500 |
label="Output Language")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
501 |
|
502 |
+
chat_exception_text = gr.Textbox(value="", visible=True, label='Chat Exceptions',
|
503 |
+
interactive=False)
|
504 |
with gr.TabItem("Expert"):
|
505 |
with gr.Row():
|
506 |
with gr.Column():
|
|
|
579 |
info="Directly pre-appended without prompt processing",
|
580 |
interactive=not is_public)
|
581 |
chat = gr.components.Checkbox(label="Chat mode", value=kwargs['chat'],
|
582 |
+
visible=False, # no longer support nochat in UI
|
583 |
interactive=not is_public,
|
584 |
)
|
585 |
count_chat_tokens_btn = gr.Button(value="Count Chat Tokens",
|
|
|
638 |
model_load8bit_checkbox = gr.components.Checkbox(
|
639 |
label="Load 8-bit [requires support]",
|
640 |
value=kwargs['load_8bit'], interactive=not is_public)
|
641 |
+
model_use_gpu_id_checkbox = gr.components.Checkbox(
|
642 |
label="Choose Devices [If not Checked, use all GPUs]",
|
643 |
+
value=kwargs['use_gpu_id'], interactive=not is_public)
|
644 |
model_gpu = gr.Dropdown(n_gpus_list,
|
645 |
label="GPU ID [-1 = all GPUs, if Choose is enabled]",
|
646 |
value=kwargs['gpu_id'], interactive=not is_public)
|
|
|
673 |
model_load8bit_checkbox2 = gr.components.Checkbox(
|
674 |
label="Load 8-bit 2 [requires support]",
|
675 |
value=kwargs['load_8bit'], interactive=not is_public)
|
676 |
+
model_use_gpu_id_checkbox2 = gr.components.Checkbox(
|
677 |
label="Choose Devices 2 [If not Checked, use all GPUs]",
|
678 |
value=kwargs[
|
679 |
+
'use_gpu_id'], interactive=not is_public)
|
680 |
model_gpu2 = gr.Dropdown(n_gpus_list,
|
681 |
label="GPU ID 2 [-1 = all GPUs, if choose is enabled]",
|
682 |
value=kwargs['gpu_id'], interactive=not is_public)
|
|
|
703 |
add_model_lora_server_button = gr.Button("Add new Model, Lora, Server url:port", scale=0,
|
704 |
size='sm', interactive=not is_public)
|
705 |
with gr.TabItem("System"):
|
706 |
+
with gr.Row():
|
707 |
+
with gr.Column(scale=1):
|
708 |
+
side_bar_text = gr.Textbox('on', visible=False, interactive=False)
|
709 |
+
submit_buttons_text = gr.Textbox('on', visible=False, interactive=False)
|
710 |
+
|
711 |
+
side_bar_btn = gr.Button("Toggle SideBar", variant="secondary", size="sm")
|
712 |
+
submit_buttons_btn = gr.Button("Toggle Submit Buttons", variant="secondary", size="sm")
|
713 |
+
col_tabs_scale = gr.Slider(minimum=1, maximum=20, value=10, step=1, label='Window Size')
|
714 |
+
text_outputs_height = gr.Slider(minimum=100, maximum=1000, value=kwargs['height'] or 400,
|
715 |
+
step=100, label='Chat Height')
|
716 |
+
dark_mode_btn = gr.Button("Dark Mode", variant="secondary", size="sm")
|
717 |
+
with gr.Column(scale=4):
|
718 |
+
pass
|
719 |
admin_row = gr.Row()
|
720 |
with admin_row:
|
721 |
+
with gr.Column(scale=1):
|
722 |
+
admin_pass_textbox = gr.Textbox(label="Admin Password", type='password', visible=is_public)
|
723 |
+
admin_btn = gr.Button(value="Admin Access", visible=is_public, size='sm')
|
724 |
+
with gr.Column(scale=4):
|
725 |
+
pass
|
726 |
system_row = gr.Row(visible=not is_public)
|
727 |
with system_row:
|
728 |
with gr.Column():
|
729 |
with gr.Row():
|
730 |
+
system_btn = gr.Button(value='Get System Info', size='sm')
|
731 |
system_text = gr.Textbox(label='System Info', interactive=False, show_copy_button=True)
|
732 |
with gr.Row():
|
733 |
system_input = gr.Textbox(label='System Info Dict Password', interactive=True,
|
734 |
visible=not is_public)
|
735 |
+
system_btn2 = gr.Button(value='Get System Info Dict', visible=not is_public, size='sm')
|
736 |
system_text2 = gr.Textbox(label='System Info Dict', interactive=False,
|
737 |
visible=not is_public, show_copy_button=True)
|
738 |
with gr.Row():
|
739 |
+
system_btn3 = gr.Button(value='Get Hash', visible=not is_public, size='sm')
|
740 |
system_text3 = gr.Textbox(label='Hash', interactive=False,
|
741 |
visible=not is_public, show_copy_button=True)
|
742 |
|
743 |
with gr.Row():
|
744 |
+
zip_btn = gr.Button("Zip", size='sm')
|
745 |
zip_text = gr.Textbox(label="Zip file name", interactive=False)
|
746 |
file_output = gr.File(interactive=False, label="Zip file to Download")
|
747 |
with gr.Row():
|
748 |
+
s3up_btn = gr.Button("S3UP", size='sm')
|
749 |
s3up_text = gr.Textbox(label='S3UP result', interactive=False)
|
750 |
+
|
751 |
+
with gr.TabItem("Terms of Service"):
|
752 |
description = ""
|
753 |
description += """<p><b> DISCLAIMERS: </b><ul><i><li>The model was trained on The Pile and other data, which may contain objectionable content. Use at own risk.</i></li>"""
|
754 |
if kwargs['load_8bit']:
|
|
|
759 |
description += """<i><li>By using h2oGPT, you accept our <a href="https://github.com/h2oai/h2ogpt/blob/main/docs/tos.md">Terms of Service</a></i></li></ul></p>"""
|
760 |
gr.Markdown(value=description, show_label=False, interactive=False)
|
761 |
|
762 |
+
with gr.TabItem("Hosts"):
|
763 |
+
gr.Markdown(f"""
|
764 |
+
{description_bottom}
|
765 |
+
{task_info_md}
|
766 |
+
""")
|
767 |
|
768 |
# Get flagged data
|
769 |
zip_data1 = functools.partial(zip_data, root_dirs=['flagged_data_points', kwargs['save_dir']])
|
770 |
+
zip_event = zip_btn.click(zip_data1, inputs=None, outputs=[file_output, zip_text], queue=False,
|
771 |
+
api_name='zip_data' if allow_api else None)
|
772 |
+
s3up_event = s3up_btn.click(s3up, inputs=zip_text, outputs=s3up_text, queue=False,
|
773 |
+
api_name='s3up_data' if allow_api else None)
|
774 |
|
775 |
def clear_file_list():
|
776 |
return None
|
|
|
788 |
return tuple([gr.update(interactive=True)] * len(args))
|
789 |
|
790 |
# Add to UserData
|
791 |
+
update_db_func = functools.partial(update_user_db,
|
792 |
+
dbs=dbs,
|
793 |
+
db_type=db_type,
|
794 |
+
use_openai_embedding=use_openai_embedding,
|
795 |
+
hf_embedding_model=hf_embedding_model,
|
796 |
+
enable_captions=enable_captions,
|
797 |
+
captions_model=captions_model,
|
798 |
+
enable_ocr=enable_ocr,
|
799 |
+
caption_loader=caption_loader,
|
800 |
+
verbose=kwargs['verbose'],
|
801 |
+
user_path=kwargs['user_path'],
|
802 |
+
n_jobs=kwargs['n_jobs'],
|
803 |
+
)
|
804 |
+
add_file_outputs = [fileup_output, langchain_mode]
|
805 |
+
add_file_kwargs = dict(fn=update_db_func,
|
806 |
+
inputs=[fileup_output, my_db_state, chunk, chunk_size, langchain_mode],
|
807 |
+
outputs=add_file_outputs + [sources_text, doc_exception_text],
|
|
|
808 |
queue=queue,
|
809 |
+
api_name='add_file' if allow_api and allow_upload_to_user_data else None)
|
810 |
|
|
|
|
|
|
|
|
|
811 |
# then no need for add buttons, only single changeable db
|
812 |
+
eventdb1a = fileup_output.upload(make_non_interactive, inputs=add_file_outputs, outputs=add_file_outputs,
|
813 |
+
show_progress='minimal')
|
814 |
+
eventdb1 = eventdb1a.then(**add_file_kwargs, show_progress='full')
|
815 |
+
eventdb1b = eventdb1.then(make_interactive, inputs=add_file_outputs, outputs=add_file_outputs,
|
816 |
+
show_progress='minimal')
|
817 |
|
818 |
# note for update_user_db_func output is ignored for db
|
819 |
|
820 |
def clear_textbox():
|
821 |
return gr.Textbox.update(value='')
|
822 |
|
823 |
+
update_user_db_url_func = functools.partial(update_db_func, is_url=True)
|
824 |
|
825 |
+
add_url_outputs = [url_text, langchain_mode]
|
826 |
add_url_kwargs = dict(fn=update_user_db_url_func,
|
827 |
+
inputs=[url_text, my_db_state, chunk, chunk_size, langchain_mode],
|
828 |
+
outputs=add_url_outputs + [sources_text, doc_exception_text],
|
|
|
829 |
queue=queue,
|
830 |
+
api_name='add_url' if allow_api and allow_upload_to_user_data else None)
|
831 |
|
832 |
+
eventdb2a = url_text.submit(fn=dummy_fun, inputs=url_text, outputs=url_text, queue=queue,
|
833 |
+
show_progress='minimal')
|
|
|
|
|
|
|
|
|
834 |
# work around https://github.com/gradio-app/gradio/issues/4733
|
835 |
eventdb2b = eventdb2a.then(make_non_interactive, inputs=add_url_outputs, outputs=add_url_outputs,
|
836 |
show_progress='minimal')
|
837 |
+
eventdb2 = eventdb2b.then(**add_url_kwargs, show_progress='full')
|
838 |
+
eventdb2c = eventdb2.then(make_interactive, inputs=add_url_outputs, outputs=add_url_outputs,
|
839 |
+
show_progress='minimal')
|
840 |
|
841 |
+
update_user_db_txt_func = functools.partial(update_db_func, is_txt=True)
|
842 |
+
add_text_outputs = [user_text_text, langchain_mode]
|
843 |
add_text_kwargs = dict(fn=update_user_db_txt_func,
|
844 |
+
inputs=[user_text_text, my_db_state, chunk, chunk_size, langchain_mode],
|
845 |
+
outputs=add_text_outputs + [sources_text, doc_exception_text],
|
|
|
846 |
queue=queue,
|
847 |
+
api_name='add_text' if allow_api and allow_upload_to_user_data else None
|
848 |
)
|
849 |
+
eventdb3a = user_text_text.submit(fn=dummy_fun, inputs=user_text_text, outputs=user_text_text, queue=queue,
|
850 |
+
show_progress='minimal')
|
|
|
|
|
|
|
|
|
|
|
851 |
eventdb3b = eventdb3a.then(make_non_interactive, inputs=add_text_outputs, outputs=add_text_outputs,
|
852 |
show_progress='minimal')
|
853 |
+
eventdb3 = eventdb3b.then(**add_text_kwargs, show_progress='full')
|
854 |
+
eventdb3c = eventdb3.then(make_interactive, inputs=add_text_outputs, outputs=add_text_outputs,
|
855 |
+
show_progress='minimal')
|
856 |
+
db_events = [eventdb1a, eventdb1, eventdb1b,
|
857 |
+
eventdb2a, eventdb2, eventdb2b, eventdb2c,
|
858 |
+
eventdb3a, eventdb3b, eventdb3, eventdb3c]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
859 |
|
860 |
get_sources1 = functools.partial(get_sources, dbs=dbs, docs_state0=docs_state0)
|
861 |
|
862 |
# if change collection source, must clear doc selections from it to avoid inconsistency
|
863 |
def clear_doc_choice():
|
864 |
+
return gr.Dropdown.update(choices=docs_state0, value=DocumentChoices.All.name)
|
865 |
+
|
866 |
+
langchain_mode.change(clear_doc_choice, inputs=None, outputs=document_choice, queue=False)
|
867 |
|
868 |
+
def resize_col_tabs(x):
|
869 |
+
return gr.Dropdown.update(scale=x)
|
870 |
+
|
871 |
+
col_tabs_scale.change(fn=resize_col_tabs, inputs=col_tabs_scale, outputs=col_tabs)
|
872 |
+
|
873 |
+
def resize_chatbots(x, num_model_lock=0):
|
874 |
+
if num_model_lock == 0:
|
875 |
+
num_model_lock = 3 # 2 + 1 (which is dup of first)
|
876 |
+
else:
|
877 |
+
num_model_lock = 2 + num_model_lock
|
878 |
+
return tuple([gr.update(height=x)] * num_model_lock)
|
879 |
+
|
880 |
+
resize_chatbots_func = functools.partial(resize_chatbots, num_model_lock=len(text_outputs))
|
881 |
+
text_outputs_height.change(fn=resize_chatbots_func, inputs=text_outputs_height,
|
882 |
+
outputs=[text_output, text_output2] + text_outputs)
|
883 |
|
884 |
def update_dropdown(x):
|
885 |
return gr.Dropdown.update(choices=x, value=[docs_state0[0]])
|
886 |
|
887 |
+
get_sources_args = dict(fn=get_sources1, inputs=[my_db_state, langchain_mode],
|
888 |
+
outputs=[file_source, docs_state],
|
889 |
+
queue=queue,
|
890 |
+
api_name='get_sources' if allow_api else None)
|
891 |
+
|
892 |
+
eventdb7 = get_sources_btn.click(**get_sources_args) \
|
893 |
.then(fn=update_dropdown, inputs=docs_state, outputs=document_choice)
|
894 |
# show button, else only show when add. Could add to above get_sources for download/dropdown, but bit much maybe
|
895 |
show_sources1 = functools.partial(get_source_files_given_langchain_mode, dbs=dbs)
|
896 |
eventdb8 = show_sources_btn.click(fn=show_sources1, inputs=[my_db_state, langchain_mode], outputs=sources_text,
|
897 |
api_name='show_sources' if allow_api else None)
|
898 |
|
899 |
+
def update_viewable_dropdown(x):
|
900 |
+
return gr.Dropdown.update(choices=x,
|
901 |
+
value=viewable_docs_state0[0] if len(viewable_docs_state0) > 0 else None)
|
902 |
+
|
903 |
+
get_viewable_sources1 = functools.partial(get_sources, dbs=dbs, docs_state0=viewable_docs_state0)
|
904 |
+
get_viewable_sources_args = dict(fn=get_viewable_sources1, inputs=[my_db_state, langchain_mode],
|
905 |
+
outputs=[file_source, viewable_docs_state],
|
906 |
+
queue=queue,
|
907 |
+
api_name='get_viewable_sources' if allow_api else None)
|
908 |
+
eventdb12 = get_viewable_sources_btn.click(**get_viewable_sources_args) \
|
909 |
+
.then(fn=update_viewable_dropdown, inputs=viewable_docs_state,
|
910 |
+
outputs=view_document_choice)
|
911 |
+
|
912 |
+
def show_doc(file):
|
913 |
+
dummy1 = gr.update(visible=False, value=None)
|
914 |
+
dummy_ret = dummy1, dummy1, dummy1, dummy1
|
915 |
+
if not isinstance(file, str):
|
916 |
+
return dummy_ret
|
917 |
+
|
918 |
+
if file.endswith('.md'):
|
919 |
+
try:
|
920 |
+
with open(file, 'rt') as f:
|
921 |
+
content = f.read()
|
922 |
+
return dummy1, dummy1, dummy1, gr.update(visible=True, value=content)
|
923 |
+
except:
|
924 |
+
return dummy_ret
|
925 |
+
|
926 |
+
if file.endswith('.py'):
|
927 |
+
try:
|
928 |
+
with open(file, 'rt') as f:
|
929 |
+
content = f.read()
|
930 |
+
content = f"```python\n{content}\n```"
|
931 |
+
return dummy1, dummy1, dummy1, gr.update(visible=True, value=content)
|
932 |
+
except:
|
933 |
+
return dummy_ret
|
934 |
+
|
935 |
+
if file.endswith('.txt') or file.endswith('.rst') or file.endswith('.rtf') or file.endswith('.toml'):
|
936 |
+
try:
|
937 |
+
with open(file, 'rt') as f:
|
938 |
+
content = f.read()
|
939 |
+
content = f"```text\n{content}\n```"
|
940 |
+
return dummy1, dummy1, dummy1, gr.update(visible=True, value=content)
|
941 |
+
except:
|
942 |
+
return dummy_ret
|
943 |
+
|
944 |
+
func = None
|
945 |
+
if file.endswith(".csv"):
|
946 |
+
func = pd.read_csv
|
947 |
+
elif file.endswith(".pickle"):
|
948 |
+
func = pd.read_pickle
|
949 |
+
elif file.endswith(".xls") or file.endswith("xlsx"):
|
950 |
+
func = pd.read_excel
|
951 |
+
elif file.endswith('.json'):
|
952 |
+
func = pd.read_json
|
953 |
+
elif file.endswith('.xml'):
|
954 |
+
func = pd.read_xml
|
955 |
+
if func is not None:
|
956 |
+
try:
|
957 |
+
df = func(file).head(100)
|
958 |
+
except:
|
959 |
+
return dummy_ret
|
960 |
+
return dummy1, gr.update(visible=True, value=df), dummy1, dummy1
|
961 |
+
port = int(os.getenv('GRADIO_SERVER_PORT', '7860'))
|
962 |
+
import pathlib
|
963 |
+
absolute_path_string = os.path.abspath(file)
|
964 |
+
url_path = pathlib.Path(absolute_path_string).as_uri()
|
965 |
+
url = get_url(absolute_path_string, from_str=True)
|
966 |
+
img_url = url.replace("""<a href=""", """<img src=""")
|
967 |
+
if file.endswith('.png') or file.endswith('.jpg') or file.endswith('.jpeg'):
|
968 |
+
return gr.update(visible=True, value=img_url), dummy1, dummy1, dummy1
|
969 |
+
elif file.endswith('.pdf') or 'arxiv.org/pdf' in file:
|
970 |
+
if file.startswith('http') or file.startswith('https'):
|
971 |
+
# if file is online, then might as well use google(?)
|
972 |
+
document1 = file
|
973 |
+
return gr.update(visible=True, value=f"""<iframe width="1000" height="800" src="https://docs.google.com/viewerng/viewer?url={document1}&embedded=true" frameborder="0" height="100%" width="100%">
|
974 |
+
</iframe>
|
975 |
+
"""), dummy1, dummy1, dummy1
|
976 |
+
else:
|
977 |
+
ip = get_local_ip()
|
978 |
+
document1 = url_path.replace('file://', f'http://{ip}:{port}/')
|
979 |
+
# document1 = url
|
980 |
+
return gr.update(visible=True, value=f"""<object data="{document1}" type="application/pdf">
|
981 |
+
<iframe src="https://docs.google.com/viewer?url={document1}&embedded=true"></iframe>
|
982 |
+
</object>"""), dummy1, dummy1, dummy1
|
983 |
+
else:
|
984 |
+
return dummy_ret
|
985 |
+
|
986 |
+
view_document_choice.select(fn=show_doc, inputs=view_document_choice,
|
987 |
+
outputs=[doc_view, doc_view2, doc_view3, doc_view4])
|
988 |
+
|
989 |
# Get inputs to evaluate() and make_db()
|
990 |
# don't deepcopy, can contain model itself
|
991 |
all_kwargs = kwargs.copy()
|
|
|
1072 |
**kwargs_evaluate
|
1073 |
)
|
1074 |
|
|
|
|
|
|
|
1075 |
dark_mode_btn.click(
|
1076 |
None,
|
1077 |
None,
|
|
|
1081 |
queue=False,
|
1082 |
)
|
1083 |
|
1084 |
+
def visible_toggle(x):
|
1085 |
+
x = 'off' if x == 'on' else 'on'
|
1086 |
+
return x, gr.Column.update(visible=True if x == 'on' else False)
|
1087 |
|
1088 |
+
side_bar_btn.click(fn=visible_toggle,
|
1089 |
+
inputs=side_bar_text,
|
1090 |
+
outputs=[side_bar_text, side_bar],
|
1091 |
+
queue=False)
|
1092 |
|
1093 |
+
submit_buttons_btn.click(fn=visible_toggle,
|
1094 |
+
inputs=submit_buttons_text,
|
1095 |
+
outputs=[submit_buttons_text, submit_buttons],
|
1096 |
+
queue=False)
|
|
|
|
|
|
|
1097 |
|
1098 |
# examples after submit or any other buttons for chat or no chat
|
1099 |
if kwargs['examples'] is not None and kwargs['show_examples']:
|
|
|
1214 |
prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
|
1215 |
langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
|
1216 |
langchain_action1 = args_list[eval_func_param_names.index('langchain_action')]
|
1217 |
+
document_subset1 = args_list[eval_func_param_names.index('document_subset')]
|
1218 |
document_choice1 = args_list[eval_func_param_names.index('document_choice')]
|
1219 |
if not prompt_type1:
|
1220 |
# shouldn't have to specify if CLI launched model
|
|
|
1247 |
return history
|
1248 |
if user_message1 in ['', None, '\n']:
|
1249 |
if langchain_action1 in LangChainAction.QUERY.value and \
|
1250 |
+
DocumentChoices.All.name != document_subset1 \
|
1251 |
or \
|
1252 |
langchain_mode1 in [LangChainMode.CHAT_LLM.value, LangChainMode.LLM.value]:
|
1253 |
# reject non-retry submit/enter
|
|
|
1310 |
args_list = args_list[:-3] # only keep rest needed for evaluate()
|
1311 |
langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
|
1312 |
langchain_action1 = args_list[eval_func_param_names.index('langchain_action')]
|
1313 |
+
document_subset1 = args_list[eval_func_param_names.index('document_subset')]
|
1314 |
document_choice1 = args_list[eval_func_param_names.index('document_choice')]
|
1315 |
if not history:
|
1316 |
print("No history", flush=True)
|
|
|
1323 |
history[-1][1] = None
|
1324 |
elif not instruction1:
|
1325 |
if langchain_action1 in LangChainAction.QUERY.value and \
|
1326 |
+
DocumentChoices.All.name != document_choice1 \
|
1327 |
or \
|
1328 |
langchain_mode1 in [LangChainMode.CHAT_LLM.value, LangChainMode.LLM.value]:
|
1329 |
# if not retrying, then reject empty query
|
|
|
1494 |
)
|
1495 |
bot_args = dict(fn=bot,
|
1496 |
inputs=inputs_list + [model_state, my_db_state] + [text_output],
|
1497 |
+
outputs=[text_output, chat_exception_text],
|
1498 |
)
|
1499 |
retry_bot_args = dict(fn=functools.partial(bot, retry=True),
|
1500 |
inputs=inputs_list + [model_state, my_db_state] + [text_output],
|
1501 |
+
outputs=[text_output, chat_exception_text],
|
1502 |
)
|
1503 |
retry_user_args = dict(fn=functools.partial(user, retry=True),
|
1504 |
inputs=inputs_list + [text_output],
|
|
|
1516 |
)
|
1517 |
bot_args2 = dict(fn=bot,
|
1518 |
inputs=inputs_list2 + [model_state2, my_db_state] + [text_output2],
|
1519 |
+
outputs=[text_output2, chat_exception_text],
|
1520 |
)
|
1521 |
retry_bot_args2 = dict(fn=functools.partial(bot, retry=True),
|
1522 |
inputs=inputs_list2 + [model_state2, my_db_state] + [text_output2],
|
1523 |
+
outputs=[text_output2, chat_exception_text],
|
1524 |
)
|
1525 |
retry_user_args2 = dict(fn=functools.partial(user, retry=True),
|
1526 |
inputs=inputs_list2 + [text_output2],
|
|
|
1541 |
)
|
1542 |
all_bot_args = dict(fn=functools.partial(all_bot, model_states1=model_states),
|
1543 |
inputs=inputs_list + [my_db_state] + text_outputs,
|
1544 |
+
outputs=text_outputs + [chat_exception_text],
|
1545 |
)
|
1546 |
all_retry_bot_args = dict(fn=functools.partial(all_bot, model_states1=model_states, retry=True),
|
1547 |
inputs=inputs_list + [my_db_state] + text_outputs,
|
1548 |
+
outputs=text_outputs + [chat_exception_text],
|
1549 |
)
|
1550 |
all_retry_user_args = dict(fn=functools.partial(all_user, retry=True,
|
1551 |
sanitize_user_prompt=kwargs['sanitize_user_prompt'],
|
|
|
1743 |
return False
|
1744 |
return is_same
|
1745 |
|
1746 |
+
def save_chat(*args, chat_is_list=False):
|
1747 |
args_list = list(args)
|
1748 |
+
if not chat_is_list:
|
1749 |
+
# list of chatbot histories,
|
1750 |
+
# can't pass in list with list of chatbot histories and state due to gradio limits
|
1751 |
+
chat_list = args_list[:-1]
|
1752 |
+
else:
|
1753 |
+
assert len(args_list) == 2
|
1754 |
+
chat_list = args_list[0]
|
1755 |
+
# if old chat file with single chatbot, get into shape
|
1756 |
+
if isinstance(chat_list, list) and len(chat_list) > 0 and isinstance(chat_list[0], list) and len(
|
1757 |
+
chat_list[0]) == 2 and isinstance(chat_list[0][0], str) and isinstance(chat_list[0][1], str):
|
1758 |
+
chat_list = [chat_list]
|
1759 |
# remove None histories
|
1760 |
chat_list_not_none = [x for x in chat_list if x and len(x) > 0 and len(x[0]) == 2 and x[0][1] is not None]
|
1761 |
+
chat_list_none = [x for x in chat_list if x not in chat_list_not_none]
|
1762 |
+
if len(chat_list_none) > 0 and len(chat_list_not_none) == 0:
|
1763 |
+
raise ValueError("Invalid chat file")
|
1764 |
+
# dict with keys of short chat names, values of list of list of chatbot histories
|
1765 |
+
chat_state1 = args_list[-1]
|
1766 |
short_chats = list(chat_state1.keys())
|
1767 |
if len(chat_list_not_none) > 0:
|
1768 |
# make short_chat key from only first history, based upon question that is same anyways
|
|
|
1774 |
if not already_exists:
|
1775 |
chat_state1[short_chat] = chat_list.copy()
|
1776 |
# clear chat_list so saved and then new conversation starts
|
1777 |
+
# FIXME: seems less confusing to clear, since have clear button right next
|
1778 |
+
# chat_list = [[]] * len(chat_list)
|
1779 |
+
if not chat_is_list:
|
1780 |
+
ret_list = chat_list + [chat_state1]
|
1781 |
+
else:
|
1782 |
+
ret_list = [chat_list] + [chat_state1]
|
1783 |
return tuple(ret_list)
|
1784 |
|
|
|
|
|
|
|
1785 |
def switch_chat(chat_key, chat_state1, num_model_lock=0):
|
1786 |
chosen_chat = chat_state1[chat_key]
|
1787 |
# deal with possible different size of chat list vs. current list
|
|
|
1805 |
.then(clear_scores, outputs=[score_text, score_text2, score_text_nochat])
|
1806 |
|
1807 |
def remove_chat(chat_key, chat_state1):
|
1808 |
+
if isinstance(chat_key, str):
|
1809 |
+
chat_state1.pop(chat_key, None)
|
1810 |
+
return gr.update(choices=list(chat_state1.keys()), value=None), chat_state1
|
1811 |
|
1812 |
+
remove_chat_event = remove_chat_btn.click(remove_chat,
|
1813 |
+
inputs=[radio_chats, chat_state], outputs=[radio_chats, chat_state],
|
1814 |
+
queue=False)
|
1815 |
|
1816 |
def get_chats1(chat_state1):
|
1817 |
base = 'chats'
|
|
|
1821 |
f.write(json.dumps(chat_state1, indent=2))
|
1822 |
return filename
|
1823 |
|
1824 |
+
export_chat_event = export_chats_btn.click(get_chats1, inputs=chat_state, outputs=chats_file, queue=False,
|
1825 |
+
api_name='export_chats' if allow_api else None)
|
1826 |
|
1827 |
+
def add_chats_from_file(file, chat_state1, radio_chats1, chat_exception_text1):
|
1828 |
if not file:
|
1829 |
+
return None, chat_state1, gr.update(choices=list(chat_state1.keys()), value=None), chat_exception_text1
|
1830 |
if isinstance(file, str):
|
1831 |
files = [file]
|
1832 |
else:
|
1833 |
files = file
|
1834 |
if not files:
|
1835 |
+
return None, chat_state1, gr.update(choices=list(chat_state1.keys()), value=None), chat_exception_text1
|
1836 |
+
chat_exception_list = []
|
1837 |
for file1 in files:
|
1838 |
try:
|
1839 |
if hasattr(file1, 'name'):
|
|
|
1842 |
new_chats = json.loads(f.read())
|
1843 |
for chat1_k, chat1_v in new_chats.items():
|
1844 |
# ignore chat1_k, regenerate and de-dup to avoid loss
|
1845 |
+
_, chat_state1 = save_chat(chat1_v, chat_state1, chat_is_list=True)
|
1846 |
except BaseException as e:
|
1847 |
t, v, tb = sys.exc_info()
|
1848 |
ex = ''.join(traceback.format_exception(t, v, tb))
|
1849 |
+
ex_str = "File %s exception: %s" % (file1, str(e))
|
1850 |
+
print(ex_str, flush=True)
|
1851 |
+
chat_exception_list.append(ex_str)
|
1852 |
+
chat_exception_text1 = '\n'.join(chat_exception_list)
|
1853 |
+
return None, chat_state1, gr.update(choices=list(chat_state1.keys()), value=None), chat_exception_text1
|
1854 |
|
1855 |
# note for update_user_db_func output is ignored for db
|
1856 |
+
chatup_change_event = chatsup_output.change(add_chats_from_file,
|
1857 |
+
inputs=[chatsup_output, chat_state, radio_chats,
|
1858 |
+
chat_exception_text],
|
1859 |
+
outputs=[chatsup_output, chat_state, radio_chats,
|
1860 |
+
chat_exception_text],
|
1861 |
+
queue=False,
|
1862 |
+
api_name='add_to_chats' if allow_api else None)
|
1863 |
+
|
1864 |
+
clear_chat_event = clear_chat_btn.click(fn=clear_texts,
|
1865 |
+
inputs=[text_output, text_output2] + text_outputs,
|
1866 |
+
outputs=[text_output, text_output2] + text_outputs,
|
1867 |
+
queue=False, api_name='clear' if allow_api else None) \
|
1868 |
.then(deselect_radio_chats, inputs=None, outputs=radio_chats, queue=False) \
|
1869 |
.then(clear_scores, outputs=[score_text, score_text2, score_text_nochat])
|
1870 |
|
1871 |
+
def update_radio_chats(chat_state1):
|
1872 |
+
# reverse so newest at top
|
1873 |
+
choices = list(chat_state1.keys()).copy()
|
1874 |
+
choices.reverse()
|
1875 |
+
return gr.update(choices=choices, value=None)
|
1876 |
+
|
1877 |
+
clear_event = save_chat_btn.click(save_chat,
|
1878 |
+
inputs=[text_output, text_output2] + text_outputs + [chat_state],
|
1879 |
+
outputs=[text_output, text_output2] + text_outputs + [chat_state],
|
1880 |
+
api_name='save_chat' if allow_api else None) \
|
1881 |
.then(update_radio_chats, inputs=chat_state, outputs=radio_chats,
|
1882 |
api_name='update_chats' if allow_api else None) \
|
1883 |
.then(clear_scores, outputs=[score_text, score_text2, score_text_nochat])
|
|
|
1911 |
.then(clear_torch_cache)
|
1912 |
|
1913 |
def load_model(model_name, lora_weights, server_name, model_state_old, prompt_type_old, load_8bit,
|
1914 |
+
use_gpu_id, gpu_id):
|
1915 |
# ensure no API calls reach here
|
1916 |
if is_public:
|
1917 |
raise RuntimeError("Illegal access for %s" % model_name)
|
|
|
1955 |
all_kwargs1 = all_kwargs.copy()
|
1956 |
all_kwargs1['base_model'] = model_name.strip()
|
1957 |
all_kwargs1['load_8bit'] = load_8bit
|
1958 |
+
all_kwargs1['use_gpu_id'] = use_gpu_id
|
1959 |
all_kwargs1['gpu_id'] = int(gpu_id) # detranscribe
|
1960 |
model_lower = model_name.strip().lower()
|
1961 |
if model_lower in inv_prompt_type_to_model_lower:
|
|
|
2008 |
|
2009 |
get_prompt_str_func1 = functools.partial(get_prompt_str, which=1)
|
2010 |
get_prompt_str_func2 = functools.partial(get_prompt_str, which=2)
|
2011 |
+
prompt_type.change(fn=get_prompt_str_func1, inputs=[prompt_type, prompt_dict], outputs=prompt_dict, queue=False)
|
2012 |
+
prompt_type2.change(fn=get_prompt_str_func2, inputs=[prompt_type2, prompt_dict2], outputs=prompt_dict2,
|
2013 |
+
queue=False)
|
2014 |
|
2015 |
def dropdown_prompt_type_list(x):
|
2016 |
return gr.Dropdown.update(value=x)
|
|
|
2020 |
|
2021 |
load_model_args = dict(fn=load_model,
|
2022 |
inputs=[model_choice, lora_choice, server_choice, model_state, prompt_type,
|
2023 |
+
model_load8bit_checkbox, model_use_gpu_id_checkbox, model_gpu],
|
2024 |
outputs=[model_state, model_used, lora_used, server_used,
|
2025 |
# if prompt_type changes, prompt_dict will change via change rule
|
2026 |
prompt_type, max_new_tokens, min_new_tokens,
|
|
|
2028 |
prompt_update_args = dict(fn=dropdown_prompt_type_list, inputs=prompt_type, outputs=prompt_type)
|
2029 |
chatbot_update_args = dict(fn=chatbot_list, inputs=[text_output, model_used], outputs=text_output)
|
2030 |
nochat_update_args = dict(fn=chatbot_list, inputs=[text_output_nochat, model_used], outputs=text_output_nochat)
|
2031 |
+
load_model_event = load_model_button.click(**load_model_args,
|
2032 |
+
api_name='load_model' if allow_api and is_public else None) \
|
2033 |
+
.then(**prompt_update_args) \
|
2034 |
+
.then(**chatbot_update_args) \
|
2035 |
+
.then(**nochat_update_args) \
|
2036 |
+
.then(clear_torch_cache)
|
2037 |
|
2038 |
load_model_args2 = dict(fn=load_model,
|
2039 |
inputs=[model_choice2, lora_choice2, server_choice2, model_state2, prompt_type2,
|
2040 |
+
model_load8bit_checkbox2, model_use_gpu_id_checkbox2, model_gpu2],
|
2041 |
outputs=[model_state2, model_used2, lora_used2, server_used2,
|
2042 |
# if prompt_type2 changes, prompt_dict2 will change via change rule
|
2043 |
prompt_type2, max_new_tokens2, min_new_tokens2
|
2044 |
])
|
2045 |
prompt_update_args2 = dict(fn=dropdown_prompt_type_list, inputs=prompt_type2, outputs=prompt_type2)
|
2046 |
chatbot_update_args2 = dict(fn=chatbot_list, inputs=[text_output2, model_used2], outputs=text_output2)
|
2047 |
+
load_model_event2 = load_model_button2.click(**load_model_args2,
|
2048 |
+
api_name='load_model2' if allow_api and is_public else None) \
|
2049 |
+
.then(**prompt_update_args2) \
|
2050 |
+
.then(**chatbot_update_args2) \
|
2051 |
+
.then(clear_torch_cache)
|
|
|
2052 |
|
2053 |
def dropdown_model_lora_server_list(model_list0, model_x,
|
2054 |
lora_list0, lora_x,
|
|
|
2097 |
server_options_state],
|
2098 |
queue=False)
|
2099 |
|
2100 |
+
go_event = go_btn.click(lambda: gr.update(visible=False), None, go_btn, api_name="go" if allow_api else None,
|
2101 |
+
queue=False) \
|
2102 |
.then(lambda: gr.update(visible=True), None, normal_block, queue=False) \
|
2103 |
.then(**load_model_args, queue=False).then(**prompt_update_args, queue=False)
|
2104 |
|
|
|
2166 |
def get_hash():
|
2167 |
return kwargs['git_hash']
|
2168 |
|
2169 |
+
system_event = system_btn3.click(get_hash,
|
2170 |
+
outputs=system_text3,
|
2171 |
+
api_name='system_hash' if allow_api else None,
|
2172 |
+
queue=False,
|
2173 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2174 |
|
2175 |
def count_chat_tokens(model_state1, chat1, prompt_type1, prompt_dict1,
|
2176 |
memory_restriction_level1=0,
|
|
|
2198 |
count_chat_tokens_func = functools.partial(count_chat_tokens,
|
2199 |
memory_restriction_level1=memory_restriction_level,
|
2200 |
keep_sources_in_context1=kwargs['keep_sources_in_context'])
|
2201 |
+
count_tokens_event = count_chat_tokens_btn.click(fn=count_chat_tokens,
|
2202 |
+
inputs=[model_state, text_output, prompt_type, prompt_dict],
|
2203 |
+
outputs=chat_token_count,
|
2204 |
+
api_name='count_tokens' if allow_api else None)
|
2205 |
+
|
2206 |
+
# don't pass text_output, don't want to clear output, just stop it
|
2207 |
+
# cancel only stops outer generation, not inner generation or non-generation
|
2208 |
+
stop_btn.click(lambda: None, None, None,
|
2209 |
+
cancels=submits1 + submits2 + submits3 + submits4 +
|
2210 |
+
[submit_event_nochat, submit_event_nochat2] +
|
2211 |
+
[eventdb1, eventdb2, eventdb3] +
|
2212 |
+
[eventdb7, eventdb8, eventdb9, eventdb12] +
|
2213 |
+
db_events +
|
2214 |
+
[clear_event] +
|
2215 |
+
[submit_event_nochat_api, submit_event_nochat] +
|
2216 |
+
[load_model_event, load_model_event2] +
|
2217 |
+
[count_tokens_event]
|
2218 |
+
,
|
2219 |
+
queue=False, api_name='stop' if allow_api else None).then(clear_torch_cache, queue=False)
|
2220 |
|
2221 |
demo.load(None, None, None, _js=get_dark_js() if kwargs['h2ocolors'] and False else None) # light best
|
2222 |
|
|
|
2289 |
|
2290 |
|
2291 |
def get_sources(db1, langchain_mode, dbs=None, docs_state0=None):
|
2292 |
+
set_userid(db1)
|
2293 |
+
|
2294 |
if langchain_mode in ['ChatLLM', 'LLM']:
|
2295 |
source_files_added = "NA"
|
2296 |
source_list = []
|
|
|
2321 |
return sources_file, source_list
|
2322 |
|
2323 |
|
2324 |
+
def set_userid(db1):
|
2325 |
+
# can only call this after function called so for specific userr, not in gr.State() that occurs during app init
|
2326 |
+
assert db1 is not None and len(db1) == 2
|
2327 |
+
if db1[1] is None:
|
2328 |
+
# uuid in db is used as user ID
|
2329 |
+
db1[1] = str(uuid.uuid4())
|
2330 |
+
|
2331 |
+
|
2332 |
+
def update_user_db(file, db1, chunk, chunk_size, langchain_mode, dbs=None, **kwargs):
|
2333 |
+
set_userid(db1)
|
2334 |
+
|
2335 |
+
if file is None:
|
2336 |
+
raise RuntimeError("Don't use change, use input")
|
2337 |
+
|
2338 |
try:
|
2339 |
+
return _update_user_db(file, db1=db1, chunk=chunk, chunk_size=chunk_size,
|
2340 |
+
langchain_mode=langchain_mode, dbs=dbs,
|
2341 |
+
**kwargs)
|
2342 |
except BaseException as e:
|
2343 |
print(traceback.format_exc(), flush=True)
|
2344 |
# gradio has issues if except, so fail semi-gracefully, else would hang forever in processing textbox
|
|
|
2355 |
</body>
|
2356 |
</html>
|
2357 |
""".format(ex_str)
|
2358 |
+
doc_exception_text = str(e)
|
2359 |
+
return None, langchain_mode, source_files_added, doc_exception_text
|
|
|
|
|
2360 |
finally:
|
2361 |
clear_torch_cache()
|
2362 |
|
2363 |
|
2364 |
def get_lock_file(db1, langchain_mode):
|
2365 |
+
set_userid(db1)
|
2366 |
assert len(db1) == 2 and db1[1] is not None and isinstance(db1[1], str)
|
2367 |
user_id = db1[1]
|
2368 |
base_path = 'locks'
|
|
|
2371 |
return lock_file
|
2372 |
|
2373 |
|
2374 |
+
def _update_user_db(file,
|
2375 |
+
db1=None,
|
2376 |
+
chunk=None, chunk_size=None,
|
2377 |
+
dbs=None, db_type=None, langchain_mode='UserData',
|
2378 |
user_path=None,
|
2379 |
use_openai_embedding=None,
|
2380 |
hf_embedding_model=None,
|
|
|
2385 |
verbose=None,
|
2386 |
is_url=None, is_txt=None,
|
2387 |
n_jobs=-1):
|
2388 |
+
assert db1 is not None
|
2389 |
+
assert chunk is not None
|
2390 |
+
assert chunk_size is not None
|
2391 |
assert use_openai_embedding is not None
|
2392 |
assert hf_embedding_model is not None
|
2393 |
assert caption_loader is not None
|
|
|
2396 |
assert enable_ocr is not None
|
2397 |
assert verbose is not None
|
2398 |
|
2399 |
+
set_userid(db1)
|
2400 |
+
|
2401 |
if dbs is None:
|
2402 |
dbs = {}
|
2403 |
assert isinstance(dbs, dict), "Wrong type for dbs: %s" % str(type(dbs))
|
|
|
2412 |
if not isinstance(file, (list, tuple, typing.Generator)) and isinstance(file, str):
|
2413 |
file = [file]
|
2414 |
|
2415 |
+
if langchain_mode == LangChainMode.DISABLED.value:
|
2416 |
+
return None, langchain_mode, get_source_files(), ""
|
2417 |
+
|
2418 |
+
if langchain_mode in [LangChainMode.CHAT_LLM.value, LangChainMode.CHAT_LLM.value]:
|
2419 |
+
# then switch to MyData, so langchain_mode also becomes way to select where upload goes
|
2420 |
+
# but default to mydata if nothing chosen, since safest
|
2421 |
+
langchain_mode = LangChainMode.MY_DATA.value
|
2422 |
+
|
2423 |
if langchain_mode == 'UserData' and user_path is not None:
|
2424 |
# move temp files from gradio upload to stable location
|
2425 |
for fili, fil in enumerate(file):
|
|
|
2448 |
caption_loader=caption_loader,
|
2449 |
)
|
2450 |
exceptions = [x for x in sources if x.metadata.get('exception')]
|
2451 |
+
exceptions_strs = [x.metadata['exception'] for x in exceptions]
|
2452 |
sources = [x for x in sources if 'exception' not in x.metadata]
|
2453 |
|
2454 |
lock_file = get_lock_file(db1, langchain_mode)
|
|
|
2475 |
if db is not None:
|
2476 |
db1[0] = db
|
2477 |
source_files_added = get_source_files(db=db1[0], exceptions=exceptions)
|
2478 |
+
return None, langchain_mode, source_files_added, '\n'.join(exceptions_strs)
|
2479 |
else:
|
2480 |
from gpt_langchain import get_persist_directory
|
2481 |
persist_directory = get_persist_directory(langchain_mode)
|
|
|
2493 |
hf_embedding_model=hf_embedding_model)
|
2494 |
dbs[langchain_mode] = db
|
2495 |
# NOTE we do not return db, because function call always same code path
|
2496 |
+
# return dbs[langchain_mode]
|
2497 |
# db in this code path is updated in place
|
2498 |
source_files_added = get_source_files(db=dbs[langchain_mode], exceptions=exceptions)
|
2499 |
+
return None, langchain_mode, source_files_added, '\n'.join(exceptions_strs)
|
2500 |
|
2501 |
|
2502 |
def get_db(db1, langchain_mode, dbs=None):
|
gradio_themes.py
CHANGED
@@ -133,6 +133,11 @@ class H2oTheme(Soft):
|
|
133 |
background_fill_primary_dark="*block_background_fill",
|
134 |
block_radius="0 0 8px 8px",
|
135 |
checkbox_label_text_color_selected_dark='#000000',
|
|
|
|
|
|
|
|
|
|
|
136 |
)
|
137 |
|
138 |
|
@@ -173,6 +178,9 @@ class SoftTheme(Soft):
|
|
173 |
font=font,
|
174 |
font_mono=font_mono,
|
175 |
)
|
|
|
|
|
|
|
176 |
|
177 |
|
178 |
h2o_logo = '<svg id="Layer_1" data-name="Layer 1" xmlns="http://www.w3.org/2000/svg" width="100%" height="100%"' \
|
|
|
133 |
background_fill_primary_dark="*block_background_fill",
|
134 |
block_radius="0 0 8px 8px",
|
135 |
checkbox_label_text_color_selected_dark='#000000',
|
136 |
+
#checkbox_label_text_size="*text_xs", # too small for iPhone etc. but good if full large screen zoomed to fit
|
137 |
+
checkbox_label_text_size="*text_sm",
|
138 |
+
#radio_circle="""url("data:image/svg+xml,%3csvg viewBox='0 0 32 32' fill='white' xmlns='http://www.w3.org/2000/svg'%3e%3ccircle cx='32' cy='32' r='1'/%3e%3c/svg%3e")""",
|
139 |
+
#checkbox_border_width=1,
|
140 |
+
#heckbox_border_width_dark=1,
|
141 |
)
|
142 |
|
143 |
|
|
|
178 |
font=font,
|
179 |
font_mono=font_mono,
|
180 |
)
|
181 |
+
super().set(
|
182 |
+
checkbox_label_text_size="*text_sm",
|
183 |
+
)
|
184 |
|
185 |
|
186 |
h2o_logo = '<svg id="Layer_1" data-name="Layer 1" xmlns="http://www.w3.org/2000/svg" width="100%" height="100%"' \
|
gradio_utils/__pycache__/css.cpython-310.pyc
CHANGED
Binary files a/gradio_utils/__pycache__/css.cpython-310.pyc and b/gradio_utils/__pycache__/css.cpython-310.pyc differ
|
|
gradio_utils/__pycache__/prompt_form.cpython-310.pyc
CHANGED
Binary files a/gradio_utils/__pycache__/prompt_form.cpython-310.pyc and b/gradio_utils/__pycache__/prompt_form.cpython-310.pyc differ
|
|
gradio_utils/css.py
CHANGED
@@ -12,7 +12,10 @@ def get_css(kwargs) -> str:
|
|
12 |
|
13 |
|
14 |
def make_css_base() -> str:
|
15 |
-
|
|
|
|
|
|
|
16 |
@import url('https://fonts.googleapis.com/css2?family=Source+Sans+Pro:wght@400;600&display=swap');
|
17 |
|
18 |
body.dark{#warning {background-color: #555555};}
|
|
|
12 |
|
13 |
|
14 |
def make_css_base() -> str:
|
15 |
+
css1 = """
|
16 |
+
#col_container {margin-left: auto; margin-right: auto; text-align: left;}
|
17 |
+
"""
|
18 |
+
return css1 + """
|
19 |
@import url('https://fonts.googleapis.com/css2?family=Source+Sans+Pro:wght@400;600&display=swap');
|
20 |
|
21 |
body.dark{#warning {background-color: #555555};}
|
gradio_utils/prompt_form.py
CHANGED
@@ -93,30 +93,3 @@ def make_chatbots(output_label0, output_label0_model2, **kwargs):
|
|
93 |
text_output2 = gr.Chatbot(label=output_label0_model2,
|
94 |
visible=False and not kwargs['model_lock'], height=kwargs['height'] or 400)
|
95 |
return text_output, text_output2, text_outputs
|
96 |
-
|
97 |
-
|
98 |
-
def make_prompt_form(kwargs, LangChainMode):
|
99 |
-
if kwargs['langchain_mode'] != LangChainMode.DISABLED.value:
|
100 |
-
extra_prompt_form = ". For summarization, empty submission uses first top_k_docs documents."
|
101 |
-
else:
|
102 |
-
extra_prompt_form = ""
|
103 |
-
if kwargs['input_lines'] > 1:
|
104 |
-
instruction_label = "Shift-Enter to Submit, Enter for more lines%s" % extra_prompt_form
|
105 |
-
else:
|
106 |
-
instruction_label = "Enter to Submit, Shift-Enter for more lines%s" % extra_prompt_form
|
107 |
-
|
108 |
-
with gr.Row():#elem_id='prompt-form-area'):
|
109 |
-
with gr.Column(scale=50):
|
110 |
-
instruction = gr.Textbox(
|
111 |
-
lines=kwargs['input_lines'],
|
112 |
-
label='Ask anything',
|
113 |
-
placeholder=instruction_label,
|
114 |
-
info=None,
|
115 |
-
elem_id='prompt-form',
|
116 |
-
container=True,
|
117 |
-
)
|
118 |
-
with gr.Row():
|
119 |
-
submit = gr.Button(value='Submit', variant='primary', scale=0, size='sm')
|
120 |
-
stop_btn = gr.Button(value="Stop", variant='secondary', scale=0, size='sm')
|
121 |
-
|
122 |
-
return instruction, submit, stop_btn
|
|
|
93 |
text_output2 = gr.Chatbot(label=output_label0_model2,
|
94 |
visible=False and not kwargs['model_lock'], height=kwargs['height'] or 400)
|
95 |
return text_output, text_output2, text_outputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loaders.py
CHANGED
@@ -1,40 +1,48 @@
|
|
1 |
-
|
|
|
|
|
|
|
2 |
# NOTE: Some models need specific new prompt_type
|
3 |
# E.g. t5_xxl_true_nli_mixture has input format: "premise: PREMISE_TEXT hypothesis: HYPOTHESIS_TEXT".)
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
if llama_type is None:
|
5 |
llama_type = "llama" in model_name.lower()
|
6 |
if llama_type:
|
7 |
from transformers import LlamaForCausalLM, LlamaTokenizer
|
8 |
-
|
9 |
-
tokenizer_loader = LlamaTokenizer
|
10 |
elif 'distilgpt2' in model_name.lower():
|
11 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
12 |
-
return AutoModelForCausalLM, AutoTokenizer
|
13 |
elif 'gpt2' in model_name.lower():
|
14 |
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
15 |
-
return GPT2LMHeadModel, GPT2Tokenizer
|
16 |
elif 'mbart-' in model_name.lower():
|
17 |
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
|
18 |
-
return MBartForConditionalGeneration, MBart50TokenizerFast
|
19 |
elif 't5' == model_name.lower() or \
|
20 |
't5-' in model_name.lower() or \
|
21 |
'flan-' in model_name.lower():
|
22 |
from transformers import AutoTokenizer, T5ForConditionalGeneration
|
23 |
-
return T5ForConditionalGeneration, AutoTokenizer
|
24 |
elif 'bigbird' in model_name:
|
25 |
from transformers import BigBirdPegasusForConditionalGeneration, AutoTokenizer
|
26 |
-
return BigBirdPegasusForConditionalGeneration, AutoTokenizer
|
27 |
elif 'bart-large-cnn-samsum' in model_name or 'flan-t5-base-samsum' in model_name:
|
28 |
from transformers import pipeline
|
29 |
return pipeline, "summarization"
|
30 |
elif reward_type or 'OpenAssistant/reward-model'.lower() in model_name.lower():
|
31 |
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
32 |
-
return AutoModelForSequenceClassification, AutoTokenizer
|
33 |
else:
|
34 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
35 |
model_loader = AutoModelForCausalLM
|
36 |
tokenizer_loader = AutoTokenizer
|
37 |
-
|
38 |
|
39 |
|
40 |
def get_tokenizer(tokenizer_loader, tokenizer_base_model, local_files_only, resume_download, use_auth_token):
|
|
|
1 |
+
import functools
|
2 |
+
|
3 |
+
|
4 |
+
def get_loaders(model_name, reward_type, llama_type=None, load_gptq=''):
|
5 |
# NOTE: Some models need specific new prompt_type
|
6 |
# E.g. t5_xxl_true_nli_mixture has input format: "premise: PREMISE_TEXT hypothesis: HYPOTHESIS_TEXT".)
|
7 |
+
if load_gptq:
|
8 |
+
from transformers import AutoTokenizer
|
9 |
+
from auto_gptq import AutoGPTQForCausalLM
|
10 |
+
use_triton = False
|
11 |
+
functools.partial(AutoGPTQForCausalLM.from_quantized, quantize_config=None, use_triton=use_triton)
|
12 |
+
return AutoGPTQForCausalLM.from_quantized, AutoTokenizer
|
13 |
if llama_type is None:
|
14 |
llama_type = "llama" in model_name.lower()
|
15 |
if llama_type:
|
16 |
from transformers import LlamaForCausalLM, LlamaTokenizer
|
17 |
+
return LlamaForCausalLM.from_pretrained, LlamaTokenizer
|
|
|
18 |
elif 'distilgpt2' in model_name.lower():
|
19 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
20 |
+
return AutoModelForCausalLM.from_pretrained, AutoTokenizer
|
21 |
elif 'gpt2' in model_name.lower():
|
22 |
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
23 |
+
return GPT2LMHeadModel.from_pretrained, GPT2Tokenizer
|
24 |
elif 'mbart-' in model_name.lower():
|
25 |
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
|
26 |
+
return MBartForConditionalGeneration.from_pretrained, MBart50TokenizerFast
|
27 |
elif 't5' == model_name.lower() or \
|
28 |
't5-' in model_name.lower() or \
|
29 |
'flan-' in model_name.lower():
|
30 |
from transformers import AutoTokenizer, T5ForConditionalGeneration
|
31 |
+
return T5ForConditionalGeneration.from_pretrained, AutoTokenizer
|
32 |
elif 'bigbird' in model_name:
|
33 |
from transformers import BigBirdPegasusForConditionalGeneration, AutoTokenizer
|
34 |
+
return BigBirdPegasusForConditionalGeneration.from_pretrained, AutoTokenizer
|
35 |
elif 'bart-large-cnn-samsum' in model_name or 'flan-t5-base-samsum' in model_name:
|
36 |
from transformers import pipeline
|
37 |
return pipeline, "summarization"
|
38 |
elif reward_type or 'OpenAssistant/reward-model'.lower() in model_name.lower():
|
39 |
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
40 |
+
return AutoModelForSequenceClassification.from_pretrained, AutoTokenizer
|
41 |
else:
|
42 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
43 |
model_loader = AutoModelForCausalLM
|
44 |
tokenizer_loader = AutoTokenizer
|
45 |
+
return model_loader.from_pretrained, tokenizer_loader
|
46 |
|
47 |
|
48 |
def get_tokenizer(tokenizer_loader, tokenizer_base_model, local_files_only, resume_download, use_auth_token):
|
prompter.py
CHANGED
@@ -23,9 +23,6 @@ prompt_type_to_model_name = {
|
|
23 |
'gpt2',
|
24 |
'distilgpt2',
|
25 |
'mosaicml/mpt-7b-storywriter',
|
26 |
-
'mosaicml/mpt-7b-instruct', # internal code handles instruct
|
27 |
-
'mosaicml/mpt-7b-chat', # NC, internal code handles instruct
|
28 |
-
'mosaicml/mpt-30b-instruct', # internal code handles instruct
|
29 |
],
|
30 |
'gptj': ['gptj', 'gpt4all_llama'],
|
31 |
'prompt_answer': [
|
@@ -41,6 +38,7 @@ prompt_type_to_model_name = {
|
|
41 |
'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-40b-v2',
|
42 |
'h2oai/h2ogpt-gm-oasst1-en-xgen-7b-8k',
|
43 |
'h2oai/h2ogpt-gm-oasst1-multilang-xgen-7b-8k',
|
|
|
44 |
],
|
45 |
'prompt_answer_openllama': [
|
46 |
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt',
|
@@ -49,7 +47,7 @@ prompt_type_to_model_name = {
|
|
49 |
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b',
|
50 |
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-13b',
|
51 |
],
|
52 |
-
'instruct': [],
|
53 |
'instruct_with_end': ['databricks/dolly-v2-12b'],
|
54 |
'quality': [],
|
55 |
'human_bot': [
|
@@ -74,8 +72,11 @@ prompt_type_to_model_name = {
|
|
74 |
"wizard_mega": ['openaccess-ai-collective/wizard-mega-13b'],
|
75 |
"instruct_simple": ['JosephusCheung/Guanaco'],
|
76 |
"wizard_vicuna": ['ehartford/Wizard-Vicuna-13B-Uncensored'],
|
77 |
-
"wizard2": ['llama'
|
|
|
|
|
78 |
"vicuna11": ['lmsys/vicuna-33b-v1.3'],
|
|
|
79 |
# could be plain, but default is correct prompt_type for default TheBloke model ggml-wizardLM-7B.q4_2.bin
|
80 |
}
|
81 |
if os.getenv('OPENAI_API_KEY'):
|
@@ -293,7 +294,7 @@ Current Time: {}
|
|
293 |
humanstr = prompt_tokens
|
294 |
botstr = answer_tokens
|
295 |
terminate_response = [humanstr, PreResponse, eos]
|
296 |
-
chat_sep =
|
297 |
chat_turn_sep = eos
|
298 |
elif prompt_type in [PromptType.prompt_answer_openllama.value, str(PromptType.prompt_answer_openllama.value),
|
299 |
PromptType.prompt_answer_openllama.name]:
|
@@ -309,7 +310,7 @@ Current Time: {}
|
|
309 |
humanstr = prompt_tokens
|
310 |
botstr = answer_tokens
|
311 |
terminate_response = [humanstr, PreResponse, eos]
|
312 |
-
chat_sep =
|
313 |
chat_turn_sep = eos
|
314 |
elif prompt_type in [PromptType.open_assistant.value, str(PromptType.open_assistant.value),
|
315 |
PromptType.open_assistant.name]:
|
@@ -520,6 +521,67 @@ ASSISTANT:
|
|
520 |
# normally LLM adds space after this, because was how trained.
|
521 |
# if add space here, non-unique tokenization will often make LLM produce wrong output
|
522 |
PreResponse = PreResponse
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
523 |
else:
|
524 |
raise RuntimeError("No such prompt_type=%s" % prompt_type)
|
525 |
|
|
|
23 |
'gpt2',
|
24 |
'distilgpt2',
|
25 |
'mosaicml/mpt-7b-storywriter',
|
|
|
|
|
|
|
26 |
],
|
27 |
'gptj': ['gptj', 'gpt4all_llama'],
|
28 |
'prompt_answer': [
|
|
|
38 |
'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-40b-v2',
|
39 |
'h2oai/h2ogpt-gm-oasst1-en-xgen-7b-8k',
|
40 |
'h2oai/h2ogpt-gm-oasst1-multilang-xgen-7b-8k',
|
41 |
+
'TheBloke/h2ogpt-gm-oasst1-en-2048-falcon-40b-v2-GPTQ',
|
42 |
],
|
43 |
'prompt_answer_openllama': [
|
44 |
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt',
|
|
|
47 |
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b',
|
48 |
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-13b',
|
49 |
],
|
50 |
+
'instruct': ['TheBloke/llama-30b-supercot-SuperHOT-8K-fp16'], # https://huggingface.co/TheBloke/llama-30b-supercot-SuperHOT-8K-fp16#prompting
|
51 |
'instruct_with_end': ['databricks/dolly-v2-12b'],
|
52 |
'quality': [],
|
53 |
'human_bot': [
|
|
|
72 |
"wizard_mega": ['openaccess-ai-collective/wizard-mega-13b'],
|
73 |
"instruct_simple": ['JosephusCheung/Guanaco'],
|
74 |
"wizard_vicuna": ['ehartford/Wizard-Vicuna-13B-Uncensored'],
|
75 |
+
"wizard2": ['llama'],
|
76 |
+
"mptinstruct": ['mosaicml/mpt-30b-instruct', 'mosaicml/mpt-7b-instruct', 'mosaicml/mpt-30b-instruct'],
|
77 |
+
"mptchat": ['mosaicml/mpt-7b-chat', 'mosaicml/mpt-30b-chat', 'TheBloke/mpt-30B-chat-GGML'],
|
78 |
"vicuna11": ['lmsys/vicuna-33b-v1.3'],
|
79 |
+
"falcon": ['tiiuae/falcon-40b-instruct', 'tiiuae/falcon-40b', 'tiiuae/falcon-7b-instruct', 'tiiuae/falcon-7b'],
|
80 |
# could be plain, but default is correct prompt_type for default TheBloke model ggml-wizardLM-7B.q4_2.bin
|
81 |
}
|
82 |
if os.getenv('OPENAI_API_KEY'):
|
|
|
294 |
humanstr = prompt_tokens
|
295 |
botstr = answer_tokens
|
296 |
terminate_response = [humanstr, PreResponse, eos]
|
297 |
+
chat_sep = eos
|
298 |
chat_turn_sep = eos
|
299 |
elif prompt_type in [PromptType.prompt_answer_openllama.value, str(PromptType.prompt_answer_openllama.value),
|
300 |
PromptType.prompt_answer_openllama.name]:
|
|
|
310 |
humanstr = prompt_tokens
|
311 |
botstr = answer_tokens
|
312 |
terminate_response = [humanstr, PreResponse, eos]
|
313 |
+
chat_sep = eos
|
314 |
chat_turn_sep = eos
|
315 |
elif prompt_type in [PromptType.open_assistant.value, str(PromptType.open_assistant.value),
|
316 |
PromptType.open_assistant.name]:
|
|
|
521 |
# normally LLM adds space after this, because was how trained.
|
522 |
# if add space here, non-unique tokenization will often make LLM produce wrong output
|
523 |
PreResponse = PreResponse
|
524 |
+
elif prompt_type in [PromptType.mptinstruct.value, str(PromptType.mptinstruct.value),
|
525 |
+
PromptType.mptinstruct.name]:
|
526 |
+
# https://huggingface.co/mosaicml/mpt-30b-instruct#formatting
|
527 |
+
promptA = promptB = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n' if not (
|
528 |
+
chat and reduced) else ''
|
529 |
+
|
530 |
+
PreInstruct = """
|
531 |
+
### Instruction
|
532 |
+
"""
|
533 |
+
|
534 |
+
PreInput = """
|
535 |
+
### Input
|
536 |
+
"""
|
537 |
+
|
538 |
+
PreResponse = """
|
539 |
+
### Response
|
540 |
+
"""
|
541 |
+
terminate_response = None
|
542 |
+
chat_turn_sep = chat_sep = '\n'
|
543 |
+
humanstr = PreInstruct
|
544 |
+
botstr = PreResponse
|
545 |
+
elif prompt_type in [PromptType.mptchat.value, str(PromptType.mptchat.value),
|
546 |
+
PromptType.mptchat.name]:
|
547 |
+
# https://huggingface.co/TheBloke/mpt-30B-chat-GGML#prompt-template
|
548 |
+
promptA = promptB = """<|im_start|>system\nA conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.\n<|im_end|>""" if not (
|
549 |
+
chat and reduced) else ''
|
550 |
+
|
551 |
+
PreInstruct = """<|im_start|>user
|
552 |
+
"""
|
553 |
+
|
554 |
+
PreInput = None
|
555 |
+
|
556 |
+
PreResponse = """<|im_end|><|im_start|>assistant
|
557 |
+
"""
|
558 |
+
terminate_response = ['<|im_end|>']
|
559 |
+
chat_sep = ''
|
560 |
+
chat_turn_sep = '<|im_end|>'
|
561 |
+
humanstr = PreInstruct
|
562 |
+
botstr = PreResponse
|
563 |
+
elif prompt_type in [PromptType.falcon.value, str(PromptType.falcon.value),
|
564 |
+
PromptType.falcon.name]:
|
565 |
+
promptA = promptB = "" if not (chat and reduced) else ''
|
566 |
+
|
567 |
+
PreInstruct = """User: """
|
568 |
+
|
569 |
+
PreInput = None
|
570 |
+
|
571 |
+
PreResponse = """Assistant:"""
|
572 |
+
terminate_response = ['\nUser', "<|endoftext|>"]
|
573 |
+
chat_sep = '\n\n'
|
574 |
+
chat_turn_sep = '\n\n'
|
575 |
+
humanstr = PreInstruct
|
576 |
+
botstr = PreResponse
|
577 |
+
if making_context:
|
578 |
+
# when making context, want it to appear as-if LLM generated, which starts with space after :
|
579 |
+
PreResponse = 'Assistant: '
|
580 |
+
else:
|
581 |
+
# normally LLM adds space after this, because was how trained.
|
582 |
+
# if add space here, non-unique tokenization will often make LLM produce wrong output
|
583 |
+
PreResponse = PreResponse
|
584 |
+
# generates_leading_space = True
|
585 |
else:
|
586 |
raise RuntimeError("No such prompt_type=%s" % prompt_type)
|
587 |
|
requirements.txt
CHANGED
@@ -6,7 +6,7 @@ huggingface_hub==0.15.1
|
|
6 |
appdirs==1.4.4
|
7 |
fire==0.5.0
|
8 |
docutils==0.20.1
|
9 |
-
torch==2.0.1
|
10 |
evaluate==0.4.0
|
11 |
rouge_score==0.1.2
|
12 |
sacrebleu==2.3.1
|
@@ -19,7 +19,7 @@ matplotlib==3.7.1
|
|
19 |
loralib==0.1.1
|
20 |
bitsandbytes==0.39.0
|
21 |
accelerate==0.20.3
|
22 |
-
git+https://github.com/huggingface/peft.git@
|
23 |
transformers==4.30.2
|
24 |
tokenizers==0.13.3
|
25 |
APScheduler==3.10.1
|
@@ -45,8 +45,8 @@ pytest-xdist==3.2.1
|
|
45 |
nltk==3.8.1
|
46 |
textstat==0.7.3
|
47 |
# pandoc==2.3
|
48 |
-
|
49 |
-
pypandoc_binary==1.11
|
50 |
openpyxl==3.1.2
|
51 |
lm_dataformat==0.0.20
|
52 |
bioc==2.0
|
|
|
6 |
appdirs==1.4.4
|
7 |
fire==0.5.0
|
8 |
docutils==0.20.1
|
9 |
+
torch==2.0.1; sys_platform != "darwin" and platform_machine != "arm64"
|
10 |
evaluate==0.4.0
|
11 |
rouge_score==0.1.2
|
12 |
sacrebleu==2.3.1
|
|
|
19 |
loralib==0.1.1
|
20 |
bitsandbytes==0.39.0
|
21 |
accelerate==0.20.3
|
22 |
+
git+https://github.com/huggingface/peft.git@06fd06a4d2e8ed8c3a253c67d9c3cb23e0f497ad
|
23 |
transformers==4.30.2
|
24 |
tokenizers==0.13.3
|
25 |
APScheduler==3.10.1
|
|
|
45 |
nltk==3.8.1
|
46 |
textstat==0.7.3
|
47 |
# pandoc==2.3
|
48 |
+
pypandoc==1.11; sys_platform == "darwin" and platform_machine == "arm64"
|
49 |
+
pypandoc_binary==1.11; platform_machine == "x86_64"
|
50 |
openpyxl==3.1.2
|
51 |
lm_dataformat==0.0.20
|
52 |
bioc==2.0
|
utils.py
CHANGED
@@ -97,6 +97,8 @@ def get_device():
|
|
97 |
import torch
|
98 |
if torch.cuda.is_available():
|
99 |
device = "cuda"
|
|
|
|
|
100 |
else:
|
101 |
device = "cpu"
|
102 |
|
@@ -138,7 +140,7 @@ def system_info():
|
|
138 |
gpu_memory_frac_dict = {k: gpu_memory_free_dict[k] / gpu_memory_total_dict[k] for k in gpu_memory_total_dict}
|
139 |
for k, v in gpu_memory_frac_dict.items():
|
140 |
system[f'GPU_M/%s' % k] = v
|
141 |
-
except ModuleNotFoundError:
|
142 |
pass
|
143 |
system['hash'] = get_githash()
|
144 |
|
@@ -926,3 +928,60 @@ class FakeTokenizer:
|
|
926 |
|
927 |
def __call__(self, x, *args, **kwargs):
|
928 |
return self.encode(x, *args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
import torch
|
98 |
if torch.cuda.is_available():
|
99 |
device = "cuda"
|
100 |
+
elif torch.backends.mps.is_built():
|
101 |
+
device = "mps"
|
102 |
else:
|
103 |
device = "cpu"
|
104 |
|
|
|
140 |
gpu_memory_frac_dict = {k: gpu_memory_free_dict[k] / gpu_memory_total_dict[k] for k in gpu_memory_total_dict}
|
141 |
for k, v in gpu_memory_frac_dict.items():
|
142 |
system[f'GPU_M/%s' % k] = v
|
143 |
+
except (KeyError, ModuleNotFoundError):
|
144 |
pass
|
145 |
system['hash'] = get_githash()
|
146 |
|
|
|
928 |
|
929 |
def __call__(self, x, *args, **kwargs):
|
930 |
return self.encode(x, *args, **kwargs)
|
931 |
+
|
932 |
+
|
933 |
+
def get_local_ip():
|
934 |
+
import socket
|
935 |
+
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
936 |
+
try:
|
937 |
+
# doesn't even have to be reachable
|
938 |
+
s.connect(('10.255.255.255', 1))
|
939 |
+
IP = s.getsockname()[0]
|
940 |
+
except Exception:
|
941 |
+
IP = '127.0.0.1'
|
942 |
+
finally:
|
943 |
+
s.close()
|
944 |
+
return IP
|
945 |
+
|
946 |
+
|
947 |
+
try:
|
948 |
+
assert pkg_resources.get_distribution('langchain') is not None
|
949 |
+
have_langchain = True
|
950 |
+
except (pkg_resources.DistributionNotFound, AssertionError):
|
951 |
+
have_langchain = False
|
952 |
+
|
953 |
+
|
954 |
+
import distutils.spawn
|
955 |
+
|
956 |
+
have_tesseract = distutils.spawn.find_executable("tesseract")
|
957 |
+
have_libreoffice = distutils.spawn.find_executable("libreoffice")
|
958 |
+
|
959 |
+
import pkg_resources
|
960 |
+
|
961 |
+
try:
|
962 |
+
assert pkg_resources.get_distribution('arxiv') is not None
|
963 |
+
assert pkg_resources.get_distribution('pymupdf') is not None
|
964 |
+
have_arxiv = True
|
965 |
+
except (pkg_resources.DistributionNotFound, AssertionError):
|
966 |
+
have_arxiv = False
|
967 |
+
|
968 |
+
try:
|
969 |
+
assert pkg_resources.get_distribution('pymupdf') is not None
|
970 |
+
have_pymupdf = True
|
971 |
+
except (pkg_resources.DistributionNotFound, AssertionError):
|
972 |
+
have_pymupdf = False
|
973 |
+
|
974 |
+
try:
|
975 |
+
assert pkg_resources.get_distribution('selenium') is not None
|
976 |
+
have_selenium = True
|
977 |
+
except (pkg_resources.DistributionNotFound, AssertionError):
|
978 |
+
have_selenium = False
|
979 |
+
|
980 |
+
try:
|
981 |
+
assert pkg_resources.get_distribution('playwright') is not None
|
982 |
+
have_playwright = True
|
983 |
+
except (pkg_resources.DistributionNotFound, AssertionError):
|
984 |
+
have_playwright = False
|
985 |
+
|
986 |
+
# disable, hangs too often
|
987 |
+
have_playwright = False
|